1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_4_inplace(Mat A, Vec bb, Vec xx) { 5 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 6 IS iscol = a->col, isrow = a->row; 7 const PetscInt *r, *c, *rout, *cout; 8 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 9 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 10 const MatScalar *aa = a->a, *v; 11 PetscScalar s1, s2, s3, s4, x1, x2, x3, x4, *x, *t; 12 const PetscScalar *b; 13 14 PetscFunctionBegin; 15 PetscCall(VecGetArrayRead(bb, &b)); 16 PetscCall(VecGetArray(xx, &x)); 17 t = a->solve_work; 18 19 PetscCall(ISGetIndices(isrow, &rout)); 20 r = rout; 21 PetscCall(ISGetIndices(iscol, &cout)); 22 c = cout; 23 24 /* copy the b into temp work space according to permutation */ 25 ii = 0; 26 for (i = 0; i < n; i++) { 27 ic = 4 * c[i]; 28 t[ii] = b[ic]; 29 t[ii + 1] = b[ic + 1]; 30 t[ii + 2] = b[ic + 2]; 31 t[ii + 3] = b[ic + 3]; 32 ii += 4; 33 } 34 35 /* forward solve the U^T */ 36 idx = 0; 37 for (i = 0; i < n; i++) { 38 v = aa + 16 * diag[i]; 39 /* multiply by the inverse of the block diagonal */ 40 x1 = t[idx]; 41 x2 = t[1 + idx]; 42 x3 = t[2 + idx]; 43 x4 = t[3 + idx]; 44 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4; 45 s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4; 46 s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4; 47 s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4; 48 v += 16; 49 50 vi = aj + diag[i] + 1; 51 nz = ai[i + 1] - diag[i] - 1; 52 while (nz--) { 53 oidx = 4 * (*vi++); 54 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 55 t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 56 t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 57 t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 58 v += 16; 59 } 60 t[idx] = s1; 61 t[1 + idx] = s2; 62 t[2 + idx] = s3; 63 t[3 + idx] = s4; 64 idx += 4; 65 } 66 /* backward solve the L^T */ 67 for (i = n - 1; i >= 0; i--) { 68 v = aa + 16 * diag[i] - 16; 69 vi = aj + diag[i] - 1; 70 nz = diag[i] - ai[i]; 71 idt = 4 * i; 72 s1 = t[idt]; 73 s2 = t[1 + idt]; 74 s3 = t[2 + idt]; 75 s4 = t[3 + idt]; 76 while (nz--) { 77 idx = 4 * (*vi--); 78 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 79 t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 80 t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 81 t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 82 v -= 16; 83 } 84 } 85 86 /* copy t into x according to permutation */ 87 ii = 0; 88 for (i = 0; i < n; i++) { 89 ir = 4 * r[i]; 90 x[ir] = t[ii]; 91 x[ir + 1] = t[ii + 1]; 92 x[ir + 2] = t[ii + 2]; 93 x[ir + 3] = t[ii + 3]; 94 ii += 4; 95 } 96 97 PetscCall(ISRestoreIndices(isrow, &rout)); 98 PetscCall(ISRestoreIndices(iscol, &cout)); 99 PetscCall(VecRestoreArrayRead(bb, &b)); 100 PetscCall(VecRestoreArray(xx, &x)); 101 PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n)); 102 PetscFunctionReturn(0); 103 } 104 105 PetscErrorCode MatSolveTranspose_SeqBAIJ_4(Mat A, Vec bb, Vec xx) { 106 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 107 IS iscol = a->col, isrow = a->row; 108 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 109 const PetscInt *r, *c, *rout, *cout; 110 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 111 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 112 const MatScalar *aa = a->a, *v; 113 PetscScalar s1, s2, s3, s4, x1, x2, x3, x4, *x, *t; 114 const PetscScalar *b; 115 116 PetscFunctionBegin; 117 PetscCall(VecGetArrayRead(bb, &b)); 118 PetscCall(VecGetArray(xx, &x)); 119 t = a->solve_work; 120 121 PetscCall(ISGetIndices(isrow, &rout)); 122 r = rout; 123 PetscCall(ISGetIndices(iscol, &cout)); 124 c = cout; 125 126 /* copy b into temp work space according to permutation */ 127 for (i = 0; i < n; i++) { 128 ii = bs * i; 129 ic = bs * c[i]; 130 t[ii] = b[ic]; 131 t[ii + 1] = b[ic + 1]; 132 t[ii + 2] = b[ic + 2]; 133 t[ii + 3] = b[ic + 3]; 134 } 135 136 /* forward solve the U^T */ 137 idx = 0; 138 for (i = 0; i < n; i++) { 139 v = aa + bs2 * diag[i]; 140 /* multiply by the inverse of the block diagonal */ 141 x1 = t[idx]; 142 x2 = t[1 + idx]; 143 x3 = t[2 + idx]; 144 x4 = t[3 + idx]; 145 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4; 146 s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4; 147 s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4; 148 s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4; 149 v -= bs2; 150 151 vi = aj + diag[i] - 1; 152 nz = diag[i] - diag[i + 1] - 1; 153 for (j = 0; j > -nz; j--) { 154 oidx = bs * vi[j]; 155 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 156 t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 157 t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 158 t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 159 v -= bs2; 160 } 161 t[idx] = s1; 162 t[1 + idx] = s2; 163 t[2 + idx] = s3; 164 t[3 + idx] = s4; 165 idx += bs; 166 } 167 /* backward solve the L^T */ 168 for (i = n - 1; i >= 0; i--) { 169 v = aa + bs2 * ai[i]; 170 vi = aj + ai[i]; 171 nz = ai[i + 1] - ai[i]; 172 idt = bs * i; 173 s1 = t[idt]; 174 s2 = t[1 + idt]; 175 s3 = t[2 + idt]; 176 s4 = t[3 + idt]; 177 for (j = 0; j < nz; j++) { 178 idx = bs * vi[j]; 179 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 180 t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 181 t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 182 t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 183 v += bs2; 184 } 185 } 186 187 /* copy t into x according to permutation */ 188 for (i = 0; i < n; i++) { 189 ii = bs * i; 190 ir = bs * r[i]; 191 x[ir] = t[ii]; 192 x[ir + 1] = t[ii + 1]; 193 x[ir + 2] = t[ii + 2]; 194 x[ir + 3] = t[ii + 3]; 195 } 196 197 PetscCall(ISRestoreIndices(isrow, &rout)); 198 PetscCall(ISRestoreIndices(iscol, &cout)); 199 PetscCall(VecRestoreArrayRead(bb, &b)); 200 PetscCall(VecRestoreArray(xx, &x)); 201 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 202 PetscFunctionReturn(0); 203 } 204