-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
32 lines (28 loc) · 1.42 KB
/
model.py
File metadata and controls
32 lines (28 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
def get_model(pretrained=True):
# # load a model pre-trained on COCO
# model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrain, min_size=540, max_size=960)
# # replace the classifier with a new one, that has
# # num_classes which is user-defined
# num_classes = 13
# # get number of input features for the classifier
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# # replace the pre-train head with a new one
# model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
backbone = resnet_fpn_backbone("resnet50", pretrained)
# backbone.out_channels = 256
# anchor_size =((8,), (16,), (32,), (64,), (128,), (256,), (512,))
anchor_size = ((16,), (32,), (64,), (128,), (256,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_size)
rpn_anchor_generator = AnchorGenerator(anchor_size, aspect_ratios)
model = FasterRCNN(backbone=backbone,
num_classes=13,
rpn_anchor_generator=rpn_anchor_generator,
min_size=540,
max_size=900)
return model