Como usar os iteradores inicializáveis do tf.data dentro do input_fn de um tf.estimator?
Eu gostaria de gerenciar meu treinamento com umtf.estimator.Estimator
mas tenha problemas para usá-lo junto com otf.data
API.
Eu tenho algo parecido com isto:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
Como não posso usar ummake_one_shot_iterator
para o meu caso de uso, meu problema é queinput_fn
contém um iterador que deve ser inicializado dentromodel_fn
(aqui eu usotf.train.Scaffold
para inicializar operações locais).
Além disso, eu entendi que não podemos usar apenasinput_fn = iterator.get_next
caso contrário, as outras operações não serão adicionadas ao mesmo gráfico.
Qual é a maneira recomendada de inicializar o iterador?