xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision 11b88dda510d0aa70e79dc59ad165e2a5539c3c3)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21   Ceed ceed;
22   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
23   CeedQFunctionContext_Hip *impl;
24   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
25 
26   if (!impl->h_data) {
27     // LCOV_EXCL_START
28     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29     // LCOV_EXCL_STOP
30   }
31 
32   size_t ctxsize;
33   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
34 
35   if (impl->d_data_borrowed) {
36     impl->d_data = impl->d_data_borrowed;
37   } else if (impl->d_data_owned) {
38     impl->d_data = impl->d_data_owned;
39   } else {
40     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
41     impl->d_data = impl->d_data_owned;
42   }
43 
44   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice));
45 
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 //------------------------------------------------------------------------------
50 // Sync device to host
51 //------------------------------------------------------------------------------
52 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
53   Ceed ceed;
54   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
55   CeedQFunctionContext_Hip *impl;
56   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
57 
58   if (!impl->d_data) {
59     // LCOV_EXCL_START
60     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
61     // LCOV_EXCL_STOP
62   }
63 
64   size_t ctxsize;
65   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
66 
67   if (impl->h_data_borrowed) {
68     impl->h_data = impl->h_data_borrowed;
69   } else if (impl->h_data_owned) {
70     impl->h_data = impl->h_data_owned;
71   } else {
72     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
73     impl->h_data = impl->h_data_owned;
74   }
75 
76   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost));
77 
78   return CEED_ERROR_SUCCESS;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Sync data of type
83 //------------------------------------------------------------------------------
84 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
85   switch (mem_type) {
86     case CEED_MEM_HOST:
87       return CeedQFunctionContextSyncD2H_Hip(ctx);
88     case CEED_MEM_DEVICE:
89       return CeedQFunctionContextSyncH2D_Hip(ctx);
90   }
91   return CEED_ERROR_UNSUPPORTED;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Set all pointers as invalid
96 //------------------------------------------------------------------------------
97 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
98   CeedQFunctionContext_Hip *impl;
99   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
100 
101   impl->h_data = NULL;
102   impl->d_data = NULL;
103 
104   return CEED_ERROR_SUCCESS;
105 }
106 
107 //------------------------------------------------------------------------------
108 // Check for valid data
109 //------------------------------------------------------------------------------
110 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
111   CeedQFunctionContext_Hip *impl;
112   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
113 
114   *has_valid_data = !!impl->h_data || !!impl->d_data;
115 
116   return CEED_ERROR_SUCCESS;
117 }
118 
119 //------------------------------------------------------------------------------
120 // Check if ctx has borrowed data
121 //------------------------------------------------------------------------------
122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
123                                                                 bool *has_borrowed_data_of_type) {
124   CeedQFunctionContext_Hip *impl;
125   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
126 
127   switch (mem_type) {
128     case CEED_MEM_HOST:
129       *has_borrowed_data_of_type = !!impl->h_data_borrowed;
130       break;
131     case CEED_MEM_DEVICE:
132       *has_borrowed_data_of_type = !!impl->d_data_borrowed;
133       break;
134   }
135 
136   return CEED_ERROR_SUCCESS;
137 }
138 
139 //------------------------------------------------------------------------------
140 // Check if data of given type needs sync
141 //------------------------------------------------------------------------------
142 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
143   CeedQFunctionContext_Hip *impl;
144   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
145 
146   bool has_valid_data = true;
147   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
148   switch (mem_type) {
149     case CEED_MEM_HOST:
150       *need_sync = has_valid_data && !impl->h_data;
151       break;
152     case CEED_MEM_DEVICE:
153       *need_sync = has_valid_data && !impl->d_data;
154       break;
155   }
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Set data from host
162 //------------------------------------------------------------------------------
163 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
164   CeedQFunctionContext_Hip *impl;
165   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
166 
167   CeedCallBackend(CeedFree(&impl->h_data_owned));
168   switch (copy_mode) {
169     case CEED_COPY_VALUES: {
170       size_t ctxsize;
171       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
172       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
173       impl->h_data_borrowed = NULL;
174       impl->h_data          = impl->h_data_owned;
175       memcpy(impl->h_data, data, ctxsize);
176     } break;
177     case CEED_OWN_POINTER:
178       impl->h_data_owned    = data;
179       impl->h_data_borrowed = NULL;
180       impl->h_data          = data;
181       break;
182     case CEED_USE_POINTER:
183       impl->h_data_borrowed = data;
184       impl->h_data          = data;
185       break;
186   }
187 
188   return CEED_ERROR_SUCCESS;
189 }
190 
191 //------------------------------------------------------------------------------
192 // Set data from device
193 //------------------------------------------------------------------------------
194 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
195   Ceed ceed;
196   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
197   CeedQFunctionContext_Hip *impl;
198   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
199 
200   CeedCallHip(ceed, hipFree(impl->d_data_owned));
201   impl->d_data_owned = NULL;
202   switch (copy_mode) {
203     case CEED_COPY_VALUES: {
204       size_t ctxsize;
205       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
206       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
207       impl->d_data_borrowed = NULL;
208       impl->d_data          = impl->d_data_owned;
209       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice));
210     } break;
211     case CEED_OWN_POINTER:
212       impl->d_data_owned    = data;
213       impl->d_data_borrowed = NULL;
214       impl->d_data          = data;
215       break;
216     case CEED_USE_POINTER:
217       impl->d_data_owned    = NULL;
218       impl->d_data_borrowed = data;
219       impl->d_data          = data;
220       break;
221   }
222 
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Set the data used by a user context, freeing any previously allocated data if applicable
228 //------------------------------------------------------------------------------
229 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
230   Ceed ceed;
231   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
232 
233   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
234   switch (mem_type) {
235     case CEED_MEM_HOST:
236       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
237     case CEED_MEM_DEVICE:
238       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
239   }
240 
241   return CEED_ERROR_UNSUPPORTED;
242 }
243 
244 //------------------------------------------------------------------------------
245 // Take data
246 //------------------------------------------------------------------------------
247 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
248   Ceed ceed;
249   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
250   CeedQFunctionContext_Hip *impl;
251   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
252 
253   // Sync data to requested mem_type
254   bool need_sync = false;
255   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
256   if (need_sync) {
257     CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
258   }
259 
260   // Update pointer
261   switch (mem_type) {
262     case CEED_MEM_HOST:
263       *(void **)data        = impl->h_data_borrowed;
264       impl->h_data_borrowed = NULL;
265       impl->h_data          = NULL;
266       break;
267     case CEED_MEM_DEVICE:
268       *(void **)data        = impl->d_data_borrowed;
269       impl->d_data_borrowed = NULL;
270       impl->d_data          = NULL;
271       break;
272   }
273 
274   return CEED_ERROR_SUCCESS;
275 }
276 
277 //------------------------------------------------------------------------------
278 // Core logic for GetData.
279 //   If a different memory type is most up to date, this will perform a copy
280 //------------------------------------------------------------------------------
281 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
282   Ceed ceed;
283   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
284   CeedQFunctionContext_Hip *impl;
285   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
286 
287   // Sync data to requested mem_type
288   bool need_sync = false;
289   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
290   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
291 
292   // Sync data to requested mem_type and update pointer
293   switch (mem_type) {
294     case CEED_MEM_HOST:
295       *(void **)data = impl->h_data;
296       break;
297     case CEED_MEM_DEVICE:
298       *(void **)data = impl->d_data;
299       break;
300   }
301 
302   return CEED_ERROR_SUCCESS;
303 }
304 
305 //------------------------------------------------------------------------------
306 // Get read-only access to the data
307 //------------------------------------------------------------------------------
308 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
309   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
310 }
311 
312 //------------------------------------------------------------------------------
313 // Get read/write access to the data
314 //------------------------------------------------------------------------------
315 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
316   CeedQFunctionContext_Hip *impl;
317   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
318 
319   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
320 
321   // Mark only pointer for requested memory as valid
322   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
323   switch (mem_type) {
324     case CEED_MEM_HOST:
325       impl->h_data = *(void **)data;
326       break;
327     case CEED_MEM_DEVICE:
328       impl->d_data = *(void **)data;
329       break;
330   }
331 
332   return CEED_ERROR_SUCCESS;
333 }
334 
335 //------------------------------------------------------------------------------
336 // Destroy the user context
337 //------------------------------------------------------------------------------
338 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
339   Ceed ceed;
340   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
341   CeedQFunctionContext_Hip *impl;
342   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
343 
344   CeedCallHip(ceed, hipFree(impl->d_data_owned));
345   CeedCallBackend(CeedFree(&impl->h_data_owned));
346   CeedCallBackend(CeedFree(&impl));
347 
348   return CEED_ERROR_SUCCESS;
349 }
350 
351 //------------------------------------------------------------------------------
352 // QFunctionContext Create
353 //------------------------------------------------------------------------------
354 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
355   CeedQFunctionContext_Hip *impl;
356   Ceed                      ceed;
357   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
358 
359   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
360   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
361   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
362   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
363   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
364   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
365   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
366 
367   CeedCallBackend(CeedCalloc(1, &impl));
368   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
369 
370   return CEED_ERROR_SUCCESS;
371 }
372 //------------------------------------------------------------------------------
373