1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stdint.h> 12 #include <string.h> 13 14 #include "ceed-opt.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, const CeedInt blk_size, CeedElemRestriction *blk_restr, 20 CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, 21 CeedInt Q) { 22 CeedInt num_comp, size, P; 23 CeedSize e_size, q_size; 24 Ceed ceed; 25 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 26 CeedBasis basis; 27 CeedElemRestriction r; 28 CeedOperatorField *op_fields; 29 CeedQFunctionField *qf_fields; 30 if (is_input) { 31 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 32 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 33 } else { 34 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 35 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 36 } 37 38 // Loop over fields 39 for (CeedInt i = 0; i < num_fields; i++) { 40 CeedEvalMode eval_mode; 41 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 42 43 if (eval_mode != CEED_EVAL_WEIGHT) { 44 Ceed ceed_rstr; 45 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &r)); 46 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed_rstr)); 47 CeedSize l_size; 48 CeedInt num_elem, elem_size, comp_stride; 49 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 50 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 51 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 52 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 53 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 54 55 CeedRestrictionType rstr_type; 56 CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 57 switch (rstr_type) { 58 case CEED_RESTRICTION_DEFAULT: { 59 const CeedInt *offsets = NULL; 60 CeedCallBackend(CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets)); 61 CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed_rstr, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, CEED_MEM_HOST, 62 CEED_COPY_VALUES, offsets, &blk_restr[i + start_e])); 63 CeedCallBackend(CeedElemRestrictionRestoreOffsets(r, &offsets)); 64 } break; 65 case CEED_RESTRICTION_ORIENTED: { 66 const CeedInt *offsets = NULL; 67 const bool *orients = NULL; 68 CeedCallBackend(CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets)); 69 CeedCallBackend(CeedElemRestrictionGetOrientations(r, CEED_MEM_HOST, &orients)); 70 CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed_rstr, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, 71 CEED_MEM_HOST, CEED_COPY_VALUES, offsets, orients, &blk_restr[i + start_e])); 72 CeedCallBackend(CeedElemRestrictionRestoreOffsets(r, &offsets)); 73 CeedCallBackend(CeedElemRestrictionRestoreOrientations(r, &orients)); 74 } break; 75 case CEED_RESTRICTION_CURL_ORIENTED: { 76 const CeedInt *offsets = NULL; 77 const CeedInt8 *curl_orients = NULL; 78 CeedCallBackend(CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets)); 79 CeedCallBackend(CeedElemRestrictionGetCurlOrientations(r, CEED_MEM_HOST, &curl_orients)); 80 CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed_rstr, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, 81 CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients, 82 &blk_restr[i + start_e])); 83 CeedCallBackend(CeedElemRestrictionRestoreOffsets(r, &offsets)); 84 CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(r, &curl_orients)); 85 } break; 86 case CEED_RESTRICTION_STRIDED: { 87 CeedInt strides[3]; 88 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 89 CeedCallBackend( 90 CeedElemRestrictionCreateBlockedStrided(ceed_rstr, num_elem, elem_size, blk_size, num_comp, l_size, strides, &blk_restr[i + start_e])); 91 } break; 92 } 93 CeedCallBackend(CeedElemRestrictionCreateVector(blk_restr[i + start_e], NULL, &e_vecs_full[i + start_e])); 94 } 95 96 switch (eval_mode) { 97 case CEED_EVAL_NONE: 98 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 99 e_size = (CeedSize)Q * size * blk_size; 100 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 101 q_size = (CeedSize)Q * size * blk_size; 102 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 103 break; 104 case CEED_EVAL_INTERP: 105 case CEED_EVAL_GRAD: 106 case CEED_EVAL_DIV: 107 case CEED_EVAL_CURL: 108 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 109 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 110 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 111 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 112 e_size = (CeedSize)P * num_comp * blk_size; 113 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 114 q_size = (CeedSize)Q * size * blk_size; 115 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 116 break; 117 case CEED_EVAL_WEIGHT: // Only on input fields 118 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 119 q_size = (CeedSize)Q * blk_size; 120 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 121 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 122 break; 123 } 124 if (is_input && !!e_vecs[i]) { 125 CeedCallBackend(CeedVectorSetArray(e_vecs[i], CEED_MEM_HOST, CEED_COPY_VALUES, NULL)); 126 } 127 } 128 return CEED_ERROR_SUCCESS; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Setup Operator 133 //------------------------------------------------------------------------------ 134 static int CeedOperatorSetup_Opt(CeedOperator op) { 135 bool is_setup_done; 136 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 137 if (is_setup_done) return CEED_ERROR_SUCCESS; 138 Ceed ceed; 139 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 140 Ceed_Opt *ceed_impl; 141 CeedCallBackend(CeedGetData(ceed, &ceed_impl)); 142 const CeedInt blk_size = ceed_impl->blk_size; 143 CeedOperator_Opt *impl; 144 CeedCallBackend(CeedOperatorGetData(op, &impl)); 145 CeedQFunction qf; 146 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 147 CeedInt Q, num_input_fields, num_output_fields; 148 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 149 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 150 CeedOperatorField *op_input_fields, *op_output_fields; 151 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 152 CeedQFunctionField *qf_input_fields, *qf_output_fields; 153 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 154 155 // Allocate 156 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr)); 157 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 158 159 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 160 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 161 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 162 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 163 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 164 165 impl->num_inputs = num_input_fields; 166 impl->num_outputs = num_output_fields; 167 168 // Set up infield and outfield pointer arrays 169 // Infields 170 CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, blk_size, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, 171 num_input_fields, Q)); 172 // Outfields 173 CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, blk_size, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, 174 num_input_fields, num_output_fields, Q)); 175 176 // Identity QFunctions 177 if (impl->is_identity_qf) { 178 CeedEvalMode in_mode, out_mode; 179 CeedQFunctionField *in_fields, *out_fields; 180 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 181 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 182 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 183 184 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 185 impl->is_identity_restr_op = true; 186 } else { 187 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 188 } 189 } 190 191 CeedCallBackend(CeedOperatorSetSetupDone(op)); 192 193 return CEED_ERROR_SUCCESS; 194 } 195 196 //------------------------------------------------------------------------------ 197 // Setup Input Fields 198 //------------------------------------------------------------------------------ 199 static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 200 CeedVector in_vec, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl, 201 CeedRequest *request) { 202 CeedEvalMode eval_mode; 203 CeedVector vec; 204 uint64_t state; 205 206 for (CeedInt i = 0; i < num_input_fields; i++) { 207 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 208 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 209 } else { 210 // Get input vector 211 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 212 if (vec != CEED_VECTOR_ACTIVE) { 213 // Restrict 214 CeedCallBackend(CeedVectorGetState(vec, &state)); 215 if (state != impl->input_states[i]) { 216 CeedCallBackend(CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 217 impl->input_states[i] = state; 218 } 219 // Get evec 220 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i])); 221 } else { 222 // Set Qvec for CEED_EVAL_NONE 223 if (eval_mode == CEED_EVAL_NONE) { 224 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_in[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i])); 225 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i])); 226 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_in[i], (const CeedScalar **)&e_data[i])); 227 } 228 } 229 } 230 } 231 return CEED_ERROR_SUCCESS; 232 } 233 234 //------------------------------------------------------------------------------ 235 // Input Basis Action 236 //------------------------------------------------------------------------------ 237 static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 238 CeedInt num_input_fields, CeedInt blk_size, CeedVector in_vec, bool skip_active, 239 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl, CeedRequest *request) { 240 CeedInt elem_size, size, num_comp; 241 CeedElemRestriction elem_restr; 242 CeedEvalMode eval_mode; 243 CeedBasis basis; 244 CeedVector vec; 245 246 for (CeedInt i = 0; i < num_input_fields; i++) { 247 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 248 // Skip active input 249 if (skip_active) { 250 if (vec == CEED_VECTOR_ACTIVE) continue; 251 } 252 253 CeedInt active_in = 0; 254 // Get elem_size, eval_mode, size 255 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr)); 256 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 257 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 258 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 259 // Restrict block active input 260 if (vec == CEED_VECTOR_ACTIVE) { 261 CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[i], e / blk_size, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 262 active_in = 1; 263 } 264 // Basis action 265 switch (eval_mode) { 266 case CEED_EVAL_NONE: 267 if (!active_in) { 268 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][e * Q * size])); 269 } 270 break; 271 case CEED_EVAL_INTERP: 272 case CEED_EVAL_GRAD: 273 case CEED_EVAL_DIV: 274 case CEED_EVAL_CURL: 275 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 276 if (!active_in) { 277 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 278 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][e * elem_size * num_comp])); 279 } 280 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 281 break; 282 case CEED_EVAL_WEIGHT: 283 break; // No action 284 } 285 } 286 return CEED_ERROR_SUCCESS; 287 } 288 289 //------------------------------------------------------------------------------ 290 // Output Basis Action 291 //------------------------------------------------------------------------------ 292 static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 293 CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 294 CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) { 295 CeedElemRestriction elem_restr; 296 CeedEvalMode eval_mode; 297 CeedBasis basis; 298 CeedVector vec; 299 300 for (CeedInt i = 0; i < num_output_fields; i++) { 301 // Get elem_size, eval_mode, size 302 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr)); 303 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 304 // Basis action 305 switch (eval_mode) { 306 case CEED_EVAL_NONE: 307 break; // No action 308 case CEED_EVAL_INTERP: 309 case CEED_EVAL_GRAD: 310 case CEED_EVAL_DIV: 311 case CEED_EVAL_CURL: 312 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 313 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 314 break; 315 // LCOV_EXCL_START 316 case CEED_EVAL_WEIGHT: { 317 Ceed ceed; 318 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 319 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 320 // LCOV_EXCL_STOP 321 } 322 } 323 // Restrict output block 324 // Get output vector 325 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 326 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 327 // Restrict 328 CeedCallBackend( 329 CeedElemRestrictionApplyBlock(impl->blk_restr[i + impl->num_inputs], e / blk_size, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 330 } 331 return CEED_ERROR_SUCCESS; 332 } 333 334 //------------------------------------------------------------------------------ 335 // Restore Input Vectors 336 //------------------------------------------------------------------------------ 337 static inline int CeedOperatorRestoreInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 338 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl) { 339 for (CeedInt i = 0; i < num_input_fields; i++) { 340 CeedEvalMode eval_mode; 341 CeedVector vec; 342 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 343 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 344 if (eval_mode != CEED_EVAL_WEIGHT && vec != CEED_VECTOR_ACTIVE) { 345 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data[i])); 346 } 347 } 348 return CEED_ERROR_SUCCESS; 349 } 350 351 //------------------------------------------------------------------------------ 352 // Operator Apply 353 //------------------------------------------------------------------------------ 354 static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 355 Ceed ceed; 356 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 357 Ceed_Opt *ceed_impl; 358 CeedCallBackend(CeedGetData(ceed, &ceed_impl)); 359 CeedInt blk_size = ceed_impl->blk_size; 360 CeedOperator_Opt *impl; 361 CeedCallBackend(CeedOperatorGetData(op, &impl)); 362 CeedInt Q, num_input_fields, num_output_fields, num_elem; 363 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 364 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 365 CeedInt num_blks = (num_elem / blk_size) + !!(num_elem % blk_size); 366 CeedQFunction qf; 367 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 368 CeedOperatorField *op_input_fields, *op_output_fields; 369 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 370 CeedQFunctionField *qf_input_fields, *qf_output_fields; 371 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 372 CeedEvalMode eval_mode; 373 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 374 375 // Setup 376 CeedCallBackend(CeedOperatorSetup_Opt(op)); 377 378 // Restriction only operator 379 if (impl->is_identity_restr_op) { 380 for (CeedInt b = 0; b < num_blks; b++) { 381 CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[0], b, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[0], request)); 382 CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[1], b, CEED_TRANSPOSE, impl->e_vecs_in[0], out_vec, request)); 383 } 384 return CEED_ERROR_SUCCESS; 385 } 386 387 // Input Evecs and Restriction 388 CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, in_vec, e_data, impl, request)); 389 390 // Output Lvecs, Evecs, and Qvecs 391 for (CeedInt i = 0; i < num_output_fields; i++) { 392 // Set Qvec if needed 393 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 394 if (eval_mode == CEED_EVAL_NONE) { 395 // Set qvec to single block evec 396 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_HOST, &e_data[i + num_input_fields])); 397 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i + num_input_fields])); 398 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[i], &e_data[i + num_input_fields])); 399 } 400 } 401 402 // Loop through elements 403 for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) { 404 // Input basis apply 405 CeedCallBackend( 406 CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, in_vec, false, e_data, impl, request)); 407 408 // Q function 409 if (!impl->is_identity_qf) { 410 CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out)); 411 } 412 413 // Output basis apply and restrict 414 CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, blk_size, num_input_fields, num_output_fields, op, out_vec, 415 impl, request)); 416 } 417 418 // Restore input arrays 419 CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl)); 420 421 return CEED_ERROR_SUCCESS; 422 } 423 424 //------------------------------------------------------------------------------ 425 // Core code for linear QFunction assembly 426 //------------------------------------------------------------------------------ 427 static inline int CeedOperatorLinearAssembleQFunctionCore_Opt(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 428 CeedRequest *request) { 429 Ceed ceed; 430 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 431 Ceed_Opt *ceed_impl; 432 CeedCallBackend(CeedGetData(ceed, &ceed_impl)); 433 const CeedInt blk_size = ceed_impl->blk_size; 434 CeedSize q_size; 435 CeedOperator_Opt *impl; 436 CeedCallBackend(CeedOperatorGetData(op, &impl)); 437 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 438 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 439 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 440 CeedInt num_blks = (num_elem / blk_size) + !!(num_elem % blk_size); 441 CeedQFunction qf; 442 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 443 CeedOperatorField *op_input_fields, *op_output_fields; 444 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 445 CeedQFunctionField *qf_input_fields, *qf_output_fields; 446 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 447 CeedVector vec, l_vec = impl->qf_l_vec; 448 CeedInt num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 449 CeedVector *active_in = impl->qf_active_in; 450 CeedScalar *a, *tmp; 451 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 452 453 // Setup 454 CeedCallBackend(CeedOperatorSetup_Opt(op)); 455 456 // Check for identity 457 CeedCheck(!impl->is_identity_qf, ceed, CEED_ERROR_BACKEND, "Assembling identity qfunctions not supported"); 458 459 // Input Evecs and Restriction 460 CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, NULL, e_data, impl, request)); 461 462 // Count number of active input fields 463 if (!num_active_in) { 464 for (CeedInt i = 0; i < num_input_fields; i++) { 465 // Get input vector 466 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 467 // Check if active input 468 if (vec == CEED_VECTOR_ACTIVE) { 469 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 470 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 471 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp)); 472 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 473 for (CeedInt field = 0; field < size; field++) { 474 q_size = (CeedSize)Q * blk_size; 475 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 476 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &tmp[field * Q * blk_size])); 477 } 478 num_active_in += size; 479 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp)); 480 } 481 } 482 impl->num_active_in = num_active_in; 483 impl->qf_active_in = active_in; 484 } 485 486 // Count number of active output fields 487 if (!num_active_out) { 488 for (CeedInt i = 0; i < num_output_fields; i++) { 489 // Get output vector 490 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 491 // Check if active output 492 if (vec == CEED_VECTOR_ACTIVE) { 493 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 494 num_active_out += size; 495 } 496 } 497 impl->num_active_out = num_active_out; 498 } 499 500 // Check sizes 501 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 502 503 // Setup l_vec 504 if (!l_vec) { 505 CeedSize l_size = (CeedSize)blk_size * Q * num_active_in * num_active_out; 506 CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec)); 507 CeedCallBackend(CeedVectorSetValue(l_vec, 0.0)); 508 impl->qf_l_vec = l_vec; 509 } 510 511 // Build objects if needed 512 CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; 513 if (build_objects) { 514 // Create output restriction 515 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, num_active_in * num_active_out, num_active_in * num_active_out * num_elem * Q, 516 strides, rstr)); 517 // Create assembled vector 518 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 519 CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled)); 520 } 521 522 // Output blocked restriction 523 CeedElemRestriction blk_rstr = impl->qf_blk_rstr; 524 if (!blk_rstr) { 525 CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, num_active_in * num_active_out, 526 num_active_in * num_active_out * num_elem * Q, strides, &blk_rstr)); 527 impl->qf_blk_rstr = blk_rstr; 528 } 529 530 // Loop through elements 531 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 532 for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) { 533 CeedCallBackend(CeedVectorGetArray(l_vec, CEED_MEM_HOST, &a)); 534 535 // Input basis apply 536 CeedCallBackend( 537 CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, NULL, true, e_data, impl, request)); 538 539 // Assemble QFunction 540 for (CeedInt in = 0; in < num_active_in; in++) { 541 // Set Inputs 542 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 543 if (num_active_in > 1) { 544 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 545 } 546 // Set Outputs 547 for (CeedInt out = 0; out < num_output_fields; out++) { 548 // Get output vector 549 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 550 // Check if active output 551 if (vec == CEED_VECTOR_ACTIVE) { 552 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, a)); 553 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 554 a += size * Q * blk_size; // Advance the pointer by the size of the output 555 } 556 } 557 // Apply QFunction 558 CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out)); 559 } 560 561 // Assemble into assembled vector 562 CeedCallBackend(CeedVectorRestoreArray(l_vec, &a)); 563 CeedCallBackend(CeedElemRestrictionApplyBlock(blk_rstr, e / blk_size, CEED_TRANSPOSE, l_vec, *assembled, request)); 564 } 565 566 // Un-set output Qvecs to prevent accidental overwrite of Assembled 567 for (CeedInt out = 0; out < num_output_fields; out++) { 568 // Get output vector 569 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 570 // Check if active output 571 if (vec == CEED_VECTOR_ACTIVE) { 572 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, NULL)); 573 } 574 } 575 576 // Restore input arrays 577 CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl)); 578 579 return CEED_ERROR_SUCCESS; 580 } 581 582 //------------------------------------------------------------------------------ 583 // Assemble Linear QFunction 584 //------------------------------------------------------------------------------ 585 static int CeedOperatorLinearAssembleQFunction_Opt(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 586 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, true, assembled, rstr, request); 587 } 588 589 //------------------------------------------------------------------------------ 590 // Update Assembled Linear QFunction 591 //------------------------------------------------------------------------------ 592 static int CeedOperatorLinearAssembleQFunctionUpdate_Opt(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 593 return CeedOperatorLinearAssembleQFunctionCore_Opt(op, false, &assembled, &rstr, request); 594 } 595 596 //------------------------------------------------------------------------------ 597 // Operator Destroy 598 //------------------------------------------------------------------------------ 599 static int CeedOperatorDestroy_Opt(CeedOperator op) { 600 CeedOperator_Opt *impl; 601 CeedCallBackend(CeedOperatorGetData(op, &impl)); 602 603 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 604 CeedCallBackend(CeedElemRestrictionDestroy(&impl->blk_restr[i])); 605 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 606 } 607 CeedCallBackend(CeedFree(&impl->blk_restr)); 608 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 609 CeedCallBackend(CeedFree(&impl->input_states)); 610 611 for (CeedInt i = 0; i < impl->num_inputs; i++) { 612 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 613 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 614 } 615 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 616 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 617 618 for (CeedInt i = 0; i < impl->num_outputs; i++) { 619 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 620 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 621 } 622 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 623 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 624 625 // QFunction assembly data 626 for (CeedInt i = 0; i < impl->num_active_in; i++) { 627 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 628 } 629 CeedCallBackend(CeedFree(&impl->qf_active_in)); 630 CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec)); 631 CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_blk_rstr)); 632 633 CeedCallBackend(CeedFree(&impl)); 634 return CEED_ERROR_SUCCESS; 635 } 636 637 //------------------------------------------------------------------------------ 638 // Operator Create 639 //------------------------------------------------------------------------------ 640 int CeedOperatorCreate_Opt(CeedOperator op) { 641 Ceed ceed; 642 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 643 Ceed_Opt *ceed_impl; 644 CeedCallBackend(CeedGetData(ceed, &ceed_impl)); 645 CeedInt blk_size = ceed_impl->blk_size; 646 CeedOperator_Opt *impl; 647 648 CeedCallBackend(CeedCalloc(1, &impl)); 649 CeedCallBackend(CeedOperatorSetData(op, impl)); 650 651 CeedCheck(blk_size == 1 || blk_size == 8, ceed, CEED_ERROR_BACKEND, "Opt backend cannot use blocksize: %" CeedInt_FMT, blk_size); 652 653 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Opt)); 654 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Opt)); 655 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Opt)); 656 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Opt)); 657 return CEED_ERROR_SUCCESS; 658 } 659 660 //------------------------------------------------------------------------------ 661