xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision fc47f7259de0629fb6058a3a6076517fd44721f2)
1 #pragma once
2 
3 #include <petsclog.h>         // PetscLogGpuTimeBegin()/End()
4 #include <petscsys.h>         // SETERRQ()
5 #include <petscdevice_cupm.h> // PETSC_USING_NVCC
6 
7 #include <thrust/version.h>          // THRUST_VERSION
8 #include <thrust/system_error.h>     // thrust::system_error
9 #include <thrust/execution_policy.h> // thrust::cuda/hip::par
10 
11 namespace Petsc
12 {
13 
14 namespace device
15 {
16 
17 namespace cupm
18 {
19 
20 #if PetscDefined(USING_NVCC)
21   #if !defined(THRUST_VERSION)
22     #error "THRUST_VERSION not defined!"
23   #endif
24   #if THRUST_VERSION >= 101600
25     #define PETSC_THRUST_HAS_ASYNC                 1
26     #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
27   #else
28     #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
29   #endif
30 #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
31   #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
32 #else
33   #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(__VA_ARGS__)
34 #endif
35 
36 #ifndef PETSC_THRUST_HAS_ASYNC
37   #define PETSC_THRUST_HAS_ASYNC 0
38 #endif
39 
40 namespace detail
41 {
42 
43 struct PetscLogGpuTimer {
44   PetscLogGpuTimer() noexcept
45   {
46     PetscFunctionBegin;
47     PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin());
48     PetscFunctionReturnVoid();
49   }
50 
51   ~PetscLogGpuTimer() noexcept
52   {
53     PetscFunctionBegin;
54     PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd());
55     PetscFunctionReturnVoid();
56   }
57 };
58 
59 } // namespace detail
60 
61 #define THRUST_CALL(...) \
62   [&] { \
63     const auto timer = ::Petsc::device::cupm::detail::PetscLogGpuTimer{}; \
64     return PETSC_THRUST_CALL_PAR_ON(__VA_ARGS__); \
65   }()
66 
67 #define PetscCallThrust(...) \
68   do { \
69     try { \
70       { \
71         __VA_ARGS__; \
72       } \
73     } catch (const thrust::system_error &ex) { \
74       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
75     } \
76   } while (0)
77 
78 } // namespace cupm
79 
80 } // namespace device
81 
82 } // namespace Petsc
83