1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-ref.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedInt *e_data_out_indices, 20 bool *apply_add_basis, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, 21 CeedInt num_fields, CeedInt Q) { 22 Ceed ceed; 23 CeedSize e_size, q_size; 24 CeedInt num_comp, size, P; 25 CeedQFunctionField *qf_fields; 26 CeedOperatorField *op_fields; 27 28 { 29 Ceed ceed_parent; 30 31 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 32 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 33 CeedCallBackend(CeedReferenceCopy(ceed_parent, &ceed)); 34 CeedCallBackend(CeedDestroy(&ceed_parent)); 35 } 36 if (is_input) { 37 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 38 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 39 } else { 40 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 41 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 42 } 43 44 // Loop over fields 45 for (CeedInt i = 0; i < num_fields; i++) { 46 CeedEvalMode eval_mode; 47 CeedElemRestriction elem_rstr; 48 CeedBasis basis; 49 50 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 51 if (eval_mode != CEED_EVAL_WEIGHT) { 52 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 53 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 54 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 55 } 56 57 switch (eval_mode) { 58 case CEED_EVAL_NONE: 59 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 60 q_size = (CeedSize)Q * size; 61 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 62 break; 63 case CEED_EVAL_INTERP: 64 case CEED_EVAL_GRAD: 65 case CEED_EVAL_DIV: 66 case CEED_EVAL_CURL: 67 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 68 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 69 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 70 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 71 e_size = (CeedSize)P * num_comp; 72 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 73 q_size = (CeedSize)Q * size; 74 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 75 CeedCallBackend(CeedBasisDestroy(&basis)); 76 break; 77 case CEED_EVAL_WEIGHT: // Only on input fields 78 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 79 q_size = (CeedSize)Q; 80 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 81 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 82 CeedCallBackend(CeedBasisDestroy(&basis)); 83 break; 84 } 85 } 86 // Drop duplicate restrictions 87 if (is_input) { 88 for (CeedInt i = 0; i < num_fields; i++) { 89 CeedVector vec_i; 90 CeedElemRestriction rstr_i; 91 92 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 93 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 94 for (CeedInt j = i + 1; j < num_fields; j++) { 95 CeedVector vec_j; 96 CeedElemRestriction rstr_j; 97 98 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 99 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 100 if (vec_i == vec_j && rstr_i == rstr_j) { 101 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 102 CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); 103 skip_rstr[j] = true; 104 } 105 CeedCallBackend(CeedVectorDestroy(&vec_j)); 106 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 107 } 108 CeedCallBackend(CeedVectorDestroy(&vec_i)); 109 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 110 } 111 } else { 112 for (CeedInt i = num_fields - 1; i >= 0; i--) { 113 CeedVector vec_i; 114 CeedElemRestriction rstr_i; 115 116 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 117 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 118 for (CeedInt j = i - 1; j >= 0; j--) { 119 CeedVector vec_j; 120 CeedElemRestriction rstr_j; 121 122 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 123 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 124 if (vec_i == vec_j && rstr_i == rstr_j) { 125 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 126 CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); 127 skip_rstr[j] = true; 128 apply_add_basis[i] = true; 129 e_data_out_indices[j] = i; 130 } 131 CeedCallBackend(CeedVectorDestroy(&vec_j)); 132 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 133 } 134 CeedCallBackend(CeedVectorDestroy(&vec_i)); 135 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 136 } 137 } 138 CeedCallBackend(CeedDestroy(&ceed)); 139 return CEED_ERROR_SUCCESS; 140 } 141 142 //------------------------------------------------------------------------------ 143 // Setup Operator 144 //------------------------------------------------------------------------------/* 145 static int CeedOperatorSetup_Ref(CeedOperator op) { 146 bool is_setup_done; 147 CeedInt Q, num_input_fields, num_output_fields; 148 CeedQFunctionField *qf_input_fields, *qf_output_fields; 149 CeedQFunction qf; 150 CeedOperatorField *op_input_fields, *op_output_fields; 151 CeedOperator_Ref *impl; 152 153 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 154 if (is_setup_done) return CEED_ERROR_SUCCESS; 155 156 CeedCallBackend(CeedOperatorGetData(op, &impl)); 157 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 158 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 159 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 160 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 161 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 162 163 // Allocate 164 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 165 166 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 167 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 168 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_data_out_indices)); 169 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 170 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 171 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 172 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 173 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 174 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 175 176 impl->num_inputs = num_input_fields; 177 impl->num_outputs = num_output_fields; 178 179 // Set up infield and outfield e_vecs and q_vecs 180 // Infields 181 CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, true, impl->skip_rstr_in, NULL, NULL, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, 182 num_input_fields, Q)); 183 // Outfields 184 CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, false, impl->skip_rstr_out, impl->e_data_out_indices, impl->apply_add_basis_out, 185 impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); 186 187 // Identity QFunctions 188 if (impl->is_identity_qf) { 189 CeedEvalMode in_mode, out_mode; 190 CeedQFunctionField *in_fields, *out_fields; 191 192 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 193 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 194 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 195 196 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 197 impl->is_identity_rstr_op = true; 198 } else { 199 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 200 } 201 } 202 203 CeedCallBackend(CeedOperatorSetSetupDone(op)); 204 CeedCallBackend(CeedQFunctionDestroy(&qf)); 205 return CEED_ERROR_SUCCESS; 206 } 207 208 //------------------------------------------------------------------------------ 209 // Setup Operator Inputs 210 //------------------------------------------------------------------------------ 211 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 212 CeedVector in_vec, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 213 CeedOperator_Ref *impl, CeedRequest *request) { 214 for (CeedInt i = 0; i < num_input_fields; i++) { 215 bool is_active; 216 uint64_t state; 217 CeedEvalMode eval_mode; 218 CeedVector vec; 219 220 // Get input vector 221 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 222 is_active = vec == CEED_VECTOR_ACTIVE; 223 if (is_active) { 224 if (skip_active) continue; 225 else vec = in_vec; 226 } 227 228 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 229 // Restrict and Evec 230 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 231 } else { 232 // Restrict 233 CeedCallBackend(CeedVectorGetState(vec, &state)); 234 // Skip restriction if input is unchanged 235 if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) { 236 CeedElemRestriction elem_rstr; 237 238 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 239 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 240 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 241 } 242 impl->input_states[i] = state; 243 // Get evec 244 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 245 } 246 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec)); 247 } 248 return CEED_ERROR_SUCCESS; 249 } 250 251 //------------------------------------------------------------------------------ 252 // Input Basis Action 253 //------------------------------------------------------------------------------ 254 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 255 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 256 CeedOperator_Ref *impl) { 257 for (CeedInt i = 0; i < num_input_fields; i++) { 258 CeedInt elem_size, size, num_comp; 259 CeedEvalMode eval_mode; 260 CeedElemRestriction elem_rstr; 261 CeedBasis basis; 262 263 // Skip active input 264 if (skip_active) { 265 bool is_active; 266 CeedVector vec; 267 268 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 269 is_active = vec == CEED_VECTOR_ACTIVE; 270 CeedCallBackend(CeedVectorDestroy(&vec)); 271 if (is_active) continue; 272 } 273 // Get elem_size, eval_mode, size 274 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 275 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 276 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 277 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 278 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 279 // Basis action 280 switch (eval_mode) { 281 case CEED_EVAL_NONE: 282 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * Q * size])); 283 break; 284 case CEED_EVAL_INTERP: 285 case CEED_EVAL_GRAD: 286 case CEED_EVAL_DIV: 287 case CEED_EVAL_CURL: 288 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 289 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 290 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * elem_size * num_comp])); 291 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 292 CeedCallBackend(CeedBasisDestroy(&basis)); 293 break; 294 case CEED_EVAL_WEIGHT: 295 break; // No action 296 } 297 } 298 return CEED_ERROR_SUCCESS; 299 } 300 301 //------------------------------------------------------------------------------ 302 // Output Basis Action 303 //------------------------------------------------------------------------------ 304 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 305 CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis, CeedOperator op, 306 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 307 for (CeedInt i = 0; i < num_output_fields; i++) { 308 CeedInt elem_size, num_comp; 309 CeedEvalMode eval_mode; 310 CeedElemRestriction elem_rstr; 311 CeedBasis basis; 312 313 // Get elem_size, eval_mode 314 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 315 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 316 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 317 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 318 // Basis action 319 switch (eval_mode) { 320 case CEED_EVAL_NONE: 321 break; // No action 322 case CEED_EVAL_INTERP: 323 case CEED_EVAL_GRAD: 324 case CEED_EVAL_DIV: 325 case CEED_EVAL_CURL: 326 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 327 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 328 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, 329 &e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp])); 330 if (apply_add_basis[i]) { 331 CeedCallBackend(CeedBasisApplyAdd(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 332 } else { 333 CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 334 } 335 CeedCallBackend(CeedBasisDestroy(&basis)); 336 break; 337 // LCOV_EXCL_START 338 case CEED_EVAL_WEIGHT: { 339 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 340 // LCOV_EXCL_STOP 341 } 342 } 343 } 344 return CEED_ERROR_SUCCESS; 345 } 346 347 //------------------------------------------------------------------------------ 348 // Restore Input Vectors 349 //------------------------------------------------------------------------------ 350 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 351 const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 352 for (CeedInt i = 0; i < num_input_fields; i++) { 353 CeedEvalMode eval_mode; 354 355 // Skip active inputs 356 if (skip_active) { 357 bool is_active; 358 CeedVector vec; 359 360 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 361 is_active = vec == CEED_VECTOR_ACTIVE; 362 CeedCallBackend(CeedVectorDestroy(&vec)); 363 if (is_active) continue; 364 } 365 // Restore input 366 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 367 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 368 } else { 369 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 370 } 371 } 372 return CEED_ERROR_SUCCESS; 373 } 374 375 //------------------------------------------------------------------------------ 376 // Operator Apply 377 //------------------------------------------------------------------------------ 378 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 379 CeedInt Q, num_elem, num_input_fields, num_output_fields, size; 380 CeedEvalMode eval_mode; 381 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 382 CeedQFunctionField *qf_input_fields, *qf_output_fields; 383 CeedQFunction qf; 384 CeedOperatorField *op_input_fields, *op_output_fields; 385 CeedOperator_Ref *impl; 386 387 // Setup 388 CeedCallBackend(CeedOperatorSetup_Ref(op)); 389 390 CeedCallBackend(CeedOperatorGetData(op, &impl)); 391 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 392 393 // Restriction only operator 394 if (impl->is_identity_rstr_op) { 395 CeedElemRestriction elem_rstr; 396 397 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_rstr)); 398 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 399 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 400 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_rstr)); 401 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 402 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 403 return CEED_ERROR_SUCCESS; 404 } 405 406 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 407 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 408 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 409 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 410 411 // Input Evecs and Restriction 412 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 413 414 // Output Evecs 415 for (CeedInt i = num_output_fields - 1; i >= 0; i--) { 416 if (impl->skip_rstr_out[i]) { 417 e_data_full[i + num_input_fields] = e_data_full[impl->e_data_out_indices[i] + num_input_fields]; 418 } else { 419 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 420 } 421 } 422 423 // Loop through elements 424 for (CeedInt e = 0; e < num_elem; e++) { 425 // Output pointers 426 for (CeedInt i = 0; i < num_output_fields; i++) { 427 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 428 if (eval_mode == CEED_EVAL_NONE) { 429 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 430 CeedCallBackend( 431 CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][(CeedSize)e * Q * size])); 432 } 433 } 434 435 // Input basis apply 436 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, false, e_data_full, impl)); 437 438 // Q function 439 if (!impl->is_identity_qf) { 440 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 441 } 442 443 // Output basis apply 444 CeedCallBackend(CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, 445 impl->apply_add_basis_out, op, e_data_full, impl)); 446 } 447 448 // Output restriction 449 for (CeedInt i = 0; i < num_output_fields; i++) { 450 bool is_active; 451 CeedVector vec; 452 CeedElemRestriction elem_rstr; 453 454 if (impl->skip_rstr_out[i]) continue; 455 // Restore Evec 456 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 457 // Get output vector 458 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 459 // Active 460 is_active = vec == CEED_VECTOR_ACTIVE; 461 if (is_active) vec = out_vec; 462 // Restrict 463 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 464 CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 465 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec)); 466 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 467 } 468 469 // Restore input arrays 470 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 471 CeedCallBackend(CeedQFunctionDestroy(&qf)); 472 return CEED_ERROR_SUCCESS; 473 } 474 475 //------------------------------------------------------------------------------ 476 // Core code for assembling linear QFunction 477 //------------------------------------------------------------------------------ 478 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 479 CeedRequest *request) { 480 Ceed ceed_parent; 481 CeedInt qf_size_in, qf_size_out, Q, num_elem, num_input_fields, num_output_fields; 482 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 483 CeedQFunctionField *qf_input_fields, *qf_output_fields; 484 CeedQFunction qf; 485 CeedOperatorField *op_input_fields, *op_output_fields; 486 CeedOperator_Ref *impl; 487 488 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 489 CeedCallBackend(CeedOperatorGetData(op, &impl)); 490 qf_size_in = impl->qf_size_in; 491 qf_size_out = impl->qf_size_out; 492 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 493 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 494 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 495 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 496 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 497 498 // Setup 499 CeedCallBackend(CeedOperatorSetup_Ref(op)); 500 501 // Check for restriction only operator 502 CeedCheck(!impl->is_identity_rstr_op, CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 503 504 // Input Evecs and Restriction 505 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 506 507 // Count number of active input fields 508 if (qf_size_in == 0) { 509 for (CeedInt i = 0; i < num_input_fields; i++) { 510 CeedInt field_size; 511 CeedVector vec; 512 513 // Get input vector 514 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 515 // Check if active input 516 if (vec == CEED_VECTOR_ACTIVE) { 517 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 518 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 519 qf_size_in += field_size; 520 } 521 CeedCallBackend(CeedVectorDestroy(&vec)); 522 } 523 CeedCheck(qf_size_in > 0, CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 524 impl->qf_size_in = qf_size_in; 525 } 526 527 // Count number of active output fields 528 if (qf_size_out == 0) { 529 for (CeedInt i = 0; i < num_output_fields; i++) { 530 CeedInt field_size; 531 CeedVector vec; 532 533 // Get output vector 534 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 535 // Check if active output 536 if (vec == CEED_VECTOR_ACTIVE) { 537 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 538 qf_size_out += field_size; 539 } 540 CeedCallBackend(CeedVectorDestroy(&vec)); 541 } 542 CeedCheck(qf_size_out > 0, CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 543 impl->qf_size_out = qf_size_out; 544 } 545 546 // Build objects if needed 547 if (build_objects) { 548 const CeedSize l_size = (CeedSize)num_elem * Q * qf_size_in * qf_size_out; 549 CeedInt strides[3] = {1, Q, qf_size_in * qf_size_out * Q}; /* *NOPAD* */ 550 551 // Create output restriction 552 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, qf_size_in * qf_size_out, 553 (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides, rstr)); 554 // Create assembled vector 555 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 556 } 557 // Clear output vector 558 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 559 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 560 561 // Loop through elements 562 for (CeedInt e = 0; e < num_elem; e++) { 563 // Input basis apply 564 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, true, e_data_full, impl)); 565 566 // Assemble QFunction 567 568 for (CeedInt i = 0; i < num_input_fields; i++) { 569 bool is_active; 570 CeedInt field_size; 571 CeedVector vec; 572 573 // Set Inputs 574 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 575 is_active = vec == CEED_VECTOR_ACTIVE; 576 CeedCallBackend(CeedVectorDestroy(&vec)); 577 if (!is_active) continue; 578 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 579 for (CeedInt field = 0; field < field_size; field++) { 580 // Set current portion of input to 1.0 581 { 582 CeedScalar *array; 583 584 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array)); 585 for (CeedInt j = 0; j < Q; j++) array[field * Q + j] = 1.0; 586 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array)); 587 } 588 589 if (!impl->is_identity_qf) { 590 // Set Outputs 591 for (CeedInt out = 0; out < num_output_fields; out++) { 592 CeedVector vec; 593 594 // Get output vector 595 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 596 // Check if active output 597 if (vec == CEED_VECTOR_ACTIVE) { 598 CeedInt field_size; 599 600 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 601 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size)); 602 assembled_array += field_size * Q; // Advance the pointer by the size of the output 603 } 604 CeedCallBackend(CeedVectorDestroy(&vec)); 605 } 606 // Apply QFunction 607 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 608 } else { 609 CeedInt field_size; 610 const CeedScalar *array; 611 612 // Copy Identity Outputs 613 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size)); 614 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &array)); 615 for (CeedInt j = 0; j < field_size * Q; j++) assembled_array[j] = array[j]; 616 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &array)); 617 assembled_array += field_size * Q; 618 } 619 // Reset input to 0.0 620 { 621 CeedScalar *array; 622 623 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array)); 624 for (CeedInt j = 0; j < Q; j++) array[field * Q + j] = 0.0; 625 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array)); 626 } 627 } 628 } 629 } 630 631 // Un-set output Qvecs to prevent accidental overwrite of Assembled 632 if (!impl->is_identity_qf) { 633 for (CeedInt out = 0; out < num_output_fields; out++) { 634 CeedVector vec; 635 636 // Get output vector 637 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 638 // Check if active output 639 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 640 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 641 } 642 CeedCallBackend(CeedVectorDestroy(&vec)); 643 } 644 } 645 646 // Restore input arrays 647 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 648 649 // Restore output 650 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 651 CeedCallBackend(CeedDestroy(&ceed_parent)); 652 CeedCallBackend(CeedQFunctionDestroy(&qf)); 653 return CEED_ERROR_SUCCESS; 654 } 655 656 //------------------------------------------------------------------------------ 657 // Assemble Linear QFunction 658 //------------------------------------------------------------------------------ 659 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 660 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr, request); 661 } 662 663 //------------------------------------------------------------------------------ 664 // Update Assembled Linear QFunction 665 //------------------------------------------------------------------------------ 666 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 667 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled, &rstr, request); 668 } 669 670 //------------------------------------------------------------------------------ 671 // Setup Input/Output Fields 672 //------------------------------------------------------------------------------ 673 static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, bool *apply_add_basis, 674 CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, 675 CeedInt Q) { 676 Ceed ceed; 677 CeedSize e_size, q_size; 678 CeedInt max_num_points, num_comp, size, P; 679 CeedQFunctionField *qf_fields; 680 CeedOperatorField *op_fields; 681 682 { 683 Ceed ceed_parent; 684 685 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 686 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 687 CeedCallBackend(CeedReferenceCopy(ceed_parent, &ceed)); 688 CeedCallBackend(CeedDestroy(&ceed_parent)); 689 } 690 if (is_input) { 691 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 692 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 693 } else { 694 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 695 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 696 } 697 698 // Get max number of points 699 { 700 CeedInt dim; 701 CeedElemRestriction rstr_points = NULL; 702 CeedOperator_Ref *impl; 703 704 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 705 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 706 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_points, &dim)); 707 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 708 CeedCallBackend(CeedOperatorGetData(op, &impl)); 709 if (is_input) { 710 CeedCallBackend(CeedVectorCreate(ceed, dim * max_num_points, &impl->point_coords_elem)); 711 CeedCallBackend(CeedVectorSetValue(impl->point_coords_elem, 0.0)); 712 } 713 } 714 715 // Loop over fields 716 for (CeedInt i = 0; i < num_fields; i++) { 717 CeedEvalMode eval_mode; 718 CeedBasis basis; 719 720 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 721 if (eval_mode != CEED_EVAL_WEIGHT) { 722 CeedElemRestriction elem_rstr; 723 724 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); 725 CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs_full[i + start_e])); 726 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 727 CeedCallBackend(CeedVectorSetValue(e_vecs_full[i + start_e], 0.0)); 728 } 729 730 switch (eval_mode) { 731 case CEED_EVAL_NONE: { 732 CeedVector vec; 733 734 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 735 e_size = (CeedSize)max_num_points * size; 736 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 737 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec)); 738 if (vec == CEED_VECTOR_ACTIVE || !is_input) { 739 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &q_vecs[i])); 740 } else { 741 q_size = (CeedSize)max_num_points * size; 742 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 743 } 744 CeedCallBackend(CeedVectorDestroy(&vec)); 745 break; 746 } 747 case CEED_EVAL_INTERP: 748 case CEED_EVAL_GRAD: 749 case CEED_EVAL_DIV: 750 case CEED_EVAL_CURL: 751 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 752 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 753 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 754 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 755 e_size = (CeedSize)P * num_comp; 756 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 757 q_size = (CeedSize)max_num_points * size; 758 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 759 CeedCallBackend(CeedBasisDestroy(&basis)); 760 break; 761 case CEED_EVAL_WEIGHT: // Only on input fields 762 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 763 q_size = (CeedSize)max_num_points; 764 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 765 CeedCallBackend( 766 CeedBasisApplyAtPoints(basis, 1, &max_num_points, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, CEED_VECTOR_NONE, q_vecs[i])); 767 CeedCallBackend(CeedBasisDestroy(&basis)); 768 break; 769 } 770 // Initialize full arrays for E-vectors and Q-vectors 771 if (e_vecs[i]) CeedCallBackend(CeedVectorSetValue(e_vecs[i], 0.0)); 772 if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorSetValue(q_vecs[i], 0.0)); 773 } 774 // Drop duplicate restrictions 775 if (is_input) { 776 for (CeedInt i = 0; i < num_fields; i++) { 777 CeedVector vec_i; 778 CeedElemRestriction rstr_i; 779 780 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 781 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 782 for (CeedInt j = i + 1; j < num_fields; j++) { 783 CeedVector vec_j; 784 CeedElemRestriction rstr_j; 785 786 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 787 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 788 if (vec_i == vec_j && rstr_i == rstr_j) { 789 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 790 CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); 791 skip_rstr[j] = true; 792 } 793 CeedCallBackend(CeedVectorDestroy(&vec_j)); 794 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 795 } 796 CeedCallBackend(CeedVectorDestroy(&vec_i)); 797 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 798 } 799 } else { 800 for (CeedInt i = num_fields - 1; i >= 0; i--) { 801 CeedVector vec_i; 802 CeedElemRestriction rstr_i; 803 804 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); 805 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); 806 for (CeedInt j = i - 1; j >= 0; j--) { 807 CeedVector vec_j; 808 CeedElemRestriction rstr_j; 809 810 CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); 811 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); 812 if (vec_i == vec_j && rstr_i == rstr_j) { 813 CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); 814 CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); 815 skip_rstr[j] = true; 816 apply_add_basis[i] = true; 817 } 818 CeedCallBackend(CeedVectorDestroy(&vec_j)); 819 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j)); 820 } 821 CeedCallBackend(CeedVectorDestroy(&vec_i)); 822 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i)); 823 } 824 } 825 CeedCallBackend(CeedDestroy(&ceed)); 826 return CEED_ERROR_SUCCESS; 827 } 828 829 //------------------------------------------------------------------------------ 830 // Setup Operator 831 //------------------------------------------------------------------------------ 832 static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) { 833 bool is_setup_done; 834 CeedInt Q, num_input_fields, num_output_fields; 835 CeedQFunctionField *qf_input_fields, *qf_output_fields; 836 CeedQFunction qf; 837 CeedOperatorField *op_input_fields, *op_output_fields; 838 CeedOperator_Ref *impl; 839 840 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 841 if (is_setup_done) return CEED_ERROR_SUCCESS; 842 843 CeedCallBackend(CeedOperatorGetData(op, &impl)); 844 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 845 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 846 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 847 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 848 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 849 850 // Allocate 851 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 852 853 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); 854 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); 855 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); 856 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 857 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 858 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 859 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 860 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 861 862 impl->num_inputs = num_input_fields; 863 impl->num_outputs = num_output_fields; 864 865 // Set up infield and outfield pointer arrays 866 // Infields 867 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->skip_rstr_in, NULL, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, 868 num_input_fields, Q)); 869 // Outfields 870 CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_full, 871 impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); 872 873 // Identity QFunctions 874 if (impl->is_identity_qf) { 875 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 876 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->e_vecs_out[0])); 877 } 878 879 CeedCallBackend(CeedOperatorSetSetupDone(op)); 880 CeedCallBackend(CeedQFunctionDestroy(&qf)); 881 return CEED_ERROR_SUCCESS; 882 } 883 884 //------------------------------------------------------------------------------ 885 // Input Basis Action 886 //------------------------------------------------------------------------------ 887 static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_input_fields, 888 CeedOperatorField *op_input_fields, CeedInt num_input_fields, CeedVector in_vec, 889 CeedVector point_coords_elem, bool skip_active, bool skip_passive, 890 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl, CeedRequest *request) { 891 for (CeedInt i = 0; i < num_input_fields; i++) { 892 bool is_active; 893 CeedInt elem_size, size, num_comp; 894 CeedRestrictionType rstr_type; 895 CeedEvalMode eval_mode; 896 CeedVector vec; 897 CeedElemRestriction elem_rstr; 898 CeedBasis basis; 899 900 // Skip active input 901 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 902 is_active = vec == CEED_VECTOR_ACTIVE; 903 CeedCallBackend(CeedVectorDestroy(&vec)); 904 if (skip_active && is_active) continue; 905 if (skip_passive && !is_active) continue; 906 907 // Get elem_size, eval_mode, size 908 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 909 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 910 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 911 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 912 // Restrict block active input 913 // When skipping passive inputs, we're doing assembly and should not restrict 914 if (is_active && !impl->skip_rstr_in[i] && !skip_passive) { 915 if (rstr_type == CEED_RESTRICTION_POINTS) { 916 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 917 } else { 918 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request)); 919 } 920 } 921 // Basis action 922 switch (eval_mode) { 923 case CEED_EVAL_NONE: 924 if (!is_active) { 925 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][num_points_offset * size])); 926 } 927 break; 928 // Note - these basis eval modes require FEM fields 929 case CEED_EVAL_INTERP: 930 case CEED_EVAL_GRAD: 931 case CEED_EVAL_DIV: 932 case CEED_EVAL_CURL: 933 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 934 if (!is_active) { 935 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 936 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 937 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][(CeedSize)e * elem_size * num_comp])); 938 } 939 CeedCallBackend( 940 CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_NOTRANSPOSE, eval_mode, point_coords_elem, impl->e_vecs_in[i], impl->q_vecs_in[i])); 941 CeedCallBackend(CeedBasisDestroy(&basis)); 942 break; 943 case CEED_EVAL_WEIGHT: 944 break; // No action 945 } 946 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 947 } 948 return CEED_ERROR_SUCCESS; 949 } 950 951 //------------------------------------------------------------------------------ 952 // Output Basis Action 953 //------------------------------------------------------------------------------ 954 static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_output_fields, 955 CeedOperatorField *op_output_fields, CeedInt num_input_fields, CeedInt num_output_fields, 956 bool *apply_add_basis, bool *skip_rstr, CeedOperator op, CeedVector out_vec, 957 CeedVector point_coords_elem, bool skip_passive, CeedOperator_Ref *impl, CeedRequest *request) { 958 for (CeedInt i = 0; i < num_output_fields; i++) { 959 bool is_active; 960 CeedRestrictionType rstr_type; 961 CeedEvalMode eval_mode; 962 CeedVector vec; 963 CeedElemRestriction elem_rstr; 964 CeedBasis basis; 965 966 // Skip active input 967 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 968 is_active = vec == CEED_VECTOR_ACTIVE; 969 CeedCallBackend(CeedVectorDestroy(&vec)); 970 if (skip_passive && !is_active) continue; 971 972 // Get elem_size, eval_mode, size 973 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 974 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 975 // Basis action 976 switch (eval_mode) { 977 case CEED_EVAL_NONE: 978 break; // No action 979 case CEED_EVAL_INTERP: 980 case CEED_EVAL_GRAD: 981 case CEED_EVAL_DIV: 982 case CEED_EVAL_CURL: 983 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 984 if (apply_add_basis[i]) { 985 CeedCallBackend(CeedBasisApplyAddAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], 986 impl->e_vecs_out[i])); 987 } else { 988 CeedCallBackend( 989 CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], impl->e_vecs_out[i])); 990 } 991 CeedCallBackend(CeedBasisDestroy(&basis)); 992 break; 993 // LCOV_EXCL_START 994 case CEED_EVAL_WEIGHT: { 995 return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 996 // LCOV_EXCL_STOP 997 } 998 } 999 // Restrict output block 1000 // When skipping passive outputs, we're doing assembly and should not restrict 1001 if (skip_rstr[i] || skip_passive) { 1002 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1003 continue; 1004 } 1005 1006 // Get output vector 1007 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1008 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1009 if (is_active) vec = out_vec; 1010 // Restrict 1011 if (rstr_type == CEED_RESTRICTION_POINTS) { 1012 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 1013 } else { 1014 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request)); 1015 } 1016 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec)); 1017 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1018 } 1019 return CEED_ERROR_SUCCESS; 1020 } 1021 1022 //------------------------------------------------------------------------------ 1023 // Operator Apply 1024 //------------------------------------------------------------------------------ 1025 static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 1026 CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem; 1027 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 1028 CeedVector point_coords = NULL; 1029 CeedElemRestriction rstr_points = NULL; 1030 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1031 CeedQFunction qf; 1032 CeedOperatorField *op_input_fields, *op_output_fields; 1033 CeedOperator_Ref *impl; 1034 1035 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1036 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1037 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1038 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1039 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1040 1041 // Setup 1042 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 1043 1044 // Point coordinates 1045 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1046 1047 // Input Evecs and Restriction 1048 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1049 1050 // Loop through elements 1051 for (CeedInt e = 0; e < num_elem; e++) { 1052 CeedInt num_points; 1053 1054 // Setup points for element 1055 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1056 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 1057 1058 // Input basis apply 1059 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec, 1060 impl->point_coords_elem, false, false, e_data, impl, request)); 1061 1062 // Q function 1063 if (!impl->is_identity_qf) { 1064 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 1065 } 1066 1067 // Output basis apply and restriction 1068 CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, 1069 num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, 1070 impl->point_coords_elem, false, impl, request)); 1071 1072 num_points_offset += num_points; 1073 } 1074 1075 // Restore input arrays 1076 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1077 1078 // Cleanup point coordinates 1079 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1080 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1081 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1082 return CEED_ERROR_SUCCESS; 1083 } 1084 1085 //------------------------------------------------------------------------------ 1086 // Core code for assembling linear QFunction 1087 //------------------------------------------------------------------------------ 1088 static inline int CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, 1089 CeedElemRestriction *rstr, CeedRequest *request) { 1090 Ceed ceed; 1091 CeedInt qf_size_in, qf_size_out, max_num_points, num_elem, num_input_fields, num_output_fields, num_points_offset = 0; 1092 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 1093 CeedVector point_coords = NULL; 1094 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1095 CeedQFunction qf; 1096 CeedOperatorField *op_input_fields, *op_output_fields; 1097 CeedOperator_Ref *impl; 1098 CeedElemRestriction rstr_points = NULL; 1099 1100 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1101 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1102 qf_size_in = impl->qf_size_in; 1103 qf_size_out = impl->qf_size_out; 1104 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1105 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1106 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1107 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1108 1109 // Setup 1110 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 1111 1112 // Check for restriction only operator 1113 CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported"); 1114 1115 // Point coordinates 1116 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1117 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr_points, &max_num_points)); 1118 1119 // Input Evecs and Restriction 1120 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 1121 1122 // Count number of active input fields 1123 if (qf_size_in == 0) { 1124 for (CeedInt i = 0; i < num_input_fields; i++) { 1125 CeedInt field_size; 1126 CeedVector vec; 1127 1128 // Get input vector 1129 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1130 // Check if active input 1131 if (vec == CEED_VECTOR_ACTIVE) { 1132 // Check that all active inputs are nodal fields 1133 { 1134 CeedElemRestriction elem_rstr; 1135 bool is_at_points = false; 1136 1137 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1138 CeedCallBackend(CeedElemRestrictionIsAtPoints(elem_rstr, &is_at_points)); 1139 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1140 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 1141 } 1142 // Get size of active input 1143 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 1144 qf_size_in += field_size; 1145 } 1146 CeedCallBackend(CeedVectorDestroy(&vec)); 1147 } 1148 CeedCheck(qf_size_in, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 1149 impl->qf_size_in = qf_size_in; 1150 } 1151 1152 // Count number of active output fields 1153 if (qf_size_out == 0) { 1154 for (CeedInt i = 0; i < num_output_fields; i++) { 1155 CeedInt field_size; 1156 CeedVector vec; 1157 1158 // Get output vector 1159 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1160 // Check if active output 1161 if (vec == CEED_VECTOR_ACTIVE) { 1162 // Check that all active inputs are nodal fields 1163 { 1164 CeedElemRestriction elem_rstr; 1165 bool is_at_points = false; 1166 1167 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); 1168 CeedCallBackend(CeedElemRestrictionIsAtPoints(elem_rstr, &is_at_points)); 1169 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1170 CeedCheck(!is_at_points, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction with active input at points"); 1171 } 1172 // Get size of active output 1173 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 1174 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 1175 qf_size_out += field_size; 1176 } 1177 CeedCallBackend(CeedVectorDestroy(&vec)); 1178 } 1179 CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 1180 impl->qf_size_out = qf_size_out; 1181 } 1182 1183 // Build objects if needed 1184 if (build_objects) { 1185 CeedInt num_points_total; 1186 const CeedInt *offsets; 1187 1188 CeedCallBackend(CeedElemRestrictionGetNumPoints(rstr_points, &num_points_total)); 1189 1190 // Create output restriction (at points) 1191 CeedCallBackend(CeedElemRestrictionGetOffsets(rstr_points, CEED_MEM_HOST, &offsets)); 1192 CeedCallBackend(CeedElemRestrictionCreateAtPoints(ceed, num_elem, num_points_total, qf_size_in * qf_size_out, 1193 qf_size_in * qf_size_out * num_points_total, CEED_MEM_HOST, CEED_COPY_VALUES, offsets, rstr)); 1194 CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr_points, &offsets)); 1195 1196 // Create assembled vector 1197 CeedCallBackend(CeedElemRestrictionCreateVector(*rstr, assembled, NULL)); 1198 } 1199 // Clear output vector 1200 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 1201 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 1202 1203 // Loop through elements 1204 for (CeedInt e = 0; e < num_elem; e++) { 1205 CeedInt num_points; 1206 1207 // Setup points for element 1208 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1209 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 1210 1211 // Input basis apply 1212 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, NULL, 1213 impl->point_coords_elem, true, false, e_data_full, impl, request)); 1214 1215 // Assemble QFunction 1216 for (CeedInt i = 0; i < num_input_fields; i++) { 1217 bool is_active; 1218 CeedInt field_size; 1219 CeedVector vec; 1220 1221 // Get input vector 1222 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1223 is_active = vec == CEED_VECTOR_ACTIVE; 1224 CeedCallBackend(CeedVectorDestroy(&vec)); 1225 // Check if active input 1226 if (!is_active) continue; 1227 // Get size of active input 1228 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 1229 for (CeedInt field = 0; field < field_size; field++) { 1230 // Set current portion of input to 1.0 1231 { 1232 CeedScalar *array; 1233 1234 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array)); 1235 for (CeedInt j = 0; j < num_points; j++) array[field * num_points + j] = 1.0; 1236 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array)); 1237 } 1238 1239 if (!impl->is_identity_qf) { 1240 // Set Outputs 1241 for (CeedInt out = 0; out < num_output_fields; out++) { 1242 CeedVector vec; 1243 CeedInt field_size; 1244 1245 // Get output vector 1246 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1247 // Check if active output 1248 if (vec == CEED_VECTOR_ACTIVE) { 1249 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 1250 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size)); 1251 assembled_array += field_size * num_points; // Advance the pointer by the size of the output 1252 } 1253 CeedCallBackend(CeedVectorDestroy(&vec)); 1254 } 1255 // Apply QFunction 1256 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 1257 } else { 1258 const CeedScalar *array; 1259 CeedInt field_size; 1260 1261 // Copy Identity Outputs 1262 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size)); 1263 CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &array)); 1264 for (CeedInt j = 0; j < field_size * num_points; j++) assembled_array[j] = array[j]; 1265 CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &array)); 1266 assembled_array += field_size * num_points; 1267 } 1268 // Reset input to 0.0 1269 { 1270 CeedScalar *array; 1271 1272 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array)); 1273 for (CeedInt j = 0; j < num_points; j++) array[field * num_points + j] = 0.0; 1274 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array)); 1275 } 1276 } 1277 } 1278 num_points_offset += num_points; 1279 } 1280 1281 // Un-set output Qvecs to prevent accidental overwrite of Assembled 1282 if (!impl->is_identity_qf) { 1283 for (CeedInt out = 0; out < num_output_fields; out++) { 1284 CeedVector vec; 1285 1286 // Get output vector 1287 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 1288 // Check if active output 1289 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 1290 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 1291 } 1292 CeedCallBackend(CeedVectorDestroy(&vec)); 1293 } 1294 } 1295 1296 // Restore input arrays 1297 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 1298 1299 // Restore output 1300 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 1301 1302 // Cleanup 1303 CeedCallBackend(CeedDestroy(&ceed)); 1304 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1305 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1306 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1307 return CEED_ERROR_SUCCESS; 1308 } 1309 1310 //------------------------------------------------------------------------------ 1311 // Assemble Linear QFunction 1312 //------------------------------------------------------------------------------ 1313 static int CeedOperatorLinearAssembleQFunctionAtPoints_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 1314 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, true, assembled, rstr, request); 1315 } 1316 1317 //------------------------------------------------------------------------------ 1318 // Update Assembled Linear QFunction 1319 //------------------------------------------------------------------------------ 1320 static int CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, 1321 CeedRequest *request) { 1322 return CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(op, false, &assembled, &rstr, request); 1323 } 1324 1325 //------------------------------------------------------------------------------ 1326 // Assemble Operator Diagonal AtPoints 1327 //------------------------------------------------------------------------------ 1328 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, CeedVector assembled, CeedRequest *request) { 1329 CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem, num_comp_active = 1; 1330 CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}; 1331 Ceed ceed; 1332 CeedVector point_coords = NULL, in_vec, out_vec; 1333 CeedElemRestriction rstr_points = NULL; 1334 CeedQFunctionField *qf_input_fields, *qf_output_fields; 1335 CeedQFunction qf; 1336 CeedOperatorField *op_input_fields, *op_output_fields; 1337 CeedOperator_Ref *impl; 1338 1339 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1340 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 1341 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 1342 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 1343 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 1344 1345 // Setup 1346 CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op)); 1347 1348 // Ceed 1349 { 1350 Ceed ceed_parent; 1351 1352 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1353 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 1354 CeedCallBackend(CeedReferenceCopy(ceed_parent, &ceed)); 1355 CeedCallBackend(CeedDestroy(&ceed_parent)); 1356 } 1357 1358 // Point coordinates 1359 CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords)); 1360 1361 // Input and output vectors 1362 { 1363 CeedSize input_size, output_size; 1364 1365 CeedCallBackend(CeedOperatorGetActiveVectorLengths(op, &input_size, &output_size)); 1366 CeedCallBackend(CeedVectorCreate(ceed, input_size, &in_vec)); 1367 CeedCallBackend(CeedVectorCreate(ceed, output_size, &out_vec)); 1368 CeedCallBackend(CeedVectorSetValue(out_vec, 0.0)); 1369 } 1370 1371 // Clear input Evecs 1372 for (CeedInt i = 0; i < num_input_fields; i++) { 1373 bool is_active; 1374 CeedVector vec; 1375 1376 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1377 is_active = vec == CEED_VECTOR_ACTIVE; 1378 CeedCallBackend(CeedVectorDestroy(&vec)); 1379 if (!is_active || impl->skip_rstr_in[i]) continue; 1380 CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0)); 1381 } 1382 1383 // Input Evecs and Restriction 1384 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request)); 1385 1386 // Loop through elements 1387 for (CeedInt e = 0; e < num_elem; e++) { 1388 CeedInt num_points, e_vec_size = 0; 1389 1390 // Setup points for element 1391 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, request)); 1392 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points)); 1393 1394 // Input basis apply for non-active bases 1395 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec, 1396 impl->point_coords_elem, true, false, e_data, impl, request)); 1397 1398 // Loop over points on element 1399 for (CeedInt i = 0; i < num_input_fields; i++) { 1400 bool is_active_at_points = true, is_active; 1401 CeedInt elem_size_active = 1; 1402 CeedRestrictionType rstr_type; 1403 CeedVector vec; 1404 CeedElemRestriction elem_rstr; 1405 1406 // -- Skip non-active input 1407 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1408 is_active = vec == CEED_VECTOR_ACTIVE; 1409 CeedCallBackend(CeedVectorDestroy(&vec)); 1410 if (!is_active || impl->skip_rstr_in[i]) continue; 1411 1412 // -- Get active restriction type 1413 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); 1414 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1415 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS; 1416 if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size_active)); 1417 else elem_size_active = num_points; 1418 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active)); 1419 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1420 1421 e_vec_size = elem_size_active * num_comp_active; 1422 for (CeedInt s = 0; s < e_vec_size; s++) { 1423 // -- Update unit vector 1424 { 1425 CeedScalar *array; 1426 1427 CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array)); 1428 array[s] = 1.0; 1429 if (s > 0) array[s - 1] = 0.0; 1430 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array)); 1431 } 1432 // Input basis apply for active bases 1433 CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, 1434 in_vec, impl->point_coords_elem, false, true, e_data, impl, request)); 1435 1436 // -- Q function 1437 if (!impl->is_identity_qf) { 1438 CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out)); 1439 } 1440 1441 // -- Output basis apply and restriction 1442 CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, 1443 num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, 1444 impl->point_coords_elem, true, impl, request)); 1445 1446 // -- Grab diagonal value 1447 for (CeedInt j = 0; j < num_output_fields; j++) { 1448 bool is_active; 1449 CeedInt elem_size = 0; 1450 CeedRestrictionType rstr_type; 1451 CeedVector vec; 1452 CeedElemRestriction elem_rstr; 1453 1454 // ---- Skip non-active output 1455 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec)); 1456 is_active = vec == CEED_VECTOR_ACTIVE; 1457 CeedCallBackend(CeedVectorDestroy(&vec)); 1458 if (!is_active || impl->skip_rstr_out[j]) continue; 1459 1460 // ---- Check if elem size matches 1461 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); 1462 CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); 1463 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) { 1464 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1465 continue; 1466 } 1467 if (rstr_type == CEED_RESTRICTION_POINTS) { 1468 CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(elem_rstr, e, &elem_size)); 1469 } else { 1470 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); 1471 } 1472 { 1473 CeedInt num_comp = 0; 1474 1475 CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); 1476 if (e_vec_size != num_comp * elem_size) { 1477 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1478 continue; 1479 } 1480 } 1481 // ---- Update output vector 1482 { 1483 CeedScalar *array, current_value = 0.0; 1484 1485 CeedCallBackend(CeedVectorGetArray(impl->e_vecs_out[j], CEED_MEM_HOST, &array)); 1486 current_value = array[s]; 1487 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[j], &array)); 1488 CeedCallBackend(CeedVectorSetValue(impl->e_vecs_out[j], 0.0)); 1489 CeedCallBackend(CeedVectorGetArray(impl->e_vecs_out[j], CEED_MEM_HOST, &array)); 1490 array[s] = current_value; 1491 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[j], &array)); 1492 } 1493 // ---- Restrict output block 1494 if (rstr_type == CEED_RESTRICTION_POINTS) { 1495 CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[j], assembled, request)); 1496 } else { 1497 CeedCallBackend(CeedElemRestrictionApplyBlock(elem_rstr, e, CEED_TRANSPOSE, impl->e_vecs_out[j], assembled, request)); 1498 } 1499 CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); 1500 } 1501 // -- Reset unit vector 1502 if (s == e_vec_size - 1) { 1503 CeedScalar *array; 1504 1505 CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array)); 1506 array[s] = 0.0; 1507 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array)); 1508 } 1509 } 1510 } 1511 num_points_offset += num_points; 1512 } 1513 1514 // Restore input arrays 1515 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl)); 1516 1517 // Cleanup 1518 CeedCallBackend(CeedDestroy(&ceed)); 1519 CeedCallBackend(CeedVectorDestroy(&in_vec)); 1520 CeedCallBackend(CeedVectorDestroy(&out_vec)); 1521 CeedCallBackend(CeedVectorDestroy(&point_coords)); 1522 CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1523 CeedCallBackend(CeedQFunctionDestroy(&qf)); 1524 return CEED_ERROR_SUCCESS; 1525 } 1526 1527 //------------------------------------------------------------------------------ 1528 // Operator Destroy 1529 //------------------------------------------------------------------------------ 1530 static int CeedOperatorDestroy_Ref(CeedOperator op) { 1531 CeedOperator_Ref *impl; 1532 1533 CeedCallBackend(CeedOperatorGetData(op, &impl)); 1534 CeedCallBackend(CeedFree(&impl->skip_rstr_in)); 1535 CeedCallBackend(CeedFree(&impl->skip_rstr_out)); 1536 CeedCallBackend(CeedFree(&impl->e_data_out_indices)); 1537 CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); 1538 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 1539 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 1540 } 1541 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 1542 CeedCallBackend(CeedFree(&impl->input_states)); 1543 1544 for (CeedInt i = 0; i < impl->num_inputs; i++) { 1545 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 1546 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 1547 } 1548 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 1549 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 1550 1551 for (CeedInt i = 0; i < impl->num_outputs; i++) { 1552 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 1553 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 1554 } 1555 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 1556 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 1557 CeedCallBackend(CeedVectorDestroy(&impl->point_coords_elem)); 1558 1559 CeedCallBackend(CeedFree(&impl)); 1560 return CEED_ERROR_SUCCESS; 1561 } 1562 1563 //------------------------------------------------------------------------------ 1564 // Operator Create 1565 //------------------------------------------------------------------------------ 1566 int CeedOperatorCreate_Ref(CeedOperator op) { 1567 Ceed ceed; 1568 CeedOperator_Ref *impl; 1569 1570 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1571 CeedCallBackend(CeedCalloc(1, &impl)); 1572 CeedCallBackend(CeedOperatorSetData(op, impl)); 1573 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Ref)); 1574 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Ref)); 1575 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Ref)); 1576 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1577 CeedCallBackend(CeedDestroy(&ceed)); 1578 return CEED_ERROR_SUCCESS; 1579 } 1580 1581 //------------------------------------------------------------------------------ 1582 // Operator Create At Points 1583 //------------------------------------------------------------------------------ 1584 int CeedOperatorCreateAtPoints_Ref(CeedOperator op) { 1585 Ceed ceed; 1586 CeedOperator_Ref *impl; 1587 1588 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 1589 CeedCallBackend(CeedCalloc(1, &impl)); 1590 CeedCallBackend(CeedOperatorSetData(op, impl)); 1591 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunctionAtPoints_Ref)); 1592 CeedCallBackend( 1593 CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref)); 1594 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref)); 1595 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref)); 1596 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 1597 CeedCallBackend(CeedDestroy(&ceed)); 1598 return CEED_ERROR_SUCCESS; 1599 } 1600 1601 //------------------------------------------------------------------------------ 1602