Advertisement

Deep Bilateral Learning for Real-Time Image Enhancement

阅读量:

模型结构为:

这里写图片描述

low resolutioion 图像特征提取

1 low-lever features

如上图所示,利用n_S个卷积(4层,卷积核为3\times3,stride=2),从low-resolution图像中提取低层特征S^i:,公式如下:
这里写图片描述

式中,I=1,...,n_S为每个卷积层的索引,c,c'为为卷积层的channels的索引.w'为卷积核权重矩阵.b^i为bias.激活函数\sigma采用ReLU,卷积时采用zero-padding.

2 Local features path

低层特征输入一个n_L=2层卷积层得到局部特征L^i.n_S+n_L对于语义特征的获取很关键,如果要获得一个更高空间的分辨率,可以通过减小,增大n_L实现.

3 Global features path

全局特征层有2个卷积层,stride=2,之后接3个全连接层组成,层数为n_G=5.全局特征效果:
这里写图片描述

4 Fusion and linear prediction

使用一个pointtwise仿射变换,加一个ReLU激活函数,来融合全局和局部特征:
这里写图片描述

这样得到了一个16\times16\times64的特征矩阵,将其输入1\times1的卷积层得到16\times16,output channels=96:
这里写图片描述

参数设置如下:
这里写图片描述

Image features as a bilateral grid

由low resolution 图像湖提取特征为16\times16\times96的feature map.可以等价与grid深度为d的多通道 bilateral grid:
这里写图片描述
取d=9,这样就等价于有一个16\times16\times8的 bilateral grid,每个grid cell包含12个,每个还有一个3\times4的仿射颜色变换矩阵.

Upsampling with a trainable slicing layer

Guidance map auxiliary network

定义g为一个pointwise非线性变换,
这里写图片描述
这里写图片描述
式中,M_c^T为的颜色变换矩阵,M_c^T,a,t,b,b'为网络要学习的参数.

Assembling the final output

最后的输入O_c由full-resolution features和sliced feature map的仿射变换得到:
这里写图片描述

模型inference代码为:

复制代码
    def inference(cls, lowres_input, fullres_input, params,
              is_training=False):
    
      with tf.variable_scope('coefficients'):
    bilateral_coeffs = cls._coefficients(lowres_input, params, is_training)
    tf.add_to_collection('bilateral_coefficients', bilateral_coeffs)
    
      with tf.variable_scope('guide'):
    guide = cls._guide(fullres_input, params, is_training)
    tf.add_to_collection('guide', guide)
    
      with tf.variable_scope('output'):
    output = cls._output(
        fullres_input, guide, bilateral_coeffs)
    tf.add_to_collection('output', output)
    
      return output
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

每个模块代码分析

1 low-lever features

输入为low-res input,网络结构为n个卷积层,卷积核为,stride=2,代码如下:

复制代码
    with tf.variable_scope('splat'):
      n_ds_layers = int(np.log2(params['net_input_size']/spatial_bin))
    
      current_layer = input_tensor
      for i in range(n_ds_layers):
    if i > 0:  # don't normalize first layer
      use_bn = params['batch_norm']
    else:
      use_bn = False
    current_layer = conv(current_layer, cm*(2**i)*gd, 3, stride=2,
                         batch_norm=use_bn, is_training=is_training,
                         scope='conv{}'.format(i+1))
    
      splat_features = current_layer
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

2 local features

用于提取图像的局部特征,网络结构为l2个卷积层,卷积核为,stride=1,第一个卷积层采用batchnorm.

复制代码
    with tf.variable_scope('local'):
      current_layer = splat_features
      current_layer = conv(current_layer, 8*cm*gd, 3, 
                       batch_norm=params['batch_norm'], 
                       is_training=is_training,
                       scope='conv1')
      # don't normalize before fusion
      current_layer = conv(current_layer, 8*cm*gd, 3, activation_fn=None,
                       use_bias=False, scope='conv2')
      grid_features = current_layer
    
      
      
      
      
      
      
      
      
      
      
    

3 global features G^i

用于提取全局特征,网络结构为两个卷积层,卷积核为,stride=2,卷积层之后是三个全连接层,代码如下:

复制代码
    with tf.variable_scope('global'):
      n_global_layers = int(np.log2(spatial_bin/4))  # 4x4 at the coarsest lvl
    
      current_layer = splat_features
      for i in range(2):
    current_layer = conv(current_layer, 8*cm*gd, 3, stride=2,
        batch_norm=params['batch_norm'], is_training=is_training,
        scope="conv{}".format(i+1))
      _, lh, lw, lc = current_layer.get_shape().as_list()
      current_layer = tf.reshape(current_layer, [bs, lh*lw*lc])
    
      current_layer = fc(current_layer, 32*cm*gd, 
                     batch_norm=params['batch_norm'], is_training=is_training,
                     scope="fc1")
      current_layer = fc(current_layer, 16*cm*gd, 
                     batch_norm=params['batch_norm'], is_training=is_training,
                     scope="fc2")
      # don't normalize before fusion
      current_layer = fc(current_layer, 8*cm*gd, activation_fn=None, scope="fc3")
      global_features = current_layer
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

将local feature 和global feture相加,得到fusion feature:

