Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

readme.md

U-Net网络


该仓库基于milesialPytorch-UNet进行的,非常感谢大佬无私的奉献。


目前支持:

相比原作者的特点:

  • 所有参数均可在config.py中设置
  • 重新整理结构,并加入大量代码注释
  • loading

  • 环境:

    python版本 pytorch版本
    3.5 0.4
  • 依赖:

    pip install pydensecrf
    

U-Net网络结构

  • 原论文左侧 conv 3x3 无pad,故每次conv后feature map尺寸缩小。故与右侧feature map融合之前需要裁剪。
  • 该仓库左侧 conv 3x3 pad=1,故每次conv后feature map尺寸不变。故反卷积后保证尺度统一与右侧feature map融合即可。

准备数据集:

下载Kaggle's Carvana Image Masking Challenge数据集,并在utils/config.py中配置数据集的根目录。

CarvanaImageMaskingChallenge
│
└───train
│   │   xxx.gif
│   │   ...
│   
└───train_masks
│   │   xxx.jpg
│   │   ...

Trian:

1、在config.py中配置训练参数

2、执行train.py开始训练


Eval:

每训练一轮epoch都将计算Dice距离(用于度量两个集合的相似性)

Predict:

功能:可视化一张预测图片

1、将预训练模型放到项目根目录下

预训练模型下载:MODEL.pth

2、预测单张图片

    python predict.py -i image.jpg -o output.jpg

3、预测多张图片并显示

    python predict.py -i image1.jpg image2.jpg --viz --no-save
图片说明图片说明

关于作者