13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 30d0321e0SJeremy L Thompson // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 50d0321e0SJeremy L Thompson // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 70d0321e0SJeremy L Thompson 849aac155SJeremy L Thompson #include <ceed.h> 90d0321e0SJeremy L Thompson #include <ceed/backend.h> 1049aac155SJeremy L Thompson #include <ceed/jit-source/cuda/cuda-types.h> 110d0321e0SJeremy L Thompson #include <cuda.h> 122b730f8bSJeremy L Thompson 1349aac155SJeremy L Thompson #include "../cuda/ceed-cuda-common.h" 140d0321e0SJeremy L Thompson #include "../cuda/ceed-cuda-compile.h" 152b730f8bSJeremy L Thompson #include "ceed-cuda-ref-qfunction-load.h" 162b730f8bSJeremy L Thompson #include "ceed-cuda-ref.h" 170d0321e0SJeremy L Thompson 180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 190d0321e0SJeremy L Thompson // Apply QFunction 200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 212b730f8bSJeremy L Thompson static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) { 220d0321e0SJeremy L Thompson Ceed ceed; 23*ca735530SJeremy L Thompson Ceed_Cuda *ceed_Cuda; 24*ca735530SJeremy L Thompson CeedInt num_input_fields, num_output_fields; 25*ca735530SJeremy L Thompson CeedQFunction_Cuda *data; 26*ca735530SJeremy L Thompson 272b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 280d0321e0SJeremy L Thompson 290d0321e0SJeremy L Thompson // Build and compile kernel, if not done 30eb7e6cafSJeremy L Thompson CeedCallBackend(CeedQFunctionBuildKernel_Cuda_ref(qf)); 310d0321e0SJeremy L Thompson 322b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &data)); 332b730f8bSJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &ceed_Cuda)); 342b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 350d0321e0SJeremy L Thompson 360d0321e0SJeremy L Thompson // Read vectors 37437930d1SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 382b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i])); 390d0321e0SJeremy L Thompson } 40437930d1SJeremy L Thompson for (CeedInt i = 0; i < num_output_fields; i++) { 412b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i])); 420d0321e0SJeremy L Thompson } 430d0321e0SJeremy L Thompson 440d0321e0SJeremy L Thompson // Get context data 452b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c)); 460d0321e0SJeremy L Thompson 470d0321e0SJeremy L Thompson // Run kernel 480d0321e0SJeremy L Thompson void *args[] = {&data->d_c, (void *)&Q, &data->fields}; 492b730f8bSJeremy L Thompson CeedCallBackend(CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args)); 500d0321e0SJeremy L Thompson 510d0321e0SJeremy L Thompson // Restore vectors 52437930d1SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 532b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i])); 540d0321e0SJeremy L Thompson } 55437930d1SJeremy L Thompson for (CeedInt i = 0; i < num_output_fields; i++) { 562b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(V[i], &data->fields.outputs[i])); 570d0321e0SJeremy L Thompson } 580d0321e0SJeremy L Thompson 590d0321e0SJeremy L Thompson // Restore context 602b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c)); 610d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 620d0321e0SJeremy L Thompson } 630d0321e0SJeremy L Thompson 640d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 650d0321e0SJeremy L Thompson // Destroy QFunction 660d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 670d0321e0SJeremy L Thompson static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) { 680d0321e0SJeremy L Thompson Ceed ceed; 69*ca735530SJeremy L Thompson CeedQFunction_Cuda *data; 70*ca735530SJeremy L Thompson 71*ca735530SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &data)); 722b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 732b730f8bSJeremy L Thompson if (data->module) CeedCallCuda(ceed, cuModuleUnload(data->module)); 742b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&data)); 750d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 760d0321e0SJeremy L Thompson } 770d0321e0SJeremy L Thompson 780d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 790d0321e0SJeremy L Thompson // Set User QFunction 800d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 812b730f8bSJeremy L Thompson static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf, CUfunction f) { 820d0321e0SJeremy L Thompson CeedQFunction_Cuda *data; 83*ca735530SJeremy L Thompson 842b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &data)); 85437930d1SJeremy L Thompson data->QFunction = f; 860d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 870d0321e0SJeremy L Thompson } 880d0321e0SJeremy L Thompson 890d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 900d0321e0SJeremy L Thompson // Create QFunction 910d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 920d0321e0SJeremy L Thompson int CeedQFunctionCreate_Cuda(CeedQFunction qf) { 930d0321e0SJeremy L Thompson Ceed ceed; 940d0321e0SJeremy L Thompson CeedQFunction_Cuda *data; 95*ca735530SJeremy L Thompson 96*ca735530SJeremy L Thompson CeedQFunctionGetCeed(qf, &ceed); 972b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &data)); 982b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionSetData(qf, data)); 990d0321e0SJeremy L Thompson 1000d0321e0SJeremy L Thompson // Read QFunction source 1012b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name)); 10223d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n"); 1032b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source)); 10423d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n"); 1050d0321e0SJeremy L Thompson 1060d0321e0SJeremy L Thompson // Register backend functions 1072b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Cuda)); 1082b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Cuda)); 1092b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction", CeedQFunctionSetCUDAUserFunction_Cuda)); 1100d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1110d0321e0SJeremy L Thompson } 1122a86cc9dSSebastian Grimberg 1130d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 114