Uses EfficientNetB0 as segmentation model encoder backbone

This commit is contained in:
Daniel J. Hofmann 2019-07-28 18:00:14 +02:00
parent 1d0cf506cd
commit 4a3c1237bb
2 changed files with 244 additions and 17 deletions

211
robosat/efficientnet.py Normal file
View File

@ -0,0 +1,211 @@
"""EfficientNet architecture.
See:
- https://arxiv.org/abs/1905.11946 - EfficientNet
- https://arxiv.org/abs/1801.04381 - MobileNet V2
- https://arxiv.org/abs/1905.02244 - MobileNet V3
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation
- https://arxiv.org/abs/1803.02579 - Concurrent spatial and channel squeeze-and-excitation
- https://arxiv.org/abs/1812.01187 - Bag of Tricks for Image Classification with Convolutional Neural Networks
Known issues:
- Not using swish activation function: unclear where, if, and how
much it helps. Needs more experimentation. See also MobileNet V3.
- Not using squeeze and excitation blocks: I had significantly worse
results with scse blocks, and cse blocks alone did not help, too.
Needs more experimentation as it was done on small datasets only.
- Not using DropConnect: no efficient native implementation in PyTorch.
Unclear if and how much it helps over Dropout.
"""
import math
import collections
import torch
import torch.nn as nn
EfficientNetParam = collections.namedtuple("EfficientNetParam", [
"width", "depth", "resolution", "dropout"])
EfficientNetParams = {
"B0": EfficientNetParam(1.0, 1.0, 224, 0.2),
"B1": EfficientNetParam(1.0, 1.1, 240, 0.2),
"B2": EfficientNetParam(1.1, 1.2, 260, 0.3),
"B3": EfficientNetParam(1.2, 1.4, 300, 0.3),
"B4": EfficientNetParam(1.4, 1.8, 380, 0.4),
"B5": EfficientNetParam(1.6, 2.2, 456, 0.4),
"B6": EfficientNetParam(1.8, 2.6, 528, 0.5),
"B7": EfficientNetParam(2.0, 3.1, 600, 0.5)}
def efficientnet0(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B0"], num_classes=num_classes)
def efficientnet1(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B1"], num_classes=num_classes)
def efficientnet2(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B2"], num_classes=num_classes)
def efficientnet3(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B3"], num_classes=num_classes)
def efficientnet4(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B4"], num_classes=num_classes)
def efficientnet5(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B5"], num_classes=num_classes)
def efficientnet6(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B6"], num_classes=num_classes)
def efficientnet7(pretrained=False, progress=False, num_classes=1000):
return EfficientNet(param=EfficientNetParams["B7"], num_classes=num_classes)
class EfficientNet(nn.Module):
def __init__(self, param, num_classes=1000):
super().__init__()
# For the exact scaling technique we follow the official implementation as the paper does not tell us
# https://github.com/tensorflow/tpu/blob/01574500090fa9c011cb8418c61d442286720211/models/official/efficientnet/efficientnet_model.py#L101-L125
def scaled_depth(n):
return int(math.ceil(n * param.depth))
# Snap number of channels to multiple of 8 for optimized implementations
def scaled_width(n):
n = n * param.width
m = max(8, int(n + 8 / 2) // 8 * 8)
if m < 0.9 * n:
m = m + 8
return int(m)
self.conv1 = nn.Conv2d(3, scaled_width(32), kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(scaled_width(32))
self.relu = nn.ReLU6(inplace=True)
self.layer1 = self._make_layer(n=scaled_depth(1), expansion=1, cin=scaled_width(32), cout=scaled_width(16), kernel_size=3, stride=1)
self.layer2 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(16), cout=scaled_width(24), kernel_size=3, stride=2)
self.layer3 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(24), cout=scaled_width(40), kernel_size=5, stride=2)
self.layer4 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(40), cout=scaled_width(80), kernel_size=3, stride=2)
self.layer5 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(80), cout=scaled_width(112), kernel_size=5, stride=1)
self.layer6 = self._make_layer(n=scaled_depth(4), expansion=6, cin=scaled_width(112), cout=scaled_width(192), kernel_size=5, stride=2)
self.layer7 = self._make_layer(n=scaled_depth(1), expansion=6, cin=scaled_width(192), cout=scaled_width(320), kernel_size=3, stride=1)
self.features = nn.Conv2d(scaled_width(320), scaled_width(1280), kernel_size=1, bias=False)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(param.dropout, inplace=True)
self.fc = nn.Linear(scaled_width(1280), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
# Zero BatchNorm weight at end of res-blocks: identity by default
# See https://arxiv.org/abs/1812.01187 Section 3.1
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.zeros_(m.linear[1].weight)
def _make_layer(self, n, expansion, cin, cout, kernel_size=3, stride=1):
layers = []
for i in range(n):
if i == 0:
planes = cin
expand = cin * expansion
squeeze = cout
stride = stride
else:
planes = cout
expand = cout * expansion
squeeze = cout
stride = 1
layers += [Bottleneck(planes, expand, squeeze, kernel_size=kernel_size, stride=stride)]
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
x = self.features(x)
x = self.avgpool(x)
x = x.reshape(x.size(0), -1)
x = self.dropout(x)
x = self.fc(x)
return x
class Bottleneck(nn.Module):
def __init__(self, planes, expand, squeeze, kernel_size, stride):
super().__init__()
self.expand = nn.Identity() if planes == expand else nn.Sequential(
nn.Conv2d(planes, expand, kernel_size=1, bias=False),
nn.BatchNorm2d(expand),
nn.ReLU6(inplace=True))
self.depthwise = nn.Sequential(
nn.Conv2d(expand, expand, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=expand, bias=False),
nn.BatchNorm2d(expand),
nn.ReLU6(inplace=True))
self.linear = nn.Sequential(
nn.Conv2d(expand, squeeze, kernel_size=1, bias=False),
nn.BatchNorm2d(squeeze))
# Make all blocks skip-able via AvgPool + 1x1 Conv
# See https://arxiv.org/abs/1812.01187 Figure 2 c
downsample = []
if stride != 1:
downsample += [nn.AvgPool2d(kernel_size=stride, stride=stride)]
if planes != squeeze:
downsample += [
nn.Conv2d(planes, squeeze, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(squeeze)]
self.downsample = nn.Identity() if not downsample else nn.Sequential(*downsample)
def forward(self, x):
xx = self.expand(x)
xx = self.depthwise(xx)
xx = self.linear(xx)
x = self.downsample(x)
xx.add_(x)
return xx

View File

@ -12,7 +12,7 @@ See:
import torch
import torch.nn as nn
from torchvision.models import resnet50
from robosat.efficientnet import efficientnet0
class ConvRelu(nn.Module):
@ -91,17 +91,17 @@ class UNet(nn.Module):
# Todo: make input channels configurable, not hard-coded to three channels for RGB
self.resnet = resnet50(pretrained=pretrained)
self.net = efficientnet0(pretrained=pretrained)
# Access resnet directly in forward pass; do not store refs here due to
# Access backbone directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392
self.center = DecoderBlock(2048, num_filters * 8)
self.center = DecoderBlock(1280, num_filters * 8)
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
self.dec0 = DecoderBlock(1280 + num_filters * 8, num_filters * 8)
self.dec1 = DecoderBlock(112 + num_filters * 8, num_filters * 8)
self.dec2 = DecoderBlock(40 + num_filters * 8, num_filters * 2)
self.dec3 = DecoderBlock(24 + num_filters * 2, num_filters * 2 * 2)
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
self.dec5 = ConvRelu(num_filters, num_filters)
@ -117,17 +117,33 @@ class UNet(nn.Module):
The networks output tensor.
"""
size = x.size()
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for resnet"
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for backbone"
enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.resnet.maxpool(enc0)
# 1, 3, 512, 512
enc0 = self.net.conv1(x)
enc0 = self.net.bn1(enc0)
enc0 = self.net.relu(enc0)
# 1, 32, 256, 256
enc0 = self.net.layer1(enc0)
# 1, 16, 256, 256
enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)
enc1 = self.net.layer2(enc0)
# 1, 24, 128, 128
enc2 = self.net.layer3(enc1)
# 1, 40, 64, 64
enc3 = self.net.layer4(enc2)
# 1, 80, 32, 32
enc3 = self.net.layer5(enc3)
# 1, 112, 32, 32
enc4 = self.net.layer6(enc3)
# 1, 192, 16, 16
enc4 = self.net.layer7(enc4)
# 1, 320, 16, 16
enc4 = self.net.features(enc4)
# 1, 1280, 16, 16
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))