训练代码简介
本代码基于TensorFlow 1.1 的mnist with summaries例子实现https://github.com/tensorflow/tensorflow/blob/v1.1.0/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py,并做了部分改动。
引入 UAI Train相关参数
mnist_summary.py#L34 import uaitrain官方flags
from uaitrain.arch.tensorflow import uflag
指定输入数据地址
mnist_summary.py#L44 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.data_dir参数作为输入数据集
\# Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
定义算法网络和summary
mnist_summary.py#L51 ~ #L143
指定tensorboard summary路径
mnist_summary.py#L144,145 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.log_dir参数作为tensorboard summary输出的路径
merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') tf.global_variables_initializer().run()
执行训练
mnist_summary.py#L146-178
保存训练模型
mnist_summary.py#L180 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.output_dir参数作为训练输出的路径
save_path = saver.save(sess, FLAGS.output_dir + "/model.ckpt") print("Model saved in file: %s" % save_path)