xref: /libCEED/backends/ref/ceed-ref-tensor.c (revision 05efb9567ae7c8fac3dbf1468da479ad6ed8c3a9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 
11 #include "ceed-ref.h"
12 
13 //------------------------------------------------------------------------------
14 // Tensor Contract Apply
15 //------------------------------------------------------------------------------
16 static int CeedTensorContractApply_Ref(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
17                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
18   CeedInt t_stride_0 = B, t_stride_1 = 1;
19   if (t_mode == CEED_TRANSPOSE) {
20     t_stride_0 = 1;
21     t_stride_1 = J;
22   }
23 
24   if (!add) {
25     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
26   }
27 
28   for (CeedInt a = 0; a < A; a++) {
29     for (CeedInt b = 0; b < B; b++) {
30       for (CeedInt j = 0; j < J; j++) {
31         CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
32         for (CeedInt c = 0; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
33       }
34     }
35   }
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Tensor Contract Destroy
41 //------------------------------------------------------------------------------
42 static int CeedTensorContractDestroy_Ref(CeedTensorContract contract) { return CEED_ERROR_SUCCESS; }
43 
44 //------------------------------------------------------------------------------
45 // Tensor Contract Create
46 //------------------------------------------------------------------------------
47 int CeedTensorContractCreate_Ref(CeedBasis basis, CeedTensorContract contract) {
48   Ceed ceed;
49   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
50 
51   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Ref));
52   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", CeedTensorContractDestroy_Ref));
53 
54   return CEED_ERROR_SUCCESS;
55 }
56 
57 //------------------------------------------------------------------------------
58