1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 2 3 PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat) 4 { 5 Mat mat; 6 Vec in, out; 7 PetscScalar *array; 8 PetscInt *dnnz, *onnz, *dnnzu, *onnzu; 9 PetscInt cst, Nbs, mbs, nbs, rbs, cbs; 10 PetscInt im, i, m, n, M, N, *rows, start; 11 12 PetscFunctionBegin; 13 PetscCall(MatGetOwnershipRange(oldmat, &start, NULL)); 14 PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, NULL)); 15 PetscCall(MatCreateVecs(oldmat, &in, &out)); 16 PetscCall(MatGetLocalSize(oldmat, &m, &n)); 17 PetscCall(MatGetSize(oldmat, &M, &N)); 18 PetscCall(PetscMalloc1(m, &rows)); 19 if (reuse != MAT_REUSE_MATRIX) { 20 PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat)); 21 PetscCall(MatSetSizes(mat, m, n, M, N)); 22 PetscCall(MatSetType(mat, newtype)); 23 PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat)); 24 PetscCall(MatGetBlockSizes(mat, &rbs, &cbs)); 25 mbs = m / rbs; 26 nbs = n / cbs; 27 Nbs = N / cbs; 28 cst = cst / cbs; 29 PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu)); 30 for (i = 0; i < mbs; i++) { 31 dnnz[i] = nbs; 32 onnz[i] = Nbs - nbs; 33 dnnzu[i] = PetscMax(nbs - i, 0); 34 onnzu[i] = PetscMax(Nbs - (cst + nbs), 0); 35 } 36 PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu)); 37 PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu)); 38 PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE)); 39 PetscCall(MatSetUp(mat)); 40 } else { 41 mat = *newmat; 42 PetscCall(MatZeroEntries(mat)); 43 } 44 for (i = 0; i < N; i++) { 45 PetscInt j; 46 47 PetscCall(VecZeroEntries(in)); 48 PetscCall(VecSetValue(in, i, 1., INSERT_VALUES)); 49 PetscCall(VecAssemblyBegin(in)); 50 PetscCall(VecAssemblyEnd(in)); 51 PetscCall(MatMult(oldmat, in, out)); 52 PetscCall(VecGetArray(out, &array)); 53 for (j = 0, im = 0; j < m; j++) { 54 if (PetscAbsScalar(array[j]) == 0.0) continue; 55 rows[im] = j + start; 56 array[im] = array[j]; 57 im++; 58 } 59 PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES)); 60 PetscCall(VecRestoreArray(out, &array)); 61 } 62 PetscCall(PetscFree(rows)); 63 PetscCall(VecDestroy(&in)); 64 PetscCall(VecDestroy(&out)); 65 PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 66 PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 67 if (reuse == MAT_INPLACE_MATRIX) { 68 PetscCall(MatHeaderReplace(oldmat, &mat)); 69 } else { 70 *newmat = mat; 71 } 72 PetscFunctionReturn(PETSC_SUCCESS); 73 } 74 75 static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X) 76 { 77 Mat B; 78 79 PetscFunctionBegin; 80 PetscCall(MatShellGetContext(A, &B)); 81 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 82 PetscCall(MatGetDiagonal(B, X)); 83 PetscFunctionReturn(PETSC_SUCCESS); 84 } 85 86 static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y) 87 { 88 Mat B; 89 90 PetscFunctionBegin; 91 PetscCall(MatShellGetContext(A, &B)); 92 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 93 PetscCall(MatMult(B, X, Y)); 94 PetscFunctionReturn(PETSC_SUCCESS); 95 } 96 97 static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y) 98 { 99 Mat B; 100 101 PetscFunctionBegin; 102 PetscCall(MatShellGetContext(A, &B)); 103 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 104 PetscCall(MatMultTranspose(B, X, Y)); 105 PetscFunctionReturn(PETSC_SUCCESS); 106 } 107 108 static PetscErrorCode MatDestroy_CF(Mat A) 109 { 110 Mat B; 111 112 PetscFunctionBegin; 113 PetscCall(MatShellGetContext(A, &B)); 114 PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix"); 115 PetscCall(MatDestroy(&B)); 116 PetscCall(MatShellSetContext(A, NULL)); 117 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL)); 118 PetscFunctionReturn(PETSC_SUCCESS); 119 } 120 121 typedef struct { 122 void *userdata; 123 PetscErrorCode (*userdestroy)(void *); 124 PetscErrorCode (*numeric)(Mat); 125 MatProductType ptype; 126 Mat Dwork; 127 } MatMatCF; 128 129 static PetscErrorCode MatProductDestroy_CF(void *data) 130 { 131 MatMatCF *mmcfdata = (MatMatCF *)data; 132 133 PetscFunctionBegin; 134 if (mmcfdata->userdestroy) PetscCall((*mmcfdata->userdestroy)(mmcfdata->userdata)); 135 PetscCall(MatDestroy(&mmcfdata->Dwork)); 136 PetscCall(PetscFree(mmcfdata)); 137 PetscFunctionReturn(PETSC_SUCCESS); 138 } 139 140 static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data) 141 { 142 MatMatCF *mmcfdata = (MatMatCF *)data; 143 144 PetscFunctionBegin; 145 PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data"); 146 PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation"); 147 /* the MATSHELL interface allows us to play with the product data */ 148 PetscCall(PetscNew(&C->product)); 149 C->product->type = mmcfdata->ptype; 150 C->product->data = mmcfdata->userdata; 151 C->product->Dwork = mmcfdata->Dwork; 152 PetscCall(MatShellGetContext(A, &C->product->A)); 153 C->product->B = B; 154 PetscCall((*mmcfdata->numeric)(C)); 155 PetscCall(PetscFree(C->product)); 156 PetscFunctionReturn(PETSC_SUCCESS); 157 } 158 159 static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data) 160 { 161 MatMatCF *mmcfdata; 162 163 PetscFunctionBegin; 164 PetscCall(MatShellGetContext(A, &C->product->A)); 165 PetscCall(MatProductSetFromOptions(C)); 166 PetscCall(MatProductSymbolic(C)); 167 /* the MATSHELL interface does not allow non-empty product data */ 168 PetscCall(PetscNew(&mmcfdata)); 169 170 mmcfdata->numeric = C->ops->productnumeric; 171 mmcfdata->ptype = C->product->type; 172 mmcfdata->userdata = C->product->data; 173 mmcfdata->userdestroy = C->product->destroy; 174 mmcfdata->Dwork = C->product->Dwork; 175 176 C->product->Dwork = NULL; 177 C->product->data = NULL; 178 C->product->destroy = NULL; 179 C->product->A = A; 180 181 *data = mmcfdata; 182 PetscFunctionReturn(PETSC_SUCCESS); 183 } 184 185 /* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */ 186 static PetscErrorCode MatProductSetFromOptions_CF(Mat D) 187 { 188 Mat A, B, Ain; 189 void (*Af)(void) = NULL; 190 PetscBool flg; 191 192 PetscFunctionBegin; 193 MatCheckProduct(D, 1); 194 if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(PETSC_SUCCESS); 195 A = D->product->A; 196 B = D->product->B; 197 PetscCall(MatIsShell(A, &flg)); 198 if (!flg) PetscFunctionReturn(PETSC_SUCCESS); 199 PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af)); 200 if (Af == (void (*)(void))MatProductSetFromOptions_CF) { 201 PetscCall(MatShellGetContext(A, &Ain)); 202 } else PetscFunctionReturn(PETSC_SUCCESS); 203 D->product->A = Ain; 204 PetscCall(MatProductSetFromOptions(D)); 205 D->product->A = A; 206 if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */ 207 PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL)); 208 PetscCall(MatProductSetFromOptions(D)); 209 } 210 PetscFunctionReturn(PETSC_SUCCESS); 211 } 212 213 PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B) 214 { 215 Mat M; 216 PetscBool flg; 217 218 PetscFunctionBegin; 219 PetscCall(PetscStrcmp(newtype, MATSHELL, &flg)); 220 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL"); 221 if (reuse == MAT_INITIAL_MATRIX) { 222 PetscCall(PetscObjectReference((PetscObject)A)); 223 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M)); 224 PetscCall(MatSetBlockSizesFromMats(M, A, A)); 225 PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF)); 226 PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF)); 227 PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF)); 228 PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF)); 229 PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF)); 230 PetscCall(PetscFree(M->defaultvectype)); 231 PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype)); 232 #if defined(PETSC_HAVE_DEVICE) 233 PetscCall(MatBindToCPU(M, A->boundtocpu)); 234 #endif 235 *B = M; 236 } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented"); 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239