1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_5_inplace(Mat A, Vec bb, Vec xx) { 5 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 6 IS iscol = a->col, isrow = a->row; 7 const PetscInt *r, *c, *rout, *cout; 8 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 9 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 10 const MatScalar *aa = a->a, *v; 11 PetscScalar s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *x, *t; 12 const PetscScalar *b; 13 14 PetscFunctionBegin; 15 PetscCall(VecGetArrayRead(bb, &b)); 16 PetscCall(VecGetArray(xx, &x)); 17 t = a->solve_work; 18 19 PetscCall(ISGetIndices(isrow, &rout)); 20 r = rout; 21 PetscCall(ISGetIndices(iscol, &cout)); 22 c = cout; 23 24 /* copy the b into temp work space according to permutation */ 25 ii = 0; 26 for (i = 0; i < n; i++) { 27 ic = 5 * c[i]; 28 t[ii] = b[ic]; 29 t[ii + 1] = b[ic + 1]; 30 t[ii + 2] = b[ic + 2]; 31 t[ii + 3] = b[ic + 3]; 32 t[ii + 4] = b[ic + 4]; 33 ii += 5; 34 } 35 36 /* forward solve the U^T */ 37 idx = 0; 38 for (i = 0; i < n; i++) { 39 v = aa + 25 * diag[i]; 40 /* multiply by the inverse of the block diagonal */ 41 x1 = t[idx]; 42 x2 = t[1 + idx]; 43 x3 = t[2 + idx]; 44 x4 = t[3 + idx]; 45 x5 = t[4 + idx]; 46 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5; 47 s2 = v[5] * x1 + v[6] * x2 + v[7] * x3 + v[8] * x4 + v[9] * x5; 48 s3 = v[10] * x1 + v[11] * x2 + v[12] * x3 + v[13] * x4 + v[14] * x5; 49 s4 = v[15] * x1 + v[16] * x2 + v[17] * x3 + v[18] * x4 + v[19] * x5; 50 s5 = v[20] * x1 + v[21] * x2 + v[22] * x3 + v[23] * x4 + v[24] * x5; 51 v += 25; 52 53 vi = aj + diag[i] + 1; 54 nz = ai[i + 1] - diag[i] - 1; 55 while (nz--) { 56 oidx = 5 * (*vi++); 57 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5; 58 t[oidx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5; 59 t[oidx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5; 60 t[oidx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5; 61 t[oidx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5; 62 v += 25; 63 } 64 t[idx] = s1; 65 t[1 + idx] = s2; 66 t[2 + idx] = s3; 67 t[3 + idx] = s4; 68 t[4 + idx] = s5; 69 idx += 5; 70 } 71 /* backward solve the L^T */ 72 for (i = n - 1; i >= 0; i--) { 73 v = aa + 25 * diag[i] - 25; 74 vi = aj + diag[i] - 1; 75 nz = diag[i] - ai[i]; 76 idt = 5 * i; 77 s1 = t[idt]; 78 s2 = t[1 + idt]; 79 s3 = t[2 + idt]; 80 s4 = t[3 + idt]; 81 s5 = t[4 + idt]; 82 while (nz--) { 83 idx = 5 * (*vi--); 84 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5; 85 t[idx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5; 86 t[idx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5; 87 t[idx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5; 88 t[idx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5; 89 v -= 25; 90 } 91 } 92 93 /* copy t into x according to permutation */ 94 ii = 0; 95 for (i = 0; i < n; i++) { 96 ir = 5 * r[i]; 97 x[ir] = t[ii]; 98 x[ir + 1] = t[ii + 1]; 99 x[ir + 2] = t[ii + 2]; 100 x[ir + 3] = t[ii + 3]; 101 x[ir + 4] = t[ii + 4]; 102 ii += 5; 103 } 104 105 PetscCall(ISRestoreIndices(isrow, &rout)); 106 PetscCall(ISRestoreIndices(iscol, &cout)); 107 PetscCall(VecRestoreArrayRead(bb, &b)); 108 PetscCall(VecRestoreArray(xx, &x)); 109 PetscCall(PetscLogFlops(2.0 * 25 * (a->nz) - 5.0 * A->cmap->n)); 110 PetscFunctionReturn(0); 111 } 112 113 PetscErrorCode MatSolveTranspose_SeqBAIJ_5(Mat A, Vec bb, Vec xx) { 114 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 115 IS iscol = a->col, isrow = a->row; 116 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 117 const PetscInt *r, *c, *rout, *cout; 118 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 119 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 120 const MatScalar *aa = a->a, *v; 121 PetscScalar s1, s2, s3, s4, s5, x1, x2, x3, x4, x5, *x, *t; 122 const PetscScalar *b; 123 124 PetscFunctionBegin; 125 PetscCall(VecGetArrayRead(bb, &b)); 126 PetscCall(VecGetArray(xx, &x)); 127 t = a->solve_work; 128 129 PetscCall(ISGetIndices(isrow, &rout)); 130 r = rout; 131 PetscCall(ISGetIndices(iscol, &cout)); 132 c = cout; 133 134 /* copy b into temp work space according to permutation */ 135 for (i = 0; i < n; i++) { 136 ii = bs * i; 137 ic = bs * c[i]; 138 t[ii] = b[ic]; 139 t[ii + 1] = b[ic + 1]; 140 t[ii + 2] = b[ic + 2]; 141 t[ii + 3] = b[ic + 3]; 142 t[ii + 4] = b[ic + 4]; 143 } 144 145 /* forward solve the U^T */ 146 idx = 0; 147 for (i = 0; i < n; i++) { 148 v = aa + bs2 * diag[i]; 149 /* multiply by the inverse of the block diagonal */ 150 x1 = t[idx]; 151 x2 = t[1 + idx]; 152 x3 = t[2 + idx]; 153 x4 = t[3 + idx]; 154 x5 = t[4 + idx]; 155 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5; 156 s2 = v[5] * x1 + v[6] * x2 + v[7] * x3 + v[8] * x4 + v[9] * x5; 157 s3 = v[10] * x1 + v[11] * x2 + v[12] * x3 + v[13] * x4 + v[14] * x5; 158 s4 = v[15] * x1 + v[16] * x2 + v[17] * x3 + v[18] * x4 + v[19] * x5; 159 s5 = v[20] * x1 + v[21] * x2 + v[22] * x3 + v[23] * x4 + v[24] * x5; 160 v -= bs2; 161 162 vi = aj + diag[i] - 1; 163 nz = diag[i] - diag[i + 1] - 1; 164 for (j = 0; j > -nz; j--) { 165 oidx = bs * vi[j]; 166 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5; 167 t[oidx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5; 168 t[oidx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5; 169 t[oidx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5; 170 t[oidx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5; 171 v -= bs2; 172 } 173 t[idx] = s1; 174 t[1 + idx] = s2; 175 t[2 + idx] = s3; 176 t[3 + idx] = s4; 177 t[4 + idx] = s5; 178 idx += bs; 179 } 180 /* backward solve the L^T */ 181 for (i = n - 1; i >= 0; i--) { 182 v = aa + bs2 * ai[i]; 183 vi = aj + ai[i]; 184 nz = ai[i + 1] - ai[i]; 185 idt = bs * i; 186 s1 = t[idt]; 187 s2 = t[1 + idt]; 188 s3 = t[2 + idt]; 189 s4 = t[3 + idt]; 190 s5 = t[4 + idt]; 191 for (j = 0; j < nz; j++) { 192 idx = bs * vi[j]; 193 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5; 194 t[idx + 1] -= v[5] * s1 + v[6] * s2 + v[7] * s3 + v[8] * s4 + v[9] * s5; 195 t[idx + 2] -= v[10] * s1 + v[11] * s2 + v[12] * s3 + v[13] * s4 + v[14] * s5; 196 t[idx + 3] -= v[15] * s1 + v[16] * s2 + v[17] * s3 + v[18] * s4 + v[19] * s5; 197 t[idx + 4] -= v[20] * s1 + v[21] * s2 + v[22] * s3 + v[23] * s4 + v[24] * s5; 198 v += bs2; 199 } 200 } 201 202 /* copy t into x according to permutation */ 203 for (i = 0; i < n; i++) { 204 ii = bs * i; 205 ir = bs * r[i]; 206 x[ir] = t[ii]; 207 x[ir + 1] = t[ii + 1]; 208 x[ir + 2] = t[ii + 2]; 209 x[ir + 3] = t[ii + 3]; 210 x[ir + 4] = t[ii + 4]; 211 } 212 213 PetscCall(ISRestoreIndices(isrow, &rout)); 214 PetscCall(ISRestoreIndices(iscol, &cout)); 215 PetscCall(VecRestoreArrayRead(bb, &b)); 216 PetscCall(VecRestoreArray(xx, &x)); 217 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 218 PetscFunctionReturn(0); 219 } 220