1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <string.h> 12 #include <hip/hip_runtime.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 size_t ctx_size; 23 CeedQFunctionContext_Hip *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 27 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 37 impl->d_data = impl->d_data_owned; 38 } 39 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice)); 40 CeedCallBackend(CeedDestroy(&ceed)); 41 return CEED_ERROR_SUCCESS; 42 } 43 44 //------------------------------------------------------------------------------ 45 // Sync device to host 46 //------------------------------------------------------------------------------ 47 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 48 Ceed ceed; 49 size_t ctx_size; 50 CeedQFunctionContext_Hip *impl; 51 52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 53 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 54 55 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 56 57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 58 if (impl->h_data_borrowed) { 59 impl->h_data = impl->h_data_borrowed; 60 } else if (impl->h_data_owned) { 61 impl->h_data = impl->h_data_owned; 62 } else { 63 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 64 impl->h_data = impl->h_data_owned; 65 } 66 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost)); 67 CeedCallBackend(CeedDestroy(&ceed)); 68 return CEED_ERROR_SUCCESS; 69 } 70 71 //------------------------------------------------------------------------------ 72 // Sync data of type 73 //------------------------------------------------------------------------------ 74 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 75 switch (mem_type) { 76 case CEED_MEM_HOST: 77 return CeedQFunctionContextSyncD2H_Hip(ctx); 78 case CEED_MEM_DEVICE: 79 return CeedQFunctionContextSyncH2D_Hip(ctx); 80 } 81 return CEED_ERROR_UNSUPPORTED; 82 } 83 84 //------------------------------------------------------------------------------ 85 // Set all pointers as invalid 86 //------------------------------------------------------------------------------ 87 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 88 CeedQFunctionContext_Hip *impl; 89 90 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 91 impl->h_data = NULL; 92 impl->d_data = NULL; 93 return CEED_ERROR_SUCCESS; 94 } 95 96 //------------------------------------------------------------------------------ 97 // Check for valid data 98 //------------------------------------------------------------------------------ 99 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 100 CeedQFunctionContext_Hip *impl; 101 102 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 103 *has_valid_data = impl && (impl->h_data || impl->d_data); 104 return CEED_ERROR_SUCCESS; 105 } 106 107 //------------------------------------------------------------------------------ 108 // Check if ctx has borrowed data 109 //------------------------------------------------------------------------------ 110 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 111 bool *has_borrowed_data_of_type) { 112 CeedQFunctionContext_Hip *impl; 113 114 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 115 switch (mem_type) { 116 case CEED_MEM_HOST: 117 *has_borrowed_data_of_type = impl->h_data_borrowed; 118 break; 119 case CEED_MEM_DEVICE: 120 *has_borrowed_data_of_type = impl->d_data_borrowed; 121 break; 122 } 123 return CEED_ERROR_SUCCESS; 124 } 125 126 //------------------------------------------------------------------------------ 127 // Check if data of given type needs sync 128 //------------------------------------------------------------------------------ 129 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 130 bool has_valid_data = true; 131 CeedQFunctionContext_Hip *impl; 132 133 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 134 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 135 switch (mem_type) { 136 case CEED_MEM_HOST: 137 *need_sync = has_valid_data && !impl->h_data; 138 break; 139 case CEED_MEM_DEVICE: 140 *need_sync = has_valid_data && !impl->d_data; 141 break; 142 } 143 return CEED_ERROR_SUCCESS; 144 } 145 146 //------------------------------------------------------------------------------ 147 // Set data from host 148 //------------------------------------------------------------------------------ 149 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 150 CeedQFunctionContext_Hip *impl; 151 152 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 153 CeedCallBackend(CeedFree(&impl->h_data_owned)); 154 switch (copy_mode) { 155 case CEED_COPY_VALUES: { 156 size_t ctx_size; 157 158 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 159 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 160 impl->h_data_borrowed = NULL; 161 impl->h_data = impl->h_data_owned; 162 memcpy(impl->h_data, data, ctx_size); 163 } break; 164 case CEED_OWN_POINTER: 165 impl->h_data_owned = data; 166 impl->h_data_borrowed = NULL; 167 impl->h_data = data; 168 break; 169 case CEED_USE_POINTER: 170 impl->h_data_borrowed = data; 171 impl->h_data = data; 172 break; 173 } 174 return CEED_ERROR_SUCCESS; 175 } 176 177 //------------------------------------------------------------------------------ 178 // Set data from device 179 //------------------------------------------------------------------------------ 180 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 181 Ceed ceed; 182 CeedQFunctionContext_Hip *impl; 183 184 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 185 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 186 187 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 188 impl->d_data_owned = NULL; 189 switch (copy_mode) { 190 case CEED_COPY_VALUES: { 191 size_t ctx_size; 192 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 193 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 194 impl->d_data_borrowed = NULL; 195 impl->d_data = impl->d_data_owned; 196 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice)); 197 } break; 198 case CEED_OWN_POINTER: 199 impl->d_data_owned = data; 200 impl->d_data_borrowed = NULL; 201 impl->d_data = data; 202 break; 203 case CEED_USE_POINTER: 204 impl->d_data_owned = NULL; 205 impl->d_data_borrowed = data; 206 impl->d_data = data; 207 break; 208 } 209 CeedCallBackend(CeedDestroy(&ceed)); 210 return CEED_ERROR_SUCCESS; 211 } 212 213 //------------------------------------------------------------------------------ 214 // Set the data used by a user context, 215 // freeing any previously allocated data if applicable 216 //------------------------------------------------------------------------------ 217 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 218 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 219 switch (mem_type) { 220 case CEED_MEM_HOST: 221 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 222 case CEED_MEM_DEVICE: 223 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 224 } 225 return CEED_ERROR_UNSUPPORTED; 226 } 227 228 //------------------------------------------------------------------------------ 229 // Take data 230 //------------------------------------------------------------------------------ 231 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 232 bool need_sync = false; 233 CeedQFunctionContext_Hip *impl; 234 235 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 236 237 // Sync data to requested mem_type 238 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 239 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 240 241 // Update pointer 242 switch (mem_type) { 243 case CEED_MEM_HOST: 244 *(void **)data = impl->h_data_borrowed; 245 impl->h_data_borrowed = NULL; 246 impl->h_data = NULL; 247 break; 248 case CEED_MEM_DEVICE: 249 *(void **)data = impl->d_data_borrowed; 250 impl->d_data_borrowed = NULL; 251 impl->d_data = NULL; 252 break; 253 } 254 return CEED_ERROR_SUCCESS; 255 } 256 257 //------------------------------------------------------------------------------ 258 // Core logic for GetData. 259 // If a different memory type is most up to date, this will perform a copy 260 //------------------------------------------------------------------------------ 261 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 262 bool need_sync = false; 263 CeedQFunctionContext_Hip *impl; 264 265 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 266 267 // Sync data to requested mem_type 268 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 269 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 270 271 // Update pointer 272 switch (mem_type) { 273 case CEED_MEM_HOST: 274 *(void **)data = impl->h_data; 275 break; 276 case CEED_MEM_DEVICE: 277 *(void **)data = impl->d_data; 278 break; 279 } 280 return CEED_ERROR_SUCCESS; 281 } 282 283 //------------------------------------------------------------------------------ 284 // Get read-only access to the data 285 //------------------------------------------------------------------------------ 286 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 287 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 288 } 289 290 //------------------------------------------------------------------------------ 291 // Get read/write access to the data 292 //------------------------------------------------------------------------------ 293 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 294 CeedQFunctionContext_Hip *impl; 295 296 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 297 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 298 299 // Mark only pointer for requested memory as valid 300 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 301 switch (mem_type) { 302 case CEED_MEM_HOST: 303 impl->h_data = *(void **)data; 304 break; 305 case CEED_MEM_DEVICE: 306 impl->d_data = *(void **)data; 307 break; 308 } 309 return CEED_ERROR_SUCCESS; 310 } 311 312 //------------------------------------------------------------------------------ 313 // Destroy the user context 314 //------------------------------------------------------------------------------ 315 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 316 CeedQFunctionContext_Hip *impl; 317 318 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 319 CeedCallHip(CeedQFunctionContextReturnCeed(ctx), hipFree(impl->d_data_owned)); 320 CeedCallBackend(CeedFree(&impl->h_data_owned)); 321 CeedCallBackend(CeedFree(&impl)); 322 return CEED_ERROR_SUCCESS; 323 } 324 325 //------------------------------------------------------------------------------ 326 // QFunctionContext Create 327 //------------------------------------------------------------------------------ 328 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 329 CeedQFunctionContext_Hip *impl; 330 Ceed ceed; 331 332 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 333 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 337 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 338 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 339 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 340 CeedCallBackend(CeedDestroy(&ceed)); 341 CeedCallBackend(CeedCalloc(1, &impl)); 342 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 343 return CEED_ERROR_SUCCESS; 344 } 345 346 //------------------------------------------------------------------------------ 347