Further chance of optimization of Thrust operation of CUDA kernel

77 Views Asked by At

I have a CUDA kernel which essentially looks like the following.

__global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
  int index = blockIdx.x*blockDim.x + threadIdx.x;

  if (index >= N) return;

  double _crit = crit[index];
  for (int i=0; i<nums[index]; i++) {
    double _res = vals[index*K + i];
    if (data[index*K + i] >= _crit) { _res = 0.0; }

    vals[index*K + i] = _res;
  }
}

This kernel evaluates vals[N*K] based on its data[N*K] compared to crit[N], and the comparison is conducted on the first nums[N] elements of the vals's segment (width K). If the data is smaller than crit, it leaves vals unchanged.

For example, data under consideration will look like the following

  int N = 3;
  int K = 5;

  vals[ 0] = 1.0; data[ 0] = 5.1; crit[0] = 5.0; nums[0] = 3;
  vals[ 1] = 1.0; data[ 1] = 4.9;
  vals[ 2] = 1.0; data[ 2] = 3.0;
  vals[ 3] = 0.0; data[ 3] = 0.0;
  vals[ 4] = 0.0; data[ 4] = 0.0;
  //-----------------------
  vals[ 5] = 1.0; data[ 5] = 2.9; crit[1] = 3.0; nums[1] = 2;
  vals[ 6] = 1.0; data[ 6] = 3.1;
  vals[ 7] = 0.0; data[ 7] = 0.0;
  vals[ 8] = 0.0; data[ 8] = 0.0;
  vals[ 9] = 0.0; data[ 9] = 0.0;
  //-----------------------
  vals[10] = 1.0; data[10] = 8.1; crit[2] = 9.0; nums[2] = 5;
  vals[11] = 1.0; data[11] = 7.8;
  vals[12] = 1.0; data[12] = 9.1;
  vals[13] = 1.0; data[13] = 200.;
  vals[14] = 1.0; data[14] = -1.0;

I noticed that this kind of operation is one of top 3 time-consuming kernels, and am considering Thrust-based acceleration.

What I came up with so far looks like the following. It uses expand provided on Thrust samples (https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu).

struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                        // vals   data   1/K 1%K nums crit
  __host__ __device__                   // 0      1      2   3   4    5
    double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
      double res;

      if (thrust::get<2>(t) >= thrust::get<4>(t)) {
        res = thrust::get<0>(t);  // do nothing
      }else {
        if (thrust::get<3>(t) >= thrust::get<4>(t)) {
          res = thrust::get<0>(t); // do nothing
        }else {
          double tmp = thrust::get<0>(t);
          if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
          res = tmp;
        }
      }

      return res;
    }
};

