1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Diffusion operator example using MFEM 10 11 #include <ceed.h> 12 #include <mfem.hpp> 13 #include "bp3.h" 14 15 /// Wrapper for a diffusion CeedOperator as an mfem::Operator 16 class CeedDiffusionOperator : public mfem::Operator { 17 protected: 18 const mfem::FiniteElementSpace *fes; 19 CeedOperator build_oper, oper; 20 CeedBasis basis, mesh_basis; 21 CeedElemRestriction restr, mesh_restr, restr_i, mesh_restr_i; 22 CeedQFunction apply_qfunc, build_qfunc; 23 CeedQFunctionContext build_ctx; 24 CeedVector node_coords, qdata; 25 26 BuildContext build_ctx_data; 27 28 CeedVector u, v; 29 30 static void FESpace2Ceed(const mfem::FiniteElementSpace *fes, 31 const mfem::IntegrationRule &ir, 32 Ceed ceed, CeedBasis *basis, 33 CeedElemRestriction *restr) { 34 mfem::Mesh *mesh = fes->GetMesh(); 35 const mfem::FiniteElement *fe = fes->GetFE(0); 36 const int order = fes->GetOrder(0); 37 mfem::Array<int> dof_map; 38 switch (mesh->Dimension()) { 39 case 1: { 40 const mfem::H1_SegmentElement *h1_fe = 41 dynamic_cast<const mfem::H1_SegmentElement *>(fe); 42 MFEM_VERIFY(h1_fe, "invalid FE"); 43 h1_fe->GetDofMap().Copy(dof_map); 44 break; 45 } 46 case 2: { 47 const mfem::H1_QuadrilateralElement *h1_fe = 48 dynamic_cast<const mfem::H1_QuadrilateralElement *>(fe); 49 MFEM_VERIFY(h1_fe, "invalid FE"); 50 h1_fe->GetDofMap().Copy(dof_map); 51 break; 52 } 53 case 3: { 54 const mfem::H1_HexahedronElement *h1_fe = 55 dynamic_cast<const mfem::H1_HexahedronElement *>(fe); 56 MFEM_VERIFY(h1_fe, "invalid FE"); 57 h1_fe->GetDofMap().Copy(dof_map); 58 break; 59 } 60 } 61 const mfem::FiniteElement *fe1d = 62 fes->FEColl()->FiniteElementForGeometry(mfem::Geometry::SEGMENT); 63 mfem::DenseMatrix shape1d(fe1d->GetDof(), ir.GetNPoints()); 64 mfem::DenseMatrix grad_1d(fe1d->GetDof(), ir.GetNPoints()); 65 mfem::Vector q_ref_1d(ir.GetNPoints()), q_weight_1d(ir.GetNPoints()); 66 mfem::Vector shape_i(shape1d.Height()); 67 mfem::DenseMatrix grad_i(grad_1d.Height(), 1); 68 const mfem::H1_SegmentElement *h1_fe1d = 69 dynamic_cast<const mfem::H1_SegmentElement *>(fe1d); 70 MFEM_VERIFY(h1_fe1d, "invalid FE"); 71 const mfem::Array<int> &dof_map_1d = h1_fe1d->GetDofMap(); 72 for (int i = 0; i < ir.GetNPoints(); i++) { 73 const mfem::IntegrationPoint &ip = ir.IntPoint(i); 74 q_ref_1d(i) = ip.x; 75 q_weight_1d(i) = ip.weight; 76 fe1d->CalcShape(ip, shape_i); 77 fe1d->CalcDShape(ip, grad_i); 78 for (int j = 0; j < shape1d.Height(); j++) { 79 shape1d(j,i) = shape_i(dof_map_1d[j]); 80 grad_1d(j,i) = grad_i(dof_map_1d[j],0); 81 } 82 } 83 CeedBasisCreateTensorH1(ceed, mesh->Dimension(), fes->GetVDim(), order+1, 84 ir.GetNPoints(), shape1d.GetData(), 85 grad_1d.GetData(), q_ref_1d.GetData(), 86 q_weight_1d.GetData(), basis); 87 88 const mfem::Table &el_dof = fes->GetElementToDofTable(); 89 mfem::Array<int> tp_el_dof(el_dof.Size_of_connections()); 90 for (int i = 0; i < mesh->GetNE(); i++) { 91 const int el_offset = fe->GetDof()*i; 92 for (int j = 0; j < fe->GetDof(); j++) { 93 tp_el_dof[j + el_offset] = el_dof.GetJ()[dof_map[j] + el_offset]; 94 } 95 } 96 CeedElemRestrictionCreate(ceed, mesh->GetNE(), fe->GetDof(), 97 fes->GetVDim(), fes->GetNDofs(), 98 (fes->GetVDim())*(fes->GetNDofs()), 99 CEED_MEM_HOST, CEED_COPY_VALUES, 100 tp_el_dof.GetData(), restr); 101 } 102 103 public: 104 /// Constructor. Assumes @a fes is a scalar FE space. 105 CeedDiffusionOperator(Ceed ceed, const mfem::FiniteElementSpace *fes) 106 : Operator(fes->GetNDofs()), 107 fes(fes) { 108 mfem::Mesh *mesh = fes->GetMesh(); 109 const int order = fes->GetOrder(0); 110 const int ir_order = 2*(order + 2) - 1; // <----- 111 const mfem::IntegrationRule &ir = 112 mfem::IntRules.Get(mfem::Geometry::SEGMENT, ir_order); 113 CeedInt num_elem = mesh->GetNE(), dim = mesh->SpaceDimension(), 114 ncompx = dim, nqpts; 115 116 FESpace2Ceed(fes, ir, ceed, &basis, &restr); 117 118 const mfem::FiniteElementSpace *mesh_fes = mesh->GetNodalFESpace(); 119 MFEM_VERIFY(mesh_fes, "the Mesh has no nodal FE space"); 120 FESpace2Ceed(mesh_fes, ir, ceed, &mesh_basis, &mesh_restr); 121 CeedBasisGetNumQuadraturePoints(basis, &nqpts); 122 123 CeedInt strides[3] = {1, nqpts, nqpts *dim *(dim+1)/2}; 124 CeedElemRestrictionCreateStrided(ceed, num_elem, nqpts, dim*(dim+1)/2, 125 dim*(dim+1)/2*nqpts*num_elem, strides, 126 &restr_i); 127 128 CeedVectorCreate(ceed, mesh->GetNodes()->Size(), &node_coords); 129 CeedVectorSetArray(node_coords, CEED_MEM_HOST, CEED_USE_POINTER, 130 mesh->GetNodes()->GetData()); 131 132 CeedVectorCreate(ceed, num_elem*nqpts*dim*(dim+1)/2, &qdata); 133 134 // Context data to be passed to the 'f_build_diff' Q-function. 135 build_ctx_data.dim = mesh->Dimension(); 136 build_ctx_data.space_dim = dim; 137 CeedQFunctionContextCreate(ceed, &build_ctx); 138 CeedQFunctionContextSetData(build_ctx, CEED_MEM_HOST, CEED_USE_POINTER, 139 sizeof(build_ctx_data), &build_ctx_data); 140 141 // Create the Q-function that builds the diff operator (i.e. computes its 142 // quadrature data) and set its context data. 143 CeedQFunctionCreateInterior(ceed, 1, f_build_diff, 144 f_build_diff_loc, &build_qfunc); 145 CeedQFunctionAddInput(build_qfunc, "dx", ncompx*dim, CEED_EVAL_GRAD); 146 CeedQFunctionAddInput(build_qfunc, "weights", 1, CEED_EVAL_WEIGHT); 147 CeedQFunctionAddOutput(build_qfunc, "qdata", dim*(dim+1)/2, CEED_EVAL_NONE); 148 CeedQFunctionSetContext(build_qfunc, build_ctx); 149 150 // Create the operator that builds the quadrature data for the diff operator. 151 CeedOperatorCreate(ceed, build_qfunc, CEED_QFUNCTION_NONE, 152 CEED_QFUNCTION_NONE, &build_oper); 153 CeedOperatorSetField(build_oper, "dx", mesh_restr, mesh_basis, 154 CEED_VECTOR_ACTIVE); 155 CeedOperatorSetField(build_oper, "weights", CEED_ELEMRESTRICTION_NONE, 156 mesh_basis, CEED_VECTOR_NONE); 157 CeedOperatorSetField(build_oper, "qdata", restr_i, CEED_BASIS_COLLOCATED, 158 CEED_VECTOR_ACTIVE); 159 160 // Compute the quadrature data for the diff operator. 161 CeedOperatorApply(build_oper, node_coords, qdata, 162 CEED_REQUEST_IMMEDIATE); 163 164 // Create the Q-function that defines the action of the diff operator. 165 CeedQFunctionCreateInterior(ceed, 1, f_apply_diff, 166 f_apply_diff_loc, &apply_qfunc); 167 CeedQFunctionAddInput(apply_qfunc, "u", dim, CEED_EVAL_GRAD); 168 CeedQFunctionAddInput(apply_qfunc, "qdata", dim*(dim+1)/2, CEED_EVAL_NONE); 169 CeedQFunctionAddOutput(apply_qfunc, "v", dim, CEED_EVAL_GRAD); 170 CeedQFunctionSetContext(apply_qfunc, build_ctx); 171 172 // Create the diff operator. 173 CeedOperatorCreate(ceed, apply_qfunc, CEED_QFUNCTION_NONE, 174 CEED_QFUNCTION_NONE, &oper); 175 CeedOperatorSetField(oper, "u", restr, basis, CEED_VECTOR_ACTIVE); 176 CeedOperatorSetField(oper, "qdata", restr_i, CEED_BASIS_COLLOCATED, qdata); 177 CeedOperatorSetField(oper, "v", restr, basis, CEED_VECTOR_ACTIVE); 178 179 CeedVectorCreate(ceed, fes->GetNDofs(), &u); 180 CeedVectorCreate(ceed, fes->GetNDofs(), &v); 181 } 182 183 /// Destructor 184 ~CeedDiffusionOperator() { 185 CeedVectorDestroy(&u); 186 CeedVectorDestroy(&v); 187 CeedVectorDestroy(&qdata); 188 CeedVectorDestroy(&node_coords); 189 CeedElemRestrictionDestroy(&restr); 190 CeedElemRestrictionDestroy(&mesh_restr); 191 CeedElemRestrictionDestroy(&restr_i); 192 CeedBasisDestroy(&basis); 193 CeedBasisDestroy(&mesh_basis); 194 CeedQFunctionDestroy(&build_qfunc); 195 CeedQFunctionContextDestroy(&build_ctx); 196 CeedOperatorDestroy(&build_oper); 197 CeedQFunctionDestroy(&apply_qfunc); 198 CeedOperatorDestroy(&oper); 199 } 200 201 /// Operator action 202 virtual void Mult(const mfem::Vector &x, mfem::Vector &y) const { 203 CeedVectorSetArray(u, CEED_MEM_HOST, CEED_USE_POINTER, x.GetData()); 204 CeedVectorSetArray(v, CEED_MEM_HOST, CEED_USE_POINTER, y.GetData()); 205 206 CeedOperatorApply(oper, u, v, CEED_REQUEST_IMMEDIATE); 207 CeedVectorSyncArray(v, CEED_MEM_HOST); 208 } 209 }; 210