Advertisement

每日Attention学习22——Inverted Residual RWKV

阅读量:
模块出处

[arXiv 25] [link] [

复制代码
](https://github.com/juntaoJianggavin/RWKV-UNet) RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation

* * *

##### 模块名称

Inverted Residual RWKV (IR-RWKV)

* * *

##### 模块作用

用于vision的RWKV结构

* * *

##### 模块结构
![在这里插入图片描述](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/ycfkAR3YSum7eT6ZrIh5CWOBFnH2.jpeg)

* * *

##### 模块代码

注:cpp扩展请参考作者原仓库
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import load


T_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
            verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])


def get_norm(norm_layer='in_1d'):
	eps = 1e-6
	norm_dict = {
		'none': nn.Identity,
		'in_1d': partial(nn.InstanceNorm1d, eps=eps),
		'in_2d': partial(nn.InstanceNorm2d, eps=eps),
		'in_3d': partial(nn.InstanceNorm3d, eps=eps),
		'bn_1d': partial(nn.BatchNorm1d, eps=eps),
		'bn_2d': partial(nn.BatchNorm2d, eps=eps),
		# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
		'bn_3d': partial(nn.BatchNorm3d, eps=eps),
		'gn': partial(nn.GroupNorm, eps=eps),
		'ln_1d': partial(nn.LayerNorm, eps=eps),
		# 'ln_2d': partial(LayerNorm2d, eps=eps),
	}
	return norm_dict[norm_layer]


def get_act(act_layer='relu'):
	act_dict = {
		'none': nn.Identity,
		'sigmoid': Sigmoid,
		'swish': Swish,
		'mish': Mish,
		'hsigmoid': HardSigmoid,
		'hswish': HardSwish,
		'hmish': HardMish,
		'tanh': Tanh,
		'relu': nn.ReLU,
		'relu6': nn.ReLU6,
		'prelu': PReLU,
		'gelu': GELU,
		'silu': nn.SiLU
	}
	return act_dict[act_layer]


class ConvNormAct(nn.Module):
	def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
				 skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
		super(ConvNormAct, self).__init__()
		self.has_skip = skip and dim_in == dim_out
		padding = math.ceil((kernel_size - stride) / 2)
		self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
		self.norm = get_norm(norm_layer)(dim_out)
		self.act = get_act(act_layer)(inplace=inplace)
		self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
	
	def forward(self, x):
		shortcut = x
		x = self.conv(x)
		x = self.norm(x)
		x = self.act(x)
		if self.has_skip:
			x = self.drop_path(x) + shortcut
		return x
  

class SE(nn.Module):
def __init__(
        self,
        in_chs: int,
        rd_ratio: float = 0.25,
        rd_channels: Optional[int] = None,
        act_layer: LayerType = nn.ReLU,
        gate_layer: LayerType = nn.Sigmoid,
        force_act_layer: Optional[LayerType] = None,
        rd_round_fn: Optional[Callable] = None,
):
    super(SE, self).__init__()
    if rd_channels is None:
        rd_round_fn = rd_round_fn or round
        rd_channels = rd_round_fn(in_chs * rd_ratio)
    act_layer = force_act_layer or act_layer
    self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
    self.act1 = create_act_layer(act_layer, inplace=True)
    self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
    self.gate = create_act_layer(gate_layer)

def forward(self, x):
    x_se = x.mean((2, 3), keepdim=True)
    x_se = self.conv_reduce(x_se)
    x_se = self.act1(x_se)
    x_se = self.conv_expand(x_se)
    return x * self.gate(x_se)


def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):
assert gamma <= 1/4
B, N, C = input.shape
input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])
B, C, H, W = input.shape
output = torch.zeros_like(input)
output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]
output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]
output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]
output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]
output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]
return output.flatten(2).transpose(1, 2)


def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())


class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
    ctx.B = B
    ctx.T = T
    ctx.C = C
    assert T <= T_MAX
    assert B * C % min(C, 1024) == 0

    half_mode = (w.dtype == torch.half)
    bf_mode = (w.dtype == torch.bfloat16)
    ctx.save_for_backward(w, u, k, v)
    w = w.float().contiguous()
    u = u.float().contiguous()
    k = k.float().contiguous()
    v = v.float().contiguous()
    y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
    wkv_cuda.forward(B, T, C, w, u, k, v, y)
    if half_mode:
        y = y.half()
    elif bf_mode:
        y = y.bfloat16()
    return y

