BERT开源代码分析(一)——dataset数据加载模块

dataset模块

pytorch版本的BERT的源码链接
我将会为大家梳理代码,解读代码。并提出自己的一些见解
基本注释已经穿插在代码块中,另外一些看法单独拿出来说

1.vocab.py

#导包
import pickle
import tqdm
from collections import Counter

class TorchVocab(object):
    def __init__(self,counter,max_sizeNone,min_freq=1,specials=['',''],
                 vectors=None,unk_init=None,vectors_cache=None):
        
        """
        counter:

        类型:collections.Counter对象
        描述:用于存储数据中每个单词出现的频率。这个计数器对象通常是通过在数据集上进行词频统计得到的。
        max_size:

        类型:int或None
        默认值:None
        描述:词汇表的最大大小。如果设置为None,则不限制词汇表的大小。如果设置为一个具体的数字,词汇表的大小将限制为最多包含这么多单词。
        min_freq:

        类型:int
        默认值:1
        描述:包含在词汇表中的最小频率阈值。频率低于这个阈值的单词将不会被包括在词汇表中。小于1的值将被设置为1。
        specials:

        类型:字符串列表
        默认值:['', '']
        描述:一组特殊标记,它们将被添加到词汇表中。这些特殊标记通常包括用于填充()、表示未知单词()等特殊用途的标记。
        vectors:

        类型:预训练向量,可以是字符串列表
        默认值:None
        描述:用于指定预训练的词向量。这可以是预训练向量的名称,或者是用户自定义的预训练向量。
        unk_init:

        类型:回调函数
        默认值:torch.Tensor.zero_
        描述:用于初始化未知单词(OOV,out-of-vocabulary)向量的函数。默认情况下,未知单词的向量被初始化为零向量。
        vectors_cache:

        类型:字符串
        默认值:'.vector_cache'
        描述:用于缓存预训练向量的目录路径。如果提供了向量路径,则预训练向量将被下载或加载到这个目录中。
        """
        
        
        
        self.freqs=counter
        counter=counter.copy()
        min_freq=max(min_freq,1)
        
        self.itos=list(specials)
        
        for tok in specials:
            del counter[tok]
        
        max_size=None if max_size is None else max_size+len(self.itos)
        
        #按频率排序,然后按字母顺序排序
        words_and_frequencies = sorted(counter.items(),key=lambda tup:tup[0])#按字母顺序排序
        words_and_frequencies.sort(key=lambda tup:tup[1],reverse=True)#再按照频率从大到小排序
        
        for word,freq in words_and_frequencies:
            if freq<min_freq or len(self.itos)==max_size:
                break
            self.itos.append(word)
        
        #stoi就是itos的翻转字典
        self.stoi={tok:i for i,tok in enumerate(self.itos)}
        
        self.vectors=None
        if vectors is not None:
            self.load_vectors(vectors,unk_init=unk_init,cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None
    
    def __eq__(self,other):#判断两个字典是否相等
        if self.freqs!=other.freqs:
            return False
        if self.stoi!=other.stoi:
            return False
        if self.itos!=other.itos:
            return False
        if self.vectors!=other.vectors:
            return False
        return True
    
    def __len__(self):
        return len(self.itos)
    
    def vocab_rerank(self):#再次翻转为stoi
        self.stoi={word: i for i,word in enumerate(self.itos)}
    
    def extend(self,v,sort=False):#将另一个字典加入到当前字典中
        words=sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w]=len(self.itos)-1
        
class Vocab(TorchVocab): 
    def __init__(self,counter,max_size=None,min_freq=1):
        self.pad_index=0
        self.unk_index=1
        self.eos_index=2
        self.sos_index=3
        self.mask_index=4
        super().__init__(counter,specials=["","","","",""],max_size=max_size,min_freq=min_freq)
    
    def to_seq(self,sentence,seq_len,with_eos=False,with_sos=False)->list:#字符序列转换为数字序列
        pass
    def from_seq(self,seq,join=False,with_pad=False):#数字序列转换为字符序列
        pass
    
    @staticmethod
    def load_vocab(vocab_path:str)->'Vocab':#静态方法,利用pickle将字符加载出来
        with open(vocab_path,'rb') as f:
            return pickle.load(f)
    def save_vocab(self,vocab_path):#利用pickle保存数据
        with open(vocab_path,'wb' ) as f:
            pickle.dump(self,f)