int main() {

  using namespace thrust::placeholders;

  thrust::device_vector<double> vals(N*K);
  thrust::device_vector<double> data(N*K);
  thrust::device_vector<double> crit(N);
  thrust::device_vector<int>    nums(N);

  thrust::device_vector<double> res(N*K);

  // ... fill values ...

  thrust::device_vector<int>    nums_expand(N*K);
  thrust::device_vector<double> crit_expand(N*K);

  // 'expand()' does something like [1,2,3] -> [1,1,1,2,2,2,3,3,3]
  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         nums.begin(),
         nums_expand.begin());

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         crit.begin(),
         crit_expand.begin());

  thrust::transform(thrust::make_zip_iterator(vals.begin(),
                                              data.begin(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // index related to N
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // index related to K
                                              nums_expand.begin(),
                                              crit_expand.begin()),
                    thrust::make_zip_iterator(vals.end(),
                                              data.end(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K,
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K,
                                              nums_expand.end(),
                                              crit_expand.end()),
                    res.begin(),
                    myOp());

  ...

}

When I tried this with arbitrary values in the arrays with sets of [N,K] = [1000,256], [10000,256], [50000,256], [100000,256], already the performance is satisfactory.

enter image description here

But I wonder if there is any further chance of speed-up with my Thrust operations. I am expanding some values to take them into if statements, but maybe this can be avoided by permutation_iterator and so on, but I cannot come up with how. Also, I am doing _1/K, _1%K stuff to get the global and local index of the elements, which could be somehow avoided with more clever mind.

At least, for the cosmetics point of view, I would love to insert expand(...) into thrust::transform(...) directly without having to define another vector such as nums_expand.

Any suggestions for any chance of improvements are welcome.

Full code used for the comparison

//https://stackoverflow.com/questions/31955505/can-thrust-transform-reduce-work-with-2-arrays%5B/url%5D

#include <thrust/device_vector.h>

#include <thrust/reduce.h>
#include <thrust/gather.h>
#include <thrust/copy.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
//#include <thrust/execution_policy.h>
#include <iostream>
#include <iomanip>

#include <thrust/transform.h>
#include <thrust/functional.h>

#include <helper_timer.h>

/////////  https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu //////////
template <typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator>
OutputIterator expand(InputIterator1 first1,
                      InputIterator1 last1,
                      InputIterator2 first2,
                      OutputIterator output)
{
  typedef typename thrust::iterator_difference<InputIterator1>::type difference_type;
  
  difference_type input_size  = thrust::distance(first1, last1);
  difference_type output_size = thrust::reduce(first1, last1);

  // scan the counts to obtain output offsets for each input element
  thrust::device_vector<difference_type> output_offsets(input_size, 0);
  thrust::exclusive_scan(first1, last1, output_offsets.begin()); 

  // scatter the nonzero counts into their corresponding output positions
  thrust::device_vector<difference_type> output_indices(output_size, 0);
  thrust::scatter_if
    (thrust::counting_iterator<difference_type>(0),
     thrust::counting_iterator<difference_type>(input_size),
     output_offsets.begin(),
     first1,
     output_indices.begin());

  // compute max-scan over the output indices, filling in the holes
  thrust::inclusive_scan
    (output_indices.begin(),
     output_indices.end(),
     output_indices.begin(),
     thrust::maximum<difference_type>());

  // gather input values according to index array (output = first2[output_indices])
  thrust::gather(output_indices.begin(),
                 output_indices.end(),
                 first2,
                 output);

  // return output + output_size
  thrust::advance(output, output_size);
  return output;
}

/////////////////////////////////////////////////////////////////////////////////////

template<typename T>
void print_vector(T& vec) {
  for (const auto& elem : vec) {
    std::cout << std::setw(5) << elem; 
  }
  std::cout << std::endl;
}

void printSdkTimer(StopWatchInterface **timer, int average) {
  float fAvgSeconds =
    ((float)1.0e-3 * (float)sdkGetTimerValue(timer) / (float)average);
  printf(" - Elapsed time: %.5f sec \n", fAvgSeconds);
}

struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                  // vals   data   1/K 1%K nums crit
  __host__ __device__             // 0      1      2   3   4    5
    double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
      double res;

      if (thrust::get<2>(t) >= thrust::get<4>(t)) {
        res = thrust::get<0>(t);  // do nothing
      }else {
        if (thrust::get<3>(t) >= thrust::get<4>(t)) {
          res = thrust::get<0>(t); // do nothing
        }else {
          double tmp = thrust::get<0>(t);
          if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
          res = tmp;
        }
      }

      return res;
    }
};

__global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
  int index = blockIdx.x*blockDim.x + threadIdx.x;

  if (index >= N) return;

  double _crit = crit[index];
  for (int i=0; i<nums[index]; i++) {
    double _res = vals[index*K + i];
    if (data[index*K + i] >= _crit) { _res = 0.0; }  

    vals[index*K + i] = _res;
  }
}

