xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision 0226ec35a99031c5bdd82a492055a25793359ffb)
1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2 #define PETSC_CUPM_THRUST_UTILITY_HPP
3 
4 #if defined(__cplusplus)
5   #include <petsclog.h>         // PetscLogGpuTimeBegin()/End()
6   #include <petscerror.h>       // SETERRQ()
7   #include <petscdevice_cupm.h> // PETSC_USING_NVCC
8 
9   #include <thrust/system_error.h>     // thrust::system_error
10   #include <thrust/execution_policy.h> // thrust::cuda/hip::par
11 
12 namespace Petsc
13 {
14 
15 namespace device
16 {
17 
18 namespace cupm
19 {
20 
21   #if PetscDefined(USING_NVCC)
22     #if !defined(THRUST_VERSION)
23       #error "THRUST_VERSION not defined!"
24     #endif
25     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
26       #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
27     #else
28       #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
29     #endif
30   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
31     #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
32   #else
33     #define PETSC_THRUST_CALL_PAR_ON(func, s, ...) func(__VA_ARGS__)
34   #endif
35 
36 namespace detail
37 {
38 
39 struct PetscLogGpuTimer {
40   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
41   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
42 };
43 
44 } // namespace detail
45 
46   #define THRUST_CALL(...) \
47     [&] { \
48       const auto timer = ::Petsc::device::cupm::detail::PetscLogGpuTimer{}; \
49       return PETSC_THRUST_CALL_PAR_ON(__VA_ARGS__); \
50     }()
51 
52   #define PetscCallThrust(...) \
53     do { \
54       try { \
55         { \
56           __VA_ARGS__; \
57         } \
58       } catch (const thrust::system_error &ex) { \
59         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
60       } \
61     } while (0)
62 
63 } // namespace cupm
64 
65 } // namespace device
66 
67 } // namespace Petsc
68 
69 #endif // __cplusplus
70 
71 #endif // PETSC_CUPM_THRUST_UTILITY_HPP
72