1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 typedef struct { 5 IS isrow, iscol; /* rows and columns in submatrix, only used to check consistency */ 6 Vec lwork, rwork; /* work vectors inside the scatters */ 7 Vec lwork2, rwork2; /* work vectors inside the scatters */ 8 VecScatter lrestrict, rprolong; 9 Mat A; 10 } Mat_SubVirtual; 11 12 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a) { 13 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 14 15 PetscFunctionBegin; 16 PetscCall(MatScale(Na->A, a)); 17 PetscFunctionReturn(0); 18 } 19 20 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a) { 21 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 22 23 PetscFunctionBegin; 24 PetscCall(MatShift(Na->A, a)); 25 PetscFunctionReturn(0); 26 } 27 28 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right) { 29 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 30 31 PetscFunctionBegin; 32 if (right) { 33 PetscCall(VecZeroEntries(Na->rwork)); 34 PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 35 PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 36 } 37 if (left) { 38 PetscCall(VecZeroEntries(Na->lwork)); 39 PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 40 PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 41 } 42 PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL)); 43 PetscFunctionReturn(0); 44 } 45 46 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d) { 47 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 48 49 PetscFunctionBegin; 50 PetscCall(MatGetDiagonal(Na->A, Na->rwork)); 51 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 52 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE)); 53 PetscFunctionReturn(0); 54 } 55 56 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y) { 57 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 58 59 PetscFunctionBegin; 60 PetscCall(VecZeroEntries(Na->rwork)); 61 PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 62 PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 63 PetscCall(MatMult(Na->A, Na->rwork, Na->lwork)); 64 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 65 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD)); 66 PetscFunctionReturn(0); 67 } 68 69 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) { 70 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 71 72 PetscFunctionBegin; 73 PetscCall(VecZeroEntries(Na->rwork)); 74 PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 75 PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 76 if (v1 == v2) { 77 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork)); 78 } else if (v2 == v3) { 79 PetscCall(VecZeroEntries(Na->lwork)); 80 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 81 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 82 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork)); 83 } else { 84 if (!Na->lwork2) { 85 PetscCall(VecDuplicate(Na->lwork, &Na->lwork2)); 86 } else { 87 PetscCall(VecZeroEntries(Na->lwork2)); 88 } 89 PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 90 PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE)); 91 PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork)); 92 } 93 PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 94 PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD)); 95 PetscFunctionReturn(0); 96 } 97 98 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y) { 99 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 100 101 PetscFunctionBegin; 102 PetscCall(VecZeroEntries(Na->lwork)); 103 PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 104 PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 105 PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork)); 106 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 107 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE)); 108 PetscFunctionReturn(0); 109 } 110 111 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) { 112 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 113 114 PetscFunctionBegin; 115 PetscCall(VecZeroEntries(Na->lwork)); 116 PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 117 PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE)); 118 if (v1 == v2) { 119 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork)); 120 } else if (v2 == v3) { 121 PetscCall(VecZeroEntries(Na->rwork)); 122 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 123 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD)); 124 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork)); 125 } else { 126 if (!Na->rwork2) { 127 PetscCall(VecDuplicate(Na->rwork, &Na->rwork2)); 128 } else { 129 PetscCall(VecZeroEntries(Na->rwork2)); 130 } 131 PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 132 PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD)); 133 PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork)); 134 } 135 PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 136 PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE)); 137 PetscFunctionReturn(0); 138 } 139 140 static PetscErrorCode MatDestroy_SubMatrix(Mat N) { 141 Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data; 142 143 PetscFunctionBegin; 144 PetscCall(ISDestroy(&Na->isrow)); 145 PetscCall(ISDestroy(&Na->iscol)); 146 PetscCall(VecDestroy(&Na->lwork)); 147 PetscCall(VecDestroy(&Na->rwork)); 148 PetscCall(VecDestroy(&Na->lwork2)); 149 PetscCall(VecDestroy(&Na->rwork2)); 150 PetscCall(VecScatterDestroy(&Na->lrestrict)); 151 PetscCall(VecScatterDestroy(&Na->rprolong)); 152 PetscCall(MatDestroy(&Na->A)); 153 PetscCall(PetscFree(N->data)); 154 PetscFunctionReturn(0); 155 } 156 157 /*@ 158 MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix 159 160 Collective on Mat 161 162 Input Parameters: 163 + A - matrix that we will extract a submatrix of 164 . isrow - rows to be present in the submatrix 165 - iscol - columns to be present in the submatrix 166 167 Output Parameters: 168 . newmat - new matrix 169 170 Level: developer 171 172 Notes: 173 Most will use MatCreateSubMatrix which provides a more efficient representation if it is available. 174 175 .seealso: `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()` 176 @*/ 177 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat) { 178 Vec left, right; 179 PetscInt m, n; 180 Mat N; 181 Mat_SubVirtual *Na; 182 183 PetscFunctionBegin; 184 PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 185 PetscValidHeaderSpecific(isrow, IS_CLASSID, 2); 186 PetscValidHeaderSpecific(iscol, IS_CLASSID, 3); 187 PetscValidPointer(newmat, 4); 188 *newmat = NULL; 189 190 PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N)); 191 PetscCall(ISGetLocalSize(isrow, &m)); 192 PetscCall(ISGetLocalSize(iscol, &n)); 193 PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE)); 194 PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX)); 195 196 PetscCall(PetscNewLog(N, &Na)); 197 N->data = (void *)Na; 198 199 PetscCall(PetscObjectReference((PetscObject)isrow)); 200 PetscCall(PetscObjectReference((PetscObject)iscol)); 201 Na->isrow = isrow; 202 Na->iscol = iscol; 203 204 PetscCall(PetscFree(N->defaultvectype)); 205 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 206 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 207 the reference count of the context. This is a problem if A is already of type MATSHELL */ 208 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 209 210 N->ops->destroy = MatDestroy_SubMatrix; 211 N->ops->mult = MatMult_SubMatrix; 212 N->ops->multadd = MatMultAdd_SubMatrix; 213 N->ops->multtranspose = MatMultTranspose_SubMatrix; 214 N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix; 215 N->ops->scale = MatScale_SubMatrix; 216 N->ops->diagonalscale = MatDiagonalScale_SubMatrix; 217 N->ops->shift = MatShift_SubMatrix; 218 N->ops->convert = MatConvert_Shell; 219 N->ops->getdiagonal = MatGetDiagonal_SubMatrix; 220 221 PetscCall(MatSetBlockSizesFromMats(N, A, A)); 222 PetscCall(PetscLayoutSetUp(N->rmap)); 223 PetscCall(PetscLayoutSetUp(N->cmap)); 224 225 PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork)); 226 PetscCall(MatCreateVecs(N, &right, &left)); 227 PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict)); 228 PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong)); 229 PetscCall(VecDestroy(&left)); 230 PetscCall(VecDestroy(&right)); 231 PetscCall(MatSetUp(N)); 232 233 N->assembled = PETSC_TRUE; 234 *newmat = N; 235 PetscFunctionReturn(0); 236 } 237 238 /*@ 239 MatSubMatrixVirtualUpdate - Updates a submatrix 240 241 Collective on Mat 242 243 Input Parameters: 244 + N - submatrix to update 245 . A - full matrix in the submatrix 246 . isrow - rows in the update (same as the first time the submatrix was created) 247 - iscol - columns in the update (same as the first time the submatrix was created) 248 249 Level: developer 250 251 Notes: 252 Most will use MatCreateSubMatrix which provides a more efficient representation if it is available. 253 254 .seealso: `MatCreateSubMatrixVirtual()` 255 @*/ 256 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol) { 257 PetscBool flg; 258 Mat_SubVirtual *Na; 259 260 PetscFunctionBegin; 261 PetscValidHeaderSpecific(N, MAT_CLASSID, 1); 262 PetscValidHeaderSpecific(A, MAT_CLASSID, 2); 263 PetscValidHeaderSpecific(isrow, IS_CLASSID, 3); 264 PetscValidHeaderSpecific(iscol, IS_CLASSID, 4); 265 PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg)); 266 PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type"); 267 268 Na = (Mat_SubVirtual *)N->data; 269 PetscCall(ISEqual(isrow, Na->isrow, &flg)); 270 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices"); 271 PetscCall(ISEqual(iscol, Na->iscol, &flg)); 272 PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices"); 273 274 PetscCall(PetscFree(N->defaultvectype)); 275 PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype)); 276 PetscCall(MatDestroy(&Na->A)); 277 /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase 278 the reference count of the context. This is a problem if A is already of type MATSHELL */ 279 PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A)); 280 PetscFunctionReturn(0); 281 } 282