9-Autoencoder_example.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """
  2. To know more or get code samples, please visit my website:
  3. https://mofanpy.com/tutorials/
  4. Or search: 莫烦Python
  5. Thank you for supporting!
  6. """
  7. # please note, all tutorial code are running under python3.5.
  8. # If you use the version like python2.7, please modify the code accordingly
  9. # 9 - Autoencoder example
  10. # to try tensorflow, un-comment following two lines
  11. # import os
  12. # os.environ['KERAS_BACKEND']='tensorflow'
  13. import numpy as np
  14. np.random.seed(1337) # for reproducibility
  15. from keras.datasets import mnist
  16. from keras.models import Model
  17. from keras.layers import Dense, Input
  18. import matplotlib.pyplot as plt
  19. # download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
  20. # X shape (60,000 28x28), y shape (10,000, )
  21. (x_train, _), (x_test, y_test) = mnist.load_data()
  22. # data pre-processing
  23. x_train = x_train.astype('float32') / 255. - 0.5 # minmax_normalized
  24. x_test = x_test.astype('float32') / 255. - 0.5 # minmax_normalized
  25. x_train = x_train.reshape((x_train.shape[0], -1))
  26. x_test = x_test.reshape((x_test.shape[0], -1))
  27. print(x_train.shape)
  28. print(x_test.shape)
  29. # in order to plot in a 2D figure
  30. encoding_dim = 2
  31. # this is our input placeholder
  32. input_img = Input(shape=(784,))
  33. # encoder layers
  34. encoded = Dense(128, activation='relu')(input_img)
  35. encoded = Dense(64, activation='relu')(encoded)
  36. encoded = Dense(10, activation='relu')(encoded)
  37. encoder_output = Dense(encoding_dim)(encoded)
  38. # decoder layers
  39. decoded = Dense(10, activation='relu')(encoder_output)
  40. decoded = Dense(64, activation='relu')(decoded)
  41. decoded = Dense(128, activation='relu')(decoded)
  42. decoded = Dense(784, activation='tanh')(decoded)
  43. # construct the autoencoder model
  44. autoencoder = Model(input=input_img, output=decoded)
  45. # construct the encoder model for plotting
  46. encoder = Model(input=input_img, output=encoder_output)
  47. # compile autoencoder
  48. autoencoder.compile(optimizer='adam', loss='mse')
  49. # training
  50. autoencoder.fit(x_train, x_train,
  51. epochs=20,
  52. batch_size=256,
  53. shuffle=True)
  54. # plotting
  55. encoded_imgs = encoder.predict(x_test)
  56. plt.scatter(encoded_imgs[:, 0], encoded_imgs[:, 1], c=y_test)
  57. plt.colorbar()
  58. plt.show()