xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision 9330daecb0fc008043eec1b94c46ef7aecbb00cd)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cublas_v2.h>
11 #include <cuda_runtime.h>
12 #include <math.h>
13 #include <stdbool.h>
14 #include <string.h>
15 
16 #include "../cuda/ceed-cuda-common.h"
17 #include "ceed-cuda-ref.h"
18 
19 //------------------------------------------------------------------------------
20 // Check if host/device sync is needed
21 //------------------------------------------------------------------------------
22 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
23   CeedVector_Cuda *impl;
24   CeedCallBackend(CeedVectorGetData(vec, &impl));
25 
26   bool has_valid_array = false;
27   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
28   switch (mem_type) {
29     case CEED_MEM_HOST:
30       *need_sync = has_valid_array && !impl->h_array;
31       break;
32     case CEED_MEM_DEVICE:
33       *need_sync = has_valid_array && !impl->d_array;
34       break;
35   }
36 
37   return CEED_ERROR_SUCCESS;
38 }
39 
40 //------------------------------------------------------------------------------
41 // Sync host to device
42 //------------------------------------------------------------------------------
43 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
44   Ceed ceed;
45   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
46   CeedVector_Cuda *impl;
47   CeedCallBackend(CeedVectorGetData(vec, &impl));
48 
49   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
50 
51   CeedSize length;
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   size_t bytes = length * sizeof(CeedScalar);
54 
55   if (impl->d_array_borrowed) {
56     impl->d_array = impl->d_array_borrowed;
57   } else if (impl->d_array_owned) {
58     impl->d_array = impl->d_array_owned;
59   } else {
60     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
61     impl->d_array = impl->d_array_owned;
62   }
63 
64   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
65 
66   return CEED_ERROR_SUCCESS;
67 }
68 
69 //------------------------------------------------------------------------------
70 // Sync device to host
71 //------------------------------------------------------------------------------
72 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
73   Ceed ceed;
74   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75   CeedVector_Cuda *impl;
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86     CeedCallBackend(CeedVectorGetLength(vec, &length));
87     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88     impl->h_array = impl->h_array_owned;
89   }
90 
91   CeedSize length;
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   size_t bytes = length * sizeof(CeedScalar);
94   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95 
96   return CEED_ERROR_SUCCESS;
97 }
98 
99 //------------------------------------------------------------------------------
100 // Sync arrays
101 //------------------------------------------------------------------------------
102 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
103   // Check whether device/host sync is needed
104   bool need_sync = false;
105   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Cuda(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Cuda(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121   CeedVector_Cuda *impl;
122   CeedCallBackend(CeedVectorGetData(vec, &impl));
123 
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126 
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 //------------------------------------------------------------------------------
131 // Check if CeedVector has any valid pointer
132 //------------------------------------------------------------------------------
133 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
134   CeedVector_Cuda *impl;
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136 
137   *has_valid_array = !!impl->h_array || !!impl->d_array;
138 
139   return CEED_ERROR_SUCCESS;
140 }
141 
142 //------------------------------------------------------------------------------
143 // Check if has array of given type
144 //------------------------------------------------------------------------------
145 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
146   CeedVector_Cuda *impl;
147   CeedCallBackend(CeedVectorGetData(vec, &impl));
148 
149   switch (mem_type) {
150     case CEED_MEM_HOST:
151       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
152       break;
153     case CEED_MEM_DEVICE:
154       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
155       break;
156   }
157 
158   return CEED_ERROR_SUCCESS;
159 }
160 
161 //------------------------------------------------------------------------------
162 // Check if has borrowed array of given type
163 //------------------------------------------------------------------------------
164 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
165   CeedVector_Cuda *impl;
166   CeedCallBackend(CeedVectorGetData(vec, &impl));
167 
168   switch (mem_type) {
169     case CEED_MEM_HOST:
170       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
171       break;
172     case CEED_MEM_DEVICE:
173       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
174       break;
175   }
176 
177   return CEED_ERROR_SUCCESS;
178 }
179 
180 //------------------------------------------------------------------------------
181 // Set array from host
182 //------------------------------------------------------------------------------
183 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
184   CeedVector_Cuda *impl;
185   CeedCallBackend(CeedVectorGetData(vec, &impl));
186 
187   switch (copy_mode) {
188     case CEED_COPY_VALUES: {
189       CeedSize length;
190       if (!impl->h_array_owned) {
191         CeedCallBackend(CeedVectorGetLength(vec, &length));
192         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
193       }
194       impl->h_array_borrowed = NULL;
195       impl->h_array          = impl->h_array_owned;
196       if (array) {
197         CeedSize length;
198         CeedCallBackend(CeedVectorGetLength(vec, &length));
199         size_t bytes = length * sizeof(CeedScalar);
200         memcpy(impl->h_array, array, bytes);
201       }
202     } break;
203     case CEED_OWN_POINTER:
204       CeedCallBackend(CeedFree(&impl->h_array_owned));
205       impl->h_array_owned    = array;
206       impl->h_array_borrowed = NULL;
207       impl->h_array          = array;
208       break;
209     case CEED_USE_POINTER:
210       CeedCallBackend(CeedFree(&impl->h_array_owned));
211       impl->h_array_borrowed = array;
212       impl->h_array          = array;
213       break;
214   }
215 
216   return CEED_ERROR_SUCCESS;
217 }
218 
219 //------------------------------------------------------------------------------
220 // Set array from device
221 //------------------------------------------------------------------------------
222 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
223   Ceed ceed;
224   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225   CeedVector_Cuda *impl;
226   CeedCallBackend(CeedVectorGetData(vec, &impl));
227 
228   switch (copy_mode) {
229     case CEED_COPY_VALUES: {
230       CeedSize length;
231       CeedCallBackend(CeedVectorGetLength(vec, &length));
232       size_t bytes = length * sizeof(CeedScalar);
233       if (!impl->d_array_owned) {
234         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
235         impl->d_array = impl->d_array_owned;
236       }
237       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238     } break;
239     case CEED_OWN_POINTER:
240       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241       impl->d_array_owned    = array;
242       impl->d_array_borrowed = NULL;
243       impl->d_array          = array;
244       break;
245     case CEED_USE_POINTER:
246       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247       impl->d_array_owned    = NULL;
248       impl->d_array_borrowed = array;
249       impl->d_array          = array;
250       break;
251   }
252 
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   Ceed ceed;
262   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263   CeedVector_Cuda *impl;
264   CeedCallBackend(CeedVectorGetData(vec, &impl));
265 
266   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267   switch (mem_type) {
268     case CEED_MEM_HOST:
269       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270     case CEED_MEM_DEVICE:
271       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272   }
273 
274   return CEED_ERROR_UNSUPPORTED;
275 }
276 
277 //------------------------------------------------------------------------------
278 // Set host array to value
279 //------------------------------------------------------------------------------
280 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedInt length, CeedScalar val) {
281   for (int i = 0; i < length; i++) h_array[i] = val;
282   return CEED_ERROR_SUCCESS;
283 }
284 
285 //------------------------------------------------------------------------------
286 // Set device array to value (impl in .cu file)
287 //------------------------------------------------------------------------------
288 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedInt length, CeedScalar val);
289 
290 //------------------------------------------------------------------------------
291 // Set a vector to a value,
292 //------------------------------------------------------------------------------
293 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
294   Ceed ceed;
295   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296   CeedVector_Cuda *impl;
297   CeedCallBackend(CeedVectorGetData(vec, &impl));
298   CeedSize length;
299   CeedCallBackend(CeedVectorGetLength(vec, &length));
300 
301   // Set value for synced device/host array
302   if (!impl->d_array && !impl->h_array) {
303     if (impl->d_array_borrowed) {
304       impl->d_array = impl->d_array_borrowed;
305     } else if (impl->h_array_borrowed) {
306       impl->h_array = impl->h_array_borrowed;
307     } else if (impl->d_array_owned) {
308       impl->d_array = impl->d_array_owned;
309     } else if (impl->h_array_owned) {
310       impl->h_array = impl->h_array_owned;
311     } else {
312       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
313     }
314   }
315   if (impl->d_array) {
316     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
317     impl->h_array = NULL;
318   }
319   if (impl->h_array) {
320     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
321     impl->d_array = NULL;
322   }
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Vector Take Array
329 //------------------------------------------------------------------------------
330 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331   Ceed ceed;
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedVector_Cuda *impl;
334   CeedCallBackend(CeedVectorGetData(vec, &impl));
335 
336   // Sync array to requested mem_type
337   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338 
339   // Update pointer
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       (*array)               = impl->h_array_borrowed;
343       impl->h_array_borrowed = NULL;
344       impl->h_array          = NULL;
345       break;
346     case CEED_MEM_DEVICE:
347       (*array)               = impl->d_array_borrowed;
348       impl->d_array_borrowed = NULL;
349       impl->d_array          = NULL;
350       break;
351   }
352 
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Core logic for array syncronization for GetArray.
358 //   If a different memory type is most up to date, this will perform a copy
359 //------------------------------------------------------------------------------
360 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361   Ceed ceed;
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedVector_Cuda *impl;
364   CeedCallBackend(CeedVectorGetData(vec, &impl));
365 
366   // Sync array to requested mem_type
367   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368 
369   // Update pointer
370   switch (mem_type) {
371     case CEED_MEM_HOST:
372       *array = impl->h_array;
373       break;
374     case CEED_MEM_DEVICE:
375       *array = impl->d_array;
376       break;
377   }
378 
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read-only access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
387 }
388 
389 //------------------------------------------------------------------------------
390 // Get read/write access to a vector via the specified mem_type
391 //------------------------------------------------------------------------------
392 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393   CeedVector_Cuda *impl;
394   CeedCallBackend(CeedVectorGetData(vec, &impl));
395 
396   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
397 
398   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
399   switch (mem_type) {
400     case CEED_MEM_HOST:
401       impl->h_array = *array;
402       break;
403     case CEED_MEM_DEVICE:
404       impl->d_array = *array;
405       break;
406   }
407 
408   return CEED_ERROR_SUCCESS;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Get write access to a vector via the specified mem_type
413 //------------------------------------------------------------------------------
414 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415   CeedVector_Cuda *impl;
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417 
418   bool has_array_of_type = true;
419   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
420   if (!has_array_of_type) {
421     // Allocate if array is not yet allocated
422     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423   } else {
424     // Select dirty array
425     switch (mem_type) {
426       case CEED_MEM_HOST:
427         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428         else impl->h_array = impl->h_array_owned;
429         break;
430       case CEED_MEM_DEVICE:
431         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432         else impl->d_array = impl->d_array_owned;
433     }
434   }
435 
436   return CeedVectorGetArray_Cuda(vec, mem_type, array);
437 }
438 
439 //------------------------------------------------------------------------------
440 // Get the norm of a CeedVector
441 //------------------------------------------------------------------------------
442 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443   Ceed ceed;
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedVector_Cuda *impl;
446   CeedCallBackend(CeedVectorGetData(vec, &impl));
447   CeedSize length;
448   CeedCallBackend(CeedVectorGetLength(vec, &length));
449   cublasHandle_t handle;
450   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
451 
452   // Compute norm
453   const CeedScalar *d_array;
454   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
455   switch (type) {
456     case CEED_NORM_1: {
457       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
458         CeedCallCublas(ceed, cublasSasum(handle, length, (float *)d_array, 1, (float *)norm));
459       } else {
460         CeedCallCublas(ceed, cublasDasum(handle, length, (double *)d_array, 1, (double *)norm));
461       }
462       break;
463     }
464     case CEED_NORM_2: {
465       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
466         CeedCallCublas(ceed, cublasSnrm2(handle, length, (float *)d_array, 1, (float *)norm));
467       } else {
468         CeedCallCublas(ceed, cublasDnrm2(handle, length, (double *)d_array, 1, (double *)norm));
469       }
470       break;
471     }
472     case CEED_NORM_MAX: {
473       CeedInt indx;
474       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
475         CeedCallCublas(ceed, cublasIsamax(handle, length, (float *)d_array, 1, &indx));
476       } else {
477         CeedCallCublas(ceed, cublasIdamax(handle, length, (double *)d_array, 1, &indx));
478       }
479       CeedScalar normNoAbs;
480       CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
481       *norm = fabs(normNoAbs);
482       break;
483     }
484   }
485   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
486 
487   return CEED_ERROR_SUCCESS;
488 }
489 
490 //------------------------------------------------------------------------------
491 // Take reciprocal of a vector on host
492 //------------------------------------------------------------------------------
493 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedInt length) {
494   for (int i = 0; i < length; i++) {
495     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
496   }
497   return CEED_ERROR_SUCCESS;
498 }
499 
500 //------------------------------------------------------------------------------
501 // Take reciprocal of a vector on device (impl in .cu file)
502 //------------------------------------------------------------------------------
503 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedInt length);
504 
505 //------------------------------------------------------------------------------
506 // Take reciprocal of a vector
507 //------------------------------------------------------------------------------
508 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
509   Ceed ceed;
510   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
511   CeedVector_Cuda *impl;
512   CeedCallBackend(CeedVectorGetData(vec, &impl));
513   CeedSize length;
514   CeedCallBackend(CeedVectorGetLength(vec, &length));
515 
516   // Set value for synced device/host array
517   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
518   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
519 
520   return CEED_ERROR_SUCCESS;
521 }
522 
523 //------------------------------------------------------------------------------
524 // Compute x = alpha x on the host
525 //------------------------------------------------------------------------------
526 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
527   for (int i = 0; i < length; i++) x_array[i] *= alpha;
528   return CEED_ERROR_SUCCESS;
529 }
530 
531 //------------------------------------------------------------------------------
532 // Compute x = alpha x on device (impl in .cu file)
533 //------------------------------------------------------------------------------
534 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length);
535 
536 //------------------------------------------------------------------------------
537 // Compute x = alpha x
538 //------------------------------------------------------------------------------
539 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
540   Ceed ceed;
541   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
542   CeedVector_Cuda *x_impl;
543   CeedCallBackend(CeedVectorGetData(x, &x_impl));
544   CeedSize length;
545   CeedCallBackend(CeedVectorGetLength(x, &length));
546 
547   // Set value for synced device/host array
548   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
549   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
550 
551   return CEED_ERROR_SUCCESS;
552 }
553 
554 //------------------------------------------------------------------------------
555 // Compute y = alpha x + y on the host
556 //------------------------------------------------------------------------------
557 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
558   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
559   return CEED_ERROR_SUCCESS;
560 }
561 
562 //------------------------------------------------------------------------------
563 // Compute y = alpha x + y on device (impl in .cu file)
564 //------------------------------------------------------------------------------
565 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length);
566 
567 //------------------------------------------------------------------------------
568 // Compute y = alpha x + y
569 //------------------------------------------------------------------------------
570 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
571   Ceed ceed;
572   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
573   CeedVector_Cuda *y_impl, *x_impl;
574   CeedCallBackend(CeedVectorGetData(y, &y_impl));
575   CeedCallBackend(CeedVectorGetData(x, &x_impl));
576   CeedSize length;
577   CeedCallBackend(CeedVectorGetLength(y, &length));
578 
579   // Set value for synced device/host array
580   if (y_impl->d_array) {
581     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
582     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
583   }
584   if (y_impl->h_array) {
585     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
586     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
587   }
588 
589   return CEED_ERROR_SUCCESS;
590 }
591 
592 //------------------------------------------------------------------------------
593 // Compute y = alpha x + beta y on the host
594 //------------------------------------------------------------------------------
595 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) {
596   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
597   return CEED_ERROR_SUCCESS;
598 }
599 
600 //------------------------------------------------------------------------------
601 // Compute y = alpha x + beta y on device (impl in .cu file)
602 //------------------------------------------------------------------------------
603 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length);
604 
605 //------------------------------------------------------------------------------
606 // Compute y = alpha x + beta y
607 //------------------------------------------------------------------------------
608 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
609   Ceed ceed;
610   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
611   CeedVector_Cuda *y_impl, *x_impl;
612   CeedCallBackend(CeedVectorGetData(y, &y_impl));
613   CeedCallBackend(CeedVectorGetData(x, &x_impl));
614   CeedSize length;
615   CeedCallBackend(CeedVectorGetLength(y, &length));
616 
617   // Set value for synced device/host array
618   if (y_impl->d_array) {
619     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
620     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
621   }
622   if (y_impl->h_array) {
623     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
624     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
625   }
626 
627   return CEED_ERROR_SUCCESS;
628 }
629 
630 //------------------------------------------------------------------------------
631 // Compute the pointwise multiplication w = x .* y on the host
632 //------------------------------------------------------------------------------
633 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
634   for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
635   return CEED_ERROR_SUCCESS;
636 }
637 
638 //------------------------------------------------------------------------------
639 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
640 //------------------------------------------------------------------------------
641 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length);
642 
643 //------------------------------------------------------------------------------
644 // Compute the pointwise multiplication w = x .* y
645 //------------------------------------------------------------------------------
646 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
647   Ceed ceed;
648   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
649   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
650   CeedCallBackend(CeedVectorGetData(w, &w_impl));
651   CeedCallBackend(CeedVectorGetData(x, &x_impl));
652   CeedCallBackend(CeedVectorGetData(y, &y_impl));
653   CeedSize length;
654   CeedCallBackend(CeedVectorGetLength(w, &length));
655 
656   // Set value for synced device/host array
657   if (!w_impl->d_array && !w_impl->h_array) {
658     CeedCallBackend(CeedVectorSetValue(w, 0.0));
659   }
660   if (w_impl->d_array) {
661     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
662     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
663     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
664   }
665   if (w_impl->h_array) {
666     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
667     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
668     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
669   }
670 
671   return CEED_ERROR_SUCCESS;
672 }
673 
674 //------------------------------------------------------------------------------
675 // Destroy the vector
676 //------------------------------------------------------------------------------
677 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
678   Ceed ceed;
679   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
680   CeedVector_Cuda *impl;
681   CeedCallBackend(CeedVectorGetData(vec, &impl));
682 
683   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
684   CeedCallBackend(CeedFree(&impl->h_array_owned));
685   CeedCallBackend(CeedFree(&impl));
686 
687   return CEED_ERROR_SUCCESS;
688 }
689 
690 //------------------------------------------------------------------------------
691 // Create a vector of the specified length (does not allocate memory)
692 //------------------------------------------------------------------------------
693 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
694   CeedVector_Cuda *impl;
695   Ceed             ceed;
696   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
697 
698   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
699   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
700   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
701   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
702   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
703   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
704   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
705   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
706   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
707   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
708   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
709   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
710   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
711   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
712   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
713   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
714 
715   CeedCallBackend(CeedCalloc(1, &impl));
716   CeedCallBackend(CeedVectorSetData(vec, impl));
717 
718   return CEED_ERROR_SUCCESS;
719 }
720 
721 //------------------------------------------------------------------------------
722