Deep Retinex Decomposition for Low-Light Enhancement
Abstract
该模型可视为实现低光图像增强的重要手段。它基于以下假设:观察到的图像可分解为反射率与照明度两种基本成分。现有基于Retinex的方法虽然针对这种高度不适定的分解问题设计了特殊的约束条件和参数设置,在实际应用中可能会受到模型容量限制的影响。为此,在本文中我们构建了一个包含低光与正常光照图像配对的数据集LOL,并在此基础上开发了一种深度学习框架——深度Retinex-Net(Decom-Net + Enhance-Net)。该框架由两个主要模块组成:用于分解成像组件Decom-Net以及用于实现光照补偿的任务网络Enhance-Net。在Decom-Net的学习过程中,并未假设分解出反射率与光照度的基本事实;相反地,则仅依据几个关键约束来进行训练:包括成对低光/正常光照图像共享的一致反射率特征以及光照度分布平滑性等特性。在此基础上通过Enhance-Net完成了对成像亮度值的提升操作,并结合去噪需求分别实现了反射率层面上的降噪处理。值得注意的是,在整个深度Retinex-Net的设计过程中我们并未引入端到端训练策略;相反地则采用了分步优化的方式来进行参数求解。经过大量实验验证表明:所提出的深度Retinex-Net不仅能够在视觉效果上达到令人满意的结果水平,在图像是分割质量方面也展现出了良好的性能表现
1 Introduction
图像捕捉中的光照不足会产生严重影响;丢失的细节与低对比度不仅会引起不愉快的人为感受;也会对专为正常光照下的图像设计而优化的计算机视觉系统造成负面影响;造成光照不足的原因多种多样;可能源于环境光照不足.摄制设备性能局限.设备设置不当等因素;为了恢复隐藏在暗光条件下的细节内容并提升当前计算机视觉系统的人机交互体验及实用性需求;应采用弱光条件下图像增强技术
在过去的几十年间,众多研究人员专注于解决低光图像增强问题。他们开发了许多技术以提升低光图像的主观感知与客观质量。直方图均衡方法(HE)[20]及其变体通过约束输出图像的直方图来满足特定条件。基于去雾霾的方法[5]利用了光照不足场景下的图像与其模拟雾霾环境图像之间的逆向关系。
另一类低光增强方法基于Retinex理论[12]构建,在该理论中观察到的彩色图像被认为由反射率与照明度组成。作为早期研究单尺度Retinex(SSR)[11]通过高斯滤波器限制了照明图的平滑性;而多尺度Retinex(MSR-CR)[10]则借助多尺度高斯滤波器与色彩恢复机制扩展了SSR的应用范围。文献[23]提出了一种基于亮度顺序误差测量的方法来保持光照自然度;傅等人[7]则提出了融合初始光照图多推导的方法。SRIE7则采用了加权变分模型来同时估计反射率与照度。值得注意的是,在此之后对操纵后的照明结果能够恢复出目标图像效果的研究逐渐增多;而LIME9则是一种较为独特的研究方向。此外,在现有研究的基础上还出现了基于Retinex理论联合实现低光增强与噪声消除的方法[14, 15]
虽然这些方法在某些情况下可能带来令人鼓舞的结果,但它们仍然受限于反射率以及照明分解模型的能力。开发出适用于各种场景的理想工作约束条件的图像分解非常具有挑战性。此外,在处理光照图时也需要人工干预
伴随着深度学习技术的迅速发展
为了应对这些挑战,我们开发了一种基于数据的Retinex分解方案。我们设计了一个整合图像分解与连续增强操作的深度学习架构称为Retinex-Net。具体而言利用Decom-Net将观测图像划分为与照明无关的反射率和具有结构感知和平滑光照的部分。Decom-Net的学习过程受到两个关键约束的影响首先是在低光及正常光照条件下获取的画面具有相同的反射特性其次在处理过程中必须确保照明图平滑且保留主要结构特征这是通过结合结构感知的信息损失实现的。随后另一个模块负责调节光照图以维持大范围的一致性同时通过多尺度串联机制来优化局部细节分布由于在黑暗区域噪声通常较大甚至可能放大因此我们引入了反射率去噪步骤以改善结果质量为了训练这一模型我们构建了包含真实拍摄的真实低光及正常光图像对以及从RAW数据集中生成的人工合成图像的数据集经过大量实验验证我们的方法不仅在低光增强任务中实现了令人满意的视觉效果而且还提供了高质量的图像分解表示总结我们的主要贡献包括
我们基于真实场景采集了配对低光/正常光图像,并生成了一个大型数据集。经研究发现,在弱光增强领域这是具有开创性的工作。
我们开发了一种基于Retinex模型的深度学习图像分解技术,在端到端训练过程中将该技术与连续的低光增强网络相结合;由此可知,该框架在光照条件调整方面具有显著的能力。
我们开发了一种基于结构感知的总变分约束模型用于深度图像分解。该约束通过弱化梯度强度大的区域对图像平滑的影响,实现了良好的平滑效果,并有效保持了图像的主要结构特征。
2 Retinex-Net for Low-Light Enhancement
经典的Retinex理论模拟了人类的色彩感知。该理论假设观察到的图像可以被分解为两个基本组成部分:反射率和照明度。令S代表源图像,则其数学表达式可表示为

