1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A, Vec bb, Vec xx) 5 { 6 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 7 IS iscol = a->col, isrow = a->row; 8 const PetscInt *r, *c, *rout, *cout; 9 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 10 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 11 const MatScalar *aa = a->a, *v; 12 PetscScalar s1, s2, x1, x2, *x, *t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb, &b)); 17 PetscCall(VecGetArray(xx, &x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow, &rout)); 21 r = rout; 22 PetscCall(ISGetIndices(iscol, &cout)); 23 c = cout; 24 25 /* copy the b into temp work space according to permutation */ 26 ii = 0; 27 for (i = 0; i < n; i++) { 28 ic = 2 * c[i]; 29 t[ii] = b[ic]; 30 t[ii + 1] = b[ic + 1]; 31 ii += 2; 32 } 33 34 /* forward solve the U^T */ 35 idx = 0; 36 for (i = 0; i < n; i++) { 37 v = aa + 4 * diag[i]; 38 /* multiply by the inverse of the block diagonal */ 39 x1 = t[idx]; 40 x2 = t[1 + idx]; 41 s1 = v[0] * x1 + v[1] * x2; 42 s2 = v[2] * x1 + v[3] * x2; 43 v += 4; 44 45 vi = aj + diag[i] + 1; 46 nz = ai[i + 1] - diag[i] - 1; 47 while (nz--) { 48 oidx = 2 * (*vi++); 49 t[oidx] -= v[0] * s1 + v[1] * s2; 50 t[oidx + 1] -= v[2] * s1 + v[3] * s2; 51 v += 4; 52 } 53 t[idx] = s1; 54 t[1 + idx] = s2; 55 idx += 2; 56 } 57 /* backward solve the L^T */ 58 for (i = n - 1; i >= 0; i--) { 59 v = aa + 4 * diag[i] - 4; 60 vi = aj + diag[i] - 1; 61 nz = diag[i] - ai[i]; 62 idt = 2 * i; 63 s1 = t[idt]; 64 s2 = t[1 + idt]; 65 while (nz--) { 66 idx = 2 * (*vi--); 67 t[idx] -= v[0] * s1 + v[1] * s2; 68 t[idx + 1] -= v[2] * s1 + v[3] * s2; 69 v -= 4; 70 } 71 } 72 73 /* copy t into x according to permutation */ 74 ii = 0; 75 for (i = 0; i < n; i++) { 76 ir = 2 * r[i]; 77 x[ir] = t[ii]; 78 x[ir + 1] = t[ii + 1]; 79 ii += 2; 80 } 81 82 PetscCall(ISRestoreIndices(isrow, &rout)); 83 PetscCall(ISRestoreIndices(iscol, &cout)); 84 PetscCall(VecRestoreArrayRead(bb, &b)); 85 PetscCall(VecRestoreArray(xx, &x)); 86 PetscCall(PetscLogFlops(2.0 * 4 * (a->nz) - 2.0 * A->cmap->n)); 87 PetscFunctionReturn(PETSC_SUCCESS); 88 } 89 90 PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A, Vec bb, Vec xx) 91 { 92 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 93 IS iscol = a->col, isrow = a->row; 94 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 95 const PetscInt *r, *c, *rout, *cout; 96 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 97 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 98 const MatScalar *aa = a->a, *v; 99 PetscScalar s1, s2, x1, x2, *x, *t; 100 const PetscScalar *b; 101 102 PetscFunctionBegin; 103 PetscCall(VecGetArrayRead(bb, &b)); 104 PetscCall(VecGetArray(xx, &x)); 105 t = a->solve_work; 106 107 PetscCall(ISGetIndices(isrow, &rout)); 108 r = rout; 109 PetscCall(ISGetIndices(iscol, &cout)); 110 c = cout; 111 112 /* copy b into temp work space according to permutation */ 113 for (i = 0; i < n; i++) { 114 ii = bs * i; 115 ic = bs * c[i]; 116 t[ii] = b[ic]; 117 t[ii + 1] = b[ic + 1]; 118 } 119 120 /* forward solve the U^T */ 121 idx = 0; 122 for (i = 0; i < n; i++) { 123 v = aa + bs2 * diag[i]; 124 /* multiply by the inverse of the block diagonal */ 125 x1 = t[idx]; 126 x2 = t[1 + idx]; 127 s1 = v[0] * x1 + v[1] * x2; 128 s2 = v[2] * x1 + v[3] * x2; 129 v -= bs2; 130 131 vi = aj + diag[i] - 1; 132 nz = diag[i] - diag[i + 1] - 1; 133 for (j = 0; j > -nz; j--) { 134 oidx = bs * vi[j]; 135 t[oidx] -= v[0] * s1 + v[1] * s2; 136 t[oidx + 1] -= v[2] * s1 + v[3] * s2; 137 v -= bs2; 138 } 139 t[idx] = s1; 140 t[1 + idx] = s2; 141 idx += bs; 142 } 143 /* backward solve the L^T */ 144 for (i = n - 1; i >= 0; i--) { 145 v = aa + bs2 * ai[i]; 146 vi = aj + ai[i]; 147 nz = ai[i + 1] - ai[i]; 148 idt = bs * i; 149 s1 = t[idt]; 150 s2 = t[1 + idt]; 151 for (j = 0; j < nz; j++) { 152 idx = bs * vi[j]; 153 t[idx] -= v[0] * s1 + v[1] * s2; 154 t[idx + 1] -= v[2] * s1 + v[3] * s2; 155 v += bs2; 156 } 157 } 158 159 /* copy t into x according to permutation */ 160 for (i = 0; i < n; i++) { 161 ii = bs * i; 162 ir = bs * r[i]; 163 x[ir] = t[ii]; 164 x[ir + 1] = t[ii + 1]; 165 } 166 167 PetscCall(ISRestoreIndices(isrow, &rout)); 168 PetscCall(ISRestoreIndices(iscol, &cout)); 169 PetscCall(VecRestoreArrayRead(bb, &b)); 170 PetscCall(VecRestoreArray(xx, &x)); 171 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 172 PetscFunctionReturn(PETSC_SUCCESS); 173 } 174