xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 5dfaedb85d2aa5da89951bb5d8f41d61be09bbf6)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <stddef.h>
11 #include "ceed-hip-gen.h"
12 #include "ceed-hip-gen-operator-build.h"
13 #include "../hip/ceed-hip-compile.h"
14 
15 //------------------------------------------------------------------------------
16 // Destroy operator
17 //------------------------------------------------------------------------------
18 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
19   int ierr;
20   CeedOperator_Hip_gen *impl;
21   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
22   ierr = CeedFree(&impl); CeedChkBackend(ierr);
23   return CEED_ERROR_SUCCESS;
24 }
25 
26 //------------------------------------------------------------------------------
27 // Apply and add to output
28 //------------------------------------------------------------------------------
29 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector invec,
30                                         CeedVector outvec, CeedRequest *request) {
31   int ierr;
32   Ceed ceed;
33   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
34   CeedOperator_Hip_gen *data;
35   ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr);
36   CeedQFunction qf;
37   CeedQFunction_Hip_gen *qf_data;
38   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
39   ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr);
40   CeedInt nelem, numinputfields, numoutputfields;
41   ierr = CeedOperatorGetNumElements(op, &nelem); CeedChkBackend(ierr);
42   CeedOperatorField *opinputfields, *opoutputfields;
43   ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields,
44                                &numoutputfields, &opoutputfields);
45   CeedChkBackend(ierr);
46   CeedQFunctionField *qfinputfields, *qfoutputfields;
47   ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields);
48   CeedChkBackend(ierr);
49   CeedEvalMode emode;
50   CeedVector vec, outvecs[16] = {};
51 
52   //Creation of the operator
53   ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr);
54 
55   // Input vectors
56   for (CeedInt i = 0; i < numinputfields; i++) {
57     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
58     CeedChkBackend(ierr);
59     if (emode == CEED_EVAL_WEIGHT) { // Skip
60       data->fields.in[i] = NULL;
61     } else {
62       // Get input vector
63       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
64       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
65       ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.in[i]);
66       CeedChkBackend(ierr);
67     }
68   }
69 
70   // Output vectors
71   for (CeedInt i = 0; i < numoutputfields; i++) {
72     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
73     CeedChkBackend(ierr);
74     if (emode == CEED_EVAL_WEIGHT) { // Skip
75       data->fields.out[i] = NULL;
76     } else {
77       // Get output vector
78       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
79       CeedChkBackend(ierr);
80       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
81       outvecs[i] = vec;
82       // Check for multiple output modes
83       CeedInt index = -1;
84       for (CeedInt j = 0; j < i; j++) {
85         if (vec == outvecs[j]) {
86           index = j;
87           break;
88         }
89       }
90       if (index == -1) {
91         ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.out[i]);
92         CeedChkBackend(ierr);
93       } else {
94         data->fields.out[i] = data->fields.out[index];
95       }
96     }
97   }
98 
99   // Get context data
100   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c);
101   CeedChkBackend(ierr);
102 
103   // Apply operator
104   void *opargs[] = {(void *) &nelem, &qf_data->d_c, &data->indices,
105                     &data->fields, &data->B, &data->G, &data->W
106                    };
107   const CeedInt dim = data->dim;
108   const CeedInt Q1d = data->Q1d;
109   const CeedInt P1d = data->maxP1d;
110   const CeedInt thread1d = CeedIntMax(Q1d, P1d);
111   CeedInt block_sizes[3];
112   ierr = BlockGridCalculate_Hip_gen(dim, nelem, P1d, Q1d, block_sizes);
113   CeedChkBackend(ierr);
114   if (dim==1) {
115     CeedInt grid = nelem/block_sizes[2] + ( (
116         nelem/block_sizes[2]*block_sizes[2]<nelem)
117                                             ? 1 : 0 );
118     CeedInt sharedMem = block_sizes[2]*thread1d*sizeof(CeedScalar);
119     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
120                                      block_sizes[1],
121                                      block_sizes[2], sharedMem, opargs);
122   } else if (dim==2) {
123     CeedInt grid = nelem/block_sizes[2] + ( (
124         nelem/block_sizes[2]*block_sizes[2]<nelem)
125                                             ? 1 : 0 );
126     CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar);
127     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
128                                      block_sizes[1],
129                                      block_sizes[2], sharedMem, opargs);
130   } else if (dim==3) {
131     CeedInt grid = nelem/block_sizes[2] + ( (
132         nelem/block_sizes[2]*block_sizes[2]<nelem)
133                                             ? 1 : 0 );
134     CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar);
135     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
136                                      block_sizes[1],
137                                      block_sizes[2], sharedMem, opargs);
138   }
139   CeedChkBackend(ierr);
140 
141   // Restore input arrays
142   for (CeedInt i = 0; i < numinputfields; i++) {
143     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
144     CeedChkBackend(ierr);
145     if (emode == CEED_EVAL_WEIGHT) { // Skip
146     } else {
147       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
148       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
149       ierr = CeedVectorRestoreArrayRead(vec, &data->fields.in[i]);
150       CeedChkBackend(ierr);
151     }
152   }
153 
154   // Restore output arrays
155   for (CeedInt i = 0; i < numoutputfields; i++) {
156     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
157     CeedChkBackend(ierr);
158     if (emode == CEED_EVAL_WEIGHT) { // Skip
159     } else {
160       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
161       CeedChkBackend(ierr);
162       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
163       // Check for multiple output modes
164       CeedInt index = -1;
165       for (CeedInt j = 0; j < i; j++) {
166         if (vec == outvecs[j]) {
167           index = j;
168           break;
169         }
170       }
171       if (index == -1) {
172         ierr = CeedVectorRestoreArray(vec, &data->fields.out[i]);
173         CeedChkBackend(ierr);
174       }
175     }
176   }
177 
178   // Restore context data
179   ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c);
180   CeedChkBackend(ierr);
181 
182   return CEED_ERROR_SUCCESS;
183 }
184 
185 //------------------------------------------------------------------------------
186 // Create operator
187 //------------------------------------------------------------------------------
188 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
189   int ierr;
190   Ceed ceed;
191   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
192   CeedOperator_Hip_gen *impl;
193 
194   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
195   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
196 
197   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
198                                 CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr);
199   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
200                                 CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr);
201   return CEED_ERROR_SUCCESS;
202 }
203 //------------------------------------------------------------------------------
204