xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision ca94c3ddc8f82b7d93a79f9e4812e99b8be840ff)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <string.h>
12 #include <hip/hip_runtime.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21   Ceed                      ceed;
22   size_t                    ctx_size;
23   CeedQFunctionContext_Hip *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice));
40   return CEED_ERROR_SUCCESS;
41 }
42 
43 //------------------------------------------------------------------------------
44 // Sync device to host
45 //------------------------------------------------------------------------------
46 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
47   Ceed                      ceed;
48   size_t                    ctx_size;
49   CeedQFunctionContext_Hip *impl;
50 
51   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
52   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
53 
54   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
55 
56   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
57   if (impl->h_data_borrowed) {
58     impl->h_data = impl->h_data_borrowed;
59   } else if (impl->h_data_owned) {
60     impl->h_data = impl->h_data_owned;
61   } else {
62     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
63     impl->h_data = impl->h_data_owned;
64   }
65   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost));
66   return CEED_ERROR_SUCCESS;
67 }
68 
69 //------------------------------------------------------------------------------
70 // Sync data of type
71 //------------------------------------------------------------------------------
72 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
73   switch (mem_type) {
74     case CEED_MEM_HOST:
75       return CeedQFunctionContextSyncD2H_Hip(ctx);
76     case CEED_MEM_DEVICE:
77       return CeedQFunctionContextSyncH2D_Hip(ctx);
78   }
79   return CEED_ERROR_UNSUPPORTED;
80 }
81 
82 //------------------------------------------------------------------------------
83 // Set all pointers as invalid
84 //------------------------------------------------------------------------------
85 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
86   CeedQFunctionContext_Hip *impl;
87 
88   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
89   impl->h_data = NULL;
90   impl->d_data = NULL;
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Check for valid data
96 //------------------------------------------------------------------------------
97 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
98   CeedQFunctionContext_Hip *impl;
99 
100   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
101   *has_valid_data = impl && (impl->h_data || impl->d_data);
102   return CEED_ERROR_SUCCESS;
103 }
104 
105 //------------------------------------------------------------------------------
106 // Check if ctx has borrowed data
107 //------------------------------------------------------------------------------
108 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
109                                                                 bool *has_borrowed_data_of_type) {
110   CeedQFunctionContext_Hip *impl;
111 
112   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
113   switch (mem_type) {
114     case CEED_MEM_HOST:
115       *has_borrowed_data_of_type = impl->h_data_borrowed;
116       break;
117     case CEED_MEM_DEVICE:
118       *has_borrowed_data_of_type = impl->d_data_borrowed;
119       break;
120   }
121   return CEED_ERROR_SUCCESS;
122 }
123 
124 //------------------------------------------------------------------------------
125 // Check if data of given type needs sync
126 //------------------------------------------------------------------------------
127 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
128   bool                      has_valid_data = true;
129   CeedQFunctionContext_Hip *impl;
130 
131   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
132   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
133   switch (mem_type) {
134     case CEED_MEM_HOST:
135       *need_sync = has_valid_data && !impl->h_data;
136       break;
137     case CEED_MEM_DEVICE:
138       *need_sync = has_valid_data && !impl->d_data;
139       break;
140   }
141   return CEED_ERROR_SUCCESS;
142 }
143 
144 //------------------------------------------------------------------------------
145 // Set data from host
146 //------------------------------------------------------------------------------
147 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
148   CeedQFunctionContext_Hip *impl;
149 
150   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
151   CeedCallBackend(CeedFree(&impl->h_data_owned));
152   switch (copy_mode) {
153     case CEED_COPY_VALUES: {
154       size_t ctx_size;
155 
156       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
157       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
158       impl->h_data_borrowed = NULL;
159       impl->h_data          = impl->h_data_owned;
160       memcpy(impl->h_data, data, ctx_size);
161     } break;
162     case CEED_OWN_POINTER:
163       impl->h_data_owned    = data;
164       impl->h_data_borrowed = NULL;
165       impl->h_data          = data;
166       break;
167     case CEED_USE_POINTER:
168       impl->h_data_borrowed = data;
169       impl->h_data          = data;
170       break;
171   }
172   return CEED_ERROR_SUCCESS;
173 }
174 
175 //------------------------------------------------------------------------------
176 // Set data from device
177 //------------------------------------------------------------------------------
178 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
179   Ceed                      ceed;
180   CeedQFunctionContext_Hip *impl;
181 
182   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
183   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
184 
185   CeedCallHip(ceed, hipFree(impl->d_data_owned));
186   impl->d_data_owned = NULL;
187   switch (copy_mode) {
188     case CEED_COPY_VALUES: {
189       size_t ctx_size;
190       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
191       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
192       impl->d_data_borrowed = NULL;
193       impl->d_data          = impl->d_data_owned;
194       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice));
195     } break;
196     case CEED_OWN_POINTER:
197       impl->d_data_owned    = data;
198       impl->d_data_borrowed = NULL;
199       impl->d_data          = data;
200       break;
201     case CEED_USE_POINTER:
202       impl->d_data_owned    = NULL;
203       impl->d_data_borrowed = data;
204       impl->d_data          = data;
205       break;
206   }
207   return CEED_ERROR_SUCCESS;
208 }
209 
210 //------------------------------------------------------------------------------
211 // Set the data used by a user context,
212 //    freeing any previously allocated data if applicable
213 //------------------------------------------------------------------------------
214 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
215   Ceed ceed;
216 
217   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
218   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
219   switch (mem_type) {
220     case CEED_MEM_HOST:
221       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
222     case CEED_MEM_DEVICE:
223       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
224   }
225   return CEED_ERROR_UNSUPPORTED;
226 }
227 
228 //------------------------------------------------------------------------------
229 // Take data
230 //------------------------------------------------------------------------------
231 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
232   bool                      need_sync = false;
233   Ceed                      ceed;
234   CeedQFunctionContext_Hip *impl;
235 
236   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
237   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
238 
239   // Sync data to requested mem_type
240   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
241   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
242 
243   // Update pointer
244   switch (mem_type) {
245     case CEED_MEM_HOST:
246       *(void **)data        = impl->h_data_borrowed;
247       impl->h_data_borrowed = NULL;
248       impl->h_data          = NULL;
249       break;
250     case CEED_MEM_DEVICE:
251       *(void **)data        = impl->d_data_borrowed;
252       impl->d_data_borrowed = NULL;
253       impl->d_data          = NULL;
254       break;
255   }
256   return CEED_ERROR_SUCCESS;
257 }
258 
259 //------------------------------------------------------------------------------
260 // Core logic for GetData.
261 //   If a different memory type is most up to date, this will perform a copy
262 //------------------------------------------------------------------------------
263 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
264   bool                      need_sync = false;
265   Ceed                      ceed;
266   CeedQFunctionContext_Hip *impl;
267 
268   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
269   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
270 
271   // Sync data to requested mem_type
272   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
273   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
274 
275   // Update pointer
276   switch (mem_type) {
277     case CEED_MEM_HOST:
278       *(void **)data = impl->h_data;
279       break;
280     case CEED_MEM_DEVICE:
281       *(void **)data = impl->d_data;
282       break;
283   }
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Get read-only access to the data
289 //------------------------------------------------------------------------------
290 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
291   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
292 }
293 
294 //------------------------------------------------------------------------------
295 // Get read/write access to the data
296 //------------------------------------------------------------------------------
297 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
298   CeedQFunctionContext_Hip *impl;
299 
300   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
301   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
302 
303   // Mark only pointer for requested memory as valid
304   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
305   switch (mem_type) {
306     case CEED_MEM_HOST:
307       impl->h_data = *(void **)data;
308       break;
309     case CEED_MEM_DEVICE:
310       impl->d_data = *(void **)data;
311       break;
312   }
313   return CEED_ERROR_SUCCESS;
314 }
315 
316 //------------------------------------------------------------------------------
317 // Destroy the user context
318 //------------------------------------------------------------------------------
319 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
320   Ceed                      ceed;
321   CeedQFunctionContext_Hip *impl;
322 
323   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
324   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
325   CeedCallHip(ceed, hipFree(impl->d_data_owned));
326   CeedCallBackend(CeedFree(&impl->h_data_owned));
327   CeedCallBackend(CeedFree(&impl));
328   return CEED_ERROR_SUCCESS;
329 }
330 
331 //------------------------------------------------------------------------------
332 // QFunctionContext Create
333 //------------------------------------------------------------------------------
334 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
335   CeedQFunctionContext_Hip *impl;
336   Ceed                      ceed;
337 
338   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
339   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
340   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
341   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
342   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
343   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
344   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
345   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
346   CeedCallBackend(CeedCalloc(1, &impl));
347   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
348   return CEED_ERROR_SUCCESS;
349 }
350 
351 //------------------------------------------------------------------------------
352