xref: /petsc/src/mat/impls/shell/shellcnv.c (revision a69119a591a03a9d906b29c0a4e9802e4d7c9795)
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 
3 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat) {
4   Mat          mat;
5   Vec          in, out;
6   PetscScalar *array;
7   PetscInt    *dnnz, *onnz, *dnnzu, *onnzu;
8   PetscInt     cst, Nbs, mbs, nbs, rbs, cbs;
9   PetscInt     im, i, m, n, M, N, *rows, start;
10 
11   PetscFunctionBegin;
12   PetscCall(MatGetOwnershipRange(oldmat, &start, NULL));
13   PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, NULL));
14   PetscCall(MatCreateVecs(oldmat, &in, &out));
15   PetscCall(MatGetLocalSize(oldmat, &m, &n));
16   PetscCall(MatGetSize(oldmat, &M, &N));
17   PetscCall(PetscMalloc1(m, &rows));
18   if (reuse != MAT_REUSE_MATRIX) {
19     PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat));
20     PetscCall(MatSetSizes(mat, m, n, M, N));
21     PetscCall(MatSetType(mat, newtype));
22     PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat));
23     PetscCall(MatGetBlockSizes(mat, &rbs, &cbs));
24     mbs = m / rbs;
25     nbs = n / cbs;
26     Nbs = N / cbs;
27     cst = cst / cbs;
28     PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu));
29     for (i = 0; i < mbs; i++) {
30       dnnz[i]  = nbs;
31       onnz[i]  = Nbs - nbs;
32       dnnzu[i] = PetscMax(nbs - i, 0);
33       onnzu[i] = PetscMax(Nbs - (cst + nbs), 0);
34     }
35     PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu));
36     PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu));
37     PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE));
38     PetscCall(MatSetUp(mat));
39   } else {
40     mat = *newmat;
41     PetscCall(MatZeroEntries(mat));
42   }
43   for (i = 0; i < N; i++) {
44     PetscInt j;
45 
46     PetscCall(VecZeroEntries(in));
47     PetscCall(VecSetValue(in, i, 1., INSERT_VALUES));
48     PetscCall(VecAssemblyBegin(in));
49     PetscCall(VecAssemblyEnd(in));
50     PetscCall(MatMult(oldmat, in, out));
51     PetscCall(VecGetArray(out, &array));
52     for (j = 0, im = 0; j < m; j++) {
53       if (PetscAbsScalar(array[j]) == 0.0) continue;
54       rows[im]  = j + start;
55       array[im] = array[j];
56       im++;
57     }
58     PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES));
59     PetscCall(VecRestoreArray(out, &array));
60   }
61   PetscCall(PetscFree(rows));
62   PetscCall(VecDestroy(&in));
63   PetscCall(VecDestroy(&out));
64   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
65   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
66   if (reuse == MAT_INPLACE_MATRIX) {
67     PetscCall(MatHeaderReplace(oldmat, &mat));
68   } else {
69     *newmat = mat;
70   }
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X) {
75   Mat B;
76 
77   PetscFunctionBegin;
78   PetscCall(MatShellGetContext(A, &B));
79   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
80   PetscCall(MatGetDiagonal(B, X));
81   PetscFunctionReturn(0);
82 }
83 
84 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y) {
85   Mat B;
86 
87   PetscFunctionBegin;
88   PetscCall(MatShellGetContext(A, &B));
89   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
90   PetscCall(MatMult(B, X, Y));
91   PetscFunctionReturn(0);
92 }
93 
94 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y) {
95   Mat B;
96 
97   PetscFunctionBegin;
98   PetscCall(MatShellGetContext(A, &B));
99   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
100   PetscCall(MatMultTranspose(B, X, Y));
101   PetscFunctionReturn(0);
102 }
103 
104 static PetscErrorCode MatDestroy_CF(Mat A) {
105   Mat B;
106 
107   PetscFunctionBegin;
108   PetscCall(MatShellGetContext(A, &B));
109   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
110   PetscCall(MatDestroy(&B));
111   PetscCall(MatShellSetContext(A, NULL));
112   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL));
113   PetscFunctionReturn(0);
114 }
115 
116 typedef struct {
117   void *userdata;
118   PetscErrorCode (*userdestroy)(void *);
119   PetscErrorCode (*numeric)(Mat);
120   MatProductType ptype;
121   Mat            Dwork;
122 } MatMatCF;
123 
124 static PetscErrorCode MatProductDestroy_CF(void *data) {
125   MatMatCF *mmcfdata = (MatMatCF *)data;
126 
127   PetscFunctionBegin;
128   if (mmcfdata->userdestroy) PetscCall((*mmcfdata->userdestroy)(mmcfdata->userdata));
129   PetscCall(MatDestroy(&mmcfdata->Dwork));
130   PetscCall(PetscFree(mmcfdata));
131   PetscFunctionReturn(0);
132 }
133 
134 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data) {
135   MatMatCF *mmcfdata = (MatMatCF *)data;
136 
137   PetscFunctionBegin;
138   PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data");
139   PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation");
140   /* the MATSHELL interface allows us to play with the product data */
141   PetscCall(PetscNew(&C->product));
142   C->product->type  = mmcfdata->ptype;
143   C->product->data  = mmcfdata->userdata;
144   C->product->Dwork = mmcfdata->Dwork;
145   PetscCall(MatShellGetContext(A, &C->product->A));
146   C->product->B = B;
147   PetscCall((*mmcfdata->numeric)(C));
148   PetscCall(PetscFree(C->product));
149   PetscFunctionReturn(0);
150 }
151 
152 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data) {
153   MatMatCF *mmcfdata;
154 
155   PetscFunctionBegin;
156   PetscCall(MatShellGetContext(A, &C->product->A));
157   PetscCall(MatProductSetFromOptions(C));
158   PetscCall(MatProductSymbolic(C));
159   /* the MATSHELL interface does not allow non-empty product data */
160   PetscCall(PetscNew(&mmcfdata));
161 
162   mmcfdata->numeric     = C->ops->productnumeric;
163   mmcfdata->ptype       = C->product->type;
164   mmcfdata->userdata    = C->product->data;
165   mmcfdata->userdestroy = C->product->destroy;
166   mmcfdata->Dwork       = C->product->Dwork;
167 
168   C->product->Dwork   = NULL;
169   C->product->data    = NULL;
170   C->product->destroy = NULL;
171   C->product->A       = A;
172 
173   *data = mmcfdata;
174   PetscFunctionReturn(0);
175 }
176 
177 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */
178 static PetscErrorCode MatProductSetFromOptions_CF(Mat D) {
179   Mat A, B, Ain;
180   void (*Af)(void) = NULL;
181   PetscBool flg;
182 
183   PetscFunctionBegin;
184   MatCheckProduct(D, 1);
185   if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(0);
186   A = D->product->A;
187   B = D->product->B;
188   PetscCall(MatIsShell(A, &flg));
189   if (!flg) PetscFunctionReturn(0);
190   PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af));
191   if (Af == (void (*)(void))MatProductSetFromOptions_CF) {
192     PetscCall(MatShellGetContext(A, &Ain));
193   } else PetscFunctionReturn(0);
194   D->product->A = Ain;
195   PetscCall(MatProductSetFromOptions(D));
196   D->product->A = A;
197   if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */
198     PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL));
199     PetscCall(MatProductSetFromOptions(D));
200   }
201   PetscFunctionReturn(0);
202 }
203 
204 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B) {
205   Mat       M;
206   PetscBool flg;
207 
208   PetscFunctionBegin;
209   PetscCall(PetscStrcmp(newtype, MATSHELL, &flg));
210   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL");
211   if (reuse == MAT_INITIAL_MATRIX) {
212     PetscCall(PetscObjectReference((PetscObject)A));
213     PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M));
214     PetscCall(MatSetBlockSizesFromMats(M, A, A));
215     PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF));
216     PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF));
217     PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF));
218     PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF));
219     PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF));
220     PetscCall(PetscFree(M->defaultvectype));
221     PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype));
222 #if defined(PETSC_HAVE_DEVICE)
223     PetscCall(MatBindToCPU(M, A->boundtocpu));
224 #endif
225     *B = M;
226   } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented");
227   PetscFunctionReturn(0);
228 }
229