1*7d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2*7d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3*7d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 4*7d8d0e25Snbeams // 5*7d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 6*7d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 7*7d8d0e25Snbeams // element discretizations for exascale applications. For more information and 8*7d8d0e25Snbeams // source code availability see http://github.com/ceed. 9*7d8d0e25Snbeams // 10*7d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11*7d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 12*7d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 13*7d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 14*7d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 15*7d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 16*7d8d0e25Snbeams 17*7d8d0e25Snbeams #include "ceed-hip-gen.h" 18*7d8d0e25Snbeams #include "ceed-hip-gen-operator-build.h" 19*7d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 20*7d8d0e25Snbeams 21*7d8d0e25Snbeams //------------------------------------------------------------------------------ 22*7d8d0e25Snbeams // Destroy operator 23*7d8d0e25Snbeams //------------------------------------------------------------------------------ 24*7d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 25*7d8d0e25Snbeams int ierr; 26*7d8d0e25Snbeams CeedOperator_Hip_gen *impl; 27*7d8d0e25Snbeams ierr = CeedOperatorGetData(op, &impl); CeedChk(ierr); 28*7d8d0e25Snbeams ierr = CeedFree(&impl); CeedChk(ierr); 29*7d8d0e25Snbeams return 0; 30*7d8d0e25Snbeams } 31*7d8d0e25Snbeams 32*7d8d0e25Snbeams //------------------------------------------------------------------------------ 33*7d8d0e25Snbeams // Apply and add to output 34*7d8d0e25Snbeams //------------------------------------------------------------------------------ 35*7d8d0e25Snbeams static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector invec, 36*7d8d0e25Snbeams CeedVector outvec, CeedRequest *request) { 37*7d8d0e25Snbeams int ierr; 38*7d8d0e25Snbeams Ceed ceed; 39*7d8d0e25Snbeams ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr); 40*7d8d0e25Snbeams CeedOperator_Hip_gen *data; 41*7d8d0e25Snbeams ierr = CeedOperatorGetData(op, &data); CeedChk(ierr); 42*7d8d0e25Snbeams CeedQFunction qf; 43*7d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 44*7d8d0e25Snbeams ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr); 45*7d8d0e25Snbeams ierr = CeedQFunctionGetData(qf, &qf_data); CeedChk(ierr); 46*7d8d0e25Snbeams CeedInt nelem, numinputfields, numoutputfields; 47*7d8d0e25Snbeams ierr = CeedOperatorGetNumElements(op, &nelem); CeedChk(ierr); 48*7d8d0e25Snbeams ierr = CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields); 49*7d8d0e25Snbeams CeedChk(ierr); 50*7d8d0e25Snbeams CeedOperatorField *opinputfields, *opoutputfields; 51*7d8d0e25Snbeams ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields); 52*7d8d0e25Snbeams CeedChk(ierr); 53*7d8d0e25Snbeams CeedQFunctionField *qfinputfields, *qfoutputfields; 54*7d8d0e25Snbeams ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields); 55*7d8d0e25Snbeams CeedChk(ierr); 56*7d8d0e25Snbeams CeedEvalMode emode; 57*7d8d0e25Snbeams CeedVector vec, outvecs[16] = {}; 58*7d8d0e25Snbeams 59*7d8d0e25Snbeams //Creation of the operator 60*7d8d0e25Snbeams ierr = CeedHipGenOperatorBuild(op); CeedChk(ierr); 61*7d8d0e25Snbeams 62*7d8d0e25Snbeams // Input vectors 63*7d8d0e25Snbeams for (CeedInt i = 0; i < numinputfields; i++) { 64*7d8d0e25Snbeams ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode); 65*7d8d0e25Snbeams CeedChk(ierr); 66*7d8d0e25Snbeams if (emode == CEED_EVAL_WEIGHT) { // Skip 67*7d8d0e25Snbeams data->fields.in[i] = NULL; 68*7d8d0e25Snbeams } else { 69*7d8d0e25Snbeams // Get input vector 70*7d8d0e25Snbeams ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr); 71*7d8d0e25Snbeams if (vec == CEED_VECTOR_ACTIVE) vec = invec; 72*7d8d0e25Snbeams ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.in[i]); 73*7d8d0e25Snbeams CeedChk(ierr); 74*7d8d0e25Snbeams } 75*7d8d0e25Snbeams } 76*7d8d0e25Snbeams 77*7d8d0e25Snbeams // Output vectors 78*7d8d0e25Snbeams for (CeedInt i = 0; i < numoutputfields; i++) { 79*7d8d0e25Snbeams ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode); 80*7d8d0e25Snbeams CeedChk(ierr); 81*7d8d0e25Snbeams if (emode == CEED_EVAL_WEIGHT) { // Skip 82*7d8d0e25Snbeams data->fields.out[i] = NULL; 83*7d8d0e25Snbeams } else { 84*7d8d0e25Snbeams // Get output vector 85*7d8d0e25Snbeams ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr); 86*7d8d0e25Snbeams if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 87*7d8d0e25Snbeams outvecs[i] = vec; 88*7d8d0e25Snbeams // Check for multiple output modes 89*7d8d0e25Snbeams CeedInt index = -1; 90*7d8d0e25Snbeams for (CeedInt j = 0; j < i; j++) { 91*7d8d0e25Snbeams if (vec == outvecs[j]) { 92*7d8d0e25Snbeams index = j; 93*7d8d0e25Snbeams break; 94*7d8d0e25Snbeams } 95*7d8d0e25Snbeams } 96*7d8d0e25Snbeams if (index == -1) { 97*7d8d0e25Snbeams ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.out[i]); 98*7d8d0e25Snbeams CeedChk(ierr); 99*7d8d0e25Snbeams } else { 100*7d8d0e25Snbeams data->fields.out[i] = data->fields.out[index]; 101*7d8d0e25Snbeams } 102*7d8d0e25Snbeams } 103*7d8d0e25Snbeams } 104*7d8d0e25Snbeams 105*7d8d0e25Snbeams // Get context data 106*7d8d0e25Snbeams CeedQFunctionContext ctx; 107*7d8d0e25Snbeams ierr = CeedQFunctionGetInnerContext(qf, &ctx); CeedChk(ierr); 108*7d8d0e25Snbeams if (ctx) { 109*7d8d0e25Snbeams ierr = CeedQFunctionContextGetData(ctx, CEED_MEM_DEVICE, &qf_data->d_c); 110*7d8d0e25Snbeams CeedChk(ierr); 111*7d8d0e25Snbeams } 112*7d8d0e25Snbeams 113*7d8d0e25Snbeams // Apply operator 114*7d8d0e25Snbeams void *opargs[] = {(void *) &nelem, &qf_data->d_c, &data->indices, 115*7d8d0e25Snbeams &data->fields, &data->B, &data->G, &data->W 116*7d8d0e25Snbeams }; 117*7d8d0e25Snbeams const CeedInt dim = data->dim; 118*7d8d0e25Snbeams const CeedInt Q1d = data->Q1d; 119*7d8d0e25Snbeams const CeedInt P1d = data->maxP1d; 120*7d8d0e25Snbeams const CeedInt thread1d = CeedIntMax(Q1d, P1d); 121*7d8d0e25Snbeams if (dim==1) { 122*7d8d0e25Snbeams CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32; 123*7d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 124*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 125*7d8d0e25Snbeams ? 1 : 0 ); 126*7d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 127*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, 1, 128*7d8d0e25Snbeams elemsPerBlock, sharedMem, opargs); 129*7d8d0e25Snbeams } else if (dim==2) { 130*7d8d0e25Snbeams const CeedInt elemsPerBlock = thread1d<4? 16 : 2; 131*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 132*7d8d0e25Snbeams ? 1 : 0 ); 133*7d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 134*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, thread1d, 135*7d8d0e25Snbeams elemsPerBlock, sharedMem, opargs); 136*7d8d0e25Snbeams } else if (dim==3) { 137*7d8d0e25Snbeams const CeedInt elemsPerBlock = thread1d<6? 4 : (thread1d<8? 2 : 1); 138*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 139*7d8d0e25Snbeams ? 1 : 0 ); 140*7d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 141*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, thread1d, thread1d, 142*7d8d0e25Snbeams elemsPerBlock, sharedMem, opargs); 143*7d8d0e25Snbeams } 144*7d8d0e25Snbeams CeedChk(ierr); 145*7d8d0e25Snbeams 146*7d8d0e25Snbeams // Restore input arrays 147*7d8d0e25Snbeams for (CeedInt i = 0; i < numinputfields; i++) { 148*7d8d0e25Snbeams ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode); 149*7d8d0e25Snbeams CeedChk(ierr); 150*7d8d0e25Snbeams if (emode == CEED_EVAL_WEIGHT) { // Skip 151*7d8d0e25Snbeams } else { 152*7d8d0e25Snbeams ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr); 153*7d8d0e25Snbeams if (vec == CEED_VECTOR_ACTIVE) vec = invec; 154*7d8d0e25Snbeams ierr = CeedVectorRestoreArrayRead(vec, &data->fields.in[i]); 155*7d8d0e25Snbeams CeedChk(ierr); 156*7d8d0e25Snbeams } 157*7d8d0e25Snbeams } 158*7d8d0e25Snbeams 159*7d8d0e25Snbeams // Restore output arrays 160*7d8d0e25Snbeams for (CeedInt i = 0; i < numoutputfields; i++) { 161*7d8d0e25Snbeams ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode); 162*7d8d0e25Snbeams CeedChk(ierr); 163*7d8d0e25Snbeams if (emode == CEED_EVAL_WEIGHT) { // Skip 164*7d8d0e25Snbeams } else { 165*7d8d0e25Snbeams ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr); 166*7d8d0e25Snbeams if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 167*7d8d0e25Snbeams // Check for multiple output modes 168*7d8d0e25Snbeams CeedInt index = -1; 169*7d8d0e25Snbeams for (CeedInt j = 0; j < i; j++) { 170*7d8d0e25Snbeams if (vec == outvecs[j]) { 171*7d8d0e25Snbeams index = j; 172*7d8d0e25Snbeams break; 173*7d8d0e25Snbeams } 174*7d8d0e25Snbeams } 175*7d8d0e25Snbeams if (index == -1) { 176*7d8d0e25Snbeams ierr = CeedVectorRestoreArray(vec, &data->fields.out[i]); 177*7d8d0e25Snbeams CeedChk(ierr); 178*7d8d0e25Snbeams } 179*7d8d0e25Snbeams } 180*7d8d0e25Snbeams } 181*7d8d0e25Snbeams 182*7d8d0e25Snbeams // Restore context data 183*7d8d0e25Snbeams if (ctx) { 184*7d8d0e25Snbeams ierr = CeedQFunctionContextRestoreData(ctx, &qf_data->d_c); 185*7d8d0e25Snbeams CeedChk(ierr); 186*7d8d0e25Snbeams } 187*7d8d0e25Snbeams return 0; 188*7d8d0e25Snbeams } 189*7d8d0e25Snbeams 190*7d8d0e25Snbeams //------------------------------------------------------------------------------ 191*7d8d0e25Snbeams // Create FDM element inverse not supported 192*7d8d0e25Snbeams //------------------------------------------------------------------------------ 193*7d8d0e25Snbeams static int CeedOperatorCreateFDMElementInverse_Hip(CeedOperator op) { 194*7d8d0e25Snbeams // LCOV_EXCL_START 195*7d8d0e25Snbeams int ierr; 196*7d8d0e25Snbeams Ceed ceed; 197*7d8d0e25Snbeams ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr); 198*7d8d0e25Snbeams return CeedError(ceed, 1, "Backend does not implement FDM inverse creation"); 199*7d8d0e25Snbeams // LCOV_EXCL_STOP 200*7d8d0e25Snbeams } 201*7d8d0e25Snbeams 202*7d8d0e25Snbeams //------------------------------------------------------------------------------ 203*7d8d0e25Snbeams // Create operator 204*7d8d0e25Snbeams //------------------------------------------------------------------------------ 205*7d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 206*7d8d0e25Snbeams int ierr; 207*7d8d0e25Snbeams Ceed ceed; 208*7d8d0e25Snbeams ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr); 209*7d8d0e25Snbeams CeedOperator_Hip_gen *impl; 210*7d8d0e25Snbeams 211*7d8d0e25Snbeams ierr = CeedCalloc(1, &impl); CeedChk(ierr); 212*7d8d0e25Snbeams ierr = CeedOperatorSetData(op, impl); CeedChk(ierr); 213*7d8d0e25Snbeams 214*7d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Operator", op, "CreateFDMElementInverse", 215*7d8d0e25Snbeams CeedOperatorCreateFDMElementInverse_Hip); 216*7d8d0e25Snbeams CeedChk(ierr); 217*7d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 218*7d8d0e25Snbeams CeedOperatorApplyAdd_Hip_gen); CeedChk(ierr); 219*7d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 220*7d8d0e25Snbeams CeedOperatorDestroy_Hip_gen); CeedChk(ierr); 221*7d8d0e25Snbeams return 0; 222*7d8d0e25Snbeams } 223*7d8d0e25Snbeams //------------------------------------------------------------------------------ 224