Nenhuma transmissão para tf.matmul no TensorFlow

Eu tenho um problema com o qual tenho lutado. Está relacionado atf.matmul() e sua ausência de transmissão.

Estou ciente de um problema semelhante emhttps://github.com/tensorflow/tensorflow/issues/216, mastf.batch_matmul() não parece uma solução para o meu caso.

Preciso codificar meus dados de entrada como um tensor 4D:X = tf.placeholder(tf.float32, shape=(None, None, None, 100)) A primeira dimensão é o tamanho de um lote, a segunda o número de entradas no lote. Você pode imaginar cada entrada como uma composição de vários objetos (terceira dimensão). Finalmente, cada objeto é descrito por um vetor de 100 valores flutuantes.

Observe que eu usei Nenhum para a segunda e terceira dimensões, porque os tamanhos reais podem mudar em cada lote. No entanto, para simplificar, vamos modelar o tensor com números reais:X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

Estas são as etapas do meu cálculo:

calcular uma função de cada vetor de 100 valores flutuantes (por exemplo, função linear)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) problema: sem transmissão paratf.matmul() e sem sucesso usandotf.batch_matmul() forma esperada de Y: (5, 10, 4, 50)

aplicação de pool médio para cada entrada do lote (sobre os objetos de cada entrada):Y_avg = tf.reduce_mean(Y, 2) forma esperada de Y_avg: (5, 10, 50)

eu esperei issotf.matmul() teria suportado a transmissão. Então eu acheitf.batch_matmul(), mas ainda assim parece que não se aplica ao meu caso (por exemplo, W precisa ter pelo menos três dimensões, sem saber por que).

BTW, acima, usei uma função linear simples (cujos pesos são armazenados em W). Mas, no meu modelo, tenho uma rede profunda. Portanto, o problema mais geral que tenho é calcular automaticamente uma função para cada fatia de um tensor. É por isso que eu esperava quetf.matmul() teria tido um comportamento de transmissão (se sim, talveztf.batch_matmul() nem seria necessário).

Ansiosos para aprender com você! Alessio

questionAnswers(2)

yourAnswerToTheQuestion