模型准备

重训练基于已经训练好的物体识别模型。具体理论和分析参阅Tensorflow Retraining,本例中我们使用以上链接中Tensorflow提供的物体识别模型。

UAI-Train训练平台无互联网权限,因此模型需保存并上传至UFile。我们使用Tensorflow_hub模块自带的下载组件缓存模型(Tensorflow Retraining: Models)。Tensorflow_hub模组中的代码在加载在线获取的模型时会将模型缓存在本地默认目录下,我们可以修改此目录以将模型保存在我们想要的路径下。

  • 在本地主机安装Tensorflow_hub:
sudo pip install tensorflow_hub
  • 转至图片保存的根目录(请参阅前一节查看保存图片的根目录,例如"/data/pets"):
cd /data/pets
  • 将默认模型缓存目录转至当前目录(注意,此设置仅在本对话中有效。若重新连接本地主机以下载模型,则需重新设置此路径):
export TFHUB_CACHE_DIR=/data/checkpoint_dir
  • 创建文件,编写简单脚本缓存模型(此处链接下载的为mobilenetv1100_224模型,你可以参考上方链接获取其他模型):
vim cache_module.py
输入:
import tensorflow_hub as hub
m = hub.Module("https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1")

等待脚本运行完成。此时模型已经保存在checkpoint_dir路径下的一个子路径中,子路径的名字为模型数据的哈希值(此处为mobilenet_v1_1.0_224模型:cc14ad57953629a2bbc0ffe52de5afb5518150b2),不同的模型此路径的名称不同。方便起见,我们将此路径下的所有文件和文件夹放入checkpoint_dir路径中。此时准备好的图片和模型路径应如:

|_ data
|  |_ pets
|  |  |_ Abyssinian
|  |  |  |_ Abyssinian_01.jpg
|  |  |  |_ Abyssinian_02.jpg
|  |  |_ Persian
|  |  |  |_ Persian_01.jpg
|  |  |  |_ Persian_02.jpg
|  |_ checkpoint_dir
|  |  |_ assets
|  |  |_ variables
|  |  |_ saved_model.pb
|  |  |_ tfhub_module.pb