复制代码
    with tf.name_scope('fusion'):
      fusion_grid = grid_features
      fusion_global = tf.reshape(global_features, [bs, 1, 1, 8*cm*gd])
      fusion = tf.nn.relu(fusion_grid+fusion_global)
    
      
      
      
      
    

bilateral grid of coefficients

复制代码
    with tf.variable_scope('prediction'):
      current_layer = fusion
      current_layer = conv(current_layer, gd*cls.n_out()*cls.n_in(), 1,
                              activation_fn=None, scope='conv1')
    
      with tf.name_scope('unroll_grid'):
    current_layer = tf.stack(
        tf.split(current_layer, cls.n_out()*cls.n_in(), axis=3), axis=4)
    current_layer = tf.stack(
        tf.split(current_layer, cls.n_in(), axis=4), axis=5)
      tf.add_to_collection('packed_coefficients', current_layer)
    
      
      
      
      
      
      
      
      
      
      
      
    

guidance map g

输入为full-res input I.

复制代码
    def _guide(cls, input_tensor, params, is_training):
      npts = 16  # number of control points for the curve
      nchans = input_tensor.get_shape().as_list()[-1]
    
      guidemap = input_tensor
    
      # Color space change
      idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32)*1e-4
      ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity)
      with tf.name_scope('ccm'):
    ccm_bias = tf.get_variable('ccm_bias', shape=[nchans,], dtype=tf.float32, initializer=tf.constant_initializer(0.0))
    
    guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm)
    guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add')
    
    guidemap = tf.reshape(guidemap, tf.shape(input_tensor))
    
      # Per-channel curve
      with tf.name_scope('curve'):
    shifts_ = np.linspace(0, 1, npts, endpoint=False, dtype=np.float32)
    shifts_ = shifts_[np.newaxis, np.newaxis, np.newaxis, :]
    shifts_ = np.tile(shifts_, (1, 1, nchans, 1))
    
    guidemap = tf.expand_dims(guidemap, 4)
    shifts = tf.get_variable('shifts', dtype=tf.float32, initializer=shifts_)
    
    slopes_ = np.zeros([1, 1, 1, nchans, npts], dtype=np.float32)
    slopes_[:, :, :, :, 0] = 1.0
    slopes = tf.get_variable('slopes', dtype=tf.float32, initializer=slopes_)
    
    guidemap = tf.reduce_sum(slopes*tf.nn.relu(guidemap-shifts), reduction_indices=[4])
    
      guidemap = tf.contrib.layers.convolution2d(
      inputs=guidemap,
      num_outputs=1, kernel_size=1, 
      weights_initializer=tf.constant_initializer(1.0/nchans),
      biases_initializer=tf.constant_initializer(0),
      activation_fn=None, 
      variables_collections={'weights':[tf.GraphKeys.WEIGHTS], 'biases':[tf.GraphKeys.BIASES]},
      outputs_collections=[tf.GraphKeys.ACTIVATIONS],
      scope='channel_mixing')
    
      guidemap = tf.clip_by_value(guidemap, 0, 1)
      guidemap = tf.squeeze(guidemap, squeeze_dims=[3,])
    
      return guidemap
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

sliced coefficients 与 full-res output

复制代码
    def _output(cls, im, guide, coeffs):
      with tf.device('/gpu:0'):
    out = bilateral_slice_apply(coeffs, guide, im, has_offset=True, name='slice')
      return out
    
      
      
      
      
    
复制代码
    def bilateral_slice_apply(grid, guide, input_image, has_offset=True, name=None):
      """Slices into a bilateral grid using the guide map.
    
      Args:
    grid: (Tensor) [batch_size, grid_h, grid_w, depth, n_outputs]
      grid to slice from.
    guide: (Tensor) [batch_size, h, w ] guide map to slice along.
    input_image: (Tensor) [batch_size, h, w, n_input] input data onto which to
      apply the affine transform.
    name: (string) name for the operation.
      Returns:
    sliced: (Tensor) [batch_size, h, w, n_outputs] sliced output.
      """
    
      with tf.name_scope(name):
    gridshape = grid.get_shape().as_list()
    if len(gridshape) == 6:
      gs = tf.shape(grid)
      _, _, _, _, n_out, n_in = gridshape
      grid = tf.reshape(grid, tf.stack([gs[0], gs[1], gs[2], gs[3], gs[4]*gs[5]]))
      # grid = tf.concat(tf.unstack(grid, None, axis=5), 4)
    
    sliced = hdrnet_ops.bilateral_slice_apply(grid, guide, input_image, has_offset=has_offset)
    return sliced
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
复制代码
    bilateral_slice_apply = _hdrnet.bilateral_slice_apply
    
      
    

github代码:https://github.com/mgharbi/hdrnet

下载:

git clone https://github.com/mgharbi/hdrnet

安装依赖库:

cd hdrnet

sudo pip2 install -r requirements.txt

编译:

复制代码
    cd hdrnet
    make
    
      
      
    

测试:

复制代码
    cd hdrnet
    py.test test
    
      
      
    
这里写图片描述

返回train.py所在目录,训练:

复制代码
    cd ..
    
    python train.py checkpoint/ sample_data/identity/filelist.txt 
    
      
      
      
    

checkpoint为模型保存目录,sample_data/identity/filelist.txt 为训练数据路径.

测试训练好的模型:

复制代码
    python run.py checkpoint/ input_val/   test_output/
    
      
    

全部评论 (0)

还没有任何评论哟~