Calcular a distância aos pares em um lote sem replicar o tensor no fluxo de tensão?
Desejo calcular a distância quadrada aos pares de um lote de recursos no Tensorflow. Eu tenho uma implementação simples usando operações + e * lado a lado com o tensor original:
def pairwise_l2_norm2(x, y, scope=None):
with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
size_x = tf.shape(x)[0]
size_y = tf.shape(y)[0]
xx = tf.expand_dims(x, -1)
xx = tf.tile(xx, tf.pack([1, 1, size_y]))
yy = tf.expand_dims(y, -1)
yy = tf.tile(yy, tf.pack([1, 1, size_x]))
yy = tf.transpose(yy, perm=[2, 1, 0])
diff = tf.sub(xx, yy)
square_diff = tf.square(diff)
square_dist = tf.reduce_sum(square_diff, 1)
return square_dist
Essa função usa como entrada duas matrizes de tamanho (m, d) e (n, d) e calcula a distância ao quadrado entre cada vetor de linha. A saída é uma matriz de tamanho (m, n) com o elemento 'd_ij = dist (x_i, y_j)'.
O problema é que eu tenho um lote grande e recursos de baixa luminosidade 'm, n, d' replicando o tensor consomem muita memória. Estou procurando outra maneira de implementar isso sem aumentar o uso de memória e apenas armazenar o tensor de distância final. Tipo de loop duplo do tensor original.