In [9]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据预处理和标准化
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小为模型要求的大小
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载已训练的MobileNetV2模型
model = models.mobilenet_v2(pretrained=False)  # 创建一个新的MobileNetV2模型
num_classes = 12  # 假设您有12个类别，根据您的实际数据集进行调整
model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)
model.load_state_dict(torch.load("./model/1/epochs10 96.13.pt"))
model = model.to(device)
model.eval()

class_names = ["battery","brick","bottle","butt","cans","carrot_piece","fruits","leaf","nothing","paper","potato","vegetable"]

def classify_image(image_path):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # 加载图像并进行预处理，添加一个批次维度
    image = image.to(device)

    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)

    class_idx = predicted.item()
    class_name = class_names[class_idx]
    return class_name


In [10]:
print(classify_image("./test_image/1690813083599.jpg"))

bottle
