xref: /libCEED/backends/ref/ceed-ref-tensor.c (revision 93b6d8191bb649fce95198206c6bff32567615d1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include "ceed-ref.h"
11 
12 //------------------------------------------------------------------------------
13 // Tensor Contract Apply
14 //------------------------------------------------------------------------------
15 static int CeedTensorContractApply_Ref(CeedTensorContract contract, CeedInt A,
16                                        CeedInt B, CeedInt C, CeedInt J,
17                                        const CeedScalar *restrict t,
18                                        CeedTransposeMode t_mode, const CeedInt add,
19                                        const CeedScalar *restrict u,
20                                        CeedScalar *restrict v) {
21   CeedInt t_stride_0 = B, t_stride_1 = 1;
22   if (t_mode == CEED_TRANSPOSE) {
23     t_stride_0 = 1; t_stride_1 = J;
24   }
25 
26   if (!add)
27     for (CeedInt q=0; q<A*J*C; q++)
28       v[q] = (CeedScalar) 0.0;
29 
30   for (CeedInt a=0; a<A; a++)
31     for (CeedInt b=0; b<B; b++)
32       for (CeedInt j=0; j<J; j++) {
33         CeedScalar tq = t[j*t_stride_0 + b*t_stride_1];
34         for (CeedInt c=0; c<C; c++)
35           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
36       }
37   return CEED_ERROR_SUCCESS;
38 }
39 
40 //------------------------------------------------------------------------------
41 // Tensor Contract Destroy
42 //------------------------------------------------------------------------------
43 static int CeedTensorContractDestroy_Ref(CeedTensorContract contract) {
44   return CEED_ERROR_SUCCESS;
45 }
46 
47 //------------------------------------------------------------------------------
48 // Tensor Contract Create
49 //------------------------------------------------------------------------------
50 int CeedTensorContractCreate_Ref(CeedBasis basis, CeedTensorContract contract) {
51   int ierr;
52   Ceed ceed;
53   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr);
54 
55   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
56                                 CeedTensorContractApply_Ref); CeedChkBackend(ierr);
57   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
58                                 CeedTensorContractDestroy_Ref); CeedChkBackend(ierr);
59 
60   return CEED_ERROR_SUCCESS;
61 }
62 //------------------------------------------------------------------------------
63