ai-trash-can/README.md

66 lines
1.5 KiB
Markdown
Raw Permalink Normal View History

2025-01-17 04:28:16 +00:00
# MobileNetV2 图像分类项目
本项目使用PyTorch框架实现基于MobileNetV2的图像分类模型。
## 环境要求
- Python 3.7+
- PyTorch 1.10+
- torchvision
- tqdm
安装依赖:
```bash
pip install torch torchvision tqdm
```
## 数据准备
1. 创建以下目录结构:
```
train_data/
1/
train/
test/
model/
1/
test_image/
```
2. 将训练图像放入`train_data/1/train`目录,每个类别一个子目录
3. 将测试图像放入`train_data/1/test`目录,保持相同的类别结构
## 训练模型
运行训练脚本:
```bash
python train_mobilenetv2.py
```
训练参数:
- 训练轮数20
- 批量大小64
- 学习率0.0001
- 优化器Adam
- 学习率调度器ReduceLROnPlateau
## 模型保存
训练好的模型将保存在`model/1/`目录下,文件名包含训练轮数和准确率。
## 目录结构
```
.
├── train_mobilenetv2.py # 主训练脚本
├── pretreatment.ipynb # 数据预处理notebook
├── test.ipynb # 测试notebook
├── train_data/ # 训练数据git忽略
│ └── 1/
│ ├── train/ # 训练图像
│ └── test/ # 测试图像
├── model/ # 保存的模型git忽略
│ └── 1/
└── test_image/ # 测试图像git忽略
```
## 注意事项
- 项目使用预训练的MobileNetV2模型
- 数据增强包括随机裁剪和水平翻转
- 如果有GPU会自动使用GPU进行训练