xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision 9ba83ac0e4b1fca39d6fa6737a318a9f0cbc172d)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <string.h>
12 #include <hip/hip_runtime.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21   Ceed                      ceed;
22   size_t                    ctx_size;
23   CeedQFunctionContext_Hip *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice));
40   CeedCallBackend(CeedDestroy(&ceed));
41   return CEED_ERROR_SUCCESS;
42 }
43 
44 //------------------------------------------------------------------------------
45 // Sync device to host
46 //------------------------------------------------------------------------------
47 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
48   Ceed                      ceed;
49   size_t                    ctx_size;
50   CeedQFunctionContext_Hip *impl;
51 
52   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
54 
55   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
56 
57   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
58   if (impl->h_data_borrowed) {
59     impl->h_data = impl->h_data_borrowed;
60   } else if (impl->h_data_owned) {
61     impl->h_data = impl->h_data_owned;
62   } else {
63     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
64     impl->h_data = impl->h_data_owned;
65   }
66   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost));
67   CeedCallBackend(CeedDestroy(&ceed));
68   return CEED_ERROR_SUCCESS;
69 }
70 
71 //------------------------------------------------------------------------------
72 // Sync data of type
73 //------------------------------------------------------------------------------
74 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
75   switch (mem_type) {
76     case CEED_MEM_HOST:
77       return CeedQFunctionContextSyncD2H_Hip(ctx);
78     case CEED_MEM_DEVICE:
79       return CeedQFunctionContextSyncH2D_Hip(ctx);
80   }
81   return CEED_ERROR_UNSUPPORTED;
82 }
83 
84 //------------------------------------------------------------------------------
85 // Set all pointers as invalid
86 //------------------------------------------------------------------------------
87 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
88   CeedQFunctionContext_Hip *impl;
89 
90   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
91   impl->h_data = NULL;
92   impl->d_data = NULL;
93   return CEED_ERROR_SUCCESS;
94 }
95 
96 //------------------------------------------------------------------------------
97 // Check for valid data
98 //------------------------------------------------------------------------------
99 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
100   CeedQFunctionContext_Hip *impl;
101 
102   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
103   *has_valid_data = impl && (impl->h_data || impl->d_data);
104   return CEED_ERROR_SUCCESS;
105 }
106 
107 //------------------------------------------------------------------------------
108 // Check if ctx has borrowed data
109 //------------------------------------------------------------------------------
110 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
111                                                                 bool *has_borrowed_data_of_type) {
112   CeedQFunctionContext_Hip *impl;
113 
114   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
115   switch (mem_type) {
116     case CEED_MEM_HOST:
117       *has_borrowed_data_of_type = impl->h_data_borrowed;
118       break;
119     case CEED_MEM_DEVICE:
120       *has_borrowed_data_of_type = impl->d_data_borrowed;
121       break;
122   }
123   return CEED_ERROR_SUCCESS;
124 }
125 
126 //------------------------------------------------------------------------------
127 // Check if data of given type needs sync
128 //------------------------------------------------------------------------------
129 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
130   bool                      has_valid_data = true;
131   CeedQFunctionContext_Hip *impl;
132 
133   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
134   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
135   switch (mem_type) {
136     case CEED_MEM_HOST:
137       *need_sync = has_valid_data && !impl->h_data;
138       break;
139     case CEED_MEM_DEVICE:
140       *need_sync = has_valid_data && !impl->d_data;
141       break;
142   }
143   return CEED_ERROR_SUCCESS;
144 }
145 
146 //------------------------------------------------------------------------------
147 // Set data from host
148 //------------------------------------------------------------------------------
149 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
150   CeedQFunctionContext_Hip *impl;
151 
152   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
153   CeedCallBackend(CeedFree(&impl->h_data_owned));
154   switch (copy_mode) {
155     case CEED_COPY_VALUES: {
156       size_t ctx_size;
157 
158       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
159       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
160       impl->h_data_borrowed = NULL;
161       impl->h_data          = impl->h_data_owned;
162       memcpy(impl->h_data, data, ctx_size);
163     } break;
164     case CEED_OWN_POINTER:
165       impl->h_data_owned    = data;
166       impl->h_data_borrowed = NULL;
167       impl->h_data          = data;
168       break;
169     case CEED_USE_POINTER:
170       impl->h_data_borrowed = data;
171       impl->h_data          = data;
172       break;
173   }
174   return CEED_ERROR_SUCCESS;
175 }
176 
177 //------------------------------------------------------------------------------
178 // Set data from device
179 //------------------------------------------------------------------------------
180 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
181   Ceed                      ceed;
182   CeedQFunctionContext_Hip *impl;
183 
184   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
185   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
186 
187   CeedCallHip(ceed, hipFree(impl->d_data_owned));
188   impl->d_data_owned = NULL;
189   switch (copy_mode) {
190     case CEED_COPY_VALUES: {
191       size_t ctx_size;
192       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
193       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
194       impl->d_data_borrowed = NULL;
195       impl->d_data          = impl->d_data_owned;
196       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice));
197     } break;
198     case CEED_OWN_POINTER:
199       impl->d_data_owned    = data;
200       impl->d_data_borrowed = NULL;
201       impl->d_data          = data;
202       break;
203     case CEED_USE_POINTER:
204       impl->d_data_owned    = NULL;
205       impl->d_data_borrowed = data;
206       impl->d_data          = data;
207       break;
208   }
209   CeedCallBackend(CeedDestroy(&ceed));
210   return CEED_ERROR_SUCCESS;
211 }
212 
213 //------------------------------------------------------------------------------
214 // Set the data used by a user context,
215 //    freeing any previously allocated data if applicable
216 //------------------------------------------------------------------------------
217 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
218   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
219   switch (mem_type) {
220     case CEED_MEM_HOST:
221       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
222     case CEED_MEM_DEVICE:
223       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
224   }
225   return CEED_ERROR_UNSUPPORTED;
226 }
227 
228 //------------------------------------------------------------------------------
229 // Take data
230 //------------------------------------------------------------------------------
231 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
232   bool                      need_sync = false;
233   CeedQFunctionContext_Hip *impl;
234 
235   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
236 
237   // Sync data to requested mem_type
238   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
239   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
240 
241   // Update pointer
242   switch (mem_type) {
243     case CEED_MEM_HOST:
244       *(void **)data        = impl->h_data_borrowed;
245       impl->h_data_borrowed = NULL;
246       impl->h_data          = NULL;
247       break;
248     case CEED_MEM_DEVICE:
249       *(void **)data        = impl->d_data_borrowed;
250       impl->d_data_borrowed = NULL;
251       impl->d_data          = NULL;
252       break;
253   }
254   return CEED_ERROR_SUCCESS;
255 }
256 
257 //------------------------------------------------------------------------------
258 // Core logic for GetData.
259 //   If a different memory type is most up to date, this will perform a copy
260 //------------------------------------------------------------------------------
261 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
262   bool                      need_sync = false;
263   CeedQFunctionContext_Hip *impl;
264 
265   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
266 
267   // Sync data to requested mem_type
268   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
269   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
270 
271   // Update pointer
272   switch (mem_type) {
273     case CEED_MEM_HOST:
274       *(void **)data = impl->h_data;
275       break;
276     case CEED_MEM_DEVICE:
277       *(void **)data = impl->d_data;
278       break;
279   }
280   return CEED_ERROR_SUCCESS;
281 }
282 
283 //------------------------------------------------------------------------------
284 // Get read-only access to the data
285 //------------------------------------------------------------------------------
286 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
287   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
288 }
289 
290 //------------------------------------------------------------------------------
291 // Get read/write access to the data
292 //------------------------------------------------------------------------------
293 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
294   CeedQFunctionContext_Hip *impl;
295 
296   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
297   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
298 
299   // Mark only pointer for requested memory as valid
300   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
301   switch (mem_type) {
302     case CEED_MEM_HOST:
303       impl->h_data = *(void **)data;
304       break;
305     case CEED_MEM_DEVICE:
306       impl->d_data = *(void **)data;
307       break;
308   }
309   return CEED_ERROR_SUCCESS;
310 }
311 
312 //------------------------------------------------------------------------------
313 // Destroy the user context
314 //------------------------------------------------------------------------------
315 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
316   CeedQFunctionContext_Hip *impl;
317 
318   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
319   CeedCallHip(CeedQFunctionContextReturnCeed(ctx), hipFree(impl->d_data_owned));
320   CeedCallBackend(CeedFree(&impl->h_data_owned));
321   CeedCallBackend(CeedFree(&impl));
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // QFunctionContext Create
327 //------------------------------------------------------------------------------
328 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
329   CeedQFunctionContext_Hip *impl;
330   Ceed                      ceed;
331 
332   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
333   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
334   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
335   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
336   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
337   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
338   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
339   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
340   CeedCallBackend(CeedDestroy(&ceed));
341   CeedCallBackend(CeedCalloc(1, &impl));
342   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
343   return CEED_ERROR_SUCCESS;
344 }
345 
346 //------------------------------------------------------------------------------
347