Combinando duas listas por chave usando Thrust

Dadas duas listas de valores-chave, estou tentando combinar os dois lados combinando as chaves e aplicando uma função aos dois valores quando as chaves coincidem. No meu caso, quero multiplicar os valores. Um pequeno exemplo para deixar mais claro:

Left keys:   { 1, 2, 4, 5, 6 }
Left values: { 3, 4, 1, 2, 1 }

Right keys:   { 1, 3, 4, 5, 6, 7 };
Right values: { 2, 1, 1, 4, 1, 2 };

Expected output keys:   { 1, 4, 5, 6 }
Expected output values: { 6, 1, 8, 1 }

Consegui implementar isso na CPU usando C ++ usando o próximo código:

int main() {
    int leftKeys[5] =   { 1, 2, 4, 5, 6 };
    int leftValues[5] = { 3, 4, 1, 2, 1 };
    int rightKeys[6] =   { 1, 3, 4, 5, 6, 7 };
    int rightValues[6] = { 2, 1, 1, 4, 1, 2 };

    int leftIndex = 0, rightIndex = 0;
    std::vector<std::tuple<int, int>> result;
    while (leftIndex < 5 && rightIndex < 6) {
        if (leftKeys[leftIndex] < rightKeys[rightIndex]) {
            leftIndex++;
        }
        if (leftKeys[leftIndex] > rightKeys[rightIndex]) {
            rightIndex++;
        }
        result.push_back(std::make_tuple(leftKeys[leftIndex], leftValues[leftIndex] * rightValues[rightIndex]));
        leftIndex++;
        rightIndex++;
    }

    // Print results
    for (int i = 0; i < result.size(); i++) {
        std::cout << "Key: " << std::get<0>(result[i]) << "; Value: " << std::get<1>(result[i]) << "\n";
    }
}

No entanto, eu tenho as chaves e valores de entrada no Thrust'sdevice_vectorse eu também precisar dos resultados na GPU. Portanto, seria mais eficiente se eu não precisasse copiar todas as entradas para o host e todas as saídas de volta para o dispositivo.

O problema é que não consigo encontrar uma função Thrust que possa ser usada para combinar duas listas usando um conjunto de chaves (e aplicar uma função aos dois valores). Existe uma função desse tipo ou existe uma maneira fácil de implementá-la, devo fazer isso no host?

Atualizar:

As seguintes suposições podem ser feitas sobre a entrada:

As chaves são sempre classificadas.Não existem chaves duplicadas em uma única lista (entre as listas, é claro que existem chaves duplicadas, caso contrário, o resultado estaria vazio).

Atualização 2:

Ao implementar a segunda abordagem na resposta de @ Robert, fico preso na transformação. Meu código até agora está abaixo:

struct multiply_transformation : public thrust::binary_function<std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>>
{
    __host__ __device__
        thrust::tuple<int, int> operator()(thrust::tuple<int, int> d_left, thrust::tuple<int, int> d_right)
    {
        if (thrust::get<0>(d_left) == thrust::get<0>(d_right)) {
            return thrust::make_tuple(thrust::get<0>(d_left), thrust::get<1>(d_left) * thrust::get<1>(d_right));
        }
        return thrust::make_tuple(-1, -1);
    }
};


thrust::device_vector<int> d_mergedKeys(h_leftCount + h_rightCount);
thrust::device_vector<int> d_mergedValues(h_leftCount + h_rightCount);
thrust::merge_by_key(d_leftCountKeys.begin(), d_leftCountKeys.begin() + h_leftCount,
    d_rightCountKeys.begin(), d_rightCountKeys.begin() + h_rightCount,
    d_leftCounts.begin(), d_rightCounts.begin(), d_mergedKeys.begin(), d_mergedValues.begin());

typedef thrust::tuple<int, int> IntTuple;
thrust::zip_iterator<IntTuple> d_zippedCounts(thrust::make_tuple(d_mergedKeys.begin(), d_mergedValues.begin()));
thrust::zip_iterator<IntTuple> d_zippedCountsOffset(d_zippedCounts + 1);

multiply_transformation transformOperator;
thrust::device_vector<IntTuple> d_transformedResult(h_leftCount + h_rightCount);
thrust::transform(d_zippedCounts, d_zippedCounts + h_leftCount + h_rightCount - 1, d_zippedCountsOffset, d_transformedResult.begin(), transformOperator);

No entanto, recebo o erro de que nenhuma função sobrecarregadathrust::transform corresponde à lista de argumentos. No código acimah_leftCount eh_rightCount são os tamanhos das entradas esquerda e direita.d_leftCountKeys, d_rightCountKeys, d_leftCountsed_rightCounts estãothrust::device_vector<int>.

questionAnswers(3)

yourAnswerToTheQuestion