mirror of
https://github.com/tengge1/ShadowEditor.git
synced 2026-01-25 15:08:11 +00:00
54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
||
|
||
# https://tensorflow.google.cn/beta/tutorials/keras/basic_classification
|
||
|
||
# TensorFlow and tf.keras
|
||
import tensorflow as tf
|
||
from tensorflow import keras
|
||
|
||
# Helper libraries
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 允许显卡动态分配内存,否则windows上报错
|
||
gpu = tf.config.experimental.list_physical_devices('GPU')[0]
|
||
tf.config.experimental.set_memory_growth(gpu, True)
|
||
|
||
fashion_mnist = keras.datasets.fashion_mnist
|
||
|
||
(train_images, train_labels), (test_images,
|
||
test_labels) = fashion_mnist.load_data()
|
||
|
||
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
|
||
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
|
||
|
||
train_images = train_images / 255.0
|
||
|
||
test_images = test_images / 255.0
|
||
|
||
plt.figure(figsize=(10, 10))
|
||
for i in range(25):
|
||
plt.subplot(5, 5, i+1)
|
||
plt.xticks([])
|
||
plt.yticks([])
|
||
plt.grid(False)
|
||
plt.imshow(train_images[i], cmap=plt.cm.binary)
|
||
plt.xlabel(class_names[train_labels[i]])
|
||
plt.show()
|
||
|
||
model = keras.Sequential([
|
||
keras.layers.Flatten(input_shape=(28, 28)),
|
||
keras.layers.Dense(128, activation='relu'),
|
||
keras.layers.Dense(10, activation='softmax')
|
||
])
|
||
|
||
model.compile(optimizer='adam',
|
||
loss='sparse_categorical_crossentropy',
|
||
metrics=['accuracy'])
|
||
|
||
model.fit(train_images, train_labels, epochs=10)
|
||
|
||
test_loss, test_acc = model.evaluate(test_images, test_labels)
|
||
|
||
print('\nTest accuracy:', test_acc)
|