Advertisement

中文新闻标题分类

阅读量:
复制代码
    import os
    import sys
    import pickle
    import logging
    
    logging.basicConfig(
    format='>>> %(asctime)s %(levelname)s %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stdout,
    # force=True,
    )
    
    import numpy as np
    from sklearn.naive_bayes import MultinomialNB
    from sklearn.linear_model import PassiveAggressiveClassifier
    from sklearn.feature_extraction.text import TfidfVectorizer
    
    def load_txt(path, mode):
    '''
    TODO
    from random import shuffle
    '''
    with open(path, mode='rt', encoding='utf-8') as f:
        lines = f.readlines()
    if mode == 'tr':
        xl,yl= [],[]
        for line in lines:
            line = line.strip()
            parts = line.split('\t')
            if len(parts) == 2:
                string,lbl = parts
                xl.append(string)
                yl.append(int(lbl))
        return xl,yl
    elif mode == 'te':
        xl = []
        for line in lines:
            line = line.strip()
            xl.append(line)
        return xl
    else:
        raise KeyError
    
    def load_data():
    logging.info('-> load data')
    if os.path.exists('./data/dtr.pkl') and os.path.exists('./data/dte.pkl'):
        with open('./data/dtr.pkl', mode='rb') as f:
            dxtr,dytr = pickle.load(f)
        with open('./data/dte.pkl', mode='rb') as f:
            dxte = pickle.load(f)
    else:
        dxtr,dytr = load_txt('./data/dtr.txt', mode='tr')
        dxtr0,dytr0 = load_txt('./data/dtr.dev.txt', mode='tr')
        dxtr.extend(dxtr0)
        dytr.extend(dytr0)
        dxte = load_txt('./data/dte.txt', mode='te')
        with open('./data/dtr.pkl', mode='wb') as f:
            pickle.dump((dxtr,dytr), f)
        with open('./data/dte.pkl', mode='wb') as f:
            pickle.dump(dxte, f)
    return (dxtr,dytr),dxte
    
    def tfidf(xtr,xte,ltr,lte):
    logging.info('-> tfidf')
    if os.path.exists(f'./data/tfidf-{ltr}-{lte}.pkl'):
        with open(f'./data/tfidf-{ltr}-{lte}.pkl', mode='rb') as f:
            return pickle.load(f)
    else:
        assert ltr <= len(xtr) and lte <= len(xte)
        para = {'input': 'content', 'encoding': 'utf-8', 'decode_error': 'strict', 'strip_accents': None, 'lowercase': True, 'preprocessor': None, 'tokenizer': None, 'stop_words': None, 'token_pattern': '(.)', 'ngram_range': (1, 8), 'analyzer': 'word', 'max_df': 1.0, 'min_df': 2, 'max_features': None, 'vocabulary': None, 'binary': False, 'dtype': np.float64, 'norm': 'l2', 'use_idf': True, 'smooth_idf': True, 'sublinear_tf': True}
        vectorizer = TfidfVectorizer(**para)
        allvec = vectorizer.fit_transform(xtr[:ltr]+xte[:lte])
        with open(f'./data/tfidf-{ltr}-{lte}.pkl', mode='wb') as f:
            pickle.dump((allvec[:ltr],allvec[-lte:],), f)
        return allvec[:ltr],allvec[-lte:]
    
    def model(model_class,para,xtr,ytr,xte,ltr,lte):
    logging.info('-> model')
    assert ltr <= xtr.shape[0] and ltr <= len(ytr) and lte <= xte.shape[0]
    clf = model_class(**para)
    clf.fit(xtr[:ltr],ytr[:ltr])
    return clf.predict(xte[:lte])
    
    def pac(xtr,ytr,xte,ltr,lte):
    para = {'max_iter':1024, 'n_jobs':-1}
    return model(PassiveAggressiveClassifier,para,xtr,ytr,xte,ltr,lte)
    
    def nb(xtr,ytr,xte,ltr,lte):
    para = {}
    return model(MultinomialNB,para,xtr,ytr,xte,ltr,lte)
    
    def gen_sub_file(yte,lte):
    logging.info('-> gen sub file')
    assert lte <= len(yte)
    with open('./data/191300000.txt', mode='wt') as f:
        for idx in range(lte):
            f.write(str(yte[idx])+'\n')
    return None
    
    def main():
    logging.info('===== START =====')
    ltr,lte = 190000,10000
    (xtr,ytr),xte = load_data()
    ztr,zte = tfidf(xtr,xte,ltr,lte,)
    yte = pac(ztr,ytr,zte,ltr,lte,)
    gen_sub_file(yte,lte,)
    logging.info('=====  END  =====')
    return None
    
    if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


bert代码(pytorch&transformers)和数据集
(运行平台: 华为云, modelarts, pytorch1.8-cuda10.2-cudnn7-ubuntu18.04, GPU: 1*V100(32GB)|CPU: 8核 64GB.)

全部评论 (0)

还没有任何评论哟~