xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunctioncontext.c (revision 24a65d3da2f623912f26b42c0b9ba6f37de25307)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <string.h>
12 #include <hip/hip_runtime.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21   Ceed                      ceed;
22   size_t                    ctx_size;
23   CeedQFunctionContext_Hip *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice));
40   return CEED_ERROR_SUCCESS;
41 }
42 
43 //------------------------------------------------------------------------------
44 // Sync device to host
45 //------------------------------------------------------------------------------
46 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
47   Ceed                      ceed;
48   size_t                    ctx_size;
49   CeedQFunctionContext_Hip *impl;
50 
51   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
52   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
53 
54   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
55 
56   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
57   if (impl->h_data_borrowed) {
58     impl->h_data = impl->h_data_borrowed;
59   } else if (impl->h_data_owned) {
60     impl->h_data = impl->h_data_owned;
61   } else {
62     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
63     impl->h_data = impl->h_data_owned;
64   }
65   CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost));
66   return CEED_ERROR_SUCCESS;
67 }
68 
69 //------------------------------------------------------------------------------
70 // Sync data of type
71 //------------------------------------------------------------------------------
72 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
73   switch (mem_type) {
74     case CEED_MEM_HOST:
75       return CeedQFunctionContextSyncD2H_Hip(ctx);
76     case CEED_MEM_DEVICE:
77       return CeedQFunctionContextSyncH2D_Hip(ctx);
78   }
79   return CEED_ERROR_UNSUPPORTED;
80 }
81 
82 //------------------------------------------------------------------------------
83 // Set all pointers as invalid
84 //------------------------------------------------------------------------------
85 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
86   CeedQFunctionContext_Hip *impl;
87 
88   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
89   impl->h_data = NULL;
90   impl->d_data = NULL;
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Check for valid data
96 //------------------------------------------------------------------------------
97 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
98   CeedQFunctionContext_Hip *impl;
99 
100   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
101   *has_valid_data = impl && (impl->h_data || impl->d_data);
102   return CEED_ERROR_SUCCESS;
103 }
104 
105 //------------------------------------------------------------------------------
106 // Check if ctx has borrowed data
107 //------------------------------------------------------------------------------
108 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
109                                                                 bool *has_borrowed_data_of_type) {
110   CeedQFunctionContext_Hip *impl;
111 
112   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
113   switch (mem_type) {
114     case CEED_MEM_HOST:
115       *has_borrowed_data_of_type = impl->h_data_borrowed;
116       break;
117     case CEED_MEM_DEVICE:
118       *has_borrowed_data_of_type = impl->d_data_borrowed;
119       break;
120   }
121   return CEED_ERROR_SUCCESS;
122 }
123 
124 //------------------------------------------------------------------------------
125 // Check if data of given type needs sync
126 //------------------------------------------------------------------------------
127 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
128   bool                      has_valid_data = true;
129   CeedQFunctionContext_Hip *impl;
130 
131   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
132   CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
133   switch (mem_type) {
134     case CEED_MEM_HOST:
135       *need_sync = has_valid_data && !impl->h_data;
136       break;
137     case CEED_MEM_DEVICE:
138       *need_sync = has_valid_data && !impl->d_data;
139       break;
140   }
141   return CEED_ERROR_SUCCESS;
142 }
143 
144 //------------------------------------------------------------------------------
145 // Set data from host
146 //------------------------------------------------------------------------------
147 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
148   CeedQFunctionContext_Hip *impl;
149 
150   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
151   CeedCallBackend(CeedFree(&impl->h_data_owned));
152   switch (copy_mode) {
153     case CEED_COPY_VALUES: {
154       size_t ctx_size;
155 
156       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
157       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
158       impl->h_data_borrowed = NULL;
159       impl->h_data          = impl->h_data_owned;
160       memcpy(impl->h_data, data, ctx_size);
161     } break;
162     case CEED_OWN_POINTER:
163       impl->h_data_owned    = data;
164       impl->h_data_borrowed = NULL;
165       impl->h_data          = data;
166       break;
167     case CEED_USE_POINTER:
168       impl->h_data_borrowed = data;
169       impl->h_data          = data;
170       break;
171   }
172   return CEED_ERROR_SUCCESS;
173 }
174 
175 //------------------------------------------------------------------------------
176 // Set data from device
177 //------------------------------------------------------------------------------
178 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
179   Ceed                      ceed;
180   CeedQFunctionContext_Hip *impl;
181 
182   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
183   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
184 
185   CeedCallHip(ceed, hipFree(impl->d_data_owned));
186   impl->d_data_owned = NULL;
187   switch (copy_mode) {
188     case CEED_COPY_VALUES: {
189       size_t ctx_size;
190       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
191       CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
192       impl->d_data_borrowed = NULL;
193       impl->d_data          = impl->d_data_owned;
194       CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice));
195     } break;
196     case CEED_OWN_POINTER:
197       impl->d_data_owned    = data;
198       impl->d_data_borrowed = NULL;
199       impl->d_data          = data;
200       break;
201     case CEED_USE_POINTER:
202       impl->d_data_owned    = NULL;
203       impl->d_data_borrowed = data;
204       impl->d_data          = data;
205       break;
206   }
207   return CEED_ERROR_SUCCESS;
208 }
209 
210 //------------------------------------------------------------------------------
211 // Set the data used by a user context,
212 //    freeing any previously allocated data if applicable
213 //------------------------------------------------------------------------------
214 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
215   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
216   switch (mem_type) {
217     case CEED_MEM_HOST:
218       return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
219     case CEED_MEM_DEVICE:
220       return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
221   }
222   return CEED_ERROR_UNSUPPORTED;
223 }
224 
225 //------------------------------------------------------------------------------
226 // Take data
227 //------------------------------------------------------------------------------
228 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
229   bool                      need_sync = false;
230   CeedQFunctionContext_Hip *impl;
231 
232   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
233 
234   // Sync data to requested mem_type
235   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
236   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
237 
238   // Update pointer
239   switch (mem_type) {
240     case CEED_MEM_HOST:
241       *(void **)data        = impl->h_data_borrowed;
242       impl->h_data_borrowed = NULL;
243       impl->h_data          = NULL;
244       break;
245     case CEED_MEM_DEVICE:
246       *(void **)data        = impl->d_data_borrowed;
247       impl->d_data_borrowed = NULL;
248       impl->d_data          = NULL;
249       break;
250   }
251   return CEED_ERROR_SUCCESS;
252 }
253 
254 //------------------------------------------------------------------------------
255 // Core logic for GetData.
256 //   If a different memory type is most up to date, this will perform a copy
257 //------------------------------------------------------------------------------
258 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
259   bool                      need_sync = false;
260   CeedQFunctionContext_Hip *impl;
261 
262   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
263 
264   // Sync data to requested mem_type
265   CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
266   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
267 
268   // Update pointer
269   switch (mem_type) {
270     case CEED_MEM_HOST:
271       *(void **)data = impl->h_data;
272       break;
273     case CEED_MEM_DEVICE:
274       *(void **)data = impl->d_data;
275       break;
276   }
277   return CEED_ERROR_SUCCESS;
278 }
279 
280 //------------------------------------------------------------------------------
281 // Get read-only access to the data
282 //------------------------------------------------------------------------------
283 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
284   return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
285 }
286 
287 //------------------------------------------------------------------------------
288 // Get read/write access to the data
289 //------------------------------------------------------------------------------
290 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
291   CeedQFunctionContext_Hip *impl;
292 
293   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
294   CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
295 
296   // Mark only pointer for requested memory as valid
297   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
298   switch (mem_type) {
299     case CEED_MEM_HOST:
300       impl->h_data = *(void **)data;
301       break;
302     case CEED_MEM_DEVICE:
303       impl->d_data = *(void **)data;
304       break;
305   }
306   return CEED_ERROR_SUCCESS;
307 }
308 
309 //------------------------------------------------------------------------------
310 // Destroy the user context
311 //------------------------------------------------------------------------------
312 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
313   CeedQFunctionContext_Hip *impl;
314 
315   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
316   CeedCallHip(CeedQFunctionContextReturnCeed(ctx), hipFree(impl->d_data_owned));
317   CeedCallBackend(CeedFree(&impl->h_data_owned));
318   CeedCallBackend(CeedFree(&impl));
319   return CEED_ERROR_SUCCESS;
320 }
321 
322 //------------------------------------------------------------------------------
323 // QFunctionContext Create
324 //------------------------------------------------------------------------------
325 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
326   CeedQFunctionContext_Hip *impl;
327   Ceed                      ceed;
328 
329   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
330   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
331   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
332   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
333   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
334   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
335   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
336   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
337   CeedCallBackend(CeedCalloc(1, &impl));
338   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
339   return CEED_ERROR_SUCCESS;
340 }
341 
342 //------------------------------------------------------------------------------
343