1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
39b2a10adSJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
59b2a10adSJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed
79b2a10adSJeremy L Thompson
849aac155SJeremy L Thompson #include <ceed.h>
99b2a10adSJeremy L Thompson #include <ceed/backend.h>
102b730f8bSJeremy L Thompson
119b2a10adSJeremy L Thompson #include "ceed-opt.h"
129b2a10adSJeremy L Thompson
139b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
149b2a10adSJeremy L Thompson // Tensor Contract Core loop
159b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
CeedTensorContractApply_Core_Opt(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)162b730f8bSJeremy L Thompson static inline int CeedTensorContractApply_Core_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
172b730f8bSJeremy L Thompson const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
182b730f8bSJeremy L Thompson const CeedScalar *restrict u, CeedScalar *restrict v) {
199b2a10adSJeremy L Thompson CeedInt t_stride_0 = B, t_stride_1 = 1;
20ad70ee2cSJeremy L Thompson
219b2a10adSJeremy L Thompson if (t_mode == CEED_TRANSPOSE) {
222b730f8bSJeremy L Thompson t_stride_0 = 1;
232b730f8bSJeremy L Thompson t_stride_1 = J;
249b2a10adSJeremy L Thompson }
259b2a10adSJeremy L Thompson
262b730f8bSJeremy L Thompson for (CeedInt a = 0; a < A; a++) {
272b730f8bSJeremy L Thompson for (CeedInt b = 0; b < B; b++) {
289b2a10adSJeremy L Thompson for (CeedInt j = 0; j < J; j++) {
299b2a10adSJeremy L Thompson CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
302b730f8bSJeremy L Thompson for (CeedInt c = 0; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
312b730f8bSJeremy L Thompson }
322b730f8bSJeremy L Thompson }
339b2a10adSJeremy L Thompson }
349b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS;
359b2a10adSJeremy L Thompson }
369b2a10adSJeremy L Thompson
379b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
389b2a10adSJeremy L Thompson // Tensor Contract Apply
399b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
CeedTensorContractApply_Opt(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)402b730f8bSJeremy L Thompson static int CeedTensorContractApply_Opt(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
412b730f8bSJeremy L Thompson CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
422b730f8bSJeremy L Thompson if (!add) {
432b730f8bSJeremy L Thompson for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
442b730f8bSJeremy L Thompson }
459b2a10adSJeremy L Thompson
462b730f8bSJeremy L Thompson if (C == 1) return CeedTensorContractApply_Core_Opt(contract, A, B, 1, J, t, t_mode, add, u, v);
472b730f8bSJeremy L Thompson else return CeedTensorContractApply_Core_Opt(contract, A, B, C, J, t, t_mode, add, u, v);
489b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS;
499b2a10adSJeremy L Thompson }
509b2a10adSJeremy L Thompson
519b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
529b2a10adSJeremy L Thompson // Tensor Contract Create
539b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
CeedTensorContractCreate_Opt(CeedTensorContract contract)54a71faab1SSebastian Grimberg int CeedTensorContractCreate_Opt(CeedTensorContract contract) {
556e536b99SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Opt));
569b2a10adSJeremy L Thompson return CEED_ERROR_SUCCESS;
579b2a10adSJeremy L Thompson }
582a86cc9dSSebastian Grimberg
599b2a10adSJeremy L Thompson //------------------------------------------------------------------------------
60