flyai文本分类

# -*- coding: utf-8 -*-
import os
import argparse
import tensorflow as tf
import keras
from keras.layers import *
from keras.optimizers import RMSprop
from sklearn.model_selection import train_test_split
from flyai.framework import FlyAI
from flyai.data_helper import DataHelper
from path import MODEL_PATH, DATA_PATH
import pandas as pd
import numpy as np
from data_helper import load_dict, load_labeldict, get_batches, read_data, get_val_batch


'''
此项目为FlyAI2.0新版本框架,数据读取,评估方式与之前不同
2.0框架不再限制数据如何读取
样例代码仅供参考学习,可以自己修改实现逻辑。
模版项目下载支持 PyTorch、Tensorflow、Keras、MXNET、scikit-learn等机器学习框架
第一次使用请看项目中的:FlyAI2.0竞赛框架使用说明.html
使用FlyAI提供的预训练模型可查看:https://www.flyai.com/models
学习资料可查看文档中心:https://doc.flyai.com/
常见问题:https://doc.flyai.com/question.html
遇到问题不要着急,添加小姐姐微信,扫描项目里面的:FlyAI小助手二维码-小姐姐在线解答您的问题.png
'''

# 项目的超参,不使用可以删除
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=32, type=int, help="batch size")
args = parser.parse_args()


class Main(FlyAI):
    '''
    项目中必须继承FlyAI类,否则线上运行会报错。
    '''
    def download_data(self):
        # 下载数据
        data_helper = DataHelper()
        data_helper.download_from_ids("MedicalClass")
        print('=*=数据下载完成=*=')

    def deal_with_data(self):
        '''
        处理数据,没有可不写。
        :return:
        '''
        # 加载数据
        self.data = pd.read_csv(os.path.join(DATA_PATH, 'MedicalClass/train.csv'))
        # 划分训练集、测试集
        self.train_data, self.valid_data = train_test_split(self.data, test_size=0.01, random_state=6, shuffle=True)
        self.text2id, _ = load_dict(os.path.join(DATA_PATH, 'MedicalClass/words_fr.dict'))
        self.label2id, _ = load_labeldict(os.path.join(DATA_PATH, 'MedicalClass/label.dict'))
        self.train_text, self.train_label = read_data(self.train_data, self.text2id, self.label2id)
        self.val_text, self.val_label = read_data(self.valid_data, self.text2id, self.label2id)
        print('=*=数据处理完成=*=')

    def train(self):
        rnn_unit_1 = 128    # RNN层包含cell个数
        embed_dim = 64      # 嵌入层大小
        class_num = len(self.label2id)
        num_word = len(self.text2id)
        batch_size = args.BATCH
        MAX_SQUES_LEN = 68  # 最大句长

        text_input = Input(shape=(MAX_SQUES_LEN,), dtype='int32')
        embedden_seq = Embedding(input_dim=num_word, output_dim=embed_dim, input_length=MAX_SQUES_LEN)(text_input)
        BN1 = BatchNormalization()(embedden_seq)
        bGRU1 = Bidirectional(GRU(rnn_unit_1, activation='selu', return_sequences=True,
                                  implementation=1), merge_mode='concat')(BN1)
        bGRU2 = Bidirectional(GRU(rnn_unit_1, activation='selu', return_sequences=True,
                                  implementation=1), merge_mode='concat')(bGRU1)

        drop = Dropout(0.5)(bGRU2)
        avgP = GlobalAveragePooling1D()(drop)
        maxP = GlobalMaxPooling1D()(drop)

        conc = concatenate([avgP, maxP])

        pred = Dense(class_num, activation='softmax')(conc)
        k_model = keras.Model(text_input, pred)
        k_model.summary()
        k_model.compile(optimizer=RMSprop(lr=0.0005), loss='categorical_crossentropy', metrics=['acc'])

        batch_nums = int(len(self.train_data)/batch_size)
        best_score = 0
        for epoch in range(args.EPOCHS):
            for batch_i, (x_train, y_train) in \
                    enumerate(get_batches(self.train_text, self.train_label,
                                          batch_size=batch_size, text_padding=self.text2id['_pad_'])):

                history = k_model.fit(np.array(x_train), np.array(y_train), batch_size=batch_size, verbose=0)
                CurEpoch = str(epoch+1) + "/" + str(args.EPOCHS)
                CurBatch = str(batch_i+1) + "/" + str(batch_nums)
                print('CurEpoch: {} | CurBatch: {}| ACC: {}'.format(CurEpoch, CurBatch, history.history['acc'][0]))

                x_val, y_val = get_val_batch(self.val_text, self.val_label,
                                             batch_size=1024, text_padding=self.text2id['_pad_'])
                if batch_i % 100 == 0:
                    score = k_model.evaluate(np.array(x_val), np.array(y_val), batch_size=1024)
                    acc = score[1]
                    if acc > best_score:
                        best_score = acc
                        k_model.save(os.path.join(MODEL_PATH, 'model.h5'))
                    print('best acc:', best_score)


if __name__ == '__main__':
    main = Main()
    main.download_data()
    main.deal_with_data()
    main.train()

    exit(0)

 

你可能感兴趣的:(flyai文本分类)