1 #include <petsc/private/sfimpl.h> 2 #include <../src/vec/is/sf/impls/basic/sfpack.h> 3 #include <../src/vec/is/sf/impls/basic/sfbasic.h> 4 5 /* This is a C file that contains packing facilities, with dispatches to device if enabled. */ 6 7 /* 8 * MPI_Reduce_local is not really useful because it can't handle sparse data and it vectorizes "in the wrong direction", 9 * therefore we pack data types manually. This file defines packing routines for the standard data types. 10 */ 11 12 #define CPPJoin4(a, b, c, d) a##_##b##_##c##_##d 13 14 /* Operations working like s += t */ 15 #define OP_BINARY(op, s, t) \ 16 do { \ 17 (s) = (s)op(t); \ 18 } while (0) /* binary ops in the middle such as +, *, && etc. */ 19 #define OP_FUNCTION(op, s, t) \ 20 do { \ 21 (s) = op((s), (t)); \ 22 } while (0) /* ops like a function, such as PetscMax, PetscMin */ 23 #define OP_LXOR(op, s, t) \ 24 do { \ 25 (s) = (!(s)) != (!(t)); \ 26 } while (0) /* logical exclusive OR */ 27 #define OP_ASSIGN(op, s, t) \ 28 do { \ 29 (s) = (t); \ 30 } while (0) 31 /* Ref MPI MAXLOC */ 32 #define OP_XLOC(op, s, t) \ 33 do { \ 34 if ((s).u == (t).u) (s).i = PetscMin((s).i, (t).i); \ 35 else if (!((s).u op(t).u)) s = t; \ 36 } while (0) 37 38 /* DEF_PackFunc - macro defining a Pack routine 39 40 Arguments of the macro: 41 +Type Type of the basic data in an entry, i.e., int, PetscInt, PetscReal etc. It is not the type of an entry. 42 .BS Block size for vectorization. It is a factor of bsz. 43 -EQ (bs == BS) ? 1 : 0. EQ is a compile-time const to help compiler optimizations. See below. 44 45 Arguments of the Pack routine: 46 +count Number of indices in idx[]. 47 .start When opt and idx are NULL, it means indices are contiguous & start is the first index; otherwise, not used. 48 .opt Per-pack optimization plan. NULL means no such plan. 49 .idx Indices of entries to packed. 50 .link Provide a context for the current call, such as link->bs, number of basic types in an entry. Ex. if unit is MPI_2INT, then bs=2 and the basic type is int. 51 .unpacked Address of the unpacked data. The entries will be packed are unpacked[idx[i]],for i in [0,count). 52 -packed Address of the packed data. 53 */ 54 #define DEF_PackFunc(Type, BS, EQ) \ 55 static PetscErrorCode CPPJoin4(Pack, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *unpacked, void *packed) \ 56 { \ 57 const Type *u = (const Type *)unpacked, *u2; \ 58 Type *p = (Type *)packed, *p2; \ 59 PetscInt i, j, k, X, Y, r, bs = link->bs; \ 60 const PetscInt M = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */ \ 61 const PetscInt MBS = M * BS; /* MBS=bs. We turn MBS into a compile time const when EQ=1. */ \ 62 PetscFunctionBegin; \ 63 if (!idx) PetscCall(PetscArraycpy(p, u + start * MBS, MBS * count)); /* idx[] are contiguous */ \ 64 else if (opt) { /* has optimizations available */ p2 = p; \ 65 for (r = 0; r < opt->n; r++) { \ 66 u2 = u + opt->start[r] * MBS; \ 67 X = opt->X[r]; \ 68 Y = opt->Y[r]; \ 69 for (k = 0; k < opt->dz[r]; k++) \ 70 for (j = 0; j < opt->dy[r]; j++) { \ 71 PetscCall(PetscArraycpy(p2, u2 + (X * Y * k + X * j) * MBS, opt->dx[r] * MBS)); \ 72 p2 += opt->dx[r] * MBS; \ 73 } \ 74 } \ 75 } else { \ 76 for (i = 0; i < count; i++) \ 77 for (j = 0; j < M; j++) /* Decent compilers should eliminate this loop when M = const 1 */ \ 78 for (k = 0; k < BS; k++) /* Compiler either unrolls (BS=1) or vectorizes (BS=2,4,8,etc) this loop */ \ 79 p[i * MBS + j * BS + k] = u[idx[i] * MBS + j * BS + k]; \ 80 } \ 81 PetscFunctionReturn(PETSC_SUCCESS); \ 82 } 83 84 /* DEF_Action - macro defining a UnpackAndInsert routine that unpacks data from a contiguous buffer 85 and inserts into a sparse array. 86 87 Arguments: 88 .Type Type of the data 89 .BS Block size for vectorization 90 .EQ (bs == BS) ? 1 : 0. EQ is a compile-time const. 91 92 Notes: 93 This macro is not combined with DEF_ActionAndOp because we want to use memcpy in this macro. 94 */ 95 #define DEF_UnpackFunc(Type, BS, EQ) \ 96 static PetscErrorCode CPPJoin4(UnpackAndInsert, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *unpacked, const void *packed) \ 97 { \ 98 Type *u = (Type *)unpacked, *u2; \ 99 const Type *p = (const Type *)packed; \ 100 PetscInt i, j, k, X, Y, r, bs = link->bs; \ 101 const PetscInt M = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */ \ 102 const PetscInt MBS = M * BS; /* MBS=bs. We turn MBS into a compile time const when EQ=1. */ \ 103 PetscFunctionBegin; \ 104 if (!idx) { \ 105 u += start * MBS; \ 106 if (u != p) PetscCall(PetscArraycpy(u, p, count *MBS)); \ 107 } else if (opt) { /* has optimizations available */ \ 108 for (r = 0; r < opt->n; r++) { \ 109 u2 = u + opt->start[r] * MBS; \ 110 X = opt->X[r]; \ 111 Y = opt->Y[r]; \ 112 for (k = 0; k < opt->dz[r]; k++) \ 113 for (j = 0; j < opt->dy[r]; j++) { \ 114 PetscCall(PetscArraycpy(u2 + (X * Y * k + X * j) * MBS, p, opt->dx[r] * MBS)); \ 115 p += opt->dx[r] * MBS; \ 116 } \ 117 } \ 118 } else { \ 119 for (i = 0; i < count; i++) \ 120 for (j = 0; j < M; j++) \ 121 for (k = 0; k < BS; k++) u[idx[i] * MBS + j * BS + k] = p[i * MBS + j * BS + k]; \ 122 } \ 123 PetscFunctionReturn(PETSC_SUCCESS); \ 124 } 125 126 /* DEF_UnpackAndOp - macro defining a UnpackAndOp routine where Op should not be Insert 127 128 Arguments: 129 +Opname Name of the Op, such as Add, Mult, LAND, etc. 130 .Type Type of the data 131 .BS Block size for vectorization 132 .EQ (bs == BS) ? 1 : 0. EQ is a compile-time const. 133 .Op Operator for the op, such as +, *, &&, ||, PetscMax, PetscMin, etc. 134 .OpApply Macro defining application of the op. Could be OP_BINARY, OP_FUNCTION, OP_LXOR 135 */ 136 #define DEF_UnpackAndOp(Type, BS, EQ, Opname, Op, OpApply) \ 137 static PetscErrorCode CPPJoin4(UnpackAnd##Opname, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *unpacked, const void *packed) \ 138 { \ 139 Type *u = (Type *)unpacked, *u2; \ 140 const Type *p = (const Type *)packed; \ 141 PetscInt i, j, k, X, Y, r, bs = link->bs; \ 142 const PetscInt M = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */ \ 143 const PetscInt MBS = M * BS; /* MBS=bs. We turn MBS into a compile time const when EQ=1. */ \ 144 PetscFunctionBegin; \ 145 if (!idx) { \ 146 u += start * MBS; \ 147 for (i = 0; i < count; i++) \ 148 for (j = 0; j < M; j++) \ 149 for (k = 0; k < BS; k++) OpApply(Op, u[i * MBS + j * BS + k], p[i * MBS + j * BS + k]); \ 150 } else if (opt) { /* idx[] has patterns */ \ 151 for (r = 0; r < opt->n; r++) { \ 152 u2 = u + opt->start[r] * MBS; \ 153 X = opt->X[r]; \ 154 Y = opt->Y[r]; \ 155 for (k = 0; k < opt->dz[r]; k++) \ 156 for (j = 0; j < opt->dy[r]; j++) { \ 157 for (i = 0; i < opt->dx[r] * MBS; i++) OpApply(Op, u2[(X * Y * k + X * j) * MBS + i], p[i]); \ 158 p += opt->dx[r] * MBS; \ 159 } \ 160 } \ 161 } else { \ 162 for (i = 0; i < count; i++) \ 163 for (j = 0; j < M; j++) \ 164 for (k = 0; k < BS; k++) OpApply(Op, u[idx[i] * MBS + j * BS + k], p[i * MBS + j * BS + k]); \ 165 } \ 166 PetscFunctionReturn(PETSC_SUCCESS); \ 167 } 168 169 #define DEF_FetchAndOp(Type, BS, EQ, Opname, Op, OpApply) \ 170 static PetscErrorCode CPPJoin4(FetchAnd##Opname, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *unpacked, void *packed) \ 171 { \ 172 Type *u = (Type *)unpacked, *p = (Type *)packed, tmp; \ 173 PetscInt i, j, k, r, l, bs = link->bs; \ 174 const PetscInt M = (EQ) ? 1 : bs / BS; \ 175 const PetscInt MBS = M * BS; \ 176 PetscFunctionBegin; \ 177 for (i = 0; i < count; i++) { \ 178 r = (!idx ? start + i : idx[i]) * MBS; \ 179 l = i * MBS; \ 180 for (j = 0; j < M; j++) \ 181 for (k = 0; k < BS; k++) { \ 182 tmp = u[r + j * BS + k]; \ 183 OpApply(Op, u[r + j * BS + k], p[l + j * BS + k]); \ 184 p[l + j * BS + k] = tmp; \ 185 } \ 186 } \ 187 PetscFunctionReturn(PETSC_SUCCESS); \ 188 } 189 190 #define DEF_ScatterAndOp(Type, BS, EQ, Opname, Op, OpApply) \ 191 static PetscErrorCode CPPJoin4(ScatterAnd##Opname, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) \ 192 { \ 193 const Type *u = (const Type *)src; \ 194 Type *v = (Type *)dst; \ 195 PetscInt i, j, k, s, t, X, Y, bs = link->bs; \ 196 const PetscInt M = (EQ) ? 1 : bs / BS; \ 197 const PetscInt MBS = M * BS; \ 198 PetscFunctionBegin; \ 199 if (!srcIdx) { /* src is contiguous */ \ 200 u += srcStart * MBS; \ 201 PetscCall(CPPJoin4(UnpackAnd##Opname, Type, BS, EQ)(link, count, dstStart, dstOpt, dstIdx, dst, u)); \ 202 } else if (srcOpt && !dstIdx) { /* src is 3D, dst is contiguous */ \ 203 u += srcOpt->start[0] * MBS; \ 204 v += dstStart * MBS; \ 205 X = srcOpt->X[0]; \ 206 Y = srcOpt->Y[0]; \ 207 for (k = 0; k < srcOpt->dz[0]; k++) \ 208 for (j = 0; j < srcOpt->dy[0]; j++) { \ 209 for (i = 0; i < srcOpt->dx[0] * MBS; i++) OpApply(Op, v[i], u[(X * Y * k + X * j) * MBS + i]); \ 210 v += srcOpt->dx[0] * MBS; \ 211 } \ 212 } else { /* all other cases */ \ 213 for (i = 0; i < count; i++) { \ 214 s = (!srcIdx ? srcStart + i : srcIdx[i]) * MBS; \ 215 t = (!dstIdx ? dstStart + i : dstIdx[i]) * MBS; \ 216 for (j = 0; j < M; j++) \ 217 for (k = 0; k < BS; k++) OpApply(Op, v[t + j * BS + k], u[s + j * BS + k]); \ 218 } \ 219 } \ 220 PetscFunctionReturn(PETSC_SUCCESS); \ 221 } 222 223 #define DEF_FetchAndOpLocal(Type, BS, EQ, Opname, Op, OpApply) \ 224 static PetscErrorCode CPPJoin4(FetchAnd##Opname##Local, Type, BS, EQ)(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate) \ 225 { \ 226 Type *rdata = (Type *)rootdata, *lupdate = (Type *)leafupdate; \ 227 const Type *ldata = (const Type *)leafdata; \ 228 PetscInt i, j, k, r, l, bs = link->bs; \ 229 const PetscInt M = (EQ) ? 1 : bs / BS; \ 230 const PetscInt MBS = M * BS; \ 231 PetscFunctionBegin; \ 232 for (i = 0; i < count; i++) { \ 233 r = (rootidx ? rootidx[i] : rootstart + i) * MBS; \ 234 l = (leafidx ? leafidx[i] : leafstart + i) * MBS; \ 235 for (j = 0; j < M; j++) \ 236 for (k = 0; k < BS; k++) { \ 237 lupdate[l + j * BS + k] = rdata[r + j * BS + k]; \ 238 OpApply(Op, rdata[r + j * BS + k], ldata[l + j * BS + k]); \ 239 } \ 240 } \ 241 PetscFunctionReturn(PETSC_SUCCESS); \ 242 } 243 244 /* Pack, Unpack/Fetch ops */ 245 #define DEF_Pack(Type, BS, EQ) \ 246 DEF_PackFunc(Type, BS, EQ) DEF_UnpackFunc(Type, BS, EQ) DEF_ScatterAndOp(Type, BS, EQ, Insert, =, OP_ASSIGN) static void CPPJoin4(PackInit_Pack, Type, BS, EQ)(PetscSFLink link) \ 247 { \ 248 link->h_Pack = CPPJoin4(Pack, Type, BS, EQ); \ 249 link->h_UnpackAndInsert = CPPJoin4(UnpackAndInsert, Type, BS, EQ); \ 250 link->h_ScatterAndInsert = CPPJoin4(ScatterAndInsert, Type, BS, EQ); \ 251 } 252 253 /* Add, Mult ops */ 254 #define DEF_Add(Type, BS, EQ) \ 255 DEF_UnpackAndOp(Type, BS, EQ, Add, +, OP_BINARY) DEF_UnpackAndOp(Type, BS, EQ, Mult, *, OP_BINARY) DEF_FetchAndOp(Type, BS, EQ, Add, +, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, Add, +, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, Mult, *, OP_BINARY) DEF_FetchAndOpLocal(Type, BS, EQ, Add, +, OP_BINARY) static void CPPJoin4(PackInit_Add, Type, BS, EQ)(PetscSFLink link) \ 256 { \ 257 link->h_UnpackAndAdd = CPPJoin4(UnpackAndAdd, Type, BS, EQ); \ 258 link->h_UnpackAndMult = CPPJoin4(UnpackAndMult, Type, BS, EQ); \ 259 link->h_FetchAndAdd = CPPJoin4(FetchAndAdd, Type, BS, EQ); \ 260 link->h_ScatterAndAdd = CPPJoin4(ScatterAndAdd, Type, BS, EQ); \ 261 link->h_ScatterAndMult = CPPJoin4(ScatterAndMult, Type, BS, EQ); \ 262 link->h_FetchAndAddLocal = CPPJoin4(FetchAndAddLocal, Type, BS, EQ); \ 263 } 264 265 /* Max, Min ops */ 266 #define DEF_Cmp(Type, BS, EQ) \ 267 DEF_UnpackAndOp(Type, BS, EQ, Max, PetscMax, OP_FUNCTION) DEF_UnpackAndOp(Type, BS, EQ, Min, PetscMin, OP_FUNCTION) DEF_ScatterAndOp(Type, BS, EQ, Max, PetscMax, OP_FUNCTION) DEF_ScatterAndOp(Type, BS, EQ, Min, PetscMin, OP_FUNCTION) static void CPPJoin4(PackInit_Compare, Type, BS, EQ)(PetscSFLink link) \ 268 { \ 269 link->h_UnpackAndMax = CPPJoin4(UnpackAndMax, Type, BS, EQ); \ 270 link->h_UnpackAndMin = CPPJoin4(UnpackAndMin, Type, BS, EQ); \ 271 link->h_ScatterAndMax = CPPJoin4(ScatterAndMax, Type, BS, EQ); \ 272 link->h_ScatterAndMin = CPPJoin4(ScatterAndMin, Type, BS, EQ); \ 273 } 274 275 /* Logical ops. 276 The operator in OP_LXOR should be empty but is ||. It is not used. Put here to avoid 277 the compilation warning "empty macro arguments are undefined in ISO C90" 278 */ 279 #define DEF_Log(Type, BS, EQ) \ 280 DEF_UnpackAndOp(Type, BS, EQ, LAND, &&, OP_BINARY) DEF_UnpackAndOp(Type, BS, EQ, LOR, ||, OP_BINARY) DEF_UnpackAndOp(Type, BS, EQ, LXOR, ||, OP_LXOR) DEF_ScatterAndOp(Type, BS, EQ, LAND, &&, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, LOR, ||, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, LXOR, ||, OP_LXOR) static void CPPJoin4(PackInit_Logical, Type, BS, EQ)(PetscSFLink link) \ 281 { \ 282 link->h_UnpackAndLAND = CPPJoin4(UnpackAndLAND, Type, BS, EQ); \ 283 link->h_UnpackAndLOR = CPPJoin4(UnpackAndLOR, Type, BS, EQ); \ 284 link->h_UnpackAndLXOR = CPPJoin4(UnpackAndLXOR, Type, BS, EQ); \ 285 link->h_ScatterAndLAND = CPPJoin4(ScatterAndLAND, Type, BS, EQ); \ 286 link->h_ScatterAndLOR = CPPJoin4(ScatterAndLOR, Type, BS, EQ); \ 287 link->h_ScatterAndLXOR = CPPJoin4(ScatterAndLXOR, Type, BS, EQ); \ 288 } 289 290 /* Bitwise ops */ 291 #define DEF_Bit(Type, BS, EQ) \ 292 DEF_UnpackAndOp(Type, BS, EQ, BAND, &, OP_BINARY) DEF_UnpackAndOp(Type, BS, EQ, BOR, |, OP_BINARY) DEF_UnpackAndOp(Type, BS, EQ, BXOR, ^, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, BAND, &, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, BOR, |, OP_BINARY) DEF_ScatterAndOp(Type, BS, EQ, BXOR, ^, OP_BINARY) static void CPPJoin4(PackInit_Bitwise, Type, BS, EQ)(PetscSFLink link) \ 293 { \ 294 link->h_UnpackAndBAND = CPPJoin4(UnpackAndBAND, Type, BS, EQ); \ 295 link->h_UnpackAndBOR = CPPJoin4(UnpackAndBOR, Type, BS, EQ); \ 296 link->h_UnpackAndBXOR = CPPJoin4(UnpackAndBXOR, Type, BS, EQ); \ 297 link->h_ScatterAndBAND = CPPJoin4(ScatterAndBAND, Type, BS, EQ); \ 298 link->h_ScatterAndBOR = CPPJoin4(ScatterAndBOR, Type, BS, EQ); \ 299 link->h_ScatterAndBXOR = CPPJoin4(ScatterAndBXOR, Type, BS, EQ); \ 300 } 301 302 /* Maxloc, Minloc ops */ 303 #define DEF_Xloc(Type, BS, EQ) \ 304 DEF_UnpackAndOp(Type, BS, EQ, Max, >, OP_XLOC) DEF_UnpackAndOp(Type, BS, EQ, Min, <, OP_XLOC) DEF_ScatterAndOp(Type, BS, EQ, Max, >, OP_XLOC) DEF_ScatterAndOp(Type, BS, EQ, Min, <, OP_XLOC) static void CPPJoin4(PackInit_Xloc, Type, BS, EQ)(PetscSFLink link) \ 305 { \ 306 link->h_UnpackAndMaxloc = CPPJoin4(UnpackAndMax, Type, BS, EQ); \ 307 link->h_UnpackAndMinloc = CPPJoin4(UnpackAndMin, Type, BS, EQ); \ 308 link->h_ScatterAndMaxloc = CPPJoin4(ScatterAndMax, Type, BS, EQ); \ 309 link->h_ScatterAndMinloc = CPPJoin4(ScatterAndMin, Type, BS, EQ); \ 310 } 311 312 #define DEF_IntegerType(Type, BS, EQ) \ 313 DEF_Pack(Type, BS, EQ) DEF_Add(Type, BS, EQ) DEF_Cmp(Type, BS, EQ) DEF_Log(Type, BS, EQ) DEF_Bit(Type, BS, EQ) static void CPPJoin4(PackInit_IntegerType, Type, BS, EQ)(PetscSFLink link) \ 314 { \ 315 CPPJoin4(PackInit_Pack, Type, BS, EQ)(link); \ 316 CPPJoin4(PackInit_Add, Type, BS, EQ)(link); \ 317 CPPJoin4(PackInit_Compare, Type, BS, EQ)(link); \ 318 CPPJoin4(PackInit_Logical, Type, BS, EQ)(link); \ 319 CPPJoin4(PackInit_Bitwise, Type, BS, EQ)(link); \ 320 } 321 322 #define DEF_RealType(Type, BS, EQ) \ 323 DEF_Pack(Type, BS, EQ) DEF_Add(Type, BS, EQ) DEF_Cmp(Type, BS, EQ) static void CPPJoin4(PackInit_RealType, Type, BS, EQ)(PetscSFLink link) \ 324 { \ 325 CPPJoin4(PackInit_Pack, Type, BS, EQ)(link); \ 326 CPPJoin4(PackInit_Add, Type, BS, EQ)(link); \ 327 CPPJoin4(PackInit_Compare, Type, BS, EQ)(link); \ 328 } 329 330 #if defined(PETSC_HAVE_COMPLEX) 331 #define DEF_ComplexType(Type, BS, EQ) \ 332 DEF_Pack(Type, BS, EQ) DEF_Add(Type, BS, EQ) static void CPPJoin4(PackInit_ComplexType, Type, BS, EQ)(PetscSFLink link) \ 333 { \ 334 CPPJoin4(PackInit_Pack, Type, BS, EQ)(link); \ 335 CPPJoin4(PackInit_Add, Type, BS, EQ)(link); \ 336 } 337 #endif 338 339 #define DEF_DumbType(Type, BS, EQ) \ 340 DEF_Pack(Type, BS, EQ) static void CPPJoin4(PackInit_DumbType, Type, BS, EQ)(PetscSFLink link) \ 341 { \ 342 CPPJoin4(PackInit_Pack, Type, BS, EQ)(link); \ 343 } 344 345 /* Maxloc, Minloc */ 346 #define DEF_PairType(Type, BS, EQ) \ 347 DEF_Pack(Type, BS, EQ) DEF_Xloc(Type, BS, EQ) static void CPPJoin4(PackInit_PairType, Type, BS, EQ)(PetscSFLink link) \ 348 { \ 349 CPPJoin4(PackInit_Pack, Type, BS, EQ)(link); \ 350 CPPJoin4(PackInit_Xloc, Type, BS, EQ)(link); \ 351 } 352 353 DEF_IntegerType(PetscInt, 1, 1) /* unit = 1 MPIU_INT */ 354 DEF_IntegerType(PetscInt, 2, 1) /* unit = 2 MPIU_INTs */ 355 DEF_IntegerType(PetscInt, 4, 1) /* unit = 4 MPIU_INTs */ 356 DEF_IntegerType(PetscInt, 8, 1) /* unit = 8 MPIU_INTs */ 357 DEF_IntegerType(PetscInt, 1, 0) /* unit = 1*n MPIU_INTs, n>1 */ 358 DEF_IntegerType(PetscInt, 2, 0) /* unit = 2*n MPIU_INTs, n>1 */ 359 DEF_IntegerType(PetscInt, 4, 0) /* unit = 4*n MPIU_INTs, n>1 */ 360 DEF_IntegerType(PetscInt, 8, 0) /* unit = 8*n MPIU_INTs, n>1. Routines with bigger BS are tried first. */ 361 362 #if defined(PETSC_USE_64BIT_INDICES) /* Do not need (though it is OK) to generate redundant functions if PetscInt is int */ 363 DEF_IntegerType(int, 1, 1) DEF_IntegerType(int, 2, 1) DEF_IntegerType(int, 4, 1) DEF_IntegerType(int, 8, 1) DEF_IntegerType(int, 1, 0) DEF_IntegerType(int, 2, 0) DEF_IntegerType(int, 4, 0) DEF_IntegerType(int, 8, 0) 364 #endif 365 366 /* The typedefs are used to get a typename without space that CPPJoin can handle */ 367 typedef signed char SignedChar; 368 DEF_IntegerType(SignedChar, 1, 1) DEF_IntegerType(SignedChar, 2, 1) DEF_IntegerType(SignedChar, 4, 1) DEF_IntegerType(SignedChar, 8, 1) DEF_IntegerType(SignedChar, 1, 0) DEF_IntegerType(SignedChar, 2, 0) DEF_IntegerType(SignedChar, 4, 0) DEF_IntegerType(SignedChar, 8, 0) 369 370 typedef unsigned char UnsignedChar; 371 DEF_IntegerType(UnsignedChar, 1, 1) DEF_IntegerType(UnsignedChar, 2, 1) DEF_IntegerType(UnsignedChar, 4, 1) DEF_IntegerType(UnsignedChar, 8, 1) DEF_IntegerType(UnsignedChar, 1, 0) DEF_IntegerType(UnsignedChar, 2, 0) DEF_IntegerType(UnsignedChar, 4, 0) DEF_IntegerType(UnsignedChar, 8, 0) 372 373 DEF_RealType(PetscReal, 1, 1) DEF_RealType(PetscReal, 2, 1) DEF_RealType(PetscReal, 4, 1) DEF_RealType(PetscReal, 8, 1) DEF_RealType(PetscReal, 1, 0) DEF_RealType(PetscReal, 2, 0) DEF_RealType(PetscReal, 4, 0) DEF_RealType(PetscReal, 8, 0) 374 #if defined(PETSC_HAVE_COMPLEX) 375 DEF_ComplexType(PetscComplex, 1, 1) DEF_ComplexType(PetscComplex, 2, 1) DEF_ComplexType(PetscComplex, 4, 1) DEF_ComplexType(PetscComplex, 8, 1) DEF_ComplexType(PetscComplex, 1, 0) DEF_ComplexType(PetscComplex, 2, 0) DEF_ComplexType(PetscComplex, 4, 0) DEF_ComplexType(PetscComplex, 8, 0) 376 #endif 377 378 #define PairType(Type1, Type2) Type1##_##Type2 379 typedef struct { 380 int u; 381 int i; 382 } PairType(int, int); 383 typedef struct { 384 PetscInt u; 385 PetscInt i; 386 } PairType(PetscInt, PetscInt); 387 DEF_PairType(PairType(int, int), 1, 1) DEF_PairType(PairType(PetscInt, PetscInt), 1, 1) 388 389 /* If we don't know the basic type, we treat it as a stream of chars or ints */ 390 DEF_DumbType(char, 1, 1) DEF_DumbType(char, 2, 1) DEF_DumbType(char, 4, 1) DEF_DumbType(char, 1, 0) DEF_DumbType(char, 2, 0) DEF_DumbType(char, 4, 0) 391 392 typedef int DumbInt; /* To have a different name than 'int' used above. The name is used to make routine names. */ 393 DEF_DumbType(DumbInt, 1, 1) DEF_DumbType(DumbInt, 2, 1) DEF_DumbType(DumbInt, 4, 1) DEF_DumbType(DumbInt, 8, 1) DEF_DumbType(DumbInt, 1, 0) DEF_DumbType(DumbInt, 2, 0) DEF_DumbType(DumbInt, 4, 0) DEF_DumbType(DumbInt, 8, 0) 394 395 PetscErrorCode PetscSFLinkDestroy(PetscSF sf, PetscSFLink link) 396 { 397 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 398 PetscInt i, nreqs = (bas->nrootreqs + sf->nleafreqs) * 8; 399 400 PetscFunctionBegin; 401 /* Destroy device-specific fields */ 402 if (link->deviceinited) PetscCall((*link->Destroy)(sf, link)); 403 404 /* Destroy host related fields */ 405 if (!link->isbuiltin) PetscCallMPI(MPI_Type_free(&link->unit)); 406 if (!link->use_nvshmem) { 407 for (i = 0; i < nreqs; i++) { /* Persistent reqs must be freed. */ 408 if (link->reqs[i] != MPI_REQUEST_NULL) PetscCallMPI(MPI_Request_free(&link->reqs[i])); 409 } 410 PetscCall(PetscFree(link->reqs)); 411 for (i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 412 PetscCall(PetscFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_HOST])); 413 PetscCall(PetscFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_HOST])); 414 } 415 } 416 PetscCall(PetscFree(link)); 417 PetscFunctionReturn(PETSC_SUCCESS); 418 } 419 420 PetscErrorCode PetscSFLinkCreate(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink) 421 { 422 PetscFunctionBegin; 423 PetscCall(PetscSFSetErrorOnUnsupportedOverlap(sf, unit, rootdata, leafdata)); 424 #if defined(PETSC_HAVE_NVSHMEM) 425 { 426 PetscBool use_nvshmem; 427 PetscCall(PetscSFLinkNvshmemCheck(sf, rootmtype, rootdata, leafmtype, leafdata, &use_nvshmem)); 428 if (use_nvshmem) { 429 PetscCall(PetscSFLinkCreate_NVSHMEM(sf, unit, rootmtype, rootdata, leafmtype, leafdata, op, sfop, mylink)); 430 PetscFunctionReturn(PETSC_SUCCESS); 431 } 432 } 433 #endif 434 PetscCall(PetscSFLinkCreate_MPI(sf, unit, rootmtype, rootdata, leafmtype, leafdata, op, sfop, mylink)); 435 PetscFunctionReturn(PETSC_SUCCESS); 436 } 437 438 PetscErrorCode PetscSFLinkGetInUse(PetscSF sf, MPI_Datatype unit, const void *rootdata, const void *leafdata, PetscCopyMode cmode, PetscSFLink *mylink) 439 { 440 PetscSFLink link, *p; 441 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 442 443 PetscFunctionBegin; 444 /* Look for types in cache */ 445 for (p = &bas->inuse; (link = *p); p = &link->next) { 446 PetscBool match; 447 PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match)); 448 if (match && (rootdata == link->rootdata) && (leafdata == link->leafdata)) { 449 switch (cmode) { 450 case PETSC_OWN_POINTER: 451 *p = link->next; 452 break; /* Remove from inuse list */ 453 case PETSC_USE_POINTER: 454 break; 455 default: 456 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "invalid cmode"); 457 } 458 *mylink = link; 459 PetscFunctionReturn(PETSC_SUCCESS); 460 } 461 } 462 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Could not find pack"); 463 } 464 465 PetscErrorCode PetscSFLinkReclaim(PetscSF sf, PetscSFLink *mylink) 466 { 467 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 468 PetscSFLink link = *mylink; 469 470 PetscFunctionBegin; 471 link->rootdata = NULL; 472 link->leafdata = NULL; 473 link->next = bas->avail; 474 bas->avail = link; 475 *mylink = NULL; 476 PetscFunctionReturn(PETSC_SUCCESS); 477 } 478 479 /* Error out on unsupported overlapped communications */ 480 PetscErrorCode PetscSFSetErrorOnUnsupportedOverlap(PetscSF sf, MPI_Datatype unit, const void *rootdata, const void *leafdata) 481 { 482 PetscSFLink link, *p; 483 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 484 PetscBool match; 485 486 PetscFunctionBegin; 487 if (PetscDefined(USE_DEBUG)) { 488 /* Look up links in use and error out if there is a match. When both rootdata and leafdata are NULL, ignore 489 the potential overlapping since this process does not participate in communication. Overlapping is harmless. 490 */ 491 if (rootdata || leafdata) { 492 for (p = &bas->inuse; (link = *p); p = &link->next) { 493 PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match)); 494 PetscCheck(!match || rootdata != link->rootdata || leafdata != link->leafdata, PETSC_COMM_SELF, PETSC_ERR_SUP, "Overlapped PetscSF with the same rootdata(%p), leafdata(%p) and data type. Undo the overlapping to avoid the error.", rootdata, leafdata); 495 } 496 } 497 } 498 PetscFunctionReturn(PETSC_SUCCESS); 499 } 500 501 static PetscErrorCode PetscSFLinkMemcpy_Host(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) 502 { 503 PetscFunctionBegin; 504 if (n) PetscCall(PetscMemcpy(dst, src, n)); 505 PetscFunctionReturn(PETSC_SUCCESS); 506 } 507 508 PetscErrorCode PetscSFLinkSetUp_Host(PetscSF sf, PetscSFLink link, MPI_Datatype unit) 509 { 510 PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; 511 PetscBool is2Int, is2PetscInt; 512 PetscMPIInt ni, na, nd, combiner; 513 #if defined(PETSC_HAVE_COMPLEX) 514 PetscInt nPetscComplex = 0; 515 #endif 516 517 PetscFunctionBegin; 518 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); 519 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); 520 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 521 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); 522 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); 523 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); 524 #if defined(PETSC_HAVE_COMPLEX) 525 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); 526 #endif 527 PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); 528 PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); 529 /* TODO: should we also handle Fortran MPI_2REAL? */ 530 PetscCallMPI(MPI_Type_get_envelope(unit, &ni, &na, &nd, &combiner)); 531 link->isbuiltin = (combiner == MPI_COMBINER_NAMED) ? PETSC_TRUE : PETSC_FALSE; /* unit is MPI builtin */ 532 link->bs = 1; /* default */ 533 534 if (is2Int) { 535 PackInit_PairType_int_int_1_1(link); 536 link->bs = 1; 537 link->unitbytes = 2 * sizeof(int); 538 link->isbuiltin = PETSC_TRUE; /* unit is PETSc builtin */ 539 link->basicunit = MPI_2INT; 540 link->unit = MPI_2INT; 541 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 542 PackInit_PairType_PetscInt_PetscInt_1_1(link); 543 link->bs = 1; 544 link->unitbytes = 2 * sizeof(PetscInt); 545 link->basicunit = MPIU_2INT; 546 link->isbuiltin = PETSC_TRUE; /* unit is PETSc builtin */ 547 link->unit = MPIU_2INT; 548 } else if (nPetscReal) { 549 if (nPetscReal == 8) PackInit_RealType_PetscReal_8_1(link); 550 else if (nPetscReal % 8 == 0) PackInit_RealType_PetscReal_8_0(link); 551 else if (nPetscReal == 4) PackInit_RealType_PetscReal_4_1(link); 552 else if (nPetscReal % 4 == 0) PackInit_RealType_PetscReal_4_0(link); 553 else if (nPetscReal == 2) PackInit_RealType_PetscReal_2_1(link); 554 else if (nPetscReal % 2 == 0) PackInit_RealType_PetscReal_2_0(link); 555 else if (nPetscReal == 1) PackInit_RealType_PetscReal_1_1(link); 556 else if (nPetscReal % 1 == 0) PackInit_RealType_PetscReal_1_0(link); 557 link->bs = nPetscReal; 558 link->unitbytes = nPetscReal * sizeof(PetscReal); 559 link->basicunit = MPIU_REAL; 560 if (link->bs == 1) { 561 link->isbuiltin = PETSC_TRUE; 562 link->unit = MPIU_REAL; 563 } 564 } else if (nPetscInt) { 565 if (nPetscInt == 8) PackInit_IntegerType_PetscInt_8_1(link); 566 else if (nPetscInt % 8 == 0) PackInit_IntegerType_PetscInt_8_0(link); 567 else if (nPetscInt == 4) PackInit_IntegerType_PetscInt_4_1(link); 568 else if (nPetscInt % 4 == 0) PackInit_IntegerType_PetscInt_4_0(link); 569 else if (nPetscInt == 2) PackInit_IntegerType_PetscInt_2_1(link); 570 else if (nPetscInt % 2 == 0) PackInit_IntegerType_PetscInt_2_0(link); 571 else if (nPetscInt == 1) PackInit_IntegerType_PetscInt_1_1(link); 572 else if (nPetscInt % 1 == 0) PackInit_IntegerType_PetscInt_1_0(link); 573 link->bs = nPetscInt; 574 link->unitbytes = nPetscInt * sizeof(PetscInt); 575 link->basicunit = MPIU_INT; 576 if (link->bs == 1) { 577 link->isbuiltin = PETSC_TRUE; 578 link->unit = MPIU_INT; 579 } 580 #if defined(PETSC_USE_64BIT_INDICES) 581 } else if (nInt) { 582 if (nInt == 8) PackInit_IntegerType_int_8_1(link); 583 else if (nInt % 8 == 0) PackInit_IntegerType_int_8_0(link); 584 else if (nInt == 4) PackInit_IntegerType_int_4_1(link); 585 else if (nInt % 4 == 0) PackInit_IntegerType_int_4_0(link); 586 else if (nInt == 2) PackInit_IntegerType_int_2_1(link); 587 else if (nInt % 2 == 0) PackInit_IntegerType_int_2_0(link); 588 else if (nInt == 1) PackInit_IntegerType_int_1_1(link); 589 else if (nInt % 1 == 0) PackInit_IntegerType_int_1_0(link); 590 link->bs = nInt; 591 link->unitbytes = nInt * sizeof(int); 592 link->basicunit = MPI_INT; 593 if (link->bs == 1) { 594 link->isbuiltin = PETSC_TRUE; 595 link->unit = MPI_INT; 596 } 597 #endif 598 } else if (nSignedChar) { 599 if (nSignedChar == 8) PackInit_IntegerType_SignedChar_8_1(link); 600 else if (nSignedChar % 8 == 0) PackInit_IntegerType_SignedChar_8_0(link); 601 else if (nSignedChar == 4) PackInit_IntegerType_SignedChar_4_1(link); 602 else if (nSignedChar % 4 == 0) PackInit_IntegerType_SignedChar_4_0(link); 603 else if (nSignedChar == 2) PackInit_IntegerType_SignedChar_2_1(link); 604 else if (nSignedChar % 2 == 0) PackInit_IntegerType_SignedChar_2_0(link); 605 else if (nSignedChar == 1) PackInit_IntegerType_SignedChar_1_1(link); 606 else if (nSignedChar % 1 == 0) PackInit_IntegerType_SignedChar_1_0(link); 607 link->bs = nSignedChar; 608 link->unitbytes = nSignedChar * sizeof(SignedChar); 609 link->basicunit = MPI_SIGNED_CHAR; 610 if (link->bs == 1) { 611 link->isbuiltin = PETSC_TRUE; 612 link->unit = MPI_SIGNED_CHAR; 613 } 614 } else if (nUnsignedChar) { 615 if (nUnsignedChar == 8) PackInit_IntegerType_UnsignedChar_8_1(link); 616 else if (nUnsignedChar % 8 == 0) PackInit_IntegerType_UnsignedChar_8_0(link); 617 else if (nUnsignedChar == 4) PackInit_IntegerType_UnsignedChar_4_1(link); 618 else if (nUnsignedChar % 4 == 0) PackInit_IntegerType_UnsignedChar_4_0(link); 619 else if (nUnsignedChar == 2) PackInit_IntegerType_UnsignedChar_2_1(link); 620 else if (nUnsignedChar % 2 == 0) PackInit_IntegerType_UnsignedChar_2_0(link); 621 else if (nUnsignedChar == 1) PackInit_IntegerType_UnsignedChar_1_1(link); 622 else if (nUnsignedChar % 1 == 0) PackInit_IntegerType_UnsignedChar_1_0(link); 623 link->bs = nUnsignedChar; 624 link->unitbytes = nUnsignedChar * sizeof(UnsignedChar); 625 link->basicunit = MPI_UNSIGNED_CHAR; 626 if (link->bs == 1) { 627 link->isbuiltin = PETSC_TRUE; 628 link->unit = MPI_UNSIGNED_CHAR; 629 } 630 #if defined(PETSC_HAVE_COMPLEX) 631 } else if (nPetscComplex) { 632 if (nPetscComplex == 8) PackInit_ComplexType_PetscComplex_8_1(link); 633 else if (nPetscComplex % 8 == 0) PackInit_ComplexType_PetscComplex_8_0(link); 634 else if (nPetscComplex == 4) PackInit_ComplexType_PetscComplex_4_1(link); 635 else if (nPetscComplex % 4 == 0) PackInit_ComplexType_PetscComplex_4_0(link); 636 else if (nPetscComplex == 2) PackInit_ComplexType_PetscComplex_2_1(link); 637 else if (nPetscComplex % 2 == 0) PackInit_ComplexType_PetscComplex_2_0(link); 638 else if (nPetscComplex == 1) PackInit_ComplexType_PetscComplex_1_1(link); 639 else if (nPetscComplex % 1 == 0) PackInit_ComplexType_PetscComplex_1_0(link); 640 link->bs = nPetscComplex; 641 link->unitbytes = nPetscComplex * sizeof(PetscComplex); 642 link->basicunit = MPIU_COMPLEX; 643 if (link->bs == 1) { 644 link->isbuiltin = PETSC_TRUE; 645 link->unit = MPIU_COMPLEX; 646 } 647 #endif 648 } else { 649 MPI_Aint lb, nbyte; 650 PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte)); 651 PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb); 652 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 653 if (nbyte == 4) PackInit_DumbType_char_4_1(link); 654 else if (nbyte % 4 == 0) PackInit_DumbType_char_4_0(link); 655 else if (nbyte == 2) PackInit_DumbType_char_2_1(link); 656 else if (nbyte % 2 == 0) PackInit_DumbType_char_2_0(link); 657 else if (nbyte == 1) PackInit_DumbType_char_1_1(link); 658 else if (nbyte % 1 == 0) PackInit_DumbType_char_1_0(link); 659 link->bs = nbyte; 660 link->unitbytes = nbyte; 661 link->basicunit = MPI_BYTE; 662 } else { 663 nInt = nbyte / sizeof(int); 664 if (nInt == 8) PackInit_DumbType_DumbInt_8_1(link); 665 else if (nInt % 8 == 0) PackInit_DumbType_DumbInt_8_0(link); 666 else if (nInt == 4) PackInit_DumbType_DumbInt_4_1(link); 667 else if (nInt % 4 == 0) PackInit_DumbType_DumbInt_4_0(link); 668 else if (nInt == 2) PackInit_DumbType_DumbInt_2_1(link); 669 else if (nInt % 2 == 0) PackInit_DumbType_DumbInt_2_0(link); 670 else if (nInt == 1) PackInit_DumbType_DumbInt_1_1(link); 671 else if (nInt % 1 == 0) PackInit_DumbType_DumbInt_1_0(link); 672 link->bs = nInt; 673 link->unitbytes = nbyte; 674 link->basicunit = MPI_INT; 675 } 676 if (link->isbuiltin) link->unit = unit; 677 } 678 679 if (!link->isbuiltin) PetscCallMPI(MPI_Type_dup(unit, &link->unit)); 680 681 link->Memcpy = PetscSFLinkMemcpy_Host; 682 PetscFunctionReturn(PETSC_SUCCESS); 683 } 684 685 PetscErrorCode PetscSFLinkGetUnpackAndOp(PetscSFLink link, PetscMemType mtype, MPI_Op op, PetscBool atomic, PetscErrorCode (**UnpackAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, const void *)) 686 { 687 PetscFunctionBegin; 688 *UnpackAndOp = NULL; 689 if (PetscMemTypeHost(mtype)) { 690 if (op == MPI_REPLACE) *UnpackAndOp = link->h_UnpackAndInsert; 691 else if (op == MPI_SUM || op == MPIU_SUM) *UnpackAndOp = link->h_UnpackAndAdd; 692 else if (op == MPI_PROD) *UnpackAndOp = link->h_UnpackAndMult; 693 else if (op == MPI_MAX || op == MPIU_MAX) *UnpackAndOp = link->h_UnpackAndMax; 694 else if (op == MPI_MIN || op == MPIU_MIN) *UnpackAndOp = link->h_UnpackAndMin; 695 else if (op == MPI_LAND) *UnpackAndOp = link->h_UnpackAndLAND; 696 else if (op == MPI_BAND) *UnpackAndOp = link->h_UnpackAndBAND; 697 else if (op == MPI_LOR) *UnpackAndOp = link->h_UnpackAndLOR; 698 else if (op == MPI_BOR) *UnpackAndOp = link->h_UnpackAndBOR; 699 else if (op == MPI_LXOR) *UnpackAndOp = link->h_UnpackAndLXOR; 700 else if (op == MPI_BXOR) *UnpackAndOp = link->h_UnpackAndBXOR; 701 else if (op == MPI_MAXLOC) *UnpackAndOp = link->h_UnpackAndMaxloc; 702 else if (op == MPI_MINLOC) *UnpackAndOp = link->h_UnpackAndMinloc; 703 } 704 #if defined(PETSC_HAVE_DEVICE) 705 else if (PetscMemTypeDevice(mtype) && !atomic) { 706 if (op == MPI_REPLACE) *UnpackAndOp = link->d_UnpackAndInsert; 707 else if (op == MPI_SUM || op == MPIU_SUM) *UnpackAndOp = link->d_UnpackAndAdd; 708 else if (op == MPI_PROD) *UnpackAndOp = link->d_UnpackAndMult; 709 else if (op == MPI_MAX || op == MPIU_MAX) *UnpackAndOp = link->d_UnpackAndMax; 710 else if (op == MPI_MIN || op == MPIU_MIN) *UnpackAndOp = link->d_UnpackAndMin; 711 else if (op == MPI_LAND) *UnpackAndOp = link->d_UnpackAndLAND; 712 else if (op == MPI_BAND) *UnpackAndOp = link->d_UnpackAndBAND; 713 else if (op == MPI_LOR) *UnpackAndOp = link->d_UnpackAndLOR; 714 else if (op == MPI_BOR) *UnpackAndOp = link->d_UnpackAndBOR; 715 else if (op == MPI_LXOR) *UnpackAndOp = link->d_UnpackAndLXOR; 716 else if (op == MPI_BXOR) *UnpackAndOp = link->d_UnpackAndBXOR; 717 else if (op == MPI_MAXLOC) *UnpackAndOp = link->d_UnpackAndMaxloc; 718 else if (op == MPI_MINLOC) *UnpackAndOp = link->d_UnpackAndMinloc; 719 } else if (PetscMemTypeDevice(mtype) && atomic) { 720 if (op == MPI_REPLACE) *UnpackAndOp = link->da_UnpackAndInsert; 721 else if (op == MPI_SUM || op == MPIU_SUM) *UnpackAndOp = link->da_UnpackAndAdd; 722 else if (op == MPI_PROD) *UnpackAndOp = link->da_UnpackAndMult; 723 else if (op == MPI_MAX || op == MPIU_MAX) *UnpackAndOp = link->da_UnpackAndMax; 724 else if (op == MPI_MIN || op == MPIU_MIN) *UnpackAndOp = link->da_UnpackAndMin; 725 else if (op == MPI_LAND) *UnpackAndOp = link->da_UnpackAndLAND; 726 else if (op == MPI_BAND) *UnpackAndOp = link->da_UnpackAndBAND; 727 else if (op == MPI_LOR) *UnpackAndOp = link->da_UnpackAndLOR; 728 else if (op == MPI_BOR) *UnpackAndOp = link->da_UnpackAndBOR; 729 else if (op == MPI_LXOR) *UnpackAndOp = link->da_UnpackAndLXOR; 730 else if (op == MPI_BXOR) *UnpackAndOp = link->da_UnpackAndBXOR; 731 else if (op == MPI_MAXLOC) *UnpackAndOp = link->da_UnpackAndMaxloc; 732 else if (op == MPI_MINLOC) *UnpackAndOp = link->da_UnpackAndMinloc; 733 } 734 #endif 735 PetscFunctionReturn(PETSC_SUCCESS); 736 } 737 738 PetscErrorCode PetscSFLinkGetScatterAndOp(PetscSFLink link, PetscMemType mtype, MPI_Op op, PetscBool atomic, PetscErrorCode (**ScatterAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, PetscInt, PetscSFPackOpt, const PetscInt *, void *)) 739 { 740 PetscFunctionBegin; 741 *ScatterAndOp = NULL; 742 if (PetscMemTypeHost(mtype)) { 743 if (op == MPI_REPLACE) *ScatterAndOp = link->h_ScatterAndInsert; 744 else if (op == MPI_SUM || op == MPIU_SUM) *ScatterAndOp = link->h_ScatterAndAdd; 745 else if (op == MPI_PROD) *ScatterAndOp = link->h_ScatterAndMult; 746 else if (op == MPI_MAX || op == MPIU_MAX) *ScatterAndOp = link->h_ScatterAndMax; 747 else if (op == MPI_MIN || op == MPIU_MIN) *ScatterAndOp = link->h_ScatterAndMin; 748 else if (op == MPI_LAND) *ScatterAndOp = link->h_ScatterAndLAND; 749 else if (op == MPI_BAND) *ScatterAndOp = link->h_ScatterAndBAND; 750 else if (op == MPI_LOR) *ScatterAndOp = link->h_ScatterAndLOR; 751 else if (op == MPI_BOR) *ScatterAndOp = link->h_ScatterAndBOR; 752 else if (op == MPI_LXOR) *ScatterAndOp = link->h_ScatterAndLXOR; 753 else if (op == MPI_BXOR) *ScatterAndOp = link->h_ScatterAndBXOR; 754 else if (op == MPI_MAXLOC) *ScatterAndOp = link->h_ScatterAndMaxloc; 755 else if (op == MPI_MINLOC) *ScatterAndOp = link->h_ScatterAndMinloc; 756 } 757 #if defined(PETSC_HAVE_DEVICE) 758 else if (PetscMemTypeDevice(mtype) && !atomic) { 759 if (op == MPI_REPLACE) *ScatterAndOp = link->d_ScatterAndInsert; 760 else if (op == MPI_SUM || op == MPIU_SUM) *ScatterAndOp = link->d_ScatterAndAdd; 761 else if (op == MPI_PROD) *ScatterAndOp = link->d_ScatterAndMult; 762 else if (op == MPI_MAX || op == MPIU_MAX) *ScatterAndOp = link->d_ScatterAndMax; 763 else if (op == MPI_MIN || op == MPIU_MIN) *ScatterAndOp = link->d_ScatterAndMin; 764 else if (op == MPI_LAND) *ScatterAndOp = link->d_ScatterAndLAND; 765 else if (op == MPI_BAND) *ScatterAndOp = link->d_ScatterAndBAND; 766 else if (op == MPI_LOR) *ScatterAndOp = link->d_ScatterAndLOR; 767 else if (op == MPI_BOR) *ScatterAndOp = link->d_ScatterAndBOR; 768 else if (op == MPI_LXOR) *ScatterAndOp = link->d_ScatterAndLXOR; 769 else if (op == MPI_BXOR) *ScatterAndOp = link->d_ScatterAndBXOR; 770 else if (op == MPI_MAXLOC) *ScatterAndOp = link->d_ScatterAndMaxloc; 771 else if (op == MPI_MINLOC) *ScatterAndOp = link->d_ScatterAndMinloc; 772 } else if (PetscMemTypeDevice(mtype) && atomic) { 773 if (op == MPI_REPLACE) *ScatterAndOp = link->da_ScatterAndInsert; 774 else if (op == MPI_SUM || op == MPIU_SUM) *ScatterAndOp = link->da_ScatterAndAdd; 775 else if (op == MPI_PROD) *ScatterAndOp = link->da_ScatterAndMult; 776 else if (op == MPI_MAX || op == MPIU_MAX) *ScatterAndOp = link->da_ScatterAndMax; 777 else if (op == MPI_MIN || op == MPIU_MIN) *ScatterAndOp = link->da_ScatterAndMin; 778 else if (op == MPI_LAND) *ScatterAndOp = link->da_ScatterAndLAND; 779 else if (op == MPI_BAND) *ScatterAndOp = link->da_ScatterAndBAND; 780 else if (op == MPI_LOR) *ScatterAndOp = link->da_ScatterAndLOR; 781 else if (op == MPI_BOR) *ScatterAndOp = link->da_ScatterAndBOR; 782 else if (op == MPI_LXOR) *ScatterAndOp = link->da_ScatterAndLXOR; 783 else if (op == MPI_BXOR) *ScatterAndOp = link->da_ScatterAndBXOR; 784 else if (op == MPI_MAXLOC) *ScatterAndOp = link->da_ScatterAndMaxloc; 785 else if (op == MPI_MINLOC) *ScatterAndOp = link->da_ScatterAndMinloc; 786 } 787 #endif 788 PetscFunctionReturn(PETSC_SUCCESS); 789 } 790 791 PetscErrorCode PetscSFLinkGetFetchAndOp(PetscSFLink link, PetscMemType mtype, MPI_Op op, PetscBool atomic, PetscErrorCode (**FetchAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, void *)) 792 { 793 PetscFunctionBegin; 794 *FetchAndOp = NULL; 795 PetscCheck(op == MPI_SUM || op == MPIU_SUM, PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for MPI_Op in FetchAndOp"); 796 if (PetscMemTypeHost(mtype)) *FetchAndOp = link->h_FetchAndAdd; 797 #if defined(PETSC_HAVE_DEVICE) 798 else if (PetscMemTypeDevice(mtype) && !atomic) *FetchAndOp = link->d_FetchAndAdd; 799 else if (PetscMemTypeDevice(mtype) && atomic) *FetchAndOp = link->da_FetchAndAdd; 800 #endif 801 PetscFunctionReturn(PETSC_SUCCESS); 802 } 803 804 PetscErrorCode PetscSFLinkGetFetchAndOpLocal(PetscSFLink link, PetscMemType mtype, MPI_Op op, PetscBool atomic, PetscErrorCode (**FetchAndOpLocal)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, void *)) 805 { 806 PetscFunctionBegin; 807 *FetchAndOpLocal = NULL; 808 PetscCheck(op == MPI_SUM || op == MPIU_SUM, PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for MPI_Op in FetchAndOp"); 809 if (PetscMemTypeHost(mtype)) *FetchAndOpLocal = link->h_FetchAndAddLocal; 810 #if defined(PETSC_HAVE_DEVICE) 811 else if (PetscMemTypeDevice(mtype) && !atomic) *FetchAndOpLocal = link->d_FetchAndAddLocal; 812 else if (PetscMemTypeDevice(mtype) && atomic) *FetchAndOpLocal = link->da_FetchAndAddLocal; 813 #endif 814 PetscFunctionReturn(PETSC_SUCCESS); 815 } 816 817 static inline PetscErrorCode PetscSFLinkLogFlopsAfterUnpackRootData(PetscSF sf, PetscSFLink link, PetscSFScope scope, MPI_Op op) 818 { 819 PetscLogDouble flops; 820 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 821 822 PetscFunctionBegin; 823 if (op != MPI_REPLACE && link->basicunit == MPIU_SCALAR) { /* op is a reduction on PetscScalars */ 824 flops = bas->rootbuflen[scope] * link->bs; /* # of roots in buffer x # of scalars in unit */ 825 #if defined(PETSC_HAVE_DEVICE) 826 if (PetscMemTypeDevice(link->rootmtype)) PetscCall(PetscLogGpuFlops(flops)); 827 else 828 #endif 829 PetscCall(PetscLogFlops(flops)); 830 } 831 PetscFunctionReturn(PETSC_SUCCESS); 832 } 833 834 static inline PetscErrorCode PetscSFLinkLogFlopsAfterUnpackLeafData(PetscSF sf, PetscSFLink link, PetscSFScope scope, MPI_Op op) 835 { 836 PetscLogDouble flops; 837 838 PetscFunctionBegin; 839 if (op != MPI_REPLACE && link->basicunit == MPIU_SCALAR) { /* op is a reduction on PetscScalars */ 840 flops = sf->leafbuflen[scope] * link->bs; /* # of roots in buffer x # of scalars in unit */ 841 #if defined(PETSC_HAVE_DEVICE) 842 if (PetscMemTypeDevice(link->leafmtype)) PetscCall(PetscLogGpuFlops(flops)); 843 else 844 #endif 845 PetscCall(PetscLogFlops(flops)); 846 } 847 PetscFunctionReturn(PETSC_SUCCESS); 848 } 849 850 /* When SF could not find a proper UnpackAndOp() from link, it falls back to MPI_Reduce_local. 851 Input Parameters: 852 +sf - The StarForest 853 .link - The link 854 .count - Number of entries to unpack 855 .start - The first index, significant when indices=NULL 856 .indices - Indices of entries in <data>. If NULL, it means indices are contiguous and the first is given in <start> 857 .buf - A contiguous buffer to unpack from 858 -op - Operation after unpack 859 860 Output Parameters: 861 .data - The data to unpack to 862 */ 863 static inline PetscErrorCode PetscSFLinkUnpackDataWithMPIReduceLocal(PetscSF sf, PetscSFLink link, PetscInt count, PetscInt start, const PetscInt *indices, void *data, const void *buf, MPI_Op op) 864 { 865 PetscFunctionBegin; 866 #if defined(PETSC_HAVE_MPI_REDUCE_LOCAL) 867 { 868 PetscInt i; 869 if (indices) { 870 /* Note we use link->unit instead of link->basicunit. When op can be mapped to MPI_SUM etc, it operates on 871 basic units of a root/leaf element-wisely. Otherwise, it is meant to operate on a whole root/leaf. 872 */ 873 for (i = 0; i < count; i++) PetscCallMPI(MPI_Reduce_local((const char *)buf + i * link->unitbytes, (char *)data + indices[i] * link->unitbytes, 1, link->unit, op)); 874 } else { 875 PetscCallMPI(MPIU_Reduce_local(buf, (char *)data + start * link->unitbytes, count, link->unit, op)); 876 } 877 } 878 PetscFunctionReturn(PETSC_SUCCESS); 879 #else 880 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No unpacking reduction operation for this MPI_Op"); 881 #endif 882 } 883 884 static inline PetscErrorCode PetscSFLinkScatterDataWithMPIReduceLocal(PetscSF sf, PetscSFLink link, PetscInt count, PetscInt srcStart, const PetscInt *srcIdx, const void *src, PetscInt dstStart, const PetscInt *dstIdx, void *dst, MPI_Op op) 885 { 886 PetscFunctionBegin; 887 #if defined(PETSC_HAVE_MPI_REDUCE_LOCAL) 888 { 889 PetscInt i, disp; 890 if (!srcIdx) { 891 PetscCall(PetscSFLinkUnpackDataWithMPIReduceLocal(sf, link, count, dstStart, dstIdx, dst, (const char *)src + srcStart * link->unitbytes, op)); 892 } else { 893 for (i = 0; i < count; i++) { 894 disp = dstIdx ? dstIdx[i] : dstStart + i; 895 PetscCallMPI(MPIU_Reduce_local((const char *)src + srcIdx[i] * link->unitbytes, (char *)dst + disp * link->unitbytes, 1, link->unit, op)); 896 } 897 } 898 } 899 PetscFunctionReturn(PETSC_SUCCESS); 900 #else 901 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No unpacking reduction operation for this MPI_Op"); 902 #endif 903 } 904 905 /*============================================================================= 906 Pack/Unpack/Fetch/Scatter routines 907 ============================================================================*/ 908 909 /* Pack rootdata to rootbuf 910 Input Parameters: 911 + sf - The SF this packing works on. 912 . link - It gives the memtype of the roots and also provides root buffer. 913 . scope - PETSCSF_LOCAL or PETSCSF_REMOTE. Note SF has the ability to do local and remote communications separately. 914 - rootdata - Where to read the roots. 915 916 Notes: 917 When rootdata can be directly used as root buffer, the routine is almost a no-op. After the call, root data is 918 in a place where the underlying MPI is ready to access (use_gpu_aware_mpi or not) 919 */ 920 static PetscErrorCode PetscSFLinkPackRootData_Private(PetscSF sf, PetscSFLink link, PetscSFScope scope, const void *rootdata) 921 { 922 const PetscInt *rootindices = NULL; 923 PetscInt count, start; 924 PetscErrorCode (*Pack)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, void *) = NULL; 925 PetscMemType rootmtype = link->rootmtype; 926 PetscSFPackOpt opt = NULL; 927 928 PetscFunctionBegin; 929 if (!link->rootdirect[scope]) { /* If rootdata works directly as rootbuf, skip packing */ 930 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, rootmtype, scope, &count, &start, &opt, &rootindices)); 931 PetscCall(PetscSFLinkGetPack(link, rootmtype, &Pack)); 932 PetscCall((*Pack)(link, count, start, opt, rootindices, rootdata, link->rootbuf[scope][rootmtype])); 933 } 934 PetscFunctionReturn(PETSC_SUCCESS); 935 } 936 937 /* Pack leafdata to leafbuf */ 938 static PetscErrorCode PetscSFLinkPackLeafData_Private(PetscSF sf, PetscSFLink link, PetscSFScope scope, const void *leafdata) 939 { 940 const PetscInt *leafindices = NULL; 941 PetscInt count, start; 942 PetscErrorCode (*Pack)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, void *) = NULL; 943 PetscMemType leafmtype = link->leafmtype; 944 PetscSFPackOpt opt = NULL; 945 946 PetscFunctionBegin; 947 if (!link->leafdirect[scope]) { /* If leafdata works directly as rootbuf, skip packing */ 948 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, leafmtype, scope, &count, &start, &opt, &leafindices)); 949 PetscCall(PetscSFLinkGetPack(link, leafmtype, &Pack)); 950 PetscCall((*Pack)(link, count, start, opt, leafindices, leafdata, link->leafbuf[scope][leafmtype])); 951 } 952 PetscFunctionReturn(PETSC_SUCCESS); 953 } 954 955 /* Pack rootdata to rootbuf, which are in the same memory space */ 956 PetscErrorCode PetscSFLinkPackRootData(PetscSF sf, PetscSFLink link, PetscSFScope scope, const void *rootdata) 957 { 958 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 959 960 PetscFunctionBegin; 961 if (scope == PETSCSF_REMOTE) { /* Sync the device if rootdata is not on petsc default stream */ 962 if (PetscMemTypeDevice(link->rootmtype) && link->SyncDevice && sf->unknown_input_stream) PetscCall((*link->SyncDevice)(link)); 963 if (link->PrePack) PetscCall((*link->PrePack)(sf, link, PETSCSF_ROOT2LEAF)); /* Used by SF nvshmem */ 964 } 965 PetscCall(PetscLogEventBegin(PETSCSF_Pack, sf, 0, 0, 0)); 966 if (bas->rootbuflen[scope]) PetscCall(PetscSFLinkPackRootData_Private(sf, link, scope, rootdata)); 967 PetscCall(PetscLogEventEnd(PETSCSF_Pack, sf, 0, 0, 0)); 968 PetscFunctionReturn(PETSC_SUCCESS); 969 } 970 /* Pack leafdata to leafbuf, which are in the same memory space */ 971 PetscErrorCode PetscSFLinkPackLeafData(PetscSF sf, PetscSFLink link, PetscSFScope scope, const void *leafdata) 972 { 973 PetscFunctionBegin; 974 if (scope == PETSCSF_REMOTE) { 975 if (PetscMemTypeDevice(link->leafmtype) && link->SyncDevice && sf->unknown_input_stream) PetscCall((*link->SyncDevice)(link)); 976 if (link->PrePack) PetscCall((*link->PrePack)(sf, link, PETSCSF_LEAF2ROOT)); /* Used by SF nvshmem */ 977 } 978 PetscCall(PetscLogEventBegin(PETSCSF_Pack, sf, 0, 0, 0)); 979 if (sf->leafbuflen[scope]) PetscCall(PetscSFLinkPackLeafData_Private(sf, link, scope, leafdata)); 980 PetscCall(PetscLogEventEnd(PETSCSF_Pack, sf, 0, 0, 0)); 981 PetscFunctionReturn(PETSC_SUCCESS); 982 } 983 984 static PetscErrorCode PetscSFLinkUnpackRootData_Private(PetscSF sf, PetscSFLink link, PetscSFScope scope, void *rootdata, MPI_Op op) 985 { 986 const PetscInt *rootindices = NULL; 987 PetscInt count, start; 988 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 989 PetscErrorCode (*UnpackAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, const void *) = NULL; 990 PetscMemType rootmtype = link->rootmtype; 991 PetscSFPackOpt opt = NULL; 992 993 PetscFunctionBegin; 994 if (!link->rootdirect[scope]) { /* If rootdata works directly as rootbuf, skip unpacking */ 995 PetscCall(PetscSFLinkGetUnpackAndOp(link, rootmtype, op, bas->rootdups[scope], &UnpackAndOp)); 996 if (UnpackAndOp) { 997 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, rootmtype, scope, &count, &start, &opt, &rootindices)); 998 PetscCall((*UnpackAndOp)(link, count, start, opt, rootindices, rootdata, link->rootbuf[scope][rootmtype])); 999 } else { 1000 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, PETSC_MEMTYPE_HOST, scope, &count, &start, &opt, &rootindices)); 1001 PetscCall(PetscSFLinkUnpackDataWithMPIReduceLocal(sf, link, count, start, rootindices, rootdata, link->rootbuf[scope][rootmtype], op)); 1002 } 1003 } 1004 PetscCall(PetscSFLinkLogFlopsAfterUnpackRootData(sf, link, scope, op)); 1005 PetscFunctionReturn(PETSC_SUCCESS); 1006 } 1007 1008 static PetscErrorCode PetscSFLinkUnpackLeafData_Private(PetscSF sf, PetscSFLink link, PetscSFScope scope, void *leafdata, MPI_Op op) 1009 { 1010 const PetscInt *leafindices = NULL; 1011 PetscInt count, start; 1012 PetscErrorCode (*UnpackAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, const void *) = NULL; 1013 PetscMemType leafmtype = link->leafmtype; 1014 PetscSFPackOpt opt = NULL; 1015 1016 PetscFunctionBegin; 1017 if (!link->leafdirect[scope]) { /* If leafdata works directly as rootbuf, skip unpacking */ 1018 PetscCall(PetscSFLinkGetUnpackAndOp(link, leafmtype, op, sf->leafdups[scope], &UnpackAndOp)); 1019 if (UnpackAndOp) { 1020 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, leafmtype, scope, &count, &start, &opt, &leafindices)); 1021 PetscCall((*UnpackAndOp)(link, count, start, opt, leafindices, leafdata, link->leafbuf[scope][leafmtype])); 1022 } else { 1023 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, PETSC_MEMTYPE_HOST, scope, &count, &start, &opt, &leafindices)); 1024 PetscCall(PetscSFLinkUnpackDataWithMPIReduceLocal(sf, link, count, start, leafindices, leafdata, link->leafbuf[scope][leafmtype], op)); 1025 } 1026 } 1027 PetscCall(PetscSFLinkLogFlopsAfterUnpackLeafData(sf, link, scope, op)); 1028 PetscFunctionReturn(PETSC_SUCCESS); 1029 } 1030 /* Unpack rootbuf to rootdata, which are in the same memory space */ 1031 PetscErrorCode PetscSFLinkUnpackRootData(PetscSF sf, PetscSFLink link, PetscSFScope scope, void *rootdata, MPI_Op op) 1032 { 1033 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1034 1035 PetscFunctionBegin; 1036 PetscCall(PetscLogEventBegin(PETSCSF_Unpack, sf, 0, 0, 0)); 1037 if (bas->rootbuflen[scope]) PetscCall(PetscSFLinkUnpackRootData_Private(sf, link, scope, rootdata, op)); 1038 PetscCall(PetscLogEventEnd(PETSCSF_Unpack, sf, 0, 0, 0)); 1039 if (scope == PETSCSF_REMOTE) { 1040 if (link->PostUnpack) PetscCall((*link->PostUnpack)(sf, link, PETSCSF_LEAF2ROOT)); /* Used by SF nvshmem */ 1041 if (PetscMemTypeDevice(link->rootmtype) && link->SyncDevice && sf->unknown_input_stream) PetscCall((*link->SyncDevice)(link)); 1042 } 1043 PetscFunctionReturn(PETSC_SUCCESS); 1044 } 1045 1046 /* Unpack leafbuf to leafdata for remote (common case) or local (rare case when rootmtype != leafmtype) */ 1047 PetscErrorCode PetscSFLinkUnpackLeafData(PetscSF sf, PetscSFLink link, PetscSFScope scope, void *leafdata, MPI_Op op) 1048 { 1049 PetscFunctionBegin; 1050 PetscCall(PetscLogEventBegin(PETSCSF_Unpack, sf, 0, 0, 0)); 1051 if (sf->leafbuflen[scope]) PetscCall(PetscSFLinkUnpackLeafData_Private(sf, link, scope, leafdata, op)); 1052 PetscCall(PetscLogEventEnd(PETSCSF_Unpack, sf, 0, 0, 0)); 1053 if (scope == PETSCSF_REMOTE) { 1054 if (link->PostUnpack) PetscCall((*link->PostUnpack)(sf, link, PETSCSF_ROOT2LEAF)); /* Used by SF nvshmem */ 1055 if (PetscMemTypeDevice(link->leafmtype) && link->SyncDevice && sf->unknown_input_stream) PetscCall((*link->SyncDevice)(link)); 1056 } 1057 PetscFunctionReturn(PETSC_SUCCESS); 1058 } 1059 1060 /* FetchAndOp rootdata with rootbuf, it is a kind of Unpack on rootdata, except it also updates rootbuf */ 1061 PetscErrorCode PetscSFLinkFetchAndOpRemote(PetscSF sf, PetscSFLink link, void *rootdata, MPI_Op op) 1062 { 1063 const PetscInt *rootindices = NULL; 1064 PetscInt count, start; 1065 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1066 PetscErrorCode (*FetchAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, void *) = NULL; 1067 PetscMemType rootmtype = link->rootmtype; 1068 PetscSFPackOpt opt = NULL; 1069 1070 PetscFunctionBegin; 1071 PetscCall(PetscLogEventBegin(PETSCSF_Unpack, sf, 0, 0, 0)); 1072 if (bas->rootbuflen[PETSCSF_REMOTE]) { 1073 /* Do FetchAndOp on rootdata with rootbuf */ 1074 PetscCall(PetscSFLinkGetFetchAndOp(link, rootmtype, op, bas->rootdups[PETSCSF_REMOTE], &FetchAndOp)); 1075 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, rootmtype, PETSCSF_REMOTE, &count, &start, &opt, &rootindices)); 1076 PetscCall((*FetchAndOp)(link, count, start, opt, rootindices, rootdata, link->rootbuf[PETSCSF_REMOTE][rootmtype])); 1077 } 1078 PetscCall(PetscSFLinkLogFlopsAfterUnpackRootData(sf, link, PETSCSF_REMOTE, op)); 1079 PetscCall(PetscLogEventEnd(PETSCSF_Unpack, sf, 0, 0, 0)); 1080 PetscFunctionReturn(PETSC_SUCCESS); 1081 } 1082 1083 PetscErrorCode PetscSFLinkScatterLocal(PetscSF sf, PetscSFLink link, PetscSFDirection direction, void *rootdata, void *leafdata, MPI_Op op) 1084 { 1085 const PetscInt *rootindices = NULL, *leafindices = NULL; 1086 PetscInt count, rootstart, leafstart; 1087 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1088 PetscErrorCode (*ScatterAndOp)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, PetscInt, PetscSFPackOpt, const PetscInt *, void *) = NULL; 1089 PetscMemType rootmtype = link->rootmtype, leafmtype = link->leafmtype, srcmtype, dstmtype; 1090 PetscSFPackOpt leafopt = NULL, rootopt = NULL; 1091 PetscInt buflen = sf->leafbuflen[PETSCSF_LOCAL]; 1092 char *srcbuf = NULL, *dstbuf = NULL; 1093 PetscBool dstdups; 1094 1095 PetscFunctionBegin; 1096 if (!buflen) PetscFunctionReturn(PETSC_SUCCESS); 1097 if (rootmtype != leafmtype) { /* The cross memory space local scatter is done by pack, copy and unpack */ 1098 if (direction == PETSCSF_ROOT2LEAF) { 1099 PetscCall(PetscSFLinkPackRootData(sf, link, PETSCSF_LOCAL, rootdata)); 1100 srcmtype = rootmtype; 1101 srcbuf = link->rootbuf[PETSCSF_LOCAL][rootmtype]; 1102 dstmtype = leafmtype; 1103 dstbuf = link->leafbuf[PETSCSF_LOCAL][leafmtype]; 1104 } else { 1105 PetscCall(PetscSFLinkPackLeafData(sf, link, PETSCSF_LOCAL, leafdata)); 1106 srcmtype = leafmtype; 1107 srcbuf = link->leafbuf[PETSCSF_LOCAL][leafmtype]; 1108 dstmtype = rootmtype; 1109 dstbuf = link->rootbuf[PETSCSF_LOCAL][rootmtype]; 1110 } 1111 PetscCall((*link->Memcpy)(link, dstmtype, dstbuf, srcmtype, srcbuf, buflen * link->unitbytes)); 1112 /* If above is a device to host copy, we have to sync the stream before accessing the buffer on host */ 1113 if (PetscMemTypeHost(dstmtype)) PetscCall((*link->SyncStream)(link)); 1114 if (direction == PETSCSF_ROOT2LEAF) { 1115 PetscCall(PetscSFLinkUnpackLeafData(sf, link, PETSCSF_LOCAL, leafdata, op)); 1116 } else { 1117 PetscCall(PetscSFLinkUnpackRootData(sf, link, PETSCSF_LOCAL, rootdata, op)); 1118 } 1119 } else { 1120 dstdups = (direction == PETSCSF_ROOT2LEAF) ? sf->leafdups[PETSCSF_LOCAL] : bas->rootdups[PETSCSF_LOCAL]; 1121 dstmtype = (direction == PETSCSF_ROOT2LEAF) ? link->leafmtype : link->rootmtype; 1122 PetscCall(PetscSFLinkGetScatterAndOp(link, dstmtype, op, dstdups, &ScatterAndOp)); 1123 if (ScatterAndOp) { 1124 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, rootmtype, PETSCSF_LOCAL, &count, &rootstart, &rootopt, &rootindices)); 1125 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, leafmtype, PETSCSF_LOCAL, &count, &leafstart, &leafopt, &leafindices)); 1126 if (direction == PETSCSF_ROOT2LEAF) { 1127 PetscCall((*ScatterAndOp)(link, count, rootstart, rootopt, rootindices, rootdata, leafstart, leafopt, leafindices, leafdata)); 1128 } else { 1129 PetscCall((*ScatterAndOp)(link, count, leafstart, leafopt, leafindices, leafdata, rootstart, rootopt, rootindices, rootdata)); 1130 } 1131 } else { 1132 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, PETSC_MEMTYPE_HOST, PETSCSF_LOCAL, &count, &rootstart, &rootopt, &rootindices)); 1133 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, PETSC_MEMTYPE_HOST, PETSCSF_LOCAL, &count, &leafstart, &leafopt, &leafindices)); 1134 if (direction == PETSCSF_ROOT2LEAF) { 1135 PetscCall(PetscSFLinkScatterDataWithMPIReduceLocal(sf, link, count, rootstart, rootindices, rootdata, leafstart, leafindices, leafdata, op)); 1136 } else { 1137 PetscCall(PetscSFLinkScatterDataWithMPIReduceLocal(sf, link, count, leafstart, leafindices, leafdata, rootstart, rootindices, rootdata, op)); 1138 } 1139 } 1140 } 1141 PetscFunctionReturn(PETSC_SUCCESS); 1142 } 1143 1144 /* Fetch rootdata to leafdata and leafupdate locally */ 1145 PetscErrorCode PetscSFLinkFetchAndOpLocal(PetscSF sf, PetscSFLink link, void *rootdata, const void *leafdata, void *leafupdate, MPI_Op op) 1146 { 1147 const PetscInt *rootindices = NULL, *leafindices = NULL; 1148 PetscInt count, rootstart, leafstart; 1149 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1150 PetscErrorCode (*FetchAndOpLocal)(PetscSFLink, PetscInt, PetscInt, PetscSFPackOpt, const PetscInt *, void *, PetscInt, PetscSFPackOpt, const PetscInt *, const void *, void *) = NULL; 1151 const PetscMemType rootmtype = link->rootmtype, leafmtype = link->leafmtype; 1152 PetscSFPackOpt leafopt = NULL, rootopt = NULL; 1153 1154 PetscFunctionBegin; 1155 if (!bas->rootbuflen[PETSCSF_LOCAL]) PetscFunctionReturn(PETSC_SUCCESS); 1156 if (rootmtype != leafmtype) { 1157 /* The local communication has to go through pack and unpack */ 1158 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Doing PetscSFFetchAndOp with rootdata and leafdata on opposite side of CPU and GPU"); 1159 } else { 1160 PetscCall(PetscSFLinkGetRootPackOptAndIndices(sf, link, rootmtype, PETSCSF_LOCAL, &count, &rootstart, &rootopt, &rootindices)); 1161 PetscCall(PetscSFLinkGetLeafPackOptAndIndices(sf, link, leafmtype, PETSCSF_LOCAL, &count, &leafstart, &leafopt, &leafindices)); 1162 PetscCall(PetscSFLinkGetFetchAndOpLocal(link, rootmtype, op, bas->rootdups[PETSCSF_LOCAL], &FetchAndOpLocal)); 1163 PetscCall((*FetchAndOpLocal)(link, count, rootstart, rootopt, rootindices, rootdata, leafstart, leafopt, leafindices, leafdata, leafupdate)); 1164 } 1165 PetscFunctionReturn(PETSC_SUCCESS); 1166 } 1167 1168 /* 1169 Create per-rank pack/unpack optimizations based on indices patterns 1170 1171 Input Parameters: 1172 + n - Number of destination ranks 1173 . offset - [n+1] For the i-th rank, its associated indices are idx[offset[i], offset[i+1]). offset[0] needs not to be 0. 1174 - idx - [*] Array storing indices 1175 1176 Output Parameters: 1177 + opt - Pack optimizations. NULL if no optimizations. 1178 */ 1179 static PetscErrorCode PetscSFCreatePackOpt(PetscInt n, const PetscInt *offset, const PetscInt *idx, PetscSFPackOpt *out) 1180 { 1181 PetscInt r, p, start, i, j, k, dx, dy, dz, dydz, m, X, Y; 1182 PetscBool optimizable = PETSC_TRUE; 1183 PetscSFPackOpt opt; 1184 1185 PetscFunctionBegin; 1186 PetscCall(PetscMalloc1(1, &opt)); 1187 PetscCall(PetscMalloc1(7 * n + 2, &opt->array)); 1188 opt->n = opt->array[0] = n; 1189 opt->offset = opt->array + 1; 1190 opt->start = opt->array + n + 2; 1191 opt->dx = opt->array + 2 * n + 2; 1192 opt->dy = opt->array + 3 * n + 2; 1193 opt->dz = opt->array + 4 * n + 2; 1194 opt->X = opt->array + 5 * n + 2; 1195 opt->Y = opt->array + 6 * n + 2; 1196 1197 for (r = 0; r < n; r++) { /* For each destination rank */ 1198 m = offset[r + 1] - offset[r]; /* Total number of indices for this rank. We want to see if m can be factored into dx*dy*dz */ 1199 p = offset[r]; 1200 start = idx[p]; /* First index for this rank */ 1201 p++; 1202 1203 /* Search in X dimension */ 1204 for (dx = 1; dx < m; dx++, p++) { 1205 if (start + dx != idx[p]) break; 1206 } 1207 1208 dydz = m / dx; 1209 X = dydz > 1 ? (idx[p] - start) : dx; 1210 /* Not optimizable if m is not a multiple of dx, or some unrecognized pattern is found */ 1211 if (m % dx || X <= 0) { 1212 optimizable = PETSC_FALSE; 1213 goto finish; 1214 } 1215 for (dy = 1; dy < dydz; dy++) { /* Search in Y dimension */ 1216 for (i = 0; i < dx; i++, p++) { 1217 if (start + X * dy + i != idx[p]) { 1218 if (i) { 1219 optimizable = PETSC_FALSE; 1220 goto finish; 1221 } /* The pattern is violated in the middle of an x-walk */ 1222 else 1223 goto Z_dimension; 1224 } 1225 } 1226 } 1227 1228 Z_dimension: 1229 dz = m / (dx * dy); 1230 Y = dz > 1 ? (idx[p] - start) / X : dy; 1231 /* Not optimizable if m is not a multiple of dx*dy, or some unrecognized pattern is found */ 1232 if (m % (dx * dy) || Y <= 0) { 1233 optimizable = PETSC_FALSE; 1234 goto finish; 1235 } 1236 for (k = 1; k < dz; k++) { /* Go through Z dimension to see if remaining indices follow the pattern */ 1237 for (j = 0; j < dy; j++) { 1238 for (i = 0; i < dx; i++, p++) { 1239 if (start + X * Y * k + X * j + i != idx[p]) { 1240 optimizable = PETSC_FALSE; 1241 goto finish; 1242 } 1243 } 1244 } 1245 } 1246 opt->start[r] = start; 1247 opt->dx[r] = dx; 1248 opt->dy[r] = dy; 1249 opt->dz[r] = dz; 1250 opt->X[r] = X; 1251 opt->Y[r] = Y; 1252 } 1253 1254 finish: 1255 /* If not optimizable, free arrays to save memory */ 1256 if (!n || !optimizable) { 1257 PetscCall(PetscFree(opt->array)); 1258 PetscCall(PetscFree(opt)); 1259 *out = NULL; 1260 } else { 1261 opt->offset[0] = 0; 1262 for (r = 0; r < n; r++) opt->offset[r + 1] = opt->offset[r] + opt->dx[r] * opt->dy[r] * opt->dz[r]; 1263 *out = opt; 1264 } 1265 PetscFunctionReturn(PETSC_SUCCESS); 1266 } 1267 1268 static inline PetscErrorCode PetscSFDestroyPackOpt(PetscSF sf, PetscMemType mtype, PetscSFPackOpt *out) 1269 { 1270 PetscSFPackOpt opt = *out; 1271 1272 PetscFunctionBegin; 1273 if (opt) { 1274 PetscCall(PetscSFFree(sf, mtype, opt->array)); 1275 PetscCall(PetscFree(opt)); 1276 *out = NULL; 1277 } 1278 PetscFunctionReturn(PETSC_SUCCESS); 1279 } 1280 1281 PetscErrorCode PetscSFSetUpPackFields(PetscSF sf) 1282 { 1283 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1284 PetscInt i, j; 1285 1286 PetscFunctionBegin; 1287 /* [0] for PETSCSF_LOCAL and [1] for PETSCSF_REMOTE in the following */ 1288 for (i = 0; i < 2; i++) { /* Set defaults */ 1289 sf->leafstart[i] = 0; 1290 sf->leafcontig[i] = PETSC_TRUE; 1291 sf->leafdups[i] = PETSC_FALSE; 1292 bas->rootstart[i] = 0; 1293 bas->rootcontig[i] = PETSC_TRUE; 1294 bas->rootdups[i] = PETSC_FALSE; 1295 } 1296 1297 sf->leafbuflen[0] = sf->roffset[sf->ndranks]; 1298 sf->leafbuflen[1] = sf->roffset[sf->nranks] - sf->roffset[sf->ndranks]; 1299 1300 if (sf->leafbuflen[0]) sf->leafstart[0] = sf->rmine[0]; 1301 if (sf->leafbuflen[1]) sf->leafstart[1] = sf->rmine[sf->roffset[sf->ndranks]]; 1302 1303 /* Are leaf indices for self and remote contiguous? If yes, it is best for pack/unpack */ 1304 for (i = 0; i < sf->roffset[sf->ndranks]; i++) { /* self */ 1305 if (sf->rmine[i] != sf->leafstart[0] + i) { 1306 sf->leafcontig[0] = PETSC_FALSE; 1307 break; 1308 } 1309 } 1310 for (i = sf->roffset[sf->ndranks], j = 0; i < sf->roffset[sf->nranks]; i++, j++) { /* remote */ 1311 if (sf->rmine[i] != sf->leafstart[1] + j) { 1312 sf->leafcontig[1] = PETSC_FALSE; 1313 break; 1314 } 1315 } 1316 1317 /* If not, see if we can have per-rank optimizations by doing index analysis */ 1318 if (!sf->leafcontig[0]) PetscCall(PetscSFCreatePackOpt(sf->ndranks, sf->roffset, sf->rmine, &sf->leafpackopt[0])); 1319 if (!sf->leafcontig[1]) PetscCall(PetscSFCreatePackOpt(sf->nranks - sf->ndranks, sf->roffset + sf->ndranks, sf->rmine, &sf->leafpackopt[1])); 1320 1321 /* Are root indices for self and remote contiguous? */ 1322 bas->rootbuflen[0] = bas->ioffset[bas->ndiranks]; 1323 bas->rootbuflen[1] = bas->ioffset[bas->niranks] - bas->ioffset[bas->ndiranks]; 1324 1325 if (bas->rootbuflen[0]) bas->rootstart[0] = bas->irootloc[0]; 1326 if (bas->rootbuflen[1]) bas->rootstart[1] = bas->irootloc[bas->ioffset[bas->ndiranks]]; 1327 1328 for (i = 0; i < bas->ioffset[bas->ndiranks]; i++) { 1329 if (bas->irootloc[i] != bas->rootstart[0] + i) { 1330 bas->rootcontig[0] = PETSC_FALSE; 1331 break; 1332 } 1333 } 1334 for (i = bas->ioffset[bas->ndiranks], j = 0; i < bas->ioffset[bas->niranks]; i++, j++) { 1335 if (bas->irootloc[i] != bas->rootstart[1] + j) { 1336 bas->rootcontig[1] = PETSC_FALSE; 1337 break; 1338 } 1339 } 1340 1341 if (!bas->rootcontig[0]) PetscCall(PetscSFCreatePackOpt(bas->ndiranks, bas->ioffset, bas->irootloc, &bas->rootpackopt[0])); 1342 if (!bas->rootcontig[1]) PetscCall(PetscSFCreatePackOpt(bas->niranks - bas->ndiranks, bas->ioffset + bas->ndiranks, bas->irootloc, &bas->rootpackopt[1])); 1343 1344 /* Check dups in indices so that CUDA unpacking kernels can use cheaper regular instructions instead of atomics when they know there are no data race chances */ 1345 if (PetscDefined(HAVE_DEVICE)) { 1346 PetscBool ismulti = (sf->multi == sf) ? PETSC_TRUE : PETSC_FALSE; 1347 if (!sf->leafcontig[0] && !ismulti) PetscCall(PetscCheckDupsInt(sf->leafbuflen[0], sf->rmine, &sf->leafdups[0])); 1348 if (!sf->leafcontig[1] && !ismulti) PetscCall(PetscCheckDupsInt(sf->leafbuflen[1], sf->rmine + sf->roffset[sf->ndranks], &sf->leafdups[1])); 1349 if (!bas->rootcontig[0] && !ismulti) PetscCall(PetscCheckDupsInt(bas->rootbuflen[0], bas->irootloc, &bas->rootdups[0])); 1350 if (!bas->rootcontig[1] && !ismulti) PetscCall(PetscCheckDupsInt(bas->rootbuflen[1], bas->irootloc + bas->ioffset[bas->ndiranks], &bas->rootdups[1])); 1351 } 1352 PetscFunctionReturn(PETSC_SUCCESS); 1353 } 1354 1355 PetscErrorCode PetscSFResetPackFields(PetscSF sf) 1356 { 1357 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 1358 PetscInt i; 1359 1360 PetscFunctionBegin; 1361 for (i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 1362 PetscCall(PetscSFDestroyPackOpt(sf, PETSC_MEMTYPE_HOST, &sf->leafpackopt[i])); 1363 PetscCall(PetscSFDestroyPackOpt(sf, PETSC_MEMTYPE_HOST, &bas->rootpackopt[i])); 1364 #if defined(PETSC_HAVE_DEVICE) 1365 PetscCall(PetscSFDestroyPackOpt(sf, PETSC_MEMTYPE_DEVICE, &sf->leafpackopt_d[i])); 1366 PetscCall(PetscSFDestroyPackOpt(sf, PETSC_MEMTYPE_DEVICE, &bas->rootpackopt_d[i])); 1367 #endif 1368 } 1369 PetscFunctionReturn(PETSC_SUCCESS); 1370 } 1371