xref: /petsc/src/binding/petsc4py/src/petsc4py/PETSc/Regressor.pyx (revision cc936d47a49e2d9025998b90c561cf03e422e9df)
1class RegressorType(object):
2    """REGRESSOR solver type.
3
4    See Also
5    --------
6    petsc.PetscRegressorType
7
8    """
9    LINEAR = S_(PETSCREGRESSORLINEAR)
10
11
12class RegressorLinearType(object):
13    """Linear regressor type.
14
15    See Also
16    --------
17    petsc.PetscRegressorLinearType
18    """
19    OLS   = REGRESSOR_LINEAR_OLS
20    LASSO = REGRESSOR_LINEAR_LASSO
21    RIDGE = REGRESSOR_LINEAR_RIDGE
22
23
24cdef class Regressor(Object):
25    """Regression solver.
26
27    REGRESSOR  is described in the `PETSc manual <petsc:manual/regressor>`.
28
29    See Also
30    --------
31    petsc.PetscRegressor
32
33    """
34
35    Type = RegressorType
36    LinearType = RegressorLinearType
37
38    def __cinit__(self):
39        self.obj = <PetscObject*> &self.regressor
40        self.regressor = NULL
41
42    def view(self, Viewer viewer=None) -> None:
43        """View the solver.
44
45        Collective.
46
47        Parameters
48        ----------
49        viewer
50            A `Viewer` instance or `None` for the default viewer.
51
52        See Also
53        --------
54        petsc.PetscRegressorView
55
56        """
57        cdef PetscViewer cviewer = NULL
58        if viewer is not None: cviewer = viewer.vwr
59        CHKERR(PetscRegressorView(self.regressor, cviewer))
60
61    def create(self, comm=None) -> Self:
62        """Create a REGRESSOR solver.
63
64        Collective.
65
66        Parameters
67        ----------
68        comm
69            MPI communicator, defaults to `Sys.getDefaultComm`.
70
71        See Also
72        --------
73        Sys.getDefaultComm, petsc.PetscRegressorCreate
74
75        """
76        cdef MPI_Comm ccomm = def_Comm(comm, PETSC_COMM_DEFAULT)
77        cdef PetscRegressor newregressor = NULL
78        CHKERR(PetscRegressorCreate(ccomm, &newregressor))
79        PetscCLEAR(self.obj); self.regressor = newregressor
80        return self
81
82    def setRegularizerWeight(self, weight: float) -> None:
83        """Set the weight to be used for the regularizer.
84
85        Logically collective.
86
87        See Also
88        --------
89        setType, petsc.PetscRegressorSetRegularizerWeight
90        """
91        CHKERR(PetscRegressorSetRegularizerWeight(self.regressor, weight))
92
93    def setFromOptions(self) -> None:
94        """Configure the solver from the options database.
95
96        Collective.
97
98        See Also
99        --------
100        petsc_options, petsc.PetscRegressorSetFromOptions
101        """
102        CHKERR(PetscRegressorSetFromOptions(self.regressor))
103
104    def setUp(self) -> None:
105        """Set up the internal data structures for using the solver.
106
107        Collective.
108
109        See Also
110        --------
111        petsc.PetscRegressorSetUp
112
113        """
114        CHKERR(PetscRegressorSetUp(self.regressor))
115
116    def fit(self, Mat X, Vec y) -> None:
117        """Fit the regression problem.
118
119        Collective.
120
121        Parameters
122        ----------
123        X
124            The matrix of training data
125        y
126            The vector of target values from the training dataset
127
128        See Also
129        --------
130        petsc.PetscRegressorPredict
131
132        """
133        CHKERR(PetscRegressorFit(self.regressor, X.mat, y.vec))
134
135    def predict(self, Mat X, Vec y) -> None:
136        """Predict the regression problem.
137
138        Collective.
139
140        Parameters
141        ----------
142        X
143            The matrix of unlabeled observations
144        y
145            The vector of predicted labels
146
147        See Also
148        --------
149        petsc.PetscRegressorFit
150
151        """
152        CHKERR(PetscRegressorPredict(self.regressor, X.mat, y.vec))
153
154    def getTAO(self) -> TAO:
155        """Return the underlying `TAO` object .
156
157        Not collective.
158
159        See Also
160        --------
161        getLinearKSP, petsc.PetscRegressorGetTao
162        """
163        cdef TAO tao = TAO()
164        CHKERR(PetscRegressorGetTao(self.regressor, &tao.tao))
165        CHKERR(PetscINCREF(tao.obj))
166        return tao
167
168    def reset(self) -> None:
169        """Destroy internal data structures of the solver.
170
171        Collective.
172
173        See Also
174        --------
175        petsc.PetscRegressorDestroy
176
177        """
178        CHKERR(PetscRegressorReset(self.regressor))
179
180    def destroy(self) -> Self:
181        """Destroy the solver.
182
183        Collective.
184
185        See Also
186        --------
187        petsc.PetscRegressorDestroy
188
189        """
190        CHKERR(PetscRegressorDestroy(&self.regressor))
191        return self
192
193    def setType(self, regressor_type: Type | str) -> None:
194        """Set the type of the solver.
195
196        Logically collective.
197
198        Parameters
199        ----------
200        regressor_type
201            The type of the solver.
202
203        See Also
204        --------
205        getType, petsc.PetscRegressorSetType
206
207        """
208        cdef PetscRegressorType cval = NULL
209        regressor_type = str2bytes(regressor_type, &cval)
210        CHKERR(PetscRegressorSetType(self.regressor, cval))
211
212    def getType(self) -> str:
213        """Return the type of the solver.
214
215        Not collective.
216
217        See Also
218        --------
219        setType, petsc.PetscRegressorGetType
220
221        """
222        cdef PetscRegressorType ctype = NULL
223        CHKERR(PetscRegressorGetType(self.regressor, &ctype))
224        return bytes2str(ctype)
225
226    # --- Linear ---
227
228    def setLinearFitIntercept(self, flag: bool) -> None:
229        """Set a flag to indicate that the intercept should be calculated.
230
231        Logically collective.
232
233        See Also
234        --------
235        petsc.PetscRegressorLinearSetFitIntercept
236        """
237        cdef PetscBool fitintercept = flag
238        CHKERR(PetscRegressorLinearSetFitIntercept(self.regressor, fitintercept))
239
240    def setLinearUseKSP(self, flag: bool) -> None:
241        """Set a flag to indicate that `KSP` instead of `TAO` solvers should be used.
242
243        Logically collective.
244
245        See Also
246        --------
247        petsc.PetscRegressorLinearSetUseKSP
248        """
249        cdef PetscBool useksp = flag
250        CHKERR(PetscRegressorLinearSetUseKSP(self.regressor, useksp))
251
252    def getLinearKSP(self) -> KSP:
253        """Returns the `KSP` context used by the linear regressor.
254
255        Not collective.
256
257        See Also
258        --------
259        petsc.PetscRegressorLinearGetKSP
260        """
261        cdef KSP ksp = KSP()
262        CHKERR(PetscRegressorLinearGetKSP(self.regressor, &ksp.ksp))
263        CHKERR(PetscINCREF(ksp.obj))
264        return ksp
265
266    def getLinearCoefficients(self) -> Vec:
267        """Get a vector of the fitted coefficients from a linear regression model.
268
269        Not collective.
270
271        See Also
272        --------
273        getLinearIntercept, petsc.PetscRegressorLinearGetCoefficients
274        """
275        cdef Vec coeffs = Vec()
276        CHKERR(PetscRegressorLinearGetCoefficients(self.regressor, &coeffs.vec))
277        CHKERR(PetscINCREF(coeffs.obj))
278        return coeffs
279
280    def getLinearIntercept(self) -> Scalar:
281        """Get the intercept from a linear regression model.
282
283        Not collective.
284
285        See Also
286        --------
287        setLinearFitIntercept, petsc.PetscRegressorLinearGetIntercept
288        """
289        cdef PetscScalar intercept = 0.0
290        CHKERR(PetscRegressorLinearGetIntercept(self.regressor, &intercept))
291        return toScalar(intercept)
292
293    def setLinearType(self, lineartype: RegressorLinearType) -> None:
294        """Set the type of linear regression to be performed.
295
296        Logically collective.
297
298        See Also
299        --------
300        getLinearType, petsc.PetscRegressorLinearSetType
301        """
302        CHKERR(PetscRegressorLinearSetType(self.regressor, lineartype))
303
304    def getLinearType(self) -> RegressorLinearType:
305        """Return the type of the linear regressor.
306
307        Not collective.
308
309        See Also
310        --------
311        setLinearType, petsc.PetscRegressorLinearGetType
312        """
313        cdef PetscRegressorLinearType cval = REGRESSOR_LINEAR_OLS
314        CHKERR(PetscRegressorLinearGetType(self.regressor, &cval))
315        return cval
316
317del RegressorType
318