1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <stddef.h> 11 #include "ceed-hip-gen.h" 12 #include "ceed-hip-gen-operator-build.h" 13 #include "../hip/ceed-hip-compile.h" 14 15 //------------------------------------------------------------------------------ 16 // Destroy operator 17 //------------------------------------------------------------------------------ 18 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 19 int ierr; 20 CeedOperator_Hip_gen *impl; 21 ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 22 ierr = CeedFree(&impl); CeedChkBackend(ierr); 23 return CEED_ERROR_SUCCESS; 24 } 25 26 //------------------------------------------------------------------------------ 27 // Apply and add to output 28 //------------------------------------------------------------------------------ 29 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector invec, 30 CeedVector outvec, CeedRequest *request) { 31 int ierr; 32 Ceed ceed; 33 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 34 CeedOperator_Hip_gen *data; 35 ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr); 36 CeedQFunction qf; 37 CeedQFunction_Hip_gen *qf_data; 38 ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 39 ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr); 40 CeedInt nelem, numinputfields, numoutputfields; 41 ierr = CeedOperatorGetNumElements(op, &nelem); CeedChkBackend(ierr); 42 CeedOperatorField *opinputfields, *opoutputfields; 43 ierr = CeedOperatorGetFields(op, &numinputfields, &opinputfields, 44 &numoutputfields, &opoutputfields); 45 CeedChkBackend(ierr); 46 CeedQFunctionField *qfinputfields, *qfoutputfields; 47 ierr = CeedQFunctionGetFields(qf, NULL, &qfinputfields, NULL, &qfoutputfields); 48 CeedChkBackend(ierr); 49 CeedEvalMode emode; 50 CeedVector vec, outvecs[16] = {}; 51 52 //Creation of the operator 53 ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr); 54 55 // Input vectors 56 for (CeedInt i = 0; i < numinputfields; i++) { 57 ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode); 58 CeedChkBackend(ierr); 59 if (emode == CEED_EVAL_WEIGHT) { // Skip 60 data->fields.in[i] = NULL; 61 } else { 62 // Get input vector 63 ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr); 64 if (vec == CEED_VECTOR_ACTIVE) vec = invec; 65 ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.in[i]); 66 CeedChkBackend(ierr); 67 } 68 } 69 70 // Output vectors 71 for (CeedInt i = 0; i < numoutputfields; i++) { 72 ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode); 73 CeedChkBackend(ierr); 74 if (emode == CEED_EVAL_WEIGHT) { // Skip 75 data->fields.out[i] = NULL; 76 } else { 77 // Get output vector 78 ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); 79 CeedChkBackend(ierr); 80 if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 81 outvecs[i] = vec; 82 // Check for multiple output modes 83 CeedInt index = -1; 84 for (CeedInt j = 0; j < i; j++) { 85 if (vec == outvecs[j]) { 86 index = j; 87 break; 88 } 89 } 90 if (index == -1) { 91 ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.out[i]); 92 CeedChkBackend(ierr); 93 } else { 94 data->fields.out[i] = data->fields.out[index]; 95 } 96 } 97 } 98 99 // Get context data 100 ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c); 101 CeedChkBackend(ierr); 102 103 // Apply operator 104 void *opargs[] = {(void *) &nelem, &qf_data->d_c, &data->indices, 105 &data->fields, &data->B, &data->G, &data->W 106 }; 107 const CeedInt dim = data->dim; 108 const CeedInt Q1d = data->Q1d; 109 const CeedInt P1d = data->maxP1d; 110 const CeedInt thread1d = CeedIntMax(Q1d, P1d); 111 CeedInt block_sizes[3]; 112 ierr = BlockGridCalculate_Hip_gen(dim, nelem, P1d, Q1d, block_sizes); 113 CeedChkBackend(ierr); 114 if (dim==1) { 115 CeedInt grid = nelem/block_sizes[2] + ( ( 116 nelem/block_sizes[2]*block_sizes[2]<nelem) 117 ? 1 : 0 ); 118 CeedInt sharedMem = block_sizes[2]*thread1d*sizeof(CeedScalar); 119 ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 120 block_sizes[1], 121 block_sizes[2], sharedMem, opargs); 122 } else if (dim==2) { 123 CeedInt grid = nelem/block_sizes[2] + ( ( 124 nelem/block_sizes[2]*block_sizes[2]<nelem) 125 ? 1 : 0 ); 126 CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar); 127 ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 128 block_sizes[1], 129 block_sizes[2], sharedMem, opargs); 130 } else if (dim==3) { 131 CeedInt grid = nelem/block_sizes[2] + ( ( 132 nelem/block_sizes[2]*block_sizes[2]<nelem) 133 ? 1 : 0 ); 134 CeedInt sharedMem = block_sizes[2]*thread1d*thread1d*sizeof(CeedScalar); 135 ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 136 block_sizes[1], 137 block_sizes[2], sharedMem, opargs); 138 } 139 CeedChkBackend(ierr); 140 141 // Restore input arrays 142 for (CeedInt i = 0; i < numinputfields; i++) { 143 ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode); 144 CeedChkBackend(ierr); 145 if (emode == CEED_EVAL_WEIGHT) { // Skip 146 } else { 147 ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChkBackend(ierr); 148 if (vec == CEED_VECTOR_ACTIVE) vec = invec; 149 ierr = CeedVectorRestoreArrayRead(vec, &data->fields.in[i]); 150 CeedChkBackend(ierr); 151 } 152 } 153 154 // Restore output arrays 155 for (CeedInt i = 0; i < numoutputfields; i++) { 156 ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode); 157 CeedChkBackend(ierr); 158 if (emode == CEED_EVAL_WEIGHT) { // Skip 159 } else { 160 ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); 161 CeedChkBackend(ierr); 162 if (vec == CEED_VECTOR_ACTIVE) vec = outvec; 163 // Check for multiple output modes 164 CeedInt index = -1; 165 for (CeedInt j = 0; j < i; j++) { 166 if (vec == outvecs[j]) { 167 index = j; 168 break; 169 } 170 } 171 if (index == -1) { 172 ierr = CeedVectorRestoreArray(vec, &data->fields.out[i]); 173 CeedChkBackend(ierr); 174 } 175 } 176 } 177 178 // Restore context data 179 ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c); 180 CeedChkBackend(ierr); 181 182 return CEED_ERROR_SUCCESS; 183 } 184 185 //------------------------------------------------------------------------------ 186 // Create operator 187 //------------------------------------------------------------------------------ 188 int CeedOperatorCreate_Hip_gen(CeedOperator op) { 189 int ierr; 190 Ceed ceed; 191 ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 192 CeedOperator_Hip_gen *impl; 193 194 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 195 ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr); 196 197 ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 198 CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr); 199 ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 200 CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr); 201 return CEED_ERROR_SUCCESS; 202 } 203 //------------------------------------------------------------------------------ 204