theano6_shared_variable.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # View more python tutorials on my Youtube and Youku channel!!!
  2. # Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
  3. # Youku video tutorial: http://i.youku.com/pythontutorial
  4. # 6 - shared variables
  5. """
  6. Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
  7. """
  8. from __future__ import print_function
  9. import numpy as np
  10. import theano
  11. import theano.tensor as T
  12. state = theano.shared(np.array(0,dtype=np.float64), 'state') # inital state = 0
  13. inc = T.scalar('inc', dtype=state.dtype)
  14. accumulator = theano.function([inc], state, updates=[(state, state+inc)])
  15. # to get variable value
  16. print(state.get_value())
  17. accumulator(1) # return previous value, 0 in here
  18. print(state.get_value())
  19. accumulator(10) # return previous value, 1 in here
  20. print(state.get_value())
  21. # to set variable value
  22. state.set_value(-1)
  23. accumulator(3)
  24. print(state.get_value())
  25. # temporarily replace shared variable with another value in another function
  26. tmp_func = state * 2 + inc
  27. a = T.scalar(dtype=state.dtype)
  28. skip_shared = theano.function([inc, a], tmp_func, givens=[(state, a)]) # temporarily use a's value for the state
  29. print(skip_shared(2, 3))
  30. print(state.get_value()) # old state value