1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <string.h> 12 #include <hip/hip_runtime.h> 13 14 #include "../hip/ceed-hip-common.h" 15 #include "ceed-hip-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 size_t ctx_size; 23 CeedQFunctionContext_Hip *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 27 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 37 impl->d_data = impl->d_data_owned; 38 } 39 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice)); 40 return CEED_ERROR_SUCCESS; 41 } 42 43 //------------------------------------------------------------------------------ 44 // Sync device to host 45 //------------------------------------------------------------------------------ 46 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) { 47 Ceed ceed; 48 size_t ctx_size; 49 CeedQFunctionContext_Hip *impl; 50 51 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 52 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 53 54 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 55 56 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 57 if (impl->h_data_borrowed) { 58 impl->h_data = impl->h_data_borrowed; 59 } else if (impl->h_data_owned) { 60 impl->h_data = impl->h_data_owned; 61 } else { 62 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 63 impl->h_data = impl->h_data_owned; 64 } 65 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost)); 66 return CEED_ERROR_SUCCESS; 67 } 68 69 //------------------------------------------------------------------------------ 70 // Sync data of type 71 //------------------------------------------------------------------------------ 72 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) { 73 switch (mem_type) { 74 case CEED_MEM_HOST: 75 return CeedQFunctionContextSyncD2H_Hip(ctx); 76 case CEED_MEM_DEVICE: 77 return CeedQFunctionContextSyncH2D_Hip(ctx); 78 } 79 return CEED_ERROR_UNSUPPORTED; 80 } 81 82 //------------------------------------------------------------------------------ 83 // Set all pointers as invalid 84 //------------------------------------------------------------------------------ 85 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) { 86 CeedQFunctionContext_Hip *impl; 87 88 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 89 impl->h_data = NULL; 90 impl->d_data = NULL; 91 return CEED_ERROR_SUCCESS; 92 } 93 94 //------------------------------------------------------------------------------ 95 // Check for valid data 96 //------------------------------------------------------------------------------ 97 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) { 98 CeedQFunctionContext_Hip *impl; 99 100 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 101 *has_valid_data = impl && (impl->h_data || impl->d_data); 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106 // Check if ctx has borrowed data 107 //------------------------------------------------------------------------------ 108 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, 109 bool *has_borrowed_data_of_type) { 110 CeedQFunctionContext_Hip *impl; 111 112 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 113 switch (mem_type) { 114 case CEED_MEM_HOST: 115 *has_borrowed_data_of_type = impl->h_data_borrowed; 116 break; 117 case CEED_MEM_DEVICE: 118 *has_borrowed_data_of_type = impl->d_data_borrowed; 119 break; 120 } 121 return CEED_ERROR_SUCCESS; 122 } 123 124 //------------------------------------------------------------------------------ 125 // Check if data of given type needs sync 126 //------------------------------------------------------------------------------ 127 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 128 bool has_valid_data = true; 129 CeedQFunctionContext_Hip *impl; 130 131 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 132 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data)); 133 switch (mem_type) { 134 case CEED_MEM_HOST: 135 *need_sync = has_valid_data && !impl->h_data; 136 break; 137 case CEED_MEM_DEVICE: 138 *need_sync = has_valid_data && !impl->d_data; 139 break; 140 } 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Set data from host 146 //------------------------------------------------------------------------------ 147 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 148 CeedQFunctionContext_Hip *impl; 149 150 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 151 CeedCallBackend(CeedFree(&impl->h_data_owned)); 152 switch (copy_mode) { 153 case CEED_COPY_VALUES: { 154 size_t ctx_size; 155 156 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 157 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 158 impl->h_data_borrowed = NULL; 159 impl->h_data = impl->h_data_owned; 160 memcpy(impl->h_data, data, ctx_size); 161 } break; 162 case CEED_OWN_POINTER: 163 impl->h_data_owned = data; 164 impl->h_data_borrowed = NULL; 165 impl->h_data = data; 166 break; 167 case CEED_USE_POINTER: 168 impl->h_data_borrowed = data; 169 impl->h_data = data; 170 break; 171 } 172 return CEED_ERROR_SUCCESS; 173 } 174 175 //------------------------------------------------------------------------------ 176 // Set data from device 177 //------------------------------------------------------------------------------ 178 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 179 Ceed ceed; 180 CeedQFunctionContext_Hip *impl; 181 182 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 183 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 184 185 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 186 impl->d_data_owned = NULL; 187 switch (copy_mode) { 188 case CEED_COPY_VALUES: { 189 size_t ctx_size; 190 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 191 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size)); 192 impl->d_data_borrowed = NULL; 193 impl->d_data = impl->d_data_owned; 194 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice)); 195 } break; 196 case CEED_OWN_POINTER: 197 impl->d_data_owned = data; 198 impl->d_data_borrowed = NULL; 199 impl->d_data = data; 200 break; 201 case CEED_USE_POINTER: 202 impl->d_data_owned = NULL; 203 impl->d_data_borrowed = data; 204 impl->d_data = data; 205 break; 206 } 207 return CEED_ERROR_SUCCESS; 208 } 209 210 //------------------------------------------------------------------------------ 211 // Set the data used by a user context, 212 // freeing any previously allocated data if applicable 213 //------------------------------------------------------------------------------ 214 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 215 Ceed ceed; 216 217 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 218 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 219 switch (mem_type) { 220 case CEED_MEM_HOST: 221 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data); 222 case CEED_MEM_DEVICE: 223 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data); 224 } 225 return CEED_ERROR_UNSUPPORTED; 226 } 227 228 //------------------------------------------------------------------------------ 229 // Take data 230 //------------------------------------------------------------------------------ 231 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 232 bool need_sync = false; 233 Ceed ceed; 234 CeedQFunctionContext_Hip *impl; 235 236 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 237 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 238 239 // Sync data to requested mem_type 240 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 241 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 242 243 // Update pointer 244 switch (mem_type) { 245 case CEED_MEM_HOST: 246 *(void **)data = impl->h_data_borrowed; 247 impl->h_data_borrowed = NULL; 248 impl->h_data = NULL; 249 break; 250 case CEED_MEM_DEVICE: 251 *(void **)data = impl->d_data_borrowed; 252 impl->d_data_borrowed = NULL; 253 impl->d_data = NULL; 254 break; 255 } 256 return CEED_ERROR_SUCCESS; 257 } 258 259 //------------------------------------------------------------------------------ 260 // Core logic for GetData. 261 // If a different memory type is most up to date, this will perform a copy 262 //------------------------------------------------------------------------------ 263 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 264 bool need_sync = false; 265 Ceed ceed; 266 CeedQFunctionContext_Hip *impl; 267 268 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 269 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 270 271 // Sync data to requested mem_type 272 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync)); 273 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type)); 274 275 // Update pointer 276 switch (mem_type) { 277 case CEED_MEM_HOST: 278 *(void **)data = impl->h_data; 279 break; 280 case CEED_MEM_DEVICE: 281 *(void **)data = impl->d_data; 282 break; 283 } 284 return CEED_ERROR_SUCCESS; 285 } 286 287 //------------------------------------------------------------------------------ 288 // Get read-only access to the data 289 //------------------------------------------------------------------------------ 290 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 291 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data); 292 } 293 294 //------------------------------------------------------------------------------ 295 // Get read/write access to the data 296 //------------------------------------------------------------------------------ 297 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 298 CeedQFunctionContext_Hip *impl; 299 300 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 301 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data)); 302 303 // Mark only pointer for requested memory as valid 304 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx)); 305 switch (mem_type) { 306 case CEED_MEM_HOST: 307 impl->h_data = *(void **)data; 308 break; 309 case CEED_MEM_DEVICE: 310 impl->d_data = *(void **)data; 311 break; 312 } 313 return CEED_ERROR_SUCCESS; 314 } 315 316 //------------------------------------------------------------------------------ 317 // Destroy the user context 318 //------------------------------------------------------------------------------ 319 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) { 320 Ceed ceed; 321 CeedQFunctionContext_Hip *impl; 322 323 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 324 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 325 CeedCallHip(ceed, hipFree(impl->d_data_owned)); 326 CeedCallBackend(CeedFree(&impl->h_data_owned)); 327 CeedCallBackend(CeedFree(&impl)); 328 return CEED_ERROR_SUCCESS; 329 } 330 331 //------------------------------------------------------------------------------ 332 // QFunctionContext Create 333 //------------------------------------------------------------------------------ 334 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) { 335 CeedQFunctionContext_Hip *impl; 336 Ceed ceed; 337 338 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 339 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip)); 340 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip)); 341 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip)); 342 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip)); 343 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip)); 344 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip)); 345 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip)); 346 CeedCallBackend(CeedCalloc(1, &impl)); 347 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 348 return CEED_ERROR_SUCCESS; 349 } 350 351 //------------------------------------------------------------------------------ 352