NLP入门-Task10 BERT(一)

BERT

  • bert实现文本分类代码解析

bert实现文本分类代码解析

在应用BERT模型之前,需要去github上下载开源代码,可以直接clone下来,在这里有一个run_classifier.py文件,在做文本分类项目时,需要修改这个文件,添加数据预处理类。run_classifier.py中的DataProcessor类如下所示:

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

其中定义了四个方法get_train_examples,get_dev_examples,get_test_examples,get_labels。从方法命名可以看出各个方法需要在模型中进行重写。
而方法_read_tsv已经写好实现代码,是从tsv文件中读取数据。
通过DataProcessor类来仿照bert中的XnliProcessor等类实现读取所需训练的数据集的类如下:

class NewsProcessor(bert.run_classifier.DataProcessor):
    """Processor for the MRPC data set (GLUE version)."""

    def __init__(self):
        self.labels = set()

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "cnews_train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "cnews_test.tsv")), "dev")

    def get_labels(self):
        """See base class."""
        return list(self.labels)

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = tokenization.convert_to_unicode(line[1])
            label = tokenization.convert_to_unicode(line[0])
            self.labels.add(label)
            examples.append(
                 bert.run_classifier.InputExample(guid=guid, text_a=text_a, text_b=None, label=label))

        return examples

其中_create_examples方法读取数据集中的每一行,label取line[0],text_a取line[1],并把取得的label放入到self.labels中。然后使用InputExample类来得到examples。
InputExample是run_classifier.py中定义的一个类,代码如下:

class InputExample(object):
 """A single training/test example for simple sequence classification."""

 def __init__(self, guid, text_a, text_b=None, label=None):
   """Constructs a InputExample.

   Args:
     guid: Unique id for the example.
     text_a: string. The untokenized text of the first sequence. For single
       sequence tasks, only this sequence must be specified.
     text_b: (Optional) string. The untokenized text of the second sequence.
       Only must be specified for sequence pair tasks.
     label: (Optional) string. The label of the example. This should be
       specified for train and dev examples, but not for test examples.
   """
   self.guid = guid
   self.text_a = text_a
   self.text_b = text_b
   self.label = label

InputExample类定义了text_a和text_b,说明是支持句子对的输入的,在做文本分类只有一个句子的输入,因此text_b可以不传参。
另外从上面自定义的数据处理类中可以看出,训练集和验证集是保存在不同文件中的,因此需要将之前预处理好的数据提前分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。

在获取features的时候需要用到run_classifier中的convert_examples_to_features方法,其代码如下:

# This function is not used by this file but is still used by the Colab and people who depend on it.
def convert_examples_to_features(examples, label_list, max_seq_length,
                                tokenizer):
 """Convert a set of `InputExample`s to a list of `InputFeatures`."""

 features = []
 for (ex_index, example) in enumerate(examples):
   if ex_index % 10000 == 0:
     tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

   feature = convert_single_example(ex_index, example, label_list,
                                    max_seq_length, tokenizer)

   features.append(feature)
 return features

方法中的参数examples就是上面处理数据获取的examples。max_seq_length是提前定义好得到的句子最大长度。

(未完待续)

参考
文本分类实战(十)—— BERT 预训练模型
Bert系列(一)——demo运行

你可能感兴趣的:(人工智能)