1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 2 3 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat) { 4 Mat mat; 5 Vec in, out; 6 PetscScalar *array; 7 PetscInt *dnnz, *onnz, *dnnzu, *onnzu; 8 PetscInt cst, Nbs, mbs, nbs, rbs, cbs; 9 PetscInt im, i, m, n, M, N, *rows, start; 10 11 PetscFunctionBegin; 12 PetscCall(MatGetOwnershipRange(oldmat, &start, NULL)); 13 PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, NULL)); 14 PetscCall(MatCreateVecs(oldmat, &in, &out)); 15 PetscCall(MatGetLocalSize(oldmat, &m, &n)); 16 PetscCall(MatGetSize(oldmat, &M, &N)); 17 PetscCall(PetscMalloc1(m, &rows)); 18 if (reuse != MAT_REUSE_MATRIX) { 19 PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat)); 20 PetscCall(MatSetSizes(mat, m, n, M, N)); 21 PetscCall(MatSetType(mat, newtype)); 22 PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat)); 23 PetscCall(MatGetBlockSizes(mat, &rbs, &cbs)); 24 mbs = m / rbs; 25 nbs = n / cbs; 26 Nbs = N / cbs; 27 cst = cst / cbs; 28 PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu)); 29 for (i = 0; i < mbs; i++) { 30 dnnz[i] = nbs; 31 onnz[i] = Nbs - nbs; 32 dnnzu[i] = PetscMax(nbs - i, 0); 33 onnzu[i] = PetscMax(Nbs - (cst + nbs), 0); 34 } 35 PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu)); 36 PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu)); 37 PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE)); 38 PetscCall(MatSetUp(mat)); 39 } else { 40 mat = *newmat; 41 PetscCall(MatZeroEntries(mat)); 42 } 43 for (i = 0; i < N; i++) { 44 PetscInt j; 45 46 PetscCall(VecZeroEntries(in)); 47 PetscCall(VecSetValue(in, i, 1., INSERT_VALUES)); 48 PetscCall(VecAssemblyBegin(in)); 49 PetscCall(VecAssemblyEnd(in)); 50 PetscCall(MatMult(oldmat, in, out)); 51 PetscCall(VecGetArray(out, &array)); 52 for (j = 0, im = 0; j < m; j++) { 53 if (PetscAbsScalar(array[j]) == 0.0) continue; 54 rows[im] = j + start; 55 array[im] = array[j]; 56 im++; 57 } 58 PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES)); 59 PetscCall(VecRestoreArray(out, &array)); 60 } 61 PetscCall(PetscFree(rows)); 62 PetscCall(VecDestroy(&in)); 63 PetscCall(VecDestroy(&out)); 64 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 65 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 66 if (reuse == MAT_INPLACE_MATRIX) { 67 PetscCall(MatHeaderReplace(oldmat, &mat)); 68 } else { 69 *newmat = mat; 70 } 71 PetscFunctionReturn(0); 72 } 73 74 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X) { 75 Mat B; 76 77 PetscFunctionBegin; 78 PetscCall(MatShellGetContext(A, &B)); 79 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 80 PetscCall(MatGetDiagonal(B, X)); 81 PetscFunctionReturn(0); 82 } 83 84 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y) { 85 Mat B; 86 87 PetscFunctionBegin; 88 PetscCall(MatShellGetContext(A, &B)); 89 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 90 PetscCall(MatMult(B, X, Y)); 91 PetscFunctionReturn(0); 92 } 93 94 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y) { 95 Mat B; 96 97 PetscFunctionBegin; 98 PetscCall(MatShellGetContext(A, &B)); 99 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 100 PetscCall(MatMultTranspose(B, X, Y)); 101 PetscFunctionReturn(0); 102 } 103 104 static PetscErrorCode MatDestroy_CF(Mat A) { 105 Mat B; 106 107 PetscFunctionBegin; 108 PetscCall(MatShellGetContext(A, &B)); 109 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 110 PetscCall(MatDestroy(&B)); 111 PetscCall(MatShellSetContext(A, NULL)); 112 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL)); 113 PetscFunctionReturn(0); 114 } 115 116 typedef struct { 117 void *userdata; 118 PetscErrorCode (*userdestroy)(void *); 119 PetscErrorCode (*numeric)(Mat); 120 MatProductType ptype; 121 Mat Dwork; 122 } MatMatCF; 123 124 static PetscErrorCode MatProductDestroy_CF(void *data) { 125 MatMatCF *mmcfdata = (MatMatCF *)data; 126 127 PetscFunctionBegin; 128 if (mmcfdata->userdestroy) PetscCall((*mmcfdata->userdestroy)(mmcfdata->userdata)); 129 PetscCall(MatDestroy(&mmcfdata->Dwork)); 130 PetscCall(PetscFree(mmcfdata)); 131 PetscFunctionReturn(0); 132 } 133 134 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data) { 135 MatMatCF *mmcfdata = (MatMatCF *)data; 136 137 PetscFunctionBegin; 138 PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data"); 139 PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation"); 140 /* the MATSHELL interface allows us to play with the product data */ 141 PetscCall(PetscNew(&C->product)); 142 C->product->type = mmcfdata->ptype; 143 C->product->data = mmcfdata->userdata; 144 C->product->Dwork = mmcfdata->Dwork; 145 PetscCall(MatShellGetContext(A, &C->product->A)); 146 C->product->B = B; 147 PetscCall((*mmcfdata->numeric)(C)); 148 PetscCall(PetscFree(C->product)); 149 PetscFunctionReturn(0); 150 } 151 152 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data) { 153 MatMatCF *mmcfdata; 154 155 PetscFunctionBegin; 156 PetscCall(MatShellGetContext(A, &C->product->A)); 157 PetscCall(MatProductSetFromOptions(C)); 158 PetscCall(MatProductSymbolic(C)); 159 /* the MATSHELL interface does not allow non-empty product data */ 160 PetscCall(PetscNew(&mmcfdata)); 161 162 mmcfdata->numeric = C->ops->productnumeric; 163 mmcfdata->ptype = C->product->type; 164 mmcfdata->userdata = C->product->data; 165 mmcfdata->userdestroy = C->product->destroy; 166 mmcfdata->Dwork = C->product->Dwork; 167 168 C->product->Dwork = NULL; 169 C->product->data = NULL; 170 C->product->destroy = NULL; 171 C->product->A = A; 172 173 *data = mmcfdata; 174 PetscFunctionReturn(0); 175 } 176 177 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */ 178 static PetscErrorCode MatProductSetFromOptions_CF(Mat D) { 179 Mat A, B, Ain; 180 void (*Af)(void) = NULL; 181 PetscBool flg; 182 183 PetscFunctionBegin; 184 MatCheckProduct(D, 1); 185 if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(0); 186 A = D->product->A; 187 B = D->product->B; 188 PetscCall(MatIsShell(A, &flg)); 189 if (!flg) PetscFunctionReturn(0); 190 PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af)); 191 if (Af == (void (*)(void))MatProductSetFromOptions_CF) { 192 PetscCall(MatShellGetContext(A, &Ain)); 193 } else PetscFunctionReturn(0); 194 D->product->A = Ain; 195 PetscCall(MatProductSetFromOptions(D)); 196 D->product->A = A; 197 if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */ 198 PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL)); 199 PetscCall(MatProductSetFromOptions(D)); 200 } 201 PetscFunctionReturn(0); 202 } 203 204 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B) { 205 Mat M; 206 PetscBool flg; 207 208 PetscFunctionBegin; 209 PetscCall(PetscStrcmp(newtype, MATSHELL, &flg)); 210 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL"); 211 if (reuse == MAT_INITIAL_MATRIX) { 212 PetscCall(PetscObjectReference((PetscObject)A)); 213 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M)); 214 PetscCall(MatSetBlockSizesFromMats(M, A, A)); 215 PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF)); 216 PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF)); 217 PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF)); 218 PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF)); 219 PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF)); 220 PetscCall(PetscFree(M->defaultvectype)); 221 PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype)); 222 #if defined(PETSC_HAVE_DEVICE) 223 PetscCall(MatBindToCPU(M, A->boundtocpu)); 224 #endif 225 *B = M; 226 } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented"); 227 PetscFunctionReturn(0); 228 } 229