该仓库基于milesial的Pytorch-UNet进行的,非常感谢大佬无私的奉献。
- 数据集: Kaggle's Carvana Image Masking Challenge
- 网络:U-Net
- 所有参数均可在config.py中设置
- 重新整理结构,并加入大量代码注释
- loading
-
环境:
python版本 pytorch版本 3.5 0.4 -
依赖:
pip install pydensecrf
下载Kaggle's Carvana Image Masking Challenge数据集,并在utils/config.py中配置数据集的根目录。
CarvanaImageMaskingChallenge
│
└───train
│ │ xxx.gif
│ │ ...
│
└───train_masks
│ │ xxx.jpg
│ │ ...
1、在config.py中配置训练参数
2、执行train.py开始训练
功能:可视化一张预测图片
1、将预训练模型放到项目根目录下
预训练模型下载:MODEL.pth
2、预测单张图片
python predict.py -i image.jpg -o output.jpg
3、预测多张图片并显示
python predict.py -i image1.jpg image2.jpg --viz --no-save


