A total variation loss
发布时间
阅读量:
阅读量
import matplotlib
import torch
x = torch.FloatTensor([1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3,3,3,3,1,1,1,1,1,0,0,1,1,1,1])
#x = torch.FloatTensor([1,1,1,1,1,20,20,20,20,20,20,3,3,3,3,3,3,3,3,1,1,1,1,1,0,0,1,1,1,1])
m = torch.distributions.normal.Normal(torch.tensor([0.0]),torch.tensor([0.3]))
noise = torch.squeeze(m.sample((x.size()[0],)))
x_ = x + noise
x_ = torch.autograd.Variable(x_)
#matplotlib.pyplot.plot(x_.numpy())
#matplotlib.pyplot.plot(x.numpy())
y = torch.zeros([x_.size()[0]])
y = torch.autograd.Variable(y,requires_grad=True)
#optimizer = torch.optim.Adam([{'params':[y]}])
#torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, #amsgrad=False)
optimizer = torch.optim.SGD([{'params':[y], 'lr':1.0e-2}])
#torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, #weight_decay=0, nesterov=False)
lamda = 0.3#0.5
for iter_ in range(5000):
#print(iter_)
#y.zero_grad()
optimizer.zero_grad()
#print(y.data.numpy())
#print(y.grad)
'''cal loss'''
tv_loss = torch.pow((y[1:] - y[:y.size()[0]-1]),2).sum()
mse_loss = torch.nn.MSELoss(reduction='sum')
E_x_y = mse_loss(x_,y)/2
loss = E_x_y + lamda*tv_loss
#if loss.item() < 0.1:
# break
loss.backward()#(retain_graph=True)
#print(y.grad.numpy())
optimizer.step()
#y.grad.data.zero_()
'''cal loss'''
tv_loss = torch.pow((y[1:] - y[:y.size()[0]-1]),2).sum()
mse_loss = torch.nn.MSELoss(reduction='sum')
E_x_y = mse_loss(x_,y)/2
loss_ = E_x_y + lamda*tv_loss
print(iter_,': ',loss_.item())#loss.data
if torch.abs(loss_-loss) < 1e-8:
break
#matplotlib.pyplot.plot(x_.numpy())
matplotlib.pyplot.plot(x.numpy())
matplotlib.pyplot.plot(y.data.numpy())
matplotlib.pyplot.show()
matplotlib.pyplot.gcf().clear()
matplotlib.pyplot.plot(x_.numpy())
matplotlib.pyplot.plot(x.numpy())
#matplotlib.pyplot.plot(y.data.numpy())
'''last_10_loss'''
'''
725 : 1.8783742189407349
726 : 1.878373622894287
727 : 1.8783732652664185
728 : 1.8783729076385498
729 : 1.8783724308013916
730 : 1.8783719539642334
731 : 1.8783714771270752
732 : 1.878371238708496
733 : 1.8783705234527588
734 : 1.8783705234527588
'''
y、x:

x、x_:

'''
for iter_ in range(5000):
noise = torch.squeeze(m.sample((x.size()[0],)))
x_ = x + noise
'''
'''last_10_loss'''
'''
4990 : 2.249378204345703
4991 : 3.130645751953125
4992 : 2.8582403659820557
4993 : 2.522833824157715
4994 : 2.6637299060821533
4995 : 2.8069911003112793
4996 : 2.898977279663086
4997 : 2.4421730041503906
4998 : 2.4187495708465576
4999 : 3.2449662685394287
'''
y、x:

x、x_:

Conclusion
1:
loss1 v.s. loss2
视觉效果好的曲线其loss值不一定就低。
如何评价?给出曲线,人主观排序。
2:
lamda,lr值如何选取?
第一步:仅有lr。根据loss值下降的速度确定lr;
第二步:lr不变,把lamda从小到大依次调大,选取最小的loss来确定lamda。
lamda=1:

lamda=0.5:

lamda=0.3:

lamda=0.1:

#lamda=0.1
'''last_10_loss'''
'''
4990 : 1.601554274559021
4991 : 2.305757522583008
4992 : 1.4222698211669922
4993 : 1.5561707019805908
4994 : 2.4976422786712646
4995 : 1.849165678024292
4996 : 2.165806770324707
4997 : 2.1707358360290527
4998 : 2.037616491317749
4999 : 1.7519365549087524
'''
ps:
variance vs variation(standard deviation)
https://www.analystforum.com/forums/cfa-forums/cfa-level-ii-forum/91355300
The Stack Exchange thread at https://stats.stackexchange.com/questions/88348/is-variability-the-same-as-variance explores whether the concept of "variability" is equivalent to "variance".
Reference:
Total variation denoising represents a technique built upon the concept of regularization theory, aimed at minimizing image noise without degrading the image's inherent structure.
全部评论 (0)
还没有任何评论哟~
