Нет трансляции для tf.matmul в TensorFlow
У меня есть проблема, с которой я боролся. Это связано сtf.matmul()
и его отсутствие вещания.
Я знаю о подобной проблеме наhttps://github.com/tensorflow/tensorflow/issues/216, ноtf.batch_matmul()
не похоже на решение для моего случая.
Мне нужно закодировать мои входные данные как тензор 4D:X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
Первое измерение - это размер пакета, второе - количество записей в пакете. Вы можете представить каждую запись как композицию из нескольких объектов (третье измерение). Наконец, каждый объект описывается вектором из 100 значений с плавающей запятой.
Обратите внимание, что я использовал None для второго и третьего измерений, потому что фактические размеры могут меняться в каждой партии. Однако, для простоты, давайте сформируем тензор с фактическими числами:X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))
Вот шаги моего вычисления:
вычислить функцию каждого вектора из 100 значений с плавающей запятой (например, линейную функцию)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
Y = tf.matmul(X, W)
проблема: нет трансляции дляtf.matmul()
и безуспешно, используяtf.batch_matmul()
ожидаемая форма Y: (5, 10, 4, 50)
применение среднего пула для каждой записи пакета (по объектам каждой записи):Y_avg = tf.reduce_mean(Y, 2)
ожидаемая форма Y_avg: (5, 10, 50)
Я ожидал чтоtf.matmul()
поддержал бы вещание. Потом я нашелtf.batch_matmul()
, но, тем не менее, похоже, что это не относится к моему случаю (например, W должен иметь как минимум 3 измерения, не понятно почему).
Кстати, выше я использовал простую линейную функцию (веса которой хранятся в W). Но в моей модели у меня глубокая сеть. Итак, более общей проблемой, с которой я столкнулся, является автоматическое вычисление функции для каждого среза тензора. Вот почему я ожидал, чтоtf.matmul()
было бы поведение вещания (если так, может быть,tf.batch_matmul()
даже не было бы необходимости).
С нетерпением ждем учиться у вас! Alessio