Devolución de múltiples matrices de la función agregada definida por el usuario (UDAF) en Apache Spark SQL
Estoy tratando de crear una función agregada definida por el usuario (UDAF) en Java usando Apache Spark SQL que devuelve múltiples matrices al finalizar. He buscado en línea y no puedo encontrar ningún ejemplo o sugerencia sobre cómo hacerlo.
Puedo devolver una sola matriz, pero no puedo entender cómo obtener los datos en el formato correcto en el método evaluado () para devolver múltiples matrices.
El UDAF funciona ya que puedo imprimir las matrices en el método de evaluación (), simplemente no puedo entender cómo devolver esas matrices al código de llamada (que se muestra a continuación como referencia).
UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
He incluido toda la clase UDAF personalizada a continuación, pero los métodos clave son el tipo de datos () y los métodos de evaluación (), que se muestran primero.
Cualquier ayuda o consejo sería muy apreciado. Gracias.
public class CustomUDAF extends UserDefinedAggregateFunction {
@Override
public DataType dataType() {
// TODO: Is this the correct way to return 2 arrays?
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public Object evaluate(Row buffer) {
// Data conversion
List<Long> longList = new ArrayList<Long>(buffer.getList(0));
List<Double> dataList = new ArrayList<Double>(buffer.getList(1));
// Processing of data (omitted)
// TODO: How to get data into format needed to return 2 arrays?
return dataList;
}
@Override
public StructType inputSchema() {
return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
}
@Override
public StructType bufferSchema() {
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new ArrayList<Long>());
buffer.update(1, new ArrayList<Double>());
}
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
longList.add(row.getLong(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
dataList.add(row.getDouble(1));
buffer.update(0, longList);
buffer.update(1, dataList);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
longList.addAll(buffer2.getList(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
dataList.addAll(buffer2.getList(1));
buffer1.update(0, longList);
buffer1.update(1, dataList);
}
@Override
public boolean deterministic() {
return true;
}
}
Actualizar: Basado en la respuesta de zero323 pude devolver dos matrices usando:
return new Tuple2<>(longArray, dataArray);
Sacar los datos de esto fue un poco difícil, pero implicó la deconstrucción del DataFrame a las listas de Java y luego volver a construirlo en un DataFrame.