2025年3月31日 星期一 乙巳(蛇)年 正月初一 设为首页 加入收藏
rss
您当前的位置:首页 > 计算机 > 编程开发 > Python

tensorflow存取,读取,及保存的文件的含义

时间:01-30来源:作者:点击数:46

1. 保存恢复命令

1) class tf.train.Saver

保存和恢复变量的类。

2) tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True)

保存变量。默认保存最近5步的所有变量。每次保存的文件格式为:model_epoch1.ckpt.data-00000-of-00001,model_epoch1.ckpt.index,model_epoch1.ckpt.meta。不过我不太明白每个里面存储的什么东西。

如果存储指定的变量,例子为:

  • # Create some variables.
  • v1 = tf.Variable(..., name="v1")
  • v2 = tf.Variable(..., name="v2")
  • ...
  • # Add ops to save and restore only 'v2' using the name "my_v2"
  • saver = tf.train.Saver({"my_v2": v2})
  • # Use the saver object normally after that.
  • ...
3) tf.train.Saver.restore(sess, save_path)
恢复之前保存的变量。saver还有其他一些函数。

2. tensorflow保存网络结构及恢复

保存:

v1=tf.Variable(...,name="v1")

v2=tf.Variable(...,name="v2”)

之后,创建一个saver对象,来进行保存,同时不要忘记设定保存的路径。

  • saver = tf.train.Saver()
  • save_path = saver.save(sess, "./MNISTmodel/model.ckpt")
  • print ("Model saved in file: ", save_path)

模型保存好之后,在需要再次使用这个模型时,同样需要再创建一个saver对象。不要忘记,要将模型中之前保存好的变量名称再赋给需要载入的模型,即

  • v1 = tf.Variable(..., name="v1")
  • v2 = tf.Variable(..., name=“v2”)

不过此时不需要对这些变量进行初始化了。

  • saver = tf.train.Saver()
  • ......
  • with tf.Session() as sess:
  • # Restore variables from disk.
  • saver.restore(sess, "./MNISTmodel/model.ckpt")
  • print "Model restored."

这样就可以直接恢复之前训练好的模型了。

方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门