xref: /petsc/src/vec/is/sf/tests/ex23.c (revision cac3c07dbc4e95423e22cb699bb64807a71d0bfe)
1 static const char help[] = "Test PetscSF with integers and MPIU_2INT \n\n";
2 
3 #include <petscvec.h>
4 #include <petscsf.h>
5 #include <petscdevice.h>
6 
7 int main(int argc, char *argv[])
8 {
9   PetscInt           n, n2, N = 12;
10   PetscInt          *indices;
11   IS                 ix, iy;
12   VecScatter         vscat;
13   Vec                x, y;
14   PetscInt           rstart, rend;
15   PetscInt          *xh, *yh, *xd, *yd;
16   PetscDeviceContext dctx;
17 
18   PetscFunctionBeginUser;
19   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
20   PetscCall(VecCreateFromOptions(PETSC_COMM_WORLD, NULL, 1, PETSC_DECIDE, N, &x));
21   PetscCall(VecDuplicate(x, &y));
22   PetscCall(VecGetLocalSize(x, &n));
23 
24   PetscCall(VecGetOwnershipRange(x, &rstart, &rend));
25   PetscCall(ISCreateStride(PETSC_COMM_WORLD, n, rstart, 1, &ix));
26   PetscCall(PetscMalloc1(n, &indices));
27   for (int i = rstart; i < rend; i++) indices[i - rstart] = i / 2;
28   PetscCall(ISCreateGeneral(PETSC_COMM_WORLD, n, indices, PETSC_OWN_POINTER, &iy));
29   // connect y[0] to x[0..1], y[1] to x[2..3], etc
30   PetscCall(VecScatterCreate(y, iy, x, ix, &vscat)); // y has roots, x has leaves
31 
32   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
33 
34   // double the allocation since we will use MPIU_2INT later
35   n2 = 2 * n;
36   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_HOST, n2, &xh));
37   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_HOST, n2, &yh));
38   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_DEVICE, n2, &xd));
39   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_DEVICE, n2, &yd));
40 
41   for (PetscInt i = 0; i < n; i++) {
42     xh[i] = xh[i + n] = i + rstart;
43     yh[i] = yh[i + n] = i + rstart;
44   }
45   PetscCall(PetscDeviceMemcpy(dctx, xd, xh, sizeof(PetscInt) * n2));
46   PetscCall(PetscDeviceMemcpy(dctx, yd, yh, sizeof(PetscInt) * n2));
47 
48   PetscCall(PetscSFReduceWithMemTypeBegin(vscat, MPIU_INT, PETSC_MEMTYPE_DEVICE, xd, PETSC_MEMTYPE_DEVICE, yd, MPI_SUM));
49   PetscCall(PetscSFReduceEnd(vscat, MPIU_INT, xd, yd, MPI_SUM));
50   PetscCall(PetscDeviceMemcpy(dctx, yh, yd, sizeof(PetscInt) * n));
51   PetscCall(PetscDeviceContextSynchronize(dctx)); // finish the async memcpy
52   PetscCall(PetscIntView(n, yh, PETSC_VIEWER_STDOUT_WORLD));
53 
54   PetscCall(PetscSFBcastWithMemTypeBegin(vscat, MPIU_2INT, PETSC_MEMTYPE_DEVICE, yd, PETSC_MEMTYPE_DEVICE, xd, MPI_MINLOC));
55   PetscCall(PetscSFBcastEnd(vscat, MPIU_2INT, yd, xd, MPI_MINLOC));
56   PetscCall(PetscDeviceMemcpy(dctx, xh, xd, sizeof(PetscInt) * n2));
57   PetscCall(PetscDeviceContextSynchronize(dctx)); // finish the async memcpy
58   PetscCall(PetscIntView(n2, xh, PETSC_VIEWER_STDOUT_WORLD));
59 
60   PetscCall(PetscDeviceFree(dctx, xh));
61   PetscCall(PetscDeviceFree(dctx, yh));
62   PetscCall(PetscDeviceFree(dctx, xd));
63   PetscCall(PetscDeviceFree(dctx, yd));
64   PetscCall(ISDestroy(&ix));
65   PetscCall(ISDestroy(&iy));
66   PetscCall(VecDestroy(&x));
67   PetscCall(VecDestroy(&y));
68   PetscCall(VecScatterDestroy(&vscat));
69   PetscCall(PetscFinalize());
70 }
71 
72 /*TEST
73   testset:
74     output_file: output/ex23.out
75     nsize: 3
76 
77     test:
78       suffix: 1
79       requires: cuda
80 
81     test:
82       suffix: 2
83       requires: hip
84 
85     test:
86       suffix: 3
87       requires: sycl
88 
89 TEST*/
90