xref: /libCEED/rust/libceed-sys/c-src/interface/ceed-tensor.c (revision 2b730f8b5a9c809740a0b3b302db43a719c636b1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 /// @file
13 /// Implementation of CeedTensorContract interfaces
14 
15 /// ----------------------------------------------------------------------------
16 /// CeedTensorContract Backend API
17 /// ----------------------------------------------------------------------------
18 /// @addtogroup CeedBasisBackend
19 /// @{
20 
21 /**
22   @brief Create a CeedTensorContract object for a CeedBasis
23 
24   @param ceed           A Ceed object where the CeedTensorContract will be created
25   @param basis          CeedBasis for which the tensor contraction will be used
26   @param[out] contract  Address of the variable where the newly created
27                           CeedTensorContract will be stored.
28 
29   @return An error code: 0 - success, otherwise - failure
30 
31   @ref Backend
32 **/
33 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
34   if (!ceed->TensorContractCreate) {
35     Ceed delegate;
36     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
37 
38     if (!delegate) {
39       // LCOV_EXCL_START
40       return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
41       // LCOV_EXCL_STOP
42     }
43 
44     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
45     return CEED_ERROR_SUCCESS;
46   }
47 
48   CeedCall(CeedCalloc(1, contract));
49 
50   (*contract)->ceed = ceed;
51   CeedCall(CeedReference(ceed));
52   CeedCall(ceed->TensorContractCreate(basis, *contract));
53   return CEED_ERROR_SUCCESS;
54 }
55 
56 /**
57   @brief Apply tensor contraction
58 
59     Contracts on the middle index
60     NOTRANSPOSE: v_ajc = t_jb u_abc
61     TRANSPOSE:   v_ajc = t_bj u_abc
62     If add != 0, "=" is replaced by "+="
63 
64   @param contract  CeedTensorContract to use
65   @param A         First index of u, v
66   @param B         Middle index of u, one index of t
67   @param C         Last index of u, v
68   @param J         Middle index of v, one index of t
69   @param[in] t     Tensor array to contract against
70   @param t_mode    Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb
71                      \ref CEED_TRANSPOSE for t_bj
72   @param add       Add mode
73   @param[in] u     Input array
74   @param[out] v    Output array
75 
76   @return An error code: 0 - success, otherwise - failure
77 
78   @ref Backend
79 **/
80 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
81                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
82   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
83   return CEED_ERROR_SUCCESS;
84 }
85 
86 /**
87   @brief Get Ceed associated with a CeedTensorContract
88 
89   @param contract   CeedTensorContract
90   @param[out] ceed  Variable to store Ceed
91 
92   @return An error code: 0 - success, otherwise - failure
93 
94   @ref Backend
95 **/
96 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
97   *ceed = contract->ceed;
98   return CEED_ERROR_SUCCESS;
99 }
100 
101 /**
102   @brief Get backend data of a CeedTensorContract
103 
104   @param contract   CeedTensorContract
105   @param[out] data  Variable to store data
106 
107   @return An error code: 0 - success, otherwise - failure
108 
109   @ref Backend
110 **/
111 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
112   *(void **)data = contract->data;
113   return CEED_ERROR_SUCCESS;
114 }
115 
116 /**
117   @brief Set backend data of a CeedTensorContract
118 
119   @param[out] contract  CeedTensorContract
120   @param data           Data to set
121 
122   @return An error code: 0 - success, otherwise - failure
123 
124   @ref Backend
125 **/
126 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
127   contract->data = data;
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 /**
132   @brief Increment the reference counter for a CeedTensorContract
133 
134   @param contract  CeedTensorContract to increment the reference counter
135 
136   @return An error code: 0 - success, otherwise - failure
137 
138   @ref Backend
139 **/
140 int CeedTensorContractReference(CeedTensorContract contract) {
141   contract->ref_count++;
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 /**
146   @brief Destroy a CeedTensorContract
147 
148   @param contract  CeedTensorContract to destroy
149 
150   @return An error code: 0 - success, otherwise - failure
151 
152   @ref Backend
153 **/
154 int CeedTensorContractDestroy(CeedTensorContract *contract) {
155   if (!*contract || --(*contract)->ref_count > 0) return CEED_ERROR_SUCCESS;
156   if ((*contract)->Destroy) {
157     CeedCall((*contract)->Destroy(*contract));
158   }
159   CeedCall(CeedDestroy(&(*contract)->ceed));
160   CeedCall(CeedFree(contract));
161   return CEED_ERROR_SUCCESS;
162 }
163 
164 /// @}
165