xref: /libCEED/interface/ceed-tensor.c (revision a2e5d304d0c7d96eecfcbbd32f1ea5194beb84ca)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 /// @file
13 /// Implementation of CeedTensorContract interfaces
14 
15 /// ----------------------------------------------------------------------------
16 /// CeedTensorContract Backend API
17 /// ----------------------------------------------------------------------------
18 /// @addtogroup CeedBasisBackend
19 /// @{
20 
21 /**
22   @brief Create a CeedTensorContract object for a CeedBasis
23 
24   @param[in]  ceed     Ceed object where the CeedTensorContract will be created
25   @param[in]  basis    CeedBasis for which the tensor contraction will be used
26   @param[out] contract Address of the variable where the newly created CeedTensorContract will be stored.
27 
28   @return An error code: 0 - success, otherwise - failure
29 
30   @ref Backend
31 **/
32 int CeedTensorContractCreate(Ceed ceed, CeedBasis basis, CeedTensorContract *contract) {
33   if (!ceed->TensorContractCreate) {
34     Ceed delegate;
35     CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
36 
37     if (!delegate) {
38       // LCOV_EXCL_START
39       return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
40       // LCOV_EXCL_STOP
41     }
42 
43     CeedCall(CeedTensorContractCreate(delegate, basis, contract));
44     return CEED_ERROR_SUCCESS;
45   }
46 
47   CeedCall(CeedCalloc(1, contract));
48 
49   (*contract)->ceed = ceed;
50   CeedCall(CeedReference(ceed));
51   CeedCall(ceed->TensorContractCreate(basis, *contract));
52   return CEED_ERROR_SUCCESS;
53 }
54 
55 /**
56   @brief Apply tensor contraction
57 
58     Contracts on the middle index
59     NOTRANSPOSE: v_ajc = t_jb u_abc
60     TRANSPOSE:   v_ajc = t_bj u_abc
61     If add != 0, "=" is replaced by "+="
62 
63   @param[in]  contract CeedTensorContract to use
64   @param[in]  A        First index of u, v
65   @param[in]  B        Middle index of u, one index of t
66   @param[in]  C        Last index of u, v
67   @param[in]  J        Middle index of v, one index of t
68   @param[in]  t        Tensor array to contract against
69   @param[in]  t_mode   Transpose mode for t, \ref CEED_NOTRANSPOSE for t_jb \ref CEED_TRANSPOSE for t_bj
70   @param[in]  add      Add mode
71   @param[in]  u        Input array
72   @param[out] v        Output array
73 
74   @return An error code: 0 - success, otherwise - failure
75 
76   @ref Backend
77 **/
78 int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
79                             CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
80   CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
81   return CEED_ERROR_SUCCESS;
82 }
83 
84 /**
85   @brief Get Ceed associated with a CeedTensorContract
86 
87   @param[in]  contract CeedTensorContract
88   @param[out] ceed     Variable to store Ceed
89 
90   @return An error code: 0 - success, otherwise - failure
91 
92   @ref Backend
93 **/
94 int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
95   *ceed = contract->ceed;
96   return CEED_ERROR_SUCCESS;
97 }
98 
99 /**
100   @brief Get backend data of a CeedTensorContract
101 
102   @param[in]  contract CeedTensorContract
103   @param[out] data     Variable to store data
104 
105   @return An error code: 0 - success, otherwise - failure
106 
107   @ref Backend
108 **/
109 int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
110   *(void **)data = contract->data;
111   return CEED_ERROR_SUCCESS;
112 }
113 
114 /**
115   @brief Set backend data of a CeedTensorContract
116 
117   @param[in,out] contract CeedTensorContract
118   @param[in]     data     Data to set
119 
120   @return An error code: 0 - success, otherwise - failure
121 
122   @ref Backend
123 **/
124 int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
125   contract->data = data;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 /**
130   @brief Increment the reference counter for a CeedTensorContract
131 
132   @param[in,out] contract CeedTensorContract to increment the reference counter
133 
134   @return An error code: 0 - success, otherwise - failure
135 
136   @ref Backend
137 **/
138 int CeedTensorContractReference(CeedTensorContract contract) {
139   contract->ref_count++;
140   return CEED_ERROR_SUCCESS;
141 }
142 
143 /**
144   @brief Destroy a CeedTensorContract
145 
146   @param[in,out] contract CeedTensorContract to destroy
147 
148   @return An error code: 0 - success, otherwise - failure
149 
150   @ref Backend
151 **/
152 int CeedTensorContractDestroy(CeedTensorContract *contract) {
153   if (!*contract || --(*contract)->ref_count > 0) {
154     *contract = NULL;
155     return CEED_ERROR_SUCCESS;
156   }
157   if ((*contract)->Destroy) {
158     CeedCall((*contract)->Destroy(*contract));
159   }
160   CeedCall(CeedDestroy(&(*contract)->ceed));
161   CeedCall(CeedFree(contract));
162   return CEED_ERROR_SUCCESS;
163 }
164 
165 /// @}
166