class WordVocab(Vocab):
    def __init__(self,texts,max_size=None,min_freq=1):
        print("Building Vocab")
        counter=Counter()
        for line in tqdm.tqdm(texts):
            if isinstance(line,list):
                words=line
            else:
                words=line.replace("\n","").replace("\t","").split()
                
            for word in words:
                counter[word]+=1
        super().__init__(counter,max_size=max_size,min_freq=min_freq)
    def to_seq(self,sentence,seq_len=None,with_eos=False,with_sos=False,with_len=False):
        if isinstance(sentence,str):
            sentence=sentence.split()
        if with_eos:#如果with_eos为真,意味着需要在序列末尾添加结束标记
            seq+=[self.eos_index]
        if with_sos:#如果with_sos为真,意味着需要在序列开头添加开始标记。
            seq=[seq.sos_index]+seq
        
        origin_seq_len=len(seq)#记录转换后序列的长度
        
        if seq_len is None:
            pass
        elif len(seq)<=seq_len:
            seq+=[self.pad_index]*(seq_len-len(seq))
        else:
            seq=seq[:seq_len]
        
        return (seq,origin_seq_len) if with_len else seq #如果 with_len 为 True,函数将返回一个元组,其中origin_seq_len为不加其他标识符之前的句子
    
    def from_seq(self,seq,join=False,with_pad=False):
        """
        Args:
        
        seq:一个整数列表,通常表示单词的索引序列。
        join:一个布尔值,指示是否应该将转换后的单词列表拼接为一个字符串。
        with_pad:一个布尔值,指示是否应该在输出中包含填充(PAD)标记。
        """
        
        
        words=[self.itos[idx] 
               if idx<len(self.itos) 
               else "<%d>"% idx
               for idx in seq if not with_pad or idx!=self.pad_index]#当with_pad为True时,所有的索引对应的单词都将加载到words中
        return "".join(words) if join else words
    
    @staticmethod
    def load_vocab(vocab_path:str)->'WordVocab':#,用于从指定路径加载一个之前保存的WordVocab对象。
        with open(vocab_path,"r") as f:
            return pickle.load(f)

def build():
    import argparse

    parser = argparse.ArgumentParser()
    
    """
    -c 或 --corpus_path:指定语料库文件的路径。
    -o 或 --output_path:指定输出词汇表文件的路径。
    -s 或 --vocab_size:指定词汇表的最大大小。
    -e 或 --encoding:指定文件的编码格式,默认为utf-8。
    -m 或 --min_freq:指定单词在语料库中出现的最小频率。
    """
    
    parser.add_argument("-c", "--corpus_path", required=True, type=str)
    parser.add_argument("-o", "--output_path", required=True, type=str)
    parser.add_argument("-s", "--vocab_size", type=int, default=None)
    parser.add_argument("-e", "--encoding", type=str, default="utf-8")
    parser.add_argument("-m", "--min_freq", type=int, default=1)
    args = parser.parse_args()

    with open(args.corpus_path, "r", encoding=args.encoding) as f:
        vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) #读取语料库并构建词汇表

    print("VOCAB SIZE:", len(vocab)) #打印词汇表的大小(即其中包含的单词数量)
    vocab.save_vocab(args.output_path) #调用vocab.save_vocab 方法将词汇表保存到指定的输出文件路径            

