xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3ff1e7120SSebastian Grimberg //
4ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5ff1e7120SSebastian Grimberg //
6ff1e7120SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7ff1e7120SSebastian Grimberg 
8ff1e7120SSebastian Grimberg #include <ceed.h>
9ff1e7120SSebastian Grimberg #include <ceed/backend.h>
10ff1e7120SSebastian Grimberg #include <cuda_runtime.h>
11ff1e7120SSebastian Grimberg #include <math.h>
12ff1e7120SSebastian Grimberg #include <stdbool.h>
13ff1e7120SSebastian Grimberg #include <string.h>
14ff1e7120SSebastian Grimberg 
15ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h"
16ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h"
17ff1e7120SSebastian Grimberg 
18ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
19ff1e7120SSebastian Grimberg // Check if host/device sync is needed
20ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
21ff1e7120SSebastian Grimberg static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22ff1e7120SSebastian Grimberg   bool             has_valid_array = false;
23ca735530SJeremy L Thompson   CeedVector_Cuda *impl;
24ca735530SJeremy L Thompson 
25ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
26ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27ff1e7120SSebastian Grimberg   switch (mem_type) {
28ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
29ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->h_array;
30ff1e7120SSebastian Grimberg       break;
31ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
32ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->d_array;
33ff1e7120SSebastian Grimberg       break;
34ff1e7120SSebastian Grimberg   }
35ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
36ff1e7120SSebastian Grimberg }
37ff1e7120SSebastian Grimberg 
38ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
39ff1e7120SSebastian Grimberg // Sync host to device
40ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
41ff1e7120SSebastian Grimberg static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
42ca735530SJeremy L Thompson   CeedSize         length;
43672b0f2aSSebastian Grimberg   size_t           bytes;
446e536b99SJeremy L Thompson   Ceed             ceed;
45ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
46ca735530SJeremy L Thompson 
47ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
49ff1e7120SSebastian Grimberg 
506e536b99SJeremy L Thompson   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
51ff1e7120SSebastian Grimberg 
52ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
53672b0f2aSSebastian Grimberg   bytes = length * sizeof(CeedScalar);
54ff1e7120SSebastian Grimberg   if (impl->d_array_borrowed) {
55ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_borrowed;
56ff1e7120SSebastian Grimberg   } else if (impl->d_array_owned) {
57ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
58ff1e7120SSebastian Grimberg   } else {
59ff1e7120SSebastian Grimberg     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
60ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
61ff1e7120SSebastian Grimberg   }
62ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
63ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
64ff1e7120SSebastian Grimberg }
65ff1e7120SSebastian Grimberg 
66ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
67ff1e7120SSebastian Grimberg // Sync device to host
68ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
69ff1e7120SSebastian Grimberg static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
70ca735530SJeremy L Thompson   CeedSize         length;
716e536b99SJeremy L Thompson   Ceed             ceed;
72ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
73ca735530SJeremy L Thompson 
74ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
76ff1e7120SSebastian Grimberg 
771a210a00SJeremy L Thompson   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78ff1e7120SSebastian Grimberg 
79ff1e7120SSebastian Grimberg   if (impl->h_array_borrowed) {
80ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_borrowed;
81ff1e7120SSebastian Grimberg   } else if (impl->h_array_owned) {
82ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
83ff1e7120SSebastian Grimberg   } else {
84ff1e7120SSebastian Grimberg     CeedSize length;
85ca735530SJeremy L Thompson 
86ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetLength(vec, &length));
87ff1e7120SSebastian Grimberg     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
89ff1e7120SSebastian Grimberg   }
90ff1e7120SSebastian Grimberg 
91ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
92ff1e7120SSebastian Grimberg   size_t bytes = length * sizeof(CeedScalar);
93ff1e7120SSebastian Grimberg 
94ca735530SJeremy L Thompson   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
96ff1e7120SSebastian Grimberg }
97ff1e7120SSebastian Grimberg 
98ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
99ff1e7120SSebastian Grimberg // Sync arrays
100ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
101ff1e7120SSebastian Grimberg static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
102ff1e7120SSebastian Grimberg   bool need_sync = false;
103ca735530SJeremy L Thompson 
104ca735530SJeremy L Thompson   // Check whether device/host sync is needed
105ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106ff1e7120SSebastian Grimberg   if (!need_sync) return CEED_ERROR_SUCCESS;
107ff1e7120SSebastian Grimberg 
108ff1e7120SSebastian Grimberg   switch (mem_type) {
109ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
110ff1e7120SSebastian Grimberg       return CeedVectorSyncD2H_Cuda(vec);
111ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
112ff1e7120SSebastian Grimberg       return CeedVectorSyncH2D_Cuda(vec);
113ff1e7120SSebastian Grimberg   }
114ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
115ff1e7120SSebastian Grimberg }
116ff1e7120SSebastian Grimberg 
117ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
118ff1e7120SSebastian Grimberg // Set all pointers as invalid
119ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
120ff1e7120SSebastian Grimberg static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
122ff1e7120SSebastian Grimberg 
123ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
124ff1e7120SSebastian Grimberg   impl->h_array = NULL;
125ff1e7120SSebastian Grimberg   impl->d_array = NULL;
126ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
127ff1e7120SSebastian Grimberg }
128ff1e7120SSebastian Grimberg 
129ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
130ff1e7120SSebastian Grimberg // Check if CeedVector has any valid pointer
131ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
132ff1e7120SSebastian Grimberg static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
133ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
134ca735530SJeremy L Thompson 
135ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
1361c66c397SJeremy L Thompson   *has_valid_array = impl->h_array || impl->d_array;
137ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
138ff1e7120SSebastian Grimberg }
139ff1e7120SSebastian Grimberg 
140ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
141ff1e7120SSebastian Grimberg // Check if has array of given type
142ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
143ff1e7120SSebastian Grimberg static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
145ff1e7120SSebastian Grimberg 
146ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
147ff1e7120SSebastian Grimberg   switch (mem_type) {
148ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
1491c66c397SJeremy L Thompson       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150ff1e7120SSebastian Grimberg       break;
151ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
1521c66c397SJeremy L Thompson       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153ff1e7120SSebastian Grimberg       break;
154ff1e7120SSebastian Grimberg   }
155ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
156ff1e7120SSebastian Grimberg }
157ff1e7120SSebastian Grimberg 
158ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
159ff1e7120SSebastian Grimberg // Check if has borrowed array of given type
160ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
161ff1e7120SSebastian Grimberg static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
163ff1e7120SSebastian Grimberg 
164ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
165ff1e7120SSebastian Grimberg   switch (mem_type) {
166ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
1671c66c397SJeremy L Thompson       *has_borrowed_array_of_type = impl->h_array_borrowed;
168ff1e7120SSebastian Grimberg       break;
169ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
1701c66c397SJeremy L Thompson       *has_borrowed_array_of_type = impl->d_array_borrowed;
171ff1e7120SSebastian Grimberg       break;
172ff1e7120SSebastian Grimberg   }
173ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
174ff1e7120SSebastian Grimberg }
175ff1e7120SSebastian Grimberg 
176ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
177ff1e7120SSebastian Grimberg // Set array from host
178ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
179ff1e7120SSebastian Grimberg static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180a267acd1SJeremy L Thompson   CeedSize         length;
181ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
182ff1e7120SSebastian Grimberg 
183ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
184ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
185a267acd1SJeremy L Thompson 
186f5d1e504SJeremy L Thompson   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
187f5d1e504SJeremy L Thompson                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
188ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
189ff1e7120SSebastian Grimberg }
190ff1e7120SSebastian Grimberg 
191ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
192ff1e7120SSebastian Grimberg // Set array from device
193ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
194ff1e7120SSebastian Grimberg static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
195a267acd1SJeremy L Thompson   CeedSize         length;
196ff1e7120SSebastian Grimberg   Ceed             ceed;
197ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
198ff1e7120SSebastian Grimberg 
199ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
200ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
201ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
202a267acd1SJeremy L Thompson 
203f5d1e504SJeremy L Thompson   CeedCallBackend(CeedSetDeviceCeedScalarArray_Cuda(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
204f5d1e504SJeremy L Thompson                                                     (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
205ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
206ff1e7120SSebastian Grimberg }
207ff1e7120SSebastian Grimberg 
208ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
209ff1e7120SSebastian Grimberg // Set the array used by a vector,
210ff1e7120SSebastian Grimberg //   freeing any previously allocated array if applicable
211ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
212ff1e7120SSebastian Grimberg static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
213ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
214ff1e7120SSebastian Grimberg 
215ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
216ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
217ff1e7120SSebastian Grimberg   switch (mem_type) {
218ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
219ff1e7120SSebastian Grimberg       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
220ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
221ff1e7120SSebastian Grimberg       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
222ff1e7120SSebastian Grimberg   }
223ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
224ff1e7120SSebastian Grimberg }
225ff1e7120SSebastian Grimberg 
226ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
227ff1e7120SSebastian Grimberg // Set host array to value
228ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
229f7c1b517Snbeams static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
230f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
231ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
232ff1e7120SSebastian Grimberg }
233ff1e7120SSebastian Grimberg 
234ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
235ff1e7120SSebastian Grimberg // Set device array to value (impl in .cu file)
236ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
237f7c1b517Snbeams int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
238ff1e7120SSebastian Grimberg 
239ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
240f7c1b517Snbeams // Set a vector to a value
241ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
242ff1e7120SSebastian Grimberg static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
243ff1e7120SSebastian Grimberg   CeedSize         length;
244ca735530SJeremy L Thompson   CeedVector_Cuda *impl;
245ff1e7120SSebastian Grimberg 
246ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
247ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
248ff1e7120SSebastian Grimberg   // Set value for synced device/host array
249ff1e7120SSebastian Grimberg   if (!impl->d_array && !impl->h_array) {
250ff1e7120SSebastian Grimberg     if (impl->d_array_borrowed) {
251ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_borrowed;
252ff1e7120SSebastian Grimberg     } else if (impl->h_array_borrowed) {
253ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_borrowed;
254ff1e7120SSebastian Grimberg     } else if (impl->d_array_owned) {
255ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_owned;
256ff1e7120SSebastian Grimberg     } else if (impl->h_array_owned) {
257ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_owned;
258ff1e7120SSebastian Grimberg     } else {
259ff1e7120SSebastian Grimberg       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
260ff1e7120SSebastian Grimberg     }
261ff1e7120SSebastian Grimberg   }
262ff1e7120SSebastian Grimberg   if (impl->d_array) {
263ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
264ff1e7120SSebastian Grimberg     impl->h_array = NULL;
265ff1e7120SSebastian Grimberg   }
266ff1e7120SSebastian Grimberg   if (impl->h_array) {
267ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
268ff1e7120SSebastian Grimberg     impl->d_array = NULL;
269ff1e7120SSebastian Grimberg   }
270ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
271ff1e7120SSebastian Grimberg }
272ff1e7120SSebastian Grimberg 
273ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
274ff1e7120SSebastian Grimberg // Vector Take Array
275ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
276ff1e7120SSebastian Grimberg static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
277ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
278ff1e7120SSebastian Grimberg 
279ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
280ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
281ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
282ff1e7120SSebastian Grimberg   // Update pointer
283ff1e7120SSebastian Grimberg   switch (mem_type) {
284ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
285ff1e7120SSebastian Grimberg       (*array)               = impl->h_array_borrowed;
286ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
287ff1e7120SSebastian Grimberg       impl->h_array          = NULL;
288ff1e7120SSebastian Grimberg       break;
289ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
290ff1e7120SSebastian Grimberg       (*array)               = impl->d_array_borrowed;
291ff1e7120SSebastian Grimberg       impl->d_array_borrowed = NULL;
292ff1e7120SSebastian Grimberg       impl->d_array          = NULL;
293ff1e7120SSebastian Grimberg       break;
294ff1e7120SSebastian Grimberg   }
295ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
296ff1e7120SSebastian Grimberg }
297ff1e7120SSebastian Grimberg 
298ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
299ff1e7120SSebastian Grimberg // Core logic for array syncronization for GetArray.
300ff1e7120SSebastian Grimberg //   If a different memory type is most up to date, this will perform a copy
301ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
302ff1e7120SSebastian Grimberg static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
303ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
304ff1e7120SSebastian Grimberg 
305ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
306ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
307ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
308ff1e7120SSebastian Grimberg   // Update pointer
309ff1e7120SSebastian Grimberg   switch (mem_type) {
310ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
311ff1e7120SSebastian Grimberg       *array = impl->h_array;
312ff1e7120SSebastian Grimberg       break;
313ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
314ff1e7120SSebastian Grimberg       *array = impl->d_array;
315ff1e7120SSebastian Grimberg       break;
316ff1e7120SSebastian Grimberg   }
317ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
318ff1e7120SSebastian Grimberg }
319ff1e7120SSebastian Grimberg 
320ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
321ff1e7120SSebastian Grimberg // Get read-only access to a vector via the specified mem_type
322ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
323ff1e7120SSebastian Grimberg static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
324ff1e7120SSebastian Grimberg   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
325ff1e7120SSebastian Grimberg }
326ff1e7120SSebastian Grimberg 
327ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
328ff1e7120SSebastian Grimberg // Get read/write access to a vector via the specified mem_type
329ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
330ff1e7120SSebastian Grimberg static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
331ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
332ca735530SJeremy L Thompson 
333ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
334ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
335ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
336ff1e7120SSebastian Grimberg   switch (mem_type) {
337ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
338ff1e7120SSebastian Grimberg       impl->h_array = *array;
339ff1e7120SSebastian Grimberg       break;
340ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
341ff1e7120SSebastian Grimberg       impl->d_array = *array;
342ff1e7120SSebastian Grimberg       break;
343ff1e7120SSebastian Grimberg   }
344ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
345ff1e7120SSebastian Grimberg }
346ff1e7120SSebastian Grimberg 
347ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
348ff1e7120SSebastian Grimberg // Get write access to a vector via the specified mem_type
349ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
350ff1e7120SSebastian Grimberg static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
351ff1e7120SSebastian Grimberg   bool             has_array_of_type = true;
352ca735530SJeremy L Thompson   CeedVector_Cuda *impl;
353ca735530SJeremy L Thompson 
354ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
355ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
356ff1e7120SSebastian Grimberg   if (!has_array_of_type) {
357ff1e7120SSebastian Grimberg     // Allocate if array is not yet allocated
358ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
359ff1e7120SSebastian Grimberg   } else {
360ff1e7120SSebastian Grimberg     // Select dirty array
361ff1e7120SSebastian Grimberg     switch (mem_type) {
362ff1e7120SSebastian Grimberg       case CEED_MEM_HOST:
363ff1e7120SSebastian Grimberg         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
364ff1e7120SSebastian Grimberg         else impl->h_array = impl->h_array_owned;
365ff1e7120SSebastian Grimberg         break;
366ff1e7120SSebastian Grimberg       case CEED_MEM_DEVICE:
367ff1e7120SSebastian Grimberg         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
368ff1e7120SSebastian Grimberg         else impl->d_array = impl->d_array_owned;
369ff1e7120SSebastian Grimberg     }
370ff1e7120SSebastian Grimberg   }
371ff1e7120SSebastian Grimberg   return CeedVectorGetArray_Cuda(vec, mem_type, array);
372ff1e7120SSebastian Grimberg }
373ff1e7120SSebastian Grimberg 
374ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
375ff1e7120SSebastian Grimberg // Get the norm of a CeedVector
376ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
377ff1e7120SSebastian Grimberg static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
378ff1e7120SSebastian Grimberg   Ceed     ceed;
379ca735530SJeremy L Thompson   CeedSize length;
380672b0f2aSSebastian Grimberg #if CUDA_VERSION < 12000
381672b0f2aSSebastian Grimberg   CeedSize num_calls;
382672b0f2aSSebastian Grimberg #endif
383ca735530SJeremy L Thompson   const CeedScalar *d_array;
384ca735530SJeremy L Thompson   CeedVector_Cuda  *impl;
385672b0f2aSSebastian Grimberg   cublasHandle_t    handle;
386ca735530SJeremy L Thompson 
387ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
388ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
389ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
390ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
391ff1e7120SSebastian Grimberg 
392f7c1b517Snbeams #if CUDA_VERSION < 12000
393f7c1b517Snbeams   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
394f7c1b517Snbeams   // we need to check if the vector is too long to handle with int32,
395b2165e7aSSebastian Grimberg   // and if so, divide it into subsections for repeated cuBLAS calls.
396672b0f2aSSebastian Grimberg   num_calls = length / INT_MAX;
397f7c1b517Snbeams   if (length % INT_MAX > 0) num_calls += 1;
398f7c1b517Snbeams #endif
399f7c1b517Snbeams 
400ff1e7120SSebastian Grimberg   // Compute norm
401ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
402ff1e7120SSebastian Grimberg   switch (type) {
403ff1e7120SSebastian Grimberg     case CEED_NORM_1: {
404f6f49adbSnbeams       *norm = 0.0;
405ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
406f7c1b517Snbeams #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
407f7c1b517Snbeams         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
408f7c1b517Snbeams #else
409f7c1b517Snbeams         float  sub_norm = 0.0;
410f7c1b517Snbeams         float *d_array_start;
411ca735530SJeremy L Thompson 
412f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
413f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
414f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
415f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
416ca735530SJeremy L Thompson 
417f7c1b517Snbeams           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
418f7c1b517Snbeams           *norm += sub_norm;
419f7c1b517Snbeams         }
420f7c1b517Snbeams #endif
421ff1e7120SSebastian Grimberg       } else {
422f7c1b517Snbeams #if CUDA_VERSION >= 12000
423f7c1b517Snbeams         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
424f7c1b517Snbeams #else
425f7c1b517Snbeams         double  sub_norm = 0.0;
426f7c1b517Snbeams         double *d_array_start;
427ca735530SJeremy L Thompson 
428f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
429f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
430f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
431f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
432ca735530SJeremy L Thompson 
433f7c1b517Snbeams           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
434f7c1b517Snbeams           *norm += sub_norm;
435f7c1b517Snbeams         }
436f7c1b517Snbeams #endif
437ff1e7120SSebastian Grimberg       }
438ff1e7120SSebastian Grimberg       break;
439ff1e7120SSebastian Grimberg     }
440ff1e7120SSebastian Grimberg     case CEED_NORM_2: {
441ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
442f7c1b517Snbeams #if CUDA_VERSION >= 12000
443f7c1b517Snbeams         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
444f7c1b517Snbeams #else
445f7c1b517Snbeams         float  sub_norm = 0.0, norm_sum = 0.0;
446f7c1b517Snbeams         float *d_array_start;
447ca735530SJeremy L Thompson 
448f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
449f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
450f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
451f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
452ca735530SJeremy L Thompson 
453f7c1b517Snbeams           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
454f7c1b517Snbeams           norm_sum += sub_norm * sub_norm;
455f7c1b517Snbeams         }
456f7c1b517Snbeams         *norm = sqrt(norm_sum);
457f7c1b517Snbeams #endif
458ff1e7120SSebastian Grimberg       } else {
459f7c1b517Snbeams #if CUDA_VERSION >= 12000
460f7c1b517Snbeams         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
461f7c1b517Snbeams #else
462f7c1b517Snbeams         double  sub_norm = 0.0, norm_sum = 0.0;
463f7c1b517Snbeams         double *d_array_start;
464ca735530SJeremy L Thompson 
465f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
466f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
467f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
468f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
469ca735530SJeremy L Thompson 
470f7c1b517Snbeams           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
471f7c1b517Snbeams           norm_sum += sub_norm * sub_norm;
472f7c1b517Snbeams         }
473f7c1b517Snbeams         *norm = sqrt(norm_sum);
474f7c1b517Snbeams #endif
475ff1e7120SSebastian Grimberg       }
476ff1e7120SSebastian Grimberg       break;
477ff1e7120SSebastian Grimberg     }
478ff1e7120SSebastian Grimberg     case CEED_NORM_MAX: {
479ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
480f7c1b517Snbeams #if CUDA_VERSION >= 12000
481ca735530SJeremy L Thompson         int64_t    index;
482ca735530SJeremy L Thompson         CeedScalar norm_no_abs;
483ca735530SJeremy L Thompson 
484ca735530SJeremy L Thompson         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
485ca735530SJeremy L Thompson         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
486ca735530SJeremy L Thompson         *norm = fabs(norm_no_abs);
487f7c1b517Snbeams #else
488ca735530SJeremy L Thompson         CeedInt index;
489f7c1b517Snbeams         float   sub_max = 0.0, current_max = 0.0;
490f7c1b517Snbeams         float  *d_array_start;
491ca735530SJeremy L Thompson 
492f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
493f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
494f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
495f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
496ca735530SJeremy L Thompson 
497ca735530SJeremy L Thompson           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
498ca735530SJeremy L Thompson           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
499f7c1b517Snbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
500f7c1b517Snbeams         }
501f7c1b517Snbeams         *norm = current_max;
502f7c1b517Snbeams #endif
503f7c1b517Snbeams       } else {
504f7c1b517Snbeams #if CUDA_VERSION >= 12000
505ca735530SJeremy L Thompson         int64_t    index;
506ca735530SJeremy L Thompson         CeedScalar norm_no_abs;
507ca735530SJeremy L Thompson 
508ca735530SJeremy L Thompson         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
509ca735530SJeremy L Thompson         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
510ca735530SJeremy L Thompson         *norm = fabs(norm_no_abs);
511f7c1b517Snbeams #else
512ca735530SJeremy L Thompson         CeedInt index;
513f7c1b517Snbeams         double  sub_max = 0.0, current_max = 0.0;
514f7c1b517Snbeams         double *d_array_start;
515ca735530SJeremy L Thompson 
516f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
517f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
518f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
519f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
520ca735530SJeremy L Thompson 
521ca735530SJeremy L Thompson           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
522ca735530SJeremy L Thompson           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
523f7c1b517Snbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
524f7c1b517Snbeams         }
525f7c1b517Snbeams         *norm = current_max;
526f7c1b517Snbeams #endif
527f7c1b517Snbeams       }
528ff1e7120SSebastian Grimberg       break;
529ff1e7120SSebastian Grimberg     }
530ff1e7120SSebastian Grimberg   }
531ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
532ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
533ff1e7120SSebastian Grimberg }
534ff1e7120SSebastian Grimberg 
535ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
536ff1e7120SSebastian Grimberg // Take reciprocal of a vector on host
537ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
538f7c1b517Snbeams static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
539f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) {
540ff1e7120SSebastian Grimberg     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
541ff1e7120SSebastian Grimberg   }
542ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
543ff1e7120SSebastian Grimberg }
544ff1e7120SSebastian Grimberg 
545ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
546ff1e7120SSebastian Grimberg // Take reciprocal of a vector on device (impl in .cu file)
547ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
548f7c1b517Snbeams int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
549ff1e7120SSebastian Grimberg 
550ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
551ff1e7120SSebastian Grimberg // Take reciprocal of a vector
552ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
553ff1e7120SSebastian Grimberg static int CeedVectorReciprocal_Cuda(CeedVector vec) {
554ff1e7120SSebastian Grimberg   CeedSize         length;
555ca735530SJeremy L Thompson   CeedVector_Cuda *impl;
556ff1e7120SSebastian Grimberg 
557ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
558ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
559ff1e7120SSebastian Grimberg   // Set value for synced device/host array
560ff1e7120SSebastian Grimberg   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
561ff1e7120SSebastian Grimberg   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
562ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
563ff1e7120SSebastian Grimberg }
564ff1e7120SSebastian Grimberg 
565ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
566ff1e7120SSebastian Grimberg // Compute x = alpha x on the host
567ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
568f7c1b517Snbeams static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
569f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
570ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
571ff1e7120SSebastian Grimberg }
572ff1e7120SSebastian Grimberg 
573ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
574ff1e7120SSebastian Grimberg // Compute x = alpha x on device (impl in .cu file)
575ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
576f7c1b517Snbeams int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
577ff1e7120SSebastian Grimberg 
578ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
579ff1e7120SSebastian Grimberg // Compute x = alpha x
580ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
581ff1e7120SSebastian Grimberg static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
582ff1e7120SSebastian Grimberg   CeedSize         length;
583ca735530SJeremy L Thompson   CeedVector_Cuda *x_impl;
584ff1e7120SSebastian Grimberg 
585ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
586ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(x, &length));
587ff1e7120SSebastian Grimberg   // Set value for synced device/host array
588ff1e7120SSebastian Grimberg   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
589ff1e7120SSebastian Grimberg   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
590ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
591ff1e7120SSebastian Grimberg }
592ff1e7120SSebastian Grimberg 
593ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
594ff1e7120SSebastian Grimberg // Compute y = alpha x + y on the host
595ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
596f7c1b517Snbeams static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
597f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
598ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
599ff1e7120SSebastian Grimberg }
600ff1e7120SSebastian Grimberg 
601ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
602ff1e7120SSebastian Grimberg // Compute y = alpha x + y on device (impl in .cu file)
603ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
604f7c1b517Snbeams int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
605ff1e7120SSebastian Grimberg 
606ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
607ff1e7120SSebastian Grimberg // Compute y = alpha x + y
608ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
609ff1e7120SSebastian Grimberg static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
610ff1e7120SSebastian Grimberg   Ceed             ceed;
611ca735530SJeremy L Thompson   CeedSize         length;
612ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
613ca735530SJeremy L Thompson 
614ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
615ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
616ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
617ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
618ff1e7120SSebastian Grimberg   // Set value for synced device/host array
619ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
620ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
621ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
622ff1e7120SSebastian Grimberg   }
623ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
624ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
625ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
626ff1e7120SSebastian Grimberg   }
627ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
628ff1e7120SSebastian Grimberg }
629ff1e7120SSebastian Grimberg 
630ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
631ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on the host
632ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
633f7c1b517Snbeams static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
634aa67b842SZach Atkins   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
635ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
636ff1e7120SSebastian Grimberg }
637ff1e7120SSebastian Grimberg 
638ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
639ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on device (impl in .cu file)
640ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
641f7c1b517Snbeams int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
642ff1e7120SSebastian Grimberg 
643ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
644ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y
645ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
646ff1e7120SSebastian Grimberg static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
647ca735530SJeremy L Thompson   CeedSize         length;
648ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
649ca735530SJeremy L Thompson 
650ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
651ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
652ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
653ff1e7120SSebastian Grimberg   // Set value for synced device/host array
654ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
655ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
656ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
657ff1e7120SSebastian Grimberg   }
658ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
659ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
660ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
661ff1e7120SSebastian Grimberg   }
662ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
663ff1e7120SSebastian Grimberg }
664ff1e7120SSebastian Grimberg 
665ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
666ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on the host
667ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
668f7c1b517Snbeams static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
669f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
670ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
671ff1e7120SSebastian Grimberg }
672ff1e7120SSebastian Grimberg 
673ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
674ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
675ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
676f7c1b517Snbeams int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
677ff1e7120SSebastian Grimberg 
678ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
679ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y
680ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
681ff1e7120SSebastian Grimberg static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
682ca735530SJeremy L Thompson   CeedSize         length;
683ff1e7120SSebastian Grimberg   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
684ca735530SJeremy L Thompson 
685ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(w, &w_impl));
686ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
687ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
688ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(w, &length));
689ff1e7120SSebastian Grimberg   // Set value for synced device/host array
690ff1e7120SSebastian Grimberg   if (!w_impl->d_array && !w_impl->h_array) {
691ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetValue(w, 0.0));
692ff1e7120SSebastian Grimberg   }
693ff1e7120SSebastian Grimberg   if (w_impl->d_array) {
694ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
695ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
696ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
697ff1e7120SSebastian Grimberg   }
698ff1e7120SSebastian Grimberg   if (w_impl->h_array) {
699ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
700ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
701ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
702ff1e7120SSebastian Grimberg   }
703ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
704ff1e7120SSebastian Grimberg }
705ff1e7120SSebastian Grimberg 
706ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
707ff1e7120SSebastian Grimberg // Destroy the vector
708ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
709ff1e7120SSebastian Grimberg static int CeedVectorDestroy_Cuda(const CeedVector vec) {
710ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
711ff1e7120SSebastian Grimberg 
712ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
7136e536b99SJeremy L Thompson   CeedCallCuda(CeedVectorReturnCeed(vec), cudaFree(impl->d_array_owned));
714ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl->h_array_owned));
715ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl));
716ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
717ff1e7120SSebastian Grimberg }
718ff1e7120SSebastian Grimberg 
719ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
720ff1e7120SSebastian Grimberg // Create a vector of the specified length (does not allocate memory)
721ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
722ff1e7120SSebastian Grimberg int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
723ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
724ff1e7120SSebastian Grimberg   Ceed             ceed;
725ff1e7120SSebastian Grimberg 
726ca735530SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
727ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
728ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
729ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
730ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
731ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
732ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
733ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
734ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
735ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
736ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
737ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
738ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
739ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
740ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
741ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
742ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
743ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
744ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetData(vec, impl));
745ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
746ff1e7120SSebastian Grimberg }
747ff1e7120SSebastian Grimberg 
748ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
749