5-classifier_example.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """
  2. To know more or get code samples, please visit my website:
  3. https://mofanpy.com/tutorials/
  4. Or search: 莫烦Python
  5. Thank you for supporting!
  6. """
  7. # please note, all tutorial code are running under python3.5.
  8. # If you use the version like python2.7, please modify the code accordingly
  9. # 5 - Classifier example
  10. import numpy as np
  11. np.random.seed(1337) # for reproducibility
  12. from keras.datasets import mnist
  13. from keras.utils import np_utils
  14. from keras.models import Sequential
  15. from keras.layers import Dense, Activation
  16. from keras.optimizers import RMSprop
  17. # download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
  18. # X shape (60,000 28x28), y shape (10,000, )
  19. (X_train, y_train), (X_test, y_test) = mnist.load_data()
  20. # data pre-processing
  21. X_train = X_train.reshape(X_train.shape[0], -1) / 255. # normalize
  22. X_test = X_test.reshape(X_test.shape[0], -1) / 255. # normalize
  23. y_train = np_utils.to_categorical(y_train, num_classes=10)
  24. y_test = np_utils.to_categorical(y_test, num_classes=10)
  25. # Another way to build your neural net
  26. model = Sequential([
  27. Dense(32, input_dim=784),
  28. Activation('relu'),
  29. Dense(10),
  30. Activation('softmax'),
  31. ])
  32. # Another way to define your optimizer
  33. rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
  34. # We add metrics to get more results you want to see
  35. model.compile(optimizer=rmsprop,
  36. loss='categorical_crossentropy',
  37. metrics=['accuracy'])
  38. print('Training ------------')
  39. # Another way to train the model
  40. model.fit(X_train, y_train, epochs=2, batch_size=32)
  41. print('\nTesting ------------')
  42. # Evaluate the model with the metrics we defined earlier
  43. loss, accuracy = model.evaluate(X_test, y_test)
  44. print('test loss: ', loss)
  45. print('test accuracy: ', accuracy)