xref: /honee/tests/smartsim_regression_framework.py (revision 00359db47665a79ecb0241f6ccbf886b649022df)
1#!/usr/bin/env python3
2from junit_xml import TestCase
3from smartsim import Experiment
4from smartsim.settings import RunSettings
5from smartredis import Client
6import numpy as np
7from pathlib import Path
8import argparse
9import traceback
10import sys
11import time
12from typing import Tuple
13import os
14import shutil
15import logging
16import socket
17
18# autopep8 off
19sys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix())
20# autopep8 on
21
22logging.disable(logging.WARNING)
23
24file_dir = Path(__file__).parent.absolute()
25test_output_dir = Path(__file__).parent.absolute() / 'output'
26
27
28def getOpenSocket():
29    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
30    s.bind(('', 0))
31    addr = s.getsockname()
32    s.close()
33    return addr[1]
34
35
36class NoError(Exception):
37    pass
38
39
40def assert_np_all(test, truth):
41    """Assert with better error reporting"""
42    try:
43        assert np.all(test == truth)
44    except Exception as e:
45        raise Exception(f"Expected {truth}, but got {test}") from e
46
47
48def assert_equal(test, truth):
49    """Assert with better error reporting"""
50    try:
51        assert test == truth
52    except Exception as e:
53        raise Exception(f"Expected {truth}, but got {test}") from e
54
55
56def verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8):
57    """Verify the training data
58
59    Cannot just use np.allclose due to vorticity vector directionality.
60    Check whether the S-frame-oriented vorticity vector's second component is just flipped.
61    This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector.
62    If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip.
63    This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model.
64    """
65    if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol):
66
67        total_tolerances = atol + rtol * np.abs(correct_array)  # mimic np.allclose tolerance calculation
68        idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances)
69        if not np.all(idx_notclose[1] == 4):
70            # values other than vorticity are not close
71            test_fail = True
72        else:
73            database_vorticity = database_array[idx_notclose]
74            correct_vorticity = correct_array[idx_notclose]
75            test_fail = False if np.allclose(-database_vorticity, correct_vorticity,
76                                             atol=atol, rtol=rtol) else True
77
78        if test_fail:
79            database_output_path = Path(
80                f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute()
81            np.save(database_output_path, database_array)
82            raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n"
83                                 f"Array saved to {database_output_path.as_posix()}")
84
85
86class SmartSimTest(object):
87
88    def __init__(self, directory_path: Path):
89        self.exp: Experiment
90        self.database = None
91        self.directory_path: Path = directory_path
92        self.original_path: Path
93
94    def setup(self):
95        """To create the test directory and start SmartRedis database"""
96        self.original_path = Path(os.getcwd())
97
98        if self.directory_path.exists() and self.directory_path.is_dir():
99            shutil.rmtree(self.directory_path)
100        self.directory_path.mkdir()
101        os.chdir(self.directory_path)
102
103        PORT = getOpenSocket()
104        self.exp = Experiment("test", launcher="local")
105        self.database = self.exp.create_database(port=PORT, batch=False, interface="lo")
106        self.exp.generate(self.database)
107        self.exp.start(self.database)
108
109        # SmartRedis will complain if these aren't set
110        os.environ['SR_LOG_FILE'] = 'R'
111        os.environ['SR_LOG_LEVEL'] = 'INFO'
112
113    def test(self, ceed_resource) -> Tuple[bool, Exception, str]:
114        client = None
115        arguments = []
116        exe_path = "../../build/navierstokes"
117        try:
118            arguments = [
119                '-ceed', ceed_resource,
120                '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(),
121                '-ts_max_steps', '2',
122                '-diff_filter_grid_based_width',
123                '-ts_monitor_wall_clock_time', '-snes_monitor', '-ts_view_pre',
124                '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor',
125                '-degree', '1',
126                '-sgs_train_enable',
127                '-sgs_train_write_data_interval', '2',
128                '-sgs_train_filter_width_scales', '1.2,3.1',
129                '-bc_symmetry_z',
130                '-dm_plex_shape', 'zbox',
131                '-dm_plex_box_bd', 'none,none,periodic',
132                '-dm_plex_box_faces', '4,6,1',
133                '-meshtransform_type', '-ts_monitor_smartsim_solution',
134            ]
135
136            run_settings = RunSettings(exe_path, exe_args=arguments)
137
138            client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings)
139
140            # Start the client model
141            self.exp.start(client_exp, summary=False, block=True)
142
143            client = Client(cluster=False, address=self.database.get_address()[0])
144
145            assert client.poll_tensor("sizeInfo", 250, 5)
146            assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0]))
147
148            assert client.poll_tensor("check-run", 250, 5)
149            assert_equal(client.get_tensor("check-run")[0], 1)
150
151            assert client.poll_tensor("tensor-ow", 250, 5)
152            assert_equal(client.get_tensor("tensor-ow")[0], 1)
153
154            assert client.poll_tensor("num_filter_widths", 250, 5)
155            assert_equal(client.get_tensor("num_filter_widths")[0], 2)
156
157            assert client.poll_tensor("step", 250, 10)
158            assert_equal(client.get_tensor("step")[0], 2)
159
160            assert client.poll_dataset("y.0.flow_solution", 250, 5)
161            test_data_path = test_output_dir / "y0flow_solution_output.npy"
162            assert test_data_path.is_file()
163            correct_value = np.load(test_data_path)
164            dataset = client.get_dataset("y.0.flow_solution")
165            assert "step" in dataset.get_metadata_field_names()
166            assert "time" in dataset.get_metadata_field_names()
167            assert_equal(dataset.get_meta_scalars("step")[0], 2)
168            database_value = dataset.get_tensor("solution")
169            verify_training_data(database_value, correct_value, ceed_resource)
170
171            assert client.poll_tensor("y.0.0", 250, 5)
172            test_data_path = test_output_dir / "y00_output.npy"
173            assert test_data_path.is_file()
174            correct_value = np.load(test_data_path)
175            database_value = client.get_tensor("y.0.0")
176            verify_training_data(database_value, correct_value, ceed_resource)
177
178            assert client.poll_tensor("y.0.1", 250, 5)
179            test_data_path = test_output_dir / "y01_output.npy"
180            assert test_data_path.is_file()
181            correct_value = np.load(test_data_path)
182            database_value = client.get_tensor("y.0.1")
183            verify_training_data(database_value, correct_value, ceed_resource)
184
185            client.flush_db([os.environ["SSDB"]])
186            output = (True, NoError(), exe_path + ' ' + ' '.join(arguments))
187        except Exception as e:
188            output = (False, e, exe_path + ' ' + ' '.join(arguments))
189
190        finally:
191            if client:
192                client.flush_db([os.environ["SSDB"]])
193
194        return output
195
196    def test_junit(self, ceed_resource):
197        start: float = time.time()
198
199        passTest, exception, args = self.test(ceed_resource)
200
201        output = "" if isinstance(exception, NoError) else ''.join(
202            traceback.TracebackException.from_exception(exception).format())
203
204        test_case = TestCase(f'SmartSim Test {ceed_resource}',
205                             elapsed_sec=time.time() - start,
206                             timestamp=time.strftime(
207                                 '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
208                             stdout=output,
209                             stderr=output,
210                             allow_multiple_subelements=True,
211                             category=f'SmartSim Tests')
212        test_case.args = args
213        if not passTest and 'occa' in ceed_resource:
214            test_case.add_skipped_info("OCCA mode not supported")
215        elif not passTest:
216            test_case.add_failure_info("exception", output)
217
218        return test_case
219
220    def teardown(self):
221        self.exp.stop(self.database)
222        os.chdir(self.original_path)
223
224
225if __name__ == "__main__":
226    parser = argparse.ArgumentParser('Testing script for SmartSim integration')
227    parser.add_argument(
228        '-c',
229        '--ceed-backends',
230        type=str,
231        nargs='*',
232        default=['/cpu/self'],
233        help='libCEED backend to use with convergence tests')
234    args = parser.parse_args()
235
236    test_dir = file_dir / "smartsim_test_dir"
237    print("Setting up database...", end='')
238    test_framework = SmartSimTest(test_dir)
239    test_framework.setup()
240    print(" Done!")
241    for ceed_resource in args.ceed_backends:
242        print("working on " + ceed_resource + ' ...', end='')
243        passTest, exception, _ = test_framework.test(ceed_resource)
244
245        if passTest:
246            print("Passed!")
247        else:
248            print("Failed!", file=sys.stderr)
249            print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr)
250
251    print("Cleaning up database...", end='')
252    test_framework.teardown()
253    print(" Done!")
254