1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <hip/hip_runtime.h> 11 #include <string.h> 12 13 #include "ceed-hip-ref.h" 14 15 //------------------------------------------------------------------------------ 16 // Sync host to device 17 //------------------------------------------------------------------------------ 18 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 19 Ceed ceed; 20 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 21 CeedQFunctionContext_Hip *impl; 22 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 23 24 if (!impl->h_data) { 25 // LCOV_EXCL_START 26 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 27 // LCOV_EXCL_STOP 28 } 29 30 size_t ctxsize; 31 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 32 33 if (impl->d_data_borrowed) { 34 impl->d_data = impl->d_data_borrowed; 35 } else if (impl->d_data_owned) { 36 impl->d_data = impl->d_data_owned; 37 } else { 38 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 39 impl->d_data = impl->d_data_owned; 40 } 41 42 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice)); 43 44 return CEED_ERROR_SUCCESS; 45 } 46 47 //------------------------------------------------------------------------------ 48 // Sync device to host 49 //------------------------------------------------------------------------------ 50 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 51 Ceed ceed; 52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 53 CeedQFunctionContext_Hip *impl; 54 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 55 56 if (!impl->d_data) { 57 // LCOV_EXCL_START 58 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 59 // LCOV_EXCL_STOP 60 } 61 62 size_t ctxsize; 63 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 64 65 if (impl->h_data_borrowed) { 66 impl->h_data = impl->h_data_borrowed; 67 } else if (impl->h_data_owned) { 68 impl->h_data = impl->h_data_owned; 69 } else { 70 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 71 impl->h_data = impl->h_data_owned; 72 } 73 74 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost)); 75 76 return CEED_ERROR_SUCCESS; 77 } 78 79 //------------------------------------------------------------------------------ 80 // Sync data of type 81 //------------------------------------------------------------------------------ 82 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 83 switch (mem_type) { 84 case CEED_MEM_HOST: 85 return CeedQFunctionContextSyncD2H_Hip(ctx); 86 case CEED_MEM_DEVICE: 87 return CeedQFunctionContextSyncH2D_Hip(ctx); 88 } 89 return CEED_ERROR_UNSUPPORTED; 90 } 91 92 //------------------------------------------------------------------------------ 93 // Set all pointers as invalid 94 //------------------------------------------------------------------------------ 95 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 96 CeedQFunctionContext_Hip *impl; 97 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 98 99 impl->h_data = NULL; 100 impl->d_data = NULL; 101 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106 // Check for valid data 107 //------------------------------------------------------------------------------ 108 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 109 CeedQFunctionContext_Hip *impl; 110 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 111 112 *has_valid_data = !!impl->h_data || !!impl->d_data; 113 114 return CEED_ERROR_SUCCESS; 115 } 116 117 //------------------------------------------------------------------------------ 118 // Check if ctx has borrowed data 119 //------------------------------------------------------------------------------ 120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 121 bool *has_borrowed_data_of_type) { 122 CeedQFunctionContext_Hip *impl; 123 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 124 125 switch (mem_type) { 126 case CEED_MEM_HOST: 127 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 128 break; 129 case CEED_MEM_DEVICE: 130 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 131 break; 132 } 133 134 return CEED_ERROR_SUCCESS; 135 } 136 137 //------------------------------------------------------------------------------ 138 // Check if data of given type needs sync 139 //------------------------------------------------------------------------------ 140 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 141 CeedQFunctionContext_Hip *impl; 142 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 143 144 bool has_valid_data = true; 145 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 146 switch (mem_type) { 147 case CEED_MEM_HOST: 148 *need_sync = has_valid_data && !impl->h_data; 149 break; 150 case CEED_MEM_DEVICE: 151 *need_sync = has_valid_data && !impl->d_data; 152 break; 153 } 154 155 return CEED_ERROR_SUCCESS; 156 } 157 158 //------------------------------------------------------------------------------ 159 // Set data from host 160 //------------------------------------------------------------------------------ 161 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 162 CeedQFunctionContext_Hip *impl; 163 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 164 165 CeedCallBackend(CeedFree(&impl->h_data_owned)); 166 switch (copy_mode) { 167 case CEED_COPY_VALUES: { 168 size_t ctxsize; 169 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 170 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 171 impl->h_data_borrowed = NULL; 172 impl->h_data = impl->h_data_owned; 173 memcpy(impl->h_data, data, ctxsize); 174 } break; 175 case CEED_OWN_POINTER: 176 impl->h_data_owned = data; 177 impl->h_data_borrowed = NULL; 178 impl->h_data = data; 179 break; 180 case CEED_USE_POINTER: 181 impl->h_data_borrowed = data; 182 impl->h_data = data; 183 break; 184 } 185 186 return CEED_ERROR_SUCCESS; 187 } 188 189 //------------------------------------------------------------------------------ 190 // Set data from device 191 //------------------------------------------------------------------------------ 192 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 193 Ceed ceed; 194 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 195 CeedQFunctionContext_Hip *impl; 196 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 197 198 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 199 impl->d_data_owned = NULL; 200 switch (copy_mode) { 201 case CEED_COPY_VALUES: { 202 size_t ctxsize; 203 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 204 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 205 impl->d_data_borrowed = NULL; 206 impl->d_data = impl->d_data_owned; 207 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice)); 208 } break; 209 case CEED_OWN_POINTER: 210 impl->d_data_owned = data; 211 impl->d_data_borrowed = NULL; 212 impl->d_data = data; 213 break; 214 case CEED_USE_POINTER: 215 impl->d_data_owned = NULL; 216 impl->d_data_borrowed = data; 217 impl->d_data = data; 218 break; 219 } 220 221 return CEED_ERROR_SUCCESS; 222 } 223 224 //------------------------------------------------------------------------------ 225 // Set the data used by a user context, 226 // freeing any previously allocated data if applicable 227 //------------------------------------------------------------------------------ 228 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 229 Ceed ceed; 230 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 231 232 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 233 switch (mem_type) { 234 case CEED_MEM_HOST: 235 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 236 case CEED_MEM_DEVICE: 237 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 238 } 239 240 return CEED_ERROR_UNSUPPORTED; 241 } 242 243 //------------------------------------------------------------------------------ 244 // Take data 245 //------------------------------------------------------------------------------ 246 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 247 Ceed ceed; 248 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 249 CeedQFunctionContext_Hip *impl; 250 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 251 252 // Sync data to requested mem_type 253 bool need_sync = false; 254 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 255 if (need_sync) { 256 CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 257 } 258 259 // Update pointer 260 switch (mem_type) { 261 case CEED_MEM_HOST: 262 *(void **)data = impl->h_data_borrowed; 263 impl->h_data_borrowed = NULL; 264 impl->h_data = NULL; 265 break; 266 case CEED_MEM_DEVICE: 267 *(void **)data = impl->d_data_borrowed; 268 impl->d_data_borrowed = NULL; 269 impl->d_data = NULL; 270 break; 271 } 272 273 return CEED_ERROR_SUCCESS; 274 } 275 276 //------------------------------------------------------------------------------ 277 // Core logic for GetData. 278 // If a different memory type is most up to date, this will perform a copy 279 //------------------------------------------------------------------------------ 280 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 281 Ceed ceed; 282 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 283 CeedQFunctionContext_Hip *impl; 284 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 285 286 // Sync data to requested mem_type 287 bool need_sync = false; 288 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 289 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 290 291 // Sync data to requested mem_type and update pointer 292 switch (mem_type) { 293 case CEED_MEM_HOST: 294 *(void **)data = impl->h_data; 295 break; 296 case CEED_MEM_DEVICE: 297 *(void **)data = impl->d_data; 298 break; 299 } 300 301 return CEED_ERROR_SUCCESS; 302 } 303 304 //------------------------------------------------------------------------------ 305 // Get read-only access to the data 306 //------------------------------------------------------------------------------ 307 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 308 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 309 } 310 311 //------------------------------------------------------------------------------ 312 // Get read/write access to the data 313 //------------------------------------------------------------------------------ 314 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 315 CeedQFunctionContext_Hip *impl; 316 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 317 318 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 319 320 // Mark only pointer for requested memory as valid 321 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 322 switch (mem_type) { 323 case CEED_MEM_HOST: 324 impl->h_data = *(void **)data; 325 break; 326 case CEED_MEM_DEVICE: 327 impl->d_data = *(void **)data; 328 break; 329 } 330 331 return CEED_ERROR_SUCCESS; 332 } 333 334 //------------------------------------------------------------------------------ 335 // Destroy the user context 336 //------------------------------------------------------------------------------ 337 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 338 Ceed ceed; 339 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 340 CeedQFunctionContext_Hip *impl; 341 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 342 343 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 344 CeedCallBackend(CeedFree(&impl->h_data_owned)); 345 CeedCallBackend(CeedFree(&impl)); 346 347 return CEED_ERROR_SUCCESS; 348 } 349 350 //------------------------------------------------------------------------------ 351 // QFunctionContext Create 352 //------------------------------------------------------------------------------ 353 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 354 CeedQFunctionContext_Hip *impl; 355 Ceed ceed; 356 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 357 358 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 359 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 360 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 361 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 362 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 363 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 364 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 365 366 CeedCallBackend(CeedCalloc(1, &impl)); 367 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 368 369 return CEED_ERROR_SUCCESS; 370 } 371 //------------------------------------------------------------------------------ 372