Advertisement

BERT多标签分类 (BERT Multi Label Classifier)

阅读量:

本文 在Multi_Label_Classifier_finetune 这个项目上进行改写

https://github.com/Vincent131499/Multi_Label_Classifier_finetune

centos安装的TensorFlow GPU版本 1.14.0

修改:

没有使用这个,加了这一句实验中会使用CPU,不使用GPU,应该和指定硬件有关

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

复制代码
 #os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    
 config = tf.ConfigProto()
    
 config.gpu_options.allow_growth = True
    
 session = tf.Session(config=config)

If you are writing GPU enabled code, you would typically use a device query to select the desired GPUs. However, a quick and easy solution for testing is to use the environment variable CUDA_VISIBLE_DEVICES to restrict the devices that your CUDA application sees. This can be useful if you are attempting to share resources on a node or you want your GPU enabled executable to target a specific GPU.

Environment Variable Syntax Results
CUDA_VISIBLE_DEVICES=1 Only device 1 will be seen
CUDA_VISIBLE_DEVICES=0,1 Devices 0 and 1 will be visible
CUDA_VISIBLE_DEVICES="0,1" Same as above, quotation marks are optional
CUDA_VISIBLE_DEVICES=0,2,3 Devices 0, 2, 3 will be visible; device 1 is masked

CUDA will enumerate the visible devices starting at zero. In the last case, devices 0, 2, 3 will appear as devices 0, 1, 2. If you change the order of the string to “2,3,0”, devices 2,3,0 will be enumerated as 0,1,2 respectively. If CUDA_VISIBLE_DEVICES is set to a device that does not exist, all devices will be masked. You can specify a mix of valid and invalid device numbers. All devices before the invalid value will be enumerated, while all devices after the invalid value will be masked.
To determine the device ID for the available hardware in your system, you can run NVIDIA’s deviceQuery executable included in the CUDA SDK.

新写了一个class,修改了原文中不能关联 label 数量的问题

复制代码
 class Test(DataProcessor):

    
  
    
     def get_train_examples(self, data_dir):
    
     filename = 'multi_train.csv'
    
     data_df = pd.read_csv(os.path.join(data_dir, filename))
    
     return self._create_examples(data_df, "train")
    
  
    
     def get_dev_examples(self, data_dir):
    
     """See base class."""
    
     filename = 'multi_dev.csv'
    
     data_df = pd.read_csv(os.path.join(data_dir, filename))
    
     return self._create_examples(data_df, "dev")
    
  
    
     def get_test_examples(self, data_dir):
    
     filename = 'multi_test.csv'
    
     data_df = pd.read_csv(os.path.join(data_dir, filename))
    
     return self._create_examples(data_df, "test")
    
  
    
     def get_labels(self):
    
     """See base class."""
    
     labels = list(pd.read_csv(os.path.join(FLAGS.data_dir, "classes.txt"), header=None)[0].values)
    
     self.num_label=len(labels)
    
     return labels
    
  
    
     def _create_examples(self, df, set_type):
    
     """Creates examples for the training and dev sets."""
    
     examples = []
    
     for (i, row) in enumerate(df.values):
    
         guid = int(row[0])
    
         text_a = row[1]
    
         if set_type == 'test':
    
             labels = [int(0) for i in range(0,self.num_label)]
    
         else: 
    
             labels = [int(a) for a in row[2:]]
    
         examples.append(
    
             InputExample(guid=guid, text_a=text_a, label=labels))
    
     return examples

label_ids,修改了原文中不能关联 label 数量的问题(还有更规范的工程方式,这边使用的是比较简便的办法)

复制代码
 def file_based_input_fn_builder(input_file, seq_length, is_training,

    
                             drop_remainder):
    
     """Creates an `input_fn` closure to be passed to TPUEstimator."""
    
     num_label = len(list(pd.read_csv(os.path.join(FLAGS.data_dir, "classes.txt"), header=None)[0].values))
    
     name_to_features = {
    
     "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
    
     "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
    
     "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
    
     #"label_ids": tf.FixedLenFeature([], tf.int64),
    
     "label_ids": tf.FixedLenFeature([num_label], tf.int64),        
    
     "is_real_example": tf.FixedLenFeature([], tf.int64),
    
     }

全部评论 (0)

还没有任何评论哟~