Реализация seq2seq с поиском луча
Я сейчас реализуюseq2seq модель на основе примера кода, которыйtensorflow
обеспечивает. И я хочу получитьтоп-5 декодер выходы, чтобы сделать обучение подкрепления.
Тем не менее, они реализовали модель трансляции с помощью декодера внимания, поэтому я должен реализовать лучевой поиск для получениятоп-K Результаты.
Есть часть кода, которая сейчас реализуется (этот код добавляется вtranslate.py
).
Ссылка поhttps://github.com/tensorflow/tensorflow/issues/654
with tf.Graph().as_default():
beam_size = FLAGS.beam_size # Number of hypotheses in beam
num_symbols = FLAGS.tar_vocab_size # Output vocabulary size
embedding_size = 10
num_steps = 5
embedding = tf.zeros([num_symbols, embedding_size])
output_projection = None
log_beam_probs, beam_symbols, beam_path = [], [], []
def beam_search(prev, i):
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
probs = tf.log(tf.nn.softmax(prev))
if i > 1:
probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])
best_probs, indices = tf.nn.top_k(probs, beam_size)
indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))
symbols = indices % num_symbols # which word in vocabulary
beam_parent = indices // num_symbols # which hypothesis it came from
beam_symbols.append(symbols)
beam_path.append(beam_parent)
log_beam_probs.append(best_probs)
return tf.nn.embedding_lookup(embedding, symbols)
# Setting up graph.
inputs = [tf.placeholder(tf.float32, shape=[None, num_symbols]) for i in range(num_steps)]
for i in range(num_steps):
beam_search(inputs[i], i+1)
input_vals = tf.zeros([1, beam_size], dtype=tf.float32)
input_feed = {inputs[i]: input_vals[i][:beam_size, :] for i in xrange(num_steps)}
output_feed = beam_symbols + beam_path + log_beam_probs
session = tf.InteractiveSession()
outputs = session.run(output_feed, feed_dict=input_feed)
print("Top_5 Sentences ")
for predicted in enumerate(outputs[:5]):
print(list(predicted))
print("\n")
В части input_feed есть ошибка:
ValueError: Shape (1, 12) must have rank 1
Есть ли проблема в моем коде, чтобы сделатьЛуч-поиск?