xref: /libCEED/rust/libceed-sys/c-src/backends/hip-gen/ceed-hip-gen-operator.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1*3d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*3d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
37d8d0e25Snbeams //
4*3d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
57d8d0e25Snbeams //
6*3d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
77d8d0e25Snbeams 
8ec3da8bcSJed Brown #include <ceed/ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
103d576824SJeremy L Thompson #include <stddef.h>
117d8d0e25Snbeams #include "ceed-hip-gen.h"
127d8d0e25Snbeams #include "ceed-hip-gen-operator-build.h"
137d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
147d8d0e25Snbeams 
157d8d0e25Snbeams //------------------------------------------------------------------------------
167d8d0e25Snbeams // Destroy operator
177d8d0e25Snbeams //------------------------------------------------------------------------------
187d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
197d8d0e25Snbeams   int ierr;
207d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
21e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
22e15f9bd0SJeremy L Thompson   ierr = CeedFree(&impl); CeedChkBackend(ierr);
23e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
247d8d0e25Snbeams }
257d8d0e25Snbeams 
267d8d0e25Snbeams //------------------------------------------------------------------------------
277d8d0e25Snbeams // Apply and add to output
287d8d0e25Snbeams //------------------------------------------------------------------------------
297d8d0e25Snbeams static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector invec,
307d8d0e25Snbeams                                         CeedVector outvec, CeedRequest *request) {
317d8d0e25Snbeams   int ierr;
327d8d0e25Snbeams   Ceed ceed;
33e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
347d8d0e25Snbeams   CeedOperator_Hip_gen *data;
35e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr);
367d8d0e25Snbeams   CeedQFunction qf;
377d8d0e25Snbeams   CeedQFunction_Hip_gen *qf_data;
38e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
39e15f9bd0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr);
407d8d0e25Snbeams   CeedInt nelem, numinputfields, numoutputfields;
41e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetNumElements(op, &nelem); CeedChkBackend(ierr);
427d8d0e25Snbeams   CeedOperatorField *opinputfields, *opoutputfields;
437e7773b5SJeremy L Thompson   ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields,
447e7773b5SJeremy L Thompson                                &numoutputfields, &opoutputfields);
45e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
467d8d0e25Snbeams   CeedQFunctionField *qfinputfields, *qfoutputfields;
477e7773b5SJeremy L Thompson   ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields);
48e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
497d8d0e25Snbeams   CeedEvalMode emode;
507d8d0e25Snbeams   CeedVector vec, outvecs[16] = {};
517d8d0e25Snbeams 
527d8d0e25Snbeams   //Creation of the operator
53e15f9bd0SJeremy L Thompson   ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr);
547d8d0e25Snbeams 
557d8d0e25Snbeams   // Input vectors
567d8d0e25Snbeams   for (CeedInt i = 0; i < numinputfields; i++) {
577d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
58e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
597d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
607d8d0e25Snbeams       data->fields.in[i] = NULL;
617d8d0e25Snbeams     } else {
627d8d0e25Snbeams       // Get input vector
63e15f9bd0SJeremy L Thompson       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
647d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
657d8d0e25Snbeams       ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.in[i]);
66e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
677d8d0e25Snbeams     }
687d8d0e25Snbeams   }
697d8d0e25Snbeams 
707d8d0e25Snbeams   // Output vectors
717d8d0e25Snbeams   for (CeedInt i = 0; i < numoutputfields; i++) {
727d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
73e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
747d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
757d8d0e25Snbeams       data->fields.out[i] = NULL;
767d8d0e25Snbeams     } else {
777d8d0e25Snbeams       // Get output vector
78e15f9bd0SJeremy L Thompson       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
79e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
807d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
817d8d0e25Snbeams       outvecs[i] = vec;
827d8d0e25Snbeams       // Check for multiple output modes
837d8d0e25Snbeams       CeedInt index = -1;
847d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
857d8d0e25Snbeams         if (vec == outvecs[j]) {
867d8d0e25Snbeams           index = j;
877d8d0e25Snbeams           break;
887d8d0e25Snbeams         }
897d8d0e25Snbeams       }
907d8d0e25Snbeams       if (index == -1) {
917d8d0e25Snbeams         ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.out[i]);
92e15f9bd0SJeremy L Thompson         CeedChkBackend(ierr);
937d8d0e25Snbeams       } else {
947d8d0e25Snbeams         data->fields.out[i] = data->fields.out[index];
957d8d0e25Snbeams       }
967d8d0e25Snbeams     }
977d8d0e25Snbeams   }
987d8d0e25Snbeams 
997d8d0e25Snbeams   // Get context data
100441428dfSJeremy L Thompson   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c);
101e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1027d8d0e25Snbeams 
1037d8d0e25Snbeams   // Apply operator
1047d8d0e25Snbeams   void *opargs[] = {(void *) &nelem, &qf_data->d_c, &data->indices,
1057d8d0e25Snbeams                     &data->fields, &data->B, &data->G, &data->W
1067d8d0e25Snbeams                    };
1077d8d0e25Snbeams   const CeedInt dim = data->dim;
1087d8d0e25Snbeams   const CeedInt Q1d = data->Q1d;
1097d8d0e25Snbeams   const CeedInt P1d = data->maxP1d;
1107d8d0e25Snbeams   const CeedInt thread1d = CeedIntMax(Q1d, P1d);
111b3e1519bSnbeams   CeedInt block_sizes[3];
11237c3b1cfSnbeams   ierr = BlockGridCalculate_Hip_gen(dim, nelem, P1d, Q1d, block_sizes);
113b3e1519bSnbeams   CeedChkBackend(ierr);
1147d8d0e25Snbeams   if (dim==1) {
115b3e1519bSnbeams     CeedInt grid = nelem/block_sizes[2] + ( (
116b3e1519bSnbeams         nelem/block_sizes[2]*block_sizes[2]<nelem)
1177d8d0e25Snbeams                                             ? 1 : 0 );
118b3e1519bSnbeams     CeedInt sharedMem = block_sizes[2]*thread1d*sizeof(CeedScalar);
119b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
120b3e1519bSnbeams                                      block_sizes[1],
121b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1227d8d0e25Snbeams   } else if (dim==2) {
123b3e1519bSnbeams     CeedInt grid = nelem/block_sizes[2] + ( (
124b3e1519bSnbeams         nelem/block_sizes[2]*block_sizes[2]<nelem)
1257d8d0e25Snbeams                                             ? 1 : 0 );
126b3e1519bSnbeams     CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar);
127b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
128b3e1519bSnbeams                                      block_sizes[1],
129b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1307d8d0e25Snbeams   } else if (dim==3) {
131b3e1519bSnbeams     CeedInt grid = nelem/block_sizes[2] + ( (
132b3e1519bSnbeams         nelem/block_sizes[2]*block_sizes[2]<nelem)
1337d8d0e25Snbeams                                             ? 1 : 0 );
134b3e1519bSnbeams     CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar);
135b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
136b3e1519bSnbeams                                      block_sizes[1],
137b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1387d8d0e25Snbeams   }
139e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1407d8d0e25Snbeams 
1417d8d0e25Snbeams   // Restore input arrays
1427d8d0e25Snbeams   for (CeedInt i = 0; i < numinputfields; i++) {
1437d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
144e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
1457d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
1467d8d0e25Snbeams     } else {
147e15f9bd0SJeremy L Thompson       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr);
1487d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = invec;
1497d8d0e25Snbeams       ierr = CeedVectorRestoreArrayRead(vec, &data->fields.in[i]);
150e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
1517d8d0e25Snbeams     }
1527d8d0e25Snbeams   }
1537d8d0e25Snbeams 
1547d8d0e25Snbeams   // Restore output arrays
1557d8d0e25Snbeams   for (CeedInt i = 0; i < numoutputfields; i++) {
1567d8d0e25Snbeams     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
157e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
1587d8d0e25Snbeams     if (emode == CEED_EVAL_WEIGHT) { // Skip
1597d8d0e25Snbeams     } else {
160e15f9bd0SJeremy L Thompson       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec);
161e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
1627d8d0e25Snbeams       if (vec == CEED_VECTOR_ACTIVE) vec = outvec;
1637d8d0e25Snbeams       // Check for multiple output modes
1647d8d0e25Snbeams       CeedInt index = -1;
1657d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
1667d8d0e25Snbeams         if (vec == outvecs[j]) {
1677d8d0e25Snbeams           index = j;
1687d8d0e25Snbeams           break;
1697d8d0e25Snbeams         }
1707d8d0e25Snbeams       }
1717d8d0e25Snbeams       if (index == -1) {
1727d8d0e25Snbeams         ierr = CeedVectorRestoreArray(vec, &data->fields.out[i]);
173e15f9bd0SJeremy L Thompson         CeedChkBackend(ierr);
1747d8d0e25Snbeams       }
1757d8d0e25Snbeams     }
1767d8d0e25Snbeams   }
1777d8d0e25Snbeams 
1787d8d0e25Snbeams   // Restore context data
179441428dfSJeremy L Thompson   ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c);
180e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
181441428dfSJeremy L Thompson 
182e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1837d8d0e25Snbeams }
1847d8d0e25Snbeams 
1857d8d0e25Snbeams //------------------------------------------------------------------------------
1867d8d0e25Snbeams // Create operator
1877d8d0e25Snbeams //------------------------------------------------------------------------------
1887d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) {
1897d8d0e25Snbeams   int ierr;
1907d8d0e25Snbeams   Ceed ceed;
191e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
1927d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
1937d8d0e25Snbeams 
194e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
195e15f9bd0SJeremy L Thompson   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
1967d8d0e25Snbeams 
1977d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
198e15f9bd0SJeremy L Thompson                                 CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr);
1997d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
200e15f9bd0SJeremy L Thompson                                 CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr);
201e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2027d8d0e25Snbeams }
2037d8d0e25Snbeams //------------------------------------------------------------------------------
204