1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 2 3 typedef struct { 4 IS isrow, iscol; /* rows and columns in submatrix, only used to check consistency */ 5 Vec lwork, rwork; /* work vectors inside the scatters */ 6 Vec lwork2, rwork2; /* work vectors inside the scatters */ 7 VecScatter lrestrict, rprolong; 8 Mat A; 9 } Mat_SubVirtual; 10 11 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a) 12 { 13 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatScale(Na->A, a)); 17 PetscFunctionReturn(PETSC_SUCCESS); 18 } 19 20 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a) 21 { 22 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 23 24 PetscFunctionBegin; 25 PetscCall(MatShift(Na->A, a)); 26 PetscFunctionReturn(PETSC_SUCCESS); 27 } 28 29 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right) 30 { 31 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 32 33 PetscFunctionBegin; 34 if (right) { 35 PetscCall(VecZeroEntries(Na->rwork)); 36 PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 37 PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 38 } 39 if (left) { 40 PetscCall(VecZeroEntries(Na->lwork)); 41 PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 42 PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 43 } 44 PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL)); 45 PetscFunctionReturn(PETSC_SUCCESS); 46 } 47 48 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d) 49 { 50 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 51 52 PetscFunctionBegin; 53 PetscCall(MatGetDiagonal(Na->A, Na->rwork)); 54 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 55 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 56 PetscFunctionReturn(PETSC_SUCCESS); 57 } 58 59 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y) 60 { 61 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 62 63 PetscFunctionBegin; 64 PetscCall(VecZeroEntries(Na->rwork)); 65 PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 66 PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 67 PetscCall(MatMult(Na->A, Na->rwork, Na->lwork)); 68 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 69 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 70 PetscFunctionReturn(PETSC_SUCCESS); 71 } 72 73 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) 74 { 75 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 76 77 PetscFunctionBegin; 78 PetscCall(VecZeroEntries(Na->rwork)); 79 PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 80 PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 81 if (v1 == v2) { 82 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork)); 83 } else if (v2 == v3) { 84 PetscCall(VecZeroEntries(Na->lwork)); 85 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 86 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 87 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork)); 88 } else { 89 if (!Na->lwork2) { 90 PetscCall(VecDuplicate(Na->lwork, &Na->lwork2)); 91 } else { 92 PetscCall(VecZeroEntries(Na->lwork2)); 93 } 94 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 95 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 96 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork)); 97 } 98 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 99 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 100 PetscFunctionReturn(PETSC_SUCCESS); 101 } 102 103 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y) 104 { 105 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 106 107 PetscFunctionBegin; 108 PetscCall(VecZeroEntries(Na->lwork)); 109 PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 110 PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 111 PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork)); 112 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 113 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 114 PetscFunctionReturn(PETSC_SUCCESS); 115 } 116 117 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) 118 { 119 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 120 121 PetscFunctionBegin; 122 PetscCall(VecZeroEntries(Na->lwork)); 123 PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 124 PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 125 if (v1 == v2) { 126 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork)); 127 } else if (v2 == v3) { 128 PetscCall(VecZeroEntries(Na->rwork)); 129 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 130 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 131 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork)); 132 } else { 133 if (!Na->rwork2) { 134 PetscCall(VecDuplicate(Na->rwork, &Na->rwork2)); 135 } else { 136 PetscCall(VecZeroEntries(Na->rwork2)); 137 } 138 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 139 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 140 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork)); 141 } 142 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 143 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 144 PetscFunctionReturn(PETSC_SUCCESS); 145 } 146 147 static PetscErrorCode MatDestroy_SubMatrix(Mat N) 148 { 149 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 150 151 PetscFunctionBegin; 152 PetscCall(ISDestroy(&Na->isrow)); 153 PetscCall(ISDestroy(&Na->iscol)); 154 PetscCall(VecDestroy(&Na->lwork)); 155 PetscCall(VecDestroy(&Na->rwork)); 156 PetscCall(VecDestroy(&Na->lwork2)); 157 PetscCall(VecDestroy(&Na->rwork2)); 158 PetscCall(VecScatterDestroy(&Na->lrestrict)); 159 PetscCall(VecScatterDestroy(&Na->rprolong)); 160 PetscCall(MatDestroy(&Na->A)); 161 PetscCall(PetscFree(N->data)); 162 PetscFunctionReturn(PETSC_SUCCESS); 163 } 164 165 /*@ 166 MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix 167 168 Collective 169 170 Input Parameters: 171 + A - matrix that we will extract a submatrix of 172 . isrow - rows to be present in the submatrix 173 - iscol - columns to be present in the submatrix 174 175 Output Parameter: 176 . newmat - new matrix 177 178 Level: developer 179 180 Note: 181 Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available. 182 183 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()` 184 @*/ 185 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat) 186 { 187 Vec left, right; 188 PetscInt m, n; 189 Mat N; 190 Mat_SubVirtual *Na; 191 192 PetscFunctionBegin; 193 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 194 PetscValidHeaderSpecific(isrow, IS_CLASSID, 2); 195 PetscValidHeaderSpecific(iscol, IS_CLASSID, 3); 196 PetscAssertPointer(newmat, 4); 197 *newmat = NULL; 198 199 PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N)); 200 PetscCall(ISGetLocalSize(isrow, &m)); 201 PetscCall(ISGetLocalSize(iscol, &n)); 202 PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE)); 203 PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX)); 204 205 PetscCall(PetscNew(&Na)); 206 N->data = (void *)Na; 207 208 PetscCall(PetscObjectReference((PetscObject)isrow)); 209 PetscCall(PetscObjectReference((PetscObject)iscol)); 210 Na->isrow = isrow; 211 Na->iscol = iscol; 212 213 PetscCall(PetscFree(N->defaultvectype)); 214 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 215 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 216 the reference count of the context. This is a problem if A is already of type MATSHELL */ 217 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 218 219 N->ops->destroy = MatDestroy_SubMatrix; 220 N->ops->mult = MatMult_SubMatrix; 221 N->ops->multadd = MatMultAdd_SubMatrix; 222 N->ops->multtranspose = MatMultTranspose_SubMatrix; 223 N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix; 224 N->ops->scale = MatScale_SubMatrix; 225 N->ops->diagonalscale = MatDiagonalScale_SubMatrix; 226 N->ops->shift = MatShift_SubMatrix; 227 N->ops->convert = MatConvert_Shell; 228 N->ops->getdiagonal = MatGetDiagonal_SubMatrix; 229 230 PetscCall(MatSetBlockSizesFromMats(N, A, A)); 231 PetscCall(PetscLayoutSetUp(N->rmap)); 232 PetscCall(PetscLayoutSetUp(N->cmap)); 233 234 PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork)); 235 PetscCall(MatCreateVecs(N, &right, &left)); 236 PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict)); 237 PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong)); 238 PetscCall(VecDestroy(&left)); 239 PetscCall(VecDestroy(&right)); 240 PetscCall(MatSetUp(N)); 241 242 N->assembled = PETSC_TRUE; 243 *newmat = N; 244 PetscFunctionReturn(PETSC_SUCCESS); 245 } 246 247 /*MC 248 MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix 249 250 Level: advanced 251 252 Developer Note: 253 The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` name should likely be changed to 254 `MATSUBMATRIXVIRTUAL` 255 256 .seealso: [](ch_matrices), `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()` 257 M*/ 258 259 /*@ 260 MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix 261 262 Collective 263 264 Input Parameters: 265 + N - submatrix to update 266 . A - full matrix in the submatrix 267 . isrow - rows in the update (same as the first time the submatrix was created) 268 - iscol - columns in the update (same as the first time the submatrix was created) 269 270 Level: developer 271 272 Note: 273 Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available. 274 275 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MatCreateSubMatrixVirtual()` 276 @*/ 277 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol) 278 { 279 PetscBool flg; 280 Mat_SubVirtual *Na; 281 282 PetscFunctionBegin; 283 PetscValidHeaderSpecific(N, MAT_CLASSID, 1); 284 PetscValidHeaderSpecific(A, MAT_CLASSID, 2); 285 PetscValidHeaderSpecific(isrow, IS_CLASSID, 3); 286 PetscValidHeaderSpecific(iscol, IS_CLASSID, 4); 287 PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg)); 288 PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type"); 289 290 Na = (Mat_SubVirtual *)N->data; 291 PetscCall(ISEqual(isrow, Na->isrow, &flg)); 292 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices"); 293 PetscCall(ISEqual(iscol, Na->iscol, &flg)); 294 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices"); 295 296 PetscCall(PetscFree(N->defaultvectype)); 297 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 298 PetscCall(MatDestroy(&Na->A)); 299 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 300 the reference count of the context. This is a problem if A is already of type MATSHELL */ 301 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 302 PetscFunctionReturn(PETSC_SUCCESS); 303 } 304