1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <hip/hip_runtime.h> 11 #include <string.h> 12 13 #include "ceed-hip-ref.h" 14 15 //------------------------------------------------------------------------------ 16 // Sync host to device 17 //------------------------------------------------------------------------------ 18 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 19 Ceed ceed; 20 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 21 CeedQFunctionContext_Hip *impl; 22 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 23 24 if (!impl->h_data) { 25 // LCOV_EXCL_START 26 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 27 // LCOV_EXCL_STOP 28 } 29 30 size_t ctxsize; 31 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 32 33 if (impl->d_data_borrowed) { 34 impl->d_data = impl->d_data_borrowed; 35 } else if (impl->d_data_owned) { 36 impl->d_data = impl->d_data_owned; 37 } else { 38 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 39 impl->d_data = impl->d_data_owned; 40 } 41 42 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice)); 43 44 return CEED_ERROR_SUCCESS; 45 } 46 47 //------------------------------------------------------------------------------ 48 // Sync device to host 49 //------------------------------------------------------------------------------ 50 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 51 Ceed ceed; 52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 53 CeedQFunctionContext_Hip *impl; 54 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 55 56 if (!impl->d_data) { 57 // LCOV_EXCL_START 58 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 59 // LCOV_EXCL_STOP 60 } 61 62 size_t ctxsize; 63 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 64 65 if (impl->h_data_borrowed) { 66 impl->h_data = impl->h_data_borrowed; 67 } else if (impl->h_data_owned) { 68 impl->h_data = impl->h_data_owned; 69 } else { 70 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 71 impl->h_data = impl->h_data_owned; 72 } 73 74 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost)); 75 76 return CEED_ERROR_SUCCESS; 77 } 78 79 //------------------------------------------------------------------------------ 80 // Sync data of type 81 //------------------------------------------------------------------------------ 82 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 83 switch (mem_type) { 84 case CEED_MEM_HOST: 85 return CeedQFunctionContextSyncD2H_Hip(ctx); 86 case CEED_MEM_DEVICE: 87 return CeedQFunctionContextSyncH2D_Hip(ctx); 88 } 89 return CEED_ERROR_UNSUPPORTED; 90 } 91 92 //------------------------------------------------------------------------------ 93 // Set all pointers as invalid 94 //------------------------------------------------------------------------------ 95 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 96 CeedQFunctionContext_Hip *impl; 97 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 98 99 impl->h_data = NULL; 100 impl->d_data = NULL; 101 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106 // Check for valid data 107 //------------------------------------------------------------------------------ 108 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 109 CeedQFunctionContext_Hip *impl; 110 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 111 112 *has_valid_data = !!impl->h_data || !!impl->d_data; 113 114 return CEED_ERROR_SUCCESS; 115 } 116 117 //------------------------------------------------------------------------------ 118 // Check if ctx has borrowed data 119 //------------------------------------------------------------------------------ 120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 121 bool *has_borrowed_data_of_type) { 122 CeedQFunctionContext_Hip *impl; 123 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 124 125 switch (mem_type) { 126 case CEED_MEM_HOST: 127 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 128 break; 129 case CEED_MEM_DEVICE: 130 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 131 break; 132 } 133 134 return CEED_ERROR_SUCCESS; 135 } 136 137 //------------------------------------------------------------------------------ 138 // Check if data of given type needs sync 139 //------------------------------------------------------------------------------ 140 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 141 CeedQFunctionContext_Hip *impl; 142 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 143 144 bool has_valid_data = true; 145 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 146 switch (mem_type) { 147 case CEED_MEM_HOST: 148 *need_sync = has_valid_data && !impl->h_data; 149 break; 150 case CEED_MEM_DEVICE: 151 *need_sync = has_valid_data && !impl->d_data; 152 break; 153 } 154 155 return CEED_ERROR_SUCCESS; 156 } 157 158 //------------------------------------------------------------------------------ 159 // Set data from host 160 //------------------------------------------------------------------------------ 161 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 162 CeedQFunctionContext_Hip *impl; 163 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 164 165 CeedCallBackend(CeedFree(&impl->h_data_owned)); 166 switch (copy_mode) { 167 case CEED_COPY_VALUES: { 168 size_t ctxsize; 169 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 170 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 171 impl->h_data_borrowed = NULL; 172 impl->h_data = impl->h_data_owned; 173 memcpy(impl->h_data, data, ctxsize); 174 } break; 175 case CEED_OWN_POINTER: 176 impl->h_data_owned = data; 177 impl->h_data_borrowed = NULL; 178 impl->h_data = data; 179 break; 180 case CEED_USE_POINTER: 181 impl->h_data_borrowed = data; 182 impl->h_data = data; 183 break; 184 } 185 186 return CEED_ERROR_SUCCESS; 187 } 188 189 //------------------------------------------------------------------------------ 190 // Set data from device 191 //------------------------------------------------------------------------------ 192 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 193 Ceed ceed; 194 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 195 CeedQFunctionContext_Hip *impl; 196 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 197 198 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 199 impl->d_data_owned = NULL; 200 switch (copy_mode) { 201 case CEED_COPY_VALUES: { 202 size_t ctxsize; 203 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 204 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 205 impl->d_data_borrowed = NULL; 206 impl->d_data = impl->d_data_owned; 207 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice)); 208 } break; 209 case CEED_OWN_POINTER: 210 impl->d_data_owned = data; 211 impl->d_data_borrowed = NULL; 212 impl->d_data = data; 213 break; 214 case CEED_USE_POINTER: 215 impl->d_data_owned = NULL; 216 impl->d_data_borrowed = data; 217 impl->d_data = data; 218 break; 219 } 220 221 return CEED_ERROR_SUCCESS; 222 } 223 224 //------------------------------------------------------------------------------ 225 // Set the data used by a user context, freeing any previously allocated data if applicable 226 //------------------------------------------------------------------------------ 227 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 228 Ceed ceed; 229 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 230 231 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 232 switch (mem_type) { 233 case CEED_MEM_HOST: 234 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 235 case CEED_MEM_DEVICE: 236 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 237 } 238 239 return CEED_ERROR_UNSUPPORTED; 240 } 241 242 //------------------------------------------------------------------------------ 243 // Take data 244 //------------------------------------------------------------------------------ 245 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 246 Ceed ceed; 247 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 248 CeedQFunctionContext_Hip *impl; 249 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 250 251 // Sync data to requested mem_type 252 bool need_sync = false; 253 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 254 if (need_sync) { 255 CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 256 } 257 258 // Update pointer 259 switch (mem_type) { 260 case CEED_MEM_HOST: 261 *(void **)data = impl->h_data_borrowed; 262 impl->h_data_borrowed = NULL; 263 impl->h_data = NULL; 264 break; 265 case CEED_MEM_DEVICE: 266 *(void **)data = impl->d_data_borrowed; 267 impl->d_data_borrowed = NULL; 268 impl->d_data = NULL; 269 break; 270 } 271 272 return CEED_ERROR_SUCCESS; 273 } 274 275 //------------------------------------------------------------------------------ 276 // Core logic for GetData. 277 // If a different memory type is most up to date, this will perform a copy 278 //------------------------------------------------------------------------------ 279 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 280 Ceed ceed; 281 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 282 CeedQFunctionContext_Hip *impl; 283 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 284 285 // Sync data to requested mem_type 286 bool need_sync = false; 287 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 288 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 289 290 // Sync data to requested mem_type and update pointer 291 switch (mem_type) { 292 case CEED_MEM_HOST: 293 *(void **)data = impl->h_data; 294 break; 295 case CEED_MEM_DEVICE: 296 *(void **)data = impl->d_data; 297 break; 298 } 299 300 return CEED_ERROR_SUCCESS; 301 } 302 303 //------------------------------------------------------------------------------ 304 // Get read-only access to the data 305 //------------------------------------------------------------------------------ 306 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 307 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 308 } 309 310 //------------------------------------------------------------------------------ 311 // Get read/write access to the data 312 //------------------------------------------------------------------------------ 313 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 314 CeedQFunctionContext_Hip *impl; 315 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 316 317 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 318 319 // Mark only pointer for requested memory as valid 320 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 321 switch (mem_type) { 322 case CEED_MEM_HOST: 323 impl->h_data = *(void **)data; 324 break; 325 case CEED_MEM_DEVICE: 326 impl->d_data = *(void **)data; 327 break; 328 } 329 330 return CEED_ERROR_SUCCESS; 331 } 332 333 //------------------------------------------------------------------------------ 334 // Destroy the user context 335 //------------------------------------------------------------------------------ 336 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 337 Ceed ceed; 338 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 339 CeedQFunctionContext_Hip *impl; 340 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 341 342 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 343 CeedCallBackend(CeedFree(&impl->h_data_owned)); 344 CeedCallBackend(CeedFree(&impl)); 345 346 return CEED_ERROR_SUCCESS; 347 } 348 349 //------------------------------------------------------------------------------ 350 // QFunctionContext Create 351 //------------------------------------------------------------------------------ 352 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 353 CeedQFunctionContext_Hip *impl; 354 Ceed ceed; 355 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 356 357 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 358 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 359 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 360 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 361 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 362 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 363 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 364 365 CeedCallBackend(CeedCalloc(1, &impl)); 366 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 367 368 return CEED_ERROR_SUCCESS; 369 } 370 //------------------------------------------------------------------------------ 371