Como ler uma string binária codificada utf-8 no tensorflow?
Estou tentando converter uma sequência de bytes codificada de volta para a matriz original no gráfico tensorflow (usando operações de tensorflow) para fazer uma previsão em um modelo de tensorflow. A conversão de matriz em byte é baseada emesta resposta e é a entrada sugerida para a previsão do modelo tensorflow no ml-engine do Google Cloud.
def array_request_example(input_array):
input_array = input_array.astype(np.float32)
byte_string = input_array.tostring()
string_encoded_contents = base64.b64encode(byte_string)
return string_encoded_contents.decode('utf-8')}
Código Tensorflow
byte_string = tf.placeholder(dtype=tf.string)
audio_samples = tf.decode_raw(byte_string, tf.float32)
audio_array = np.array([1, 2, 3, 4])
bstring = array_request_example(audio_array)
fdict = {byte_string: bstring}
with tf.Session() as sess:
[tf_samples] = sess.run([audio_samples], feed_dict=fdict)
Eu tentei usardecode_raw edecode_base64 mas também não retorna os valores originais.
Eu tentei definir o out_type de decode bruto para os diferentes tipos de dados possíveis e tentei alterar em que tipo de dados estou convertendo a matriz original.
Então, como eu leria a matriz de bytes no fluxo tensor? Obrigado :)
Informação extraO objetivo por trás disso é criar a função de entrada de veiculação para um estimador personalizado fazer previsões usando previsão local do gcloud ml-engine (para teste) e usando a API REST para o modelo armazenado na nuvem.
A função de entrada de veiculação para o Estimador é
def serving_input_fn():
feature_placeholders = {'b64': tf.placeholder(dtype=tf.string,
shape=[None],
name='source')}
audio_samples = tf.decode_raw(feature_placeholders['b64'], tf.float32)
# Dummy function to save space
power_spectrogram = create_spectrogram_from_audio(audio_samples)
inputs = {'spectrogram': power_spectrogram}
return tf.estimator.export.ServingInputReceiver(inputs, feature_placeholders)
Pedido JsonEu uso .decode ('utf-8') porque, ao tentar json despejar as seqüências de bytes codificadas em base64, recebo esse erro
raise TypeError(repr(o) + " is not JSON serializable")
TypeError: b'longbytestring'
Erros de previsãoAo passar a solicitação json {'audio_bytes': 'b64': bytestring} com o gcloud local, recebo o erro
PredictionError: Invalid inputs: Expected tensor name: b64, got tensor name: [u'audio_bytes']
Então, talvez o Google Cloud Local Forecast não lide automaticamente com os bytes de áudio e a conversão da base64? Ou provavelmente algo de errado com a minha configuração do Estimador.
E a solicitação {'instance': [{'audio_bytes': 'b64': bytestring}]} para a API REST fornece
{'error': 'Prediction failed: Error during model execution: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="Input to DecodeRaw has length 793713 that is not a multiple of 4, the size of float\n\t [[Node: DecodeRaw = DecodeRaw[_output_shapes=[[?,?]], little_endian=true, out_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_source_0_0)]]")'}
o que me confunde ao definir explicitamente a solicitação como flutuante e fazer o mesmo no receptor de entrada de veiculação.
A remoção de audio_bytes da solicitação e a codificação utf-8 das sequências de bytes permitem-me obter previsões, embora ao testar a decodificação localmente, acho que o áudio está sendo convertido incorretamente da sequência de bytes.