Inferencia de Tensorflow Java Multi-GPU

Tengo un servidor con múltiples GPU y quiero hacer un uso completo de ellas durante la inferencia del modelo dentro de una aplicación Java. Por defecto, tensorflow incauta todas las GPU disponibles, pero usa solo la primera.

Se me ocurren tres opciones para superar este problema:

Restrinja la visibilidad del dispositivo a nivel de proceso, es decir, utilizandoCUDA_VISIBLE_DEVICES Variable ambiental.

Eso requeriría que ejecute varias instancias de la aplicación Java y distribuya el tráfico entre ellas. No es esa idea tentadora.

Inicie varias sesiones dentro de una sola aplicación e intente asignar un dispositivo a cada una de ellas medianteConfigProto:

public class DistributedPredictor {

    private Predictor[] nested;
    private int[] counters;

    // ...

    public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) {
        nested = new Predictor[numDevices];
        counters = new int[numDevices];

        for (int i = 0; i < nested.length; i++) {
            nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice);
        }
    }

    public Prediction predict(Data data) {
        int i = acquirePredictorIndex();
        Prediction result = nested[i].predict(data);
        releasePredictorIndex(i);
        return result;
    }

    private synchronized int acquirePredictorIndex() {
        int i = argmin(counters);
        counters[i] += 1;
        return i;
    }

    private synchronized void releasePredictorIndex(int i) {
        counters[i] -= 1;
    }
}


public class Predictor {

    private Session session;

    public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {

        GPUOptions gpuOptions = GPUOptions.newBuilder()
                .setVisibleDeviceList("" + deviceIdx)
                .setAllowGrowth(true)
                .build();

        ConfigProto config = ConfigProto.newBuilder()
                .setGpuOptions(gpuOptions)
                .setInterOpParallelismThreads(numDevices * numThreadsPerDevice)
                .build();

        byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
        Graph graph = new Graph();
        graph.importGraphDef(graphDef);

        this.session = new Session(graph, config.toByteArray());
    }

    public Prediction predict(Data data) {
        // ...
    }
}

Este enfoque parece funcionar bien de un vistazo. Sin embargo, las sesiones ocasionalmente ignoransetVisibleDeviceList opción y todos van por el primer dispositivo que causa el bloqueo de memoria insuficiente.

Construye el modelo en una forma de múltiples torres en Python usandotf.device() especificación. En el lado de Java, dar diferentePredictors diferentes torres dentro de una sesión compartida.

Se siente engorroso e idiomáticamente equivocado para mí.

ACTUALIZAR: Como @ash propuso, hay otra opción:

Asigne un dispositivo apropiado a cada operación del gráfico existente modificando su definición (graphDef)

Para hacerlo, uno podría adaptar el código del Método 2:

public class Predictor {

    private Session session;

    public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {

        byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
        graphDef = setGraphDefDevice(graphDef, deviceIdx)

        Graph graph = new Graph();
        graph.importGraphDef(graphDef);

        ConfigProto config = ConfigProto.newBuilder()
                .setAllowSoftPlacement(true)
                .build();

        this.session = new Session(graph, config.toByteArray());
    }

    private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException {
        String deviceString = String.format("/gpu:%d", deviceIdx);

        GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
        for (int i = 0; i < builder.getNodeCount(); i++) {
            builder.getNodeBuilder(i).setDevice(deviceString);
        }
        return builder.build().toByteArray();
    }

    public Prediction predict(Data data) {
        // ...
    }
}

Al igual que otros enfoques mencionados, este no me libera de la distribución manual de datos entre dispositivos. Pero al menos funciona de manera estable y es comparablemente fácil de implementar. En general, esto parece una técnica (casi) normal.

¿Hay una manera elegante de hacer algo tan básico con tensorflow java API? Cualquier idea sería apreciada.

Respuestas a la pregunta(1)

Su respuesta a la pregunta