复现:Residual channel prior-guided multi-scale progressive dehazing network with hybrid attention
发布时间
阅读量:
阅读量
Xing, Y., Zhang, J. Residual channel prior-guided multi-scale progressive dehazing network with hybrid attention. Multimedia Systems 31, 161 (2025). https://doi.org/10.1007/s00530-025-01754-0
复现代码见最后
文章目录
- 研究背景与动机
- 研究方法
- 实验
- 关键结论
- 研究贡献
- 复现
研究背景与动机
- 去雾任务的重要性 :在雾霾天气下,空气中的气体和液体颗粒增多,成像设备获取的图像对比度、亮度和可见度降低,影响后续智能系统的有效性。因此,单图像去雾对于提高图像质量、保障计算机视觉任务的准确性和可靠性具有重要意义。
- 现有方法的局限性 :早期基于大气散射模型的去雾方法易受不同场景先验的影响,鲁棒性差。而基于深度学习的去雾方法虽能提升去雾效果,但一些方法为提高性能盲目增加网络参数,导致计算开销大、参数冗余,且存在内存消耗大、速度慢等问题。

研究方法
- MPDNet 网络架构 :提出了残差通道先验引导的多尺度渐进去雾网络 MPDNet,其基于编码器 - 解码器结构,融合残差学习和注意力机制。网络通过下采样降低特征图分辨率,再利用特征细化模块逐步恢复低分辨率的无雾图像,同时采用大核设计增加模型感受野以捕获全局特征,提升特征学习能力。
- 残差通道先验 :发现雾图像的残差通道先验包含丰富的结构信息,设计先验引导块(PGB),在不同尺度上提取残差通道先验图,以更好地引导网络学习雾的结构信息,帮助网络更准确地识别和去除雾成分。
- 混合注意力机制 :引入混合注意力机制,自适应地调整空间权重和通道权重,使网络更加关注与任务相关的对象和区域,从而实现更准确的去雾效果,同时减少无关信息的干扰,提高特征表示的质量和有效性。
- 多尺度渐进去雾策略 :采用多尺度渐进去雾策略,从不同尺度对雾图像进行处理,逐步去除雾成分。这种策略能够充分利用不同尺度的特征信息,更好地捕捉图像中的细节和全局结构,提升去雾效果的准确性和鲁棒性。


实验
- 实验设置 :使用 Haze4K 数据集进行训练和测试,该数据集从 NYU-Depth 和 OTS 数据集中随机选取 500 张室内图和 500 张室外图,生成 4000 张雾图像,其中 3000 张用于训练,1000 张用于测试。实验平台采用 Windows10 操作系统、Intel® Core ™ i9-9900 CPU、NVIDIA RTX 2070SUPER GPU,开发环境为 Python 3.8、Pytorch 1.11.0、CUDA 11.3。训练时,将图像随机裁剪为 256×256 大小,使用随机翻转和裁剪等数据增强技术,并采用 AdamW 优化器,初始学习率为 0.0001,使用余弦退火学习率调整策略,共训练 300 个 epoch,批量大小为 16。
- 评估指标 :采用峰值信噪比(PSNR)和结构相似性(SSIM)作为客观评估指标,PSNR 用于评估图像在像素级别的质量,值越高表示图像质量越好;SSIM 从亮度、对比度和结构三个方面衡量图像相似度,值越大表示保留的结构信息越多。
- 实验结果 :实验结果表明,MPDNet 在合成数据集和真实数据集上均取得了较好的去雾效果,能够有效去除雾成分,恢复出清晰、自然的图像。与其他方法相比,MPDNet 在 PSNR 和 SSIM 等评估指标上均表现出色,且生成的图像在视觉效果上更加逼真,细节更加丰富,颜色更加准确。


