Elegante Methode zum Auswählen eines Elements pro Zeile in Tensorflow
Gegeben..
a MatrixA
von Form[m, n]
ein TensorI
von Form[m]
Ich möchte eine Liste bekommenJ
von Elementen ausA
woJ[i] = A[i, I[i]]
.
Das ist,I
enthält den Index des Elements, das aus jeder Zeile in @ ausgewählt werden soA
.
Kontext: Ich habe bereits dasargmax(A, 1)
und jetzt will ich auch dasmax
. Ich weiß, dass ich nur @ verwenden kareduce_max
. Und nachdem ich ein bisschen rumprobiert hatte, kam ich auch auf folgendes:
J = tf.gather_nd(A,
tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
Bei dem dieto_int64
wird benötigt, da range nur @ erzeuint32
undargmax
produziert nurint64
.
Keiner der beiden kommt mir besonders elegant vor. Man hat Laufzeit-Overhead (wahrscheinlich über Faktorn
) und der andere hat einen unbekannten kognitiven Overheadfaktor. Vermisse ich hier etwas?