1aa34f9e7SJames Wright#!/usr/bin/env python3 2aa34f9e7SJames Wrightfrom junit_xml import TestCase 3aa34f9e7SJames Wrightfrom smartsim import Experiment 4aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings 5aa34f9e7SJames Wrightfrom smartredis import Client 6aa34f9e7SJames Wrightimport numpy as np 7aa34f9e7SJames Wrightfrom pathlib import Path 8aa34f9e7SJames Wrightimport argparse 9aa34f9e7SJames Wrightimport traceback 10aa34f9e7SJames Wrightimport sys 11aa34f9e7SJames Wrightimport time 12aa34f9e7SJames Wrightfrom typing import Tuple 13aa34f9e7SJames Wrightimport os 14aa34f9e7SJames Wrightimport shutil 15aa34f9e7SJames Wrightimport logging 16aa34f9e7SJames Wrightimport socket 17aa34f9e7SJames Wright 18aa34f9e7SJames Wright# autopep8 off 19aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix()) 20aa34f9e7SJames Wright# autopep8 on 21aa34f9e7SJames Wright 22aa34f9e7SJames Wrightlogging.disable(logging.WARNING) 23aa34f9e7SJames Wright 24aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute() 25aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output' 26aa34f9e7SJames Wright 27aa34f9e7SJames Wright 28aa34f9e7SJames Wrightdef getOpenSocket(): 29aa34f9e7SJames Wright s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30aa34f9e7SJames Wright s.bind(('', 0)) 31aa34f9e7SJames Wright addr = s.getsockname() 32aa34f9e7SJames Wright s.close() 33aa34f9e7SJames Wright return addr[1] 34aa34f9e7SJames Wright 35aa34f9e7SJames Wright 36aa34f9e7SJames Wrightclass NoError(Exception): 37aa34f9e7SJames Wright pass 38aa34f9e7SJames Wright 39aa34f9e7SJames Wright 40aa34f9e7SJames Wrightdef assert_np_all(test, truth): 41aa34f9e7SJames Wright """Assert with better error reporting""" 42aa34f9e7SJames Wright try: 43aa34f9e7SJames Wright assert np.all(test == truth) 44aa34f9e7SJames Wright except Exception as e: 45aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 46aa34f9e7SJames Wright 47aa34f9e7SJames Wright 48aa34f9e7SJames Wrightdef assert_equal(test, truth): 49aa34f9e7SJames Wright """Assert with better error reporting""" 50aa34f9e7SJames Wright try: 51aa34f9e7SJames Wright assert test == truth 52aa34f9e7SJames Wright except Exception as e: 53aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 54aa34f9e7SJames Wright 55aa34f9e7SJames Wright 56aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8): 57aa34f9e7SJames Wright """Verify the training data 58aa34f9e7SJames Wright 59aa34f9e7SJames Wright Cannot just use np.allclose due to vorticity vector directionality. 60aa34f9e7SJames Wright Check whether the S-frame-oriented vorticity vector's second component is just flipped. 61aa34f9e7SJames Wright This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector. 62aa34f9e7SJames Wright If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip. 63aa34f9e7SJames Wright This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model. 64aa34f9e7SJames Wright """ 65aa34f9e7SJames Wright if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol): 66aa34f9e7SJames Wright 67aa34f9e7SJames Wright total_tolerances = atol + rtol * np.abs(correct_array) # mimic np.allclose tolerance calculation 68aa34f9e7SJames Wright idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances) 69aa34f9e7SJames Wright if not np.all(idx_notclose[1] == 4): 70aa34f9e7SJames Wright # values other than vorticity are not close 71aa34f9e7SJames Wright test_fail = True 72aa34f9e7SJames Wright else: 73aa34f9e7SJames Wright database_vorticity = database_array[idx_notclose] 74aa34f9e7SJames Wright correct_vorticity = correct_array[idx_notclose] 75aa34f9e7SJames Wright test_fail = False if np.allclose(-database_vorticity, correct_vorticity, 76aa34f9e7SJames Wright atol=atol, rtol=rtol) else True 77aa34f9e7SJames Wright 78aa34f9e7SJames Wright if test_fail: 79aa34f9e7SJames Wright database_output_path = Path( 80aa34f9e7SJames Wright f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute() 81aa34f9e7SJames Wright np.save(database_output_path, database_array) 82aa34f9e7SJames Wright raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n" 83aa34f9e7SJames Wright f"Array saved to {database_output_path.as_posix()}") 84aa34f9e7SJames Wright 85aa34f9e7SJames Wright 86aa34f9e7SJames Wrightclass SmartSimTest(object): 87aa34f9e7SJames Wright 88aa34f9e7SJames Wright def __init__(self, directory_path: Path): 89aa34f9e7SJames Wright self.exp: Experiment 90aa34f9e7SJames Wright self.database = None 91aa34f9e7SJames Wright self.directory_path: Path = directory_path 92aa34f9e7SJames Wright self.original_path: Path 93aa34f9e7SJames Wright 94aa34f9e7SJames Wright def setup(self): 95aa34f9e7SJames Wright """To create the test directory and start SmartRedis database""" 96aa34f9e7SJames Wright self.original_path = Path(os.getcwd()) 97aa34f9e7SJames Wright 98aa34f9e7SJames Wright if self.directory_path.exists() and self.directory_path.is_dir(): 99aa34f9e7SJames Wright shutil.rmtree(self.directory_path) 100aa34f9e7SJames Wright self.directory_path.mkdir() 101aa34f9e7SJames Wright os.chdir(self.directory_path) 102aa34f9e7SJames Wright 103aa34f9e7SJames Wright PORT = getOpenSocket() 104aa34f9e7SJames Wright self.exp = Experiment("test", launcher="local") 105aa34f9e7SJames Wright self.database = self.exp.create_database(port=PORT, batch=False, interface="lo") 106aa34f9e7SJames Wright self.exp.generate(self.database) 107aa34f9e7SJames Wright self.exp.start(self.database) 108aa34f9e7SJames Wright 109aa34f9e7SJames Wright # SmartRedis will complain if these aren't set 110aa34f9e7SJames Wright os.environ['SR_LOG_FILE'] = 'R' 111aa34f9e7SJames Wright os.environ['SR_LOG_LEVEL'] = 'INFO' 112aa34f9e7SJames Wright 113aa34f9e7SJames Wright def test(self, ceed_resource) -> Tuple[bool, Exception, str]: 114aa34f9e7SJames Wright client = None 115aa34f9e7SJames Wright arguments = [] 116aa34f9e7SJames Wright exe_path = "../../build/navierstokes" 117aa34f9e7SJames Wright try: 118aa34f9e7SJames Wright arguments = [ 119aa34f9e7SJames Wright '-ceed', ceed_resource, 120fc37ad8cSJames Wright '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(), 121aa34f9e7SJames Wright '-ts_max_steps', '2', 122aa34f9e7SJames Wright '-diff_filter_grid_based_width', 123c4020f1fSJames Wright '-ts_monitor_wall_clock_time', '-snes_monitor', '-ts_view_pre', 124aa34f9e7SJames Wright '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor', 125aa34f9e7SJames Wright '-degree', '1', 126aa34f9e7SJames Wright '-sgs_train_enable', 127aa34f9e7SJames Wright '-sgs_train_write_data_interval', '2', 128aa34f9e7SJames Wright '-sgs_train_filter_width_scales', '1.2,3.1', 129aa34f9e7SJames Wright '-bc_symmetry_z', 130aa34f9e7SJames Wright '-dm_plex_shape', 'zbox', 131aa34f9e7SJames Wright '-dm_plex_box_bd', 'none,none,periodic', 132aa34f9e7SJames Wright '-dm_plex_box_faces', '4,6,1', 133*7e3656bdSJames Wright '-meshtransform_type', '-ts_monitor_smartsim_solution', 134aa34f9e7SJames Wright ] 135aa34f9e7SJames Wright 136aa34f9e7SJames Wright run_settings = RunSettings(exe_path, exe_args=arguments) 137aa34f9e7SJames Wright 138aa34f9e7SJames Wright client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings) 139aa34f9e7SJames Wright 140aa34f9e7SJames Wright # Start the client model 141aa34f9e7SJames Wright self.exp.start(client_exp, summary=False, block=True) 142aa34f9e7SJames Wright 143aa34f9e7SJames Wright client = Client(cluster=False, address=self.database.get_address()[0]) 144aa34f9e7SJames Wright 145aa34f9e7SJames Wright assert client.poll_tensor("sizeInfo", 250, 5) 146aa34f9e7SJames Wright assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0])) 147aa34f9e7SJames Wright 148aa34f9e7SJames Wright assert client.poll_tensor("check-run", 250, 5) 149aa34f9e7SJames Wright assert_equal(client.get_tensor("check-run")[0], 1) 150aa34f9e7SJames Wright 151aa34f9e7SJames Wright assert client.poll_tensor("tensor-ow", 250, 5) 152aa34f9e7SJames Wright assert_equal(client.get_tensor("tensor-ow")[0], 1) 153aa34f9e7SJames Wright 154aa34f9e7SJames Wright assert client.poll_tensor("num_filter_widths", 250, 5) 155aa34f9e7SJames Wright assert_equal(client.get_tensor("num_filter_widths")[0], 2) 156aa34f9e7SJames Wright 157c4020f1fSJames Wright assert client.poll_tensor("step", 250, 10) 158aa34f9e7SJames Wright assert_equal(client.get_tensor("step")[0], 2) 159aa34f9e7SJames Wright 1601fc84d60SJames Wright assert client.poll_dataset("y.0.flow_solution", 250, 5) 161c4020f1fSJames Wright test_data_path = test_output_dir / "y0flow_solution_output.npy" 162c4020f1fSJames Wright assert test_data_path.is_file() 163c4020f1fSJames Wright correct_value = np.load(test_data_path) 1641fc84d60SJames Wright dataset = client.get_dataset("y.0.flow_solution") 1651fc84d60SJames Wright assert "step" in dataset.get_metadata_field_names() 1661fc84d60SJames Wright assert "time" in dataset.get_metadata_field_names() 1671fc84d60SJames Wright assert_equal(dataset.get_meta_scalars("step")[0], 2) 1681fc84d60SJames Wright database_value = dataset.get_tensor("solution") 169c4020f1fSJames Wright verify_training_data(database_value, correct_value, ceed_resource) 170c4020f1fSJames Wright 171aa34f9e7SJames Wright assert client.poll_tensor("y.0.0", 250, 5) 172aa34f9e7SJames Wright test_data_path = test_output_dir / "y00_output.npy" 173aa34f9e7SJames Wright assert test_data_path.is_file() 174aa34f9e7SJames Wright correct_value = np.load(test_data_path) 175aa34f9e7SJames Wright database_value = client.get_tensor("y.0.0") 176aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 177aa34f9e7SJames Wright 178aa34f9e7SJames Wright assert client.poll_tensor("y.0.1", 250, 5) 179aa34f9e7SJames Wright test_data_path = test_output_dir / "y01_output.npy" 180aa34f9e7SJames Wright assert test_data_path.is_file() 181aa34f9e7SJames Wright correct_value = np.load(test_data_path) 182aa34f9e7SJames Wright database_value = client.get_tensor("y.0.1") 183aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 184aa34f9e7SJames Wright 185aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 186aa34f9e7SJames Wright output = (True, NoError(), exe_path + ' ' + ' '.join(arguments)) 187aa34f9e7SJames Wright except Exception as e: 188aa34f9e7SJames Wright output = (False, e, exe_path + ' ' + ' '.join(arguments)) 189aa34f9e7SJames Wright 190aa34f9e7SJames Wright finally: 191aa34f9e7SJames Wright if client: 192aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 193aa34f9e7SJames Wright 194aa34f9e7SJames Wright return output 195aa34f9e7SJames Wright 196aa34f9e7SJames Wright def test_junit(self, ceed_resource): 197aa34f9e7SJames Wright start: float = time.time() 198aa34f9e7SJames Wright 199aa34f9e7SJames Wright passTest, exception, args = self.test(ceed_resource) 200aa34f9e7SJames Wright 201aa34f9e7SJames Wright output = "" if isinstance(exception, NoError) else ''.join( 202aa34f9e7SJames Wright traceback.TracebackException.from_exception(exception).format()) 203aa34f9e7SJames Wright 204aa34f9e7SJames Wright test_case = TestCase(f'SmartSim Test {ceed_resource}', 205aa34f9e7SJames Wright elapsed_sec=time.time() - start, 206aa34f9e7SJames Wright timestamp=time.strftime( 207aa34f9e7SJames Wright '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 208aa34f9e7SJames Wright stdout=output, 209aa34f9e7SJames Wright stderr=output, 210aa34f9e7SJames Wright allow_multiple_subelements=True, 211aa34f9e7SJames Wright category=f'SmartSim Tests') 212aa34f9e7SJames Wright test_case.args = args 213aa34f9e7SJames Wright if not passTest and 'occa' in ceed_resource: 214aa34f9e7SJames Wright test_case.add_skipped_info("OCCA mode not supported") 215aa34f9e7SJames Wright elif not passTest: 216aa34f9e7SJames Wright test_case.add_failure_info("exception", output) 217aa34f9e7SJames Wright 218aa34f9e7SJames Wright return test_case 219aa34f9e7SJames Wright 220aa34f9e7SJames Wright def teardown(self): 221aa34f9e7SJames Wright self.exp.stop(self.database) 222aa34f9e7SJames Wright os.chdir(self.original_path) 223aa34f9e7SJames Wright 224aa34f9e7SJames Wright 225aa34f9e7SJames Wrightif __name__ == "__main__": 226aa34f9e7SJames Wright parser = argparse.ArgumentParser('Testing script for SmartSim integration') 227aa34f9e7SJames Wright parser.add_argument( 228aa34f9e7SJames Wright '-c', 229aa34f9e7SJames Wright '--ceed-backends', 230aa34f9e7SJames Wright type=str, 231aa34f9e7SJames Wright nargs='*', 232aa34f9e7SJames Wright default=['/cpu/self'], 233aa34f9e7SJames Wright help='libCEED backend to use with convergence tests') 234aa34f9e7SJames Wright args = parser.parse_args() 235aa34f9e7SJames Wright 236aa34f9e7SJames Wright test_dir = file_dir / "smartsim_test_dir" 237aa34f9e7SJames Wright print("Setting up database...", end='') 238aa34f9e7SJames Wright test_framework = SmartSimTest(test_dir) 239aa34f9e7SJames Wright test_framework.setup() 240aa34f9e7SJames Wright print(" Done!") 241aa34f9e7SJames Wright for ceed_resource in args.ceed_backends: 242aa34f9e7SJames Wright print("working on " + ceed_resource + ' ...', end='') 243aa34f9e7SJames Wright passTest, exception, _ = test_framework.test(ceed_resource) 244aa34f9e7SJames Wright 245aa34f9e7SJames Wright if passTest: 246aa34f9e7SJames Wright print("Passed!") 247aa34f9e7SJames Wright else: 248aa34f9e7SJames Wright print("Failed!", file=sys.stderr) 249aa34f9e7SJames Wright print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr) 250aa34f9e7SJames Wright 251aa34f9e7SJames Wright print("Cleaning up database...", end='') 252aa34f9e7SJames Wright test_framework.teardown() 253aa34f9e7SJames Wright print(" Done!") 254