1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 2 #include <petsc/private/vecimpl.h> /* for Vec->ops->setvalues */ 3 4 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat) 5 { 6 Mat mat; 7 Vec in, out; 8 PetscScalar *array; 9 PetscInt *dnnz, *onnz, *dnnzu, *onnzu; 10 PetscInt cst, cen, Nbs, mbs, nbs, rbs, cbs; 11 PetscInt im, i, m, n, M, N, *rows, start; 12 13 PetscFunctionBegin; 14 PetscCall(MatGetOwnershipRange(oldmat, &start, NULL)); 15 PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, &cen)); 16 PetscCall(MatCreateVecs(oldmat, &in, &out)); 17 PetscCall(MatGetLocalSize(oldmat, &m, &n)); 18 PetscCall(MatGetSize(oldmat, &M, &N)); 19 PetscCall(PetscMalloc1(m, &rows)); 20 if (reuse != MAT_REUSE_MATRIX) { 21 PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat)); 22 PetscCall(MatSetSizes(mat, m, n, M, N)); 23 PetscCall(MatSetType(mat, newtype)); 24 PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat)); 25 PetscCall(MatGetBlockSizes(mat, &rbs, &cbs)); 26 mbs = m / rbs; 27 nbs = n / cbs; 28 Nbs = N / cbs; 29 cst = cst / cbs; 30 PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu)); 31 for (i = 0; i < mbs; i++) { 32 dnnz[i] = nbs; 33 onnz[i] = Nbs - nbs; 34 dnnzu[i] = PetscMax(nbs - i, 0); 35 onnzu[i] = PetscMax(Nbs - (cst + nbs), 0); 36 } 37 PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu)); 38 PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu)); 39 PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE)); 40 PetscCall(MatSetUp(mat)); 41 } else { 42 mat = *newmat; 43 PetscCall(MatZeroEntries(mat)); 44 } 45 for (i = 0; i < N; i++) { 46 PetscInt j; 47 48 PetscCall(VecZeroEntries(in)); 49 if (in->ops->setvalues) { 50 PetscCall(VecSetValue(in, i, 1., INSERT_VALUES)); 51 } else { 52 if (i >= cst && i < cen) { 53 PetscCall(VecGetArray(in, &array)); 54 array[i - cst] = 1.0; 55 PetscCall(VecRestoreArray(in, &array)); 56 } 57 } 58 PetscCall(VecAssemblyBegin(in)); 59 PetscCall(VecAssemblyEnd(in)); 60 PetscCall(MatMult(oldmat, in, out)); 61 PetscCall(VecGetArray(out, &array)); 62 for (j = 0, im = 0; j < m; j++) { 63 if (PetscAbsScalar(array[j]) == 0.0) continue; 64 rows[im] = j + start; 65 array[im] = array[j]; 66 im++; 67 } 68 PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES)); 69 PetscCall(VecRestoreArray(out, &array)); 70 } 71 PetscCall(PetscFree(rows)); 72 PetscCall(VecDestroy(&in)); 73 PetscCall(VecDestroy(&out)); 74 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 75 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 76 if (reuse == MAT_INPLACE_MATRIX) { 77 PetscCall(MatHeaderReplace(oldmat, &mat)); 78 } else { 79 *newmat = mat; 80 } 81 PetscFunctionReturn(PETSC_SUCCESS); 82 } 83 84 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X) 85 { 86 Mat B; 87 88 PetscFunctionBegin; 89 PetscCall(MatShellGetContext(A, &B)); 90 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 91 PetscCall(MatGetDiagonal(B, X)); 92 PetscFunctionReturn(PETSC_SUCCESS); 93 } 94 95 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y) 96 { 97 Mat B; 98 99 PetscFunctionBegin; 100 PetscCall(MatShellGetContext(A, &B)); 101 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 102 PetscCall(MatMult(B, X, Y)); 103 PetscFunctionReturn(PETSC_SUCCESS); 104 } 105 106 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y) 107 { 108 Mat B; 109 110 PetscFunctionBegin; 111 PetscCall(MatShellGetContext(A, &B)); 112 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 113 PetscCall(MatMultTranspose(B, X, Y)); 114 PetscFunctionReturn(PETSC_SUCCESS); 115 } 116 117 static PetscErrorCode MatDestroy_CF(Mat A) 118 { 119 Mat B; 120 121 PetscFunctionBegin; 122 PetscCall(MatShellGetContext(A, &B)); 123 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 124 PetscCall(MatDestroy(&B)); 125 PetscCall(MatShellSetContext(A, NULL)); 126 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL)); 127 PetscFunctionReturn(PETSC_SUCCESS); 128 } 129 130 typedef struct { 131 void *userdata; 132 PetscErrorCode (*ctxdestroy)(void *); 133 PetscErrorCode (*numeric)(Mat); 134 MatProductType ptype; 135 Mat Dwork; 136 } MatMatCF; 137 138 static PetscErrorCode MatProductDestroy_CF(void *data) 139 { 140 MatMatCF *mmcfdata = (MatMatCF *)data; 141 142 PetscFunctionBegin; 143 if (mmcfdata->ctxdestroy) PetscCall((*mmcfdata->ctxdestroy)(mmcfdata->userdata)); 144 PetscCall(MatDestroy(&mmcfdata->Dwork)); 145 PetscCall(PetscFree(mmcfdata)); 146 PetscFunctionReturn(PETSC_SUCCESS); 147 } 148 149 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data) 150 { 151 MatMatCF *mmcfdata = (MatMatCF *)data; 152 153 PetscFunctionBegin; 154 PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data"); 155 PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation"); 156 /* the MATSHELL interface allows us to play with the product data */ 157 PetscCall(PetscNew(&C->product)); 158 C->product->type = mmcfdata->ptype; 159 C->product->data = mmcfdata->userdata; 160 C->product->Dwork = mmcfdata->Dwork; 161 PetscCall(MatShellGetContext(A, &C->product->A)); 162 C->product->B = B; 163 PetscCall((*mmcfdata->numeric)(C)); 164 PetscCall(PetscFree(C->product)); 165 PetscFunctionReturn(PETSC_SUCCESS); 166 } 167 168 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data) 169 { 170 MatMatCF *mmcfdata; 171 172 PetscFunctionBegin; 173 PetscCall(MatShellGetContext(A, &C->product->A)); 174 PetscCall(MatProductSetFromOptions(C)); 175 PetscCall(MatProductSymbolic(C)); 176 /* the MATSHELL interface does not allow non-empty product data */ 177 PetscCall(PetscNew(&mmcfdata)); 178 179 mmcfdata->numeric = C->ops->productnumeric; 180 mmcfdata->ptype = C->product->type; 181 mmcfdata->userdata = C->product->data; 182 mmcfdata->ctxdestroy = C->product->destroy; 183 mmcfdata->Dwork = C->product->Dwork; 184 185 C->product->Dwork = NULL; 186 C->product->data = NULL; 187 C->product->destroy = NULL; 188 C->product->A = A; 189 190 *data = mmcfdata; 191 PetscFunctionReturn(PETSC_SUCCESS); 192 } 193 194 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */ 195 static PetscErrorCode MatProductSetFromOptions_CF(Mat D) 196 { 197 Mat A, B, Ain; 198 PetscErrorCode (*Af)(Mat) = NULL; 199 PetscBool flg; 200 201 PetscFunctionBegin; 202 MatCheckProduct(D, 1); 203 if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(PETSC_SUCCESS); 204 A = D->product->A; 205 B = D->product->B; 206 PetscCall(MatIsShell(A, &flg)); 207 if (!flg) PetscFunctionReturn(PETSC_SUCCESS); 208 PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af)); 209 if (Af == MatProductSetFromOptions_CF) PetscCall(MatShellGetContext(A, &Ain)); 210 else PetscFunctionReturn(PETSC_SUCCESS); 211 D->product->A = Ain; 212 PetscCall(MatProductSetFromOptions(D)); 213 D->product->A = A; 214 if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */ 215 PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL)); 216 PetscCall(MatProductSetFromOptions(D)); 217 } 218 PetscFunctionReturn(PETSC_SUCCESS); 219 } 220 221 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B) 222 { 223 Mat M; 224 PetscBool flg; 225 226 PetscFunctionBegin; 227 PetscCall(PetscStrcmp(newtype, MATSHELL, &flg)); 228 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL"); 229 if (reuse == MAT_INITIAL_MATRIX) { 230 PetscCall(PetscObjectReference((PetscObject)A)); 231 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M)); 232 PetscCall(MatSetBlockSizesFromMats(M, A, A)); 233 PetscCall(MatShellSetOperation(M, MATOP_MULT, (PetscErrorCodeFn *)MatMult_CF)); 234 PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_CF)); 235 PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_CF)); 236 PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (PetscErrorCodeFn *)MatDestroy_CF)); 237 PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF)); 238 PetscCall(PetscFree(M->defaultvectype)); 239 PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype)); 240 #if defined(PETSC_HAVE_DEVICE) 241 PetscCall(MatBindToCPU(M, A->boundtocpu)); 242 #endif 243 *B = M; 244 } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented"); 245 PetscFunctionReturn(PETSC_SUCCESS); 246 } 247