可以看到vocab.py包含三个类,这三个类从上往下依次继承

  • TorchVocab:
    这个类是基础类,定义了词汇表的核心功能,如建立词汇表、索引到字符串(itos)和字符串到索引(stoi)的映射,处理预训练向量等。
    它包含了初始化方法(init),用于从词频计数器创建词汇表,以及一些基本方法,如比较两个词汇表是否相等(eq)、获取词汇表长度(len)、重新排列词汇表(vocab_rerank)和扩展当前词汇表(extend)。

  • Vocab:
    Vocab类继承自TorchVocab。它在TorchVocab的基础上增加了特定的特殊标记,如填充pad、未知单词unk、句末eos、句首sos和掩码mask。
    Vocab类提供了额外的方法,如to_seq和from_seq(这两个方法在提供的代码中没有具体实现,而是放在之后的WordVocab中实现了),以及用于保存和加载词汇表的方法。

  • WordVocab:
    WordVocab类继承自Vocab。它专门用于从文本数据构建词汇表。这个类的构造函数接受文本数据,计算单词频率,并使用这些信息来构建词汇表。
    它重写了to_seq和from_seq方法,这些方法用于将句子转换为索引序列,以及将索引序列转换回文本。

这三个类构成了一个层次结构,其中每个子类在其父类的基础上增加了更具体的功能。TorchVocab提供了基本的词汇表功能,Vocab在此基础上增加了对特殊标记的处理,而WordVocab则是专门针对从文本数据构建和使用词汇表的场景。这种设计使得代码更加模块化和可重用,同时也提供了灵活性,以便于根据特定需求进行扩展或修改。

2.dataset.py

