最近在看keras文档的时候看到一个关于RNN的很有意思的应用——用RNN拟合加法运算。看完之后我自己也实现了一下,原版代码在这里https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py
用RNN拟合整数的加法运算,其中被加数和加数在区间[0,1000)内。
先从数据讲起,因为被加数和加数在区间[0,1000)内,所以加法表达式的最大长度为3+1+3=7且表达式中包含的字符只能为{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', ' '}
,所以该实验中一个加法表达式可以表示成一个7×12的一个布尔矩阵,以223+11为例,其布尔矩阵为:
[[False False True False False False False False False False False False]
[False False True False False False False False False False False False]
[False False False True False False False False False False False False]
[False False False False False False False False False False True False]
[False True False False False False False False False False False False]
[False True False False False False False False False False False False]
[False False False False False False False False False False False True]]
同理,加法运算的结果可以表示为4×12的布尔矩阵,上面例子中的结果234可以表示为:
[[False False True False False False False False False False False False]
[False False False True False False False False False False False False]
[False False False False True False False False False False False False]
[False False False False False False False False False False False True]]
把加法表达式和结果表示出来后,接下来就可以把数据放到RNN中训练了,我参考了https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py中的网络结构,其结构图如下:
本实验随机生成了50000组数据,其中35000用来训练,15000组用来测试。
#encoding: utf-8
import numpy
from keras.models import Sequential
from keras.engine.training import slice_X
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
from keras.layers.recurrent import LSTM
DATA_SIZE = 50000
NUMBER_LEN = 3
HIDDEN_SIZE = 128
BATCH_SIZE = 128
EXPRESSION_LEN = NUMBER_LEN + 1 + NUMBER_LEN
symbol_table = '0123456789+ '
index_2_symbol = {}
symbol_2_index = {}
for i, ch in enumerate(symbol_table):
index_2_symbol[i] = ch
symbol_2_index[ch] = i
def encode(s):
feat_mat = numpy.zeros((len(s), len(symbol_table)), dtype = numpy.bool);
for i, ch in enumerate(s):
feat_mat[i, symbol_2_index[ch]] = 1
return feat_mat
def load_data(split):
print 'Generating data...'
used = set()
exp_list = []
ans_list = []
while (len(exp_list) < DATA_SIZE):
f = lambda: int(''.join(numpy.random.choice(list('0123456789')) for i in range(numpy.random.randint(1, NUMBER_LEN + 1))))
a, b = f(), f()
key = (a, b)
if key in used:
continue
used.add(key)
exp = '{}+{}'.format(a, b);
exp += ' ' * (EXPRESSION_LEN - len(exp))
ans = str(a + b)
ans += ' ' * (NUMBER_LEN + 1 - len(ans))
exp_list.append(exp)
ans_list.append(ans)
data = numpy.zeros((len(exp_list), EXPRESSION_LEN, len(symbol_table)), dtype = numpy.bool)
label = numpy.zeros((len(ans_list), NUMBER_LEN + 1, len(symbol_table)), dtype = numpy.bool)
for i, exp in enumerate(exp_list):
data[i] = encode(exp)
for i, ans in enumerate(ans_list):
label[i] = encode(ans)
dividing_line = (int)(split * DATA_SIZE)
indices = numpy.random.permutation(DATA_SIZE)
train_idx, test_idx = indices[:dividing_line], indices[dividing_line:]
train_data, test_data = data[train_idx, :], data[test_idx, :]
train_label, test_label = label[train_idx,], label[test_idx, ]
return (train_data, train_label), (test_data, test_label)
def deal():
(train_data, train_label), (test_data, test_label) = load_data(0.7)
# print train_data.shape
# print train_label.shape
# print test_data.shape
# print test_label.shape
print('Build model...')
model = Sequential()
model.add(LSTM(HIDDEN_SIZE, input_shape = (EXPRESSION_LEN, len(symbol_table))))
model.add(RepeatVector(NUMBER_LEN + 1))
model.add(LSTM(HIDDEN_SIZE, return_sequences = True))
model.add(TimeDistributed(Dense(len(symbol_table))))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
for i in range(10, 100, 10):
print 'iter: ', i
model.fit(train_data, train_label, batch_size = BATCH_SIZE, nb_epoch = 10, validation_split = 0.1, verbose = 0)
score = model.evaluate(test_data, test_label, batch_size = BATCH_SIZE, verbose = 0)
print 'acc: ', score[1]
if __name__ == "__main__":
deal()
iter: 10
acc: 0.717766666667
iter: 20
acc: 0.913150000032
iter: 30
acc: 0.951416666667
iter: 40
acc: 0.968233333302
iter: 50
acc: 0.978
iter: 60
acc: 0.983066666635
iter: 70
acc: 0.9844
iter: 80
acc: 0.985700000032
iter: 90
acc: 0.981200000032
如有错误请指正