int main(int argc, char **argv) {

  using namespace thrust::placeholders;

  int N = atoi(argv[1]); 
  int K = atoi(argv[2]); 

  std::cout << "N " << N << " K " << K << std::endl;

  thrust::device_vector<double> vals(N*K);
  thrust::device_vector<double> data(N*K);
  thrust::device_vector<double> crit(N);
  thrust::device_vector<int>    nums(N);

  thrust::device_vector<double> res(N*K);

  for (int i=0; i<N; i++) {
    crit[i] = 101.0; // arbitrary
    nums[i] = 200;   // arbitrary number less than 256
    for (int j=0; j<K; j++) {
      vals[i*K + j] = (double)(i*K + j); // arbitrary
      data[i*K + j] = (double)(i*K + j); // arbitrary
    }
  }

  // to be used for kernel
  thrust::device_vector<double> vals2 = vals;
  thrust::device_vector<double> data2 = data;
  thrust::device_vector<double> crit2 = crit;
  thrust::device_vector<int>    nums2 = nums;

  StopWatchInterface *timer=NULL;
 
//--- 1) thrust
  thrust::device_vector<int>    nums_expand(N*K);
  thrust::device_vector<double> crit_expand(N*K);

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         nums.begin(),
         nums_expand.begin());

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         crit.begin(),
         crit_expand.begin());

  sdkCreateTimer(&timer);
  sdkStartTimer(&timer);

  thrust::transform(thrust::make_zip_iterator(vals.begin(), 
                                              data.begin(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // for N
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // for K
                                              nums_expand.begin(),
                                              crit_expand.begin()),
                    thrust::make_zip_iterator(vals.end(), 
                                              data.end(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K, 
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K, 
                                              nums_expand.end(),
                                              crit_expand.end()),
                    res.begin(),
                    myOp());

  sdkStopTimer(&timer);
  printSdkTimer(&timer,1);

  cudaDeviceSynchronize();
  sdkResetTimer(&timer);
  sdkStartTimer(&timer);

//--- 2) kernel
  double *raw_vals2 = thrust::raw_pointer_cast(vals2.data());
  double *raw_data2 = thrust::raw_pointer_cast(data2.data());
  double *raw_crit2 = thrust::raw_pointer_cast(crit2.data());
  int    *raw_nums2 = thrust::raw_pointer_cast(nums2.data());

  int Nthreads = 256;
  int Nblocks = (N*K - 1) / Nthreads + 1;
  myOpKernel<<<Nblocks,Nthreads>>>(raw_vals2, raw_data2, raw_nums2, raw_crit2, N, K);

  cudaDeviceSynchronize();

  sdkStopTimer(&timer);
  printSdkTimer(&timer,1);

  sdkDeleteTimer(&timer);

  return 0;
}
1

There are 1 best solutions below

2
Abator Abetor On BEST ANSWER

Below is some modified benchmark code with improved kernels. Compiled with nvcc --extended-lambda -arch=sm_89 -O3 main.cu -o main

Since the timer is not included in your code, I use cudaEvents instead. Data buffers are initialized in host vectors to avoid millions of memcopies. I also noticed that the thrust approach does not produce identical results to your kernel for large N.

I added two kernels. myOpKernel3 simply uses 1 threadblock per index to access the num[index] values.

myOpKernel4 uses 1 thread per output element. this requires a prefix sum of nums, and the computation of index per thread. I chose to precompute the indices. An alternative approach would be to perform a binary search on the prefix sum within the kernel.

For full segments, i.e. nums[i] = 256, the output is

N 100000 K 256
expandtime 6.28806 ms
thrusttransformtime 1.05165 ms
myOpKerneltime 2.04301 ms
results from thrust and myOpKernel do not match
myOpKerneltime3 0.662016 ms
myOpKernel4_setuptime 0.723872 ms
myOpKerneltime4 0.785408 ms

For nums[i] = 128

N 100000 K 256
expandtime 6.2976 ms
thrusttransformtime 1.05472 ms
myOpKerneltime 1.04054 ms
results from thrust and myOpKernel do not match
myOpKerneltime3 0.337952 ms
myOpKernel4_setuptime 0.369152 ms
myOpKerneltime4 0.386048 ms

nums[i] = 4

N 100000 K 256
expandtime 6.2936 ms
thrusttransformtime 1.05165 ms
myOpKerneltime 0.273536 ms
results from thrust and myOpKernel do not match
myOpKerneltime3 0.119104 ms
myOpKernel4_setuptime 0.293824 ms
myOpKerneltime4 0.066848 ms

I did not test non-uniform segment sizes. Note that performance costs of temporary memory allocations and thrust calls can be reduced by using custom allocators in conjunction with thrust's thrust::cuda::par_nosync execution policy.

//https://stackoverflow.com/questions/31955505/can-thrust-transform-reduce-work-with-2-arrays%5B/url%5D

#include <thrust/device_vector.h>

#include <thrust/reduce.h>
#include <thrust/gather.h>
#include <thrust/copy.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/execution_policy.h>
#include <thrust/host_vector.h>

#include <iostream>
#include <iomanip>

#include <thrust/transform.h>
#include <thrust/functional.h>


/////////  https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu //////////
template <typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator>
OutputIterator expand(InputIterator1 first1,
                      InputIterator1 last1,
                      InputIterator2 first2,
                      OutputIterator output)
{
  typedef typename thrust::iterator_difference<InputIterator1>::type difference_type;
  
  difference_type input_size  = thrust::distance(first1, last1);
  difference_type output_size = thrust::reduce(first1, last1);

  // scan the counts to obtain output offsets for each input element
  thrust::device_vector<difference_type> output_offsets(input_size, 0);
  thrust::exclusive_scan(first1, last1, output_offsets.begin()); 

  // scatter the nonzero counts into their corresponding output positions
  thrust::device_vector<difference_type> output_indices(output_size, 0);
  thrust::scatter_if
    (thrust::counting_iterator<difference_type>(0),
     thrust::counting_iterator<difference_type>(input_size),
     output_offsets.begin(),
     first1,
     output_indices.begin());

  // compute max-scan over the output indices, filling in the holes
  thrust::inclusive_scan
    (output_indices.begin(),
     output_indices.end(),
     output_indices.begin(),
     thrust::maximum<difference_type>());

  // gather input values according to index array (output = first2[output_indices])
  thrust::gather(output_indices.begin(),
                 output_indices.end(),
                 first2,
                 output);

  // return output + output_size
  thrust::advance(output, output_size);
  return output;
}

/////////////////////////////////////////////////////////////////////////////////////

template<typename T>
void print_vector(T& vec) {
  for (const auto& elem : vec) {
    std::cout << std::setw(5) << elem; 
  }
  std::cout << std::endl;
}

struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                  // vals   data   1/K 1%K nums crit
  __host__ __device__             // 0      1      2   3   4    5
    double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
      double res;

      if (thrust::get<2>(t) >= thrust::get<4>(t)) {
        res = thrust::get<0>(t);  // do nothing
      }else {
        if (thrust::get<3>(t) >= thrust::get<4>(t)) {
          res = thrust::get<0>(t); // do nothing
        }else {
          double tmp = thrust::get<0>(t);
          if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
          res = tmp;
        }
      }

      return res;
    }
};

