Advertisement

《Single Image Reflection Removal Exploiting Misaligned Training Data and Network Enhancements》阅读笔记

阅读量:

一、论文

《Single Image Reflection Removal Exploiting Misaligned Training Data and Network Enhancements》

从通过玻璃窗捕获的单个图像中去除不希望的反射对于视觉计算系统具有实际重要性。 尽管最先进的方法在某些情况下可以获得不错的结果,但是在处理更一般的实际情况时,性能会大大下降。 这些失败源于单张图像反射去除的内在困难-问题的根本不适性,以及解决基于学习的神经网络管道中这种歧义所需的密集标记的训练数据不足。 在本文中,我们通过利用有针对性的网络增强功能和错位数据的新颖用法来解决这些问题。 对于前者,我们通过嵌入上下文编码模块来增强基线网络体系结构,这些模块能够利用高级上下文线索来减少包含强烈反射的区域内的不确定性。 对于后者,我们引入了对齐不变损失函数,该函数有助于利用容易收集的未对齐现实世界训练数据。 实验结果共同表明,我们的方法在对齐数据方面优于最新技术,并且在使用其他未对齐数据时可能会进行重大改进。

二、网络结构

我关注的是Pyramid Pooling和Residual Block这两个结构

三、代码

代码下载:https://github.com/Vandermode/ERRNet

复制代码
 # Define network components here

    
 import torch
    
 from torch import nn
    
 import torch.nn.functional as F
    
  
    
  
    
 class PyramidPooling(nn.Module):
    
     def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1):
    
     super().__init__()
    
     self.stages = []
    
     self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales])
    
     self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1)
    
     self.relu = nn.LeakyReLU(0.2, inplace=True)
    
  
    
     def _make_stage(self, in_channels, scale, ct_channels):
    
     # prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
    
     prior = nn.AvgPool2d(kernel_size=(scale, scale))
    
     conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False)
    
     relu = nn.LeakyReLU(0.2, inplace=True)
    
     return nn.Sequential(prior, conv, relu)
    
  
    
     def forward(self, feats):
    
     h, w = feats.size(2), feats.size(3)
    
     priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1)
    
     return self.relu(self.bottleneck(priors))
    
  
    
  
    
 class SELayer(nn.Module):
    
     def __init__(self, channel, reduction=16):
    
     super(SELayer, self).__init__()
    
     self.avg_pool = nn.AdaptiveAvgPool2d(1)
    
     self.fc = nn.Sequential(
    
             nn.Linear(channel, channel // reduction),
    
             nn.ReLU(inplace=True),
    
             nn.Linear(channel // reduction, channel),
    
             nn.Sigmoid()
    
     )
    
  
    
     def forward(self, x):
    
     b, c, _, _ = x.size()
    
     y = self.avg_pool(x).view(b, c)
    
     y = self.fc(y).view(b, c, 1, 1)
    
     
    
     return x * y        
    
      
    
  
    
 class DRNet(torch.nn.Module):
    
     def __init__(self, in_channels, out_channels, n_feats, n_resblocks, norm=nn.BatchNorm2d, 
    
     se_reduction=None, res_scale=1, bottom_kernel_size=3, pyramid=False):
    
     super(DRNet, self).__init__()
    
     # Initial convolution layers
    
     conv = nn.Conv2d
    
     deconv = nn.ConvTranspose2d
    
     act = nn.ReLU(True)
    
     
    
     self.pyramid_module = None
    
     self.conv1 = ConvLayer(conv, in_channels, n_feats, kernel_size=bottom_kernel_size, stride=1, norm=None, act=act)
    
     self.conv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
    
     self.conv3 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=2, norm=norm, act=act)
    
  
    
     # Residual layers
    
     dilation_config = [1] * n_resblocks
    
  
    
     self.res_module = nn.Sequential(*[ResidualBlock(
    
         n_feats, dilation=dilation_config[i], norm=norm, act=act, 
    
         se_reduction=se_reduction, res_scale=res_scale) for i in range(n_resblocks)])
    
  
    
     # Upsampling Layers
    
     self.deconv1 = ConvLayer(deconv, n_feats, n_feats, kernel_size=4, stride=2, padding=1, norm=norm, act=act)
    
  
    
     if not pyramid:
    
         self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
    
         self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act)
    
     else:
    
         self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act)
    
         self.pyramid_module = PyramidPooling(n_feats, n_feats, scales=(4,8,16,32), ct_channels=n_feats//4)
    
         self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act)
    
     
    
     def forward(self, x):
    
     x = self.conv1(x)
    
     x = self.conv2(x)
    
     x = self.conv3(x)
    
     x = self.res_module(x)
    
  
    
     x = self.deconv1(x)
    
     x = self.deconv2(x)
    
     if self.pyramid_module is not None:
    
         x = self.pyramid_module(x)
    
     x = self.deconv3(x)
    
  
    
     return x
    
  
    
  
    
 class ConvLayer(torch.nn.Sequential):
    
     def __init__(self, conv, in_channels, out_channels, kernel_size, stride, padding=None, dilation=1, norm=None, act=None):
    
     super(ConvLayer, self).__init__()
    
     # padding = padding or kernel_size // 2
    
     padding = padding or dilation * (kernel_size - 1) // 2
    
     self.add_module('conv2d', conv(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation))
    
     if norm is not None:
    
         self.add_module('norm', norm(out_channels))
    
         # self.add_module('norm', norm(out_channels, track_running_stats=True))
    
     if act is not None:
    
         self.add_module('act', act)
    
  
    
  
    
 class ResidualBlock(torch.nn.Module):
    
     def __init__(self, channels, dilation=1, norm=nn.BatchNorm2d, act=nn.ReLU(True), se_reduction=None, res_scale=1):
    
     super(ResidualBlock, self).__init__()
    
     conv = nn.Conv2d
    
     self.conv1 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=act)
    
     self.conv2 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=None)
    
     self.se_layer = None
    
     self.res_scale = res_scale
    
     if se_reduction is not None:
    
         self.se_layer = SELayer(channels, se_reduction)
    
  
    
     def forward(self, x):
    
     residual = x
    
     out = self.conv1(x)
    
     out = self.conv2(out)
    
     if self.se_layer:
    
         out = self.se_layer(out)
    
     out = out * self.res_scale
    
     out = out + residual
    
     return out
    
  
    
     def extra_repr(self):
    
     return 'res_scale={}'.format(self.res_scale)
    
    
    
    
    AI生成项目
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/0wpht1qdl6S2PyzWUkJrnXoVO7TN.png)

四、相关资料

Single Image Reflection Removal Exploiting Misaligned Training Data and Network Enhancements

全部评论 (0)

还没有任何评论哟~