ai-trash-can/test.ipynb
2023-08-03 10:38:52 +08:00

102 lines
2.7 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as transforms\n",
"import torchvision.models as models\n",
"from PIL import Image\n",
"\n",
"# 设置设备\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# 数据预处理和标准化\n",
"transform = transforms.Compose([\n",
" transforms.Resize((224, 224)), # 调整图像大小为模型要求的大小\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
"])\n",
"\n",
"# 加载已训练的MobileNetV2模型\n",
"model = models.mobilenet_v2(pretrained=False) # 创建一个新的MobileNetV2模型\n",
"num_classes = 12 # 假设您有12个类别根据您的实际数据集进行调整\n",
"model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)\n",
"model.load_state_dict(torch.load(\"./model/1/epochs10 96.13.pt\"))\n",
"model = model.to(device)\n",
"model.eval()\n",
"\n",
"class_names = [\"battery\",\"brick\",\"bottle\",\"butt\",\"cans\",\"carrot_piece\",\"fruits\",\"leaf\",\"nothing\",\"paper\",\"potato\",\"vegetable\"]\n",
"\n",
"def classify_image(image_path):\n",
" image = Image.open(image_path)\n",
" image = transform(image).unsqueeze(0) # 加载图像并进行预处理,添加一个批次维度\n",
" image = image.to(device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = model(image)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
"\n",
" class_idx = predicted.item()\n",
" class_name = class_names[class_idx]\n",
" return class_name\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bottle\n"
]
}
],
"source": [
"print(classify_image(\"./test_image/1690813083599.jpg\"))"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}