ValueError: los valores RGB de la imagen de punto flotante deben estar en el rango 0..1. mientras usa matplotlib
Quiero visualizar los pesos de la capa de una red neuronal. Estoy usando pytorch.
import torch
import torchvision.models as models
from matplotlib import pyplot as plt
def plot_kernels(tensor, num_cols=6):
if not tensor.ndim==4:
raise Exception("assumes a 4D tensor")
if not tensor.shape[-1]==3:
raise Exception("last dim needs to be 3 to plot")
num_kernels = tensor.shape[0]
num_rows = 1+ num_kernels // num_cols
fig = plt.figure(figsize=(num_cols,num_rows))
for i in range(tensor.shape[0]):
ax1 = fig.add_subplot(num_rows,num_cols,i+1)
ax1.imshow(tensor[i])
ax1.axis('off')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
vgg = models.vgg16(pretrained=True)
mm = vgg.double()
filters = mm.modules
body_model = [i for i in mm.children()][0]
layer1 = body_model[0]
tensor = layer1.weight.data.numpy()
plot_kernels(tensor)
Lo anterior da este errorValueError: Floating point image RGB values must be in the 0..1 range.
Mi pregunta es ¿debería normalizar y tomar el valor absoluto de los pesos para superar este error o hay alguna otra forma? Si normalizo y uso el valor absoluto, creo que el significado de los gráficos cambia.
[[[[ 0.02240197 -1.22057354 -0.55051649]
[-0.50310904 0.00891289 0.15427093]
[ 0.42360783 -0.23392732 -0.56789106]]
[[ 1.12248898 0.99013627 1.6526649 ]
[ 1.09936976 2.39608836 1.83921957]
[ 1.64557672 1.4093554 0.76332706]]
[[ 0.26969245 -1.2997849 -0.64577204]
[-1.88377869 -2.0100112 -1.43068039]
[-0.44531786 -1.67845118 -1.33723605]]]
[[[ 0.71286005 1.45265901 0.64986968]
[ 0.75984162 1.8061738 1.06934202]
[-0.08650422 0.83452386 -0.04468433]]
[[-1.36591709 -2.01630116 -1.54488969]
[-1.46221244 -2.5365622 -1.91758668]
[-0.88827479 -1.59151018 -1.47308767]]
[[ 0.93600738 0.98174071 1.12213969]
[ 1.03908169 0.83749604 1.09565806]
[ 0.71188802 0.85773659 0.86840987]]]
[[[-0.48592842 0.2971966 1.3365227 ]
[ 0.47920835 -0.18186836 0.59673625]
[-0.81358945 1.23862112 0.13635623]]
[[-0.75361633 -1.074965 0.70477796]
[ 1.24439156 -1.53563368 -1.03012812]
[ 0.97597247 0.83084011 -1.81764793]]
[[-0.80762428 -0.62829626 1.37428832]
[ 1.01448071 -0.81775147 -0.41943246]
[ 1.02848887 1.39178836 -1.36779451]]]
...,
[[[ 1.28134537 -0.00482408 0.71610934]
[ 0.95264435 -0.09291686 -0.28001019]
[ 1.34494913 0.64477581 0.96984017]]
[[-0.34442815 -1.40002513 1.66856039]
[-2.21281362 -3.24513769 -1.17751861]
[-0.93520379 -1.99811196 0.72937071]]
[[ 0.63388056 -0.17022935 2.06905985]
[-0.7285465 -1.24722099 0.30488953]
[ 0.24900314 -0.19559766 1.45432627]]]
[[[-0.80684513 2.1764245 -0.73765725]
[-1.35886598 1.71875226 -1.73327696]
[-0.75233924 2.14700699 -0.71064663]]
[[-0.79627383 2.21598244 -0.57396138]
[-1.81044972 1.88310981 -1.63758397]
[-0.6589964 2.013237 -0.48532376]]
[[-0.3710472 1.4949851 -0.30245575]
[-1.25448656 1.20453358 -1.29454732]
[-0.56755757 1.30994892 -0.39370224]]]
[[[-0.67361742 -3.69201088 -1.23768616]
[ 3.12674141 1.70414758 -1.76272404]
[-0.22565465 1.66484773 1.38172317]]
[[ 0.28095332 -2.03035069 0.69989491]
[ 1.97936332 1.76992691 -1.09842575]
[-2.22433758 0.52577412 0.18292744]]
[[ 0.48471382 -1.1984663 1.57565165]
[ 1.09911084 1.31910467 -0.51982772]
[-2.76202297 -0.47073677 0.03936549]]]]