ai-trash-can/README.md
2025-01-17 12:28:16 +08:00

1.5 KiB
Raw Blame History

MobileNetV2 图像分类项目

本项目使用PyTorch框架实现基于MobileNetV2的图像分类模型。

环境要求

  • Python 3.7+
  • PyTorch 1.10+
  • torchvision
  • tqdm

安装依赖:

pip install torch torchvision tqdm

数据准备

  1. 创建以下目录结构:
train_data/
    1/
        train/
        test/
model/
    1/
test_image/
  1. 将训练图像放入train_data/1/train目录,每个类别一个子目录
  2. 将测试图像放入train_data/1/test目录,保持相同的类别结构

训练模型

运行训练脚本:

python train_mobilenetv2.py

训练参数:

  • 训练轮数20
  • 批量大小64
  • 学习率0.0001
  • 优化器Adam
  • 学习率调度器ReduceLROnPlateau

模型保存

训练好的模型将保存在model/1/目录下,文件名包含训练轮数和准确率。

目录结构

.
├── train_mobilenetv2.py       # 主训练脚本
├── pretreatment.ipynb         # 数据预处理notebook
├── test.ipynb                 # 测试notebook
├── train_data/                # 训练数据git忽略
│   └── 1/
│       ├── train/             # 训练图像
│       └── test/              # 测试图像
├── model/                     # 保存的模型git忽略
│   └── 1/
└── test_image/                # 测试图像git忽略

注意事项

  • 项目使用预训练的MobileNetV2模型
  • 数据增强包括随机裁剪和水平翻转
  • 如果有GPU会自动使用GPU进行训练