xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision ea61e9ac44808524e4667c1525a05976f536c19c)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <hip/hip_runtime.h>
11 #include <string.h>
12 
13 #include "ceed-hip-ref.h"
14 
15 //------------------------------------------------------------------------------
16 // Sync host to device
17 //------------------------------------------------------------------------------
18 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
19   Ceed ceed;
20   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
21   CeedQFunctionContext_Hip *impl;
22   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
23 
24   if (!impl->h_data) {
25     // LCOV_EXCL_START
26     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
27     // LCOV_EXCL_STOP
28   }
29 
30   size_t ctxsize;
31   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
32 
33   if (impl->d_data_borrowed) {
34     impl->d_data = impl->d_data_borrowed;
35   } else if (impl->d_data_owned) {
36     impl->d_data = impl->d_data_owned;
37   } else {
38     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
39     impl->d_data = impl->d_data_owned;
40   }
41 
42   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice));
43 
44   return CEED_ERROR_SUCCESS;
45 }
46 
47 //------------------------------------------------------------------------------
48 // Sync device to host
49 //------------------------------------------------------------------------------
50 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
51   Ceed ceed;
52   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53   CeedQFunctionContext_Hip *impl;
54   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
55 
56   if (!impl->d_data) {
57     // LCOV_EXCL_START
58     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
59     // LCOV_EXCL_STOP
60   }
61 
62   size_t ctxsize;
63   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
64 
65   if (impl->h_data_borrowed) {
66     impl->h_data = impl->h_data_borrowed;
67   } else if (impl->h_data_owned) {
68     impl->h_data = impl->h_data_owned;
69   } else {
70     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
71     impl->h_data = impl->h_data_owned;
72   }
73 
74   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost));
75 
76   return CEED_ERROR_SUCCESS;
77 }
78 
79 //------------------------------------------------------------------------------
80 // Sync data of type
81 //------------------------------------------------------------------------------
82 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
83   switch (mem_type) {
84     case CEED_MEM_HOST:
85       return CeedQFunctionContextSyncD2H_Hip(ctx);
86     case CEED_MEM_DEVICE:
87       return CeedQFunctionContextSyncH2D_Hip(ctx);
88   }
89   return CEED_ERROR_UNSUPPORTED;
90 }
91 
92 //------------------------------------------------------------------------------
93 // Set all pointers as invalid
94 //------------------------------------------------------------------------------
95 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
96   CeedQFunctionContext_Hip *impl;
97   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
98 
99   impl->h_data = NULL;
100   impl->d_data = NULL;
101 
102   return CEED_ERROR_SUCCESS;
103 }
104 
105 //------------------------------------------------------------------------------
106 // Check for valid data
107 //------------------------------------------------------------------------------
108 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
109   CeedQFunctionContext_Hip *impl;
110   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
111 
112   *has_valid_data = !!impl->h_data || !!impl->d_data;
113 
114   return CEED_ERROR_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Check if ctx has borrowed data
119 //------------------------------------------------------------------------------
120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
121                                                                 bool *has_borrowed_data_of_type) {
122   CeedQFunctionContext_Hip *impl;
123   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
124 
125   switch (mem_type) {
126     case CEED_MEM_HOST:
127       *has_borrowed_data_of_type = !!impl->h_data_borrowed;
128       break;
129     case CEED_MEM_DEVICE:
130       *has_borrowed_data_of_type = !!impl->d_data_borrowed;
131       break;
132   }
133 
134   return CEED_ERROR_SUCCESS;
135 }
136 
137 //------------------------------------------------------------------------------
138 // Check if data of given type needs sync
139 //------------------------------------------------------------------------------
140 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
141   CeedQFunctionContext_Hip *impl;
142   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
143 
144   bool has_valid_data = true;
145   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
146   switch (mem_type) {
147     case CEED_MEM_HOST:
148       *need_sync = has_valid_data && !impl->h_data;
149       break;
150     case CEED_MEM_DEVICE:
151       *need_sync = has_valid_data && !impl->d_data;
152       break;
153   }
154 
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Set data from host
160 //------------------------------------------------------------------------------
161 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
162   CeedQFunctionContext_Hip *impl;
163   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
164 
165   CeedCallBackend(CeedFree(&impl->h_data_owned));
166   switch (copy_mode) {
167     case CEED_COPY_VALUES: {
168       size_t ctxsize;
169       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
170       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
171       impl->h_data_borrowed = NULL;
172       impl->h_data          = impl->h_data_owned;
173       memcpy(impl->h_data, data, ctxsize);
174     } break;
175     case CEED_OWN_POINTER:
176       impl->h_data_owned    = data;
177       impl->h_data_borrowed = NULL;
178       impl->h_data          = data;
179       break;
180     case CEED_USE_POINTER:
181       impl->h_data_borrowed = data;
182       impl->h_data          = data;
183       break;
184   }
185 
186   return CEED_ERROR_SUCCESS;
187 }
188 
189 //------------------------------------------------------------------------------
190 // Set data from device
191 //------------------------------------------------------------------------------
192 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
193   Ceed ceed;
194   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
195   CeedQFunctionContext_Hip *impl;
196   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
197 
198   CeedCallHip(ceed, hipFree(impl->d_data_owned));
199   impl->d_data_owned = NULL;
200   switch (copy_mode) {
201     case CEED_COPY_VALUES: {
202       size_t ctxsize;
203       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
204       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize));
205       impl->d_data_borrowed = NULL;
206       impl->d_data          = impl->d_data_owned;
207       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice));
208     } break;
209     case CEED_OWN_POINTER:
210       impl->d_data_owned    = data;
211       impl->d_data_borrowed = NULL;
212       impl->d_data          = data;
213       break;
214     case CEED_USE_POINTER:
215       impl->d_data_owned    = NULL;
216       impl->d_data_borrowed = data;
217       impl->d_data          = data;
218       break;
219   }
220 
221   return CEED_ERROR_SUCCESS;
222 }
223 
224 //------------------------------------------------------------------------------
225 // Set the data used by a user context, freeing any previously allocated data if applicable
226 //------------------------------------------------------------------------------
227 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
228   Ceed ceed;
229   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
230 
231   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
232   switch (mem_type) {
233     case CEED_MEM_HOST:
234       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
235     case CEED_MEM_DEVICE:
236       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
237   }
238 
239   return CEED_ERROR_UNSUPPORTED;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Take data
244 //------------------------------------------------------------------------------
245 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
246   Ceed ceed;
247   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
248   CeedQFunctionContext_Hip *impl;
249   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
250 
251   // Sync data to requested mem_type
252   bool need_sync = false;
253   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
254   if (need_sync) {
255     CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
256   }
257 
258   // Update pointer
259   switch (mem_type) {
260     case CEED_MEM_HOST:
261       *(void **)data        = impl->h_data_borrowed;
262       impl->h_data_borrowed = NULL;
263       impl->h_data          = NULL;
264       break;
265     case CEED_MEM_DEVICE:
266       *(void **)data        = impl->d_data_borrowed;
267       impl->d_data_borrowed = NULL;
268       impl->d_data          = NULL;
269       break;
270   }
271 
272   return CEED_ERROR_SUCCESS;
273 }
274 
275 //------------------------------------------------------------------------------
276 // Core logic for GetData.
277 //   If a different memory type is most up to date, this will perform a copy
278 //------------------------------------------------------------------------------
279 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
280   Ceed ceed;
281   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
282   CeedQFunctionContext_Hip *impl;
283   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
284 
285   // Sync data to requested mem_type
286   bool need_sync = false;
287   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
288   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
289 
290   // Sync data to requested mem_type and update pointer
291   switch (mem_type) {
292     case CEED_MEM_HOST:
293       *(void **)data = impl->h_data;
294       break;
295     case CEED_MEM_DEVICE:
296       *(void **)data = impl->d_data;
297       break;
298   }
299 
300   return CEED_ERROR_SUCCESS;
301 }
302 
303 //------------------------------------------------------------------------------
304 // Get read-only access to the data
305 //------------------------------------------------------------------------------
306 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
307   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
308 }
309 
310 //------------------------------------------------------------------------------
311 // Get read/write access to the data
312 //------------------------------------------------------------------------------
313 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
314   CeedQFunctionContext_Hip *impl;
315   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
316 
317   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
318 
319   // Mark only pointer for requested memory as valid
320   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
321   switch (mem_type) {
322     case CEED_MEM_HOST:
323       impl->h_data = *(void **)data;
324       break;
325     case CEED_MEM_DEVICE:
326       impl->d_data = *(void **)data;
327       break;
328   }
329 
330   return CEED_ERROR_SUCCESS;
331 }
332 
333 //------------------------------------------------------------------------------
334 // Destroy the user context
335 //------------------------------------------------------------------------------
336 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
337   Ceed ceed;
338   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
339   CeedQFunctionContext_Hip *impl;
340   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
341 
342   CeedCallHip(ceed, hipFree(impl->d_data_owned));
343   CeedCallBackend(CeedFree(&impl->h_data_owned));
344   CeedCallBackend(CeedFree(&impl));
345 
346   return CEED_ERROR_SUCCESS;
347 }
348 
349 //------------------------------------------------------------------------------
350 // QFunctionContext Create
351 //------------------------------------------------------------------------------
352 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
353   CeedQFunctionContext_Hip *impl;
354   Ceed                      ceed;
355   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
356 
357   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
358   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
359   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
360   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
361   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
362   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
363   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
364 
365   CeedCallBackend(CeedCalloc(1, &impl));
366   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
367 
368   return CEED_ERROR_SUCCESS;
369 }
370 //------------------------------------------------------------------------------
371