@staticmethod
def backward(ctx, gy):
    B = ctx.B
    T = ctx.T
    C = ctx.C
    assert T <= T_MAX
    assert B * C % min(C, 1024) == 0
    w, u, k, v = ctx.saved_tensors
    gw = torch.zeros((B, C), device='cuda').contiguous()
    gu = torch.zeros((B, C), device='cuda').contiguous()
    gk = torch.zeros((B, T, C), device='cuda').contiguous()
    gv = torch.zeros((B, T, C), device='cuda').contiguous()
    half_mode = (w.dtype == torch.half)
    bf_mode = (w.dtype == torch.bfloat16)
    wkv_cuda.backward(B, T, C,
                      w.float().contiguous(),
                      u.float().contiguous(),
                      k.float().contiguous(),
                      v.float().contiguous(),
                      gy.float().contiguous(),
                      gw, gu, gk, gv)
    if half_mode:
        gw = torch.sum(gw.half(), dim=0)
        gu = torch.sum(gu.half(), dim=0)
        return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
    elif bf_mode:
        gw = torch.sum(gw.bfloat16(), dim=0)
        gu = torch.sum(gu.bfloat16(), dim=0)
        return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
    else:
        gw = torch.sum(gw, dim=0)
        gu = torch.sum(gu, dim=0)
        return (None, None, None, gw, gu, gk, gv)
    

class VRWKV_SpatialMix(nn.Module):
def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):
    super().__init__()
    self.n_embd = n_embd
    attn_sz = n_embd
    self._init_weights()
    self.shift_pixel = shift_pixel
    if shift_pixel > 0:
        self.channel_gamma = channel_gamma
    else:
        self.spatial_mix_k = None
        self.spatial_mix_v = None
        self.spatial_mix_r = None

    self.key = nn.Linear(n_embd, attn_sz, bias=False)
    self.value = nn.Linear(n_embd, attn_sz, bias=False)
    self.receptance = nn.Linear(n_embd, attn_sz, bias=False)
    self.key_norm = nn.LayerNorm(n_embd)
    self.output = nn.Linear(attn_sz, n_embd, bias=False)

    self.key.scale_init = 0
    self.receptance.scale_init = 0
    self.output.scale_init = 0

def _init_weights(self):
    self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))
    self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))
    self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
    self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
    self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
def jit_func(self, x, patch_resolution):
    # Mix x with the previous timestep to produce xk, xv, xr
    B, T, C = x.size()
    # Use xk, xv, xr to produce k, v, r
    if self.shift_pixel > 0:
        xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)
        xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)
        xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)
        xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)
    else:
        xk = x
        xv = x
        xr = x
    k = self.key(xk)
    v = self.value(xv)
    r = self.receptance(xr)
    sr = torch.sigmoid(r)
    return sr, k, v

def forward(self, x, patch_resolution=None):
    B, T, C = x.size()
    sr, k, v = self.jit_func(x, patch_resolution)
    x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)
    x = self.key_norm(x)
    x = sr * x
    x = self.output(x)
    return x


class iR_RWKV(nn.Module):
def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
             act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,
             attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):
    super().__init__()
    self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
    dim_mid = int(dim_in * exp_ratio)
    self.ln1 = nn.LayerNorm(dim_mid)
    self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)
    self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
    if attn_s==True:
            self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)
    self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
    self.proj_drop = nn.Dropout(drop)
    self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
    self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
    self.attn_s=attn_s
    self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)
    
def forward(self, x):
    shortcut = x
    x = self.norm(x)
    x = self.conv(x)
    if self.attn_s:
        B, hidden, H, W = x.size()
        patch_resolution = (H,  W)
        x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)
        x = x.permute(0, 2, 1)
        x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))
        B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidde
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
        x = x.permute(0, 2, 1)
        x = x.contiguous().view(B, hidden, h, w)
    x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
    x = self.proj_drop(x)
    x = self.proj(x)
    x = (shortcut + self.drop_path(x)) if self.has_skip else x
    return x


if __name__ == '__main__':
x = torch.randn([1, 64, 11, 11]).cuda()
ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()
out = ir_rwkv(x)
print(out.shape)  # [1, 64, 11, 11]


python
复制代码
* * *

全部评论 (0)

还没有任何评论哟~