mirror of
https://github.com/tengge1/ShadowEditor.git
synced 2026-01-25 15:08:11 +00:00
155 lines
4.5 KiB
Python
155 lines
4.5 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import time
|
|
import PIL.Image as Image
|
|
import numpy as np
|
|
import PIL.Image as Image
|
|
import numpy as np
|
|
import matplotlib.pylab as plt
|
|
|
|
import tensorflow as tf
|
|
|
|
import tensorflow_hub as hub
|
|
|
|
from tensorflow.keras import layers
|
|
|
|
# @param {type:"string"}
|
|
classifier_url = "https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/2"
|
|
|
|
IMAGE_SHAPE = (224, 224)
|
|
|
|
classifier = tf.keras.Sequential([
|
|
hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))
|
|
])
|
|
|
|
|
|
grace_hopper = tf.keras.utils.get_file(
|
|
'image.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
|
|
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
|
|
|
|
grace_hopper = np.array(grace_hopper)/255.0
|
|
|
|
result = classifier.predict(grace_hopper[np.newaxis, ...])
|
|
|
|
predicted_class = np.argmax(result[0], axis=-1)
|
|
|
|
labels_path = tf.keras.utils.get_file(
|
|
'ImageNetLabels.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
|
|
imagenet_labels = np.array(open(labels_path).read().splitlines())
|
|
|
|
plt.imshow(grace_hopper)
|
|
plt.axis('off')
|
|
predicted_class_name = imagenet_labels[predicted_class]
|
|
_ = plt.title("Prediction: " + predicted_class_name.title())
|
|
|
|
data_root = tf.keras.utils.get_file(
|
|
'flower_photos', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
|
|
untar=True)
|
|
|
|
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
|
|
rescale=1/255)
|
|
image_data = image_generator.flow_from_directory(
|
|
str(data_root), target_size=IMAGE_SHAPE)
|
|
|
|
for image_batch, label_batch in image_data:
|
|
print("Image batch shape: ", image_batch.shape)
|
|
print("Label batch shape: ", label_batch.shape)
|
|
break
|
|
|
|
result_batch = classifier.predict(image_batch)
|
|
|
|
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
|
|
|
|
plt.figure(figsize=(10, 9))
|
|
plt.subplots_adjust(hspace=0.5)
|
|
for n in range(30):
|
|
plt.subplot(6, 5, n+1)
|
|
plt.imshow(image_batch[n])
|
|
plt.title(predicted_class_names[n])
|
|
plt.axis('off')
|
|
_ = plt.suptitle("ImageNet predictions")
|
|
|
|
# @param {type:"string"}
|
|
feature_extractor_url = "https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/feature_vector/2"
|
|
|
|
feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
|
|
input_shape=(224, 224, 3))
|
|
|
|
feature_batch = feature_extractor_layer(image_batch)
|
|
print(feature_batch.shape)
|
|
|
|
feature_extractor_layer.trainable = False
|
|
|
|
model = tf.keras.Sequential([
|
|
feature_extractor_layer,
|
|
layers.Dense(image_data.num_classes, activation='softmax')
|
|
])
|
|
|
|
model.summary()
|
|
|
|
predictions = model(image_batch)
|
|
|
|
model.compile(
|
|
optimizer=tf.keras.optimizers.Adam(),
|
|
loss='categorical_crossentropy',
|
|
metrics=['acc'])
|
|
|
|
|
|
class CollectBatchStats(tf.keras.callbacks.Callback):
|
|
def __init__(self):
|
|
self.batch_losses = []
|
|
self.batch_acc = []
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
self.batch_losses.append(logs['loss'])
|
|
self.batch_acc.append(logs['acc'])
|
|
self.model.reset_metrics()
|
|
|
|
|
|
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)
|
|
|
|
batch_stats_callback = CollectBatchStats()
|
|
|
|
history = model.fit(image_data, epochs=2,
|
|
steps_per_epoch=steps_per_epoch,
|
|
callbacks=[batch_stats_callback])
|
|
|
|
plt.figure()
|
|
plt.ylabel("Loss")
|
|
plt.xlabel("Training Steps")
|
|
plt.ylim([0, 2])
|
|
plt.plot(batch_stats_callback.batch_losses)
|
|
|
|
class_names = sorted(image_data.class_indices.items(),
|
|
key=lambda pair: pair[1])
|
|
class_names = np.array([key.title() for key, value in class_names])
|
|
|
|
predicted_batch = model.predict(image_batch)
|
|
predicted_id = np.argmax(predicted_batch, axis=-1)
|
|
predicted_label_batch = class_names[predicted_id]
|
|
|
|
label_id = np.argmax(label_batch, axis=-1)
|
|
|
|
plt.figure(figsize=(10, 9))
|
|
plt.subplots_adjust(hspace=0.5)
|
|
for n in range(30):
|
|
plt.subplot(6, 5, n+1)
|
|
plt.imshow(image_batch[n])
|
|
color = "green" if predicted_id[n] == label_id[n] else "red"
|
|
plt.title(predicted_label_batch[n].title(), color=color)
|
|
plt.axis('off')
|
|
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")
|
|
|
|
t = time.time()
|
|
|
|
export_path = "/tmp/saved_models/{}".format(int(t))
|
|
tf.keras.experimental.export_saved_model(model, export_path)
|
|
|
|
reloaded = tf.keras.experimental.load_from_saved_model(
|
|
export_path, custom_objects={'KerasLayer': hub.KerasLayer})
|
|
|
|
result_batch = model.predict(image_batch)
|
|
reloaded_result_batch = reloaded.predict(image_batch)
|
|
|
|
abs(reloaded_result_batch - result_batch).max()
|