1from petsc4py import PETSc 2import unittest 3import numpy as np 4 5 6class TestDMShell(unittest.TestCase): 7 COMM = PETSc.COMM_WORLD 8 9 def setUp(self): 10 self.dm = PETSc.DMShell().create(comm=self.COMM) 11 12 def tearDown(self): 13 self.dm.destroy() 14 self.dm = None 15 PETSc.garbage_cleanup() 16 17 def testSetGlobalVector(self): 18 vec = PETSc.Vec().create(comm=self.COMM) 19 vec.setSizes((10, None)) 20 vec.setUp() 21 self.dm.setGlobalVector(vec) 22 gvec = self.dm.createGlobalVector() 23 self.assertEqual(vec.getSizes(), gvec.getSizes()) 24 self.assertEqual(vec.comm, gvec.comm) 25 26 def testSetCreateGlobalVector(self): 27 def create_vec(dm): 28 v = PETSc.Vec().create(comm=dm.comm) 29 v.setSizes((10, None)) 30 v.setUp() 31 return v 32 33 self.dm.setCreateGlobalVector(create_vec) 34 gvec = self.dm.createGlobalVector() 35 self.assertEqual(gvec.comm, self.dm.comm) 36 self.assertEqual(gvec.getLocalSize(), 10) 37 38 def testSetLocalVector(self): 39 vec = PETSc.Vec().create(comm=PETSc.COMM_SELF) 40 vec.setSizes((1 + 10 * self.COMM.rank, None)) 41 vec.setUp() 42 self.dm.setLocalVector(vec) 43 lvec = self.dm.createLocalVector() 44 self.assertEqual(vec.getSizes(), lvec.getSizes()) 45 lsize, gsize = lvec.getSizes() 46 self.assertEqual(lsize, gsize) 47 self.assertEqual(lvec.comm, PETSc.COMM_SELF) 48 49 def testSetCreateLocalVector(self): 50 def create_vec(dm): 51 v = PETSc.Vec().create(comm=PETSc.COMM_SELF) 52 v.setSizes((1 + 10 * dm.comm.rank, None)) 53 v.setUp() 54 return v 55 56 self.dm.setCreateLocalVector(create_vec) 57 lvec = self.dm.createLocalVector() 58 lsize, gsize = lvec.getSizes() 59 self.assertEqual(lsize, gsize) 60 self.assertEqual(lsize, 1 + 10 * self.dm.comm.rank) 61 self.assertEqual(lvec.comm, PETSc.COMM_SELF) 62 63 def testSetMatrix(self): 64 mat = PETSc.Mat().create(comm=self.COMM) 65 mat.setSizes(((10, None), (2, None))) 66 mat.setUp() 67 mat.assemble() 68 self.dm.setMatrix(mat) 69 nmat = self.dm.createMatrix() 70 self.assertEqual(nmat.getSizes(), mat.getSizes()) 71 72 def testSetCreateMatrix(self): 73 def create_mat(dm): 74 mat = PETSc.Mat().create(comm=self.COMM) 75 mat.setSizes(((10, None), (2, None))) 76 mat.setUp() 77 return mat 78 79 self.dm.setCreateMatrix(create_mat) 80 nmat = self.dm.createMatrix() 81 self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes()) 82 83 def testSetCreateFieldDecomposition(self): 84 def create_field_decomposition_only_is(dm): 85 return None, [PETSc.IS().createStride(1) for _ in range(2)], None 86 87 def create_field_decomposition_only_is_names(dm): 88 _, ises, _ = create_field_decomposition_only_is(dm) 89 names = [f'f_{i}' for i in range(len(ises))] 90 return names, ises, None 91 92 def create_field_decomposition_only_is_dms(dm): 93 _, ises, _ = create_field_decomposition_only_is(dm) 94 dms = [dm.clone() for _ in range(len(ises))] 95 return _, ises, dms 96 97 def create_field_decomposition_only_full(dm): 98 names, ises, _ = create_field_decomposition_only_is_names(dm) 99 dms = [dm.clone() for _ in range(len(ises))] 100 return names, ises, dms 101 102 tests = [ 103 (create_field_decomposition_only_is, False, False), 104 (create_field_decomposition_only_is_names, True, False), 105 (create_field_decomposition_only_is_dms, False, True), 106 (create_field_decomposition_only_full, True, True), 107 ] 108 for test_f, has_names, has_dms in tests: 109 self.dm.setCreateFieldDecomposition(test_f) 110 names, ises, dms = self.dm.createFieldDecomposition() 111 self.assertEqual(len(names), len(ises)) 112 self.assertEqual(len(dms), len(ises)) 113 if has_names: 114 checknames = [f'f_{i}' for i in range(len(ises))] 115 self.assertEqual(names, checknames) 116 else: 117 for n in names: 118 self.assertTrue(n is None) 119 for dm in dms: 120 if has_dms: 121 self.assertEqual(type(dm), PETSc.DM) 122 else: 123 self.assertTrue(dm is None) 124 125 def testGlobalToLocal(self): 126 def begin(dm, ivec, mode, ovec): 127 if mode == PETSc.InsertMode.INSERT_VALUES: 128 ovec[...] = ivec[...] 129 elif mode == PETSc.InsertMode.ADD_VALUES: 130 ovec[...] += ivec[...] 131 132 def end(dm, ivec, mode, ovec): 133 pass 134 135 vec = PETSc.Vec().create(comm=self.COMM) 136 vec.setSizes((10, None)) 137 vec.setUp() 138 vec[...] = self.dm.comm.rank + 1 139 ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF) 140 ovec.setSizes((10, None)) 141 ovec.setUp() 142 self.dm.setGlobalToLocal(begin, end) 143 self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 144 self.assertTrue(np.allclose(vec.getArray(), ovec.getArray())) 145 self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES) 146 self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray())) 147 148 def testLocalToGlobal(self): 149 def begin(dm, ivec, mode, ovec): 150 if mode == PETSc.InsertMode.INSERT_VALUES: 151 ovec[...] = ivec[...] 152 elif mode == PETSc.InsertMode.ADD_VALUES: 153 ovec[...] += ivec[...] 154 155 def end(dm, ivec, mode, ovec): 156 pass 157 158 vec = PETSc.Vec().create(comm=PETSc.COMM_SELF) 159 vec.setSizes((10, None)) 160 vec.setUp() 161 vec[...] = self.dm.comm.rank + 1 162 ovec = PETSc.Vec().create(comm=self.COMM) 163 ovec.setSizes((10, None)) 164 ovec.setUp() 165 self.dm.setLocalToGlobal(begin, end) 166 self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 167 self.assertTrue(np.allclose(vec.getArray(), ovec.getArray())) 168 self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES) 169 self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray())) 170 171 def testLocalToLocal(self): 172 def begin(dm, ivec, mode, ovec): 173 if mode == PETSc.InsertMode.INSERT_VALUES: 174 ovec[...] = ivec[...] 175 elif mode == PETSc.InsertMode.ADD_VALUES: 176 ovec[...] += ivec[...] 177 178 def end(dm, ivec, mode, ovec): 179 pass 180 181 vec = PETSc.Vec().create(comm=PETSc.COMM_SELF) 182 vec.setSizes((10, None)) 183 vec.setUp() 184 vec[...] = self.dm.comm.rank + 1 185 ovec = vec.duplicate() 186 self.dm.setLocalToLocal(begin, end) 187 self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 188 self.assertTrue(np.allclose(vec.getArray(), ovec.getArray())) 189 self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES) 190 self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray())) 191 192 def testGlobalToLocalVecScatter(self): 193 vec = PETSc.Vec().create() 194 vec.setSizes((10, None)) 195 vec.setUp() 196 sct, ovec = PETSc.Scatter.toAll(vec) 197 self.dm.setGlobalToLocalVecScatter(sct) 198 self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 199 200 def testLocalToGlobalVecScatter(self): 201 vec = PETSc.Vec().create() 202 vec.setSizes((10, None)) 203 vec.setUp() 204 sct, ovec = PETSc.Scatter.toAll(vec) 205 self.dm.setLocalToGlobalVecScatter(sct) 206 self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 207 208 def testLocalToLocalVecScatter(self): 209 vec = PETSc.Vec().create() 210 vec.setSizes((10, None)) 211 vec.setUp() 212 sct, ovec = PETSc.Scatter.toAll(vec) 213 self.dm.setLocalToLocalVecScatter(sct) 214 self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES) 215 216 def testCoarsenRefine(self): 217 cdm = PETSc.DMShell().create(comm=self.COMM) 218 219 def coarsen(dm, comm): 220 return cdm 221 222 def refine(dm, comm): 223 return self.dm 224 225 cdm.setRefine(refine) 226 self.dm.setCoarsen(coarsen) 227 coarsened = self.dm.coarsen() 228 self.assertEqual(coarsened, cdm) 229 refined = coarsened.refine() 230 self.assertEqual(refined, self.dm) 231 232 def testCreateInterpolation(self): 233 mat = PETSc.Mat().create() 234 mat.setSizes(((10, None), (10, None))) 235 mat.setUp() 236 vec = PETSc.Vec().create() 237 vec.setSizes((10, None)) 238 vec.setUp() 239 240 def create_interp(dm, dmf): 241 return mat, vec 242 243 self.dm.setCreateInterpolation(create_interp) 244 m, v = self.dm.createInterpolation(self.dm) 245 self.assertEqual(m, mat) 246 self.assertEqual(v, vec) 247 248 def testCreateInjection(self): 249 mat = PETSc.Mat().create() 250 mat.setSizes(((10, None), (10, None))) 251 mat.setUp() 252 253 def create_inject(dm, dmf): 254 return mat 255 256 self.dm.setCreateInjection(create_inject) 257 m = self.dm.createInjection(self.dm) 258 self.assertEqual(m, mat) 259 260 261if __name__ == '__main__': 262 unittest.main() 263