Tensorflow: uso de una tubería de entrada (.csv) como diccionario para el entrenamiento
Estoy tratando de entrenar un modelo en un conjunto de datos .csv (5008 columnas, 533 filas). Estoy usando un lector de texto para analizar los datos en dos tensores, uno con los datos para entrenar en [ejemplo] y otro con las etiquetas correctas [etiqueta]:
def read_my_file_format(filename_queue):
reader = tf.TextLineReader()
key, record_string = reader.read(filename_queue)
record_defaults = [[0.5] for row in range(5008)]
#Left out most of the columns for obvious reasons
col1, col2, col3, ..., col5008 = tf.decode_csv(record_string, record_defaults=record_defaults)
example = tf.stack([col1, col2, col3, ..., col5007])
label = col5008
return example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
Esta parte está funcionando cuando se ejecuta algo como:
with tf.Session() as sess:
ex_b, l_b = input_pipeline(["Tensorflow_vectors.csv"], 10, 1)
print("Test: ",ex_b)
mi resultado esTest: Tensor("shuffle_batch:0", shape=(10, 5007), dtype=float32)
Hasta ahora esto me parece bien. A continuación, he creado un modelo simple que consta de dos capas ocultas (512 y 256 nodos respectivamente). Cuando las cosas salen mal es cuando intento entrenar al modelo:
batch_x, batch_y = input_pipeline(["Tensorflow_vectors.csv"], batch_size)
_, cost = sess.run([optimizer, cost], feed_dict={x: batch_x.eval(), y: batch_y.eval()})
He basado este enfoque eneste ejemplo que usa la base de datos MNIST. Sin embargo, cuando estoy ejecutando esto, incluso cuando solo estoy usandobatch_size = 1
, Tensorflow simplemente se cuelga. Si dejo fuera el.eval()
funciones que deberían obtener los datos reales de los tensores, obtengo la siguiente respuesta:
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
Ahora puedo entender esto, pero no entiendo por qué el programa se bloquea cuando incluyo el.eval()
función y no sé dónde podría encontrar información sobre este problema.
EDITAR: Incluí la versión más reciente de mi guión completoaquí. El programa todavía se cuelga a pesar de que implementé (hasta donde sé correctamente) la solución ofrecida porvijay m