xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision 356036fa84f714fa73ef64c9a80ce2028dde816f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <string.h>
12 #include <hip/hip_runtime.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21   Ceed ceed;
22   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
23   CeedQFunctionContext_Hip *impl;
24   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
25 
26   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
27 
28   size_t ctxsize;
29   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
30 
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
37     impl->d_data = impl->d_data_owned;
38   }
39 
40   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice));
41 
42   return CEED_ERROR_SUCCESS;
43 }
44 
45 //------------------------------------------------------------------------------
46 // Sync device to host
47 //------------------------------------------------------------------------------
48 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
49   Ceed ceed;
50   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
51   CeedQFunctionContext_Hip *impl;
52   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
53 
54   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
55 
56   size_t ctxsize;
57   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
58 
59   if (impl->h_data_borrowed) {
60     impl->h_data = impl->h_data_borrowed;
61   } else if (impl->h_data_owned) {
62     impl->h_data = impl->h_data_owned;
63   } else {
64     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
65     impl->h_data = impl->h_data_owned;
66   }
67 
68   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost));
69 
70   return CEED_ERROR_SUCCESS;
71 }
72 
73 //------------------------------------------------------------------------------
74 // Sync data of type
75 //------------------------------------------------------------------------------
76 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
77   switch (mem_type) {
78     case CEED_MEM_HOST:
79       return CeedQFunctionContextSyncD2H_Hip(ctx);
80     case CEED_MEM_DEVICE:
81       return CeedQFunctionContextSyncH2D_Hip(ctx);
82   }
83   return CEED_ERROR_UNSUPPORTED;
84 }
85 
86 //------------------------------------------------------------------------------
87 // Set all pointers as invalid
88 //------------------------------------------------------------------------------
89 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
90   CeedQFunctionContext_Hip *impl;
91   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
92 
93   impl->h_data = NULL;
94   impl->d_data = NULL;
95 
96   return CEED_ERROR_SUCCESS;
97 }
98 
99 //------------------------------------------------------------------------------
100 // Check for valid data
101 //------------------------------------------------------------------------------
102 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
103   CeedQFunctionContext_Hip *impl;
104   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
105 
106   *has_valid_data = impl && (impl->h_data || impl->d_data);
107 
108   return CEED_ERROR_SUCCESS;
109 }
110 
111 //------------------------------------------------------------------------------
112 // Check if ctx has borrowed data
113 //------------------------------------------------------------------------------
114 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
115                                                                 bool *has_borrowed_data_of_type) {
116   CeedQFunctionContext_Hip *impl;
117   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
118 
119   switch (mem_type) {
120     case CEED_MEM_HOST:
121       *has_borrowed_data_of_type = impl->h_data_borrowed;
122       break;
123     case CEED_MEM_DEVICE:
124       *has_borrowed_data_of_type = impl->d_data_borrowed;
125       break;
126   }
127 
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 //------------------------------------------------------------------------------
132 // Check if data of given type needs sync
133 //------------------------------------------------------------------------------
134 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
135   CeedQFunctionContext_Hip *impl;
136   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
137 
138   bool has_valid_data = true;
139   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
140   switch (mem_type) {
141     case CEED_MEM_HOST:
142       *need_sync = has_valid_data && !impl->h_data;
143       break;
144     case CEED_MEM_DEVICE:
145       *need_sync = has_valid_data && !impl->d_data;
146       break;
147   }
148 
149   return CEED_ERROR_SUCCESS;
150 }
151 
152 //------------------------------------------------------------------------------
153 // Set data from host
154 //------------------------------------------------------------------------------
155 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
156   CeedQFunctionContext_Hip *impl;
157   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
158 
159   CeedCallBackend(CeedFree(&impl->h_data_owned));
160   switch (copy_mode) {
161     case CEED_COPY_VALUES: {
162       size_t ctxsize;
163       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
164       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
165       impl->h_data_borrowed = NULL;
166       impl->h_data          = impl->h_data_owned;
167       memcpy(impl->h_data, data, ctxsize);
168     } break;
169     case CEED_OWN_POINTER:
170       impl->h_data_owned    = data;
171       impl->h_data_borrowed = NULL;
172       impl->h_data          = data;
173       break;
174     case CEED_USE_POINTER:
175       impl->h_data_borrowed = data;
176       impl->h_data          = data;
177       break;
178   }
179 
180   return CEED_ERROR_SUCCESS;
181 }
182 
183 //------------------------------------------------------------------------------
184 // Set data from device
185 //------------------------------------------------------------------------------
186 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
187   Ceed ceed;
188   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
189   CeedQFunctionContext_Hip *impl;
190   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
191 
192   CeedCallHip(ceed, hipFree(impl->d_data_owned));
193   impl->d_data_owned = NULL;
194   switch (copy_mode) {
195     case CEED_COPY_VALUES: {
196       size_t ctxsize;
197       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
198       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
199       impl->d_data_borrowed = NULL;
200       impl->d_data          = impl->d_data_owned;
201       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice));
202     } break;
203     case CEED_OWN_POINTER:
204       impl->d_data_owned    = data;
205       impl->d_data_borrowed = NULL;
206       impl->d_data          = data;
207       break;
208     case CEED_USE_POINTER:
209       impl->d_data_owned    = NULL;
210       impl->d_data_borrowed = data;
211       impl->d_data          = data;
212       break;
213   }
214 
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Set the data used by a user context,
220 //    freeing any previously allocated data if applicable
221 //------------------------------------------------------------------------------
222 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
223   Ceed ceed;
224   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
225 
226   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
227   switch (mem_type) {
228     case CEED_MEM_HOST:
229       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
230     case CEED_MEM_DEVICE:
231       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
232   }
233 
234   return CEED_ERROR_UNSUPPORTED;
235 }
236 
237 //------------------------------------------------------------------------------
238 // Take data
239 //------------------------------------------------------------------------------
240 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
241   Ceed ceed;
242   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
243   CeedQFunctionContext_Hip *impl;
244   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
245 
246   // Sync data to requested mem_type
247   bool need_sync = false;
248   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
249   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
250 
251   // Update pointer
252   switch (mem_type) {
253     case CEED_MEM_HOST:
254       *(void **)data        = impl->h_data_borrowed;
255       impl->h_data_borrowed = NULL;
256       impl->h_data          = NULL;
257       break;
258     case CEED_MEM_DEVICE:
259       *(void **)data        = impl->d_data_borrowed;
260       impl->d_data_borrowed = NULL;
261       impl->d_data          = NULL;
262       break;
263   }
264 
265   return CEED_ERROR_SUCCESS;
266 }
267 
268 //------------------------------------------------------------------------------
269 // Core logic for GetData.
270 //   If a different memory type is most up to date, this will perform a copy
271 //------------------------------------------------------------------------------
272 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
273   Ceed ceed;
274   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
275   CeedQFunctionContext_Hip *impl;
276   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
277 
278   // Sync data to requested mem_type
279   bool need_sync = false;
280   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
281   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
282 
283   // Update pointer
284   switch (mem_type) {
285     case CEED_MEM_HOST:
286       *(void **)data = impl->h_data;
287       break;
288     case CEED_MEM_DEVICE:
289       *(void **)data = impl->d_data;
290       break;
291   }
292 
293   return CEED_ERROR_SUCCESS;
294 }
295 
296 //------------------------------------------------------------------------------
297 // Get read-only access to the data
298 //------------------------------------------------------------------------------
299 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
300   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
301 }
302 
303 //------------------------------------------------------------------------------
304 // Get read/write access to the data
305 //------------------------------------------------------------------------------
306 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
307   CeedQFunctionContext_Hip *impl;
308   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
309 
310   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
311 
312   // Mark only pointer for requested memory as valid
313   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
314   switch (mem_type) {
315     case CEED_MEM_HOST:
316       impl->h_data = *(void **)data;
317       break;
318     case CEED_MEM_DEVICE:
319       impl->d_data = *(void **)data;
320       break;
321   }
322 
323   return CEED_ERROR_SUCCESS;
324 }
325 
326 //------------------------------------------------------------------------------
327 // Destroy the user context
328 //------------------------------------------------------------------------------
329 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
330   Ceed ceed;
331   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
332   CeedQFunctionContext_Hip *impl;
333   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
334 
335   CeedCallHip(ceed, hipFree(impl->d_data_owned));
336   CeedCallBackend(CeedFree(&impl->h_data_owned));
337   CeedCallBackend(CeedFree(&impl));
338 
339   return CEED_ERROR_SUCCESS;
340 }
341 
342 //------------------------------------------------------------------------------
343 // QFunctionContext Create
344 //------------------------------------------------------------------------------
345 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
346   CeedQFunctionContext_Hip *impl;
347   Ceed                      ceed;
348   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
349 
350   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
351   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
352   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
353   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
354   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
355   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
356   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
357 
358   CeedCallBackend(CeedCalloc(1, &impl));
359   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
360 
361   return CEED_ERROR_SUCCESS;
362 }
363 
364 //------------------------------------------------------------------------------
365