/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Modifications Copyright© 2019-2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#include <thrust/functional.h>
#include <thrust/set_operations.h>
#include <thrust/sort.h>

#include <unittest/unittest.h>

template <typename Vector>
void TestSetIntersectionByKeyDescendingSimple()
{
  using T        = typename Vector::value_type;
  using Iterator = typename Vector::iterator;

  Vector a_key{4, 2, 0}, b_key{4, 3, 3, 0};
  Vector a_val(3, 0);

  Vector ref_key{4, 0}, ref_val{0, 0};
  Vector result_key(2), result_val(2);

  thrust::pair<Iterator, Iterator> end = thrust::set_intersection_by_key(
    a_key.begin(),
    a_key.end(),
    b_key.begin(),
    b_key.end(),
    a_val.begin(),
    result_key.begin(),
    result_val.begin(),
    thrust::greater<T>());

  ASSERT_EQUAL_QUIET(result_key.end(), end.first);
  ASSERT_EQUAL_QUIET(result_val.end(), end.second);
  ASSERT_EQUAL(ref_key, result_key);
  ASSERT_EQUAL(ref_val, result_val);
}
DECLARE_VECTOR_UNITTEST(TestSetIntersectionByKeyDescendingSimple);

template <typename T>
void TestSetIntersectionByKeyDescending(const size_t n)
{
  thrust::host_vector<T> temp = unittest::random_integers<T>(2 * n);
  thrust::host_vector<T> h_a_key(temp.begin(), temp.begin() + n);
  thrust::host_vector<T> h_b_key(temp.begin() + n, temp.end());

  thrust::sort(h_a_key.begin(), h_a_key.end(), thrust::greater<T>());
  thrust::sort(h_b_key.begin(), h_b_key.end(), thrust::greater<T>());

  thrust::host_vector<T> h_a_val = unittest::random_integers<T>(h_a_key.size());

  thrust::device_vector<T> d_a_key = h_a_key;
  thrust::device_vector<T> d_b_key = h_b_key;

  thrust::device_vector<T> d_a_val = h_a_val;

  thrust::host_vector<T> h_result_key(n), h_result_val(n);
  thrust::device_vector<T> d_result_key(n), d_result_val(n);

  thrust::pair<typename thrust::host_vector<T>::iterator, typename thrust::host_vector<T>::iterator> h_end;

  thrust::pair<typename thrust::device_vector<T>::iterator, typename thrust::device_vector<T>::iterator> d_end;

  h_end = thrust::set_intersection_by_key(
    h_a_key.begin(),
    h_a_key.end(),
    h_b_key.begin(),
    h_b_key.end(),
    h_a_val.begin(),
    h_result_key.begin(),
    h_result_val.begin(),
    thrust::greater<T>());
  h_result_key.erase(h_end.first, h_result_key.end());
  h_result_val.erase(h_end.second, h_result_val.end());

  d_end = thrust::set_intersection_by_key(
    d_a_key.begin(),
    d_a_key.end(),
    d_b_key.begin(),
    d_b_key.end(),
    d_a_val.begin(),
    d_result_key.begin(),
    d_result_val.begin(),
    thrust::greater<T>());
  d_result_key.erase(d_end.first, d_result_key.end());
  d_result_val.erase(d_end.second, d_result_val.end());

  ASSERT_EQUAL(h_result_key, d_result_key);
  ASSERT_EQUAL(h_result_val, d_result_val);
}
DECLARE_VARIABLE_UNITTEST(TestSetIntersectionByKeyDescending);
