27 lines
667 B
Python
27 lines
667 B
Python
import torch
|
|
import torch.utils.data
|
|
|
|
from architecture import model
|
|
from dataset import ImagesDataset
|
|
|
|
if __name__ == "__main__":
|
|
model_params = torch.load("submit_attempts/Nr1/model-20240713-164545-epoch-30.pth")
|
|
|
|
model.load_state_dict(model_params)
|
|
model.eval()
|
|
|
|
dataset = ImagesDataset("training_data")
|
|
|
|
correct = 0
|
|
|
|
print("evaluating...")
|
|
|
|
for (image_t, class_id, _, _) in torch.utils.data.DataLoader(dataset):
|
|
out = model(image_t)
|
|
|
|
if out.argmax() == class_id:
|
|
correct += 1
|
|
|
|
print(f"Identified {correct} images out of {len(dataset)} correctly")
|
|
print(f"Accuracy: {100 * correct / len(dataset)}%")
|