1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <stdbool.h> 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include "ceed-blocked.h" 15 16 //------------------------------------------------------------------------------ 17 // Setup Input/Output Fields 18 //------------------------------------------------------------------------------ 19 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, CeedElemRestriction *blk_restr, CeedVector *e_vecs_full, 20 CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { 21 CeedInt num_comp, size, P; 22 CeedSize e_size, q_size; 23 Ceed ceed; 24 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 25 CeedBasis basis; 26 CeedElemRestriction r; 27 CeedOperatorField *op_fields; 28 CeedQFunctionField *qf_fields; 29 if (is_input) { 30 CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL)); 31 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL)); 32 } else { 33 CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields)); 34 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields)); 35 } 36 const CeedInt blk_size = 8; 37 38 // Loop over fields 39 for (CeedInt i = 0; i < num_fields; i++) { 40 CeedEvalMode eval_mode; 41 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); 42 43 if (eval_mode != CEED_EVAL_WEIGHT) { 44 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &r)); 45 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 46 CeedSize l_size; 47 CeedInt num_elem, elem_size, comp_stride; 48 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 49 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 50 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 51 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 52 53 bool strided; 54 CeedCallBackend(CeedElemRestrictionIsStrided(r, &strided)); 55 if (strided) { 56 CeedInt strides[3]; 57 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 58 CeedCallBackend( 59 CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, elem_size, blk_size, num_comp, l_size, strides, &blk_restr[i + start_e])); 60 } else { 61 const CeedInt *offsets = NULL; 62 CeedCallBackend(CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets)); 63 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 64 CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, CEED_MEM_HOST, 65 CEED_COPY_VALUES, offsets, &blk_restr[i + start_e])); 66 CeedCallBackend(CeedElemRestrictionRestoreOffsets(r, &offsets)); 67 } 68 CeedCallBackend(CeedElemRestrictionCreateVector(blk_restr[i + start_e], NULL, &e_vecs_full[i + start_e])); 69 } 70 71 switch (eval_mode) { 72 case CEED_EVAL_NONE: 73 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 74 q_size = (CeedSize)Q * size * blk_size; 75 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 76 break; 77 case CEED_EVAL_INTERP: 78 case CEED_EVAL_GRAD: 79 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 80 CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); 81 CeedCallBackend(CeedBasisGetNumNodes(basis, &P)); 82 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 83 e_size = (CeedSize)P * num_comp * blk_size; 84 CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i])); 85 q_size = (CeedSize)Q * size * blk_size; 86 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 87 break; 88 case CEED_EVAL_WEIGHT: // Only on input fields 89 CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); 90 q_size = (CeedSize)Q * blk_size; 91 CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); 92 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); 93 94 break; 95 case CEED_EVAL_DIV: 96 break; // Not implemented 97 case CEED_EVAL_CURL: 98 break; // Not implemented 99 } 100 } 101 return CEED_ERROR_SUCCESS; 102 } 103 104 //------------------------------------------------------------------------------ 105 // Setup Operator 106 //------------------------------------------------------------------------------ 107 static int CeedOperatorSetup_Blocked(CeedOperator op) { 108 bool is_setup_done; 109 CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done)); 110 if (is_setup_done) return CEED_ERROR_SUCCESS; 111 Ceed ceed; 112 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 113 CeedOperator_Blocked *impl; 114 CeedCallBackend(CeedOperatorGetData(op, &impl)); 115 CeedQFunction qf; 116 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 117 CeedInt Q, num_input_fields, num_output_fields; 118 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 119 CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf)); 120 CeedOperatorField *op_input_fields, *op_output_fields; 121 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 122 CeedQFunctionField *qf_input_fields, *qf_output_fields; 123 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 124 125 // Allocate 126 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr)); 127 CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); 128 129 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); 130 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); 131 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); 132 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); 133 CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); 134 135 impl->num_inputs = num_input_fields; 136 impl->num_outputs = num_output_fields; 137 138 // Set up infield and outfield pointer arrays 139 // Infields 140 CeedCallBackend( 141 CeedOperatorSetupFields_Blocked(qf, op, true, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); 142 // Outfields 143 CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, 144 num_input_fields, num_output_fields, Q)); 145 146 // Identity QFunctions 147 if (impl->is_identity_qf) { 148 CeedEvalMode in_mode, out_mode; 149 CeedQFunctionField *in_fields, *out_fields; 150 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields)); 151 CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode)); 152 CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode)); 153 154 if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) { 155 impl->is_identity_restr_op = true; 156 } else { 157 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[0])); 158 impl->q_vecs_out[0] = impl->q_vecs_in[0]; 159 CeedCallBackend(CeedVectorAddReference(impl->q_vecs_in[0])); 160 } 161 } 162 163 CeedCallBackend(CeedOperatorSetSetupDone(op)); 164 165 return CEED_ERROR_SUCCESS; 166 } 167 168 //------------------------------------------------------------------------------ 169 // Setup Operator Inputs 170 //------------------------------------------------------------------------------ 171 static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 172 CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], 173 CeedOperator_Blocked *impl, CeedRequest *request) { 174 CeedEvalMode eval_mode; 175 CeedVector vec; 176 uint64_t state; 177 178 for (CeedInt i = 0; i < num_input_fields; i++) { 179 // Get input vector 180 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 181 if (vec == CEED_VECTOR_ACTIVE) { 182 if (skip_active) continue; 183 else vec = in_vec; 184 } 185 186 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 187 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 188 } else { 189 // Restrict 190 CeedCallBackend(CeedVectorGetState(vec, &state)); 191 if (state != impl->input_states[i] || vec == in_vec) { 192 CeedCallBackend(CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); 193 impl->input_states[i] = state; 194 } 195 // Get evec 196 CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i])); 197 } 198 } 199 return CEED_ERROR_SUCCESS; 200 } 201 202 //------------------------------------------------------------------------------ 203 // Input Basis Action 204 //------------------------------------------------------------------------------ 205 static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 206 CeedInt num_input_fields, CeedInt blk_size, bool skip_active, 207 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 208 CeedInt dim, elem_size, size; 209 CeedElemRestriction elem_restr; 210 CeedEvalMode eval_mode; 211 CeedBasis basis; 212 213 for (CeedInt i = 0; i < num_input_fields; i++) { 214 // Skip active input 215 if (skip_active) { 216 CeedVector vec; 217 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 218 if (vec == CEED_VECTOR_ACTIVE) continue; 219 } 220 221 // Get elem_size, eval_mode, size 222 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr)); 223 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 224 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 225 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 226 // Basis action 227 switch (eval_mode) { 228 case CEED_EVAL_NONE: 229 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * Q * size])); 230 break; 231 case CEED_EVAL_INTERP: 232 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 233 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * elem_size * size])); 234 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, impl->e_vecs_in[i], impl->q_vecs_in[i])); 235 break; 236 case CEED_EVAL_GRAD: 237 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 238 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 239 CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][e * elem_size * size / dim])); 240 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, impl->e_vecs_in[i], impl->q_vecs_in[i])); 241 break; 242 case CEED_EVAL_WEIGHT: 243 break; // No action 244 // LCOV_EXCL_START 245 case CEED_EVAL_DIV: 246 case CEED_EVAL_CURL: { 247 CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); 248 Ceed ceed; 249 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 250 return CeedError(ceed, CEED_ERROR_BACKEND, "Ceed evaluation mode not implemented"); 251 // LCOV_EXCL_STOP 252 } 253 } 254 } 255 return CEED_ERROR_SUCCESS; 256 } 257 258 //------------------------------------------------------------------------------ 259 // Output Basis Action 260 //------------------------------------------------------------------------------ 261 static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, 262 CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, 263 CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 264 CeedInt dim, elem_size, size; 265 CeedElemRestriction elem_restr; 266 CeedEvalMode eval_mode; 267 CeedBasis basis; 268 269 for (CeedInt i = 0; i < num_output_fields; i++) { 270 // Get elem_size, eval_mode, size 271 CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr)); 272 CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size)); 273 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 274 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 275 // Basis action 276 switch (eval_mode) { 277 case CEED_EVAL_NONE: 278 break; // No action 279 case CEED_EVAL_INTERP: 280 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 281 CeedCallBackend( 282 CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * elem_size * size])); 283 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, CEED_EVAL_INTERP, impl->q_vecs_out[i], impl->e_vecs_out[i])); 284 break; 285 case CEED_EVAL_GRAD: 286 CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); 287 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 288 CeedCallBackend( 289 CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * elem_size * size / dim])); 290 CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, CEED_EVAL_GRAD, impl->q_vecs_out[i], impl->e_vecs_out[i])); 291 break; 292 // LCOV_EXCL_START 293 case CEED_EVAL_WEIGHT: { 294 Ceed ceed; 295 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 296 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); 297 } 298 case CEED_EVAL_DIV: 299 case CEED_EVAL_CURL: { 300 Ceed ceed; 301 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 302 return CeedError(ceed, CEED_ERROR_BACKEND, "Ceed evaluation mode not implemented"); 303 // LCOV_EXCL_STOP 304 } 305 } 306 } 307 return CEED_ERROR_SUCCESS; 308 } 309 310 //------------------------------------------------------------------------------ 311 // Restore Input Vectors 312 //------------------------------------------------------------------------------ 313 static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields, 314 bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { 315 CeedEvalMode eval_mode; 316 317 for (CeedInt i = 0; i < num_input_fields; i++) { 318 // Skip active inputs 319 if (skip_active) { 320 CeedVector vec; 321 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 322 if (vec == CEED_VECTOR_ACTIVE) continue; 323 } 324 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 325 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 326 } else { 327 CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i])); 328 } 329 } 330 return CEED_ERROR_SUCCESS; 331 } 332 333 //------------------------------------------------------------------------------ 334 // Operator Apply 335 //------------------------------------------------------------------------------ 336 static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { 337 CeedOperator_Blocked *impl; 338 CeedCallBackend(CeedOperatorGetData(op, &impl)); 339 const CeedInt blk_size = 8; 340 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 341 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 342 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 343 CeedInt num_blks = (num_elem / blk_size) + !!(num_elem % blk_size); 344 CeedQFunction qf; 345 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 346 CeedOperatorField *op_input_fields, *op_output_fields; 347 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 348 CeedQFunctionField *qf_input_fields, *qf_output_fields; 349 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 350 CeedEvalMode eval_mode; 351 CeedVector vec; 352 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {0}; 353 354 // Setup 355 CeedCallBackend(CeedOperatorSetup_Blocked(op)); 356 357 // Restriction only operator 358 if (impl->is_identity_restr_op) { 359 CeedCallBackend(CeedElemRestrictionApply(impl->blk_restr[0], CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request)); 360 CeedCallBackend(CeedElemRestrictionApply(impl->blk_restr[1], CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request)); 361 return CEED_ERROR_SUCCESS; 362 } 363 364 // Input Evecs and Restriction 365 CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); 366 367 // Output Evecs 368 for (CeedInt i = 0; i < num_output_fields; i++) { 369 CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); 370 } 371 372 // Loop through elements 373 for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) { 374 // Output pointers 375 for (CeedInt i = 0; i < num_output_fields; i++) { 376 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 377 if (eval_mode == CEED_EVAL_NONE) { 378 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 379 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][e * Q * size])); 380 } 381 } 382 383 // Input basis apply 384 CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, false, e_data_full, impl)); 385 386 // Q function 387 if (!impl->is_identity_qf) { 388 CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out)); 389 } 390 391 // Output basis apply 392 CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, blk_size, num_input_fields, num_output_fields, op, 393 e_data_full, impl)); 394 } 395 396 // Output restriction 397 for (CeedInt i = 0; i < num_output_fields; i++) { 398 // Restore evec 399 CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); 400 // Get output vector 401 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 402 // Active 403 if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; 404 // Restrict 405 CeedCallBackend( 406 CeedElemRestrictionApply(impl->blk_restr[i + impl->num_inputs], CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request)); 407 } 408 409 // Restore input arrays 410 CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl)); 411 412 return CEED_ERROR_SUCCESS; 413 } 414 415 //------------------------------------------------------------------------------ 416 // Core code for assembling linear QFunction 417 //------------------------------------------------------------------------------ 418 static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator op, bool build_objects, CeedVector *assembled, 419 CeedElemRestriction *rstr, CeedRequest *request) { 420 CeedOperator_Blocked *impl; 421 CeedCallBackend(CeedOperatorGetData(op, &impl)); 422 const CeedInt blk_size = 8; 423 CeedInt Q, num_input_fields, num_output_fields, num_elem, size; 424 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 425 CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 426 CeedInt num_blks = (num_elem / blk_size) + !!(num_elem % blk_size); 427 CeedQFunction qf; 428 CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 429 CeedOperatorField *op_input_fields, *op_output_fields; 430 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 431 CeedQFunctionField *qf_input_fields, *qf_output_fields; 432 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 433 CeedVector vec, l_vec = impl->qf_l_vec; 434 CeedInt num_active_in = impl->num_active_in, num_active_out = impl->num_active_out; 435 CeedSize q_size; 436 CeedVector *active_in = impl->qf_active_in; 437 CeedScalar *a, *tmp; 438 Ceed ceed; 439 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 440 CeedScalar *e_data_full[2 * CEED_FIELD_MAX] = {0}; 441 442 // Setup 443 CeedCallBackend(CeedOperatorSetup_Blocked(op)); 444 445 // Check for identity 446 if (impl->is_identity_qf) { 447 // LCOV_EXCL_START 448 return CeedError(ceed, CEED_ERROR_BACKEND, "Assembling identity QFunctions not supported"); 449 // LCOV_EXCL_STOP 450 } 451 452 // Input Evecs and Restriction 453 CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request)); 454 455 // Count number of active input fields 456 if (!num_active_in) { 457 for (CeedInt i = 0; i < num_input_fields; i++) { 458 // Get input vector 459 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 460 // Check if active input 461 if (vec == CEED_VECTOR_ACTIVE) { 462 CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size)); 463 CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); 464 CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp)); 465 CeedCallBackend(CeedRealloc(num_active_in + size, &active_in)); 466 for (CeedInt field = 0; field < size; field++) { 467 q_size = (CeedSize)Q * blk_size; 468 CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field])); 469 CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &tmp[field * Q * blk_size])); 470 } 471 num_active_in += size; 472 CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp)); 473 } 474 } 475 impl->num_active_in = num_active_in; 476 impl->qf_active_in = active_in; 477 } 478 479 // Count number of active output fields 480 if (!num_active_out) { 481 for (CeedInt i = 0; i < num_output_fields; i++) { 482 // Get output vector 483 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 484 // Check if active output 485 if (vec == CEED_VECTOR_ACTIVE) { 486 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size)); 487 num_active_out += size; 488 } 489 } 490 impl->num_active_out = num_active_out; 491 } 492 493 // Check sizes 494 if (!num_active_in || !num_active_out) { 495 // LCOV_EXCL_START 496 return CeedError(ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 497 // LCOV_EXCL_STOP 498 } 499 500 // Setup Lvec 501 if (!l_vec) { 502 CeedSize l_size = (CeedSize)num_blks * blk_size * Q * num_active_in * num_active_out; 503 CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec)); 504 impl->qf_l_vec = l_vec; 505 } 506 CeedCallBackend(CeedVectorGetArrayWrite(l_vec, CEED_MEM_HOST, &a)); 507 508 // Build objects if needed 509 CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q}; 510 if (build_objects) { 511 // Create output restriction 512 CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, num_active_in * num_active_out, num_active_in * num_active_out * num_elem * Q, 513 strides, rstr)); 514 // Create assembled vector 515 CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out; 516 CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled)); 517 } 518 519 // Loop through elements 520 for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) { 521 // Input basis apply 522 CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, true, e_data_full, impl)); 523 524 // Assemble QFunction 525 for (CeedInt in = 0; in < num_active_in; in++) { 526 // Set Inputs 527 CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0)); 528 if (num_active_in > 1) { 529 CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0)); 530 } 531 // Set Outputs 532 for (CeedInt out = 0; out < num_output_fields; out++) { 533 // Get output vector 534 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 535 // Check if active output 536 if (vec == CEED_VECTOR_ACTIVE) { 537 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, a)); 538 CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size)); 539 a += size * Q * blk_size; // Advance the pointer by the size of the output 540 } 541 } 542 // Apply QFunction 543 CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out)); 544 } 545 } 546 547 // Un-set output Qvecs to prevent accidental overwrite of Assembled 548 for (CeedInt out = 0; out < num_output_fields; out++) { 549 // Get output vector 550 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec)); 551 // Check if active output 552 if (vec == CEED_VECTOR_ACTIVE) { 553 CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, NULL)); 554 } 555 } 556 557 // Restore input arrays 558 CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl)); 559 560 // Output blocked restriction 561 CeedCallBackend(CeedVectorRestoreArray(l_vec, &a)); 562 CeedCallBackend(CeedVectorSetValue(*assembled, 0.0)); 563 CeedElemRestriction blk_rstr = impl->qf_blk_rstr; 564 if (!impl->qf_blk_rstr) { 565 CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, num_active_in * num_active_out, 566 num_active_in * num_active_out * num_elem * Q, strides, &blk_rstr)); 567 impl->qf_blk_rstr = blk_rstr; 568 } 569 CeedCallBackend(CeedElemRestrictionApply(blk_rstr, CEED_TRANSPOSE, l_vec, *assembled, request)); 570 571 return CEED_ERROR_SUCCESS; 572 } 573 574 //------------------------------------------------------------------------------ 575 // Assemble Linear QFunction 576 //------------------------------------------------------------------------------ 577 static int CeedOperatorLinearAssembleQFunction_Blocked(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 578 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, true, assembled, rstr, request); 579 } 580 581 //------------------------------------------------------------------------------ 582 // Update Assembled Linear QFunction 583 //------------------------------------------------------------------------------ 584 static int CeedOperatorLinearAssembleQFunctionUpdate_Blocked(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 585 return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, false, &assembled, &rstr, request); 586 } 587 588 //------------------------------------------------------------------------------ 589 // Operator Destroy 590 //------------------------------------------------------------------------------ 591 static int CeedOperatorDestroy_Blocked(CeedOperator op) { 592 CeedOperator_Blocked *impl; 593 CeedCallBackend(CeedOperatorGetData(op, &impl)); 594 595 for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { 596 CeedCallBackend(CeedElemRestrictionDestroy(&impl->blk_restr[i])); 597 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); 598 } 599 CeedCallBackend(CeedFree(&impl->blk_restr)); 600 CeedCallBackend(CeedFree(&impl->e_vecs_full)); 601 CeedCallBackend(CeedFree(&impl->input_states)); 602 603 for (CeedInt i = 0; i < impl->num_inputs; i++) { 604 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); 605 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i])); 606 } 607 CeedCallBackend(CeedFree(&impl->e_vecs_in)); 608 CeedCallBackend(CeedFree(&impl->q_vecs_in)); 609 610 for (CeedInt i = 0; i < impl->num_outputs; i++) { 611 CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); 612 CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); 613 } 614 CeedCallBackend(CeedFree(&impl->e_vecs_out)); 615 CeedCallBackend(CeedFree(&impl->q_vecs_out)); 616 617 // QFunction assembly data 618 for (CeedInt i = 0; i < impl->num_active_in; i++) { 619 CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i])); 620 } 621 CeedCallBackend(CeedFree(&impl->qf_active_in)); 622 CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec)); 623 CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_blk_rstr)); 624 625 CeedCallBackend(CeedFree(&impl)); 626 return CEED_ERROR_SUCCESS; 627 } 628 629 //------------------------------------------------------------------------------ 630 // Operator Create 631 //------------------------------------------------------------------------------ 632 int CeedOperatorCreate_Blocked(CeedOperator op) { 633 Ceed ceed; 634 CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 635 CeedOperator_Blocked *impl; 636 637 CeedCallBackend(CeedCalloc(1, &impl)); 638 CeedCallBackend(CeedOperatorSetData(op, impl)); 639 640 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Blocked)); 641 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Blocked)); 642 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Blocked)); 643 CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Blocked)); 644 return CEED_ERROR_SUCCESS; 645 } 646 //------------------------------------------------------------------------------ 647