Python itertools.product reordena a geração

Eu tenho isto

shape = (2, 4) # arbitrary, could be 3 dimensions such as (3, 5, 7), etc...

for i in itertools.product(*(range(x) for x in shape)):
    print(i)

# output: (0, 0) (0, 1) (0, 2) (0, 3) (1, 0) (1, 1) (1, 2) (1, 3)

Por enquanto, tudo bem,itertools.product avança o elemento mais à direita em todas as iterações. Mas agora eu quero poder especificar a ordem da iteração de acordo com o seguinte:

axes = (0, 1) # normal order
# output: (0, 0) (0, 1) (0, 2) (0, 3) (1, 0) (1, 1) (1, 2) (1, 3)

axes = (1, 0) # reversed order
# output: (0, 0) (1, 0) (2, 0) (3, 0) (0, 1) (1, 1) (2, 1) (3, 1)

E seshapes tinha três dimensões,axes poderia ter sido, por exemplo,(0, 1, 2) ou(2, 0, 1) etc, então não se trata de simplesmente usarreversed(). Então, eu escrevi um código que faz isso, mas parece muito ineficiente:

axes = (1, 0)

# transposed axes
tpaxes = [0]*len(axes)
for i in range(len(axes)):
    tpaxes[axes[i]] = i

for i in itertools.product(*(range(x) for x in shape)):
    # reorder the output of itertools.product
    x = (i[y] for y in tpaxes)
    print(tuple(x))

Alguma idéia de como fazer isso corretamente?

questionAnswers(12)

yourAnswerToTheQuestion