- 本文为365天深度学习训练营 中的学习记录博客
- 原作者:K同学啊
这是一个使用PyTorch通过Transformer算法实现简单的文本分类实战案例。
import torch, torchvision
print(torch.__version__)
print(torchvision.__version__)
2.3.0.post100
0.18.1
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
warnings.filterwarnings("ignore") ## 忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type=‘cpu’)
import pandas as pd
train_data = pd.read_csv('train.csv', sep='\t', header=None)
train_data.head()
0 | 1 | |
---|---|---|
0 | 还有双鸭山到淮阴的汽车票吗13号的 | Travel-Query |
1 | 从这里怎么回家 | Travel-Query |
2 | 随便播放一首专辑阁楼里的佛里的歌 | Music-Play |
3 | 给看一下墓王之王嘛 | FilmTele-Play |
4 | 我想看挑战两把s686打突变团竞的游戏视频 | Video-Play |
## 构造数据集迭代器
def coustom_data_iter(texts, labels):
for x, y in zip(texts, labels):
yield x, y
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba
## 中文分词法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
for text,_ in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["" ])
vocab.set_default_index(vocab["" ]) ## 设置为默认索引,如果找不到单词,则会选择默认索引
Building prefix dict from the default dictionary …
Loading model from cache /var/folders/8g/07dt5gr9275f63l59fkg9_n00000gn/T/jieba.cache
Loading model cost 0.196 seconds.
Prefix dict has been built successfully.
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)
[‘Music-Play’, ‘Travel-Query’, ‘Audio-Play’, ‘Radio-Listen’, ‘Video-Play’, ‘Alarm-Update’, ‘Calendar-Query’, ‘FilmTele-Play’, ‘Other’, ‘Weather-Query’, ‘HomeAppliance-Control’, ‘TVProgram-Play’]
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)
print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
4
from torch.utils.data import DataLoader
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_text, _label) in batch:
## 标签数据
label_list.append(label_pipeline(_label))
## 文本列表
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
## 偏移量,即语句的总词汇量
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.cat(text_list)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) ## 返回维度dim中输入元素的累积和
return text_list.to(device), label_list.to(device), offsets.to(device)
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
BATCH_SIZE = 4
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,
[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
import math,os, torch
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_len=500):
super(PositionalEncoding, self).__init__()
## 创建一个大小为[max_len, embed_dim]的零张量
pe = torch.zeros(max_len, embed_dim)
## 创建一个形状为[max_len, 1]的位置张量
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(100.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term) ## 计算PE(pos, 2i)
pe[:, 1::2] = torch.cos(position * div_term) ## 计算PE(pos, 2i+1)
pe = pe.unsqueeze(0).transpose(0,1)
## 将位置编码张量注册为模型的缓冲区,参数不参与梯度下降,保存model的时候会将其保存下来
self.register_buffer('pe', pe)
def forward(self, x):
## 将位置编码添加到输入张量中,注意位置编码的形状
x = x + self.pe[:x.size(0)]
return x
from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class, nhead=8, d_hid=256, nlayers=12, dropout=0.1):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, ## 词典大小
embed_dim, ## 词嵌入的维度
sparse=False)
self.pos_encoder = PositionalEncoding(embed_dim)
## 定义编码器层
encoder_layers = TransformerEncoderLayer(embed_dim, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.embed_dim = embed_dim
self.linear = nn.Linear(embed_dim*4, num_class)
def forward(self, src, offsets, src_mask=None):
src = self.embedding(src, offsets)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = output.view(4, embed_dim*4)
output = self.linear(output)
return output
vocab_size = len(vocab) ## 词汇表的大小
embed_dim = 64 ## 嵌入维度
num_class = len(label_name)
## 创建Transformer模型,并将其移动到设备上
model = TransformerModel(vocab_size,
embed_dim,
num_class).to(device)
import time
def train(dataloader):
total_acc, train_loss, total_count = 0, 0, 0
log_interval = 300
start_time = time.time()
for idx, (text,label,offsets) in enumerate(dataloader):
predicted_label = model(text, offsets)
optimizer.zero_grad() ## grad属性归零
loss = criterion(predicted_label, label) ## 计算网络输出和真实值之间的差距,label为真实值
loss.backward() ## 反向传播
optimizer.step() ## 每一步自动更新
## 纪录acc与loss
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:1d} | {:4d}/{:4d} batches'
'| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
total_acc/total_count, train_loss/total_count))
total_acc, train_loss, total_count = 0, 0, 0
start_time = time.time()
def evaluate(dataloader):
model.eval() ## 切换为测试模式
total_acc, train_loss, total_count = 0, 0, 0
with torch.no_grad():
for idx, (text,label,offsets) in enumerate(dataloader):
predicted_label = model(text, offsets)
loss = criterion(predicted_label, label) ## 计算loss值
## 纪录测试数据
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
return total_acc/total_count, train_loss/total_count
## 超参数
EPOCHS = 10
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) ## 使用Adam优化器,并将学习率设计为1e-4
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
val_acc, val_loss = evaluate(valid_dataloader)
## 获取当前的学习率
lr = optimizer.state_dict()['param_groups'][0]['lr']
print('-' * 69)
print('| epoch {:1d} valid_loss{:4.3f} | lr{:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))
print('-' * 69)
| epoch 1 | 300/2420 batches| train_acc 0.096 train_loss 0.60036
| epoch 1 | 600/2420 batches| train_acc 0.102 train_loss 0.59268
| epoch 1 | 900/2420 batches| train_acc 0.126 train_loss 0.58958
| epoch 1 | 1200/2420 batches| train_acc 0.105 train_loss 0.59722
| epoch 1 | 1500/2420 batches| train_acc 0.122 train_loss 0.59231
| epoch 1 | 1800/2420 batches| train_acc 0.107 train_loss 0.58772
| epoch 1 | 2100/2420 batches| train_acc 0.103 train_loss 0.59647
| epoch 1 | 2400/2420 batches| train_acc 0.137 train_loss 0.58352
---------------------------------------------------------------------
| epoch 1 valid_loss23.072 | lr0.153719
---------------------------------------------------------------------
| epoch 2 | 300/2420 batches| train_acc 0.149 train_loss 0.57707
| epoch 2 | 600/2420 batches| train_acc 0.168 train_loss 0.57212
| epoch 2 | 900/2420 batches| train_acc 0.168 train_loss 0.57754
| epoch 2 | 1200/2420 batches| train_acc 0.153 train_loss 0.58067
| epoch 2 | 1500/2420 batches| train_acc 0.171 train_loss 0.57044
| epoch 2 | 1800/2420 batches| train_acc 0.191 train_loss 0.56242
| epoch 2 | 2100/2420 batches| train_acc 0.182 train_loss 0.56749
| epoch 2 | 2400/2420 batches| train_acc 0.207 train_loss 0.55959
---------------------------------------------------------------------
| epoch 2 valid_loss28.620 | lr0.190909
---------------------------------------------------------------------
| epoch 3 | 300/2420 batches| train_acc 0.190 train_loss 0.56846
| epoch 3 | 600/2420 batches| train_acc 0.219 train_loss 0.55344
| epoch 3 | 900/2420 batches| train_acc 0.202 train_loss 0.55777
| epoch 3 | 1200/2420 batches| train_acc 0.195 train_loss 0.55114
| epoch 3 | 1500/2420 batches| train_acc 0.209 train_loss 0.55462
| epoch 3 | 1800/2420 batches| train_acc 0.212 train_loss 0.55450
| epoch 3 | 2100/2420 batches| train_acc 0.228 train_loss 0.54016
| epoch 3 | 2400/2420 batches| train_acc 0.240 train_loss 0.54188
---------------------------------------------------------------------
| epoch 3 valid_loss28.475 | lr0.210331
---------------------------------------------------------------------
| epoch 4 | 300/2420 batches| train_acc 0.218 train_loss 0.53826
| epoch 4 | 600/2420 batches| train_acc 0.257 train_loss 0.52673
| epoch 4 | 900/2420 batches| train_acc 0.238 train_loss 0.54038
| epoch 4 | 1200/2420 batches| train_acc 0.242 train_loss 0.53388
| epoch 4 | 1500/2420 batches| train_acc 0.242 train_loss 0.53499
| epoch 4 | 1800/2420 batches| train_acc 0.244 train_loss 0.53181
| epoch 4 | 2100/2420 batches| train_acc 0.239 train_loss 0.53427
| epoch 4 | 2400/2420 batches| train_acc 0.255 train_loss 0.53116
---------------------------------------------------------------------
| epoch 4 valid_loss28.343 | lr0.257851
---------------------------------------------------------------------
| epoch 5 | 300/2420 batches| train_acc 0.272 train_loss 0.51579
| epoch 5 | 600/2420 batches| train_acc 0.273 train_loss 0.51938
| epoch 5 | 900/2420 batches| train_acc 0.300 train_loss 0.51277
| epoch 5 | 1200/2420 batches| train_acc 0.312 train_loss 0.50920
| epoch 5 | 1500/2420 batches| train_acc 0.318 train_loss 0.49847
| epoch 5 | 1800/2420 batches| train_acc 0.320 train_loss 0.49063
| epoch 5 | 2100/2420 batches| train_acc 0.321 train_loss 0.49132
| epoch 5 | 2400/2420 batches| train_acc 0.337 train_loss 0.48170
---------------------------------------------------------------------
| epoch 5 valid_loss28.563 | lr0.331818
---------------------------------------------------------------------
| epoch 6 | 300/2420 batches| train_acc 0.355 train_loss 0.47457
| epoch 6 | 600/2420 batches| train_acc 0.369 train_loss 0.47280
| epoch 6 | 900/2420 batches| train_acc 0.398 train_loss 0.44898
| epoch 6 | 1200/2420 batches| train_acc 0.370 train_loss 0.45624
| epoch 6 | 1500/2420 batches| train_acc 0.418 train_loss 0.43283
| epoch 6 | 1800/2420 batches| train_acc 0.379 train_loss 0.45049
| epoch 6 | 2100/2420 batches| train_acc 0.401 train_loss 0.43735
| epoch 6 | 2400/2420 batches| train_acc 0.428 train_loss 0.42844
---------------------------------------------------------------------
| epoch 6 valid_loss28.588 | lr0.426860
---------------------------------------------------------------------
| epoch 7 | 300/2420 batches| train_acc 0.453 train_loss 0.42141
| epoch 7 | 600/2420 batches| train_acc 0.468 train_loss 0.39567
| epoch 7 | 900/2420 batches| train_acc 0.480 train_loss 0.39902
| epoch 7 | 1200/2420 batches| train_acc 0.496 train_loss 0.39537
| epoch 7 | 1500/2420 batches| train_acc 0.516 train_loss 0.38615
| epoch 7 | 1800/2420 batches| train_acc 0.498 train_loss 0.39730
| epoch 7 | 2100/2420 batches| train_acc 0.533 train_loss 0.37394
| epoch 7 | 2400/2420 batches| train_acc 0.547 train_loss 0.37372
---------------------------------------------------------------------
| epoch 7 valid_loss28.788 | lr0.552479
---------------------------------------------------------------------
| epoch 8 | 300/2420 batches| train_acc 0.553 train_loss 0.35435
| epoch 8 | 600/2420 batches| train_acc 0.609 train_loss 0.33628
| epoch 8 | 900/2420 batches| train_acc 0.590 train_loss 0.34807
| epoch 8 | 1200/2420 batches| train_acc 0.627 train_loss 0.32164
| epoch 8 | 1500/2420 batches| train_acc 0.627 train_loss 0.32383
| epoch 8 | 1800/2420 batches| train_acc 0.625 train_loss 0.32545
| epoch 8 | 2100/2420 batches| train_acc 0.645 train_loss 0.31090
| epoch 8 | 2400/2420 batches| train_acc 0.645 train_loss 0.31012
---------------------------------------------------------------------
| epoch 8 valid_loss28.297 | lr0.623554
---------------------------------------------------------------------
| epoch 9 | 300/2420 batches| train_acc 0.683 train_loss 0.29333
| epoch 9 | 600/2420 batches| train_acc 0.682 train_loss 0.28600
| epoch 9 | 900/2420 batches| train_acc 0.683 train_loss 0.28948
| epoch 9 | 1200/2420 batches| train_acc 0.684 train_loss 0.27552
| epoch 9 | 1500/2420 batches| train_acc 0.678 train_loss 0.28707
| epoch 9 | 1800/2420 batches| train_acc 0.720 train_loss 0.25827
| epoch 9 | 2100/2420 batches| train_acc 0.688 train_loss 0.28523
| epoch 9 | 2400/2420 batches| train_acc 0.739 train_loss 0.24568
---------------------------------------------------------------------
| epoch 9 valid_loss28.608 | lr0.675620
---------------------------------------------------------------------
| epoch 10 | 300/2420 batches| train_acc 0.731 train_loss 0.25230
| epoch 10 | 600/2420 batches| train_acc 0.735 train_loss 0.24387
| epoch 10 | 900/2420 batches| train_acc 0.734 train_loss 0.24514
| epoch 10 | 1200/2420 batches| train_acc 0.739 train_loss 0.23629
| epoch 10 | 1500/2420 batches| train_acc 0.732 train_loss 0.24510
| epoch 10 | 1800/2420 batches| train_acc 0.733 train_loss 0.24110
| epoch 10 | 2100/2420 batches| train_acc 0.724 train_loss 0.25269
| epoch 10 | 2400/2420 batches| train_acc 0.756 train_loss 0.22635
---------------------------------------------------------------------
| epoch 10 valid_loss28.491 | lr0.714876
---------------------------------------------------------------------
test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
模型准确率为:0.7045
本周学习了如何使用Transformer完成文本分类任务,更加深入地了解了Transformer模型。开始模型在测试集上的准确率只有50%,后期通过调整优化器(SGD–>Adam)和学习率(1e-2 --> 1e-4)使得测试集准确率达到70%。