其中 R 表征反射率,I 表示照明度,◦ 则代表逐元素乘法运算。反射率表征了物体材料的固有属性特征,无论光照强度如何变化,其值均保持恒定不变特性。而照度则表征物体表面不同区域的亮度差异,在低光成像场景中通常会呈现暗区与光照分布不均的现象
基于Retinex理论的基础上
2.1 Data-Driven Image Decomposition
一种途径是基于低光输入图像来实现对观测图像的分解,并通过巧妙设计的约束条件计算反射率与照明信息。鉴于方程(1)高度不适定,在开发适用于不同场景的理想化约束条件方面仍面临诸多挑战。鉴于此,我们采用数据驱动的方法来寻求解决方案。
在训练过程中, Decom-Net系统会持续接收多对低光与正常光图像,并基于这些图像共有的指导信息,自主学习如何分解出低光及其对应的正常光图像中的相同反射特性.尽管分解过程基于成对数据进行,但它可以通过单个低光输入实现独立的分解任务.在整个训练过程中,系统无需掌握真实场景中反射率和光照的具体参数.相反,通过将强调了反射均匀性和光照图的空间平滑特性等关键知识融入网络损失函数,系统得以自主构建有效的分解模型.这种自适应的学习机制能够有效捕捉不同光照条件下的图像变换特征.
需要注意的是,在我们的研究任务中尽管从形式上看这个问题可能与本征图像分解存在相似之处然而实质上二者存在差异为此我们需要一种不同于传统本征图像分解的方法为此我们需要一种不同于传统本征图像分解的方法为此我们需要一种不同于传统本征图像分解的方法为此我们需要一种不同于传统本征图
如图1所示,Decom-Net采用低光图像Slow与正常光图像Snormal作为输入,并分别估计了 Slow 的反射率 R_low 和照度 I_low 以及 S_normal 的 R_normal 和 I_normal 。该网络首先通过 3×3 卷积层从输入图像中提取特征;随后利用多个以整流线性单元(ReLU)为激活函数的 3×3 卷积层将 RGB 图像转换为反射率 R 和照明度 I 。网络进一步通过基于特征空间的投影得到 R 和 I ,并运用 sigmoid 函数将其限制在 [0, 1] 区间内。

Figure 1: Retinex-Net 的拟议框架。增强过程划分为分解、调节和重构三个阶段。其中,在分解阶段中,Decom-Net子网络将输入图像分解为反射率层和照明度层。随后,在调节阶段中,则通过增强网络编码器与解码器的协同作用提升亮度水平。为了实现更加精细的调节效果,在此过程中我们还引入了多尺度串联结构以从多层次视角优化照明条件,并在此过程中,在反射率层上残留的噪声也能够被有效去除。最后经过重建处理后即可获得经过深度优化后的图像结果
损失L由三个主要组成部分构成:具体来说是重建L_{\text{recon}}、不变反射率L_{\text{ir}}以及照明平滑度L_{\text{is}}。

