训练代码简介

本代码基于TensorFlow 1.1 的mnist with summaries例子实现https://github.com/tensorflow/tensorflow/blob/v1.1.0/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py,并做了部分改动。

引入 UAI Train相关参数

mnist_summary.py#L34 import uaitrain官方flags

from uaitrain.arch.tensorflow import uflag

指定输入数据地址

mnist_summary.py#L44 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.data_dir参数作为输入数据集

  \# Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

定义算法网络和summary

mnist_summary.py#L51 ~ #L143

指定tensorboard summary路径

mnist_summary.py#L144,145 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.log_dir参数作为tensorboard summary输出的路径

  merged = tf.summary.merge_all()
  train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
  test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
  tf.global_variables_initializer().run()

执行训练

mnist_summary.py#L146-178

保存训练模型

mnist_summary.py#L180 使用uaitrain.arch.tensorflow.uflag中的定义的FLAGS.output_dir参数作为训练输出的路径

  save_path = saver.save(sess, FLAGS.output_dir + "/model.ckpt")
  print("Model saved in file: %s" % save_path)

完整代码请参见https://github.com/ucloud/uai-sdk/blob/master/examples/tensorflow/train/mnist_summary_1.1/code/mnist_summary.py