关键结论
- MPDNet 通过结合残差通道先验和混合注意力机制,在多尺度渐进去雾策略的指导下,能够有效地去除雾图像中的雾成分,恢复出高质量的无雾图像,为单图像去雾任务提供了一种有效的解决方案。
- 残差通道先验的引入为网络提供了丰富的结构信息,有助于更好地学习雾的特征和分布,从而提高去雾的准确性。混合注意力机制则增强了网络对关键区域和特征的关注能力,进一步提升了去雾效果的质量。
- 多尺度渐进去雾策略能够充分利用不同尺度的特征信息,使网络在处理复杂雾图像时更具优势,能够更好地捕捉图像中的细节和全局结构,提高去雾结果的鲁棒性和稳定性。
研究贡献
- 提出了一种新颖的单图像去雾网络 MPDNet,该网络在编码器 - 解码器结构的基础上,创新性地融合了残差学习、注意力机制、残差通道先验以及多尺度渐进去雾策略等多种先进技术,为单图像去雾领域提供了一种新的有效方法。
- 引入残差通道先验和先验引导块,为利用先验信息指导网络学习提供了一种新的思路和方法,能够更好地挖掘雾图像中的有用信息,提升去雾效果。
- 设计的混合注意力机制能够自适应地调整空间和通道权重,使网络更加关注与去雾任务相关的区域和特征,提高了特征表示的针对性和有效性,为提升去雾网络的性能提供了一种有效的手段。
- 构建的多尺度渐进去雾策略能够从不同尺度对雾图像进行处理,逐步去除雾成分,充分利用了不同尺度的特征信息,提高了去雾结果的准确性和鲁棒性,为解决复杂雾图像的去雾问题提供了一种有效的途径。
复现
由于篇幅限制,以下代码仅包含核心模块的实现,完整代码需要根据论文细节进一步完善。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义多尺度特征提取模块 (MSFEM)
class MSFEM(nn.Module):
def __init__(self):
super(MSFEM, self).__init__()
self.conv3x3 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.dconv1 = nn.Conv2d(64, 64, kernel_size=3, dilation=1, padding=1)
self.dconv3 = nn.Conv2d(64, 64, kernel_size=3, dilation=3, padding=3)
self.dconv5 = nn.Conv2d(64, 64, kernel_size=3, dilation=5, padding=5)
self.conv1x1 = nn.Conv2d(64*4, 64, kernel_size=1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.ca = ChannelAttention(64)
self.pa = PixelAttention(64)
def forward(self, x):
x = self.lrelu(self.conv3x3(x))
feat1 = self.lrelu(self.dconv1(x))
feat3 = self.lrelu(self.dconv3(x))
feat5 = self.lrelu(self.dconv5(x))
feat_concat = torch.cat([x, feat1, feat3, feat5], dim=1)
feat = self.lrelu(self.conv1x1(feat_concat))
feat_ca = self.ca(feat)
feat_pa = self.pa(feat_ca)
return feat_pa
# 定义通道注意力模块 (CA)
class ChannelAttention(nn.Module):
def __init__(self, num_channels):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(num_channels, num_channels//16, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels//16, num_channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.sigmoid(y)
return x * y
# 定义像素注意力模块 (PA)
class PixelAttention(nn.Module):
def __init__(self, num_channels):
super(PixelAttention, self).__init__()
self.dconv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, dilation=1, padding=1)
self.dconv3 = nn.Conv2d(num_channels, num_channels, kernel_size=3, dilation=3, padding=3)
self.dconv5 = nn.Conv2d(num_channels, num_channels, kernel_size=3, dilation=5, padding=5)
self.conv = nn.Conv2d(num_channels*3, num_channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
feat1 = self.dconv1(x)
feat3 = self.dconv3(x)
feat5 = self.dconv5(x)
feat_concat = torch.cat([feat1, feat3, feat5], dim=1)
feat = self.conv(feat_concat)
feat = self.sigmoid(feat)
return x * feat
# 定义先验引导块 (PGB)
class PGB(nn.Module):
def __init__(self):
super(PGB, self).__init__()
self.conv3x3 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.msfem = MSFEM()
def forward(self, x):
rcp_map = torch.max(x[:, :3, :, :], dim=1, keepdim=True)[0] - torch.min(x[:, :3, :, :], dim=1, keepdim=True)[0]
feat = self.conv3x3(rcp_map)
feat = self.msfem(feat)
return feat
# 定义多尺度去雾单元 (MDU)
class MDU(nn.Module):
def __init__(self):
super(MDU, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.msfem = MSFEM()
self.conv3x3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
def forward(self, x):
enc = self.encoder(x)
dec = self.decoder(enc)
feat = self.msfem(dec)
out = self.conv3x3(feat)
return out
# 定义注意力引导特征记忆模块 (AFM)
class AFM(nn.Module):
def __init__(self):
super(AFM, self).__init__()
self.lstm = nn.LSTM(64, 64, batch_first=True)
self.iafm = InteractiveAttentionFusionModule()
def forward(self, x, h_t_1):
x = x.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, x.size(1))
h_t, _ = self.lstm(x, (h_t_1.contiguous().view(1, x.size(0), -1), torch.zeros(1, x.size(0), 64).to(x.device)))
h_t = h_t.view(x.size(0), x.size(1), x.size(2), -1).permute(0, 3, 1, 2).contiguous()
fused_feat = self.iafm(x.permute(0, 3, 1, 2).contiguous(), h_t)
return fused_feat, h_t
# 定义交互式注意力融合模块 (IAFM)
class InteractiveAttentionFusionModule(nn.Module):
def __init__(self):
super(InteractiveAttentionFusionModule, self).__init__()
self.conv_x = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv_h = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv1x1 = nn.Conv2d(64, 64, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x, h_t_1):
x_conv = self.conv_x(x)
h_conv = self.conv_h(h_t_1)
similarity_map = x_conv * h_conv
fusion_weights = self.sigmoid(self.conv1x1(self.relu(similarity_map)))
fused_feat = torch.cat([(x * fusion_weights + x), (h_t_1 * fusion_weights + h_t_1)], dim=1)
return fused_feat
# 定义 MPDNet 网络
class MPDNet(nn.Module):
def __init__(self, num_stages=5):
super(MPDNet, self).__init__()
self.num_stages = num_stages
self.afm = AFM()
self.pgb = PGB()
self.mdu = MDU()
self.h_t = None
def forward(self, x):
for _ in range(self.num_stages):
if self.h_t is None:
h_t_1 = torch.zeros(x.size(0), 64, x.size(2), x.size(3)).to(x.device)
else:
h_t_1 = self.h_t
fused_feat, self.h_t = self.afm(x, h_t_1)
rcp_feat = self.pgb(fused_feat)
x = self.mdu(rcp_feat)
return x
AI生成项目python

全部评论 (0)
还没有任何评论哟~