__global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
  int index = blockIdx.x*blockDim.x + threadIdx.x;

  if (index >= N) return;

  double _crit = crit[index];
  for (int i=0; i<nums[index]; i++) {
    double _res = vals[index*K + i];
    if (data[index*K + i] >= _crit) { _res = 0.0; }  

    vals[index*K + i] = _res;
  }
}


  __global__ void myOpKernel3(
      double * __restrict__ vals, 
      const double * __restrict__ data, 
      const int * __restrict__ nums, 
      const double * __restrict__ crit, 
      int N, 
      int K
    ){
        for(int index = blockIdx.x; index < N; index += gridDim.x){   
            const double _crit = crit[index];
            const int num = nums[index];
            for(int i = threadIdx.x; i < num; i += blockDim.x){
                double _res = vals[index*K + i];
                if (data[index*K + i] >= _crit) { _res = 0.0; }          
                vals[index*K + i] = _res;
            }
        }
  }

  __global__ 
  void myOpKernel4(
    double * __restrict__ vals, 
    const double * __restrict__ data, 
    const int * __restrict__ nums, 
    const double * __restrict__ crit, 
    const int* __restrict__ numsPrefixSum,
    const int* __restrict__ indexForThread,
    int totalnums,
    int N, 
    int K
  ){
      const int tid = threadIdx.x + blockIdx.x * blockDim.x;
      const int numValid = totalnums;
      if(tid < numValid){
          const int index = indexForThread[tid];
          const int i = tid - numsPrefixSum[index];
          const double _crit = crit[index];
          double _res = vals[index*K + i];
            if (data[index*K + i] >= _crit) { _res = 0.0; }          
            vals[index*K + i] = _res;
      }
}

