1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
CeedVectorNeedSync_Hip(const CeedVector vec,CeedMemType mem_type,bool * need_sync)21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22 CeedVector_Hip *impl;
23 bool has_valid_array = false;
24
25 CeedCallBackend(CeedVectorGetData(vec, &impl));
26 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27 switch (mem_type) {
28 case CEED_MEM_HOST:
29 *need_sync = has_valid_array && !impl->h_array;
30 break;
31 case CEED_MEM_DEVICE:
32 *need_sync = has_valid_array && !impl->d_array;
33 break;
34 }
35 return CEED_ERROR_SUCCESS;
36 }
37
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
CeedVectorSyncH2D_Hip(const CeedVector vec)41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42 CeedSize length;
43 size_t bytes;
44 CeedVector_Hip *impl;
45
46 CeedCallBackend(CeedVectorGetData(vec, &impl));
47
48 CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49
50 CeedCallBackend(CeedVectorGetLength(vec, &length));
51 bytes = length * sizeof(CeedScalar);
52 if (impl->d_array_borrowed) {
53 impl->d_array = impl->d_array_borrowed;
54 } else if (impl->d_array_owned) {
55 impl->d_array = impl->d_array_owned;
56 } else {
57 CeedCallHip(CeedVectorReturnCeed(vec), hipMalloc((void **)&impl->d_array_owned, bytes));
58 impl->d_array = impl->d_array_owned;
59 }
60 CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
61 return CEED_ERROR_SUCCESS;
62 }
63
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
CeedVectorSyncD2H_Hip(const CeedVector vec)67 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
68 CeedSize length;
69 size_t bytes;
70 CeedVector_Hip *impl;
71
72 CeedCallBackend(CeedVectorGetData(vec, &impl));
73
74 CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
75
76 if (impl->h_array_borrowed) {
77 impl->h_array = impl->h_array_borrowed;
78 } else if (impl->h_array_owned) {
79 impl->h_array = impl->h_array_owned;
80 } else {
81 CeedSize length;
82
83 CeedCallBackend(CeedVectorGetLength(vec, &length));
84 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
85 impl->h_array = impl->h_array_owned;
86 }
87
88 CeedCallBackend(CeedVectorGetLength(vec, &length));
89 bytes = length * sizeof(CeedScalar);
90 CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
91 return CEED_ERROR_SUCCESS;
92 }
93
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
CeedVectorSyncArray_Hip(const CeedVector vec,CeedMemType mem_type)97 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98 bool need_sync = false;
99 CeedVector_Hip *impl;
100
101 // Sync for unified memory
102 CeedCallBackend(CeedVectorGetData(vec, &impl));
103 if (impl->has_unified_addressing && !impl->h_array_borrowed) {
104 CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
105 return CEED_ERROR_SUCCESS;
106 }
107
108 // Check whether device/host sync is needed
109 CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
110 if (!need_sync) return CEED_ERROR_SUCCESS;
111
112 switch (mem_type) {
113 case CEED_MEM_HOST:
114 return CeedVectorSyncD2H_Hip(vec);
115 case CEED_MEM_DEVICE:
116 return CeedVectorSyncH2D_Hip(vec);
117 }
118 return CEED_ERROR_UNSUPPORTED;
119 }
120
121 //------------------------------------------------------------------------------
122 // Set all pointers as invalid
123 //------------------------------------------------------------------------------
CeedVectorSetAllInvalid_Hip(const CeedVector vec)124 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
125 CeedVector_Hip *impl;
126
127 CeedCallBackend(CeedVectorGetData(vec, &impl));
128 impl->h_array = NULL;
129 impl->d_array = NULL;
130 return CEED_ERROR_SUCCESS;
131 }
132
133 //------------------------------------------------------------------------------
134 // Check if CeedVector has any valid pointer
135 //------------------------------------------------------------------------------
CeedVectorHasValidArray_Hip(const CeedVector vec,bool * has_valid_array)136 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
137 CeedVector_Hip *impl;
138
139 CeedCallBackend(CeedVectorGetData(vec, &impl));
140 *has_valid_array = impl->h_array || impl->d_array;
141 return CEED_ERROR_SUCCESS;
142 }
143
144 //------------------------------------------------------------------------------
145 // Check if has array of given type
146 //------------------------------------------------------------------------------
CeedVectorHasArrayOfType_Hip(const CeedVector vec,CeedMemType mem_type,bool * has_array_of_type)147 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
148 CeedVector_Hip *impl;
149
150 CeedCallBackend(CeedVectorGetData(vec, &impl));
151 switch (mem_type) {
152 case CEED_MEM_HOST:
153 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
154 break;
155 case CEED_MEM_DEVICE:
156 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
157 break;
158 }
159 return CEED_ERROR_SUCCESS;
160 }
161
162 //------------------------------------------------------------------------------
163 // Check if has borrowed array of given type
164 //------------------------------------------------------------------------------
CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,CeedMemType mem_type,bool * has_borrowed_array_of_type)165 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
166 CeedVector_Hip *impl;
167
168 CeedCallBackend(CeedVectorGetData(vec, &impl));
169
170 // Use device memory for unified memory
171 mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
172
173 switch (mem_type) {
174 case CEED_MEM_HOST:
175 *has_borrowed_array_of_type = impl->h_array_borrowed;
176 break;
177 case CEED_MEM_DEVICE:
178 *has_borrowed_array_of_type = impl->d_array_borrowed;
179 break;
180 }
181 return CEED_ERROR_SUCCESS;
182 }
183
184 //------------------------------------------------------------------------------
185 // Set array from host
186 //------------------------------------------------------------------------------
CeedVectorSetArrayHost_Hip(const CeedVector vec,const CeedCopyMode copy_mode,CeedScalar * array)187 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
188 CeedSize length;
189 CeedVector_Hip *impl;
190
191 CeedCallBackend(CeedVectorGetData(vec, &impl));
192 CeedCallBackend(CeedVectorGetLength(vec, &length));
193
194 CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
195 (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
196 return CEED_ERROR_SUCCESS;
197 }
198
199 //------------------------------------------------------------------------------
200 // Set array from device
201 //------------------------------------------------------------------------------
CeedVectorSetArrayDevice_Hip(const CeedVector vec,const CeedCopyMode copy_mode,CeedScalar * array)202 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
203 CeedSize length;
204 Ceed ceed;
205 CeedVector_Hip *impl;
206
207 CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
208 CeedCallBackend(CeedVectorGetData(vec, &impl));
209 CeedCallBackend(CeedVectorGetLength(vec, &length));
210
211 CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
212 (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
213 CeedCallBackend(CeedDestroy(&ceed));
214 return CEED_ERROR_SUCCESS;
215 }
216
217 //------------------------------------------------------------------------------
218 // Set array with unified memory
219 //------------------------------------------------------------------------------
CeedVectorSetArrayUnifiedHostToDevice_Hip(const CeedVector vec,const CeedCopyMode copy_mode,CeedScalar * array)220 static int CeedVectorSetArrayUnifiedHostToDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
221 CeedSize length;
222 Ceed ceed;
223 CeedVector_Hip *impl;
224
225 CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
226 CeedCallBackend(CeedVectorGetData(vec, &impl));
227 CeedCallBackend(CeedVectorGetLength(vec, &length));
228
229 switch (copy_mode) {
230 case CEED_COPY_VALUES:
231 case CEED_OWN_POINTER:
232 if (!impl->d_array) {
233 if (impl->d_array_borrowed) {
234 impl->d_array = impl->d_array_borrowed;
235 } else {
236 if (!impl->d_array_owned) CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, sizeof(CeedScalar) * length));
237 impl->d_array = impl->d_array_owned;
238 }
239 }
240 if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, sizeof(CeedScalar) * length, hipMemcpyHostToDevice));
241 if (copy_mode == CEED_OWN_POINTER) CeedCallBackend(CeedFree(&array));
242 break;
243 case CEED_USE_POINTER:
244 CeedCallHip(ceed, hipFree(impl->d_array_owned));
245 CeedCallBackend(CeedFree(&impl->h_array_owned));
246 impl->h_array_owned = NULL;
247 impl->h_array_borrowed = array;
248 impl->d_array = impl->h_array_borrowed;
249 }
250 CeedCallBackend(CeedDestroy(&ceed));
251 return CEED_ERROR_SUCCESS;
252 }
253
254 //------------------------------------------------------------------------------
255 // Set the array used by a vector,
256 // freeing any previously allocated array if applicable
257 //------------------------------------------------------------------------------
CeedVectorSetArray_Hip(const CeedVector vec,const CeedMemType mem_type,const CeedCopyMode copy_mode,CeedScalar * array)258 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
259 CeedVector_Hip *impl;
260
261 CeedCallBackend(CeedVectorGetData(vec, &impl));
262 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
263 switch (mem_type) {
264 case CEED_MEM_HOST:
265 if (impl->has_unified_addressing) {
266 return CeedVectorSetArrayUnifiedHostToDevice_Hip(vec, copy_mode, array);
267 } else {
268 return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
269 }
270 case CEED_MEM_DEVICE:
271 return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
272 }
273 return CEED_ERROR_UNSUPPORTED;
274 }
275
276 //------------------------------------------------------------------------------
277 // Copy host array to value strided
278 //------------------------------------------------------------------------------
CeedHostCopyStrided_Hip(CeedScalar * h_array,CeedSize start,CeedSize stop,CeedSize step,CeedScalar * h_copy_array)279 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *h_copy_array) {
280 for (CeedSize i = start; i < stop; i += step) h_copy_array[i] = h_array[i];
281 return CEED_ERROR_SUCCESS;
282 }
283
284 //------------------------------------------------------------------------------
285 // Copy device array to value strided (impl in .hip.cpp file)
286 //------------------------------------------------------------------------------
287 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *d_copy_array);
288
289 //------------------------------------------------------------------------------
290 // Copy a vector to a value strided
291 //------------------------------------------------------------------------------
CeedVectorCopyStrided_Hip(CeedVector vec,CeedSize start,CeedSize stop,CeedSize step,CeedVector vec_copy)292 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedVector vec_copy) {
293 CeedSize length;
294 CeedVector_Hip *impl;
295
296 CeedCallBackend(CeedVectorGetData(vec, &impl));
297 {
298 CeedSize length_vec, length_copy;
299
300 CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
301 CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
302 length = length_vec < length_copy ? length_vec : length_copy;
303 }
304 if (stop == -1) stop = length;
305 // Set value for synced device/host array
306 if (impl->d_array) {
307 CeedScalar *copy_array;
308 Ceed ceed;
309
310 CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
311 CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, ©_array));
312 #if (HIP_VERSION >= 60000000)
313 hipblasHandle_t handle;
314 hipStream_t stream;
315 CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
316 CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
317 #if defined(CEED_SCALAR_IS_FP32)
318 CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
319 #else /* CEED_SCALAR */
320 CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
321 #endif /* CEED_SCALAR */
322 CeedCallHip(ceed, hipStreamSynchronize(stream));
323 #else /* HIP_VERSION */
324 CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, stop, step, copy_array));
325 #endif /* HIP_VERSION */
326 CeedCallBackend(CeedVectorRestoreArray(vec_copy, ©_array));
327 impl->h_array = NULL;
328 CeedCallBackend(CeedDestroy(&ceed));
329 } else if (impl->h_array) {
330 CeedScalar *copy_array;
331
332 CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, ©_array));
333 CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, stop, step, copy_array));
334 CeedCallBackend(CeedVectorRestoreArray(vec_copy, ©_array));
335 impl->d_array = NULL;
336 } else {
337 return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
338 }
339 return CEED_ERROR_SUCCESS;
340 }
341
342 //------------------------------------------------------------------------------
343 // Set host array to value
344 //------------------------------------------------------------------------------
CeedHostSetValue_Hip(CeedScalar * h_array,CeedSize length,CeedScalar val)345 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
346 for (CeedSize i = 0; i < length; i++) h_array[i] = val;
347 return CEED_ERROR_SUCCESS;
348 }
349
350 //------------------------------------------------------------------------------
351 // Set device array to value (impl in .hip file)
352 //------------------------------------------------------------------------------
353 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
354
355 //------------------------------------------------------------------------------
356 // Set a vector to a value
357 //------------------------------------------------------------------------------
CeedVectorSetValue_Hip(CeedVector vec,CeedScalar val)358 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
359 CeedSize length;
360 CeedVector_Hip *impl;
361 Ceed_Hip *hip_data;
362
363 CeedCallBackend(CeedVectorGetData(vec, &impl));
364 CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
365 CeedCallBackend(CeedVectorGetLength(vec, &length));
366 // Set value for synced device/host array
367 if (!impl->d_array && !impl->h_array) {
368 if (impl->d_array_borrowed) {
369 impl->d_array = impl->d_array_borrowed;
370 } else if (impl->h_array_borrowed) {
371 impl->h_array = impl->h_array_borrowed;
372 } else if (impl->d_array_owned) {
373 impl->d_array = impl->d_array_owned;
374 } else if (impl->h_array_owned) {
375 impl->h_array = impl->h_array_owned;
376 } else {
377 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
378 }
379 }
380 if (impl->d_array) {
381 if (val == 0 && !impl->h_array_borrowed) {
382 CeedCallHip(CeedVectorReturnCeed(vec), hipMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
383 } else {
384 CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
385 }
386 impl->h_array = NULL;
387 } else if (impl->h_array) {
388 CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
389 impl->d_array = NULL;
390 }
391 return CEED_ERROR_SUCCESS;
392 }
393
394 //------------------------------------------------------------------------------
395 // Set host array to value strided
396 //------------------------------------------------------------------------------
CeedHostSetValueStrided_Hip(CeedScalar * h_array,CeedSize start,CeedSize stop,CeedSize step,CeedScalar val)397 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
398 for (CeedSize i = start; i < stop; i += step) h_array[i] = val;
399 return CEED_ERROR_SUCCESS;
400 }
401
402 //------------------------------------------------------------------------------
403 // Set device array to value strided (impl in .hip.cpp file)
404 //------------------------------------------------------------------------------
405 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val);
406
407 //------------------------------------------------------------------------------
408 // Set a vector to a value strided
409 //------------------------------------------------------------------------------
CeedVectorSetValueStrided_Hip(CeedVector vec,CeedSize start,CeedSize stop,CeedSize step,CeedScalar val)410 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
411 CeedSize length;
412 CeedVector_Hip *impl;
413
414 CeedCallBackend(CeedVectorGetData(vec, &impl));
415 CeedCallBackend(CeedVectorGetLength(vec, &length));
416 // Set value for synced device/host array
417 if (stop == -1) stop = length;
418 if (impl->d_array) {
419 CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, stop, step, val));
420 impl->h_array = NULL;
421 } else if (impl->h_array) {
422 CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, stop, step, val));
423 impl->d_array = NULL;
424 } else {
425 return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
426 }
427 return CEED_ERROR_SUCCESS;
428 }
429
430 //------------------------------------------------------------------------------
431 // Vector Take Array
432 //------------------------------------------------------------------------------
CeedVectorTakeArray_Hip(CeedVector vec,CeedMemType mem_type,CeedScalar ** array)433 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
434 CeedVector_Hip *impl;
435
436 CeedCallBackend(CeedVectorGetData(vec, &impl));
437
438 // Sync array to requested mem_type
439 CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
440
441 // Update pointer
442 switch (mem_type) {
443 case CEED_MEM_HOST:
444 (*array) = impl->h_array_borrowed;
445 impl->h_array_borrowed = NULL;
446 impl->h_array = NULL;
447 break;
448 case CEED_MEM_DEVICE:
449 (*array) = impl->d_array_borrowed;
450 impl->d_array_borrowed = NULL;
451 impl->d_array = NULL;
452 break;
453 }
454 return CEED_ERROR_SUCCESS;
455 }
456
457 //------------------------------------------------------------------------------
458 // Core logic for array synchronization for GetArray.
459 // If a different memory type is most up to date, this will perform a copy
460 //------------------------------------------------------------------------------
CeedVectorGetArrayCore_Hip(const CeedVector vec,CeedMemType mem_type,CeedScalar ** array)461 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
462 CeedVector_Hip *impl;
463
464 CeedCallBackend(CeedVectorGetData(vec, &impl));
465
466 // Use device memory for unified memory
467 mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
468
469 // Sync array to requested mem_type
470 CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
471
472 // Update pointer
473 switch (mem_type) {
474 case CEED_MEM_HOST:
475 *array = impl->h_array;
476 break;
477 case CEED_MEM_DEVICE:
478 *array = impl->d_array;
479 break;
480 }
481 return CEED_ERROR_SUCCESS;
482 }
483
484 //------------------------------------------------------------------------------
485 // Get read-only access to a vector via the specified mem_type
486 //------------------------------------------------------------------------------
CeedVectorGetArrayRead_Hip(const CeedVector vec,const CeedMemType mem_type,const CeedScalar ** array)487 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
488 return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
489 }
490
491 //------------------------------------------------------------------------------
492 // Get read/write access to a vector via the specified mem_type
493 //------------------------------------------------------------------------------
CeedVectorGetArray_Hip(const CeedVector vec,CeedMemType mem_type,CeedScalar ** array)494 static int CeedVectorGetArray_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
495 CeedVector_Hip *impl;
496
497 CeedCallBackend(CeedVectorGetData(vec, &impl));
498
499 // Use device memory for unified memory
500 mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
501
502 // 'Get' array and set only 'get'ed array as valid
503 CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
504 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
505 switch (mem_type) {
506 case CEED_MEM_HOST:
507 impl->h_array = *array;
508 if (impl->has_unified_addressing) impl->d_array = *array;
509 break;
510 case CEED_MEM_DEVICE:
511 impl->d_array = *array;
512 break;
513 }
514 return CEED_ERROR_SUCCESS;
515 }
516
517 //------------------------------------------------------------------------------
518 // Get write access to a vector via the specified mem_type
519 //------------------------------------------------------------------------------
CeedVectorGetArrayWrite_Hip(const CeedVector vec,CeedMemType mem_type,CeedScalar ** array)520 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
521 bool has_array_of_type = true;
522 CeedVector_Hip *impl;
523 Ceed_Hip *hip_data;
524
525 CeedCallBackend(CeedVectorGetData(vec, &impl));
526 CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
527
528 // Use device memory for unified memory
529 mem_type = impl->has_unified_addressing && !impl->h_array_borrowed ? CEED_MEM_DEVICE : mem_type;
530
531 CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
532 if (!has_array_of_type) {
533 // Allocate if array is not yet allocated
534 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
535 } else {
536 // Select dirty array
537 switch (mem_type) {
538 case CEED_MEM_HOST:
539 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
540 else impl->h_array = impl->h_array_owned;
541 break;
542 case CEED_MEM_DEVICE:
543 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
544 else impl->d_array = impl->d_array_owned;
545 }
546 }
547 return CeedVectorGetArray_Hip(vec, mem_type, array);
548 }
549
550 //------------------------------------------------------------------------------
551 // Get the norm of a CeedVector
552 //------------------------------------------------------------------------------
CeedVectorNorm_Hip(CeedVector vec,CeedNormType type,CeedScalar * norm)553 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
554 Ceed ceed;
555 CeedSize length;
556 #if (HIP_VERSION < 60000000)
557 CeedSize num_calls;
558 #endif /* HIP_VERSION */
559 const CeedScalar *d_array;
560 CeedVector_Hip *impl;
561 hipblasHandle_t handle;
562 hipStream_t stream;
563 Ceed_Hip *hip_data;
564
565 CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
566 CeedCallBackend(CeedGetData(ceed, &hip_data));
567 CeedCallBackend(CeedVectorGetData(vec, &impl));
568 CeedCallBackend(CeedVectorGetLength(vec, &length));
569 CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
570 CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
571 #if (HIP_VERSION < 60000000)
572 // With ROCm 6, we can use the 64-bit integer interface. Prior to that,
573 // we need to check if the vector is too long to handle with int32,
574 // and if so, divide it into subsections for repeated hipBLAS calls.
575 num_calls = length / INT_MAX;
576 if (length % INT_MAX > 0) num_calls += 1;
577 #endif /* HIP_VERSION */
578
579 // Compute norm
580 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
581 switch (type) {
582 case CEED_NORM_1: {
583 *norm = 0.0;
584 #if defined(CEED_SCALAR_IS_FP32)
585 #if (HIP_VERSION >= 60000000) // We have ROCm 6, and can use 64-bit integers
586 CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
587 CeedCallHip(ceed, hipStreamSynchronize(stream));
588 #else /* HIP_VERSION */
589 float sub_norm = 0.0;
590 float *d_array_start;
591
592 for (CeedInt i = 0; i < num_calls; i++) {
593 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX;
594 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
595 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
596
597 CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
598 CeedCallHip(ceed, hipStreamSynchronize(stream));
599 *norm += sub_norm;
600 }
601 #endif /* HIP_VERSION */
602 #else /* CEED_SCALAR */
603 #if (HIP_VERSION >= 60000000)
604 CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
605 CeedCallHip(ceed, hipStreamSynchronize(stream));
606 #else /* HIP_VERSION */
607 double sub_norm = 0.0;
608 double *d_array_start;
609
610 for (CeedInt i = 0; i < num_calls; i++) {
611 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX;
612 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
613 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
614
615 CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
616 CeedCallHip(ceed, hipStreamSynchronize(stream));
617 *norm += sub_norm;
618 }
619 #endif /* HIP_VERSION */
620 #endif /* CEED_SCALAR */
621 break;
622 }
623 case CEED_NORM_2: {
624 #if defined(CEED_SCALAR_IS_FP32)
625 #if (HIP_VERSION >= 60000000)
626 CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
627 CeedCallHip(ceed, hipStreamSynchronize(stream));
628 #else /* HIP_VERSION */
629 float sub_norm = 0.0, norm_sum = 0.0;
630 float *d_array_start;
631
632 for (CeedInt i = 0; i < num_calls; i++) {
633 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX;
634 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
635 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
636
637 CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
638 CeedCallHip(ceed, hipStreamSynchronize(stream));
639 norm_sum += sub_norm * sub_norm;
640 }
641 *norm = sqrt(norm_sum);
642 #endif /* HIP_VERSION */
643 #else /* CEED_SCALAR */
644 #if (HIP_VERSION >= 60000000)
645 CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
646 CeedCallHip(ceed, hipStreamSynchronize(stream));
647 #else /* HIP_VERSION */
648 double sub_norm = 0.0, norm_sum = 0.0;
649 double *d_array_start;
650
651 for (CeedInt i = 0; i < num_calls; i++) {
652 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX;
653 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
654 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
655
656 CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
657 CeedCallHip(ceed, hipStreamSynchronize(stream));
658 norm_sum += sub_norm * sub_norm;
659 }
660 *norm = sqrt(norm_sum);
661 #endif /* HIP_VERSION */
662 #endif /* CEED_SCALAR */
663 break;
664 }
665 case CEED_NORM_MAX: {
666 #if defined(CEED_SCALAR_IS_FP32)
667 #if (HIP_VERSION >= 60000000)
668 int64_t index;
669 CeedScalar norm_no_abs;
670
671 CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
672 CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
673 CeedCallHip(ceed, hipStreamSynchronize(stream));
674 *norm = fabs(norm_no_abs);
675 #else /* HIP_VERSION */
676 CeedInt index;
677 float sub_max = 0.0, current_max = 0.0;
678 float *d_array_start;
679
680 for (CeedInt i = 0; i < num_calls; i++) {
681 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX;
682 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
683 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
684
685 CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
686 if (hip_data->has_unified_addressing) {
687 CeedCallHip(ceed, hipStreamSynchronize(stream));
688 sub_max = fabs(d_array[index - 1]);
689 } else {
690 CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
691 CeedCallHip(ceed, hipStreamSynchronize(stream));
692 }
693 if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
694 }
695 *norm = current_max;
696 #endif /* HIP_VERSION */
697 #else /* CEED_SCALAR */
698 #if (HIP_VERSION >= 60000000)
699 int64_t index;
700 CeedScalar norm_no_abs;
701
702 CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
703 if (hip_data->has_unified_addressing) {
704 CeedCallHip(ceed, hipStreamSynchronize(stream));
705 norm_no_abs = fabs(d_array[index - 1]);
706 } else {
707 CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
708 CeedCallHip(ceed, hipStreamSynchronize(stream));
709 }
710 *norm = fabs(norm_no_abs);
711 #else /* HIP_VERSION */
712 CeedInt index;
713 double sub_max = 0.0, current_max = 0.0;
714 double *d_array_start;
715
716 for (CeedInt i = 0; i < num_calls; i++) {
717 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX;
718 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
719 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
720
721 CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
722 if (hip_data->has_unified_addressing) {
723 CeedCallHip(ceed, hipStreamSynchronize(stream));
724 sub_max = fabs(d_array[index - 1]);
725 } else {
726 CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
727 CeedCallHip(ceed, hipStreamSynchronize(stream));
728 }
729 if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
730 }
731 *norm = current_max;
732 #endif /* HIP_VERSION */
733 #endif /* CEED_SCALAR */
734 break;
735 }
736 }
737 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
738 CeedCallBackend(CeedDestroy(&ceed));
739 return CEED_ERROR_SUCCESS;
740 }
741
742 //------------------------------------------------------------------------------
743 // Take reciprocal of a vector on host
744 //------------------------------------------------------------------------------
CeedHostReciprocal_Hip(CeedScalar * h_array,CeedSize length)745 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
746 for (CeedSize i = 0; i < length; i++) {
747 if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
748 }
749 return CEED_ERROR_SUCCESS;
750 }
751
752 //------------------------------------------------------------------------------
753 // Take reciprocal of a vector on device (impl in .hip.cpp file)
754 //------------------------------------------------------------------------------
755 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
756
757 //------------------------------------------------------------------------------
758 // Take reciprocal of a vector
759 //------------------------------------------------------------------------------
CeedVectorReciprocal_Hip(CeedVector vec)760 static int CeedVectorReciprocal_Hip(CeedVector vec) {
761 CeedSize length;
762 CeedVector_Hip *impl;
763
764 CeedCallBackend(CeedVectorGetData(vec, &impl));
765 CeedCallBackend(CeedVectorGetLength(vec, &length));
766 // Set value for synced device/host array
767 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
768 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
769 return CEED_ERROR_SUCCESS;
770 }
771
772 //------------------------------------------------------------------------------
773 // Compute x = alpha x on the host
774 //------------------------------------------------------------------------------
CeedHostScale_Hip(CeedScalar * x_array,CeedScalar alpha,CeedSize length)775 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
776 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
777 return CEED_ERROR_SUCCESS;
778 }
779
780 //------------------------------------------------------------------------------
781 // Compute x = alpha x on device (impl in .hip.cpp file)
782 //------------------------------------------------------------------------------
783 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
784
785 //------------------------------------------------------------------------------
786 // Compute x = alpha x
787 //------------------------------------------------------------------------------
CeedVectorScale_Hip(CeedVector x,CeedScalar alpha)788 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
789 CeedSize length;
790 CeedVector_Hip *impl;
791
792 CeedCallBackend(CeedVectorGetData(x, &impl));
793 CeedCallBackend(CeedVectorGetLength(x, &length));
794 // Set value for synced device/host array
795 if (impl->d_array) {
796 #if (HIP_VERSION >= 60000000)
797 hipblasHandle_t handle;
798 hipStream_t stream;
799
800 CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
801 CeedCallHipblas(CeedVectorReturnCeed(x), hipblasGetStream(handle, &stream));
802 #if defined(CEED_SCALAR_IS_FP32)
803 CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
804 #else /* CEED_SCALAR */
805 CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
806 #endif /* CEED_SCALAR */
807 CeedCallHip(CeedVectorReturnCeed(x), hipStreamSynchronize(stream));
808 #else /* HIP_VERSION */
809 CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
810 #endif /* HIP_VERSION */
811 impl->h_array = NULL;
812 }
813 if (impl->h_array) {
814 CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
815 impl->d_array = NULL;
816 }
817 return CEED_ERROR_SUCCESS;
818 }
819
820 //------------------------------------------------------------------------------
821 // Compute y = alpha x + y on the host
822 //------------------------------------------------------------------------------
CeedHostAXPY_Hip(CeedScalar * y_array,CeedScalar alpha,CeedScalar * x_array,CeedSize length)823 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
824 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
825 return CEED_ERROR_SUCCESS;
826 }
827
828 //------------------------------------------------------------------------------
829 // Compute y = alpha x + y on device (impl in .hip.cpp file)
830 //------------------------------------------------------------------------------
831 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
832
833 //------------------------------------------------------------------------------
834 // Compute y = alpha x + y
835 //------------------------------------------------------------------------------
CeedVectorAXPY_Hip(CeedVector y,CeedScalar alpha,CeedVector x)836 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
837 CeedSize length;
838 CeedVector_Hip *y_impl, *x_impl;
839
840 CeedCallBackend(CeedVectorGetData(y, &y_impl));
841 CeedCallBackend(CeedVectorGetData(x, &x_impl));
842 CeedCallBackend(CeedVectorGetLength(y, &length));
843 // Set value for synced device/host array
844 if (y_impl->d_array) {
845 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
846 #if (HIP_VERSION >= 60000000)
847 hipblasHandle_t handle;
848 hipStream_t stream;
849
850 CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
851 CeedCallHipblas(CeedVectorReturnCeed(y), hipblasGetStream(handle, &stream));
852 #if defined(CEED_SCALAR_IS_FP32)
853 CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
854 #else /* CEED_SCALAR */
855 CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
856 #endif /* CEED_SCALAR */
857 CeedCallHip(CeedVectorReturnCeed(y), hipStreamSynchronize(stream));
858 #else /* HIP_VERSION */
859 CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
860 #endif /* HIP_VERSION */
861 y_impl->h_array = NULL;
862 } else if (y_impl->h_array) {
863 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
864 CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
865 y_impl->d_array = NULL;
866 }
867 return CEED_ERROR_SUCCESS;
868 }
869
870 //------------------------------------------------------------------------------
871 // Compute y = alpha x + beta y on the host
872 //------------------------------------------------------------------------------
CeedHostAXPBY_Hip(CeedScalar * y_array,CeedScalar alpha,CeedScalar beta,CeedScalar * x_array,CeedSize length)873 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
874 for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
875 return CEED_ERROR_SUCCESS;
876 }
877
878 //------------------------------------------------------------------------------
879 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
880 //------------------------------------------------------------------------------
881 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
882
883 //------------------------------------------------------------------------------
884 // Compute y = alpha x + beta y
885 //------------------------------------------------------------------------------
CeedVectorAXPBY_Hip(CeedVector y,CeedScalar alpha,CeedScalar beta,CeedVector x)886 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
887 CeedSize length;
888 CeedVector_Hip *y_impl, *x_impl;
889
890 CeedCallBackend(CeedVectorGetData(y, &y_impl));
891 CeedCallBackend(CeedVectorGetData(x, &x_impl));
892 CeedCallBackend(CeedVectorGetLength(y, &length));
893 // Set value for synced device/host array
894 if (y_impl->d_array) {
895 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
896 CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
897 }
898 if (y_impl->h_array) {
899 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
900 CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
901 }
902 return CEED_ERROR_SUCCESS;
903 }
904
905 //------------------------------------------------------------------------------
906 // Compute the pointwise multiplication w = x .* y on the host
907 //------------------------------------------------------------------------------
CeedHostPointwiseMult_Hip(CeedScalar * w_array,CeedScalar * x_array,CeedScalar * y_array,CeedSize length)908 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
909 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
910 return CEED_ERROR_SUCCESS;
911 }
912
913 //------------------------------------------------------------------------------
914 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
915 //------------------------------------------------------------------------------
916 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
917
918 //------------------------------------------------------------------------------
919 // Compute the pointwise multiplication w = x .* y
920 //------------------------------------------------------------------------------
CeedVectorPointwiseMult_Hip(CeedVector w,CeedVector x,CeedVector y)921 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
922 CeedSize length;
923 CeedVector_Hip *w_impl, *x_impl, *y_impl;
924
925 CeedCallBackend(CeedVectorGetData(w, &w_impl));
926 CeedCallBackend(CeedVectorGetData(x, &x_impl));
927 CeedCallBackend(CeedVectorGetData(y, &y_impl));
928 CeedCallBackend(CeedVectorGetLength(w, &length));
929
930 // Set value for synced device/host array
931 if (!w_impl->d_array && !w_impl->h_array) {
932 CeedCallBackend(CeedVectorSetValue(w, 0.0));
933 }
934 if (w_impl->d_array) {
935 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
936 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
937 CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
938 }
939 if (w_impl->h_array) {
940 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
941 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
942 CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
943 }
944 return CEED_ERROR_SUCCESS;
945 }
946
947 //------------------------------------------------------------------------------
948 // Destroy the vector
949 //------------------------------------------------------------------------------
CeedVectorDestroy_Hip(const CeedVector vec)950 static int CeedVectorDestroy_Hip(const CeedVector vec) {
951 CeedVector_Hip *impl;
952
953 CeedCallBackend(CeedVectorGetData(vec, &impl));
954 CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
955 CeedCallBackend(CeedFree(&impl->h_array_owned));
956 CeedCallBackend(CeedFree(&impl));
957 return CEED_ERROR_SUCCESS;
958 }
959
960 //------------------------------------------------------------------------------
961 // Create a vector of the specified length (does not allocate memory)
962 //------------------------------------------------------------------------------
CeedVectorCreate_Hip(CeedSize n,CeedVector vec)963 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
964 CeedVector_Hip *impl;
965 Ceed_Hip *hip_impl;
966 Ceed ceed;
967
968 CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
969 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
970 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
971 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
972 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
973 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
974 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
975 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
976 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
977 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
978 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
979 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
980 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
981 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
982 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
983 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
984 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
985 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
986 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
987 CeedCallBackend(CeedCalloc(1, &impl));
988 CeedCallBackend(CeedGetData(ceed, &hip_impl));
989 CeedCallBackend(CeedDestroy(&ceed));
990 impl->has_unified_addressing = hip_impl->has_unified_addressing;
991 CeedCallBackend(CeedVectorSetData(vec, impl));
992 return CEED_ERROR_SUCCESS;
993 }
994
995 //------------------------------------------------------------------------------
996