tf19_saver.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # View more python tutorials on my Youtube and Youku channel!!!
  2. # Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
  3. # Youku video tutorial: http://i.youku.com/pythontutorial
  4. """
  5. Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
  6. """
  7. from __future__ import print_function
  8. import tensorflow as tf
  9. import numpy as np
  10. # Save to file
  11. # remember to define the same dtype and shape when restore
  12. # W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
  13. # b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
  14. # tf.initialize_all_variables() no long valid from
  15. # 2017-03-02 if using tensorflow >= 0.12
  16. # if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
  17. # init = tf.initialize_all_variables()
  18. # else:
  19. # init = tf.global_variables_initializer()
  20. #
  21. # saver = tf.train.Saver()
  22. #
  23. # with tf.Session() as sess:
  24. # sess.run(init)
  25. # save_path = saver.save(sess, "my_net/save_net.ckpt")
  26. # print("Save to path: ", save_path)
  27. ################################################
  28. # restore variables
  29. # redefine the same shape and same type for your variables
  30. W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
  31. b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
  32. # not need init step
  33. saver = tf.train.Saver()
  34. with tf.Session() as sess:
  35. saver.restore(sess, "my_net/save_net.ckpt")
  36. print("weights:", sess.run(W))
  37. print("biases:", sess.run(b))