pytorch知识蒸馏测试
发布时间
阅读量:
阅读量
import torch
from torch import nn,optim
import torch.utils
import torch.utils.data
import torch.utils.data.dataloader
from torchvision import transforms,datasets
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DeepNN(nn.Module):
def __init__(self,num_classes=10):
super(DeepNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class ModifiedDeepNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedDeepNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(64,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(
flattened_conv_output, 2
)
return x, flattened_conv_output_after_pooling
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattend_conv_output = torch.flatten(x, 1)
x = self.classifier(flattend_conv_output)
return x, flattend_conv_output
def train(model, train_loader, epochs, learning_rate):
loss_func = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# print(f"inputs.shape={inputs.shape}, labels.shape={labels.shape}")
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# print(outputs.shape) [batch_size, num_classes]
loss = loss_func(outputs, labels)
loss.backward()
optimizer.step()
running_loss +=loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
def test(model, test_loader):
model.to(device)
model.eval()
correct = 0.0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
accuracy = 100 * correct / total
print(f"Test acc: {accuracy:.2f}%")
return accuracy
def train_knowledge_distillation(teacher,student,train_loader,epochs,learning_rate,T,soft_target_loss_weight, ce_loss_weight):
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
teacher.eval()
student.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with torch.no_grad():
teacher_logits = teacher(inputs)
student_logits = student(inputs)
soft_targets = nn.functional.softmax(teacher_logits/T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits/T, dim=-1)
soft_target_loss = torch.sum(soft_targets*(soft_targets.log()-soft_prob))/soft_prob.size()[0]*(T**2)
label_loss = ce_loss(student_logits, labels)
loss = soft_target_loss_weight * soft_target_loss + ce_loss_weight*label_loss
loss.backward()
optimizer.step()
running_loss +=loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate,
hidden_rep_loss_weight, ce_loss_weight):
ce_loss = nn.CrossEntropyLoss()
cosine_loss = nn.CosineEmbeddingLoss()
optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
teacher.to(device)
student.to(device)
teacher.eval() # set teacher to eval mode
student.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)
student_logits, student_hidden_representation = student(inputs)
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
label_loss = ce_loss(student_logits, labels)
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight*label_loss
loss.backward()
optimizer.step()
running_loss +=loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
def test_multiple_outputs(model, test_loader):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs, _ = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
acc = 100 * correct / total
print(f"The acc: {acc:.2f}%")
return acc
if __name__ == '__main__':
transform_cifar = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
]
)
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,
transform=transform_cifar)
test_dataset = datasets.CIFAR10(root='./data',train=False,download=True,
transform=transform_cifar)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128,
shuffle=True,num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128,
shuffle=False,num_workers=2)
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader,epochs=10, learning_rate=1e-3)
test_acc_deep = test(nn_deep, test_loader)
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)
print("Norm of 1st layer of nn_light:",torch.norm(nn_light.features[0].weight).item())
print("Norm of 1st layer of new_nn_light:",torch.norm(new_nn_light.features[0].weight).item())
total_params_deep = sum(p.numel() for p in nn_deep.parameters())
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = sum(p.numel() for p in nn_light.parameters())
print(f"LightNN parameters: {total_params_light}")
train(nn_light, train_loader, epochs=10, learning_rate=1e-3)
test_acc_light_ce = test(nn_light, test_loader)
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light,
train_loader=train_loader,
epochs=10,learning_rate=1e-3,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75)
test_acc_light_ce_and_kd = test(new_nn_light, test_loader)
print(f"Teacher acc: {test_acc_deep:.2f}%")
print(f"Student acc without teacher: {test_acc_light_ce:.2f}%")
print(f"Student acc with CE+KD: {test_acc_light_ce_and_kd:.2f}%")
#######################################################################
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:",torch.norm(modified_nn_deep.features[0].weight).item())
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print(f"Norm of 1st layer:",torch.norm(modified_nn_light.features[0].weight).item())
sample_input = torch.randn(128,3, 32, 32).to(device)
logits, hidden_representation = modified_nn_light(sample_input)
print("Student logits shape:", logits.shape)
print("Student hidden representation shape:",hidden_representation.shape)
logits, hidden_representation = modified_nn_deep(sample_input)
print(f"Teacher logits shape: {logits.shape}")
print(f"Teacher hidden representation shape: {hidden_representation.shape}")
train_cosine_loss(teacher=modified_nn_deep,
student=modified_nn_light,
train_loader=train_loader,
epochs=10,
learning_rate=1e-3,
hidden_rep_loss_weight=0.25,
ce_loss_weight=0.75)
test_acc_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light,
test_loader)
全部评论 (0)
还没有任何评论哟~
