Implementando seq2seq com pesquisa de feixe
Agora estou implementandoseq2seq modelo baseado no código de exemplo quetensorflow
fornece. E eu quero ter umadecodificador top-5 resultados para fazer um aprendizado por reforço.
No entanto, eles implementaram o modelo de tradução com decodificador de atenção, portanto, eu deveria implementar a pesquisa por feixe para obtertop-k resultados.
Há uma parte do código que agora implementa (esse código é adicionado aotranslate.py
)
Referência porhttps://github.com/tensorflow/tensorflow/issues/654
with tf.Graph().as_default():
beam_size = FLAGS.beam_size # Number of hypotheses in beam
num_symbols = FLAGS.tar_vocab_size # Output vocabulary size
embedding_size = 10
num_steps = 5
embedding = tf.zeros([num_symbols, embedding_size])
output_projection = None
log_beam_probs, beam_symbols, beam_path = [], [], []
def beam_search(prev, i):
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
probs = tf.log(tf.nn.softmax(prev))
if i > 1:
probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])
best_probs, indices = tf.nn.top_k(probs, beam_size)
indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))
symbols = indices % num_symbols # which word in vocabulary
beam_parent = indices // num_symbols # which hypothesis it came from
beam_symbols.append(symbols)
beam_path.append(beam_parent)
log_beam_probs.append(best_probs)
return tf.nn.embedding_lookup(embedding, symbols)
# Setting up graph.
inputs = [tf.placeholder(tf.float32, shape=[None, num_symbols]) for i in range(num_steps)]
for i in range(num_steps):
beam_search(inputs[i], i+1)
input_vals = tf.zeros([1, beam_size], dtype=tf.float32)
input_feed = {inputs[i]: input_vals[i][:beam_size, :] for i in xrange(num_steps)}
output_feed = beam_symbols + beam_path + log_beam_probs
session = tf.InteractiveSession()
outputs = session.run(output_feed, feed_dict=input_feed)
print("Top_5 Sentences ")
for predicted in enumerate(outputs[:5]):
print(list(predicted))
print("\n")
Na parte input_feed, há um erro:
ValueError: Shape (1, 12) must have rank 1
Existe algum problema no meu código para fazerpesquisa por feixe?