其中λir和λis表示平衡反射率一致性和光照平滑度的系数。
基于Rlow和Rhigh各自具有对应的光源分布这一假设,在计算重建损失Lrecon时


引入不变反射率损耗Lir来约束反射率的一致性:

照明平滑度损失Lis将在下面的部分中详细描述。
2.2 Structure-Aware Smoothness Loss
正如文献[9]所指出的那样,光照图的一个核心假设即为局部均匀性和结构辨识能力。换句话说,一个有效的光照贴图方案应当能够在纹理细节上呈现光滑过渡的同时,依然能够准确识别并保留关键边界特征。
其旨在最小化图像的整体梯度[2],常被用作图像恢复任务中的平滑先验工具。然而,在图像结构较为复杂或亮度变化剧烈的区域中,则直接使用 TV 作为损失函数会失效。这是因为无论该区域是文本细节还是强边界边缘,在照明图中梯度都会均匀减小。这正是导致 TV 损失存在结构性缺陷的根本原因。
为了使损失对图像结构进行理解, 传统的Total Variation (TV)函数基于反射图的梯度赋予权重. 最终的结果表示为:

其中符号∇代表梯度,在水平方向为梯度分量∇h,在垂直方向为梯度分量∇v;而λg则是衡量结构感知平衡程度的关键参数。借助权重因子exp(−λg∇Ri),系统Lis在反射率变化剧烈的位置上实施了更为宽松的约束措施;换言之,在图像中具有明确结构的位置区域以及光照条件预期不连续的地方,默认会对平滑性施加较弱的要求
虽然 LIME [9] 也考虑在加权TV约束的照明图中保留图像结构,但我们则认为这两种方法存在本质区别。对于 LIME,其总变化量受到原始光照分布的影响,而该分布通常仅反映 R、G 和 B 通道中每个像素的最大亮度值这一特征。相比之下,我们采用基于反射率的权重分配策略以优化图像平滑度损失。值得注意的是,尽管 LIME 中采用了静态初始估计这一假设,但这种方法难以准确描述图像结构这一物理特性,因为反射率被严格假定为固定不变的物理属性。基于外部大量数据集进行离线训练后,在模型参数更新阶段同时优化光照条件与权重分配
2.3 Multi-Scale Illumination Adjustment
照明增强网络基于编码器-解码器架构的整体结构设计。为了在不同层次上调整照明效果,在系统中引入了多尺度串联结构。
该架构通过编码器-解码器模式获取图像的大范围上下文信息。输入图像经过连续降采样处理至较小尺度后, 使得该系统能够识别出大尺寸范围内的光照分布情况。这种机制赋予网络自适应调节的能力。基于大规模照明数据, 在上采样块中重构局部区域的照明分布情况。通过逐元素求和的方式, 在下采样块与其镜像对齐的上采样块之间建立跳跃连接机制, 这一过程迫使网络学习各层之间的残差信息
为实现分层次调节光照,即保证全局光照的一致性的同时,能够灵活调节各区域局部光照分布,我们提出了一种多尺度串联框架.假设存在M个逐级上采样模块,每个模块能够提取对应的C通道空间特征图.通过最近邻插值将各模块的空间特征按比例放大至统一尺度后整合在一起,并将其融合至统一尺度后的空间特征映射融合至一个新的C×M通道的整体特征图中.随后采用1×1卷积层对融合后的整体特征进行降维处理.最后利用3×3卷积核对整体特征进行精细建模以恢复完整的光照信息

。
downsampling block is composed of convolution layers with stride 2 and ReLU activations. The upsampling block incorporates size-altering convolution layers. As demonstrated in [19], it effectively avoids the checkerboard artifacts. The size-altering convolution layers incorporate nearest-neighbor interpolation, convolution layers with stride of 1, and ReLU activations.
Enhance-Net 模型中的目标函数 L 由重建损失项 L_{\text{recon}} 和光照平滑度相关的项 L_{\text{is}} 构成。其中 L_{\text{recon}} 的作用在于生成法线向量。

