1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <string.h> 12 #include <hip/hip_runtime.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 size_t ctx_size; 23 CeedQFunctionContext_Hip *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 27 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 37 impl->d_data = impl->d_data_owned; 38 } 39 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice)); 40 return CEED_ERROR_SUCCESS; 41 } 42 43 //------------------------------------------------------------------------------ 44 // Sync device to host 45 //------------------------------------------------------------------------------ 46 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 47 Ceed ceed; 48 size_t ctx_size; 49 CeedQFunctionContext_Hip *impl; 50 51 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 52 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 53 54 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 55 56 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 57 if (impl->h_data_borrowed) { 58 impl->h_data = impl->h_data_borrowed; 59 } else if (impl->h_data_owned) { 60 impl->h_data = impl->h_data_owned; 61 } else { 62 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 63 impl->h_data = impl->h_data_owned; 64 } 65 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost)); 66 return CEED_ERROR_SUCCESS; 67 } 68 69 //------------------------------------------------------------------------------ 70 // Sync data of type 71 //------------------------------------------------------------------------------ 72 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 73 switch (mem_type) { 74 case CEED_MEM_HOST: 75 return CeedQFunctionContextSyncD2H_Hip(ctx); 76 case CEED_MEM_DEVICE: 77 return CeedQFunctionContextSyncH2D_Hip(ctx); 78 } 79 return CEED_ERROR_UNSUPPORTED; 80 } 81 82 //------------------------------------------------------------------------------ 83 // Set all pointers as invalid 84 //------------------------------------------------------------------------------ 85 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 86 CeedQFunctionContext_Hip *impl; 87 88 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 89 impl->h_data = NULL; 90 impl->d_data = NULL; 91 return CEED_ERROR_SUCCESS; 92 } 93 94 //------------------------------------------------------------------------------ 95 // Check for valid data 96 //------------------------------------------------------------------------------ 97 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 98 CeedQFunctionContext_Hip *impl; 99 100 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 101 *has_valid_data = impl && (impl->h_data || impl->d_data); 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106 // Check if ctx has borrowed data 107 //------------------------------------------------------------------------------ 108 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 109 bool *has_borrowed_data_of_type) { 110 CeedQFunctionContext_Hip *impl; 111 112 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 113 switch (mem_type) { 114 case CEED_MEM_HOST: 115 *has_borrowed_data_of_type = impl->h_data_borrowed; 116 break; 117 case CEED_MEM_DEVICE: 118 *has_borrowed_data_of_type = impl->d_data_borrowed; 119 break; 120 } 121 return CEED_ERROR_SUCCESS; 122 } 123 124 //------------------------------------------------------------------------------ 125 // Check if data of given type needs sync 126 //------------------------------------------------------------------------------ 127 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 128 bool has_valid_data = true; 129 CeedQFunctionContext_Hip *impl; 130 131 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 132 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 133 switch (mem_type) { 134 case CEED_MEM_HOST: 135 *need_sync = has_valid_data && !impl->h_data; 136 break; 137 case CEED_MEM_DEVICE: 138 *need_sync = has_valid_data && !impl->d_data; 139 break; 140 } 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Set data from host 146 //------------------------------------------------------------------------------ 147 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 148 CeedQFunctionContext_Hip *impl; 149 150 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 151 CeedCallBackend(CeedFree(&impl->h_data_owned)); 152 switch (copy_mode) { 153 case CEED_COPY_VALUES: { 154 size_t ctx_size; 155 156 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 157 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 158 impl->h_data_borrowed = NULL; 159 impl->h_data = impl->h_data_owned; 160 memcpy(impl->h_data, data, ctx_size); 161 } break; 162 case CEED_OWN_POINTER: 163 impl->h_data_owned = data; 164 impl->h_data_borrowed = NULL; 165 impl->h_data = data; 166 break; 167 case CEED_USE_POINTER: 168 impl->h_data_borrowed = data; 169 impl->h_data = data; 170 break; 171 } 172 return CEED_ERROR_SUCCESS; 173 } 174 175 //------------------------------------------------------------------------------ 176 // Set data from device 177 //------------------------------------------------------------------------------ 178 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 179 Ceed ceed; 180 CeedQFunctionContext_Hip *impl; 181 182 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 183 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 184 185 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 186 impl->d_data_owned = NULL; 187 switch (copy_mode) { 188 case CEED_COPY_VALUES: { 189 size_t ctx_size; 190 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 191 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 192 impl->d_data_borrowed = NULL; 193 impl->d_data = impl->d_data_owned; 194 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice)); 195 } break; 196 case CEED_OWN_POINTER: 197 impl->d_data_owned = data; 198 impl->d_data_borrowed = NULL; 199 impl->d_data = data; 200 break; 201 case CEED_USE_POINTER: 202 impl->d_data_owned = NULL; 203 impl->d_data_borrowed = data; 204 impl->d_data = data; 205 break; 206 } 207 return CEED_ERROR_SUCCESS; 208 } 209 210 //------------------------------------------------------------------------------ 211 // Set the data used by a user context, 212 // freeing any previously allocated data if applicable 213 //------------------------------------------------------------------------------ 214 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 215 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 216 switch (mem_type) { 217 case CEED_MEM_HOST: 218 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 219 case CEED_MEM_DEVICE: 220 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 221 } 222 return CEED_ERROR_UNSUPPORTED; 223 } 224 225 //------------------------------------------------------------------------------ 226 // Take data 227 //------------------------------------------------------------------------------ 228 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 229 bool need_sync = false; 230 CeedQFunctionContext_Hip *impl; 231 232 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 233 234 // Sync data to requested mem_type 235 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 236 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 237 238 // Update pointer 239 switch (mem_type) { 240 case CEED_MEM_HOST: 241 *(void **)data = impl->h_data_borrowed; 242 impl->h_data_borrowed = NULL; 243 impl->h_data = NULL; 244 break; 245 case CEED_MEM_DEVICE: 246 *(void **)data = impl->d_data_borrowed; 247 impl->d_data_borrowed = NULL; 248 impl->d_data = NULL; 249 break; 250 } 251 return CEED_ERROR_SUCCESS; 252 } 253 254 //------------------------------------------------------------------------------ 255 // Core logic for GetData. 256 // If a different memory type is most up to date, this will perform a copy 257 //------------------------------------------------------------------------------ 258 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 259 bool need_sync = false; 260 CeedQFunctionContext_Hip *impl; 261 262 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 263 264 // Sync data to requested mem_type 265 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 266 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 267 268 // Update pointer 269 switch (mem_type) { 270 case CEED_MEM_HOST: 271 *(void **)data = impl->h_data; 272 break; 273 case CEED_MEM_DEVICE: 274 *(void **)data = impl->d_data; 275 break; 276 } 277 return CEED_ERROR_SUCCESS; 278 } 279 280 //------------------------------------------------------------------------------ 281 // Get read-only access to the data 282 //------------------------------------------------------------------------------ 283 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 284 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 285 } 286 287 //------------------------------------------------------------------------------ 288 // Get read/write access to the data 289 //------------------------------------------------------------------------------ 290 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 291 CeedQFunctionContext_Hip *impl; 292 293 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 294 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 295 296 // Mark only pointer for requested memory as valid 297 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 298 switch (mem_type) { 299 case CEED_MEM_HOST: 300 impl->h_data = *(void **)data; 301 break; 302 case CEED_MEM_DEVICE: 303 impl->d_data = *(void **)data; 304 break; 305 } 306 return CEED_ERROR_SUCCESS; 307 } 308 309 //------------------------------------------------------------------------------ 310 // Destroy the user context 311 //------------------------------------------------------------------------------ 312 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 313 CeedQFunctionContext_Hip *impl; 314 315 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 316 CeedCallHip(CeedQFunctionContextReturnCeed(ctx), hipFree(impl->d_data_owned)); 317 CeedCallBackend(CeedFree(&impl->h_data_owned)); 318 CeedCallBackend(CeedFree(&impl)); 319 return CEED_ERROR_SUCCESS; 320 } 321 322 //------------------------------------------------------------------------------ 323 // QFunctionContext Create 324 //------------------------------------------------------------------------------ 325 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 326 CeedQFunctionContext_Hip *impl; 327 Ceed ceed; 328 329 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 330 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 331 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 332 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 333 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 337 CeedCallBackend(CeedCalloc(1, &impl)); 338 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 339 return CEED_ERROR_SUCCESS; 340 } 341 342 //------------------------------------------------------------------------------ 343