xref: /petsc/src/mat/impls/shell/shellcnv.c (revision 6dd63270497ad23dcf16ae500a87ff2b2a0b7474) !
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 #include <petsc/private/vecimpl.h> /* for Vec->ops->setvalues */
3 
4 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat)
5 {
6   Mat          mat;
7   Vec          in, out;
8   PetscScalar *array;
9   PetscInt    *dnnz, *onnz, *dnnzu, *onnzu;
10   PetscInt     cst, cen, Nbs, mbs, nbs, rbs, cbs;
11   PetscInt     im, i, m, n, M, N, *rows, start;
12 
13   PetscFunctionBegin;
14   PetscCall(MatGetOwnershipRange(oldmat, &start, NULL));
15   PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, &cen));
16   PetscCall(MatCreateVecs(oldmat, &in, &out));
17   PetscCall(MatGetLocalSize(oldmat, &m, &n));
18   PetscCall(MatGetSize(oldmat, &M, &N));
19   PetscCall(PetscMalloc1(m, &rows));
20   if (reuse != MAT_REUSE_MATRIX) {
21     PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat));
22     PetscCall(MatSetSizes(mat, m, n, M, N));
23     PetscCall(MatSetType(mat, newtype));
24     PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat));
25     PetscCall(MatGetBlockSizes(mat, &rbs, &cbs));
26     mbs = m / rbs;
27     nbs = n / cbs;
28     Nbs = N / cbs;
29     cst = cst / cbs;
30     PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu));
31     for (i = 0; i < mbs; i++) {
32       dnnz[i]  = nbs;
33       onnz[i]  = Nbs - nbs;
34       dnnzu[i] = PetscMax(nbs - i, 0);
35       onnzu[i] = PetscMax(Nbs - (cst + nbs), 0);
36     }
37     PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu));
38     PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu));
39     PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE));
40     PetscCall(MatSetUp(mat));
41   } else {
42     mat = *newmat;
43     PetscCall(MatZeroEntries(mat));
44   }
45   for (i = 0; i < N; i++) {
46     PetscInt j;
47 
48     PetscCall(VecZeroEntries(in));
49     if (in->ops->setvalues) {
50       PetscCall(VecSetValue(in, i, 1., INSERT_VALUES));
51     } else {
52       if (i >= cst && i < cen) {
53         PetscCall(VecGetArray(in, &array));
54         array[i - cst] = 1.0;
55         PetscCall(VecRestoreArray(in, &array));
56       }
57     }
58     PetscCall(VecAssemblyBegin(in));
59     PetscCall(VecAssemblyEnd(in));
60     PetscCall(MatMult(oldmat, in, out));
61     PetscCall(VecGetArray(out, &array));
62     for (j = 0, im = 0; j < m; j++) {
63       if (PetscAbsScalar(array[j]) == 0.0) continue;
64       rows[im]  = j + start;
65       array[im] = array[j];
66       im++;
67     }
68     PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES));
69     PetscCall(VecRestoreArray(out, &array));
70   }
71   PetscCall(PetscFree(rows));
72   PetscCall(VecDestroy(&in));
73   PetscCall(VecDestroy(&out));
74   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
75   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
76   if (reuse == MAT_INPLACE_MATRIX) {
77     PetscCall(MatHeaderReplace(oldmat, &mat));
78   } else {
79     *newmat = mat;
80   }
81   PetscFunctionReturn(PETSC_SUCCESS);
82 }
83 
84 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X)
85 {
86   Mat B;
87 
88   PetscFunctionBegin;
89   PetscCall(MatShellGetContext(A, &B));
90   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
91   PetscCall(MatGetDiagonal(B, X));
92   PetscFunctionReturn(PETSC_SUCCESS);
93 }
94 
95 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y)
96 {
97   Mat B;
98 
99   PetscFunctionBegin;
100   PetscCall(MatShellGetContext(A, &B));
101   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
102   PetscCall(MatMult(B, X, Y));
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105 
106 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y)
107 {
108   Mat B;
109 
110   PetscFunctionBegin;
111   PetscCall(MatShellGetContext(A, &B));
112   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
113   PetscCall(MatMultTranspose(B, X, Y));
114   PetscFunctionReturn(PETSC_SUCCESS);
115 }
116 
117 static PetscErrorCode MatDestroy_CF(Mat A)
118 {
119   Mat B;
120 
121   PetscFunctionBegin;
122   PetscCall(MatShellGetContext(A, &B));
123   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
124   PetscCall(MatDestroy(&B));
125   PetscCall(MatShellSetContext(A, NULL));
126   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL));
127   PetscFunctionReturn(PETSC_SUCCESS);
128 }
129 
130 typedef struct {
131   void *userdata;
132   PetscErrorCode (*ctxdestroy)(void *);
133   PetscErrorCode (*numeric)(Mat);
134   MatProductType ptype;
135   Mat            Dwork;
136 } MatMatCF;
137 
138 static PetscErrorCode MatProductDestroy_CF(void *data)
139 {
140   MatMatCF *mmcfdata = (MatMatCF *)data;
141 
142   PetscFunctionBegin;
143   if (mmcfdata->ctxdestroy) PetscCall((*mmcfdata->ctxdestroy)(mmcfdata->userdata));
144   PetscCall(MatDestroy(&mmcfdata->Dwork));
145   PetscCall(PetscFree(mmcfdata));
146   PetscFunctionReturn(PETSC_SUCCESS);
147 }
148 
149 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data)
150 {
151   MatMatCF *mmcfdata = (MatMatCF *)data;
152 
153   PetscFunctionBegin;
154   PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data");
155   PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation");
156   /* the MATSHELL interface allows us to play with the product data */
157   PetscCall(PetscNew(&C->product));
158   C->product->type  = mmcfdata->ptype;
159   C->product->data  = mmcfdata->userdata;
160   C->product->Dwork = mmcfdata->Dwork;
161   PetscCall(MatShellGetContext(A, &C->product->A));
162   C->product->B = B;
163   PetscCall((*mmcfdata->numeric)(C));
164   PetscCall(PetscFree(C->product));
165   PetscFunctionReturn(PETSC_SUCCESS);
166 }
167 
168 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data)
169 {
170   MatMatCF *mmcfdata;
171 
172   PetscFunctionBegin;
173   PetscCall(MatShellGetContext(A, &C->product->A));
174   PetscCall(MatProductSetFromOptions(C));
175   PetscCall(MatProductSymbolic(C));
176   /* the MATSHELL interface does not allow non-empty product data */
177   PetscCall(PetscNew(&mmcfdata));
178 
179   mmcfdata->numeric    = C->ops->productnumeric;
180   mmcfdata->ptype      = C->product->type;
181   mmcfdata->userdata   = C->product->data;
182   mmcfdata->ctxdestroy = C->product->destroy;
183   mmcfdata->Dwork      = C->product->Dwork;
184 
185   C->product->Dwork   = NULL;
186   C->product->data    = NULL;
187   C->product->destroy = NULL;
188   C->product->A       = A;
189 
190   *data = mmcfdata;
191   PetscFunctionReturn(PETSC_SUCCESS);
192 }
193 
194 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */
195 static PetscErrorCode MatProductSetFromOptions_CF(Mat D)
196 {
197   Mat A, B, Ain;
198   PetscErrorCode (*Af)(Mat) = NULL;
199   PetscBool flg;
200 
201   PetscFunctionBegin;
202   MatCheckProduct(D, 1);
203   if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(PETSC_SUCCESS);
204   A = D->product->A;
205   B = D->product->B;
206   PetscCall(MatIsShell(A, &flg));
207   if (!flg) PetscFunctionReturn(PETSC_SUCCESS);
208   PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af));
209   if (Af == MatProductSetFromOptions_CF) {
210     PetscCall(MatShellGetContext(A, &Ain));
211   } else PetscFunctionReturn(PETSC_SUCCESS);
212   D->product->A = Ain;
213   PetscCall(MatProductSetFromOptions(D));
214   D->product->A = A;
215   if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */
216     PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL));
217     PetscCall(MatProductSetFromOptions(D));
218   }
219   PetscFunctionReturn(PETSC_SUCCESS);
220 }
221 
222 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B)
223 {
224   Mat       M;
225   PetscBool flg;
226 
227   PetscFunctionBegin;
228   PetscCall(PetscStrcmp(newtype, MATSHELL, &flg));
229   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL");
230   if (reuse == MAT_INITIAL_MATRIX) {
231     PetscCall(PetscObjectReference((PetscObject)A));
232     PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M));
233     PetscCall(MatSetBlockSizesFromMats(M, A, A));
234     PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF));
235     PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF));
236     PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF));
237     PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF));
238     PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF));
239     PetscCall(PetscFree(M->defaultvectype));
240     PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype));
241 #if defined(PETSC_HAVE_DEVICE)
242     PetscCall(MatBindToCPU(M, A->boundtocpu));
243 #endif
244     *B = M;
245   } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented");
246   PetscFunctionReturn(PETSC_SUCCESS);
247 }
248