,即

Lis 与式(5)相同,只是

通过 Rlow 的梯度图进行加权。
2.4 Denoising on Reflectance
在分解阶段中施加了几项约束条件,在其中一个约束条件下涉及到了基于光环境的空间平滑特性。如果估算出的光照场呈现空间均匀性,则这些细节特征主要体现在表面反射率的变化中,并包含增强噪声的影响;因此,在用照明图重建输出图像之前应用反射率降噪处理策略较为合理;值得注意的是,在分解过程中暗区区域中的噪声放大程度会随着亮度水平的变化而有所差异;我们的实现在第 4 节中进行了描述

3 Dataset
虽然这一领域已研究了数十年然而目前公开的数据集尚未提供真实场景下的配对低光/正常光图像其中一些采用 HDR 数据集[18]作为补充方案但此类数据集规模较小涉及的场景也较为有限因而无法用于训练深度学习模型为此我们开发了一种新架构它由两部分构成第一部分模仿真实拍摄中的降质特性第二部分则在数据增强场景多样性和对象多样性等方面发挥了作用
3.1 Dataset Captured in Real Scenes
该命名方案包含了500组低光与正常光图像对比样本,并经调查发现该方案是首个引入基于真实场景采集的用于低光增强的图像对比数据集合
大多数低光图像主要通过调节曝光时间和ISO值来捕获光线信息,在相机设置中其他参数保持不变的情况下完成这一过程。我们在多个不同场景下进行拍摄以丰富数据集:如住宅楼、校园景观、娱乐场所以及城市街道等。图 3 展示了所选场景的关键子区域。
因为相机振动、物体运动以及光线变化等因素可能导致同一场景中的不同幅图像配对出现偏差。参考文献[1]提出了一种分阶段的方法来消除数据集中各幅图像配对间的偏差。具体实施步骤可在附录中详细说明。经过尺寸调整后将原始图片转换为便于传输的手持设备兼容的网络图形格式。该数据集已作为开放资源发布
3.2 Synthetic Image Pairs from Raw Images
为了使合成图像具备真实暗摄影的特性,我们需要研究低光环境下的光照分布特征。本研究主要来源于MEF[18,23,9,13], VV-1以及Fusion[3]等公开数据集中的270张低照图片,在完成YCbCr空间变换后对Y分量进行分析;同时我们又从RAISE[4]中获取了包含约一千张正常曝光图片作为对照样本,并对YCbCr码流中的Y区域进行了统计直方图构建;如图所示的结果展示了各项实验指标的变化曲线

Figure 5展示了LOL数据集中的Bookshelf上分别应用Decom-Net和LIME进行图像分解的结果图示。研究发现,在真实场景中暗区噪声放大现象较为明显的情况下,尽管低光图像与正常光照下的反射率存在差异(如表1所示),但整体趋势却表现出较高的相似性。
原始图像所包含的信息量高于转换后的结果。当处理原始图像时, 所有用于计算像素值的过程均基于前一步骤的基础数据, 因此能够获得更为精确的结果。RAISE [4]项目中的1000张原始图片被用来合成低光场景图。借助Adobe Lightroom平台提供的接口, 通过尝试不同参数组合, 可以使Y通道的空间直方图呈现出弱光场景的特点。研究者们最终确定的最佳参数设置可在附录中查阅到(如本研究附图4所示)。最后, 我们将这些原始图片的比例调整至4:6, 并将其格式化为便携式的网络图形文件。
4 Experiments
4.1 Implementation Details
在第2节中所述的LOL数据集

