1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_7_inplace(Mat A, Vec bb, Vec xx) { 5 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 6 IS iscol = a->col, isrow = a->row; 7 const PetscInt *r, *c, *rout, *cout; 8 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 9 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 10 const MatScalar *aa = a->a, *v; 11 PetscScalar s1, s2, s3, s4, s5, s6, s7, x1, x2, x3, x4, x5, x6, x7, *x, *t; 12 const PetscScalar *b; 13 14 PetscFunctionBegin; 15 PetscCall(VecGetArrayRead(bb, &b)); 16 PetscCall(VecGetArray(xx, &x)); 17 t = a->solve_work; 18 19 PetscCall(ISGetIndices(isrow, &rout)); 20 r = rout; 21 PetscCall(ISGetIndices(iscol, &cout)); 22 c = cout; 23 24 /* copy the b into temp work space according to permutation */ 25 ii = 0; 26 for (i = 0; i < n; i++) { 27 ic = 7 * c[i]; 28 t[ii] = b[ic]; 29 t[ii + 1] = b[ic + 1]; 30 t[ii + 2] = b[ic + 2]; 31 t[ii + 3] = b[ic + 3]; 32 t[ii + 4] = b[ic + 4]; 33 t[ii + 5] = b[ic + 5]; 34 t[ii + 6] = b[ic + 6]; 35 ii += 7; 36 } 37 38 /* forward solve the U^T */ 39 idx = 0; 40 for (i = 0; i < n; i++) { 41 v = aa + 49 * diag[i]; 42 /* multiply by the inverse of the block diagonal */ 43 x1 = t[idx]; 44 x2 = t[1 + idx]; 45 x3 = t[2 + idx]; 46 x4 = t[3 + idx]; 47 x5 = t[4 + idx]; 48 x6 = t[5 + idx]; 49 x7 = t[6 + idx]; 50 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6 + v[6] * x7; 51 s2 = v[7] * x1 + v[8] * x2 + v[9] * x3 + v[10] * x4 + v[11] * x5 + v[12] * x6 + v[13] * x7; 52 s3 = v[14] * x1 + v[15] * x2 + v[16] * x3 + v[17] * x4 + v[18] * x5 + v[19] * x6 + v[20] * x7; 53 s4 = v[21] * x1 + v[22] * x2 + v[23] * x3 + v[24] * x4 + v[25] * x5 + v[26] * x6 + v[27] * x7; 54 s5 = v[28] * x1 + v[29] * x2 + v[30] * x3 + v[31] * x4 + v[32] * x5 + v[33] * x6 + v[34] * x7; 55 s6 = v[35] * x1 + v[36] * x2 + v[37] * x3 + v[38] * x4 + v[39] * x5 + v[40] * x6 + v[41] * x7; 56 s7 = v[42] * x1 + v[43] * x2 + v[44] * x3 + v[45] * x4 + v[46] * x5 + v[47] * x6 + v[48] * x7; 57 v += 49; 58 59 vi = aj + diag[i] + 1; 60 nz = ai[i + 1] - diag[i] - 1; 61 while (nz--) { 62 oidx = 7 * (*vi++); 63 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6 + v[6] * s7; 64 t[oidx + 1] -= v[7] * s1 + v[8] * s2 + v[9] * s3 + v[10] * s4 + v[11] * s5 + v[12] * s6 + v[13] * s7; 65 t[oidx + 2] -= v[14] * s1 + v[15] * s2 + v[16] * s3 + v[17] * s4 + v[18] * s5 + v[19] * s6 + v[20] * s7; 66 t[oidx + 3] -= v[21] * s1 + v[22] * s2 + v[23] * s3 + v[24] * s4 + v[25] * s5 + v[26] * s6 + v[27] * s7; 67 t[oidx + 4] -= v[28] * s1 + v[29] * s2 + v[30] * s3 + v[31] * s4 + v[32] * s5 + v[33] * s6 + v[34] * s7; 68 t[oidx + 5] -= v[35] * s1 + v[36] * s2 + v[37] * s3 + v[38] * s4 + v[39] * s5 + v[40] * s6 + v[41] * s7; 69 t[oidx + 6] -= v[42] * s1 + v[43] * s2 + v[44] * s3 + v[45] * s4 + v[46] * s5 + v[47] * s6 + v[48] * s7; 70 v += 49; 71 } 72 t[idx] = s1; 73 t[1 + idx] = s2; 74 t[2 + idx] = s3; 75 t[3 + idx] = s4; 76 t[4 + idx] = s5; 77 t[5 + idx] = s6; 78 t[6 + idx] = s7; 79 idx += 7; 80 } 81 /* backward solve the L^T */ 82 for (i = n - 1; i >= 0; i--) { 83 v = aa + 49 * diag[i] - 49; 84 vi = aj + diag[i] - 1; 85 nz = diag[i] - ai[i]; 86 idt = 7 * i; 87 s1 = t[idt]; 88 s2 = t[1 + idt]; 89 s3 = t[2 + idt]; 90 s4 = t[3 + idt]; 91 s5 = t[4 + idt]; 92 s6 = t[5 + idt]; 93 s7 = t[6 + idt]; 94 while (nz--) { 95 idx = 7 * (*vi--); 96 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6 + v[6] * s7; 97 t[idx + 1] -= v[7] * s1 + v[8] * s2 + v[9] * s3 + v[10] * s4 + v[11] * s5 + v[12] * s6 + v[13] * s7; 98 t[idx + 2] -= v[14] * s1 + v[15] * s2 + v[16] * s3 + v[17] * s4 + v[18] * s5 + v[19] * s6 + v[20] * s7; 99 t[idx + 3] -= v[21] * s1 + v[22] * s2 + v[23] * s3 + v[24] * s4 + v[25] * s5 + v[26] * s6 + v[27] * s7; 100 t[idx + 4] -= v[28] * s1 + v[29] * s2 + v[30] * s3 + v[31] * s4 + v[32] * s5 + v[33] * s6 + v[34] * s7; 101 t[idx + 5] -= v[35] * s1 + v[36] * s2 + v[37] * s3 + v[38] * s4 + v[39] * s5 + v[40] * s6 + v[41] * s7; 102 t[idx + 6] -= v[42] * s1 + v[43] * s2 + v[44] * s3 + v[45] * s4 + v[46] * s5 + v[47] * s6 + v[48] * s7; 103 v -= 49; 104 } 105 } 106 107 /* copy t into x according to permutation */ 108 ii = 0; 109 for (i = 0; i < n; i++) { 110 ir = 7 * r[i]; 111 x[ir] = t[ii]; 112 x[ir + 1] = t[ii + 1]; 113 x[ir + 2] = t[ii + 2]; 114 x[ir + 3] = t[ii + 3]; 115 x[ir + 4] = t[ii + 4]; 116 x[ir + 5] = t[ii + 5]; 117 x[ir + 6] = t[ii + 6]; 118 ii += 7; 119 } 120 121 PetscCall(ISRestoreIndices(isrow, &rout)); 122 PetscCall(ISRestoreIndices(iscol, &cout)); 123 PetscCall(VecRestoreArrayRead(bb, &b)); 124 PetscCall(VecRestoreArray(xx, &x)); 125 PetscCall(PetscLogFlops(2.0 * 49 * (a->nz) - 7.0 * A->cmap->n)); 126 PetscFunctionReturn(0); 127 } 128 PetscErrorCode MatSolveTranspose_SeqBAIJ_7(Mat A, Vec bb, Vec xx) { 129 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 130 IS iscol = a->col, isrow = a->row; 131 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 132 const PetscInt *r, *c, *rout, *cout; 133 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 134 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 135 const MatScalar *aa = a->a, *v; 136 PetscScalar s1, s2, s3, s4, s5, s6, s7, x1, x2, x3, x4, x5, x6, x7, *x, *t; 137 const PetscScalar *b; 138 139 PetscFunctionBegin; 140 PetscCall(VecGetArrayRead(bb, &b)); 141 PetscCall(VecGetArray(xx, &x)); 142 t = a->solve_work; 143 144 PetscCall(ISGetIndices(isrow, &rout)); 145 r = rout; 146 PetscCall(ISGetIndices(iscol, &cout)); 147 c = cout; 148 149 /* copy b into temp work space according to permutation */ 150 for (i = 0; i < n; i++) { 151 ii = bs * i; 152 ic = bs * c[i]; 153 t[ii] = b[ic]; 154 t[ii + 1] = b[ic + 1]; 155 t[ii + 2] = b[ic + 2]; 156 t[ii + 3] = b[ic + 3]; 157 t[ii + 4] = b[ic + 4]; 158 t[ii + 5] = b[ic + 5]; 159 t[ii + 6] = b[ic + 6]; 160 } 161 162 /* forward solve the U^T */ 163 idx = 0; 164 for (i = 0; i < n; i++) { 165 v = aa + bs2 * diag[i]; 166 /* multiply by the inverse of the block diagonal */ 167 x1 = t[idx]; 168 x2 = t[1 + idx]; 169 x3 = t[2 + idx]; 170 x4 = t[3 + idx]; 171 x5 = t[4 + idx]; 172 x6 = t[5 + idx]; 173 x7 = t[6 + idx]; 174 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6 + v[6] * x7; 175 s2 = v[7] * x1 + v[8] * x2 + v[9] * x3 + v[10] * x4 + v[11] * x5 + v[12] * x6 + v[13] * x7; 176 s3 = v[14] * x1 + v[15] * x2 + v[16] * x3 + v[17] * x4 + v[18] * x5 + v[19] * x6 + v[20] * x7; 177 s4 = v[21] * x1 + v[22] * x2 + v[23] * x3 + v[24] * x4 + v[25] * x5 + v[26] * x6 + v[27] * x7; 178 s5 = v[28] * x1 + v[29] * x2 + v[30] * x3 + v[31] * x4 + v[32] * x5 + v[33] * x6 + v[34] * x7; 179 s6 = v[35] * x1 + v[36] * x2 + v[37] * x3 + v[38] * x4 + v[39] * x5 + v[40] * x6 + v[41] * x7; 180 s7 = v[42] * x1 + v[43] * x2 + v[44] * x3 + v[45] * x4 + v[46] * x5 + v[47] * x6 + v[48] * x7; 181 v -= bs2; 182 183 vi = aj + diag[i] - 1; 184 nz = diag[i] - diag[i + 1] - 1; 185 for (j = 0; j > -nz; j--) { 186 oidx = bs * vi[j]; 187 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6 + v[6] * s7; 188 t[oidx + 1] -= v[7] * s1 + v[8] * s2 + v[9] * s3 + v[10] * s4 + v[11] * s5 + v[12] * s6 + v[13] * s7; 189 t[oidx + 2] -= v[14] * s1 + v[15] * s2 + v[16] * s3 + v[17] * s4 + v[18] * s5 + v[19] * s6 + v[20] * s7; 190 t[oidx + 3] -= v[21] * s1 + v[22] * s2 + v[23] * s3 + v[24] * s4 + v[25] * s5 + v[26] * s6 + v[27] * s7; 191 t[oidx + 4] -= v[28] * s1 + v[29] * s2 + v[30] * s3 + v[31] * s4 + v[32] * s5 + v[33] * s6 + v[34] * s7; 192 t[oidx + 5] -= v[35] * s1 + v[36] * s2 + v[37] * s3 + v[38] * s4 + v[39] * s5 + v[40] * s6 + v[41] * s7; 193 t[oidx + 6] -= v[42] * s1 + v[43] * s2 + v[44] * s3 + v[45] * s4 + v[46] * s5 + v[47] * s6 + v[48] * s7; 194 v -= bs2; 195 } 196 t[idx] = s1; 197 t[1 + idx] = s2; 198 t[2 + idx] = s3; 199 t[3 + idx] = s4; 200 t[4 + idx] = s5; 201 t[5 + idx] = s6; 202 t[6 + idx] = s7; 203 idx += bs; 204 } 205 /* backward solve the L^T */ 206 for (i = n - 1; i >= 0; i--) { 207 v = aa + bs2 * ai[i]; 208 vi = aj + ai[i]; 209 nz = ai[i + 1] - ai[i]; 210 idt = bs * i; 211 s1 = t[idt]; 212 s2 = t[1 + idt]; 213 s3 = t[2 + idt]; 214 s4 = t[3 + idt]; 215 s5 = t[4 + idt]; 216 s6 = t[5 + idt]; 217 s7 = t[6 + idt]; 218 for (j = 0; j < nz; j++) { 219 idx = bs * vi[j]; 220 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6 + v[6] * s7; 221 t[idx + 1] -= v[7] * s1 + v[8] * s2 + v[9] * s3 + v[10] * s4 + v[11] * s5 + v[12] * s6 + v[13] * s7; 222 t[idx + 2] -= v[14] * s1 + v[15] * s2 + v[16] * s3 + v[17] * s4 + v[18] * s5 + v[19] * s6 + v[20] * s7; 223 t[idx + 3] -= v[21] * s1 + v[22] * s2 + v[23] * s3 + v[24] * s4 + v[25] * s5 + v[26] * s6 + v[27] * s7; 224 t[idx + 4] -= v[28] * s1 + v[29] * s2 + v[30] * s3 + v[31] * s4 + v[32] * s5 + v[33] * s6 + v[34] * s7; 225 t[idx + 5] -= v[35] * s1 + v[36] * s2 + v[37] * s3 + v[38] * s4 + v[39] * s5 + v[40] * s6 + v[41] * s7; 226 t[idx + 6] -= v[42] * s1 + v[43] * s2 + v[44] * s3 + v[45] * s4 + v[46] * s5 + v[47] * s6 + v[48] * s7; 227 v += bs2; 228 } 229 } 230 231 /* copy t into x according to permutation */ 232 for (i = 0; i < n; i++) { 233 ii = bs * i; 234 ir = bs * r[i]; 235 x[ir] = t[ii]; 236 x[ir + 1] = t[ii + 1]; 237 x[ir + 2] = t[ii + 2]; 238 x[ir + 3] = t[ii + 3]; 239 x[ir + 4] = t[ii + 4]; 240 x[ir + 5] = t[ii + 5]; 241 x[ir + 6] = t[ii + 6]; 242 } 243 244 PetscCall(ISRestoreIndices(isrow, &rout)); 245 PetscCall(ISRestoreIndices(iscol, &cout)); 246 PetscCall(VecRestoreArrayRead(bb, &b)); 247 PetscCall(VecRestoreArray(xx, &x)); 248 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 249 PetscFunctionReturn(0); 250 } 251