Sensorflow: Wiederherstellen eines Diagramms und eines Modells und anschließende Auswertung eines einzelnen Bildes

Ich denke, es wäre für die Tensorflow-Community sehr hilfreich, wenn es eine gut dokumentierte Lösung für die entscheidende Aufgabe gäbe, ein einzelnes neues Image gegen das vom @ erstellte Modell zu testeconvnet im CIFAR-10-Tutorial.

Ich kann mich irren, aber dieser kritische Schritt, der das trainierte Modell in der Praxis nutzbar macht, scheint zu fehlen. In diesem Lernprogramm gibt es ein "fehlendes Glied" - ein Skript, das ein einzelnes Bild (als Array oder Binärdatei) direkt lädt, es mit dem trainierten Modell vergleicht und eine Klassifizierung zurückgibt.

Prior-Antworten enthalten Teillösungen, die den Gesamtansatz erklären, von denen ich jedoch keine erfolgreich implementieren konnte. Andere Teile sind hier und da zu finden, haben aber leider nicht zu einer funktionierenden Lösung geführt. Überlegen Sie sich bitte die Recherchen, die ich durchgeführt habe, bevor Sie diese als doppelt oder bereits beantwortet kennzeichnen.

Tensorflow: Wie speichere / restauriere ich ein Modell?

Wiederherstellung des TensorFlow-Modells

Modelle in Tensorflow v0.8 können nicht wiederhergestellt werden.

https: //gist.github.com/nikitakit/6ef3b72be67b86cb786

Die beliebteste Antwort ist die erste, in der @RyanSepassi und @YaroslavBulatov das Problem und einen Ansatz beschreiben: Man muss "manuell einen Graphen mit identischen Knotennamen erstellen und Saver verwenden, um die Gewichte darin zu laden". Obwohl beide Antworten hilfreich sind, ist nicht ersichtlich, wie man dies in das CIFAR-10-Projekt einbinden würde.

Eine voll funktionsfähige Lösung wäre sehr wünschenswert, damit wir sie auf andere Probleme bei der Klassifizierung einzelner Bilder übertragen können. Es gibt diesbezüglich mehrere Fragen zu SO, die danach fragen, aber immer noch keine vollständige Antwort (zum Beispiel Checkpoint laden und Einzelbild mit Tensorflow DNN auswerten).

Ich hoffe, wir können uns auf ein funktionierendes Skript einigen, das jeder verwenden kann.

Das folgende Skript ist noch nicht funktionsfähig, und ich würde mich freuen, von Ihnen zu hören, wie dies verbessert werden kann, um eine Lösung für die Einzelbildklassifizierung mithilfe des CIFAR-10 TF-Modells mit Tutorial bereitzustellen.

Angenommen, alle Variablen, Dateinamen usw. bleiben vom Original-Tutorial unberührt.

Neue Datei: cifar10_eval_single.py

import cv2
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', './input/eval',
                           """Directory where to write event logs.""")
tf.app.flags.DEFINE_string('checkpoint_dir', './input/train',
                           """Directory where to read model checkpoints.""")

def get_single_img():
    file_path = './input/data/single/test_image.tif'
    pixels = cv2.imread(file_path, 0)
    return pixels

def eval_single_img():

    # below code adapted from @RyanSepassi, however not functional
    # among other errors, saver throws an error that there are no
    # variables to save
    with tf.Graph().as_default():

        # Get image.
        image = get_single_img()

        # Build a Graph.
        # TODO

        # Create dummy variables.
        x = tf.placeholder(tf.float32)
        w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
        b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
        y_hat = tf.add(b, tf.matmul(x, w))

        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)

            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Checkpoint found')
            else:
                print('No checkpoint found')

            # Run the model to get predictions
            predictions = sess.run(y_hat, feed_dict={x: image})
            print(predictions)

def main(argv=None):
    if tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.DeleteRecursively(FLAGS.eval_dir)
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    eval_single_img()

if __name__ == '__main__':
    tf.app.run()

Antworten auf die Frage(8)

Ihre Antwort auf die Frage