时,λi j 设置为 0.001,当 i = j 时,λi j 设置为 1。
4.2 Decomposition Results
在图 5 中
4.3 Evaluation
本研究在公共 LIME [9]、MEF [18] 和 DICM [13] 数据集的真实场景图像上进行了评估。其中,LIME 包括 10 张测试图片,MEF 则提供了 17 张具有不同曝光度的连续图像序列,DICM 则通过商用数码相机捕捉到了 69 张高质量图片。我们还将 Retinex-Net 方法与当前最先进的一些对比方案进行了系统性比较分析,包括基于去雾的 DeHz 方法 [5],自然保留增强算法 NPE [23],同时考虑反射率与光照的 SRIE 方法 [8],以及基于照明图 LIME 的方案 [9].

图6展示了三幅自然图像进行视觉比较的结果。补充文件中提供了更多详细数据。每个红色矩形都表明我们提出的方法能够有效提升暗亮度中的物体对比度而不导致过度曝光的现象发生(这一效果得益于基于学习的方法以及多尺度定制化的照明方案)。与LIME相比我们发现该方法并未出现部分区域过度曝光的情况(可参考静物画中叶子部分以及房间外景中树叶部分的具体表现)。与DeHz相比我们所获得的结果呈现出良好的边缘特性德海兹则因采用加权TV损失函数而获得了更好的边缘表现效果(如街道边房屋轮廓的具体情况可见)。
4.4 Joint Low-Light Enhancement and Denoising
基于综合性能考量,在Retinex-Net中应用BM3D[3]作为降噪操作。鉴于噪声在反射率上的放大具有不均匀性特征,则采用了基于光照相对性的策略(如图7所示)。为了全面评估我们的联合降噪Retinex-Net的表现效果,在以下两个基准测试中进行了对比:首先是对经去噪处理后的LIME算法(LIME with denoising),其次则是最近提出的JED方法[22](a recent joint low-light enhancement and denoising method)。实验结果表明,在细节保留方面表现更为出色的是Retinex-Net算法(the proposed Retinex-Net),而LIME及其改进型JED则导致边缘模糊现象(edge blurring)出现
5 Conclusion
本文开发了一种基于深度学习的Retinex分解框架,该框架通过数据驱动的学习机制能够有效分离出图像中物体表面反射特性与环境光谱特征,并非依赖于反射率与光照分离的基本理论。在此基础上详细阐述了照明操作中的光增强技术和反射域上的降噪处理过程。研究设计了端到端优化的深度学习模型架构,并在此基础上构建了弱光增强网络模块。实验表明该方法不仅实现了图像增强效果具有视觉吸引力且各向异性特征明显,并获得了较为理想的图像分解结果。
代码解读
DecomNet
class DecomNet(nn.Module):
def __init__(self, channel=64, kernel_size=3):
super(DecomNet, self).__init__()
# Shallow feature extraction
self.net1_conv0 = nn.Conv2d(4, channel, kernel_size * 3,
padding=4, padding_mode='replicate')
# Activated layers!
self.net1_convs = nn.Sequential(nn.Conv2d(channel, channel, kernel_size,
padding=1, padding_mode='replicate'),
nn.ReLU(),
nn.Conv2d(channel, channel, kernel_size,
padding=1, padding_mode='replicate'),
nn.ReLU(),
nn.Conv2d(channel, channel, kernel_size,
padding=1, padding_mode='replicate'),
nn.ReLU(),
nn.Conv2d(channel, channel, kernel_size,
padding=1, padding_mode='replicate'),
nn.ReLU(),
nn.Conv2d(channel, channel, kernel_size,
padding=1, padding_mode='replicate'),
nn.ReLU())
# Final recon layer
self.net1_recon = nn.Conv2d(channel, 4, kernel_size,
padding=1, padding_mode='replicate')
def forward(self, input_im):
input_max= torch.max(input_im, dim=1, keepdim=True)[0]
input_img= torch.cat((input_max, input_im), dim=1)
feats0 = self.net1_conv0(input_img)
featss = self.net1_convs(feats0)
outs = self.net1_recon(featss)
R = torch.sigmoid(outs[:, 0:3, :, :])
L = torch.sigmoid(outs[:, 3:4, :, :])
return R, L

