¿Cómo usar tf.contrib.model_pruning en MNIST?

Estoy luchando por usar la biblioteca de poda de Tensorflow y no he encontrado muchos ejemplos útiles, por lo que estoy buscando ayuda para podar un modelo simple entrenado en el conjunto de datos MNIST. Si alguien puede ayudar a solucionar mi intento o proporcionar un ejemplo de cómo usar la biblioteca en MNIST, estaría muy agradecido.

La primera mitad de mi código es bastante estándar, excepto que mi modelo tiene 2 capas ocultas de 300 unidades de ancho usandolayers.masked_fully_connected para poda.

import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Luego intento definir las operaciones de poda necesarias, pero aparece un error.

############ Pruning Operations ##############
# Create global step variable
global_step = tf.contrib.framework.get_or_create_global_step()

# Create a pruning object using the pruning specification
pruning_hparams = pruning.get_pruning_hparams()
p = pruning.Pruning(pruning_hparams, global_step=global_step)

# Mask Update op
mask_update_op = p.conditional_mask_update_op()

# Set up the specification for model pruning
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

Error en esta línea:

prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

InvalidArgumentError (ver arriba para el rastreo): debe alimentar un valor para el tensor de marcador de posición 'Placeholder_1' con dtype float y shape [?, 10] [[Node: Placeholder_1 = Placeholderdtype = DT_FLOAT, shape = [?, 10], _device = "/ job: localhost / replica: 0 / task: 0 / device: GPU: 0"]] [[Node: global_step / _57 = _Recv_start_time = 0, client_terminated = false, recv_device = "/ job: localhost / replica: 0 / tarea: 0 / dispositivo: CPU: 0 ", send_device =" / job: localhost / replica: 0 / task: 0 / device: GPU: 0 ", send_device_incarnation = 1, tensor_name =" edge_71_global_step ", tensor_type = DT_INT64, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0"]]

Supongo que quiere un tipo diferente de operación en lugar de train_op, pero no he encontrado ningún ajuste que funcione.

De nuevo, si tiene un ejemplo de trabajo diferente que elimina un modelo capacitado en MNIST, consideraría que es una respuesta.

Respuestas a la pregunta(2)

Su respuesta a la pregunta