int main(int argc, char **argv) {

  using namespace thrust::placeholders;

  int N = atoi(argv[1]); 
  int K = atoi(argv[2]); 

  std::cout << "N " << N << " K " << K << std::endl;

  thrust::host_vector<double> h_vals(N*K);
  thrust::host_vector<double> h_data(N*K);
  thrust::host_vector<double> h_crit(N);
  thrust::host_vector<int>    h_nums(N);

  for (int i=0; i<N; i++) {
    h_crit[i] = 101.0; // arbitrary
    h_nums[i] = 4;   // arbitrary number less than 256
    for (int j=0; j<K; j++) {
        h_vals[i*K + j] = (double)(i*K + j); // arbitrary
        h_data[i*K + j] = (double)(i*K + j); // arbitrary
    }
  }

  thrust::device_vector<double> vals = h_vals;
  thrust::device_vector<double> data = h_data;
  thrust::device_vector<double> crit = h_crit;
  thrust::device_vector<int>    nums = h_nums;

  thrust::device_vector<double> res(vals.size());



  cudaEvent_t eventA; cudaEventCreate(&eventA);
  cudaEvent_t eventB; cudaEventCreate(&eventB);
 
  //--- 1) thrust
  cudaEventRecord(eventA);
  thrust::device_vector<int>    nums_expand(N*K);
  thrust::device_vector<double> crit_expand(N*K);


  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         nums.begin(),
         nums_expand.begin());

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         crit.begin(),
         crit_expand.begin());

  cudaEventRecord(eventB);
  cudaEventSynchronize(eventB);
  float expandtime; cudaEventElapsedTime(&expandtime, eventA, eventB);
  std::cout << "expandtime " << expandtime << " ms\n";

  cudaEventRecord(eventA);

  thrust::transform(thrust::make_zip_iterator(vals.begin(), 
                                              data.begin(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // for N
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // for K
                                              nums_expand.begin(),
                                              crit_expand.begin()),
                    thrust::make_zip_iterator(vals.end(), 
                                              data.end(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K, 
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K, 
                                              nums_expand.end(),
                                              crit_expand.end()),
                    res.begin(),
                    myOp());


cudaEventRecord(eventB);
cudaEventSynchronize(eventB);
float thrusttransformtime; cudaEventElapsedTime(&thrusttransformtime, eventA, eventB);
std::cout << "thrusttransformtime " << thrusttransformtime << " ms\n";
  cudaDeviceSynchronize();

//   std::cout << "vals after thrust\n";
//   for(int i = 0; i < res.size(); i++){
//     std::cout << res[i] << " ";
//   }
//   std::cout << "\n";


//--- 2) kernel
thrust::device_vector<double> vals2 = h_vals;
thrust::device_vector<double> data2 = h_data;
thrust::device_vector<double> crit2 = h_crit;
thrust::device_vector<int>    nums2 = h_nums;

  cudaEventRecord(eventA);

  int Nthreads = 256;
  int Nblocks = (N*K - 1) / Nthreads + 1;
  myOpKernel<<<Nblocks,Nthreads>>>(vals2.data().get(), data2.data().get(), nums2.data().get(), crit2.data().get(), N, K);

  cudaEventRecord(eventB);
cudaEventSynchronize(eventB);
float myOpKerneltime; cudaEventElapsedTime(&myOpKerneltime, eventA, eventB);
std::cout << "myOpKerneltime " << myOpKerneltime << " ms\n";

  cudaDeviceSynchronize();

  if(res == vals2){
      std::cout << "results from thrust and myOpKernel match\n";
  }else{
    std::cout << "results from thrust and myOpKernel do not match\n";
  }


  

  {
    //1 block per index

    thrust::device_vector<double> vals_new = h_vals;
      thrust::device_vector<double> data_new = h_data;
      thrust::device_vector<double> crit_new = h_crit;
      thrust::device_vector<int>    nums_new = h_nums;

      cudaEventRecord(eventA);

      int Nthreads = 256;
      int Nblocks = N;
      myOpKernel3<<<Nblocks,Nthreads>>>(vals_new.data().get(), data_new.data().get(), nums_new.data().get(), crit_new.data().get(), N, K);

      cudaEventRecord(eventB);
      cudaEventSynchronize(eventB);
      float myOpKerneltime3; cudaEventElapsedTime(&myOpKerneltime3, eventA, eventB);
      std::cout << "myOpKerneltime3 " << myOpKerneltime3 << " ms\n";

      cudaDeviceSynchronize();

      assert(vals_new == vals2);
    }
      {
        //1 thread per output position
    
        thrust::device_vector<double> vals_new = h_vals;
          thrust::device_vector<double> data_new = h_data;
          thrust::device_vector<double> crit_new = h_crit;
          thrust::device_vector<int>    nums_new = h_nums;

          
          cudaEventRecord(eventA);
          thrust::device_vector<int> numsPrefixSum(N+1);
          numsPrefixSum[0] = 0;
          thrust::inclusive_scan(
              nums_new.begin(),
              nums_new.end(),
              numsPrefixSum.begin() + 1
            );
        const int totalNums = numsPrefixSum.back();
        thrust::device_vector<int> indexForThread(totalNums, 0);

          thrust::scatter_if(
                thrust::make_counting_iterator(0),
                thrust::make_counting_iterator(0) + N, 
                numsPrefixSum.begin(),
                thrust::make_transform_iterator(
                    nums_new.begin(), 
                    [] __host__ __device__ (int i){return i > 0;}
                ),
                indexForThread.begin()
            );
        
            thrust::inclusive_scan(
                indexForThread.begin(), 
                indexForThread.begin() + totalNums, 
                indexForThread.begin(), 
                thrust::maximum<int>{}
            );

          cudaEventRecord(eventB);
          cudaEventSynchronize(eventB);
          float myOpKernel4_setuptime; cudaEventElapsedTime(&myOpKernel4_setuptime, eventA, eventB);
          std::cout << "myOpKernel4_setuptime " << myOpKernel4_setuptime << " ms\n";
    
          cudaEventRecord(eventA);
    
          int Nthreads = 256;
          int Nblocks = (totalNums + Nthreads - 1) / Nthreads;
          myOpKernel4<<<Nblocks,Nthreads>>>(
            vals_new.data().get(), 
            data_new.data().get(), 
            nums_new.data().get(), 
            crit_new.data().get(),
            numsPrefixSum.data().get(),
            indexForThread.data().get(),
            totalNums,
            N, 
            K
          );

          cudaEventRecord(eventB);
          cudaEventSynchronize(eventB);
          float myOpKerneltime4; cudaEventElapsedTime(&myOpKerneltime4, eventA, eventB);
          std::cout << "myOpKerneltime4 " << myOpKerneltime4 << " ms\n";
    
          cudaDeviceSynchronize();
    
          assert(vals_new == vals2);
        }

  cudaEventDestroy(eventA);
  cudaEventDestroy(eventB);

  return 0;
}