输入由原始图像与单通道最大图叠加而成。总共有四个通道。经过连续应用卷积层和ReLU激活函数处理后,在输出结果中分为两部分:R代表反射面区域和L代表光照区域。
RelightNet
class RelightNet(nn.Module):
def __init__(self, channel=64, kernel_size=3):
super(RelightNet, self).__init__()
self.relu = nn.ReLU()
self.net2_conv0_1 = nn.Conv2d(4, channel, kernel_size,
padding=1, padding_mode='replicate')
self.net2_conv1_1 = nn.Conv2d(channel, channel, kernel_size, stride=2,
padding=1, padding_mode='replicate')
self.net2_conv1_2 = nn.Conv2d(channel, channel, kernel_size, stride=2,
padding=1, padding_mode='replicate')
self.net2_conv1_3 = nn.Conv2d(channel, channel, kernel_size, stride=2,
padding=1, padding_mode='replicate')
self.net2_deconv1_1= nn.Conv2d(channel*2, channel, kernel_size,
padding=1, padding_mode='replicate')
self.net2_deconv1_2= nn.Conv2d(channel*2, channel, kernel_size,
padding=1, padding_mode='replicate')
self.net2_deconv1_3= nn.Conv2d(channel*2, channel, kernel_size,
padding=1, padding_mode='replicate')
self.net2_fusion = nn.Conv2d(channel*3, channel, kernel_size=1,
padding=1, padding_mode='replicate')
self.net2_output = nn.Conv2d(channel, 1, kernel_size=3, padding=0)
def forward(self, input_L, input_R):
input_img = torch.cat((input_R, input_L), dim=1)
out0 = self.net2_conv0_1(input_img)
out1 = self.relu(self.net2_conv1_1(out0))
out2 = self.relu(self.net2_conv1_2(out1))
out3 = self.relu(self.net2_conv1_3(out2))
out3_up = F.interpolate(out3, size=(out2.size()[2], out2.size()[3]))
deconv1 = self.relu(self.net2_deconv1_1(torch.cat((out3_up, out2), dim=1)))
deconv1_up= F.interpolate(deconv1, size=(out1.size()[2], out1.size()[3]))
deconv2 = self.relu(self.net2_deconv1_2(torch.cat((deconv1_up, out1), dim=1)))
deconv2_up= F.interpolate(deconv2, size=(out0.size()[2], out0.size()[3]))
deconv3 = self.relu(self.net2_deconv1_3(torch.cat((deconv2_up, out0), dim=1)))
deconv1_rs= F.interpolate(deconv1, size=(input_R.size()[2], input_R.size()[3]))
deconv2_rs= F.interpolate(deconv2, size=(input_R.size()[2], input_R.size()[3]))
feats_all = torch.cat((deconv1_rs, deconv2_rs, deconv3), dim=1)
feats_fus = self.net2_fusion(feats_all)
output = self.net2_output(feats_fus)
return output

