1.5 KiB
1.5 KiB
MobileNetV2 图像分类项目
本项目使用PyTorch框架实现基于MobileNetV2的图像分类模型。
环境要求
- Python 3.7+
- PyTorch 1.10+
- torchvision
- tqdm
安装依赖:
pip install torch torchvision tqdm
数据准备
- 创建以下目录结构:
train_data/
1/
train/
test/
model/
1/
test_image/
- 将训练图像放入
train_data/1/train
目录,每个类别一个子目录 - 将测试图像放入
train_data/1/test
目录,保持相同的类别结构
训练模型
运行训练脚本:
python train_mobilenetv2.py
训练参数:
- 训练轮数:20
- 批量大小:64
- 学习率:0.0001
- 优化器:Adam
- 学习率调度器:ReduceLROnPlateau
模型保存
训练好的模型将保存在model/1/
目录下,文件名包含训练轮数和准确率。
目录结构
.
├── train_mobilenetv2.py # 主训练脚本
├── pretreatment.ipynb # 数据预处理notebook
├── test.ipynb # 测试notebook
├── train_data/ # 训练数据(git忽略)
│ └── 1/
│ ├── train/ # 训练图像
│ └── test/ # 测试图像
├── model/ # 保存的模型(git忽略)
│ └── 1/
└── test_image/ # 测试图像(git忽略)
注意事项
- 项目使用预训练的MobileNetV2模型
- 数据增强包括随机裁剪和水平翻转
- 如果有GPU会自动使用GPU进行训练