from torch import Dataset
import tqdm
import torch
import random
class BERTDataset(Dataset):
    def __init__(self,corpus_path,vocab,seq_len,encoding="utf-8",corpus_lines=None,on_memory=True):
        self.vocab=vocab
        self.seq_len=seq_len
        
        self.on_memory=on_memory
        self.corpus_lines=corpus_lines
        self.corpus_path=corpus_path
        self.encoding=encoding
        
        with open(corpus_path,"r",encoding=encoding) as f:
            if self.corpus_lines is None and not on_memory:
                for _ in tqdm.tqdm(f,desc="Loading Dataset",total=corpus_lines):
                    self.corpus_lines+=1
                
                if on_memory:
                    self.lines=[line[:-1].split("\t") for line in tqdm.tqdm(f,desc="Loading Dataset",total=self.corpus_lines)]
                    self.corpus_lines=len(self.lines)
        if not on_memory:
            self.file=open(corpus_path,"r",encoding=encoding)
            self.random_file=open(corpus_path,"r",encoding=encoding)
            
            for _ in range(random.randint(self.corpus_lines if self.corpus_lines<1000 else 1000)):
                self.random_file.__next__()

    #返回语料的长度
    def __len__(self):
        return self.corpus_lines
    
    
    def __getitem__(self,item):
        t1,t2,is_next_label=self.random_sent(item)
        t1_random,t1_label=self.random_word(t1)
        t2_random,t2_label=self.random_word(t2)
        
        t1=[self.vocab.sos_index]+t1_random+[self.vocab.eos_index]#sos是文本开始标签 eos是文本结束也是分割标签
        t2=t2_random+[self.vocab.eos_index]
        
        #构建掩码标签
        t1_label=[self.vocab.pad_index]+t1_label+[self.vocab.pad_index]
        t2_label=t2_label+[self.vocab.pad_index]
        
        #生成段落标签
        segment_label=([1 for _ in range(len(t1))]+[2 for _ in range (len(t2))])[:self.seq_len]
        bert_input=(t1+t2)[:self.seq_len]
        bert_label=(t1_label+t2_label)[:self.seq_len]
        
        padding=[self.vocab.pad_index for _ in range(self.seq_len-len(bert_input))]
        bert_input.extend(padding),bert_label.extend(padding),segment_label.extend(padding)
        
        output={"bert_input":bert_input,"bert_label":bert_label,"segment_label":segment_label}
        
        return {key:torch.tensor(value) for key ,value in output.items()}
    
        
    
    def random_word(self,sentence):
        tokens=sentence.split()#将每个单词分词
        output_label=[]
        
        for i,token in enumerate(tokens):
            prob=random.random()
            if prob<0.15:
                prob/=0.15 #相对概率
                
                
                #80% randomly change token to mask token
                if prob<0.8:
                    tokens[i]=self.vocab.mask_index
                
                # 10% randomly change token to mask token
                elif prob<0.9:
                    tokens[i]=random.randrange(len(self.vocab))#返回一个随机的单词在词典中的索引
                
                # 10% randomly change token to current token
                else:
                    tokens[i]=self.vocab.stoi.get(token,self.vocab.unk_index)#返回本单词在词典中的索引
                
                output_label.append(self.vocab.stoi.get())#单词被替换,那么output_label需要记录ground_truth,即记录被替换的单词在词典中的索引
            else:
                tokens[i]=self.vocab.stoi.get(token,self.vocab.unk_index)
                output_label.append(0)
        
        return tokens,output_label
            
           
        
    """
    内存中读取 (self.on_memory 为 True):

    当 self.on_memory 为 True 时,函数直接从 self.lines 列表中返回索引为 item 的行,其中 self.lines[item][0] 和 self.lines[item][1] 分别代表该行的两个文本片段。
    文件中读取 (self.on_memory 为 False):

    如果 self.on_memory 为 False,函数尝试从打开的文件对象 self.file 中读取下一行。
    如果读取的行是 None(意味着到达了文件的末尾),它会关闭文件,重新打开文件,然后再次读取下一行。
    读取的行假定为一个用制表符(\t)分隔的两部分文本,这两部分文本通过 split("\t") 方法分割,并赋值给 t1 和 t2。
    """
    #"I like to eat pizza."\t"It is delicious."
    #         t1                        t2

        
    def get_corpus_line(self,item):
        if self.on_memory:
            return self.lines[item][0],self.lines[item][1]
        else:
            line=self.file.__next__()
            if line is None:
                self.file.close()
                self.file=open(self.corpus_path,"r",encoding=self.encoding)
                line=self.file.__next__()
                
            t1,t2=line[:-1].split("\t")
            return t1,t2
    
    #它的作用是从文本数据中随机选择并返回一行。它处理两种情况:数据全部加载到内存中,或者从文件中逐行读取
    def get_random_line(self):
        if self.on_memory:
            return self.line[random.randrange(len(self.lines))][1]
        line=self.file.__next__()
        if line is None:
            self.file.closee()
            self.file=open(self.corpus_path,"r",encoding=self.encoding)
            for _ in range(random.randint(self.corpus_lines if self.corpus_lines <1000 else 1000)):
                self.random_file.__next__()
            line=self.random_file.__next__()
        return line[:-1].split("\t")[1]
    
    
    #创建一个训练样本,其中包括从数据集中选定的文本行(可能是一个句子或者段落),以及一个标签,标识这个文本行是否是与另一个文本行“下一个”相关联(即是否是连续的文本行)
    def random_sent(self,index):
        t1,t2=self.get_corpus_line(index)
        
        if(random.random()>0.5):
            return t1,t2,1
        else:
            return t1,self.get_random_line(),0        

BERTDataset 类的核心功能是为训练 BERT 模型准备和处理数据。它处理文本数据,生成适合 BERT 训练的输入,包括掩码语言模型的输入和下一个句子预测任务的输入。通过随机替换、生成掩码和标签,这个类为 BERT 模型的预训练过程提供了必要的数据处理功能。

3. init.py

from .dataset import BERTDataset
from .vocab import WordVocab

这段代码就是象征性地引入两个上面提到的关键模块。

写在最后

了解一个模型,不能只了解他的核心架构,如果你有代码实现的需求,就必须把这些数据处理的部分掌握好,这些是模型运行的基础,更是成功的根基。

你可能感兴趣的:(机器学习—NLP领域,bert,人工智能,深度学习,python,nlp)