xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision 667e613fe678313d77f7966d97cc228a73b32933)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <string.h>
12 #include "ceed-hip-ref.h"
13 
14 //------------------------------------------------------------------------------
15 // Sync host to device
16 //------------------------------------------------------------------------------
17 static inline int CeedQFunctionContextSyncH2D_Hip(
18   const CeedQFunctionContext ctx) {
19   int ierr;
20   Ceed ceed;
21   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
22   CeedQFunctionContext_Hip *impl;
23   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
24 
25   if (!impl->h_data)
26     // LCOV_EXCL_START
27     return CeedError(ceed, CEED_ERROR_BACKEND,
28                      "No valid host data to sync to device");
29   // LCOV_EXCL_STOP
30 
31   size_t ctxsize;
32   ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
33 
34   if (impl->d_data_borrowed) {
35     impl->d_data = impl->d_data_borrowed;
36   } else if (impl->d_data_owned) {
37     impl->d_data = impl->d_data_owned;
38   } else {
39     ierr = hipMalloc((void **)&impl->d_data_owned, ctxsize);
40     CeedChk_Hip(ceed, ierr);
41     impl->d_data = impl->d_data_owned;
42   }
43 
44   ierr = hipMemcpy(impl->d_data, impl->h_data, ctxsize,
45                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
46 
47   return CEED_ERROR_SUCCESS;
48 }
49 
50 //------------------------------------------------------------------------------
51 // Sync device to host
52 //------------------------------------------------------------------------------
53 static inline int CeedQFunctionContextSyncD2H_Hip(
54   const CeedQFunctionContext ctx) {
55   int ierr;
56   Ceed ceed;
57   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
58   CeedQFunctionContext_Hip *impl;
59   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
60 
61   if (!impl->d_data)
62     // LCOV_EXCL_START
63     return CeedError(ceed, CEED_ERROR_BACKEND,
64                      "No valid device data to sync to host");
65   // LCOV_EXCL_STOP
66 
67   size_t ctxsize;
68   ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
69 
70   if (impl->h_data_borrowed) {
71     impl->h_data = impl->h_data_borrowed;
72   } else if (impl->h_data_owned) {
73     impl->h_data = impl->h_data_owned;
74   } else {
75     ierr = CeedMalloc(ctxsize, &impl->h_data_owned);
76     CeedChkBackend(ierr);
77     impl->h_data = impl->h_data_owned;
78   }
79 
80   ierr = hipMemcpy(impl->h_data, impl->d_data, ctxsize,
81                    hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
82 
83   return CEED_ERROR_SUCCESS;
84 }
85 
86 //------------------------------------------------------------------------------
87 // Sync data of type
88 //------------------------------------------------------------------------------
89 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx,
90     CeedMemType mem_type) {
91   switch (mem_type) {
92   case CEED_MEM_HOST: return CeedQFunctionContextSyncD2H_Hip(ctx);
93   case CEED_MEM_DEVICE: return CeedQFunctionContextSyncH2D_Hip(ctx);
94   }
95   return CEED_ERROR_UNSUPPORTED;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Set all pointers as invalid
100 //------------------------------------------------------------------------------
101 static inline int CeedQFunctionContextSetAllInvalid_Hip(
102   const CeedQFunctionContext ctx) {
103   int ierr;
104   CeedQFunctionContext_Hip *impl;
105   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
106 
107   impl->h_data = NULL;
108   impl->d_data = NULL;
109 
110   return CEED_ERROR_SUCCESS;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Check for valid data
115 //------------------------------------------------------------------------------
116 static inline int CeedQFunctionContextHasValidData_Hip(
117   const CeedQFunctionContext ctx, bool *has_valid_data) {
118   int ierr;
119   CeedQFunctionContext_Hip *impl;
120   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
121 
122   *has_valid_data = !!impl->h_data || !!impl->d_data;
123 
124   return CEED_ERROR_SUCCESS;
125 }
126 
127 //------------------------------------------------------------------------------
128 // Check if ctx has borrowed data
129 //------------------------------------------------------------------------------
130 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(
131   const CeedQFunctionContext ctx, CeedMemType mem_type,
132   bool *has_borrowed_data_of_type) {
133   int ierr;
134   CeedQFunctionContext_Hip *impl;
135   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
136 
137   switch (mem_type) {
138   case CEED_MEM_HOST:
139     *has_borrowed_data_of_type = !!impl->h_data_borrowed;
140     break;
141   case CEED_MEM_DEVICE:
142     *has_borrowed_data_of_type = !!impl->d_data_borrowed;
143     break;
144   }
145 
146   return CEED_ERROR_SUCCESS;
147 }
148 
149 //------------------------------------------------------------------------------
150 // Check if data of given type needs sync
151 //------------------------------------------------------------------------------
152 static inline int CeedQFunctionContextNeedSync_Hip(
153   const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
154   int ierr;
155   CeedQFunctionContext_Hip *impl;
156   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
157 
158   bool has_valid_data = true;
159   ierr = CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data);
160   CeedChkBackend(ierr);
161   switch (mem_type) {
162   case CEED_MEM_HOST:
163     *need_sync = has_valid_data && !impl->h_data;
164     break;
165   case CEED_MEM_DEVICE:
166     *need_sync = has_valid_data && !impl->d_data;
167     break;
168   }
169 
170   return CEED_ERROR_SUCCESS;
171 }
172 
173 //------------------------------------------------------------------------------
174 // Set data from host
175 //------------------------------------------------------------------------------
176 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx,
177     const CeedCopyMode copy_mode, void *data) {
178   int ierr;
179   CeedQFunctionContext_Hip *impl;
180   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
181 
182   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
183   switch (copy_mode) {
184   case CEED_COPY_VALUES: {
185     size_t ctxsize;
186     ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
187     ierr = CeedMalloc(ctxsize, &impl->h_data_owned); CeedChkBackend(ierr);
188     impl->h_data_borrowed = NULL;
189     impl->h_data = impl->h_data_owned;
190     memcpy(impl->h_data, data, ctxsize);
191   } break;
192   case CEED_OWN_POINTER:
193     impl->h_data_owned = data;
194     impl->h_data_borrowed = NULL;
195     impl->h_data = data;
196     break;
197   case CEED_USE_POINTER:
198     impl->h_data_borrowed = data;
199     impl->h_data = data;
200     break;
201   }
202 
203   return CEED_ERROR_SUCCESS;
204 }
205 
206 //------------------------------------------------------------------------------
207 // Set data from device
208 //------------------------------------------------------------------------------
209 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx,
210     const CeedCopyMode copy_mode, void *data) {
211   int ierr;
212   Ceed ceed;
213   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
214   CeedQFunctionContext_Hip *impl;
215   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
216 
217   ierr = hipFree(impl->d_data_owned); CeedChk_Hip(ceed, ierr);
218   impl->d_data_owned = NULL;
219   switch (copy_mode) {
220   case CEED_COPY_VALUES: {
221     size_t ctxsize;
222     ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
223     ierr = hipMalloc((void **)&impl->d_data_owned, ctxsize);
224     CeedChk_Hip(ceed, ierr);
225     impl->d_data_borrowed = NULL;
226     impl->d_data = impl->d_data_owned;
227     ierr = hipMemcpy(impl->d_data, data, ctxsize,
228                      hipMemcpyDeviceToDevice); CeedChk_Hip(ceed, ierr);
229   } break;
230   case CEED_OWN_POINTER:
231     impl->d_data_owned = data;
232     impl->d_data_borrowed = NULL;
233     impl->d_data = data;
234     break;
235   case CEED_USE_POINTER:
236     impl->d_data_owned = NULL;
237     impl->d_data_borrowed = data;
238     impl->d_data = data;
239     break;
240   }
241 
242   return CEED_ERROR_SUCCESS;
243 }
244 
245 //------------------------------------------------------------------------------
246 // Set the data used by a user context,
247 //   freeing any previously allocated data if applicable
248 //------------------------------------------------------------------------------
249 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx,
250     const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
251   int ierr;
252   Ceed ceed;
253   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
254 
255   ierr = CeedQFunctionContextSetAllInvalid_Hip(ctx); CeedChkBackend(ierr);
256   switch (mem_type) {
257   case CEED_MEM_HOST:
258     return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
259   case CEED_MEM_DEVICE:
260     return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
261   }
262 
263   return CEED_ERROR_UNSUPPORTED;
264 }
265 
266 //------------------------------------------------------------------------------
267 // Take data
268 //------------------------------------------------------------------------------
269 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx,
270     const CeedMemType mem_type, void *data) {
271   int ierr;
272   Ceed ceed;
273   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
274   CeedQFunctionContext_Hip *impl;
275   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
276 
277   // Sync data to requested mem_type
278   bool need_sync = false;
279   ierr = CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync);
280   CeedChkBackend(ierr);
281   if (need_sync) {
282     ierr = CeedQFunctionContextSync_Hip(ctx, mem_type); CeedChkBackend(ierr);
283   }
284 
285   // Update pointer
286   switch (mem_type) {
287   case CEED_MEM_HOST:
288     *(void **)data = impl->h_data_borrowed;
289     impl->h_data_borrowed = NULL;
290     impl->h_data = NULL;
291     break;
292   case CEED_MEM_DEVICE:
293     *(void **)data = impl->d_data_borrowed;
294     impl->d_data_borrowed = NULL;
295     impl->d_data = NULL;
296     break;
297   }
298 
299   return CEED_ERROR_SUCCESS;
300 }
301 
302 //------------------------------------------------------------------------------
303 // Core logic for GetData.
304 //   If a different memory type is most up to date, this will perform a copy
305 //------------------------------------------------------------------------------
306 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx,
307     const CeedMemType mem_type, void *data) {
308   int ierr;
309   Ceed ceed;
310   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
311   CeedQFunctionContext_Hip *impl;
312   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
313 
314   // Sync data to requested mem_type
315   bool need_sync = false;
316   ierr = CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync);
317   CeedChkBackend(ierr);
318   if (need_sync) {
319     ierr = CeedQFunctionContextSync_Hip(ctx, mem_type); CeedChkBackend(ierr);
320   }
321 
322   // Sync data to requested mem_type and update pointer
323   switch (mem_type) {
324   case CEED_MEM_HOST:
325     *(void **)data = impl->h_data;
326     break;
327   case CEED_MEM_DEVICE:
328     *(void **)data = impl->d_data;
329     break;
330   }
331 
332   return CEED_ERROR_SUCCESS;
333 }
334 
335 //------------------------------------------------------------------------------
336 // Get read-only access to the data
337 //------------------------------------------------------------------------------
338 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx,
339     const CeedMemType mem_type, void *data) {
340   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
341 }
342 
343 //------------------------------------------------------------------------------
344 // Get read/write access to the data
345 //------------------------------------------------------------------------------
346 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx,
347     const CeedMemType mem_type, void *data) {
348   int ierr;
349   CeedQFunctionContext_Hip *impl;
350   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
351 
352   ierr = CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
353   CeedChkBackend(ierr);
354 
355   // Mark only pointer for requested memory as valid
356   ierr = CeedQFunctionContextSetAllInvalid_Hip(ctx); CeedChkBackend(ierr);
357   switch (mem_type) {
358   case CEED_MEM_HOST:
359     impl->h_data = *(void **)data;
360     break;
361   case CEED_MEM_DEVICE:
362     impl->d_data = *(void **)data;
363     break;
364   }
365 
366   return CEED_ERROR_SUCCESS;
367 }
368 
369 //------------------------------------------------------------------------------
370 // Destroy the user context
371 //------------------------------------------------------------------------------
372 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
373   int ierr;
374   Ceed ceed;
375   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
376   CeedQFunctionContext_Hip *impl;
377   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
378 
379   ierr = hipFree(impl->d_data_owned); CeedChk_Hip(ceed, ierr);
380   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
381   ierr = CeedFree(&impl); CeedChkBackend(ierr);
382 
383   return CEED_ERROR_SUCCESS;
384 }
385 
386 //------------------------------------------------------------------------------
387 // QFunctionContext Create
388 //------------------------------------------------------------------------------
389 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
390   int ierr;
391   CeedQFunctionContext_Hip *impl;
392   Ceed ceed;
393   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
394 
395   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData",
396                                 CeedQFunctionContextHasValidData_Hip);
397   CeedChkBackend(ierr);
398   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx,
399                                 "HasBorrowedDataOfType",
400                                 CeedQFunctionContextHasBorrowedDataOfType_Hip);
401   CeedChkBackend(ierr);
402   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData",
403                                 CeedQFunctionContextSetData_Hip); CeedChkBackend(ierr);
404   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData",
405                                 CeedQFunctionContextTakeData_Hip); CeedChkBackend(ierr);
406   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData",
407                                 CeedQFunctionContextGetData_Hip); CeedChkBackend(ierr);
408   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead",
409                                 CeedQFunctionContextGetDataRead_Hip); CeedChkBackend(ierr);
410   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy",
411                                 CeedQFunctionContextDestroy_Hip); CeedChkBackend(ierr);
412 
413   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
414   ierr = CeedQFunctionContextSetBackendData(ctx, impl); CeedChkBackend(ierr);
415 
416   return CEED_ERROR_SUCCESS;
417 }
418 //------------------------------------------------------------------------------
419