1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <hip/hip_runtime.h>
12
13 #include "../hip/ceed-hip-common.h"
14 #include "../hip/ceed-hip-compile.h"
15 #include "ceed-hip-ref-qfunction-load.h"
16 #include "ceed-hip-ref.h"
17
18 //------------------------------------------------------------------------------
19 // Apply QFunction
20 //------------------------------------------------------------------------------
CeedQFunctionApply_Hip(CeedQFunction qf,CeedInt Q,CeedVector * U,CeedVector * V)21 static int CeedQFunctionApply_Hip(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
22 Ceed ceed;
23 Ceed_Hip *ceed_Hip;
24 CeedInt num_input_fields, num_output_fields;
25 CeedQFunction_Hip *data;
26
27 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
28
29 // Build and compile kernel, if not done
30 CeedCallBackend(CeedQFunctionBuildKernel_Hip_ref(qf));
31
32 CeedCallBackend(CeedQFunctionGetData(qf, &data));
33 CeedCallBackend(CeedGetData(ceed, &ceed_Hip));
34 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
35 const int block_size = ceed_Hip->opt_block_size;
36
37 // Read vectors
38 for (CeedInt i = 0; i < num_input_fields; i++) {
39 CeedCallBackend(CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]));
40 }
41 for (CeedInt i = 0; i < num_output_fields; i++) {
42 CeedCallBackend(CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]));
43 }
44
45 // Get context data
46 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c));
47
48 // Run kernel
49 void *args[] = {&data->d_c, (void *)&Q, &data->fields};
50
51 CeedCallBackend(CeedRunKernel_Hip(ceed, data->QFunction, CeedDivUpInt(Q, block_size), block_size, args));
52
53 // Restore vectors
54 for (CeedInt i = 0; i < num_input_fields; i++) {
55 CeedCallBackend(CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]));
56 }
57 for (CeedInt i = 0; i < num_output_fields; i++) {
58 CeedCallBackend(CeedVectorRestoreArray(V[i], &data->fields.outputs[i]));
59 }
60
61 // Restore context
62 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c));
63 CeedCallBackend(CeedDestroy(&ceed));
64 return CEED_ERROR_SUCCESS;
65 }
66
67 //------------------------------------------------------------------------------
68 // Destroy QFunction
69 //------------------------------------------------------------------------------
CeedQFunctionDestroy_Hip(CeedQFunction qf)70 static int CeedQFunctionDestroy_Hip(CeedQFunction qf) {
71 CeedQFunction_Hip *data;
72
73 CeedCallBackend(CeedQFunctionGetData(qf, &data));
74 if (data->module) CeedCallHip(CeedQFunctionReturnCeed(qf), hipModuleUnload(data->module));
75 CeedCallBackend(CeedFree(&data));
76 return CEED_ERROR_SUCCESS;
77 }
78
79 //------------------------------------------------------------------------------
80 // Create QFunction
81 //------------------------------------------------------------------------------
CeedQFunctionCreate_Hip(CeedQFunction qf)82 int CeedQFunctionCreate_Hip(CeedQFunction qf) {
83 Ceed ceed;
84 CeedInt num_input_fields, num_output_fields;
85 CeedQFunction_Hip *data;
86
87 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
88 CeedCallBackend(CeedCalloc(1, &data));
89 CeedCallBackend(CeedQFunctionSetData(qf, data));
90 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
91
92 CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name));
93
94 // Register backend functions
95 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Hip));
96 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Hip));
97 CeedCallBackend(CeedDestroy(&ceed));
98 return CEED_ERROR_SUCCESS;
99 }
100
101 //------------------------------------------------------------------------------
102