ai-trash-can/README.md
2025-01-17 12:28:16 +08:00

66 lines
1.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# MobileNetV2 图像分类项目
本项目使用PyTorch框架实现基于MobileNetV2的图像分类模型。
## 环境要求
- Python 3.7+
- PyTorch 1.10+
- torchvision
- tqdm
安装依赖:
```bash
pip install torch torchvision tqdm
```
## 数据准备
1. 创建以下目录结构:
```
train_data/
1/
train/
test/
model/
1/
test_image/
```
2. 将训练图像放入`train_data/1/train`目录,每个类别一个子目录
3. 将测试图像放入`train_data/1/test`目录,保持相同的类别结构
## 训练模型
运行训练脚本:
```bash
python train_mobilenetv2.py
```
训练参数:
- 训练轮数20
- 批量大小64
- 学习率0.0001
- 优化器Adam
- 学习率调度器ReduceLROnPlateau
## 模型保存
训练好的模型将保存在`model/1/`目录下,文件名包含训练轮数和准确率。
## 目录结构
```
.
├── train_mobilenetv2.py # 主训练脚本
├── pretreatment.ipynb # 数据预处理notebook
├── test.ipynb # 测试notebook
├── train_data/ # 训练数据git忽略
│ └── 1/
│ ├── train/ # 训练图像
│ └── test/ # 测试图像
├── model/ # 保存的模型git忽略
│ └── 1/
└── test_image/ # 测试图像git忽略
```
## 注意事项
- 项目使用预训练的MobileNetV2模型
- 数据增强包括随机裁剪和水平翻转
- 如果有GPU会自动使用GPU进行训练