Advertisement

NLP-文本分类-TextCNN

阅读量:

NLP-文本分类-TextCNN

复制代码
    # -*- ecoding: utf-8 -*-
    # @Author: SuperLong
    # @Email: miu_zxl@163.com
    # @Time: 2024/8/28 9:19
    import os
    import argparse
    import time
    
    import torch.cuda
    import torch.nn as nn
    import pickle as pkl
    from pathlib import Path
    import torch.nn.functional as F
    from sklearn.metrics import accuracy_score
    from torch.utils.data import DataLoader, Dataset
    # import os
    # os.environ['CUDA_VISIBLE_DEVICES']='7'
    path = Path(__file__).parent
    
    
    def read_data(args, data_str):
    text, label, len_max = [], [], []
    with open(args.data_path + f"/{data_str}.txt", "r", encoding='utf-8') as f:
        for line in f.readlines():
            if not line:
                continue
            text_i, label_i = line.strip().split('\t')
            text.append(text_i), label.append(label_i)
            len_max.append(len(text_i))
    return text, label, max(len_max)
    
    
    def build_dict(args, train_text):
    word2idx = {"<PAD>": 0, "<UNK>": 1}
    for text_i in train_text:
        for word in text_i:
            word2idx[word] = word2idx.get(word, len(word2idx))
    embedding = nn.Embedding(len(word2idx), args.embedding_dim)
    pkl.dump([word2idx,embedding], open(args.emb_path, 'wb'))
    return word2idx, embedding
    
    class Datasets(Dataset):
    def __init__(self, text, label, len_max, dict_ids):
        self.text = text
        self.label = label
        self.len_max = len_max
        self.dict_ids = dict_ids
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        text_i = self.text[index][:self.len_max]
        text_ids = [self.dict_ids.get(word, 1) for word in text_i]
        text_ids = text_ids + [0] * (self.len_max - len(text_ids))
        label = self.label[index]
        return torch.tensor(text_ids).to(self.device).unsqueeze(dim=0), torch.tensor(int(label))
    
    
    class MaxPool(nn.Module):
    def __init__(self, kernel, len_max):
        super(MaxPool, self).__init__()
        self.kernel = kernel
        self.max_pool = nn.MaxPool1d(kernel_size=len_max - self.kernel + 1)
    
    def forward(self, x):
        return self.max_pool(x)
    
    
    class TextCNN(nn.Module):
    def __init__(self, args, Embeddings):
        super(TextCNN, self).__init__()
        self.emb = Embeddings
        self.embedding_dim = args.embedding_dim
        self.kernels = list(map(int, args.kernel_sizes.split(',')))
        self.num_filters = args.num_filters
        self.convs = nn.ModuleList([nn.Conv2d(1, self.num_filters, (k, self.embedding_dim)) for k in self.kernels])
        self.dropout = nn.Dropout(args.dropout)
        self.classifier = nn.Linear(self.num_filters * len(self.kernels), args.num_class)
    
    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x
    
    def forward(self, x):
        x = self.emb(x)
        out = torch.cat([self.conv_and_pool(x, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        return self.classifier(out)
    
    
    def main(args):
    start = time.time()
    
    train_text, train_label, len_max = read_data(args, "train")
    if not os.path.exists(args.emb_path):
        dict_ids, Embeddings = build_dict(args, train_text)
    else:
        dict_ids, Embeddings = pkl.load(open(args.emb_path, 'rb'))
    train_dataset = DataLoader(Datasets(train_text, train_label, len_max, dict_ids), batch_size=args.batch_size,
                               shuffle=True)
    dev_text, dev_label, _ = read_data(args, "dev")
    dev_dataset = DataLoader(Datasets(dev_text, dev_label, len_max, dict_ids), batch_size=args.batch_size, shuffle=True)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TextCNN(args, Embeddings).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    loss_func = nn.CrossEntropyLoss()
    best_acc = 0
    for epoch in range(args.epoch):
        model.train()
        train_loss, counts = 0, 0
        train_pred, train_true = [], []
        for step, (text, label) in enumerate(train_dataset):
            label = label.to(device)
            out = model(text)
            loss = loss_func(out, label)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
            counts += 1
    
            train_pred.extend(out.argmax(dim=1).cpu().numpy().tolist())
            train_true.extend(label.cpu().numpy().tolist())
            train_acc = accuracy_score(train_pred, train_true)
            if step % 100 == 0:
                print("Epoch:{0} step:{1} loss:{2:.4f} train_acc:{3:.4f}".format(epoch, step, train_loss / counts,
                                                                                 train_acc))
        model.eval()
        dev_pred, dev_true = [], []
        with torch.no_grad():
            for step,(text,label) in enumerate(dev_dataset):
                label = label.to(device)
                pred = model(text)
                dev_pred.extend(pred.argmax(dim=1).cpu().numpy().tolist())
                dev_true.extend(label.cpu().numpy().tolist())
            dev_acc = accuracy_score(dev_pred,dev_true)
            if dev_acc > best_acc:
                best_acc = dev_acc
                torch.save(model.state_dict(),args.model_path)
            print("Epoch:{0} dev_acc:{1:.4f} best_acc:{2:.4f}".format(epoch,dev_acc,best_acc))
    
    print("训练结束!开始测试")
    
    dev_text, dev_label, _ = read_data(args, "test")
    test_dataset = DataLoader(Datasets(dev_text, dev_label, len_max, dict_ids), batch_size=args.batch_size, shuffle=True)
    model = TextCNN(args, Embeddings).to(device)
    model.load_state_dict(torch.load('best_model-TextCNN.pth'))
    dev_pre, dev_true = [], []
    for step, (text, label) in enumerate(test_dataset):
        label = label.to(device)
        with torch.no_grad():
            out = model(text)
            pre = out.argmax(dim=1).cpu().numpy().tolist()
            label = label.cpu().numpy().tolist()
            dev_pre.extend(pre), dev_true.extend(label)
    acc = accuracy_score(dev_pre, dev_true)
    print("测试结果为:{:.4f}".format(acc))
    
    if __name__ == '__main__':
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument("--data_path", type=str, default=os.path.join(path, "data", "class_data"))
    args_parser.add_argument("--emb_path", type=str, default=os.path.join(path, "emb_words-TextCNN.pkl"))
    args_parser.add_argument("--model_path", type=str, default=os.path.join(path, "bset_model-TextCNN.pth"))
    args_parser.add_argument("--epoch", type=int, default=10)
    args_parser.add_argument("--batch_size", type=int, default=64)
    args_parser.add_argument("--learning_rate", type=float, default=0.001)
    args_parser.add_argument("--embedding_dim", type=int, default=1024)
    args_parser.add_argument("--kernel_sizes", type=str, default="3,4,5")
    args_parser.add_argument("--num_filters", type=int, default=6)
    args_parser.add_argument("--num_class", type=int, default=10)
    args_parser.add_argument("--dropout", type=float, default=0.2)
    args = args_parser.parse_args()
    main(args)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

目前仅专注于NLP的技术学习和分享
感谢大家的关注与支持!

全部评论 (0)

还没有任何评论哟~