1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <string.h> 12 #include <hip/hip_runtime.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 23 CeedQFunctionContext_Hip *impl; 24 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 25 26 if (!impl->h_data) { 27 // LCOV_EXCL_START 28 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 // LCOV_EXCL_STOP 30 } 31 32 size_t ctxsize; 33 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 34 35 if (impl->d_data_borrowed) { 36 impl->d_data = impl->d_data_borrowed; 37 } else if (impl->d_data_owned) { 38 impl->d_data = impl->d_data_owned; 39 } else { 40 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 41 impl->d_data = impl->d_data_owned; 42 } 43 44 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice)); 45 46 return CEED_ERROR_SUCCESS; 47 } 48 49 //------------------------------------------------------------------------------ 50 // Sync device to host 51 //------------------------------------------------------------------------------ 52 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 53 Ceed ceed; 54 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 55 CeedQFunctionContext_Hip *impl; 56 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 57 58 if (!impl->d_data) { 59 // LCOV_EXCL_START 60 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 61 // LCOV_EXCL_STOP 62 } 63 64 size_t ctxsize; 65 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 66 67 if (impl->h_data_borrowed) { 68 impl->h_data = impl->h_data_borrowed; 69 } else if (impl->h_data_owned) { 70 impl->h_data = impl->h_data_owned; 71 } else { 72 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 73 impl->h_data = impl->h_data_owned; 74 } 75 76 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost)); 77 78 return CEED_ERROR_SUCCESS; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Sync data of type 83 //------------------------------------------------------------------------------ 84 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 85 switch (mem_type) { 86 case CEED_MEM_HOST: 87 return CeedQFunctionContextSyncD2H_Hip(ctx); 88 case CEED_MEM_DEVICE: 89 return CeedQFunctionContextSyncH2D_Hip(ctx); 90 } 91 return CEED_ERROR_UNSUPPORTED; 92 } 93 94 //------------------------------------------------------------------------------ 95 // Set all pointers as invalid 96 //------------------------------------------------------------------------------ 97 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 98 CeedQFunctionContext_Hip *impl; 99 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 100 101 impl->h_data = NULL; 102 impl->d_data = NULL; 103 104 return CEED_ERROR_SUCCESS; 105 } 106 107 //------------------------------------------------------------------------------ 108 // Check for valid data 109 //------------------------------------------------------------------------------ 110 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 111 CeedQFunctionContext_Hip *impl; 112 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 113 114 *has_valid_data = !!impl->h_data || !!impl->d_data; 115 116 return CEED_ERROR_SUCCESS; 117 } 118 119 //------------------------------------------------------------------------------ 120 // Check if ctx has borrowed data 121 //------------------------------------------------------------------------------ 122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 123 bool *has_borrowed_data_of_type) { 124 CeedQFunctionContext_Hip *impl; 125 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 126 127 switch (mem_type) { 128 case CEED_MEM_HOST: 129 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 130 break; 131 case CEED_MEM_DEVICE: 132 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 133 break; 134 } 135 136 return CEED_ERROR_SUCCESS; 137 } 138 139 //------------------------------------------------------------------------------ 140 // Check if data of given type needs sync 141 //------------------------------------------------------------------------------ 142 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 143 CeedQFunctionContext_Hip *impl; 144 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 145 146 bool has_valid_data = true; 147 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 148 switch (mem_type) { 149 case CEED_MEM_HOST: 150 *need_sync = has_valid_data && !impl->h_data; 151 break; 152 case CEED_MEM_DEVICE: 153 *need_sync = has_valid_data && !impl->d_data; 154 break; 155 } 156 157 return CEED_ERROR_SUCCESS; 158 } 159 160 //------------------------------------------------------------------------------ 161 // Set data from host 162 //------------------------------------------------------------------------------ 163 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 164 CeedQFunctionContext_Hip *impl; 165 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 166 167 CeedCallBackend(CeedFree(&impl->h_data_owned)); 168 switch (copy_mode) { 169 case CEED_COPY_VALUES: { 170 size_t ctxsize; 171 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 172 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 173 impl->h_data_borrowed = NULL; 174 impl->h_data = impl->h_data_owned; 175 memcpy(impl->h_data, data, ctxsize); 176 } break; 177 case CEED_OWN_POINTER: 178 impl->h_data_owned = data; 179 impl->h_data_borrowed = NULL; 180 impl->h_data = data; 181 break; 182 case CEED_USE_POINTER: 183 impl->h_data_borrowed = data; 184 impl->h_data = data; 185 break; 186 } 187 188 return CEED_ERROR_SUCCESS; 189 } 190 191 //------------------------------------------------------------------------------ 192 // Set data from device 193 //------------------------------------------------------------------------------ 194 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 195 Ceed ceed; 196 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 197 CeedQFunctionContext_Hip *impl; 198 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 199 200 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 201 impl->d_data_owned = NULL; 202 switch (copy_mode) { 203 case CEED_COPY_VALUES: { 204 size_t ctxsize; 205 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 206 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 207 impl->d_data_borrowed = NULL; 208 impl->d_data = impl->d_data_owned; 209 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice)); 210 } break; 211 case CEED_OWN_POINTER: 212 impl->d_data_owned = data; 213 impl->d_data_borrowed = NULL; 214 impl->d_data = data; 215 break; 216 case CEED_USE_POINTER: 217 impl->d_data_owned = NULL; 218 impl->d_data_borrowed = data; 219 impl->d_data = data; 220 break; 221 } 222 223 return CEED_ERROR_SUCCESS; 224 } 225 226 //------------------------------------------------------------------------------ 227 // Set the data used by a user context, freeing any previously allocated data if applicable 228 //------------------------------------------------------------------------------ 229 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 230 Ceed ceed; 231 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 232 233 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 234 switch (mem_type) { 235 case CEED_MEM_HOST: 236 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 237 case CEED_MEM_DEVICE: 238 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 239 } 240 241 return CEED_ERROR_UNSUPPORTED; 242 } 243 244 //------------------------------------------------------------------------------ 245 // Take data 246 //------------------------------------------------------------------------------ 247 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 248 Ceed ceed; 249 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 250 CeedQFunctionContext_Hip *impl; 251 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 252 253 // Sync data to requested mem_type 254 bool need_sync = false; 255 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 256 if (need_sync) { 257 CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 258 } 259 260 // Update pointer 261 switch (mem_type) { 262 case CEED_MEM_HOST: 263 *(void **)data = impl->h_data_borrowed; 264 impl->h_data_borrowed = NULL; 265 impl->h_data = NULL; 266 break; 267 case CEED_MEM_DEVICE: 268 *(void **)data = impl->d_data_borrowed; 269 impl->d_data_borrowed = NULL; 270 impl->d_data = NULL; 271 break; 272 } 273 274 return CEED_ERROR_SUCCESS; 275 } 276 277 //------------------------------------------------------------------------------ 278 // Core logic for GetData. 279 // If a different memory type is most up to date, this will perform a copy 280 //------------------------------------------------------------------------------ 281 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 282 Ceed ceed; 283 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 284 CeedQFunctionContext_Hip *impl; 285 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 286 287 // Sync data to requested mem_type 288 bool need_sync = false; 289 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 290 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 291 292 // Sync data to requested mem_type and update pointer 293 switch (mem_type) { 294 case CEED_MEM_HOST: 295 *(void **)data = impl->h_data; 296 break; 297 case CEED_MEM_DEVICE: 298 *(void **)data = impl->d_data; 299 break; 300 } 301 302 return CEED_ERROR_SUCCESS; 303 } 304 305 //------------------------------------------------------------------------------ 306 // Get read-only access to the data 307 //------------------------------------------------------------------------------ 308 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 309 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 310 } 311 312 //------------------------------------------------------------------------------ 313 // Get read/write access to the data 314 //------------------------------------------------------------------------------ 315 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 316 CeedQFunctionContext_Hip *impl; 317 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 318 319 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 320 321 // Mark only pointer for requested memory as valid 322 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 323 switch (mem_type) { 324 case CEED_MEM_HOST: 325 impl->h_data = *(void **)data; 326 break; 327 case CEED_MEM_DEVICE: 328 impl->d_data = *(void **)data; 329 break; 330 } 331 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Destroy the user context 337 //------------------------------------------------------------------------------ 338 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 339 Ceed ceed; 340 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 341 CeedQFunctionContext_Hip *impl; 342 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 343 344 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 345 CeedCallBackend(CeedFree(&impl->h_data_owned)); 346 CeedCallBackend(CeedFree(&impl)); 347 348 return CEED_ERROR_SUCCESS; 349 } 350 351 //------------------------------------------------------------------------------ 352 // QFunctionContext Create 353 //------------------------------------------------------------------------------ 354 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 355 CeedQFunctionContext_Hip *impl; 356 Ceed ceed; 357 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 358 359 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 360 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 361 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 362 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 363 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 364 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 365 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 366 367 CeedCallBackend(CeedCalloc(1, &impl)); 368 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 369 370 return CEED_ERROR_SUCCESS; 371 } 372 373 //------------------------------------------------------------------------------ 374