代码如下
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(1)
mnist = input_data.read_data_sets(r'MNIST_data',one_hot=True)
learning_tare = 0.03
n_times = 15
batch_size = 100
class Model:
def __init__(self,sess):
self.sess = sess
self.build_net()
def build_net(self):
self.X = tf.placeholder(tf.float32,shape=[None,784])
self.Y = tf.placeholder(tf.float32,shape=[None,10])
self.keep_prob = tf.placeholder(tf.float32)
X_img = tf.reshape(self.X,shape=[-1,28,28,1])
W1 = tf.Variable(tf.random_normal([3,3,1,32],stddev=0.01))
W2 = tf.Variable(tf.random_normal([3,3,32,64],stddev=0.01))
W3 = tf.Variable(tf.random_normal([3,3,64,128],stddev=0.01))
l1 = tf.nn.conv2d(X_img,W1,[1],padding='SAME')
l1 = tf.nn.relu(l1)
l1 = tf.nn.max_pool(l1,[1,2,2,1],[1,2,2,1],padding="SAME")
l1 = tf.nn.dropout(l1,self.keep_prob)
l2 = tf.nn.conv2d(l1,W2,[1],padding='SAME')
l2 = tf.nn.relu(l2)
l2 = tf.nn.max_pool(l2,[1,2,2,1],[1,2,2,1],padding="SAME")
l2 = tf.nn.dropout(l2,self.keep_prob)
l3 = tf.nn.conv2d(l2,W3,[1],padding='SAME')
l3 = tf.nn.relu(l3)
l3 = tf.nn.max_pool(l3,[1,2,2,1],[1,2,2,1],padding="SAME")
l3 = tf.nn.dropout(l3,self.keep_prob)
l3_flaten = tf.reshape(l3,shape=[-1,4*4*128])
W4 = tf.get_variable('W4',shape=[4*4*128,50],initializer=tf.contrib.layers.xavier_initializer())
W5 = tf.get_variable('W5',shape=[50,10],initializer=tf.contrib.layers.xavier_initializer())
b4 = tf.Variable(tf.random_normal([50]))
b5 = tf.Variable(tf.random_normal([10]))
l4 = tf.nn.relu(tf.matmul(l3_flaten,W4) + b4)
l4 = tf.nn.dropout(l4,self.keep_prob)
self.l5 = tf.nn.softmax(tf.matmul(l4,W5) + b5)
self.cost = - tf.reduce_mean(tf.reduce_sum(self.Y * tf.log(self.l5),1))
self.train = tf.train.GradientDescentOptimizer(learning_rate=learning_tare).minimize(self.cost)
self.acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.Y,1),tf.argmax(self.l5,1)),dtype=tf.float32))
def trains(self,x_data,y_data,keep_prob=0.7):
return self.sess.run([self.cost,self.train],feed_dict={
self.X:x_data,self.Y:y_data,self.keep_prob:keep_prob
})
def get_acc(self,x_data,y_data,keep_prob=1):
return self.sess.run(self.acc,feed_dict={
self.X:x_data,self.Y:y_data,self.keep_prob:keep_prob
})
sess = tf.Session()
m1 = Model(sess)
sess.run(tf.global_variables_initializer())
for i in range(n_times):
avg_cost = 0
n_batch = mnist.train.num_examples // batch_size
for n in range(n_batch):
x_train,y_train = mnist.train.next_batch(batch_size)
c,_ = m1.trains(x_train,y_train)
avg_cost = c / n_batch
print('批次',i+1)
print('代价',avg_cost)
print('准确率',m1.get_acc(mnist.test.images,mnist.test.labels))