1 #pragma once
2
3 #include "vecmpicupm.hpp"
4
5 #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
6
7 #include <petsc/private/sfimpl.h> // for vec->localupdate (_p_VecScatter) in duplicate()
8
9 namespace Petsc
10 {
11
12 namespace vec
13 {
14
15 namespace cupm
16 {
17
18 namespace impl
19 {
20
21 template <device::cupm::DeviceType T>
VecIMPLCast_(Vec v)22 inline Vec_MPI *VecMPI_CUPM<T>::VecIMPLCast_(Vec v) noexcept
23 {
24 return static_cast<Vec_MPI *>(v->data);
25 }
26
27 template <device::cupm::DeviceType T>
VECIMPLCUPM_()28 inline constexpr VecType VecMPI_CUPM<T>::VECIMPLCUPM_() noexcept
29 {
30 return VECMPICUPM();
31 }
32
33 template <device::cupm::DeviceType T>
VECIMPL_()34 inline constexpr VecType VecMPI_CUPM<T>::VECIMPL_() noexcept
35 {
36 return VECMPI;
37 }
38
39 template <device::cupm::DeviceType T>
VecDestroy_IMPL_(Vec v)40 inline PetscErrorCode VecMPI_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
41 {
42 PetscFunctionBegin;
43 PetscCall(VecSeq_T::ClearAsyncFunctions(v));
44 PetscCall(VecDestroy_MPI(v));
45 PetscFunctionReturn(PETSC_SUCCESS);
46 }
47
48 template <device::cupm::DeviceType T>
VecResetArray_IMPL_(Vec v)49 inline PetscErrorCode VecMPI_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
50 {
51 return VecResetArray_MPI(v);
52 }
53
54 template <device::cupm::DeviceType T>
VecPlaceArray_IMPL_(Vec v,const PetscScalar * a)55 inline PetscErrorCode VecMPI_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
56 {
57 return VecPlaceArray_MPI(v, a);
58 }
59
60 template <device::cupm::DeviceType T>
VecCreate_IMPL_Private_(Vec v,PetscBool * alloc_missing,PetscInt nghost,PetscScalar *)61 inline PetscErrorCode VecMPI_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *) noexcept
62 {
63 PetscFunctionBegin;
64 if (alloc_missing) *alloc_missing = PETSC_TRUE;
65 // note host_array is always ignored, we never create it as part of the construction sequence
66 // for VecMPI since we always want to either allocate it ourselves with pinned memory or set
67 // it in Initialize_CUPMBase()
68 PetscCall(VecCreate_MPI_Private(v, PETSC_FALSE, nghost, nullptr));
69 PetscCall(VecSeq_T::InitializeAsyncFunctions(v));
70 PetscFunctionReturn(PETSC_SUCCESS);
71 }
72
73 template <device::cupm::DeviceType T>
CreateMPICUPM_(Vec v,PetscDeviceContext dctx,PetscBool allocate_missing,PetscInt nghost,PetscScalar * host_array,PetscScalar * device_array)74 inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM_(Vec v, PetscDeviceContext dctx, PetscBool allocate_missing, PetscInt nghost, PetscScalar *host_array, PetscScalar *device_array) noexcept
75 {
76 PetscFunctionBegin;
77 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, nghost));
78 PetscCall(Initialize_CUPMBase(v, allocate_missing, host_array, device_array, dctx));
79 PetscFunctionReturn(PETSC_SUCCESS);
80 }
81
82 // ================================================================================== //
83 // //
84 // public methods //
85 // //
86 // ================================================================================== //
87
88 // ================================================================================== //
89 // constructors/destructors //
90
91 // VecCreateMPICUPM()
92 template <device::cupm::DeviceType T>
CreateMPICUPM(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,Vec * v,PetscBool call_set_type)93 inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type) noexcept
94 {
95 PetscFunctionBegin;
96 PetscCall(Create_CUPMBase(comm, bs, n, N, v, call_set_type));
97 PetscFunctionReturn(PETSC_SUCCESS);
98 }
99
100 // VecCreateMPICUPMWithArray[s]()
101 template <device::cupm::DeviceType T>
CreateMPICUPMWithArrays(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar host_array[],const PetscScalar device_array[],Vec * v)102 inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
103 {
104 PetscDeviceContext dctx;
105
106 PetscFunctionBegin;
107 PetscCall(GetHandles_(&dctx));
108 // do NOT call VecSetType(), otherwise ops->create() -> create() ->
109 // CreateMPICUPM_() is called!
110 PetscCall(CreateMPICUPM(comm, bs, n, N, v, PETSC_FALSE));
111 PetscCall(CreateMPICUPM_(*v, dctx, PETSC_FALSE, 0, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
112 PetscFunctionReturn(PETSC_SUCCESS);
113 }
114
115 // v->ops->duplicate
116 template <device::cupm::DeviceType T>
Duplicate(Vec v,Vec * y)117 inline PetscErrorCode VecMPI_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
118 {
119 const auto vimpl = VecIMPLCast(v);
120 const auto nghost = vimpl->nghost;
121 PetscDeviceContext dctx;
122
123 PetscFunctionBegin;
124 PetscCall(GetHandles_(&dctx));
125 // does not call VecSetType(), we set up the data structures ourselves
126 PetscCall(Duplicate_CUPMBase(v, y, dctx, [=](Vec z) { return CreateMPICUPM_(z, dctx, PETSC_FALSE, nghost); }));
127
128 /* save local representation of the parallel vector (and scatter) if it exists */
129 if (const auto locrep = vimpl->localrep) {
130 const auto yimpl = VecIMPLCast(*y);
131 auto &ylocrep = yimpl->localrep;
132 PetscScalar *array;
133
134 PetscCall(VecGetArray(*y, &array));
135 PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, v->map->bs, v->map->n + nghost, array, &ylocrep));
136 PetscCall(VecRestoreArray(*y, &array));
137 ylocrep->ops[0] = locrep->ops[0];
138 if (const auto scatter = (yimpl->localupdate = vimpl->localupdate)) PetscCall(PetscObjectReference(PetscObjectCast(scatter)));
139 }
140 PetscFunctionReturn(PETSC_SUCCESS);
141 }
142
143 // v->ops->bintocpu
144 template <device::cupm::DeviceType T>
BindToCPU(Vec v,PetscBool usehost)145 inline PetscErrorCode VecMPI_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
146 {
147 PetscDeviceContext dctx;
148
149 PetscFunctionBegin;
150 PetscCall(GetHandles_(&dctx));
151 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
152
153 VecSetOp_CUPM(dot, VecDot_MPI, Dot);
154 VecSetOp_CUPM(mdot, VecMDot_MPI, MDot);
155 VecSetOp_CUPM(norm, VecNorm_MPI, Norm);
156 VecSetOp_CUPM(tdot, VecTDot_MPI, TDot);
157 VecSetOp_CUPM(resetarray, VecResetArray_MPI, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
158 VecSetOp_CUPM(placearray, VecPlaceArray_MPI, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
159 VecSetOp_CUPM(max, VecMax_MPI, Max);
160 VecSetOp_CUPM(min, VecMin_MPI, Min);
161 PetscFunctionReturn(PETSC_SUCCESS);
162 }
163
164 // ================================================================================== //
165 // compute methods //
166
167 template <device::cupm::DeviceType T>
Norm(Vec v,NormType type,PetscReal * z)168 inline PetscErrorCode VecMPI_CUPM<T>::Norm(Vec v, NormType type, PetscReal *z) noexcept
169 {
170 PetscFunctionBegin;
171 PetscCall(VecNorm_MPI_Default(v, type, z, VecSeq_T::Norm));
172 PetscFunctionReturn(PETSC_SUCCESS);
173 }
174
175 template <device::cupm::DeviceType T>
ErrorWnorm(Vec U,Vec Y,Vec E,NormType wnormtype,PetscReal atol,Vec vatol,PetscReal rtol,Vec vrtol,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)176 inline PetscErrorCode VecMPI_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
177 {
178 PetscFunctionBegin;
179 PetscCall(VecErrorWeightedNorms_MPI_Default(U, Y, E, wnormtype, atol, vatol, rtol, vrtol, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc, VecSeq_T::ErrorWnorm));
180 PetscFunctionReturn(PETSC_SUCCESS);
181 }
182
183 template <device::cupm::DeviceType T>
Dot(Vec x,Vec y,PetscScalar * z)184 inline PetscErrorCode VecMPI_CUPM<T>::Dot(Vec x, Vec y, PetscScalar *z) noexcept
185 {
186 PetscFunctionBegin;
187 PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::Dot));
188 PetscFunctionReturn(PETSC_SUCCESS);
189 }
190
191 template <device::cupm::DeviceType T>
TDot(Vec x,Vec y,PetscScalar * z)192 inline PetscErrorCode VecMPI_CUPM<T>::TDot(Vec x, Vec y, PetscScalar *z) noexcept
193 {
194 PetscFunctionBegin;
195 PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::TDot));
196 PetscFunctionReturn(PETSC_SUCCESS);
197 }
198
199 template <device::cupm::DeviceType T>
MDot(Vec x,PetscInt nv,const Vec y[],PetscScalar * z)200 inline PetscErrorCode VecMPI_CUPM<T>::MDot(Vec x, PetscInt nv, const Vec y[], PetscScalar *z) noexcept
201 {
202 PetscFunctionBegin;
203 PetscCall(VecMXDot_MPI_Default(x, nv, y, z, VecSeq_T::MDot));
204 PetscFunctionReturn(PETSC_SUCCESS);
205 }
206
207 template <device::cupm::DeviceType T>
DotNorm2(Vec x,Vec y,PetscScalar * dp,PetscScalar * nm)208 inline PetscErrorCode VecMPI_CUPM<T>::DotNorm2(Vec x, Vec y, PetscScalar *dp, PetscScalar *nm) noexcept
209 {
210 PetscFunctionBegin;
211 PetscCall(VecDotNorm2_MPI_Default(x, y, dp, nm, VecSeq_T::DotNorm2));
212 PetscFunctionReturn(PETSC_SUCCESS);
213 }
214
215 template <device::cupm::DeviceType T>
Max(Vec x,PetscInt * idx,PetscReal * z)216 inline PetscErrorCode VecMPI_CUPM<T>::Max(Vec x, PetscInt *idx, PetscReal *z) noexcept
217 {
218 const MPI_Op ops[] = {MPIU_MAXLOC, MPIU_MAX};
219
220 PetscFunctionBegin;
221 PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Max, ops));
222 PetscFunctionReturn(PETSC_SUCCESS);
223 }
224
225 template <device::cupm::DeviceType T>
Min(Vec x,PetscInt * idx,PetscReal * z)226 inline PetscErrorCode VecMPI_CUPM<T>::Min(Vec x, PetscInt *idx, PetscReal *z) noexcept
227 {
228 const MPI_Op ops[] = {MPIU_MINLOC, MPIU_MIN};
229
230 PetscFunctionBegin;
231 PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Min, ops));
232 PetscFunctionReturn(PETSC_SUCCESS);
233 }
234
235 template <device::cupm::DeviceType T>
SetPreallocationCOO(Vec x,PetscCount ncoo,const PetscInt coo_i[])236 inline PetscErrorCode VecMPI_CUPM<T>::SetPreallocationCOO(Vec x, PetscCount ncoo, const PetscInt coo_i[]) noexcept
237 {
238 PetscDeviceContext dctx;
239
240 PetscFunctionBegin;
241 PetscCall(GetHandles_(&dctx));
242 PetscCall(VecSetPreallocationCOO_MPI(x, ncoo, coo_i));
243 // both of these must exist for this to work
244 PetscCall(VecCUPMAllocateCheck_(x));
245 {
246 const auto vcu = VecCUPMCast(x);
247 const auto vmpi = VecIMPLCast(x);
248
249 // clang-format off
250 PetscCall(
251 SetPreallocationCOO_CUPMBase(
252 x, ncoo, coo_i, dctx,
253 util::make_array(
254 make_coo_pair(vcu->imap2_d, vmpi->imap2, vmpi->nnz2),
255 make_coo_pair(vcu->jmap2_d, vmpi->jmap2, vmpi->nnz2 + 1),
256 make_coo_pair(vcu->perm2_d, vmpi->perm2, vmpi->recvlen),
257 make_coo_pair(vcu->Cperm_d, vmpi->Cperm, vmpi->sendlen)
258 ),
259 util::make_array(
260 make_coo_pair(vcu->sendbuf_d, vmpi->sendbuf, vmpi->sendlen),
261 make_coo_pair(vcu->recvbuf_d, vmpi->recvbuf, vmpi->recvlen)
262 )
263 )
264 );
265 // clang-format on
266 }
267 PetscFunctionReturn(PETSC_SUCCESS);
268 }
269
270 namespace kernels
271 {
272
273 namespace
274 {
275
pack_coo_values(const PetscScalar * PETSC_RESTRICT vv,PetscCount nnz,const PetscCount * PETSC_RESTRICT perm,PetscScalar * PETSC_RESTRICT buf)276 PETSC_KERNEL_DECL void pack_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz, const PetscCount *PETSC_RESTRICT perm, PetscScalar *PETSC_RESTRICT buf)
277 {
278 Petsc::device::cupm::kernels::util::grid_stride_1D(nnz, [=](PetscCount i) { buf[i] = vv[perm[i]]; });
279 return;
280 }
281
add_remote_coo_values(const PetscScalar * PETSC_RESTRICT vv,PetscCount nnz2,const PetscCount * PETSC_RESTRICT imap2,const PetscCount * PETSC_RESTRICT jmap2,const PetscCount * PETSC_RESTRICT perm2,PetscScalar * PETSC_RESTRICT xv)282 PETSC_KERNEL_DECL void add_remote_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz2, const PetscCount *PETSC_RESTRICT imap2, const PetscCount *PETSC_RESTRICT jmap2, const PetscCount *PETSC_RESTRICT perm2, PetscScalar *PETSC_RESTRICT xv)
283 {
284 add_coo_values_impl(vv, nnz2, jmap2, perm2, ADD_VALUES, xv, [=](PetscCount i) { return imap2[i]; });
285 return;
286 }
287
288 } // namespace
289
290 #if PetscDefined(USING_HCC)
291 namespace do_not_use
292 {
293
294 // Needed to silence clang warning:
295 //
296 // warning: function 'FUNCTION NAME' is not needed and will not be emitted
297 //
298 // The warning is silly, since the function *is* used, however the host compiler does not
299 // appear see this. Likely because the function using it is in a template.
300 //
301 // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
silence_warning_function_pack_coo_values_is_not_needed_and_will_not_be_emitted()302 inline void silence_warning_function_pack_coo_values_is_not_needed_and_will_not_be_emitted()
303 {
304 (void)pack_coo_values;
305 }
306
silence_warning_function_add_remote_coo_values_is_not_needed_and_will_not_be_emitted()307 inline void silence_warning_function_add_remote_coo_values_is_not_needed_and_will_not_be_emitted()
308 {
309 (void)add_remote_coo_values;
310 }
311
312 } // namespace do_not_use
313 #endif
314
315 } // namespace kernels
316
317 template <device::cupm::DeviceType T>
SetValuesCOO(Vec x,const PetscScalar v[],InsertMode imode)318 inline PetscErrorCode VecMPI_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
319 {
320 PetscDeviceContext dctx;
321 PetscMemType v_memtype;
322 cupmStream_t stream;
323
324 PetscFunctionBegin;
325 PetscCall(GetHandles_(&dctx, &stream));
326 PetscCall(PetscGetMemType(v, &v_memtype));
327 {
328 const auto vmpi = VecIMPLCast(x);
329 const auto vcu = VecCUPMCast(x);
330 const auto sf = vmpi->coo_sf;
331 const auto sendbuf_d = vcu->sendbuf_d;
332 const auto recvbuf_d = vcu->recvbuf_d;
333 const auto xv = imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data();
334 auto vv = const_cast<PetscScalar *>(v);
335
336 if (PetscMemTypeHost(v_memtype)) {
337 const auto size = vmpi->coo_n;
338
339 /* If user gave v[] in host, we might need to copy it to device if any */
340 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
341 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
342 }
343
344 /* Pack entries to be sent to remote */
345 if (const auto sendlen = vmpi->sendlen) {
346 PetscCall(PetscCUPMLaunchKernel1D(sendlen, 0, stream, kernels::pack_coo_values, vv, sendlen, vcu->Cperm_d, sendbuf_d));
347 // need to sync up here since we are about to send this to petscsf
348 // REVIEW ME: no we dont, sf just needs to learn to use PetscDeviceContext
349 PetscCallCUPM(cupmStreamSynchronize(stream));
350 }
351
352 PetscCall(PetscSFReduceWithMemTypeBegin(sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), sendbuf_d, PETSC_MEMTYPE_CUPM(), recvbuf_d, MPI_REPLACE));
353
354 if (const auto n = x->map->n) PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, xv));
355
356 PetscCall(PetscSFReduceEnd(sf, MPIU_SCALAR, sendbuf_d, recvbuf_d, MPI_REPLACE));
357
358 /* Add received remote entries */
359 if (const auto nnz2 = vmpi->nnz2) PetscCall(PetscCUPMLaunchKernel1D(nnz2, 0, stream, kernels::add_remote_coo_values, recvbuf_d, nnz2, vcu->imap2_d, vcu->jmap2_d, vcu->perm2_d, xv));
360
361 if (PetscMemTypeHost(v_memtype)) PetscCall(PetscDeviceFree(dctx, vv));
362 PetscCall(PetscDeviceContextSynchronize(dctx));
363 }
364 PetscFunctionReturn(PETSC_SUCCESS);
365 }
366
367 } // namespace impl
368
369 } // namespace cupm
370
371 } // namespace vec
372
373 } // namespace Petsc
374