Advertisement

pytorch知识蒸馏测试

阅读量:
复制代码
    import torch
    from torch import nn,optim
    import torch.utils
    import torch.utils.data
    import torch.utils.data.dataloader
    from torchvision import transforms,datasets
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class DeepNN(nn.Module):
    def __init__(self,num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    class ModifiedDeepNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedDeepNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
    
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(
            flattened_conv_output, 2
        )
        return x, flattened_conv_output_after_pooling
    
    
    
    class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    class ModifiedLightNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        flattend_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattend_conv_output)
        return x, flattend_conv_output
    
    
    
    def train(model, train_loader, epochs, learning_rate):
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # print(f"inputs.shape={inputs.shape}, labels.shape={labels.shape}")
            inputs, labels = inputs.to(device), labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(inputs)
            # print(outputs.shape) [batch_size, num_classes]
    
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss +=loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
    
    
    
    def test(model, test_loader):
    model.to(device)
    model.eval()
    
    correct = 0.0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
    
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
    
            total +=labels.size(0)
            correct +=(predicted==labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test acc: {accuracy:.2f}%")
    
    return accuracy
    
    
    def train_knowledge_distillation(teacher,student,train_loader,epochs,learning_rate,T,soft_target_loss_weight, ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
    
    teacher.eval()
    student.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
    
            optimizer.zero_grad()
    
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            
            student_logits = student(inputs)
    
            soft_targets = nn.functional.softmax(teacher_logits/T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits/T, dim=-1)
    
            soft_target_loss = torch.sum(soft_targets*(soft_targets.log()-soft_prob))/soft_prob.size()[0]*(T**2)
    
            label_loss = ce_loss(student_logits, labels)
    
            loss = soft_target_loss_weight * soft_target_loss + ce_loss_weight*label_loss
    
            loss.backward()
            optimizer.step()
    
            running_loss +=loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
    
    
    def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate,
                      hidden_rep_loss_weight, ce_loss_weight):
    ce_loss = nn.CrossEntropyLoss()
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.AdamW(student.parameters(), lr=learning_rate)
    
    teacher.to(device)
    student.to(device)
    teacher.eval() # set teacher to eval mode
    student.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            with torch.no_grad():
                _, teacher_hidden_representation = teacher(inputs)
            
            student_logits, student_hidden_representation = student(inputs)
    
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
    
            label_loss = ce_loss(student_logits, labels)
            
            loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight*label_loss
            loss.backward()
            optimizer.step()
    
            running_loss +=loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
    
    
    def test_multiple_outputs(model, test_loader):
    model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
    
            outputs, _ = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
    
            total +=labels.size(0)
            correct +=(predicted==labels).sum().item()
    
    acc = 100 * correct / total
    print(f"The acc: {acc:.2f}%")
    return acc
    
    
    if __name__ == '__main__':
    transform_cifar = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
        ]
    )
    
    train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,
                                     transform=transform_cifar)
    test_dataset = datasets.CIFAR10(root='./data',train=False,download=True,
                                    transform=transform_cifar)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128,
                                               shuffle=True,num_workers=2)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128,
                                              shuffle=False,num_workers=2)
    
    
    torch.manual_seed(42)
    nn_deep = DeepNN(num_classes=10).to(device)
    train(nn_deep, train_loader,epochs=10, learning_rate=1e-3)
    test_acc_deep = test(nn_deep, test_loader)
    
    torch.manual_seed(42)
    nn_light = LightNN(num_classes=10).to(device)
    
    torch.manual_seed(42)
    new_nn_light = LightNN(num_classes=10).to(device)
    
    print("Norm of 1st layer of nn_light:",torch.norm(nn_light.features[0].weight).item())
    print("Norm of 1st layer of new_nn_light:",torch.norm(new_nn_light.features[0].weight).item())
    
    total_params_deep = sum(p.numel() for p in nn_deep.parameters())
    print(f"DeepNN parameters: {total_params_deep}")
    total_params_light = sum(p.numel() for p in nn_light.parameters())
    print(f"LightNN parameters: {total_params_light}")
    
    train(nn_light, train_loader, epochs=10, learning_rate=1e-3)
    test_acc_light_ce = test(nn_light, test_loader)
    
    
    train_knowledge_distillation(teacher=nn_deep, student=new_nn_light,
                                 train_loader=train_loader,
                                 epochs=10,learning_rate=1e-3,
                                 T=2,
                                 soft_target_loss_weight=0.25,
                                 ce_loss_weight=0.75)
    
    test_acc_light_ce_and_kd = test(new_nn_light, test_loader)
    
    print(f"Teacher acc: {test_acc_deep:.2f}%")
    print(f"Student acc without teacher: {test_acc_light_ce:.2f}%")
    print(f"Student acc with CE+KD: {test_acc_light_ce_and_kd:.2f}%")
    
    #######################################################################
    modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
    modified_nn_deep.load_state_dict(nn_deep.state_dict())
    
    print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
    print("Norm of 1st layer for modified_deep_nn:",torch.norm(modified_nn_deep.features[0].weight).item())
    
    torch.manual_seed(42)
    modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
    print(f"Norm of 1st layer:",torch.norm(modified_nn_light.features[0].weight).item())
    
    sample_input = torch.randn(128,3, 32, 32).to(device)
    
    logits, hidden_representation = modified_nn_light(sample_input)
    print("Student logits shape:", logits.shape)
    print("Student hidden representation shape:",hidden_representation.shape)
    
    logits, hidden_representation = modified_nn_deep(sample_input)
    print(f"Teacher logits shape: {logits.shape}")
    print(f"Teacher hidden representation shape: {hidden_representation.shape}")
    
    train_cosine_loss(teacher=modified_nn_deep,
                      student=modified_nn_light,
                      train_loader=train_loader,
                      epochs=10,
                      learning_rate=1e-3,
                      hidden_rep_loss_weight=0.25,
                      ce_loss_weight=0.75)
    test_acc_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light,
                                                              test_loader)

全部评论 (0)

还没有任何评论哟~