该模型的整体架构基于UNet框架设计,在输入端通过将R与L进行通道组合(cat)处理得到融合特征图。随后在下采样阶段采用卷积层(conv),同时确保输出的维度与输入一致。具体而言,在上采样阶段采用插值方法(如F.interpolate)进行放大,并在每次放大操作后将降采样的conv输出与跳跃连接的低分辨率特征图进行融合。随后通过一个卷积层进一步缩减特征空间维度,并经过一系列操作后,在最后一个卷积层中生成了一维向量L(light adjustment result)。
RetinexNet
class RetinexNet(nn.Module):
def __init__(self):
super(RetinexNet, self).__init__()
self.DecomNet = DecomNet()
self.RelightNet= RelightNet()
def forward(self, input_low, input_high):
# Forward DecompNet
input_low = Variable(torch.FloatTensor(torch.from_numpy(input_low))).cuda()
input_high= Variable(torch.FloatTensor(torch.from_numpy(input_high))).cuda()
R_low, I_low = self.DecomNet(input_low)
R_high, I_high = self.DecomNet(input_high)
# Forward RelightNet
I_delta = self.RelightNet(I_low, R_low)
# Other variables
I_low_3 = torch.cat((I_low, I_low, I_low), dim=1)
I_high_3 = torch.cat((I_high, I_high, I_high), dim=1)
I_delta_3= torch.cat((I_delta, I_delta, I_delta), dim=1)
# Compute losses
self.recon_loss_low = F.l1_loss(R_low * I_low_3, input_low)
self.recon_loss_high = F.l1_loss(R_high * I_high_3, input_high)
self.recon_loss_mutal_low = F.l1_loss(R_high * I_low_3, input_low)
self.recon_loss_mutal_high = F.l1_loss(R_low * I_high_3, input_high)
self.equal_R_loss = F.l1_loss(R_low, R_high.detach())
self.relight_loss = F.l1_loss(R_low * I_delta_3, input_high)
self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)
self.loss_Decom = self.recon_loss_low + \
self.recon_loss_high + \
0.001 * self.recon_loss_mutal_low + \
0.001 * self.recon_loss_mutal_high + \
0.1 * self.Ismooth_loss_low + \
0.1 * self.Ismooth_loss_high + \
0.01 * self.equal_R_loss
self.loss_Relight = self.relight_loss + \
3 * self.Ismooth_loss_delta
self.output_R_low = R_low.detach().cpu()
self.output_I_low = I_low_3.detach().cpu()
self.output_I_delta = I_delta_3.detach().cpu()
self.output_S = R_low.detach().cpu() * I_delta_3.detach().cpu()

主要讲解loss,
low和high分解的R应该保持一致
self.recon_loss_low = F.l1_loss(R_low * I_low_3, input_low)
self.recon_loss_high = F.l1_loss(R_high * I_high_3, input_high)
self.recon_loss_mutal_low = F.l1_loss(R_high * I_low_3, input_low)
self.recon_loss_mutal_high = F.l1_loss(R_low * I_high_3, input_high)
self.equal_R_loss = F.l1_loss(R_low, R_high.detach())
self.relight_loss = F.l1_loss(R_low * I_delta_3, input_high)
low和high分解的I应该光滑
self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)
值得注意的是该网络的训练并非完全端到端,在其中Decom和Relight分别进行独立优化
smooth
def gradient(self, input_tensor, direction):
self.smooth_kernel_x = torch.FloatTensor([[0, 0], [-1, 1]]).view((1, 1, 2, 2)).cuda()
self.smooth_kernel_y = torch.transpose(self.smooth_kernel_x, 2, 3)
if direction == "x":
kernel = self.smooth_kernel_x
elif direction == "y":
kernel = self.smooth_kernel_y
grad_out = torch.abs(F.conv2d(input_tensor, kernel,
stride=1, padding=1))
return grad_out
def ave_gradient(self, input_tensor, direction):
return F.avg_pool2d(self.gradient(input_tensor, direction),
kernel_size=3, stride=1, padding=1)
def smooth(self, input_I, input_R):
input_R = 0.299*input_R[:, 0, :, :] + 0.587*input_R[:, 1, :, :] + 0.114*input_R[:, 2, :, :]
input_R = torch.unsqueeze(input_R, dim=1)
return torch.mean(self.gradient(input_I, "x") * torch.exp(-10 * self.ave_gradient(input_R, "x")) +
self.gradient(input_I, "y") * torch.exp(-10 * self.ave_gradient(input_R, "y")))

首先将输入数据转换为灰度图后并对其计算出相应的梯度映射,并区分图像在x轴和y轴方向上的变化程度。其中输入R的梯度变化与所提取的方向具有一定的关联性;随后对输入R进行一次完整的梯度计算,并对该结果进行3×3窗口下的均值滤波操作;其中当输入R的边缘检测结果较小时(即平滑区域),赋予相应的权重系数以增强其作用效果;这会导致较小范围内的图像变得更加平滑;反之则会体现出一定的细节保留能力;而文中提到的光照结构感知损失函数正是基于这一原理设计而成;其中avgPool2d操作可能是在为了减少网络输出层的空间维度从而避免过拟合所采取的一种降维策略。
