mirror of
https://github.com/mapbox/robosat.git
synced 2025-12-08 20:14:01 +00:00
Uses EfficientNetB0 as segmentation model encoder backbone
This commit is contained in:
parent
1d0cf506cd
commit
4a3c1237bb
211
robosat/efficientnet.py
Normal file
211
robosat/efficientnet.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""EfficientNet architecture.
|
||||
|
||||
See:
|
||||
- https://arxiv.org/abs/1905.11946 - EfficientNet
|
||||
- https://arxiv.org/abs/1801.04381 - MobileNet V2
|
||||
- https://arxiv.org/abs/1905.02244 - MobileNet V3
|
||||
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation
|
||||
- https://arxiv.org/abs/1803.02579 - Concurrent spatial and channel squeeze-and-excitation
|
||||
- https://arxiv.org/abs/1812.01187 - Bag of Tricks for Image Classification with Convolutional Neural Networks
|
||||
|
||||
|
||||
Known issues:
|
||||
|
||||
- Not using swish activation function: unclear where, if, and how
|
||||
much it helps. Needs more experimentation. See also MobileNet V3.
|
||||
|
||||
- Not using squeeze and excitation blocks: I had significantly worse
|
||||
results with scse blocks, and cse blocks alone did not help, too.
|
||||
Needs more experimentation as it was done on small datasets only.
|
||||
|
||||
- Not using DropConnect: no efficient native implementation in PyTorch.
|
||||
Unclear if and how much it helps over Dropout.
|
||||
"""
|
||||
|
||||
import math
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
EfficientNetParam = collections.namedtuple("EfficientNetParam", [
|
||||
"width", "depth", "resolution", "dropout"])
|
||||
|
||||
EfficientNetParams = {
|
||||
"B0": EfficientNetParam(1.0, 1.0, 224, 0.2),
|
||||
"B1": EfficientNetParam(1.0, 1.1, 240, 0.2),
|
||||
"B2": EfficientNetParam(1.1, 1.2, 260, 0.3),
|
||||
"B3": EfficientNetParam(1.2, 1.4, 300, 0.3),
|
||||
"B4": EfficientNetParam(1.4, 1.8, 380, 0.4),
|
||||
"B5": EfficientNetParam(1.6, 2.2, 456, 0.4),
|
||||
"B6": EfficientNetParam(1.8, 2.6, 528, 0.5),
|
||||
"B7": EfficientNetParam(2.0, 3.1, 600, 0.5)}
|
||||
|
||||
|
||||
def efficientnet0(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B0"], num_classes=num_classes)
|
||||
|
||||
def efficientnet1(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B1"], num_classes=num_classes)
|
||||
|
||||
def efficientnet2(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B2"], num_classes=num_classes)
|
||||
|
||||
def efficientnet3(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B3"], num_classes=num_classes)
|
||||
|
||||
def efficientnet4(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B4"], num_classes=num_classes)
|
||||
|
||||
def efficientnet5(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B5"], num_classes=num_classes)
|
||||
|
||||
def efficientnet6(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B6"], num_classes=num_classes)
|
||||
|
||||
def efficientnet7(pretrained=False, progress=False, num_classes=1000):
|
||||
return EfficientNet(param=EfficientNetParams["B7"], num_classes=num_classes)
|
||||
|
||||
|
||||
class EfficientNet(nn.Module):
|
||||
def __init__(self, param, num_classes=1000):
|
||||
super().__init__()
|
||||
|
||||
# For the exact scaling technique we follow the official implementation as the paper does not tell us
|
||||
# https://github.com/tensorflow/tpu/blob/01574500090fa9c011cb8418c61d442286720211/models/official/efficientnet/efficientnet_model.py#L101-L125
|
||||
|
||||
def scaled_depth(n):
|
||||
return int(math.ceil(n * param.depth))
|
||||
|
||||
# Snap number of channels to multiple of 8 for optimized implementations
|
||||
def scaled_width(n):
|
||||
n = n * param.width
|
||||
m = max(8, int(n + 8 / 2) // 8 * 8)
|
||||
|
||||
if m < 0.9 * n:
|
||||
m = m + 8
|
||||
|
||||
return int(m)
|
||||
|
||||
self.conv1 = nn.Conv2d(3, scaled_width(32), kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(scaled_width(32))
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
|
||||
self.layer1 = self._make_layer(n=scaled_depth(1), expansion=1, cin=scaled_width(32), cout=scaled_width(16), kernel_size=3, stride=1)
|
||||
self.layer2 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(16), cout=scaled_width(24), kernel_size=3, stride=2)
|
||||
self.layer3 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(24), cout=scaled_width(40), kernel_size=5, stride=2)
|
||||
self.layer4 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(40), cout=scaled_width(80), kernel_size=3, stride=2)
|
||||
self.layer5 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(80), cout=scaled_width(112), kernel_size=5, stride=1)
|
||||
self.layer6 = self._make_layer(n=scaled_depth(4), expansion=6, cin=scaled_width(112), cout=scaled_width(192), kernel_size=5, stride=2)
|
||||
self.layer7 = self._make_layer(n=scaled_depth(1), expansion=6, cin=scaled_width(192), cout=scaled_width(320), kernel_size=3, stride=1)
|
||||
|
||||
self.features = nn.Conv2d(scaled_width(320), scaled_width(1280), kernel_size=1, bias=False)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.dropout = nn.Dropout(param.dropout, inplace=True)
|
||||
self.fc = nn.Linear(scaled_width(1280), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
# Zero BatchNorm weight at end of res-blocks: identity by default
|
||||
# See https://arxiv.org/abs/1812.01187 Section 3.1
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.zeros_(m.linear[1].weight)
|
||||
|
||||
|
||||
def _make_layer(self, n, expansion, cin, cout, kernel_size=3, stride=1):
|
||||
layers = []
|
||||
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
planes = cin
|
||||
expand = cin * expansion
|
||||
squeeze = cout
|
||||
stride = stride
|
||||
else:
|
||||
planes = cout
|
||||
expand = cout * expansion
|
||||
squeeze = cout
|
||||
stride = 1
|
||||
|
||||
layers += [Bottleneck(planes, expand, squeeze, kernel_size=kernel_size, stride=stride)]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
x = self.layer6(x)
|
||||
x = self.layer7(x)
|
||||
|
||||
x = self.features(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.reshape(x.size(0), -1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, planes, expand, squeeze, kernel_size, stride):
|
||||
super().__init__()
|
||||
|
||||
self.expand = nn.Identity() if planes == expand else nn.Sequential(
|
||||
nn.Conv2d(planes, expand, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(expand),
|
||||
nn.ReLU6(inplace=True))
|
||||
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.Conv2d(expand, expand, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=expand, bias=False),
|
||||
nn.BatchNorm2d(expand),
|
||||
nn.ReLU6(inplace=True))
|
||||
|
||||
self.linear = nn.Sequential(
|
||||
nn.Conv2d(expand, squeeze, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(squeeze))
|
||||
|
||||
# Make all blocks skip-able via AvgPool + 1x1 Conv
|
||||
# See https://arxiv.org/abs/1812.01187 Figure 2 c
|
||||
|
||||
downsample = []
|
||||
|
||||
if stride != 1:
|
||||
downsample += [nn.AvgPool2d(kernel_size=stride, stride=stride)]
|
||||
|
||||
if planes != squeeze:
|
||||
downsample += [
|
||||
nn.Conv2d(planes, squeeze, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(squeeze)]
|
||||
|
||||
self.downsample = nn.Identity() if not downsample else nn.Sequential(*downsample)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.expand(x)
|
||||
xx = self.depthwise(xx)
|
||||
xx = self.linear(xx)
|
||||
|
||||
x = self.downsample(x)
|
||||
xx.add_(x)
|
||||
|
||||
return xx
|
||||
@ -12,7 +12,7 @@ See:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision.models import resnet50
|
||||
from robosat.efficientnet import efficientnet0
|
||||
|
||||
|
||||
class ConvRelu(nn.Module):
|
||||
@ -91,17 +91,17 @@ class UNet(nn.Module):
|
||||
|
||||
# Todo: make input channels configurable, not hard-coded to three channels for RGB
|
||||
|
||||
self.resnet = resnet50(pretrained=pretrained)
|
||||
self.net = efficientnet0(pretrained=pretrained)
|
||||
|
||||
# Access resnet directly in forward pass; do not store refs here due to
|
||||
# Access backbone directly in forward pass; do not store refs here due to
|
||||
# https://github.com/pytorch/pytorch/issues/8392
|
||||
|
||||
self.center = DecoderBlock(2048, num_filters * 8)
|
||||
self.center = DecoderBlock(1280, num_filters * 8)
|
||||
|
||||
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
|
||||
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
|
||||
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
|
||||
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
|
||||
self.dec0 = DecoderBlock(1280 + num_filters * 8, num_filters * 8)
|
||||
self.dec1 = DecoderBlock(112 + num_filters * 8, num_filters * 8)
|
||||
self.dec2 = DecoderBlock(40 + num_filters * 8, num_filters * 2)
|
||||
self.dec3 = DecoderBlock(24 + num_filters * 2, num_filters * 2 * 2)
|
||||
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
|
||||
self.dec5 = ConvRelu(num_filters, num_filters)
|
||||
|
||||
@ -117,17 +117,33 @@ class UNet(nn.Module):
|
||||
The networks output tensor.
|
||||
"""
|
||||
size = x.size()
|
||||
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for resnet"
|
||||
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for backbone"
|
||||
|
||||
enc0 = self.resnet.conv1(x)
|
||||
enc0 = self.resnet.bn1(enc0)
|
||||
enc0 = self.resnet.relu(enc0)
|
||||
enc0 = self.resnet.maxpool(enc0)
|
||||
# 1, 3, 512, 512
|
||||
enc0 = self.net.conv1(x)
|
||||
enc0 = self.net.bn1(enc0)
|
||||
enc0 = self.net.relu(enc0)
|
||||
# 1, 32, 256, 256
|
||||
enc0 = self.net.layer1(enc0)
|
||||
# 1, 16, 256, 256
|
||||
|
||||
enc1 = self.resnet.layer1(enc0)
|
||||
enc2 = self.resnet.layer2(enc1)
|
||||
enc3 = self.resnet.layer3(enc2)
|
||||
enc4 = self.resnet.layer4(enc3)
|
||||
enc1 = self.net.layer2(enc0)
|
||||
# 1, 24, 128, 128
|
||||
|
||||
enc2 = self.net.layer3(enc1)
|
||||
# 1, 40, 64, 64
|
||||
|
||||
enc3 = self.net.layer4(enc2)
|
||||
# 1, 80, 32, 32
|
||||
enc3 = self.net.layer5(enc3)
|
||||
# 1, 112, 32, 32
|
||||
|
||||
enc4 = self.net.layer6(enc3)
|
||||
# 1, 192, 16, 16
|
||||
enc4 = self.net.layer7(enc4)
|
||||
# 1, 320, 16, 16
|
||||
enc4 = self.net.features(enc4)
|
||||
# 1, 1280, 16, 16
|
||||
|
||||
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user