1#!/usr/bin/env python3 2from junit_xml import TestCase 3from smartsim import Experiment 4from smartsim.settings import RunSettings 5from smartredis import Client 6import numpy as np 7from pathlib import Path 8import argparse 9import traceback 10import sys 11import time 12from typing import Tuple 13import os 14import shutil 15import logging 16import socket 17 18# autopep8 off 19sys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix()) 20# autopep8 on 21 22logging.disable(logging.WARNING) 23 24file_dir = Path(__file__).parent.absolute() 25test_output_dir = Path(__file__).parent.absolute() / 'output' 26 27 28def getOpenSocket(): 29 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30 s.bind(('', 0)) 31 addr = s.getsockname() 32 s.close() 33 return addr[1] 34 35 36class NoError(Exception): 37 pass 38 39 40def assert_np_all(test, truth): 41 """Assert with better error reporting""" 42 try: 43 assert np.all(test == truth) 44 except Exception as e: 45 raise Exception(f"Expected {truth}, but got {test}") from e 46 47 48def assert_equal(test, truth): 49 """Assert with better error reporting""" 50 try: 51 assert test == truth 52 except Exception as e: 53 raise Exception(f"Expected {truth}, but got {test}") from e 54 55 56def verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8): 57 """Verify the training data 58 59 Cannot just use np.allclose due to vorticity vector directionality. 60 Check whether the S-frame-oriented vorticity vector's second component is just flipped. 61 This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector. 62 If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip. 63 This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model. 64 """ 65 if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol): 66 67 total_tolerances = atol + rtol * np.abs(correct_array) # mimic np.allclose tolerance calculation 68 idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances) 69 if not np.all(idx_notclose[1] == 4): 70 # values other than vorticity are not close 71 test_fail = True 72 else: 73 database_vorticity = database_array[idx_notclose] 74 correct_vorticity = correct_array[idx_notclose] 75 test_fail = False if np.allclose(-database_vorticity, correct_vorticity, 76 atol=atol, rtol=rtol) else True 77 78 if test_fail: 79 database_output_path = Path( 80 f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute() 81 np.save(database_output_path, database_array) 82 raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n" 83 f"Array saved to {database_output_path.as_posix()}") 84 85 86class SmartSimTest(object): 87 88 def __init__(self, directory_path: Path): 89 self.exp: Experiment 90 self.database = None 91 self.directory_path: Path = directory_path 92 self.original_path: Path 93 94 def setup(self): 95 """To create the test directory and start SmartRedis database""" 96 self.original_path = Path(os.getcwd()) 97 98 if self.directory_path.exists() and self.directory_path.is_dir(): 99 shutil.rmtree(self.directory_path) 100 self.directory_path.mkdir() 101 os.chdir(self.directory_path) 102 103 PORT = getOpenSocket() 104 self.exp = Experiment("test", launcher="local") 105 self.database = self.exp.create_database(port=PORT, batch=False, interface="lo") 106 self.exp.generate(self.database) 107 self.exp.start(self.database) 108 109 # SmartRedis will complain if these aren't set 110 os.environ['SR_LOG_FILE'] = 'R' 111 os.environ['SR_LOG_LEVEL'] = 'INFO' 112 113 def test(self, ceed_resource) -> Tuple[bool, Exception, str]: 114 client = None 115 arguments = [] 116 exe_path = "../../build/navierstokes" 117 try: 118 arguments = [ 119 '-ceed', ceed_resource, 120 '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(), 121 '-ts_max_steps', '2', 122 '-diff_filter_grid_based_width', 123 '-ts_monitor_wall_clock_time', '-snes_monitor', '-ts_view_pre', 124 '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor', 125 '-degree', '1', 126 '-sgs_train_enable', 127 '-sgs_train_write_data_interval', '2', 128 '-sgs_train_filter_width_scales', '1.2,3.1', 129 '-bc_symmetry_z', 130 '-dm_plex_shape', 'zbox', 131 '-dm_plex_box_bd', 'none,none,periodic', 132 '-dm_plex_box_faces', '4,6,1', 133 '-mesh_transform', '-ts_monitor_smartsim_solution', 134 ] 135 136 run_settings = RunSettings(exe_path, exe_args=arguments) 137 138 client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings) 139 140 # Start the client model 141 self.exp.start(client_exp, summary=False, block=True) 142 143 client = Client(cluster=False, address=self.database.get_address()[0]) 144 145 assert client.poll_tensor("sizeInfo", 250, 5) 146 assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0])) 147 148 assert client.poll_tensor("check-run", 250, 5) 149 assert_equal(client.get_tensor("check-run")[0], 1) 150 151 assert client.poll_tensor("tensor-ow", 250, 5) 152 assert_equal(client.get_tensor("tensor-ow")[0], 1) 153 154 assert client.poll_tensor("num_filter_widths", 250, 5) 155 assert_equal(client.get_tensor("num_filter_widths")[0], 2) 156 157 assert client.poll_tensor("step", 250, 10) 158 assert_equal(client.get_tensor("step")[0], 2) 159 160 assert client.poll_dataset("y.0.flow_solution", 250, 5) 161 test_data_path = test_output_dir / "y0flow_solution_output.npy" 162 assert test_data_path.is_file() 163 correct_value = np.load(test_data_path) 164 dataset = client.get_dataset("y.0.flow_solution") 165 assert "step" in dataset.get_metadata_field_names() 166 assert "time" in dataset.get_metadata_field_names() 167 assert_equal(dataset.get_meta_scalars("step")[0], 2) 168 database_value = dataset.get_tensor("solution") 169 verify_training_data(database_value, correct_value, ceed_resource) 170 171 assert client.poll_tensor("y.0.0", 250, 5) 172 test_data_path = test_output_dir / "y00_output.npy" 173 assert test_data_path.is_file() 174 correct_value = np.load(test_data_path) 175 database_value = client.get_tensor("y.0.0") 176 verify_training_data(database_value, correct_value, ceed_resource) 177 178 assert client.poll_tensor("y.0.1", 250, 5) 179 test_data_path = test_output_dir / "y01_output.npy" 180 assert test_data_path.is_file() 181 correct_value = np.load(test_data_path) 182 database_value = client.get_tensor("y.0.1") 183 verify_training_data(database_value, correct_value, ceed_resource) 184 185 client.flush_db([os.environ["SSDB"]]) 186 output = (True, NoError(), exe_path + ' ' + ' '.join(arguments)) 187 except Exception as e: 188 output = (False, e, exe_path + ' ' + ' '.join(arguments)) 189 190 finally: 191 if client: 192 client.flush_db([os.environ["SSDB"]]) 193 194 return output 195 196 def test_junit(self, ceed_resource): 197 start: float = time.time() 198 199 passTest, exception, args = self.test(ceed_resource) 200 201 output = "" if isinstance(exception, NoError) else ''.join( 202 traceback.TracebackException.from_exception(exception).format()) 203 204 test_case = TestCase(f'SmartSim Test {ceed_resource}', 205 elapsed_sec=time.time() - start, 206 timestamp=time.strftime( 207 '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 208 stdout=output, 209 stderr=output, 210 allow_multiple_subelements=True, 211 category=f'SmartSim Tests') 212 test_case.args = args 213 if not passTest and 'occa' in ceed_resource: 214 test_case.add_skipped_info("OCCA mode not supported") 215 elif not passTest: 216 test_case.add_failure_info("exception", output) 217 218 return test_case 219 220 def teardown(self): 221 self.exp.stop(self.database) 222 os.chdir(self.original_path) 223 224 225if __name__ == "__main__": 226 parser = argparse.ArgumentParser('Testing script for SmartSim integration') 227 parser.add_argument( 228 '-c', 229 '--ceed-backends', 230 type=str, 231 nargs='*', 232 default=['/cpu/self'], 233 help='libCEED backend to use with convergence tests') 234 args = parser.parse_args() 235 236 test_dir = file_dir / "smartsim_test_dir" 237 print("Setting up database...", end='') 238 test_framework = SmartSimTest(test_dir) 239 test_framework.setup() 240 print(" Done!") 241 for ceed_resource in args.ceed_backends: 242 print("working on " + ceed_resource + ' ...', end='') 243 passTest, exception, _ = test_framework.test(ceed_resource) 244 245 if passTest: 246 print("Passed!") 247 else: 248 print("Failed!", file=sys.stderr) 249 print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr) 250 251 print("Cleaning up database...", end='') 252 test_framework.teardown() 253 print(" Done!") 254