Erro StackOverflow ao aplicar o "recommendProductsForUsers" do pyspark ALS (embora haja um cluster de> 300 GB de RAM disponível)
Procurando experiência para me guiar na questão abaixo.
Fundo:
Estou tentando seguir com um script básico do PySpark inspirado emeste exemploComo infraestrutura de implantação, uso um cluster do Google Cloud Dataproc.A pedra angular no meu código é a função "recommendProductsForUsers" documentadaaqui o que me devolve os principais produtos X para todos os usuários do modeloEmissão em que incorri
O script ALS.Train funciona sem problemas e se adapta bem ao GCP (facilmente> 1 milhão de clientes).
No entanto, aplicar as previsões: ou seja, usar as funções 'PredictAll' ou 'recommendProductsForUsers', não é escalável. Meu script é tranquilo para um pequeno conjunto de dados (<100 Cliente com <100 produtos). No entanto, quando o tamanho é relevante para os negócios, não consigo escalá-lo (por exemplo,> 50 mil clientes e> 10 mil produtos)
O erro que recebo está abaixo:
16/08/16 14:38:56 WARN org.apache.spark.scheduler.TaskSetManager:
Lost task 22.0 in stage 411.0 (TID 15139,
productrecommendation-high-w-2.c.main-nova-558.internal):
java.lang.StackOverflowError
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1909)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
Até cheguei a obter um cluster de 300 GB (1 nó principal de 108 GB + 2 nós de 108 GB de RAM) para tentar executá-lo; funciona para 50 mil clientes, mas não para mais nada
A ambição é ter uma configuração na qual eu possa executar mais de 800 mil clientes
Detalhes
Linha de código onde falha
predictions = model.recommendProductsForUsers(10).flatMap(lambda p: p[1]).map(lambda p: (str(p[0]), str(p[1]), float(p[2])))
pprint.pprint(predictions.take(10))
schema = StructType([StructField("customer", StringType(), True), StructField("sku", StringType(), True), StructField("prediction", FloatType(), True)])
dfToSave = sqlContext.createDataFrame(predictions, schema).dropDuplicates()
Como você sugere que prossiga? Eu sinto que a parte 'mesclando' no final do meu script (ou seja, quando eu o escrevo no dfToSave) causa o erro; existe uma maneira de contornar isso e salvar parte por parte?