从代码学习深度学习 - RNN PyTorch版

文章目录

  • 前言
  • 一、数据预处理
  • 二、辅助训练工具函数
  • 三、绘图工具函数
  • 四、模型定义
  • 五、模型训练与预测
  • 六、实例化模型并训练
    • 训练结果可视化
  • 总结


前言

循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一个基于 PyTorch 的 RNN 实现,结合《The Time Machine》数据集,带你从零开始理解 RNN 的构建、训练和预测过程。我们将逐步剖析代码,展示如何加载数据、定义工具函数、构建模型、绘制训练过程图表,并最终训练一个字符级别的 RNN 模型。代码中包含了数据预处理、模型定义、梯度裁剪、困惑度计算等关键步骤,适合希望深入理解 RNN 的初学者和进阶者。

本文基于 PyTorch 实现,所有代码均来自附件,并辅以详细注释和图表说明。让我们开始吧!


一、数据预处理

首先,我们需要加载和预处理《The Time Machine》数据集,将其转化为适合 RNN 输入的格式。以下是数据预处理的完整代码:

import random
import re
import torch
from collections import Counter

def read_time_machine():
    """将时间机器数据集加载到文本行的列表中"""
    with open('timemachine.txt', 'r') as f:
        lines = f.readlines()
    # 去除非字母字符并将每行转换为小写
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print(f'错误:未知词元类型:{
     token}')

def count_corpus(tokens):
    """统计词元的频率"""
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = [''] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def load_corpus_time_machine(max_tokens=-1):
    """返回时光机器数据集的词元索引列表和词表"""
    lines = read_time_machine()
    tokens = tokenize(lines, 'char')
    vocab = Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus

你可能感兴趣的:(深度学习-pytorch版,深度学习,pytorch)