xref: /petsc/src/mat/impls/shell/shellcnv.c (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 
3 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat)
4 {
5   Mat          mat;
6   Vec          in, out;
7   PetscScalar *array;
8   PetscInt    *dnnz, *onnz, *dnnzu, *onnzu;
9   PetscInt     cst, Nbs, mbs, nbs, rbs, cbs;
10   PetscInt     im, i, m, n, M, N, *rows, start;
11 
12   PetscFunctionBegin;
13   PetscCall(MatGetOwnershipRange(oldmat, &start, NULL));
14   PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, NULL));
15   PetscCall(MatCreateVecs(oldmat, &in, &out));
16   PetscCall(MatGetLocalSize(oldmat, &m, &n));
17   PetscCall(MatGetSize(oldmat, &M, &N));
18   PetscCall(PetscMalloc1(m, &rows));
19   if (reuse != MAT_REUSE_MATRIX) {
20     PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat));
21     PetscCall(MatSetSizes(mat, m, n, M, N));
22     PetscCall(MatSetType(mat, newtype));
23     PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat));
24     PetscCall(MatGetBlockSizes(mat, &rbs, &cbs));
25     mbs = m / rbs;
26     nbs = n / cbs;
27     Nbs = N / cbs;
28     cst = cst / cbs;
29     PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu));
30     for (i = 0; i < mbs; i++) {
31       dnnz[i]  = nbs;
32       onnz[i]  = Nbs - nbs;
33       dnnzu[i] = PetscMax(nbs - i, 0);
34       onnzu[i] = PetscMax(Nbs - (cst + nbs), 0);
35     }
36     PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu));
37     PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu));
38     PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE));
39     PetscCall(MatSetUp(mat));
40   } else {
41     mat = *newmat;
42     PetscCall(MatZeroEntries(mat));
43   }
44   for (i = 0; i < N; i++) {
45     PetscInt j;
46 
47     PetscCall(VecZeroEntries(in));
48     PetscCall(VecSetValue(in, i, 1., INSERT_VALUES));
49     PetscCall(VecAssemblyBegin(in));
50     PetscCall(VecAssemblyEnd(in));
51     PetscCall(MatMult(oldmat, in, out));
52     PetscCall(VecGetArray(out, &array));
53     for (j = 0, im = 0; j < m; j++) {
54       if (PetscAbsScalar(array[j]) == 0.0) continue;
55       rows[im]  = j + start;
56       array[im] = array[j];
57       im++;
58     }
59     PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES));
60     PetscCall(VecRestoreArray(out, &array));
61   }
62   PetscCall(PetscFree(rows));
63   PetscCall(VecDestroy(&in));
64   PetscCall(VecDestroy(&out));
65   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
66   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
67   if (reuse == MAT_INPLACE_MATRIX) {
68     PetscCall(MatHeaderReplace(oldmat, &mat));
69   } else {
70     *newmat = mat;
71   }
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X)
76 {
77   Mat B;
78 
79   PetscFunctionBegin;
80   PetscCall(MatShellGetContext(A, &B));
81   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
82   PetscCall(MatGetDiagonal(B, X));
83   PetscFunctionReturn(PETSC_SUCCESS);
84 }
85 
86 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y)
87 {
88   Mat B;
89 
90   PetscFunctionBegin;
91   PetscCall(MatShellGetContext(A, &B));
92   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
93   PetscCall(MatMult(B, X, Y));
94   PetscFunctionReturn(PETSC_SUCCESS);
95 }
96 
97 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y)
98 {
99   Mat B;
100 
101   PetscFunctionBegin;
102   PetscCall(MatShellGetContext(A, &B));
103   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
104   PetscCall(MatMultTranspose(B, X, Y));
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 static PetscErrorCode MatDestroy_CF(Mat A)
109 {
110   Mat B;
111 
112   PetscFunctionBegin;
113   PetscCall(MatShellGetContext(A, &B));
114   PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
115   PetscCall(MatDestroy(&B));
116   PetscCall(MatShellSetContext(A, NULL));
117   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL));
118   PetscFunctionReturn(PETSC_SUCCESS);
119 }
120 
121 typedef struct {
122   void *userdata;
123   PetscErrorCode (*userdestroy)(void *);
124   PetscErrorCode (*numeric)(Mat);
125   MatProductType ptype;
126   Mat            Dwork;
127 } MatMatCF;
128 
129 static PetscErrorCode MatProductDestroy_CF(void *data)
130 {
131   MatMatCF *mmcfdata = (MatMatCF *)data;
132 
133   PetscFunctionBegin;
134   if (mmcfdata->userdestroy) PetscCall((*mmcfdata->userdestroy)(mmcfdata->userdata));
135   PetscCall(MatDestroy(&mmcfdata->Dwork));
136   PetscCall(PetscFree(mmcfdata));
137   PetscFunctionReturn(PETSC_SUCCESS);
138 }
139 
140 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data)
141 {
142   MatMatCF *mmcfdata = (MatMatCF *)data;
143 
144   PetscFunctionBegin;
145   PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data");
146   PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation");
147   /* the MATSHELL interface allows us to play with the product data */
148   PetscCall(PetscNew(&C->product));
149   C->product->type  = mmcfdata->ptype;
150   C->product->data  = mmcfdata->userdata;
151   C->product->Dwork = mmcfdata->Dwork;
152   PetscCall(MatShellGetContext(A, &C->product->A));
153   C->product->B = B;
154   PetscCall((*mmcfdata->numeric)(C));
155   PetscCall(PetscFree(C->product));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data)
160 {
161   MatMatCF *mmcfdata;
162 
163   PetscFunctionBegin;
164   PetscCall(MatShellGetContext(A, &C->product->A));
165   PetscCall(MatProductSetFromOptions(C));
166   PetscCall(MatProductSymbolic(C));
167   /* the MATSHELL interface does not allow non-empty product data */
168   PetscCall(PetscNew(&mmcfdata));
169 
170   mmcfdata->numeric     = C->ops->productnumeric;
171   mmcfdata->ptype       = C->product->type;
172   mmcfdata->userdata    = C->product->data;
173   mmcfdata->userdestroy = C->product->destroy;
174   mmcfdata->Dwork       = C->product->Dwork;
175 
176   C->product->Dwork   = NULL;
177   C->product->data    = NULL;
178   C->product->destroy = NULL;
179   C->product->A       = A;
180 
181   *data = mmcfdata;
182   PetscFunctionReturn(PETSC_SUCCESS);
183 }
184 
185 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */
186 static PetscErrorCode MatProductSetFromOptions_CF(Mat D)
187 {
188   Mat A, B, Ain;
189   void (*Af)(void) = NULL;
190   PetscBool flg;
191 
192   PetscFunctionBegin;
193   MatCheckProduct(D, 1);
194   if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(PETSC_SUCCESS);
195   A = D->product->A;
196   B = D->product->B;
197   PetscCall(MatIsShell(A, &flg));
198   if (!flg) PetscFunctionReturn(PETSC_SUCCESS);
199   PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af));
200   if (Af == (void (*)(void))MatProductSetFromOptions_CF) {
201     PetscCall(MatShellGetContext(A, &Ain));
202   } else PetscFunctionReturn(PETSC_SUCCESS);
203   D->product->A = Ain;
204   PetscCall(MatProductSetFromOptions(D));
205   D->product->A = A;
206   if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */
207     PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL));
208     PetscCall(MatProductSetFromOptions(D));
209   }
210   PetscFunctionReturn(PETSC_SUCCESS);
211 }
212 
213 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B)
214 {
215   Mat       M;
216   PetscBool flg;
217 
218   PetscFunctionBegin;
219   PetscCall(PetscStrcmp(newtype, MATSHELL, &flg));
220   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL");
221   if (reuse == MAT_INITIAL_MATRIX) {
222     PetscCall(PetscObjectReference((PetscObject)A));
223     PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M));
224     PetscCall(MatSetBlockSizesFromMats(M, A, A));
225     PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF));
226     PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF));
227     PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF));
228     PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF));
229     PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF));
230     PetscCall(PetscFree(M->defaultvectype));
231     PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype));
232 #if defined(PETSC_HAVE_DEVICE)
233     PetscCall(MatBindToCPU(M, A->boundtocpu));
234 #endif
235     *B = M;
236   } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented");
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239