人名分类器(RNN案例)

案例介绍:

  • 人名分类案例是多分类问题,根据人名预测属于哪个国家
  • 人名->x,国家->y
  • 监督学习,历史数据中已知y

 案例步骤:

1.数据预处理

 获取常用字符以及国家类别

# 导入torch工具
from cProfile import label
import torch
# 导入nn准备构建模型
import torch.nn as nn
import torch.optim as optim
from jax.experimental.rnn import rnn_bwd_p
from sympy.physics.vector.printing import params
# 导入torch的数据源 数据迭代器工具包
from torch.utils.data import DataLoader, Dataset
# 用于获得常见字母及字符规范化
import string
# 导入时间工具包
import time
# 引入制图工具包
import matplotlib.pyplot as plt

# todo:1-获取字符数量,26个大小写字母和常用标点符号,57个字符数量->one-hot编码维度数对应
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
print('all_letters(字符)-->', all_letters)
print('n_letters(字符长度)-->', n_letters)

# todo:2-获取国家和数量   国家名->预测结果   数量->输出层的输出特征数
categories = ['Italian', 'English', 'Arabic', 'Spanish', 'Scottish', 'Irish', 'Chinese', 'Vietnamese', 'Japanese',
              'French', 'Greek', 'Dutch', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Czech', 'German']
n_categories = len(categories)
print('categories(国家)-->', categories)
print('n_categories(国家长度)-->', n_categories)

加载数据获取x和y列表

def read_data(filepath):
    # todo:1-定义两个空列表,用来分别存储x和y数据
    x_list, y_list = [], []
    # todo:2-将数据文件读取到内存中
    with open(filepath, 'r', encoding='utf8') as f:
        # todo:3-循环遍历文本中的每行数据
        for line in f.readlines():
            # todo:4-删除长度小于等于5的行数据:异常样本
            if len(line) <= 5:
                continue  # 跳出当前循环,进入下一步循环
            # todo:5-先将\n删除,然后再根据\t分割符对每行数据进行分割,x和y
            x, y = line.strip().split('\t')
            # print('x-->', x)
            # print('y-->', y)
            # todo:6-将每行数据的x和y分别保存到列表中
            x_list.append(x)
            y_list.append(y)
        # print('x_list的样本数-->', x_list)
        # print('y_list的样本数-->', y_list)
    return x_list, y_list

构建张量数据源 

# todo:1-创建张量数据集类
class NameClassDataset(Dataset):
    # todo:2-定义构造方法 __init__方法
    def __init__(self, x_list, y_list):
        # 定义三个属性
        self.x_list = x_list
        self.y_list = y_list
        # 计算样本长度
        self.sample_len = len(x_list)

    # todo:3-定义__len__方法,调用len()函数的时候,返回魔法方法的值
    def __len__(self):
        return self.sample_len

    # todo:4-定义__getitem__方法,调用索引的时候,返回魔法方法的值
    def __getitem__(self, index):
        # 修正索引,解决越界问题,不报错
        index = min(max(index, 0), self.sample_len - 1)
        # print(f'index(索引)-->', index)
        # 获取index下对应的x和y的数据
        x = self.x_list[index]
        y = self.y_list[index]
        # print(f'x(x数据)-->', x)
        # print(f'y(y数据)-->', y)
        # 对x进行one-hot编码
        # 每个tensor_x都是一个二位向量,第一个维度:人名的字符数量(词数量),第二个维度:词向量维度->词汇表的长度
        tensor_x = torch.zeros(size=(len(x), n_letters))
        # print(f'tensor_x: {tensor_x}, tensor_x.shape: {tensor_x.shape}')
        for i, letter in enumerate(x):  # enumerate(x)->{0:A, 1:b, 2:l}
            # tensor_x对应的索引的位置上设置为1
            # [i]->获取二维全0张量的第i个样本(字符), 返回一维张量
            # all_letters.find(letter)->根据字符获取字符串中的对应位置下标
            # [i][all_letters.find(letter)] -> 获取一维张量中字符所在位置, 将位置的0值修改为1
            tensor_x[i][all_letters.find(letter)] = 1
        # categories.index(y): 根据y元素获取categories列表中的位置下标
        tensor_y = torch.tensor(data=categories.index(y), dtype=torch.long)
        # print(f'tensor_x(one-hot)--->: {tensor_x}, tensor_x.shape: {tensor_x.shape}')
        # print(f'tensor_y(one-hot): {tensor_y}, tensor_y.shape: {tensor_y.shape}')
        return tensor_x, tensor_y

2.构建RNN模型 

传统RNN模型 

class RNN(nn.Module):
    # todo:1-定义构造方法,定义模型参数
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super().__init__()
        # 定义属性
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        # 定义rnn层
        self.rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers)
        # 定义输出层
        # 线性计算
        # in_features:rnn隐藏层维度数 128
        # out_feature:国家类别个数 18
        self.linear = nn.Linear(in_features=self.hidden_size, out_features=self.output_size)
        # logsoftmax计算
        # Log

你可能感兴趣的:(rnn,深度学习,人工智能)