1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_3_inplace(Mat A, Vec bb, Vec xx) { 5 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 6 IS iscol = a->col, isrow = a->row; 7 const PetscInt *r, *c, *rout, *cout; 8 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 9 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 10 const MatScalar *aa = a->a, *v; 11 PetscScalar s1, s2, s3, x1, x2, x3, *x, *t; 12 const PetscScalar *b; 13 14 PetscFunctionBegin; 15 PetscCall(VecGetArrayRead(bb, &b)); 16 PetscCall(VecGetArray(xx, &x)); 17 t = a->solve_work; 18 19 PetscCall(ISGetIndices(isrow, &rout)); 20 r = rout; 21 PetscCall(ISGetIndices(iscol, &cout)); 22 c = cout; 23 24 /* copy the b into temp work space according to permutation */ 25 ii = 0; 26 for (i = 0; i < n; i++) { 27 ic = 3 * c[i]; 28 t[ii] = b[ic]; 29 t[ii + 1] = b[ic + 1]; 30 t[ii + 2] = b[ic + 2]; 31 ii += 3; 32 } 33 34 /* forward solve the U^T */ 35 idx = 0; 36 for (i = 0; i < n; i++) { 37 v = aa + 9 * diag[i]; 38 /* multiply by the inverse of the block diagonal */ 39 x1 = t[idx]; 40 x2 = t[1 + idx]; 41 x3 = t[2 + idx]; 42 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3; 43 s2 = v[3] * x1 + v[4] * x2 + v[5] * x3; 44 s3 = v[6] * x1 + v[7] * x2 + v[8] * x3; 45 v += 9; 46 47 vi = aj + diag[i] + 1; 48 nz = ai[i + 1] - diag[i] - 1; 49 while (nz--) { 50 oidx = 3 * (*vi++); 51 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 52 t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 53 t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 54 v += 9; 55 } 56 t[idx] = s1; 57 t[1 + idx] = s2; 58 t[2 + idx] = s3; 59 idx += 3; 60 } 61 /* backward solve the L^T */ 62 for (i = n - 1; i >= 0; i--) { 63 v = aa + 9 * diag[i] - 9; 64 vi = aj + diag[i] - 1; 65 nz = diag[i] - ai[i]; 66 idt = 3 * i; 67 s1 = t[idt]; 68 s2 = t[1 + idt]; 69 s3 = t[2 + idt]; 70 while (nz--) { 71 idx = 3 * (*vi--); 72 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 73 t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 74 t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 75 v -= 9; 76 } 77 } 78 79 /* copy t into x according to permutation */ 80 ii = 0; 81 for (i = 0; i < n; i++) { 82 ir = 3 * r[i]; 83 x[ir] = t[ii]; 84 x[ir + 1] = t[ii + 1]; 85 x[ir + 2] = t[ii + 2]; 86 ii += 3; 87 } 88 89 PetscCall(ISRestoreIndices(isrow, &rout)); 90 PetscCall(ISRestoreIndices(iscol, &cout)); 91 PetscCall(VecRestoreArrayRead(bb, &b)); 92 PetscCall(VecRestoreArray(xx, &x)); 93 PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n)); 94 PetscFunctionReturn(0); 95 } 96 97 PetscErrorCode MatSolveTranspose_SeqBAIJ_3(Mat A, Vec bb, Vec xx) { 98 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 99 IS iscol = a->col, isrow = a->row; 100 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 101 const PetscInt *r, *c, *rout, *cout; 102 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 103 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 104 const MatScalar *aa = a->a, *v; 105 PetscScalar s1, s2, s3, x1, x2, x3, *x, *t; 106 const PetscScalar *b; 107 108 PetscFunctionBegin; 109 PetscCall(VecGetArrayRead(bb, &b)); 110 PetscCall(VecGetArray(xx, &x)); 111 t = a->solve_work; 112 113 PetscCall(ISGetIndices(isrow, &rout)); 114 r = rout; 115 PetscCall(ISGetIndices(iscol, &cout)); 116 c = cout; 117 118 /* copy b into temp work space according to permutation */ 119 for (i = 0; i < n; i++) { 120 ii = bs * i; 121 ic = bs * c[i]; 122 t[ii] = b[ic]; 123 t[ii + 1] = b[ic + 1]; 124 t[ii + 2] = b[ic + 2]; 125 } 126 127 /* forward solve the U^T */ 128 idx = 0; 129 for (i = 0; i < n; i++) { 130 v = aa + bs2 * diag[i]; 131 /* multiply by the inverse of the block diagonal */ 132 x1 = t[idx]; 133 x2 = t[1 + idx]; 134 x3 = t[2 + idx]; 135 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3; 136 s2 = v[3] * x1 + v[4] * x2 + v[5] * x3; 137 s3 = v[6] * x1 + v[7] * x2 + v[8] * x3; 138 v -= bs2; 139 140 vi = aj + diag[i] - 1; 141 nz = diag[i] - diag[i + 1] - 1; 142 for (j = 0; j > -nz; j--) { 143 oidx = bs * vi[j]; 144 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 145 t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 146 t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 147 v -= bs2; 148 } 149 t[idx] = s1; 150 t[1 + idx] = s2; 151 t[2 + idx] = s3; 152 idx += bs; 153 } 154 /* backward solve the L^T */ 155 for (i = n - 1; i >= 0; i--) { 156 v = aa + bs2 * ai[i]; 157 vi = aj + ai[i]; 158 nz = ai[i + 1] - ai[i]; 159 idt = bs * i; 160 s1 = t[idt]; 161 s2 = t[1 + idt]; 162 s3 = t[2 + idt]; 163 for (j = 0; j < nz; j++) { 164 idx = bs * vi[j]; 165 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 166 t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 167 t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 168 v += bs2; 169 } 170 } 171 172 /* copy t into x according to permutation */ 173 for (i = 0; i < n; i++) { 174 ii = bs * i; 175 ir = bs * r[i]; 176 x[ir] = t[ii]; 177 x[ir + 1] = t[ii + 1]; 178 x[ir + 2] = t[ii + 2]; 179 } 180 181 PetscCall(ISRestoreIndices(isrow, &rout)); 182 PetscCall(ISRestoreIndices(iscol, &cout)); 183 PetscCall(VecRestoreArrayRead(bb, &b)); 184 PetscCall(VecRestoreArray(xx, &x)); 185 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 186 PetscFunctionReturn(0); 187 } 188