1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <string.h> 12 #include <hip/hip_runtime.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 23 CeedQFunctionContext_Hip *impl; 24 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 25 26 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 27 28 size_t ctxsize; 29 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 30 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 37 impl->d_data = impl->d_data_owned; 38 } 39 40 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice)); 41 42 return CEED_ERROR_SUCCESS; 43 } 44 45 //------------------------------------------------------------------------------ 46 // Sync device to host 47 //------------------------------------------------------------------------------ 48 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 49 Ceed ceed; 50 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 51 CeedQFunctionContext_Hip *impl; 52 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 53 54 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 55 56 size_t ctxsize; 57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 58 59 if (impl->h_data_borrowed) { 60 impl->h_data = impl->h_data_borrowed; 61 } else if (impl->h_data_owned) { 62 impl->h_data = impl->h_data_owned; 63 } else { 64 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 65 impl->h_data = impl->h_data_owned; 66 } 67 68 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost)); 69 70 return CEED_ERROR_SUCCESS; 71 } 72 73 //------------------------------------------------------------------------------ 74 // Sync data of type 75 //------------------------------------------------------------------------------ 76 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 77 switch (mem_type) { 78 case CEED_MEM_HOST: 79 return CeedQFunctionContextSyncD2H_Hip(ctx); 80 case CEED_MEM_DEVICE: 81 return CeedQFunctionContextSyncH2D_Hip(ctx); 82 } 83 return CEED_ERROR_UNSUPPORTED; 84 } 85 86 //------------------------------------------------------------------------------ 87 // Set all pointers as invalid 88 //------------------------------------------------------------------------------ 89 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 90 CeedQFunctionContext_Hip *impl; 91 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 92 93 impl->h_data = NULL; 94 impl->d_data = NULL; 95 96 return CEED_ERROR_SUCCESS; 97 } 98 99 //------------------------------------------------------------------------------ 100 // Check for valid data 101 //------------------------------------------------------------------------------ 102 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 103 CeedQFunctionContext_Hip *impl; 104 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 105 106 *has_valid_data = impl && (impl->h_data || impl->d_data); 107 108 return CEED_ERROR_SUCCESS; 109 } 110 111 //------------------------------------------------------------------------------ 112 // Check if ctx has borrowed data 113 //------------------------------------------------------------------------------ 114 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 115 bool *has_borrowed_data_of_type) { 116 CeedQFunctionContext_Hip *impl; 117 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 118 119 switch (mem_type) { 120 case CEED_MEM_HOST: 121 *has_borrowed_data_of_type = impl->h_data_borrowed; 122 break; 123 case CEED_MEM_DEVICE: 124 *has_borrowed_data_of_type = impl->d_data_borrowed; 125 break; 126 } 127 128 return CEED_ERROR_SUCCESS; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Check if data of given type needs sync 133 //------------------------------------------------------------------------------ 134 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 135 CeedQFunctionContext_Hip *impl; 136 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 137 138 bool has_valid_data = true; 139 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 140 switch (mem_type) { 141 case CEED_MEM_HOST: 142 *need_sync = has_valid_data && !impl->h_data; 143 break; 144 case CEED_MEM_DEVICE: 145 *need_sync = has_valid_data && !impl->d_data; 146 break; 147 } 148 149 return CEED_ERROR_SUCCESS; 150 } 151 152 //------------------------------------------------------------------------------ 153 // Set data from host 154 //------------------------------------------------------------------------------ 155 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 156 CeedQFunctionContext_Hip *impl; 157 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 158 159 CeedCallBackend(CeedFree(&impl->h_data_owned)); 160 switch (copy_mode) { 161 case CEED_COPY_VALUES: { 162 size_t ctxsize; 163 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 164 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 165 impl->h_data_borrowed = NULL; 166 impl->h_data = impl->h_data_owned; 167 memcpy(impl->h_data, data, ctxsize); 168 } break; 169 case CEED_OWN_POINTER: 170 impl->h_data_owned = data; 171 impl->h_data_borrowed = NULL; 172 impl->h_data = data; 173 break; 174 case CEED_USE_POINTER: 175 impl->h_data_borrowed = data; 176 impl->h_data = data; 177 break; 178 } 179 180 return CEED_ERROR_SUCCESS; 181 } 182 183 //------------------------------------------------------------------------------ 184 // Set data from device 185 //------------------------------------------------------------------------------ 186 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 187 Ceed ceed; 188 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 189 CeedQFunctionContext_Hip *impl; 190 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 191 192 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 193 impl->d_data_owned = NULL; 194 switch (copy_mode) { 195 case CEED_COPY_VALUES: { 196 size_t ctxsize; 197 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 198 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 199 impl->d_data_borrowed = NULL; 200 impl->d_data = impl->d_data_owned; 201 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice)); 202 } break; 203 case CEED_OWN_POINTER: 204 impl->d_data_owned = data; 205 impl->d_data_borrowed = NULL; 206 impl->d_data = data; 207 break; 208 case CEED_USE_POINTER: 209 impl->d_data_owned = NULL; 210 impl->d_data_borrowed = data; 211 impl->d_data = data; 212 break; 213 } 214 215 return CEED_ERROR_SUCCESS; 216 } 217 218 //------------------------------------------------------------------------------ 219 // Set the data used by a user context, 220 // freeing any previously allocated data if applicable 221 //------------------------------------------------------------------------------ 222 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 223 Ceed ceed; 224 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 225 226 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 227 switch (mem_type) { 228 case CEED_MEM_HOST: 229 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 230 case CEED_MEM_DEVICE: 231 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 232 } 233 234 return CEED_ERROR_UNSUPPORTED; 235 } 236 237 //------------------------------------------------------------------------------ 238 // Take data 239 //------------------------------------------------------------------------------ 240 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 241 Ceed ceed; 242 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 243 CeedQFunctionContext_Hip *impl; 244 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 245 246 // Sync data to requested mem_type 247 bool need_sync = false; 248 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 249 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 250 251 // Update pointer 252 switch (mem_type) { 253 case CEED_MEM_HOST: 254 *(void **)data = impl->h_data_borrowed; 255 impl->h_data_borrowed = NULL; 256 impl->h_data = NULL; 257 break; 258 case CEED_MEM_DEVICE: 259 *(void **)data = impl->d_data_borrowed; 260 impl->d_data_borrowed = NULL; 261 impl->d_data = NULL; 262 break; 263 } 264 265 return CEED_ERROR_SUCCESS; 266 } 267 268 //------------------------------------------------------------------------------ 269 // Core logic for GetData. 270 // If a different memory type is most up to date, this will perform a copy 271 //------------------------------------------------------------------------------ 272 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 273 Ceed ceed; 274 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 275 CeedQFunctionContext_Hip *impl; 276 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 277 278 // Sync data to requested mem_type 279 bool need_sync = false; 280 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 281 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 282 283 // Update pointer 284 switch (mem_type) { 285 case CEED_MEM_HOST: 286 *(void **)data = impl->h_data; 287 break; 288 case CEED_MEM_DEVICE: 289 *(void **)data = impl->d_data; 290 break; 291 } 292 293 return CEED_ERROR_SUCCESS; 294 } 295 296 //------------------------------------------------------------------------------ 297 // Get read-only access to the data 298 //------------------------------------------------------------------------------ 299 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 300 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 301 } 302 303 //------------------------------------------------------------------------------ 304 // Get read/write access to the data 305 //------------------------------------------------------------------------------ 306 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 307 CeedQFunctionContext_Hip *impl; 308 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 309 310 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 311 312 // Mark only pointer for requested memory as valid 313 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 314 switch (mem_type) { 315 case CEED_MEM_HOST: 316 impl->h_data = *(void **)data; 317 break; 318 case CEED_MEM_DEVICE: 319 impl->d_data = *(void **)data; 320 break; 321 } 322 323 return CEED_ERROR_SUCCESS; 324 } 325 326 //------------------------------------------------------------------------------ 327 // Destroy the user context 328 //------------------------------------------------------------------------------ 329 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 330 Ceed ceed; 331 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 332 CeedQFunctionContext_Hip *impl; 333 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 334 335 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 336 CeedCallBackend(CeedFree(&impl->h_data_owned)); 337 CeedCallBackend(CeedFree(&impl)); 338 339 return CEED_ERROR_SUCCESS; 340 } 341 342 //------------------------------------------------------------------------------ 343 // QFunctionContext Create 344 //------------------------------------------------------------------------------ 345 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 346 CeedQFunctionContext_Hip *impl; 347 Ceed ceed; 348 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 349 350 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 351 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 352 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 353 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 354 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 355 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 356 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 357 358 CeedCallBackend(CeedCalloc(1, &impl)); 359 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 360 361 return CEED_ERROR_SUCCESS; 362 } 363 364 //------------------------------------------------------------------------------ 365