AI-Project/verify.py

27 lines
667 B
Python

import torch
import torch.utils.data
from architecture import model
from dataset import ImagesDataset
if __name__ == "__main__":
model_params = torch.load("submit_attempts/Nr1/model-20240713-164545-epoch-30.pth")
model.load_state_dict(model_params)
model.eval()
dataset = ImagesDataset("training_data")
correct = 0
print("evaluating...")
for (image_t, class_id, _, _) in torch.utils.data.DataLoader(dataset):
out = model(image_t)
if out.argmax() == class_id:
correct += 1
print(f"Identified {correct} images out of {len(dataset)} correctly")
print(f"Accuracy: {100 * correct / len(dataset)}%")