MobileViT DeepLabV3+ Integration
Tạo vào: 24 tháng 12, 2024
Tạo vào: 24 tháng 12, 2024
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2
class MobileNetV2(nn.Module):
def init(self, downsample_factor=8, pretrained=True):
super(MobileNetV2, self).init()
from functools import partial
textmodel = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x
#-----------------------------------------#
#-----------------------------------------#
class ASPP(nn.Module):
def init(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).init()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=6rate, dilation=6rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=12rate, dilation=12rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=18rate, dilation=18rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch5_relu = nn.ReLU(inplace=True)
textself.conv_cat = nn.Sequential( nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() #-----------------------------------------# # 一共五个分支 #-----------------------------------------# conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) #-----------------------------------------# # 第五个分支,全局平均池化+卷积 #-----------------------------------------# global_feature = torch.mean(x,2,True) global_feature = torch.mean(global_feature,3,True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) #-----------------------------------------# # 将五个分支的内容堆叠起来 # 然后1x1卷积整合特征。 #-----------------------------------------# feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result
class DeepLab(nn.Module):
def init(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
super(DeepLab, self).init()
if backbone=="xception":
#----------------------------------#
# 获得两个特征层
# 浅层特征 [128,128,256]
# 主干部分 [30,30,2048]
#----------------------------------#
self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
in_channels = 2048
low_level_channels = 256
elif backbone=="mobilenet":
#----------------------------------#
# 获得两个特征层
# 浅层特征 [128,128,24]
# 主干部分 [30,30,320]
#----------------------------------#
self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
in_channels = 320
low_level_channels = 24
else:
raise ValueError('Unsupported backbone - {}, Use mobilenet, xception.'.format(backbone))
text#-----------------------------------------# # ASPP特征提取模块 # 利用不同膨胀率的膨胀卷积进行特征提取 #-----------------------------------------# self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor) #----------------------------------# # 浅层特征边 #----------------------------------# self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) self.cat_conv = nn.Sequential( nn.Conv2d(48+256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) def forward(self, x): H, W = x.size(2), x.size(3) #-----------------------------------------# # 获得两个特征层 # low_level_features: 浅层特征-进行卷积处理 # x : 主干部分-利用ASPP结构进行加强特征提取 #-----------------------------------------# low_level_features, x = self.backbone(x) x = self.aspp(x) low_level_features = self.shortcut_conv(low_level_features) #-----------------------------------------# # 将加强特征边上采样 # 与浅层特征堆叠后利用卷积进行特征提取 #-----------------------------------------# x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True) x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) x = self.cls_conv(x) x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x
import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
class PreNorm(nn.Module):
def init(self, dim, fn):
super().init()
self.norm = nn.LayerNorm(dim)
self.fn = fn
textdef forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def init(self, dim, hidden_dim, dropout=0.):
super().init()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
textdef forward(self, x): return self.net(x)
class Attention(nn.Module):
def init(self, dim, heads=8, dim_head=64, dropout=0.):
super().init()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
textself.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out)
class Transformer(nn.Module):
def init(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().init()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
textdef forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x
class MV2Block(nn.Module):
def init(self, inp, oup, stride=1, expansion=4):
super().init()
self.stride = stride
assert stride in [1, 2]
texthidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x)
class MobileViTBlock(nn.Module):
def init(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().init()
self.ph, self.pw = patch_size
textself.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # Local representations x = self.conv1(x) x = self.conv2(x) # Global representations _, _, h, w = x.shape x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x
class MobileViT(nn.Module):
def init(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
super().init()
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
textL = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) self.conv3 = conv_1x1_bn(channels[-1],num_classes) self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # 移除全局平均池化层 # self.pool = nn.AvgPool2d(ih//32, 1) # self.fc = nn.Linear(channels[-1], num_classes, bias=False) def forward(self, x): x = self.conv1(x) x = self.mv2[0](x) # print("1:", x.shape) x = self.mv2[1](x) x = self.mv2[2](x) x = self.mv2[3](x) # Repeat # print("2:", x.shape) x = self.mv2[4](x) x = self.mvit[0](x) # print("3:", x.shape) x = self.mv2[5](x) x = self.mvit[1](x) # print("4:", x.shape) x = self.mv2[6](x) x = self.mvit[2](x) x = self.conv2(x) # print("5:", x.shape) # 注释掉全局平均池化层 # x = self.pool(x).view(-1, x.shape[1]) # x = self.fc(x) x = self.conv3(x) x = self.upsample(x) return x
def mobilevit_xxs():
dims = [64, 80, 96]
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
def mobilevit_xs():
dims = [96, 120, 144]
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
return MobileViT((256, 256), dims, channels, num_classes=1000)
def mobilevit_s():
dims = [144, 192, 240]
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
return MobileViT((256, 256), dims, channels, num_classes=1000)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if name == 'main':
img = torch.randn(1, 3, 256, 256)
model = MobileViT(image_size=(512,512), dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],num_classes=3)
out= model(img)
print(out.shape)
text# vit = mobilevit_xxs() # out = vit(img) # print(out.shape) # print(count_parameters(vit)) # # vit = mobilevit_xs() # out = vit(img) # print(out.shape) # print(count_parameters(vit)) # # vit = mobilevit_s() # out = vit(img) # print(out.shape) # print(count_parameters(vit))
第一个是DeeplabV3+网络,第二个是mobileVit网络,现在将DeeplabV3+的主干替换为mobilVit,并在一个代码文件中展示
Certainly! Below is a comprehensive Python script that integrates the DeepLabV3+ architecture with a MobileViT backbone. This script combines the definitions of both networks, modifies the MobileViT backbone to output the necessary feature maps for DeepLabV3+, and demonstrates how to instantiate and test the combined model.
MobileViT Backbone Modification:
MobileViT class is adjusted to return two feature maps:
low_level_features: Extracted from an intermediate layer (after the second MV2Block).x: The final deep features used by the ASPP module.DeepLabV3+ Integration:
DeepLab class is updated to utilize the modified MobileViT as its backbone.low_level_channels parameter is set appropriately based on the MobileViT architecture.Complete Code Structure:
pythonimport torch import torch.nn as nn import torch.nn.functional as F from functools import partial from einops import rearrange # --------------------------- # MobileViT Backbone # --------------------------- def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): return nn.Sequential( nn.Conv2d(inp, oup, kernel_size, stride, kernel_size // 2, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): # x: (batch, patch, num_tokens, dim) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # (b, p, h, n, n) attn = self.attend(dots) # (b, p, h, n, n) out = torch.matmul(attn, v) # (b, p, h, n, d) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads, dim_head, dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class MV2Block(nn.Module): def __init__(self, inp, oup, stride=1, expansion=4): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileViTBlock(nn.Module): def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # Local representations x = self.conv1(x) x = self.conv2(x) # Global representations _, _, h, w = x.shape x = rearrange(x, 'b c (h ph) (w pw) -> b (ph pw) (h w) c', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) c -> b c (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x class MobileViT(nn.Module): def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): super().__init__() ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 L = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[4], channels[5], 1, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[6], channels[7], 1, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mv2.append(MV2Block(channels[8], channels[9], 1, expansion)) self.mv2.append(MV2Block(channels[9], channels[10], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) self.conv3 = conv_1x1_bn(channels[-1], num_classes) self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # Note: Global average pooling and classification layers are removed for backbone purposes def forward(self, x): x = self.conv1(x) # Output: channels[0], stride=2 x = self.mv2[0](x) # Output: channels[1] x = self.mv2[1](x) # Output: channels[2], stride=2 # Extract low_level_features here (after second MV2Block) low_level_features = x.clone() x = self.mv2[2](x) # Output: channels[3] x = self.mv2[3](x) # Output: channels[4], stride=2 x = self.mv2[4](x) # Output: channels[5] x = self.mvit[0](x) # MobileViTBlock 1 x = self.mv2[5](x) # Output: channels[6], stride=2 x = self.mvit[1](x) # MobileViTBlock 2 x = self.mv2[6](x) # Output: channels[7] x = self.mvit[2](x) # MobileViTBlock 3 x = self.mv2[7](x) # Output: channels[8], stride=2 x = self.mv2[8](x) # Output: channels[9] x = self.mv2[9](x) # Output: channels[10], stride=2 x = self.conv2(x) # Output: channels[10] # For backbone purposes, we return low_level_features and x return low_level_features, x # --------------------------- # ASPP Module # --------------------------- class ASPP(nn.Module): def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): super(ASPP, self).__init__() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch5_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() # Five branches conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) # Fifth branch: global average pooling + conv global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) # Concatenate and project feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result # --------------------------- # DeepLabV3+ with MobileViT Backbone # --------------------------- class DeepLab(nn.Module): def __init__(self, num_classes, backbone="mobilivit", pretrained_backbone=True, downsample_factor=16): super(DeepLab, self).__init__() if backbone == "xception": raise NotImplementedError("Xception backbone is not implemented in this script.") elif backbone == "mobilivit": # Initialize MobileViT backbone # Define MobileViT parameters as per your requirements dims = [96, 120, 144] channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] self.backbone = MobileViT( image_size=(256, 256), dims=dims, channels=channels, num_classes=1000, # Not used in backbone expansion=2, kernel_size=3, patch_size=(2, 2) ) in_channels = channels[-1] # 384 low_level_channels = channels[2] # 48 else: raise ValueError('Unsupported backbone - `{}`, Use mobilivit.'.format(backbone)) # ASPP Module self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16 // downsample_factor) # Shortcut convolution for low-level features self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) # Convolution layers after concatenation self.cat_conv = nn.Sequential( nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) def forward(self, x): H, W = x.size(2), x.size(3) # Backbone forward low_level_features, x = self.backbone(x) # ASPP x = self.aspp(x) # Process low-level features low_level_features = self.shortcut_conv(low_level_features) # Upsample ASPP output to match low-level features x = F.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) # Concatenate x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) # Classification layer x = self.cls_conv(x) # Final upsampling to input size x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x # --------------------------- # Parameter Counting Utility # --------------------------- def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------- # Testing the Combined Model # --------------------------- if __name__ == '__main__': # Example input img = torch.randn(1, 3, 256, 256) # Initialize DeepLab with MobileViT backbone model = DeepLab( num_classes=3, # Example for 3 classes backbone="mobilivit", # Specify the backbone pretrained_backbone=True, # Whether to use pretrained weights (if available) downsample_factor=16 # Downsample factor ) # Forward pass out = model(img) print("Output shape:", out.shape) # Expected: (1, 3, 256, 256) # Optional: Print number of parameters print("Total trainable parameters:", count_parameters(model))
MobileViT Backbone Modification:
MobileViT class is designed to extract features at different stages. For DeepLabV3+, we need two outputs:
low_level_features: Captured after the second MV2Block (channels[2] = 48).x: The final deep features after all MV2Block and MobileViTBlock layers (channels[-1] = 384).forward method of MobileViT is modified to return both low_level_features and x.DeepLabV3+ Configuration:
DeepLab class is configured to use the MobileViT backbone by specifying backbone="mobilivit".dim_in set to the number of channels from the backbone's output (384 in this case) and dim_out=256.low_level_features to reduce channel dimensions to 48.low_level_features are concatenated and passed through additional convolutional layers to produce the final segmentation map.Testing:
(1, 3, 256, 256) is passed through the combined model.3 in this example).Pretrained Weights: The pretrained_backbone=True parameter is included for scenarios where pretrained weights for MobileViT are available. Ensure that you load appropriate pretrained weights if necessary.
Backbone Flexibility: While this script focuses on integrating MobileViT as the backbone, the structure allows for easy extension to other backbones (e.g., Xception) by following a similar pattern.
Adjustable Parameters: You can modify parameters such as num_classes, image_size, and downsample_factor based on your specific requirements.
This integrated model leverages the strengths of MobileViT for efficient feature extraction and the powerful segmentation capabilities of DeepLabV3+. It is well-suited for tasks requiring accurate and efficient semantic segmentation.
C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\python.exe C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py
Traceback (most recent call last):
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 389, in <module>
out = model(img)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 347, in forward
low_level_features, x = self.backbone(x)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 223, in forward
x = self.mvit2 # MobileViTBlock 3
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 156, in forward
x = self.conv1(x)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\container.py", line 250, in forward
input = module(input)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
return self._conv_forward(input, self.weight, self.bias)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
return F.conv2d(
RuntimeError: Given groups=1, weight of size [96, 96, 3, 3], expected input[1, 80, 16, 16] to have 96 channels, but got 80 channels instead
进程已结束,退出代码为 1
The error you're encountering stems from a mismatch between the expected number of input channels for a convolutional layer and the actual number of channels being passed during the forward pass of your MobileViT backbone within the DeepLabV3+ architecture.
Error Message:
RuntimeError: Given groups=1, weight of size [96, 96, 3, 3], expected input[1, 80, 16, 16] to have 96 channels, but got 80 channels instead
Cause:
MobileViTBlock is expecting an input with 96 channels, but it is receiving an input with 80 channels.MobileViT backbone's forward pass is incorrectly ordering the application of MobileViTBlock and MV2Block layers, leading to a mismatch in channel dimensions.Let's delve into the MobileViT class to understand the flow of channels:
Channel Configuration:
pythonchannels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
Layer Initialization:
self.mv2 contains 10 MV2Block instances, transitioning through the defined channels.self.mvit contains 3 MobileViTBlock instances:
self.mvit[0] expects channels[5] = 64 channels.self.mvit[1] expects channels[7] = 80 channels.self.mvit[2] expects channels[9] = 96 channels.Forward Pass Issue:
pythonx = self.mv2[6](x) # Output channels[7] = 80 x = self.mvit[2](x) # Expects 96 channels
self.mvit[2] is being applied to an input with 80 channels instead of the expected 96 channels, causing the runtime error.Rearrange the Forward Pass Order:
To ensure that self.mvit[2] receives the correct number of input channels (96 channels), you need to adjust the order in which layers are applied in the forward method of the MobileViT class.
Correct Order:
self.mv2[6] to transition from 80 to 80 channels.self.mv2[7] to transition from 80 to 96 channels.self.mvit[2] can safely process the input with 96 channels.Updated Forward Method:
pythondef forward(self, x): x = self.conv1(x) # Output: channels[0]=16 x = self.mv2[0](x) # Output: channels[1]=32 x = self.mv2[1](x) # Output: channels[2]=48, stride=2 low_level_features = x.clone() # channels[2]=48 x = self.mv2[2](x) # Output: channels[3]=48 x = self.mv2[3](x) # Output: channels[4]=64, stride=2 x = self.mv2[4](x) # Output: channels[5]=64 x = self.mvit[0](x) # MobileViTBlock 1: Input channels=64, Output channels=64 x = self.mv2[5](x) # Output: channels[6]=80, stride=2 x = self.mvit[1](x) # MobileViTBlock 2: Input channels=80, Output channels=80 x = self.mv2[6](x) # Output: channels[7]=80 x = self.mv2[7](x) # Output: channels[8]=96 x = self.mvit[2](x) # MobileViTBlock 3: Input channels=96, Output channels=96 x = self.mv2[8](x) # Output: channels[9]=96 x = self.mv2[9](x) # Output: channels[10]=384 x = self.conv2(x) # Output: channels[10]=384 return low_level_features, x
Below is the corrected and comprehensive Python script that integrates the DeepLabV3+ architecture with a MobileViT backbone. The primary change is the rearrangement of the forward pass in the MobileViT class to ensure channel compatibility.
pythonimport torch import torch.nn as nn import torch.nn.functional as F from functools import partial from einops import rearrange # --------------------------- # MobileViT Backbone # --------------------------- def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): return nn.Sequential( nn.Conv2d(inp, oup, kernel_size, stride, kernel_size // 2, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): # x: (batch, patch, num_tokens, dim) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # (b, p, h, n, n) attn = self.attend(dots) # (b, p, h, n, n) out = torch.matmul(attn, v) # (b, p, h, n, d) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads, dim_head, dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class MV2Block(nn.Module): def __init__(self, inp, oup, stride=1, expansion=4): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileViTBlock(nn.Module): def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # Local representations x = self.conv1(x) x = self.conv2(x) # Global representations _, _, h, w = x.shape x = rearrange(x, 'b c (h ph) (w pw) -> b (ph pw) (h w) c', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) c -> b c (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x class MobileViT(nn.Module): def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): super().__init__() ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 L = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[4], channels[5], 1, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[6], channels[7], 1, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mv2.append(MV2Block(channels[8], channels[9], 1, expansion)) self.mv2.append(MV2Block(channels[9], channels[10], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) self.conv3 = conv_1x1_bn(channels[-1], num_classes) self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # Note: Global average pooling and classification layers are removed for backbone purposes def forward(self, x): x = self.conv1(x) # Output: channels[0]=16 x = self.mv2[0](x) # Output: channels[1]=32 x = self.mv2[1](x) # Output: channels[2]=48, stride=2 low_level_features = x.clone() # channels[2]=48 x = self.mv2[2](x) # Output: channels[3]=48 x = self.mv2[3](x) # Output: channels[4]=64, stride=2 x = self.mv2[4](x) # Output: channels[5]=64 x = self.mvit[0](x) # MobileViTBlock 1: Input channels=64, Output channels=64 x = self.mv2[5](x) # Output: channels[6]=80, stride=2 x = self.mvit[1](x) # MobileViTBlock 2: Input channels=80, Output channels=80 x = self.mv2[6](x) # Output: channels[7]=80 x = self.mv2[7](x) # Output: channels[8]=96 x = self.mvit[2](x) # MobileViTBlock 3: Input channels=96, Output channels=96 x = self.mv2[8](x) # Output: channels[9]=96 x = self.mv2[9](x) # Output: channels[10]=384 x = self.conv2(x) # Output: channels[10]=384 return low_level_features, x # --------------------------- # ASPP Module # --------------------------- class ASPP(nn.Module): def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): super(ASPP, self).__init__() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch5_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() # Five branches conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) # Fifth branch: global average pooling + conv global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) # Concatenate and project feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result # --------------------------- # DeepLabV3+ with MobileViT Backbone # --------------------------- class DeepLab(nn.Module): def __init__(self, num_classes, backbone="mobilivit", pretrained_backbone=True, downsample_factor=16): super(DeepLab, self).__init__() if backbone == "xception": raise NotImplementedError("Xception backbone is not implemented in this script.") elif backbone == "mobilivit": # Initialize MobileViT backbone # Define MobileViT parameters as per your requirements dims = [96, 120, 144] channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] self.backbone = MobileViT( image_size=(256, 256), dims=dims, channels=channels, num_classes=1000, # Not used in backbone expansion=2, kernel_size=3, patch_size=(2, 2) ) in_channels = channels[-1] # 384 low_level_channels = channels[2] # 48 else: raise ValueError('Unsupported backbone - `{}`, Use mobilivit.'.format(backbone)) # ASPP Module self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16 // downsample_factor) # Shortcut convolution for low-level features self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) # Convolution layers after concatenation self.cat_conv = nn.Sequential( nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) def forward(self, x): H, W = x.size(2), x.size(3) # Backbone forward low_level_features, x = self.backbone(x) # ASPP x = self.aspp(x) # Process low-level features low_level_features = self.shortcut_conv(low_level_features) # Upsample ASPP output to match low-level features x = F.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) # Concatenate x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) # Classification layer x = self.cls_conv(x) # Final upsampling to input size x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x # --------------------------- # Parameter Counting Utility # --------------------------- def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------- # Testing the Combined Model # --------------------------- if __name__ == '__main__': # Example input img = torch.randn(1, 3, 256, 256) # Initialize DeepLab with MobileViT backbone model = DeepLab( num_classes=3, # Example for 3 classes backbone="mobilivit", # Specify the backbone pretrained_backbone=True, # Whether to use pretrained weights (if available) downsample_factor=16 # Downsample factor ) # Forward pass out = model(img) print("Output shape:", out.shape) # Expected: (1, 3, 256, 256) # Optional: Print number of parameters print("Total trainable parameters:", count_parameters(model))
Rearranged Forward Pass in MobileViT:
pythonx = self.mv2[6](x) # channels[7]=80 x = self.mvit[2](x) # Expects 96 channels x = self.mv2[7](x) # channels[8]=96
pythonx = self.mv2[6](x) # channels[7]=80 x = self.mv2[7](x) # channels[8]=96 x = self.mvit[2](x) # Expects 96 channels
self.mvit[2] receives the correct number of input channels (96 channels).Ensured Channel Consistency:
self.mv2[7], the output channels match the expected input channels for self.mvit[2].Removed Unused Layers in MobileViT:
self.conv3 and self.conv4 layers within MobileViT are retained as they are essential for the MobileViT block's functionality.Adjusted the DeepLab Class Initialization:
low_level_channels and in_channels are correctly set based on the channels list.After implementing the above changes, running the script should yield an output shape that matches the input spatial dimensions with the specified number of classes. Additionally, the total number of trainable parameters will be printed for reference.
Expected Output:
textOutput shape: torch.Size([1, 3, 256, 256]) Total trainable parameters: <number_of_parameters>
Pretrained Weights:
MobileViT, ensure they are correctly loaded to leverage transfer learning benefits.Backbone Flexibility:
MobileViT, you can extend it to support other backbones like Xception by following a similar pattern.Hyperparameter Tuning:
dims, channels, expansion, kernel_size, and patch_size for optimal performance.Debugging Tips:
Modular Design:
MobileViT and ASPP modular enhances readability and maintainability, making it easier to debug and extend the architecture.By implementing the corrected forward pass order, your DeepLabV3+ model with a MobileViT backbone should function correctly without channel mismatches.
C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\python.exe C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py
Traceback (most recent call last):
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 386, in <module>
out = model(img)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 346, in forward
low_level_features, x = self.backbone(x)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\PycharmProjects\deeplabv3-plus-pytorch-main\nets\mobileVit_deeplabV3+.py", line 229, in forward
x = self.conv2(x) # Output: channels[10]=384
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\container.py", line 250, in forward
input = module(input)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
return self._conv_forward(input, self.weight, self.bias)
File "C:\Users\wztbkl\anaconda3\envs\wuzhupytorch\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
return F.conv2d(
RuntimeError: Given groups=1, weight of size [384, 96, 1, 1], expected input[1, 384, 4, 4] to have 96 channels, but got 384 channels instead
进程已结束,退出代码为 1
中文回答
出现的错误是由于 MobileViT 主干网络中的卷积层 conv2 期望输入的通道数为 96,但实际接收到的输入通道数为 384。这导致了通道不匹配的问题。
RuntimeError: Given groups=1, weight of size [384, 96, 1, 1], expected input[1, 384, 4, 4] to have 96 channels, but got 384 channels instead
原因分析:
conv2 定义:
pythonself.conv2 = conv_1x1_bn(channels[-2], channels[-1])
channels[-2] = 96channels[-1] = 384MobileViT 前向传播:
pythonx = self.mv2[9](x) # Output: channels[10]=384 x = self.conv2(x) # 期望输入通道数=96,但实际输入通道数=384
这里,conv2 被应用于一个具有 384 个通道的输入,而它期望的输入通道数是 96,因此引发了运行时错误。
要解决这个问题,可以采取以下步骤:
移除 conv2 层:
conv2 层在当前架构中是不必要的,因为主干网络已经输出了所需的 384 个通道。这意味着可以安全地移除 conv2,以避免通道不匹配的问题。修改 MobileViT 类:
__init__ 方法中删除 conv2 的定义。forward 方法中删除 conv2 的调用。确保 DeepLab 类的初始化参数正确:
in_channels 设置为 384,与主干网络的最终输出通道数一致。以下是修正后的代码,已移除 MobileViT 主干网络中的 conv2 层,并确保通道数一致性:
pythonimport torch import torch.nn as nn import torch.nn.functional as F from functools import partial from einops import rearrange # --------------------------- # MobileViT Backbone # --------------------------- def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): return nn.Sequential( nn.Conv2d(inp, oup, kernel_size, stride, kernel_size // 2, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): # x: (batch, patch, num_tokens, dim) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # (b, p, h, n, n) attn = self.attend(dots) # (b, p, h, n, n) out = torch.matmul(attn, v) # (b, p, h, n, d) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads, dim_head, dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class MV2Block(nn.Module): def __init__(self, inp, oup, stride=1, expansion=4): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileViTBlock(nn.Module): def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # Local representations x = self.conv1(x) x = self.conv2(x) # Global representations _, _, h, w = x.shape x = rearrange(x, 'b c (h ph) (w pw) -> b (ph pw) (h w) c', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) c -> b c (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x class MobileViT(nn.Module): def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): super().__init__() ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 L = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[4], channels[5], 1, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[6], channels[7], 1, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mv2.append(MV2Block(channels[8], channels[9], 1, expansion)) self.mv2.append(MV2Block(channels[9], channels[10], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) # 移除 conv2 层 # self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # 注意:全局平均池化和分类层被移除,仅用于特征提取 def forward(self, x): x = self.conv1(x) # Output: channels[0]=16 x = self.mv2[0](x) # Output: channels[1]=32 x = self.mv2[1](x) # Output: channels[2]=48, stride=2 low_level_features = x.clone() # channels[2]=48 x = self.mv2[2](x) # Output: channels[3]=48 x = self.mv2[3](x) # Output: channels[4]=64, stride=2 x = self.mv2[4](x) # Output: channels[5]=64 x = self.mvit[0](x) # MobileViTBlock 1: Input channels=64, Output channels=64 x = self.mv2[5](x) # Output: channels[6]=80, stride=2 x = self.mvit[1](x) # MobileViTBlock 2: Input channels=80, Output channels=80 x = self.mv2[6](x) # Output: channels[7]=80 x = self.mv2[7](x) # Output: channels[8]=96 x = self.mvit[2](x) # MobileViTBlock 3: Input channels=96, Output channels=96 x = self.mv2[8](x) # Output: channels[9]=96 x = self.mv2[9](x) # Output: channels[10]=384 # 移除 conv2 的调用 # x = self.conv2(x) # Output: channels[10]=384 return low_level_features, x # --------------------------- # ASPP Module # --------------------------- class ASPP(nn.Module): def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): super(ASPP, self).__init__() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch5_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() # 五个分支 conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) # 第五个分支:全局平均池化 + 卷积 global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) # 拼接并投影 feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result # --------------------------- # DeepLabV3+ with MobileViT Backbone # --------------------------- class DeepLab(nn.Module): def __init__(self, num_classes, backbone="mobilivit", pretrained_backbone=True, downsample_factor=16): super(DeepLab, self).__init__() if backbone == "xception": raise NotImplementedError("Xception backbone is not implemented in this script.") elif backbone == "mobilivit": # 初始化 MobileViT 主干网络 dims = [96, 120, 144] channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] self.backbone = MobileViT( image_size=(256, 256), dims=dims, channels=channels, num_classes=1000, # 主干网络中不使用 expansion=2, kernel_size=3, patch_size=(2, 2) ) in_channels = channels[-1] # 384 low_level_channels = channels[2] # 48 else: raise ValueError('Unsupported backbone - `{}`, Use mobilivit.'.format(backbone)) # ASPP 模块 self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16 // downsample_factor) # 低层特征的卷积 self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) # 拼接后的卷积层 self.cat_conv = nn.Sequential( nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) def forward(self, x): H, W = x.size(2), x.size(3) # 主干网络前向传播 low_level_features, x = self.backbone(x) # ASPP x = self.aspp(x) # 处理低层特征 low_level_features = self.shortcut_conv(low_level_features) # 上采样 ASPP 输出以匹配低层特征 x = F.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) # 拼接 x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) # 分类层 x = self.cls_conv(x) # 最终上采样到输入尺寸 x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x # --------------------------- # 参数统计工具 # --------------------------- def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------- # 测试组合模型 # --------------------------- if __name__ == '__main__': # 示例输入 img = torch.randn(1, 3, 256, 256) # 初始化 DeepLab 模型,使用 MobileViT 主干网络 model = DeepLab( num_classes=3, # 例如3个类别 backbone="mobilivit", # 指定主干网络 pretrained_backbone=True, # 是否使用预训练权重(如果有的话) downsample_factor=16 # 下采样因子 ) # 前向传播 out = model(img) print("输出形状:", out.shape) # 预期输出: (1, 3, 256, 256) # 可选:打印可训练参数数量 print("可训练参数总数:", count_parameters(model))
移除 conv2 层:
MobileViT 类的 __init__ 方法中,移除了 self.conv2 的定义。MobileViT 类的 forward 方法中,移除了 x = self.conv2(x) 的调用。确保通道数一致性:
MobileViT 主干网络在输出时直接返回 384 个通道,与 DeepLab 类中的 ASPP 模块和后续处理步骤一致。运行上述修正后的脚本,应该不会再出现通道不匹配的错误,且输出形状应与输入形状一致:
text输出形状: torch.Size([1, 3, 256, 256]) 可训练参数总数: <参数数量>
加载预训练权重:
MobileViT 的预训练权重,建议在初始化 MobileViT 主干网络时加载这些权重,以提升模型性能。扩展主干网络支持:
MobileViT 主干网络,但可以根据需要添加对其他主干网络(如 Xception)的支持。调试技巧:
pythonprint("Low level features shape:", low_level_features.shape) print("ASPP output shape:", x.shape)
参数调整:
dims, channels, expansion, kernel_size, patch_size 等参数,以获得最佳性能。模型优化:
通过上述修改和建议,您的 DeepLabV3+ 模型与 MobileViT 主干网络的集成应该能够顺利运行,并且具备良好的性能表现。
import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
class PreNorm(nn.Module):
def init(self, dim, fn):
super().init()
self.norm = nn.LayerNorm(dim)
self.fn = fn
textdef forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def init(self, dim, hidden_dim, dropout=0.):
super().init()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
textdef forward(self, x): return self.net(x)
class Attention(nn.Module):
def init(self, dim, heads=8, dim_head=64, dropout=0.):
super().init()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
textself.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out)
class Transformer(nn.Module):
def init(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().init()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
textdef forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x
class MV2Block(nn.Module):
def init(self, inp, oup, stride=1, expansion=4):
super().init()
self.stride = stride
assert stride in [1, 2]
texthidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x)
class MobileViTBlock(nn.Module):
def init(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().init()
self.ph, self.pw = patch_size
textself.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # Local representations x = self.conv1(x) x = self.conv2(x) # Global representations _, _, h, w = x.shape x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x
class MobileViT(nn.Module):
def init(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
super().init()
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
textL = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) # 移除全局平均池化层 # self.pool = nn.AvgPool2d(ih//32, 1) # self.fc = nn.Linear(channels[-1], num_classes, bias=False) def forward(self, x): x = self.conv1(x) x = self.mv2[0](x) # print("1:", x.shape) x = self.mv2[1](x) x = self.mv2[2](x) x = self.mv2[3](x) # Repeat # print("2:", x.shape) x = self.mv2[4](x) x = self.mvit[0](x) # print("3:", x.shape) x = self.mv2[5](x) x = self.mvit[1](x) # print("4:", x.shape) x = self.mv2[6](x) x = self.mvit[2](x) x = self.conv2(x) # print("5:", x.shape) # 注释掉全局平均池化层 # x = self.pool(x).view(-1, x.shape[1]) # x = self.fc(x) return x
def mobilevit_xxs():
dims = [64, 80, 96]
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
def mobilevit_xs():
dims = [96, 120, 144]
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
return MobileViT((256, 256), dims, channels, num_classes=1000)
def mobilevit_s():
dims = [144, 192, 240]
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
return MobileViT((256, 256), dims, channels, num_classes=1000)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if name == 'main':
img = torch.randn(1, 3, 256, 256)
model = MobileViT(image_size=(512,512), dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],num_classes=3)
out= model(img)
print(out.shape)
text# vit = mobilevit_xxs() # out = vit(img) # print(out.shape) # print(count_parameters(vit)) # # vit = mobilevit_xs() # out = vit(img) # print(out.shape) # print(count_parameters(vit)) # # vit = mobilevit_s() # out = vit(img) # print(out.shape) # print(count_parameters(vit))
上面的mobileVit代码错了,用这个重新写一个。用中文注释
了解您在整合 MobileViT 与 DeepLabV3+ 时遇到了通道不匹配的问题。下面,我将为您重新编写一个正确的 MobileViT 模型,并添加详细的中文注释,以确保其与 DeepLabV3+ 主干网络的集成能够顺利进行。
移除不必要的 conv2 层:
conv2 层导致了通道数不匹配的问题。由于主干网络的最终输出已经满足 DeepLabV3+ 的需求,因此不需要额外的 conv2 层。确保通道数一致性:
调整 MobileViT 的前向传播:
MobileViT 返回两个特征图:low_level_features 和 x,以便 DeepLabV3+ 能够正确地处理这些特征图。pythonimport torch import torch.nn as nn import torch.nn.functional as F from functools import partial from einops import rearrange # --------------------------- # 基础卷积模块 # --------------------------- def conv_1x1_bn(inp, oup): """ 1x1 卷积 + 批归一化 + 激活函数 """ return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): """ NxN 卷积 + 批归一化 + 激活函数 """ padding = kernel_size // 2 return nn.Sequential( nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) # --------------------------- # Transformer 相关模块 # --------------------------- class PreNorm(nn.Module): """ 在执行函数前进行层归一化 """ def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): """ 前馈神经网络模块 """ def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): """ 多头自注意力机制 """ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): # x: (batch, patch, num_tokens, dim) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # (b, p, h, n, n) attn = self.attend(dots) # (b, p, h, n, n) out = torch.matmul(attn, v) # (b, p, h, n, d) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) class Transformer(nn.Module): """ Transformer 编码器 """ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads, dim_head, dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x # --------------------------- # MobileViT 模块 # --------------------------- class MV2Block(nn.Module): """ MobileNetV2 的倒残差块 """ def __init__(self, inp, oup, stride=1, expansion=4): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # 深度卷积 nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 点卷积 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # 点卷积 nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 深度卷积 nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 点卷积 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileViTBlock(nn.Module): """ MobileViT 块,包括局部和全局特征提取 """ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # 局部特征提取 x = self.conv1(x) x = self.conv2(x) # 全局特征提取 _, _, h, w = x.shape x = rearrange(x, 'b c (h ph) (w pw) -> b (ph pw) (h w) c', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) c -> b c (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # 特征融合 x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x class MobileViT(nn.Module): """ MobileViT 主干网络 """ def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): super().__init__() ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 # MobileViT 块的深度 L = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) # 初始卷积 self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[4], channels[5], 1, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[6], channels[7], 1, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mv2.append(MV2Block(channels[8], channels[9], 1, expansion)) self.mv2.append(MV2Block(channels[9], channels[10], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) # 移除 conv2 层 # self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) # 上采样模块 self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # 注意:全局平均池化和分类层被移除,仅用于特征提取 def forward(self, x): x = self.conv1(x) # 输出通道: channels[0]=16 x = self.mv2[0](x) # 输出通道: channels[1]=32 x = self.mv2[1](x) # 输出通道: channels[2]=48, stride=2 low_level_features = x.clone() # 低层特征: channels[2]=48 x = self.mv2[2](x) # 输出通道: channels[3]=48 x = self.mv2[3](x) # 输出通道: channels[4]=64, stride=2 x = self.mv2[4](x) # 输出通道: channels[5]=64 x = self.mvit[0](x) # MobileViTBlock 1: 输入通道=64, 输出通道=64 x = self.mv2[5](x) # 输出通道: channels[6]=80, stride=2 x = self.mvit[1](x) # MobileViTBlock 2: 输入通道=80, 输出通道=80 x = self.mv2[6](x) # 输出通道: channels[7]=80 x = self.mv2[7](x) # 输出通道: channels[8]=96 x = self.mvit[2](x) # MobileViTBlock 3: 输入通道=96, 输出通道=96 x = self.mv2[8](x) # 输出通道: channels[9]=96 x = self.mv2[9](x) # 输出通道: channels[10]=384 # 移除 conv2 的调用 # x = self.conv2(x) # 输出通道: channels[10]=384 return low_level_features, x # 返回低层特征和深层特征 def mobilevit_xxs(): """ 创建 MobileViT_xxs 模型 """ dims = [64, 80, 96] channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2) def mobilevit_xs(): """ 创建 MobileViT_xs 模型 """ dims = [96, 120, 144] channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] return MobileViT((256, 256), dims, channels, num_classes=1000) def mobilevit_s(): """ 创建 MobileViT_s 模型 """ dims = [144, 192, 240] channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] return MobileViT((256, 256), dims, channels, num_classes=1000) # --------------------------- # 参数统计工具 # --------------------------- def count_parameters(model): """ 计算模型的可训练参数总数 """ return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------- # 测试 MobileViT 模型 # --------------------------- if __name__ == '__main__': # 示例输入 img = torch.randn(1, 3, 256, 256) # 初始化 MobileViT 模型 model = MobileViT( image_size=(256, 256), dims=[96, 120, 144], channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], num_classes=3, # 注意:主干网络不需要此参数 expansion=2, kernel_size=3, patch_size=(2, 2) ) # 前向传播 low_level_features, x = model(img) print("低层特征形状:", low_level_features.shape) # 预期: [1, 48, 32, 32] print("深层特征形状:", x.shape) # 预期: [1, 384, 8, 8] # 打印可训练参数数量 print("可训练参数总数:", count_parameters(model))
基础卷积模块 (conv_1x1_bn 和 conv_nxn_bn):
conv_1x1_bn: 1x1 卷积用于调整通道数,随后进行批归一化和激活函数。conv_nxn_bn: 任意大小的卷积(默认为3x3),用于提取空间特征,随后进行批归一化和激活函数。Transformer 相关模块:
PreNorm: 在执行核心函数(如注意力机制或前馈网络)之前,先进行层归一化。FeedForward: 前馈神经网络模块,包含两层线性变换和中间的激活函数及 dropout。Attention: 多头自注意力机制,计算 Query、Key、Value 并进行注意力加权。Transformer: 由多个 Attention 和 FeedForward 组成的 Transformer 编码器。MobileViT 模块:
MV2Block: MobileNetV2 的倒残差块,用于特征提取和通道数的调整。MobileViTBlock: 集成了局部卷积特征提取和全局 Transformer 特征提取的模块。MobileViT: 主干网络,由多个 MV2Block 和 MobileViTBlock 组成,负责提取多层次的特征图。MobileViT 模型创建函数:
mobilevit_xxs, mobilevit_xs, mobilevit_s: 分别创建不同规模的 MobileViT 模型,调整 dims 和 channels 以适应不同的需求。参数统计工具 (count_parameters):
测试部分:
运行上述代码后,您应该会看到类似如下的输出:
text低层特征形状: torch.Size([1, 48, 32, 32]) 深层特征形状: torch.Size([1, 384, 8, 8]) 可训练参数总数: <参数数量>
上述修正后的 MobileViT 模型现在可以作为 DeepLabV3+ 的主干网络进行集成。确保在 DeepLabV3+ 的实现中正确地调用 low_level_features 和 x 这两个特征图,以便 ASPP 模块和后续的融合步骤能够顺利进行。
参数一致性:
dims 和 channels 的配置与您希望的 MobileViT 架构一致。不同的配置会影响模型的性能和计算复杂度。预训练权重:
适应不同任务:
调试建议:
通过上述修正和说明,您的 MobileViT 模型应能正确运行,并与 DeepLabV3+ 主干网络顺利集成。如果在集成过程中仍然遇到问题,请随时提供更多详细信息,以便进一步帮助您解决问题。
将这两者结合需要移除mobieVit的全局平均池化层嘛
是的,当将 MobileViT 作为 DeepLabV3+ 的主干网络时,确实需要移除 MobileViT 中的全局平均池化(Global Average Pooling)层以及全连接(Fully Connected)层。这是因为全局平均池化和全连接层主要用于分类任务,它们会丢失空间位置信息,而语义分割任务需要保留空间信息以生成高分辨率的分割图。
下面,我将为您提供一个修正后的 MobileViT 模型代码,移除了全局平均池化层和全连接层,并添加了详细的中文注释。这个模型设计用于作为 DeepLabV3+ 的主干网络,输出必要的特征图以供后续的分割模块使用。
pythonimport torch import torch.nn as nn import torch.nn.functional as F from functools import partial from einops import rearrange # --------------------------- # 基础卷积模块 # --------------------------- def conv_1x1_bn(inp, oup): """ 1x1 卷积 + 批归一化 + 激活函数 """ return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): """ NxN 卷积 + 批归一化 + 激活函数 """ padding = kernel_size // 2 return nn.Sequential( nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=False), nn.BatchNorm2d(oup), nn.SiLU() ) # --------------------------- # Transformer 相关模块 # --------------------------- class PreNorm(nn.Module): """ 在执行函数前进行层归一化 """ def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): """ 前馈神经网络模块 """ def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): """ 多头自注意力机制 """ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): # x: (batch, patch, num_tokens, dim) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # (b, p, h, n, n) attn = self.attend(dots) # (b, p, h, n, n) out = torch.matmul(attn, v) # (b, p, h, n, d) out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) class Transformer(nn.Module): """ Transformer 编码器 """ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads, dim_head, dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x # --------------------------- # MobileViT 模块 # --------------------------- class MV2Block(nn.Module): """ MobileNetV2 的倒残差块 """ def __init__(self, inp, oup, stride=1, expansion=4): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(inp * expansion) self.use_res_connect = self.stride == 1 and inp == oup if expansion == 1: self.conv = nn.Sequential( # 深度卷积 nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 点卷积 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) else: self.conv = nn.Sequential( # 点卷积 nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 深度卷积 nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # 点卷积 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileViTBlock(nn.Module): """ MobileViT 块,包括局部和全局特征提取 """ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): super().__init__() self.ph, self.pw = patch_size self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) self.transformer = Transformer(dim, depth, heads=4, dim_head=8, mlp_dim=mlp_dim, dropout=dropout) self.conv3 = conv_1x1_bn(dim, channel) self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() # 局部特征提取 x = self.conv1(x) x = self.conv2(x) # 全局特征提取 _, _, h, w = x.shape x = rearrange(x, 'b c (h ph) (w pw) -> b (ph pw) (h w) c', ph=self.ph, pw=self.pw) x = self.transformer(x) x = rearrange(x, 'b (ph pw) (h w) c -> b c (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) # 特征融合 x = self.conv3(x) x = torch.cat((x, y), 1) x = self.conv4(x) return x class MobileViT(nn.Module): """ MobileViT 主干网络,用于 DeepLabV3+ 等语义分割任务 """ def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): super().__init__() ih, iw = image_size ph, pw = patch_size assert ih % ph == 0 and iw % pw == 0 # MobileViT 块的深度 L = [2, 4, 3] self.conv1 = conv_nxn_bn(3, channels[0], stride=2) # 初始卷积,输出 channels[0] self.mv2 = nn.ModuleList([]) self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) self.mv2.append(MV2Block(channels[4], channels[5], 1, expansion)) self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) self.mv2.append(MV2Block(channels[6], channels[7], 1, expansion)) self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) self.mv2.append(MV2Block(channels[8], channels[9], 1, expansion)) self.mv2.append(MV2Block(channels[9], channels[10], 2, expansion)) self.mvit = nn.ModuleList([]) self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))) self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))) self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))) # 移除 conv2 层,因为它会导致通道不匹配 # self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) # 上采样模块(可选,根据需要使用) self.upsample = nn.Upsample(size=image_size, mode='bilinear', align_corners=False) # 注意:全局平均池化和分类层被移除,仅用于特征提取 def forward(self, x): x = self.conv1(x) # 输出通道: channels[0] x = self.mv2[0](x) # 输出通道: channels[1] x = self.mv2[1](x) # 输出通道: channels[2], stride=2 low_level_features = x.clone() # 低层特征: channels[2] x = self.mv2[2](x) # 输出通道: channels[3] x = self.mv2[3](x) # 输出通道: channels[4], stride=2 x = self.mv2[4](x) # 输出通道: channels[5] x = self.mvit[0](x) # MobileViTBlock 1: 输入通道=channels[5], 输出通道=channels[5] x = self.mv2[5](x) # 输出通道: channels[6], stride=2 x = self.mvit[1](x) # MobileViTBlock 2: 输入通道=channels[7], 输出通道=channels[7] x = self.mv2[6](x) # 输出通道: channels[7] x = self.mv2[7](x) # 输出通道: channels[8], stride=2 x = self.mvit[2](x) # MobileViTBlock 3: 输入通道=channels[9], 输出通道=channels[9] x = self.mv2[8](x) # 输出通道: channels[9] x = self.mv2[9](x) # 输出通道: channels[10] # 移除 conv2 的调用 # x = self.conv2(x) # 输出通道: channels[10] return low_level_features, x # 返回低层特征和深层特征 # --------------------------- # 参数统计工具 # --------------------------- def count_parameters(model): """ 计算模型的可训练参数总数 """ return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------- # 测试 MobileViT 模型 # --------------------------- if __name__ == '__main__': # 示例输入 img = torch.randn(1, 3, 256, 256) # 初始化 MobileViT 模型 model = MobileViT( image_size=(256, 256), dims=[96, 120, 144], channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], num_classes=3, # 注意:主干网络不需要此参数 expansion=2, kernel_size=3, patch_size=(2, 2) ) # 前向传播 low_level_features, x = model(img) print("低层特征形状:", low_level_features.shape) # 预期: [1, 48, 32, 32] print("深层特征形状:", x.shape) # 预期: [1, 384, 8, 8] # 打印可训练参数数量 print("可训练参数总数:", count_parameters(model))
移除全局平均池化层和全连接层:
python# 移除全局平均池化层和全连接层 # self.pool = nn.AvgPool2d(ih//32, 1) # self.fc = nn.Linear(channels[-1], num_classes, bias=False)
移除 conv2 层:
conv2 层导致了通道数不匹配的问题。为了避免这一问题,conv2 层及其调用已经被移除。python# 移除 conv2 层 # self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
python# 移除 conv2 的调用 # x = self.conv2(x)
确保 MobileViT 返回两个特征图:
low_level_features 和 x,以便 DeepLabV3+ 能够进行后续的特征融合和分割。forward 方法中,low_level_features 在第二个 MV2Block 之后被捕获并返回。详细的中文注释:
运行上述修正后的代码,您应该会看到如下输出:
text低层特征形状: torch.Size([1, 48, 32, 32]) 深层特征形状: torch.Size([1, 384, 8, 8]) 可训练参数总数: <参数数量>
[1, 48, 32, 32] 表示从主干网络中提取的浅层特征,用于后续的特征融合。[1, 384, 8, 8] 表示从主干网络中提取的深层特征,用于高级语义特征的提取。确保通道数匹配:
in_channels 和 low_level_channels 参数与 MobileViT 主干网络的输出通道数一致。例如:
pythonin_channels = 384 # 深层特征的通道数 low_level_channels = 48 # 低层特征的通道数
特征图的空间尺寸:
low_level_features 应具有较高的分辨率,而深层特征图具有较低的分辨率。后续的特征融合:
low_level_features 和 x 进行 ASPP 模块的处理和特征融合。移除或调整其它不必要的层:
通过移除 MobileViT 主干网络中的全局平均池化层和全连接层,并确保特征图的通道数和空间尺寸与 DeepLabV3+ 的需求一致,您可以顺利地将 MobileViT 集成到 DeepLabV3+ 中,实现高效且精确的语义分割模型。
如果在集成过程中仍然遇到问题,建议逐步调试每一层的输出形状,确保特征图的通道数和空间尺寸与预期一致。例如,可以在关键步骤添加打印语句:
pythonprint("Low level features shape:", low_level_features.shape) print("Deep features shape:", x.shape)
这样可以帮助您快速定位和解决维度不匹配的问题。