在应用BERT模型之前,需要去github上下载开源代码,可以直接clone下来,在这里有一个run_classifier.py文件,在做文本分类项目时,需要修改这个文件,添加数据预处理类。run_classifier.py中的DataProcessor类如下所示:
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
其中定义了四个方法get_train_examples,get_dev_examples,get_test_examples,get_labels。从方法命名可以看出各个方法需要在模型中进行重写。
而方法_read_tsv已经写好实现代码,是从tsv文件中读取数据。
通过DataProcessor类来仿照bert中的XnliProcessor等类实现读取所需训练的数据集的类如下:
class NewsProcessor(bert.run_classifier.DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def __init__(self):
self.labels = set()
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "cnews_train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "cnews_test.tsv")), "dev")
def get_labels(self):
"""See base class."""
return list(self.labels)
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
self.labels.add(label)
examples.append(
bert.run_classifier.InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
其中_create_examples方法读取数据集中的每一行,label取line[0],text_a取line[1],并把取得的label放入到self.labels中。然后使用InputExample类来得到examples。
InputExample是run_classifier.py中定义的一个类,代码如下:
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
InputExample类定义了text_a和text_b,说明是支持句子对的输入的,在做文本分类只有一个句子的输入,因此text_b可以不传参。
另外从上面自定义的数据处理类中可以看出,训练集和验证集是保存在不同文件中的,因此需要将之前预处理好的数据提前分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。
在获取features的时候需要用到run_classifier中的convert_examples_to_features方法,其代码如下:
# This function is not used by this file but is still used by the Colab and people who depend on it.
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
features.append(feature)
return features
方法中的参数examples就是上面处理数据获取的examples。max_seq_length是提前定义好得到的句子最大长度。
(未完待续)
参考
文本分类实战(十)—— BERT 预训练模型
Bert系列(一)——demo运行