1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-ref.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, CeedVector *e_vecs_full, CeedVector *e_vecs, 20 CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 21 Ceed ceed; 22 CeedSize e_size, q_size; 23 CeedInt num_comp, size, P; 24 CeedQFunctionField *qf_fields; 25 CeedOperatorField *op_fields; 26 27 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 28 if (is_input) { 29 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 30 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 31 } else { 32 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 33 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 34 } 35 36 // Loop over fields 37 for (CeedInt i = 0; i < num_fields; i++) { 38 CeedEvalMode eval_mode; 39 CeedElemRestriction elem_restr; 40 CeedBasis basis; 41 42 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 43 if (eval_mode != CEED_EVAL_WEIGHT) { 44 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_restr)); 45 CeedCallBackend(CeedElemRestrictionCreateVector(elem_restr, NULL, &e_vecs_full[i + start_e])); 46 } 47 48 switch (eval_mode) { 49 case CEED_EVAL_NONE: 50 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 51 q_size = (CeedSize)Q * size; 52 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 53 break; 54 case CEED_EVAL_INTERP: 55 case CEED_EVAL_GRAD: 56 case CEED_EVAL_DIV: 57 case CEED_EVAL_CURL: 58 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 59 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 60 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 61 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 62 e_size = (CeedSize)P * num_comp; 63 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 64 q_size = (CeedSize)Q * size; 65 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 66 break; 67 case CEED_EVAL_WEIGHT: // Only on input fields 68 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 69 q_size = (CeedSize)Q; 70 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 71 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 72 break; 73 } 74 } 75 return CEED_ERROR_SUCCESS; 76 } 77 78 //------------------------------------------------------------------------------ 79 // Setup Operator 80 //------------------------------------------------------------------------------/* 81 static int CeedOperatorSetup_Ref(CeedOperator op) { 82 bool is_setup_done; 83 Ceed ceed; 84 CeedInt Q, num_input_fields, num_output_fields; 85 CeedQFunctionField *qf_input_fields, *qf_output_fields; 86 CeedQFunction qf; 87 CeedOperatorField *op_input_fields, *op_output_fields; 88 CeedOperator_Ref *impl; 89 90 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 91 if (is_setup_done) return CEED_ERROR_SUCCESS; 92 93 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 94 CeedCallBackend(CeedOperatorGetData(op, &impl)); 95 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 96 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 97 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 98 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 99 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 100 101 // Allocate 102 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 103 104 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 105 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 106 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 107 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 108 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 109 110 impl->num_inputs = num_input_fields; 111 impl->num_outputs = num_output_fields; 112 113 // Set up infield and outfield e_vecs and q_vecs 114 // Infields 115 CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, true, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 116 // Outfields 117 CeedCallBackend( 118 CeedOperatorSetupFields_Ref(qf, op, false, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); 119 120 // Identity QFunctions 121 if (impl->is_identity_qf) { 122 CeedEvalMode in_mode, out_mode; 123 CeedQFunctionField *in_fields, *out_fields; 124 125 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 126 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 127 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 128 129 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 130 impl->is_identity_restr_op = true; 131 } else { 132 CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0])); 133 } 134 } 135 136 CeedCallBackend(CeedOperatorSetSetupDone(op)); 137 return CEED_ERROR_SUCCESS; 138 } 139 140 //------------------------------------------------------------------------------ 141 // Setup Operator Inputs 142 //------------------------------------------------------------------------------ 143 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 144 CeedVector in_vec, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 145 CeedOperator_Ref *impl, CeedRequest *request) { 146 for (CeedInt i = 0; i < num_input_fields; i++) { 147 uint64_t state; 148 CeedEvalMode eval_mode; 149 CeedVector vec; 150 CeedElemRestriction elem_restr; 151 152 // Get input vector 153 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 154 if (vec == CEED_VECTOR_ACTIVE) { 155 if (skip_active) continue; 156 else vec = in_vec; 157 } 158 159 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 160 // Restrict and Evec 161 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 162 } else { 163 // Restrict 164 CeedCallBackend(CeedVectorGetState(vec, &state)); 165 // Skip restriction if input is unchanged 166 if (state != impl->input_states[i] || vec == in_vec) { 167 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr)); 168 CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 169 impl->input_states[i] = state; 170 } 171 // Get evec 172 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 173 } 174 } 175 return CEED_ERROR_SUCCESS; 176 } 177 178 //------------------------------------------------------------------------------ 179 // Input Basis Action 180 //------------------------------------------------------------------------------ 181 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 182 CeedInt num_input_fields, const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 183 CeedOperator_Ref *impl) { 184 for (CeedInt i = 0; i < num_input_fields; i++) { 185 CeedInt elem_size, size, num_comp; 186 CeedEvalMode eval_mode; 187 CeedElemRestriction elem_restr; 188 CeedBasis basis; 189 190 // Skip active input 191 if (skip_active) { 192 CeedVector vec; 193 194 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 195 if (vec == CEED_VECTOR_ACTIVE) continue; 196 } 197 // Get elem_size, eval_mode, size 198 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr)); 199 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 200 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 201 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 202 // Basis action 203 switch (eval_mode) { 204 case CEED_EVAL_NONE: 205 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * Q * size])); 206 break; 207 case CEED_EVAL_INTERP: 208 case CEED_EVAL_GRAD: 209 case CEED_EVAL_DIV: 210 case CEED_EVAL_CURL: 211 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 212 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 213 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * elem_size * num_comp])); 214 CeedCallBackend(CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i])); 215 break; 216 case CEED_EVAL_WEIGHT: 217 break; // No action 218 } 219 } 220 return CEED_ERROR_SUCCESS; 221 } 222 223 //------------------------------------------------------------------------------ 224 // Output Basis Action 225 //------------------------------------------------------------------------------ 226 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 227 CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 228 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 229 for (CeedInt i = 0; i < num_output_fields; i++) { 230 CeedInt elem_size, num_comp; 231 CeedEvalMode eval_mode; 232 CeedElemRestriction elem_restr; 233 CeedBasis basis; 234 235 // Get elem_size, eval_mode 236 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr)); 237 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 238 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 239 // Basis action 240 switch (eval_mode) { 241 case CEED_EVAL_NONE: 242 break; // No action 243 case CEED_EVAL_INTERP: 244 case CEED_EVAL_GRAD: 245 case CEED_EVAL_DIV: 246 case CEED_EVAL_CURL: 247 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 248 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 249 CeedCallBackend( 250 CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * elem_size * num_comp])); 251 CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); 252 break; 253 // LCOV_EXCL_START 254 case CEED_EVAL_WEIGHT: { 255 Ceed ceed; 256 257 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 258 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 259 // LCOV_EXCL_STOP 260 } 261 } 262 } 263 return CEED_ERROR_SUCCESS; 264 } 265 266 //------------------------------------------------------------------------------ 267 // Restore Input Vectors 268 //------------------------------------------------------------------------------ 269 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 270 const bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { 271 for (CeedInt i = 0; i < num_input_fields; i++) { 272 CeedEvalMode eval_mode; 273 274 // Skip active inputs 275 if (skip_active) { 276 CeedVector vec; 277 278 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 279 if (vec == CEED_VECTOR_ACTIVE) continue; 280 } 281 // Restore input 282 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 283 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 284 } else { 285 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 286 } 287 } 288 return CEED_ERROR_SUCCESS; 289 } 290 291 //------------------------------------------------------------------------------ 292 // Operator Apply 293 //------------------------------------------------------------------------------ 294 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 295 CeedInt Q, num_elem, num_input_fields, num_output_fields, size; 296 CeedEvalMode eval_mode; 297 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 298 CeedQFunctionField *qf_input_fields, *qf_output_fields; 299 CeedQFunction qf; 300 CeedOperatorField *op_input_fields, *op_output_fields; 301 CeedOperator_Ref *impl; 302 303 CeedCallBackend(CeedOperatorGetData(op, &impl)); 304 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 305 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 306 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 307 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 308 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 309 310 // Setup 311 CeedCallBackend(CeedOperatorSetup_Ref(op)); 312 313 // Restriction only operator 314 if (impl->is_identity_restr_op) { 315 CeedElemRestriction elem_restr; 316 317 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_restr)); 318 CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 319 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_restr)); 320 CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 321 return CEED_ERROR_SUCCESS; 322 } 323 324 // Input Evecs and Restriction 325 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 326 327 // Output Evecs 328 for (CeedInt i = 0; i < num_output_fields; i++) { 329 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 330 } 331 332 // Loop through elements 333 for (CeedInt e = 0; e < num_elem; e++) { 334 // Output pointers 335 for (CeedInt i = 0; i < num_output_fields; i++) { 336 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 337 if (eval_mode == CEED_EVAL_NONE) { 338 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 339 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * Q * size])); 340 } 341 } 342 343 // Input basis apply 344 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, false, e_data_full, impl)); 345 346 // Q function 347 if (!impl->is_identity_qf) { 348 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 349 } 350 351 // Output basis apply 352 CeedCallBackend( 353 CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, op, e_data_full, impl)); 354 } 355 356 // Output restriction 357 for (CeedInt i = 0; i < num_output_fields; i++) { 358 CeedVector vec; 359 CeedElemRestriction elem_restr; 360 361 // Restore Evec 362 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 363 // Get output vector 364 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 365 // Active 366 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 367 // Restrict 368 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr)); 369 CeedCallBackend(CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 370 } 371 372 // Restore input arrays 373 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 374 return CEED_ERROR_SUCCESS; 375 } 376 377 //------------------------------------------------------------------------------ 378 // Core code for assembling linear QFunction 379 //------------------------------------------------------------------------------ 380 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 381 CeedRequest *request) { 382 Ceed ceed, ceed_parent; 383 CeedSize q_size; 384 CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; 385 CeedScalar *assembled_array, *e_data_full[2 * CEED_FIELD_MAX] = {NULL}; 386 CeedVector *active_in; 387 CeedQFunctionField *qf_input_fields, *qf_output_fields; 388 CeedQFunction qf; 389 CeedOperatorField *op_input_fields, *op_output_fields; 390 CeedOperator_Ref *impl; 391 392 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 393 CeedCallBackend(CeedOperatorGetFallbackParentCeed(op, &ceed_parent)); 394 CeedCallBackend(CeedOperatorGetData(op, &impl)); 395 active_in = impl->qf_active_in; 396 num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 397 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 398 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 399 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 400 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 401 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 402 403 // Setup 404 CeedCallBackend(CeedOperatorSetup_Ref(op)); 405 406 // Check for identity 407 CeedCheck(!impl->is_identity_qf, ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 408 409 // Input Evecs and Restriction 410 CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 411 412 // Count number of active input fields 413 if (!num_active_in) { 414 for (CeedInt i = 0; i < num_input_fields; i++) { 415 CeedScalar *q_vec_array; 416 CeedVector vec; 417 418 // Get input vector 419 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 420 // Check if active input 421 if (vec == CEED_VECTOR_ACTIVE) { 422 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 423 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 424 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &q_vec_array)); 425 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 426 for (CeedInt field = 0; field < size; field++) { 427 q_size = (CeedSize)Q; 428 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 429 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &q_vec_array[field * Q])); 430 } 431 num_active_in += size; 432 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &q_vec_array)); 433 } 434 } 435 impl->num_active_in = num_active_in; 436 impl->qf_active_in = active_in; 437 } 438 439 // Count number of active output fields 440 if (!num_active_out) { 441 for (CeedInt i = 0; i < num_output_fields; i++) { 442 CeedVector vec; 443 444 // Get output vector 445 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 446 // Check if active output 447 if (vec == CEED_VECTOR_ACTIVE) { 448 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 449 num_active_out += size; 450 } 451 } 452 impl->num_active_out = num_active_out; 453 } 454 455 // Check sizes 456 CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 457 458 // Build objects if needed 459 if (build_objects) { 460 const CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 461 CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; /* *NOPAD* */ 462 463 // Create output restriction 464 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, 465 num_active_in * num_active_out * num_elem * Q, strides, rstr)); 466 // Create assembled vector 467 CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled)); 468 } 469 // Clear output vector 470 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 471 CeedCallBackend(CeedVectorGetArray(*assembled, CEED_MEM_HOST, &assembled_array)); 472 473 // Loop through elements 474 for (CeedInt e = 0; e < num_elem; e++) { 475 // Input basis apply 476 CeedCallBackend(CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields, num_input_fields, true, e_data_full, impl)); 477 478 // Assemble QFunction 479 for (CeedInt in = 0; in < num_active_in; in++) { 480 // Set Inputs 481 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 482 if (num_active_in > 1) { 483 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 484 } 485 // Set Outputs 486 for (CeedInt out = 0; out < num_output_fields; out++) { 487 CeedVector vec; 488 489 // Get output vector 490 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 491 // Check if active output 492 if (vec == CEED_VECTOR_ACTIVE) { 493 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, assembled_array)); 494 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 495 assembled_array += size * Q; // Advance the pointer by the size of the output 496 } 497 } 498 // Apply QFunction 499 CeedCallBackend(CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out)); 500 } 501 } 502 503 // Un-set output Qvecs to prevent accidental overwrite of Assembled 504 for (CeedInt out = 0; out < num_output_fields; out++) { 505 CeedVector vec; 506 507 // Get output vector 508 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 509 // Check if active output 510 if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) { 511 CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL)); 512 } 513 } 514 515 // Restore input arrays 516 CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 517 518 // Restore output 519 CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 520 return CEED_ERROR_SUCCESS; 521 } 522 523 //------------------------------------------------------------------------------ 524 // Assemble Linear QFunction 525 //------------------------------------------------------------------------------ 526 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 527 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr, request); 528 } 529 530 //------------------------------------------------------------------------------ 531 // Update Assembled Linear QFunction 532 //------------------------------------------------------------------------------ 533 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 534 return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled, &rstr, request); 535 } 536 537 //------------------------------------------------------------------------------ 538 // Operator Destroy 539 //------------------------------------------------------------------------------ 540 static int CeedOperatorDestroy_Ref(CeedOperator op) { 541 CeedOperator_Ref *impl; 542 543 CeedCallBackend(CeedOperatorGetData(op, &impl)); 544 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 545 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 546 } 547 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 548 CeedCallBackend(CeedFree(&impl->input_states)); 549 550 for (CeedInt i = 0; i < impl->num_inputs; i++) { 551 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 552 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 553 } 554 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 555 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 556 557 for (CeedInt i = 0; i < impl->num_outputs; i++) { 558 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 559 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 560 } 561 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 562 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 563 564 // QFunction assembly 565 for (CeedInt i = 0; i < impl->num_active_in; i++) { 566 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 567 } 568 CeedCallBackend(CeedFree(&impl->qf_active_in)); 569 570 CeedCallBackend(CeedFree(&impl)); 571 return CEED_ERROR_SUCCESS; 572 } 573 574 //------------------------------------------------------------------------------ 575 // Operator Create 576 //------------------------------------------------------------------------------ 577 int CeedOperatorCreate_Ref(CeedOperator op) { 578 Ceed ceed; 579 CeedOperator_Ref *impl; 580 581 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 582 CeedCallBackend(CeedCalloc(1, &impl)); 583 CeedCallBackend(CeedOperatorSetData(op, impl)); 584 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Ref)); 585 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Ref)); 586 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Ref)); 587 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref)); 588 return CEED_ERROR_SUCCESS; 589 } 590 591 //------------------------------------------------------------------------------ 592