1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 typedef struct { 5 IS isrow, iscol; /* rows and columns in submatrix, only used to check consistency */ 6 Vec lwork, rwork; /* work vectors inside the scatters */ 7 Vec lwork2, rwork2; /* work vectors inside the scatters */ 8 VecScatter lrestrict, rprolong; 9 Mat A; 10 } Mat_SubVirtual; 11 12 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a) { 13 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatScale(Na->A, a)); 17 PetscFunctionReturn(0); 18 } 19 20 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a) { 21 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 22 23 PetscFunctionBegin; 24 PetscCall(MatShift(Na->A, a)); 25 PetscFunctionReturn(0); 26 } 27 28 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right) { 29 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 30 31 PetscFunctionBegin; 32 if (right) { 33 PetscCall(VecZeroEntries(Na->rwork)); 34 PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 35 PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 36 } 37 if (left) { 38 PetscCall(VecZeroEntries(Na->lwork)); 39 PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 40 PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 41 } 42 PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL)); 43 PetscFunctionReturn(0); 44 } 45 46 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d) { 47 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 48 49 PetscFunctionBegin; 50 PetscCall(MatGetDiagonal(Na->A, Na->rwork)); 51 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 52 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 53 PetscFunctionReturn(0); 54 } 55 56 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y) { 57 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 58 59 PetscFunctionBegin; 60 PetscCall(VecZeroEntries(Na->rwork)); 61 PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 62 PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 63 PetscCall(MatMult(Na->A, Na->rwork, Na->lwork)); 64 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 65 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 66 PetscFunctionReturn(0); 67 } 68 69 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) { 70 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 71 72 PetscFunctionBegin; 73 PetscCall(VecZeroEntries(Na->rwork)); 74 PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 75 PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 76 if (v1 == v2) { 77 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork)); 78 } else if (v2 == v3) { 79 PetscCall(VecZeroEntries(Na->lwork)); 80 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 81 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 82 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork)); 83 } else { 84 if (!Na->lwork2) { 85 PetscCall(VecDuplicate(Na->lwork, &Na->lwork2)); 86 } else { 87 PetscCall(VecZeroEntries(Na->lwork2)); 88 } 89 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 90 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 91 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork)); 92 } 93 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 94 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 95 PetscFunctionReturn(0); 96 } 97 98 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y) { 99 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 100 101 PetscFunctionBegin; 102 PetscCall(VecZeroEntries(Na->lwork)); 103 PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 104 PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 105 PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork)); 106 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 107 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 108 PetscFunctionReturn(0); 109 } 110 111 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) { 112 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 113 114 PetscFunctionBegin; 115 PetscCall(VecZeroEntries(Na->lwork)); 116 PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 117 PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 118 if (v1 == v2) { 119 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork)); 120 } else if (v2 == v3) { 121 PetscCall(VecZeroEntries(Na->rwork)); 122 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 123 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 124 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork)); 125 } else { 126 if (!Na->rwork2) { 127 PetscCall(VecDuplicate(Na->rwork, &Na->rwork2)); 128 } else { 129 PetscCall(VecZeroEntries(Na->rwork2)); 130 } 131 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 132 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 133 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork)); 134 } 135 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 136 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 137 PetscFunctionReturn(0); 138 } 139 140 static PetscErrorCode MatDestroy_SubMatrix(Mat N) { 141 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 142 143 PetscFunctionBegin; 144 PetscCall(ISDestroy(&Na->isrow)); 145 PetscCall(ISDestroy(&Na->iscol)); 146 PetscCall(VecDestroy(&Na->lwork)); 147 PetscCall(VecDestroy(&Na->rwork)); 148 PetscCall(VecDestroy(&Na->lwork2)); 149 PetscCall(VecDestroy(&Na->rwork2)); 150 PetscCall(VecScatterDestroy(&Na->lrestrict)); 151 PetscCall(VecScatterDestroy(&Na->rprolong)); 152 PetscCall(MatDestroy(&Na->A)); 153 PetscCall(PetscFree(N->data)); 154 PetscFunctionReturn(0); 155 } 156 157 /*@ 158 MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix 159 160 Collective on A 161 162 Input Parameters: 163 + A - matrix that we will extract a submatrix of 164 . isrow - rows to be present in the submatrix 165 - iscol - columns to be present in the submatrix 166 167 Output Parameters: 168 . newmat - new matrix 169 170 Level: developer 171 172 Note: 173 Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available. 174 175 Developer Note: 176 The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` should likely be changed 177 178 .seealso: `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()` 179 @*/ 180 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat) { 181 Vec left, right; 182 PetscInt m, n; 183 Mat N; 184 Mat_SubVirtual *Na; 185 186 PetscFunctionBegin; 187 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 188 PetscValidHeaderSpecific(isrow, IS_CLASSID, 2); 189 PetscValidHeaderSpecific(iscol, IS_CLASSID, 3); 190 PetscValidPointer(newmat, 4); 191 *newmat = NULL; 192 193 PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N)); 194 PetscCall(ISGetLocalSize(isrow, &m)); 195 PetscCall(ISGetLocalSize(iscol, &n)); 196 PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE)); 197 PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX)); 198 199 PetscCall(PetscNewLog(N, &Na)); 200 N->data = (void *)Na; 201 202 PetscCall(PetscObjectReference((PetscObject)isrow)); 203 PetscCall(PetscObjectReference((PetscObject)iscol)); 204 Na->isrow = isrow; 205 Na->iscol = iscol; 206 207 PetscCall(PetscFree(N->defaultvectype)); 208 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 209 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 210 the reference count of the context. This is a problem if A is already of type MATSHELL */ 211 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 212 213 N->ops->destroy = MatDestroy_SubMatrix; 214 N->ops->mult = MatMult_SubMatrix; 215 N->ops->multadd = MatMultAdd_SubMatrix; 216 N->ops->multtranspose = MatMultTranspose_SubMatrix; 217 N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix; 218 N->ops->scale = MatScale_SubMatrix; 219 N->ops->diagonalscale = MatDiagonalScale_SubMatrix; 220 N->ops->shift = MatShift_SubMatrix; 221 N->ops->convert = MatConvert_Shell; 222 N->ops->getdiagonal = MatGetDiagonal_SubMatrix; 223 224 PetscCall(MatSetBlockSizesFromMats(N, A, A)); 225 PetscCall(PetscLayoutSetUp(N->rmap)); 226 PetscCall(PetscLayoutSetUp(N->cmap)); 227 228 PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork)); 229 PetscCall(MatCreateVecs(N, &right, &left)); 230 PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict)); 231 PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong)); 232 PetscCall(VecDestroy(&left)); 233 PetscCall(VecDestroy(&right)); 234 PetscCall(MatSetUp(N)); 235 236 N->assembled = PETSC_TRUE; 237 *newmat = N; 238 PetscFunctionReturn(0); 239 } 240 241 /*MC 242 MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix 243 244 Level: advanced 245 246 Developer Note: 247 This should be named `MATSUBMATRIXVIRTUAL` 248 249 .seealso: `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()` 250 M*/ 251 252 /*@ 253 MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix 254 255 Collective on N 256 257 Input Parameters: 258 + N - submatrix to update 259 . A - full matrix in the submatrix 260 . isrow - rows in the update (same as the first time the submatrix was created) 261 - iscol - columns in the update (same as the first time the submatrix was created) 262 263 Level: developer 264 265 Note: 266 Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available. 267 268 .seealso: MATSUBMATRIX`, `MatCreateSubMatrixVirtual()` 269 @*/ 270 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol) { 271 PetscBool flg; 272 Mat_SubVirtual *Na; 273 274 PetscFunctionBegin; 275 PetscValidHeaderSpecific(N, MAT_CLASSID, 1); 276 PetscValidHeaderSpecific(A, MAT_CLASSID, 2); 277 PetscValidHeaderSpecific(isrow, IS_CLASSID, 3); 278 PetscValidHeaderSpecific(iscol, IS_CLASSID, 4); 279 PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg)); 280 PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type"); 281 282 Na = (Mat_SubVirtual *)N->data; 283 PetscCall(ISEqual(isrow, Na->isrow, &flg)); 284 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices"); 285 PetscCall(ISEqual(iscol, Na->iscol, &flg)); 286 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices"); 287 288 PetscCall(PetscFree(N->defaultvectype)); 289 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 290 PetscCall(MatDestroy(&Na->A)); 291 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 292 the reference count of the context. This is a problem if A is already of type MATSHELL */ 293 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 294 PetscFunctionReturn(0); 295 } 296