xref: /honee/tests/smartsim_regression_framework.py (revision 00359db47665a79ecb0241f6ccbf886b649022df)
1aa34f9e7SJames Wright#!/usr/bin/env python3
2aa34f9e7SJames Wrightfrom junit_xml import TestCase
3aa34f9e7SJames Wrightfrom smartsim import Experiment
4aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings
5aa34f9e7SJames Wrightfrom smartredis import Client
6aa34f9e7SJames Wrightimport numpy as np
7aa34f9e7SJames Wrightfrom pathlib import Path
8aa34f9e7SJames Wrightimport argparse
9aa34f9e7SJames Wrightimport traceback
10aa34f9e7SJames Wrightimport sys
11aa34f9e7SJames Wrightimport time
12aa34f9e7SJames Wrightfrom typing import Tuple
13aa34f9e7SJames Wrightimport os
14aa34f9e7SJames Wrightimport shutil
15aa34f9e7SJames Wrightimport logging
16aa34f9e7SJames Wrightimport socket
17aa34f9e7SJames Wright
18aa34f9e7SJames Wright# autopep8 off
19aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix())
20aa34f9e7SJames Wright# autopep8 on
21aa34f9e7SJames Wright
22aa34f9e7SJames Wrightlogging.disable(logging.WARNING)
23aa34f9e7SJames Wright
24aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute()
25aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output'
26aa34f9e7SJames Wright
27aa34f9e7SJames Wright
28aa34f9e7SJames Wrightdef getOpenSocket():
29aa34f9e7SJames Wright    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
30aa34f9e7SJames Wright    s.bind(('', 0))
31aa34f9e7SJames Wright    addr = s.getsockname()
32aa34f9e7SJames Wright    s.close()
33aa34f9e7SJames Wright    return addr[1]
34aa34f9e7SJames Wright
35aa34f9e7SJames Wright
36aa34f9e7SJames Wrightclass NoError(Exception):
37aa34f9e7SJames Wright    pass
38aa34f9e7SJames Wright
39aa34f9e7SJames Wright
40aa34f9e7SJames Wrightdef assert_np_all(test, truth):
41aa34f9e7SJames Wright    """Assert with better error reporting"""
42aa34f9e7SJames Wright    try:
43aa34f9e7SJames Wright        assert np.all(test == truth)
44aa34f9e7SJames Wright    except Exception as e:
45aa34f9e7SJames Wright        raise Exception(f"Expected {truth}, but got {test}") from e
46aa34f9e7SJames Wright
47aa34f9e7SJames Wright
48aa34f9e7SJames Wrightdef assert_equal(test, truth):
49aa34f9e7SJames Wright    """Assert with better error reporting"""
50aa34f9e7SJames Wright    try:
51aa34f9e7SJames Wright        assert test == truth
52aa34f9e7SJames Wright    except Exception as e:
53aa34f9e7SJames Wright        raise Exception(f"Expected {truth}, but got {test}") from e
54aa34f9e7SJames Wright
55aa34f9e7SJames Wright
56aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8):
57aa34f9e7SJames Wright    """Verify the training data
58aa34f9e7SJames Wright
59aa34f9e7SJames Wright    Cannot just use np.allclose due to vorticity vector directionality.
60aa34f9e7SJames Wright    Check whether the S-frame-oriented vorticity vector's second component is just flipped.
61aa34f9e7SJames Wright    This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector.
62aa34f9e7SJames Wright    If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip.
63aa34f9e7SJames Wright    This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model.
64aa34f9e7SJames Wright    """
65aa34f9e7SJames Wright    if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol):
66aa34f9e7SJames Wright
67aa34f9e7SJames Wright        total_tolerances = atol + rtol * np.abs(correct_array)  # mimic np.allclose tolerance calculation
68aa34f9e7SJames Wright        idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances)
69aa34f9e7SJames Wright        if not np.all(idx_notclose[1] == 4):
70aa34f9e7SJames Wright            # values other than vorticity are not close
71aa34f9e7SJames Wright            test_fail = True
72aa34f9e7SJames Wright        else:
73aa34f9e7SJames Wright            database_vorticity = database_array[idx_notclose]
74aa34f9e7SJames Wright            correct_vorticity = correct_array[idx_notclose]
75aa34f9e7SJames Wright            test_fail = False if np.allclose(-database_vorticity, correct_vorticity,
76aa34f9e7SJames Wright                                             atol=atol, rtol=rtol) else True
77aa34f9e7SJames Wright
78aa34f9e7SJames Wright        if test_fail:
79aa34f9e7SJames Wright            database_output_path = Path(
80aa34f9e7SJames Wright                f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute()
81aa34f9e7SJames Wright            np.save(database_output_path, database_array)
82aa34f9e7SJames Wright            raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n"
83aa34f9e7SJames Wright                                 f"Array saved to {database_output_path.as_posix()}")
84aa34f9e7SJames Wright
85aa34f9e7SJames Wright
86aa34f9e7SJames Wrightclass SmartSimTest(object):
87aa34f9e7SJames Wright
88aa34f9e7SJames Wright    def __init__(self, directory_path: Path):
89aa34f9e7SJames Wright        self.exp: Experiment
90aa34f9e7SJames Wright        self.database = None
91aa34f9e7SJames Wright        self.directory_path: Path = directory_path
92aa34f9e7SJames Wright        self.original_path: Path
93aa34f9e7SJames Wright
94aa34f9e7SJames Wright    def setup(self):
95aa34f9e7SJames Wright        """To create the test directory and start SmartRedis database"""
96aa34f9e7SJames Wright        self.original_path = Path(os.getcwd())
97aa34f9e7SJames Wright
98aa34f9e7SJames Wright        if self.directory_path.exists() and self.directory_path.is_dir():
99aa34f9e7SJames Wright            shutil.rmtree(self.directory_path)
100aa34f9e7SJames Wright        self.directory_path.mkdir()
101aa34f9e7SJames Wright        os.chdir(self.directory_path)
102aa34f9e7SJames Wright
103aa34f9e7SJames Wright        PORT = getOpenSocket()
104aa34f9e7SJames Wright        self.exp = Experiment("test", launcher="local")
105aa34f9e7SJames Wright        self.database = self.exp.create_database(port=PORT, batch=False, interface="lo")
106aa34f9e7SJames Wright        self.exp.generate(self.database)
107aa34f9e7SJames Wright        self.exp.start(self.database)
108aa34f9e7SJames Wright
109aa34f9e7SJames Wright        # SmartRedis will complain if these aren't set
110aa34f9e7SJames Wright        os.environ['SR_LOG_FILE'] = 'R'
111aa34f9e7SJames Wright        os.environ['SR_LOG_LEVEL'] = 'INFO'
112aa34f9e7SJames Wright
113aa34f9e7SJames Wright    def test(self, ceed_resource) -> Tuple[bool, Exception, str]:
114aa34f9e7SJames Wright        client = None
115aa34f9e7SJames Wright        arguments = []
116aa34f9e7SJames Wright        exe_path = "../../build/navierstokes"
117aa34f9e7SJames Wright        try:
118aa34f9e7SJames Wright            arguments = [
119aa34f9e7SJames Wright                '-ceed', ceed_resource,
120fc37ad8cSJames Wright                '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(),
121aa34f9e7SJames Wright                '-ts_max_steps', '2',
122aa34f9e7SJames Wright                '-diff_filter_grid_based_width',
123c4020f1fSJames Wright                '-ts_monitor_wall_clock_time', '-snes_monitor', '-ts_view_pre',
124aa34f9e7SJames Wright                '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor',
125aa34f9e7SJames Wright                '-degree', '1',
126aa34f9e7SJames Wright                '-sgs_train_enable',
127aa34f9e7SJames Wright                '-sgs_train_write_data_interval', '2',
128aa34f9e7SJames Wright                '-sgs_train_filter_width_scales', '1.2,3.1',
129aa34f9e7SJames Wright                '-bc_symmetry_z',
130aa34f9e7SJames Wright                '-dm_plex_shape', 'zbox',
131aa34f9e7SJames Wright                '-dm_plex_box_bd', 'none,none,periodic',
132aa34f9e7SJames Wright                '-dm_plex_box_faces', '4,6,1',
133*7e3656bdSJames Wright                '-meshtransform_type', '-ts_monitor_smartsim_solution',
134aa34f9e7SJames Wright            ]
135aa34f9e7SJames Wright
136aa34f9e7SJames Wright            run_settings = RunSettings(exe_path, exe_args=arguments)
137aa34f9e7SJames Wright
138aa34f9e7SJames Wright            client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings)
139aa34f9e7SJames Wright
140aa34f9e7SJames Wright            # Start the client model
141aa34f9e7SJames Wright            self.exp.start(client_exp, summary=False, block=True)
142aa34f9e7SJames Wright
143aa34f9e7SJames Wright            client = Client(cluster=False, address=self.database.get_address()[0])
144aa34f9e7SJames Wright
145aa34f9e7SJames Wright            assert client.poll_tensor("sizeInfo", 250, 5)
146aa34f9e7SJames Wright            assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0]))
147aa34f9e7SJames Wright
148aa34f9e7SJames Wright            assert client.poll_tensor("check-run", 250, 5)
149aa34f9e7SJames Wright            assert_equal(client.get_tensor("check-run")[0], 1)
150aa34f9e7SJames Wright
151aa34f9e7SJames Wright            assert client.poll_tensor("tensor-ow", 250, 5)
152aa34f9e7SJames Wright            assert_equal(client.get_tensor("tensor-ow")[0], 1)
153aa34f9e7SJames Wright
154aa34f9e7SJames Wright            assert client.poll_tensor("num_filter_widths", 250, 5)
155aa34f9e7SJames Wright            assert_equal(client.get_tensor("num_filter_widths")[0], 2)
156aa34f9e7SJames Wright
157c4020f1fSJames Wright            assert client.poll_tensor("step", 250, 10)
158aa34f9e7SJames Wright            assert_equal(client.get_tensor("step")[0], 2)
159aa34f9e7SJames Wright
1601fc84d60SJames Wright            assert client.poll_dataset("y.0.flow_solution", 250, 5)
161c4020f1fSJames Wright            test_data_path = test_output_dir / "y0flow_solution_output.npy"
162c4020f1fSJames Wright            assert test_data_path.is_file()
163c4020f1fSJames Wright            correct_value = np.load(test_data_path)
1641fc84d60SJames Wright            dataset = client.get_dataset("y.0.flow_solution")
1651fc84d60SJames Wright            assert "step" in dataset.get_metadata_field_names()
1661fc84d60SJames Wright            assert "time" in dataset.get_metadata_field_names()
1671fc84d60SJames Wright            assert_equal(dataset.get_meta_scalars("step")[0], 2)
1681fc84d60SJames Wright            database_value = dataset.get_tensor("solution")
169c4020f1fSJames Wright            verify_training_data(database_value, correct_value, ceed_resource)
170c4020f1fSJames Wright
171aa34f9e7SJames Wright            assert client.poll_tensor("y.0.0", 250, 5)
172aa34f9e7SJames Wright            test_data_path = test_output_dir / "y00_output.npy"
173aa34f9e7SJames Wright            assert test_data_path.is_file()
174aa34f9e7SJames Wright            correct_value = np.load(test_data_path)
175aa34f9e7SJames Wright            database_value = client.get_tensor("y.0.0")
176aa34f9e7SJames Wright            verify_training_data(database_value, correct_value, ceed_resource)
177aa34f9e7SJames Wright
178aa34f9e7SJames Wright            assert client.poll_tensor("y.0.1", 250, 5)
179aa34f9e7SJames Wright            test_data_path = test_output_dir / "y01_output.npy"
180aa34f9e7SJames Wright            assert test_data_path.is_file()
181aa34f9e7SJames Wright            correct_value = np.load(test_data_path)
182aa34f9e7SJames Wright            database_value = client.get_tensor("y.0.1")
183aa34f9e7SJames Wright            verify_training_data(database_value, correct_value, ceed_resource)
184aa34f9e7SJames Wright
185aa34f9e7SJames Wright            client.flush_db([os.environ["SSDB"]])
186aa34f9e7SJames Wright            output = (True, NoError(), exe_path + ' ' + ' '.join(arguments))
187aa34f9e7SJames Wright        except Exception as e:
188aa34f9e7SJames Wright            output = (False, e, exe_path + ' ' + ' '.join(arguments))
189aa34f9e7SJames Wright
190aa34f9e7SJames Wright        finally:
191aa34f9e7SJames Wright            if client:
192aa34f9e7SJames Wright                client.flush_db([os.environ["SSDB"]])
193aa34f9e7SJames Wright
194aa34f9e7SJames Wright        return output
195aa34f9e7SJames Wright
196aa34f9e7SJames Wright    def test_junit(self, ceed_resource):
197aa34f9e7SJames Wright        start: float = time.time()
198aa34f9e7SJames Wright
199aa34f9e7SJames Wright        passTest, exception, args = self.test(ceed_resource)
200aa34f9e7SJames Wright
201aa34f9e7SJames Wright        output = "" if isinstance(exception, NoError) else ''.join(
202aa34f9e7SJames Wright            traceback.TracebackException.from_exception(exception).format())
203aa34f9e7SJames Wright
204aa34f9e7SJames Wright        test_case = TestCase(f'SmartSim Test {ceed_resource}',
205aa34f9e7SJames Wright                             elapsed_sec=time.time() - start,
206aa34f9e7SJames Wright                             timestamp=time.strftime(
207aa34f9e7SJames Wright                                 '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
208aa34f9e7SJames Wright                             stdout=output,
209aa34f9e7SJames Wright                             stderr=output,
210aa34f9e7SJames Wright                             allow_multiple_subelements=True,
211aa34f9e7SJames Wright                             category=f'SmartSim Tests')
212aa34f9e7SJames Wright        test_case.args = args
213aa34f9e7SJames Wright        if not passTest and 'occa' in ceed_resource:
214aa34f9e7SJames Wright            test_case.add_skipped_info("OCCA mode not supported")
215aa34f9e7SJames Wright        elif not passTest:
216aa34f9e7SJames Wright            test_case.add_failure_info("exception", output)
217aa34f9e7SJames Wright
218aa34f9e7SJames Wright        return test_case
219aa34f9e7SJames Wright
220aa34f9e7SJames Wright    def teardown(self):
221aa34f9e7SJames Wright        self.exp.stop(self.database)
222aa34f9e7SJames Wright        os.chdir(self.original_path)
223aa34f9e7SJames Wright
224aa34f9e7SJames Wright
225aa34f9e7SJames Wrightif __name__ == "__main__":
226aa34f9e7SJames Wright    parser = argparse.ArgumentParser('Testing script for SmartSim integration')
227aa34f9e7SJames Wright    parser.add_argument(
228aa34f9e7SJames Wright        '-c',
229aa34f9e7SJames Wright        '--ceed-backends',
230aa34f9e7SJames Wright        type=str,
231aa34f9e7SJames Wright        nargs='*',
232aa34f9e7SJames Wright        default=['/cpu/self'],
233aa34f9e7SJames Wright        help='libCEED backend to use with convergence tests')
234aa34f9e7SJames Wright    args = parser.parse_args()
235aa34f9e7SJames Wright
236aa34f9e7SJames Wright    test_dir = file_dir / "smartsim_test_dir"
237aa34f9e7SJames Wright    print("Setting up database...", end='')
238aa34f9e7SJames Wright    test_framework = SmartSimTest(test_dir)
239aa34f9e7SJames Wright    test_framework.setup()
240aa34f9e7SJames Wright    print(" Done!")
241aa34f9e7SJames Wright    for ceed_resource in args.ceed_backends:
242aa34f9e7SJames Wright        print("working on " + ceed_resource + ' ...', end='')
243aa34f9e7SJames Wright        passTest, exception, _ = test_framework.test(ceed_resource)
244aa34f9e7SJames Wright
245aa34f9e7SJames Wright        if passTest:
246aa34f9e7SJames Wright            print("Passed!")
247aa34f9e7SJames Wright        else:
248aa34f9e7SJames Wright            print("Failed!", file=sys.stderr)
249aa34f9e7SJames Wright            print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr)
250aa34f9e7SJames Wright
251aa34f9e7SJames Wright    print("Cleaning up database...", end='')
252aa34f9e7SJames Wright    test_framework.teardown()
253aa34f9e7SJames Wright    print(" Done!")
254