Нет трансляции для tf.matmul в TensorFlow

У меня есть проблема, с которой я боролся. Это связано сtf.matmul() и его отсутствие вещания.

Я знаю о подобной проблеме наhttps://github.com/tensorflow/tensorflow/issues/216, ноtf.batch_matmul() не похоже на решение для моего случая.

Мне нужно закодировать мои входные данные как тензор 4D:X = tf.placeholder(tf.float32, shape=(None, None, None, 100)) Первое измерение - это размер пакета, второе - количество записей в пакете. Вы можете представить каждую запись как композицию из нескольких объектов (третье измерение). Наконец, каждый объект описывается вектором из 100 значений с плавающей запятой.

Обратите внимание, что я использовал None для второго и третьего измерений, потому что фактические размеры могут меняться в каждой партии. Однако, для простоты, давайте сформируем тензор с фактическими числами:X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

Вот шаги моего вычисления:

вычислить функцию каждого вектора из 100 значений с плавающей запятой (например, линейную функцию)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) проблема: нет трансляции дляtf.matmul() и безуспешно, используяtf.batch_matmul() ожидаемая форма Y: (5, 10, 4, 50)

применение среднего пула для каждой записи пакета (по объектам каждой записи):Y_avg = tf.reduce_mean(Y, 2) ожидаемая форма Y_avg: (5, 10, 50)

Я ожидал чтоtf.matmul() поддержал бы вещание. Потом я нашелtf.batch_matmul(), но, тем не менее, похоже, что это не относится к моему случаю (например, W должен иметь как минимум 3 измерения, не понятно почему).

Кстати, выше я использовал простую линейную функцию (веса которой хранятся в W). Но в моей модели у меня глубокая сеть. Итак, более общей проблемой, с которой я столкнулся, является автоматическое вычисление функции для каждого среза тензора. Вот почему я ожидал, чтоtf.matmul() было бы поведение вещания (если так, может быть,tf.batch_matmul() даже не было бы необходимости).

С нетерпением ждем учиться у вас! Alessio

Ответы на вопрос(2)

Ваш ответ на вопрос