1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <hip/hip_runtime.h> 11 #include <stdbool.h> 12 #include <string.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 23 CeedQFunctionContext_Hip *impl; 24 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 25 26 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 27 28 size_t ctxsize; 29 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 30 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 37 impl->d_data = impl->d_data_owned; 38 } 39 40 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctxsize, hipMemcpyHostToDevice)); 41 42 return CEED_ERROR_SUCCESS; 43 } 44 45 //------------------------------------------------------------------------------ 46 // Sync device to host 47 //------------------------------------------------------------------------------ 48 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 49 Ceed ceed; 50 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 51 CeedQFunctionContext_Hip *impl; 52 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 53 54 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 55 56 size_t ctxsize; 57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 58 59 if (impl->h_data_borrowed) { 60 impl->h_data = impl->h_data_borrowed; 61 } else if (impl->h_data_owned) { 62 impl->h_data = impl->h_data_owned; 63 } else { 64 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 65 impl->h_data = impl->h_data_owned; 66 } 67 68 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctxsize, hipMemcpyDeviceToHost)); 69 70 return CEED_ERROR_SUCCESS; 71 } 72 73 //------------------------------------------------------------------------------ 74 // Sync data of type 75 //------------------------------------------------------------------------------ 76 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 77 switch (mem_type) { 78 case CEED_MEM_HOST: 79 return CeedQFunctionContextSyncD2H_Hip(ctx); 80 case CEED_MEM_DEVICE: 81 return CeedQFunctionContextSyncH2D_Hip(ctx); 82 } 83 return CEED_ERROR_UNSUPPORTED; 84 } 85 86 //------------------------------------------------------------------------------ 87 // Set all pointers as invalid 88 //------------------------------------------------------------------------------ 89 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 90 CeedQFunctionContext_Hip *impl; 91 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 92 93 impl->h_data = NULL; 94 impl->d_data = NULL; 95 96 return CEED_ERROR_SUCCESS; 97 } 98 99 //------------------------------------------------------------------------------ 100 // Check for valid data 101 //------------------------------------------------------------------------------ 102 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 103 CeedQFunctionContext_Hip *impl; 104 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 105 106 *has_valid_data = !!impl->h_data || !!impl->d_data; 107 108 return CEED_ERROR_SUCCESS; 109 } 110 111 //------------------------------------------------------------------------------ 112 // Check if ctx has borrowed data 113 //------------------------------------------------------------------------------ 114 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 115 bool *has_borrowed_data_of_type) { 116 CeedQFunctionContext_Hip *impl; 117 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 118 119 switch (mem_type) { 120 case CEED_MEM_HOST: 121 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 122 break; 123 case CEED_MEM_DEVICE: 124 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 125 break; 126 } 127 128 return CEED_ERROR_SUCCESS; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Check if data of given type needs sync 133 //------------------------------------------------------------------------------ 134 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 135 CeedQFunctionContext_Hip *impl; 136 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 137 138 bool has_valid_data = true; 139 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 140 switch (mem_type) { 141 case CEED_MEM_HOST: 142 *need_sync = has_valid_data && !impl->h_data; 143 break; 144 case CEED_MEM_DEVICE: 145 *need_sync = has_valid_data && !impl->d_data; 146 break; 147 } 148 149 return CEED_ERROR_SUCCESS; 150 } 151 152 //------------------------------------------------------------------------------ 153 // Set data from host 154 //------------------------------------------------------------------------------ 155 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 156 CeedQFunctionContext_Hip *impl; 157 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 158 159 CeedCallBackend(CeedFree(&impl->h_data_owned)); 160 switch (copy_mode) { 161 case CEED_COPY_VALUES: { 162 size_t ctxsize; 163 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 164 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 165 impl->h_data_borrowed = NULL; 166 impl->h_data = impl->h_data_owned; 167 memcpy(impl->h_data, data, ctxsize); 168 } break; 169 case CEED_OWN_POINTER: 170 impl->h_data_owned = data; 171 impl->h_data_borrowed = NULL; 172 impl->h_data = data; 173 break; 174 case CEED_USE_POINTER: 175 impl->h_data_borrowed = data; 176 impl->h_data = data; 177 break; 178 } 179 180 return CEED_ERROR_SUCCESS; 181 } 182 183 //------------------------------------------------------------------------------ 184 // Set data from device 185 //------------------------------------------------------------------------------ 186 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 187 Ceed ceed; 188 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 189 CeedQFunctionContext_Hip *impl; 190 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 191 192 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 193 impl->d_data_owned = NULL; 194 switch (copy_mode) { 195 case CEED_COPY_VALUES: { 196 size_t ctxsize; 197 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 198 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctxsize)); 199 impl->d_data_borrowed = NULL; 200 impl->d_data = impl->d_data_owned; 201 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctxsize, hipMemcpyDeviceToDevice)); 202 } break; 203 case CEED_OWN_POINTER: 204 impl->d_data_owned = data; 205 impl->d_data_borrowed = NULL; 206 impl->d_data = data; 207 break; 208 case CEED_USE_POINTER: 209 impl->d_data_owned = NULL; 210 impl->d_data_borrowed = data; 211 impl->d_data = data; 212 break; 213 } 214 215 return CEED_ERROR_SUCCESS; 216 } 217 218 //------------------------------------------------------------------------------ 219 // Set the data used by a user context, freeing any previously allocated data if applicable 220 //------------------------------------------------------------------------------ 221 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 222 Ceed ceed; 223 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 224 225 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 226 switch (mem_type) { 227 case CEED_MEM_HOST: 228 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 229 case CEED_MEM_DEVICE: 230 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 231 } 232 233 return CEED_ERROR_UNSUPPORTED; 234 } 235 236 //------------------------------------------------------------------------------ 237 // Take data 238 //------------------------------------------------------------------------------ 239 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 240 Ceed ceed; 241 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 242 CeedQFunctionContext_Hip *impl; 243 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 244 245 // Sync data to requested mem_type 246 bool need_sync = false; 247 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 248 if (need_sync) { 249 CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 250 } 251 252 // Update pointer 253 switch (mem_type) { 254 case CEED_MEM_HOST: 255 *(void **)data = impl->h_data_borrowed; 256 impl->h_data_borrowed = NULL; 257 impl->h_data = NULL; 258 break; 259 case CEED_MEM_DEVICE: 260 *(void **)data = impl->d_data_borrowed; 261 impl->d_data_borrowed = NULL; 262 impl->d_data = NULL; 263 break; 264 } 265 266 return CEED_ERROR_SUCCESS; 267 } 268 269 //------------------------------------------------------------------------------ 270 // Core logic for GetData. 271 // If a different memory type is most up to date, this will perform a copy 272 //------------------------------------------------------------------------------ 273 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 274 Ceed ceed; 275 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 276 CeedQFunctionContext_Hip *impl; 277 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 278 279 // Sync data to requested mem_type 280 bool need_sync = false; 281 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 282 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 283 284 // Sync data to requested mem_type and update pointer 285 switch (mem_type) { 286 case CEED_MEM_HOST: 287 *(void **)data = impl->h_data; 288 break; 289 case CEED_MEM_DEVICE: 290 *(void **)data = impl->d_data; 291 break; 292 } 293 294 return CEED_ERROR_SUCCESS; 295 } 296 297 //------------------------------------------------------------------------------ 298 // Get read-only access to the data 299 //------------------------------------------------------------------------------ 300 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 301 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 302 } 303 304 //------------------------------------------------------------------------------ 305 // Get read/write access to the data 306 //------------------------------------------------------------------------------ 307 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 308 CeedQFunctionContext_Hip *impl; 309 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 310 311 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 312 313 // Mark only pointer for requested memory as valid 314 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 315 switch (mem_type) { 316 case CEED_MEM_HOST: 317 impl->h_data = *(void **)data; 318 break; 319 case CEED_MEM_DEVICE: 320 impl->d_data = *(void **)data; 321 break; 322 } 323 324 return CEED_ERROR_SUCCESS; 325 } 326 327 //------------------------------------------------------------------------------ 328 // Destroy the user context 329 //------------------------------------------------------------------------------ 330 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 331 Ceed ceed; 332 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 333 CeedQFunctionContext_Hip *impl; 334 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 335 336 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 337 CeedCallBackend(CeedFree(&impl->h_data_owned)); 338 CeedCallBackend(CeedFree(&impl)); 339 340 return CEED_ERROR_SUCCESS; 341 } 342 343 //------------------------------------------------------------------------------ 344 // QFunctionContext Create 345 //------------------------------------------------------------------------------ 346 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 347 CeedQFunctionContext_Hip *impl; 348 Ceed ceed; 349 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 350 351 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 352 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 353 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 354 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 355 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 356 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 357 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 358 359 CeedCallBackend(CeedCalloc(1, &impl)); 360 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 361 362 return CEED_ERROR_SUCCESS; 363 } 364 //------------------------------------------------------------------------------ 365