bert4torch(参考bert4keras的pytorch实现)

背景

本人经常会阅读苏神的科学空间网站,里面有很多对前言paper浅显易懂的解释,以及很多苏神自己的创新实践;并且基于bert4keras框架都有了相应的代码实现。但是由于本人主要用pytorch开发,因此参考bert4keras开发了bert4torch项目,实现了bert4keras的主要功能。

简介

bert4torch是一个基于pytorch的训练框架,前期以效仿和实现bert4keras的主要功能为主,方便加载多类预训练模型进行finetune,提供了中文注释方便用户理解模型结构。主要是期望应对新项目时,可以直接调用不同的预训练模型直接finetune,或方便用户基于bert进行魔改,快速验证自己的idea;节省在github上clone各种项目耗时耗力,且本地文件各种copy的问题。

  • pip安装
pip install bert4torch
  • github链接

主要功能

1、加载预训练权重(bert、roberta、albert、nezha、bart、RoFormer、ELECTRA、GPT、GPT2、T5)继续进行finetune

bert4torch(参考bert4keras的pytorch实现)_第1张图片

目前支持的预训练模型一览

2、在bert基础上灵活定义自己模型:主要是可以接在bert的[btz, seq_len, hdsz]的隐含层向量后做各种魔改

3、调用方式和bert4keras基本一致,简洁高效

    model.fit(
        train_dataloader,
        steps_per_epoch=1000,
        epochs=epochs,
        callbacks=[evaluator]
    )

4、实现基于keras的训练进度条动态展示

bert4torch(参考bert4keras的pytorch实现)_第2张图片

仿照keras的模型训练进度条

5、配合torchinfo,实现打印各层参数量功能

bert4torch(参考bert4keras的pytorch实现)_第3张图片

打印参数

6、结合logger,或者tensorboard可以在后台打印日志

支持在训练开始/结束,batch开始/结束,epoch的开始/结束,记录日志,写tensorboard等

class Callback(object):
    '''Callback基类
    '''
    def __init__(self):
        pass
    def on_train_begin(self, logs=None):
        pass
    def on_train_end(self, logs=None):
        pass
    def on_epoch_begin(self, global_step, epoch, logs=None):
        pass
    def on_epoch_end(self, global_step, epoch, logs=None):
        pass
    def on_batch_begin(self, global_step, batch, logs=None):
        pass
    def on_batch_end(self, global_step, batch, logs=None):
        pass

7、集成多个example,可以作为自己的训练框架,方便在同一个数据集上尝试多种解决方案

bert4torch(参考bert4keras的pytorch实现)_第4张图片

实现多个example可供参考

未来计划

  • Transformer-XL、XLnet等其他网络架构
  • 前沿的各类模型idea实现,如苏神科学空间网站的诸多idea

你可能感兴趣的:(nlp,深度学习,python,人工智能,bert,nlp)