TFRecord实践

在工作中,需要在训练模型的过程中,读入大规模稀疏矩阵,因此考虑用tfrecord进行加载

1.生TFRecord

import tensorflow as tf
import numpy as np
"""
txt文件中保存的是矩阵每一行的行坐标,列坐标,以及元素值
数据格式为:‘行坐标’ + ‘[对应所有列坐标]’ + ‘[对应所有元素值]’
"""
def write_TFRecord(srcpath, dstpath):  
    writer = tf.python_io.TFRecordWriter(dstpath)
    f = open(srcpath)
    line = f.readline()
    while line:
        line_ = line.strip().split('\t')
        cols = eval(line[1])
        vals = eval(line[2])
        rows = [int(line[0])]
        features = tf.train.Features(
            feature={'rows': tf.train.Feature(int64_list=tf.train.Int64List(value=rows)),
                    'photos': tf.train.Feature(int64_list=tf.train.Int64List(value=cols)),
                    'vals': tf.train.Feature(float_list=tf.train.FloatList(value=vals))
                    })
            
        example = tf.train.Example(features=features)
        writer.write(example.SerializeToString())
    writer.close()          
    f.close()

2. 利用tf.data.TFRecordDataset接口进行解析

2.1 将每一行的值解析为稠密张量

def parser(example):
    dicts = {
        'rows': tf.FixedLenFeature(shape=[],dtype=tf.int64),
        'cols': tf.VarLenFeature(dtype=tf.int64), #由于cols为变长,需要使用 tf.VarLenFeature
        'vals': tf.VarLenFeature(dtype=tf.float32)
    }
    parsed_example = tf.parse_single_example(example, dicts)
    rows = parsed_example['rows']
    cols = parsed_example['cols']
    vals = parsed_example['vals']
    return rows, tf.sparse_tensor_to_dense(rows), tf.sparse_tensor_to_dense(vals)
# 采用这种方式,返回的是稀疏张量,需要用tf.sparse_tensor_to_dense转化为稠密张量

def get_batch_dataset(recordfile, parser):
    dataset = tf.data.TFRecordDataset(recordfile).map(parser).padded_batch(2, padded_shapes=([],[None],[None]))
# 由于row_index跟vals均不为定长,无法进行batch,所以需要对其进行填充,将短的张量用0填充,直到其长度与batch中最长的张量相等
    return dataset

dataset = get_batch_dataset('tfrecord',  parser)

2.2 直接读取为稀疏张量

def parser1(example):
    my_example_features = {'sparse': tf.SparseFeature(index_key=['rows', 'cols'],
                                                  value_key='vals',
                                                  dtype=tf.float32,
                                                  size=[1,max_col])} #size[0]表示一行,size[1]表示稀疏矩阵的列数
    parsed_example = tf.parse_single_example(example, my_example_features)
    return parsed_example['sparse']

def get_batch_dataset(recordfile, parser):                            
    dataset = tf.data.TFRecordDataset(recordfile).map(parser).repeat(2).batch(20000)
    return dataset

dataset = get_batch_dataset('tfrecord',  parser)

3. Iterator的使用

你可能感兴趣的:(TFRecord实践)