1from abc import ABC, abstractmethod 2import argparse 3from dataclasses import dataclass, field 4import difflib 5from enum import Enum 6from math import isclose 7import os 8from pathlib import Path 9import re 10import subprocess 11import sys 12import time 13from typing import Optional 14 15sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 16from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 17 18 19class CaseInsensitiveEnumAction(argparse.Action): 20 """Action to convert input values to lower case prior to converting to an Enum type""" 21 22 def __init__(self, option_strings, dest, type, default, **kwargs): 23 if not (issubclass(type, Enum) and issubclass(type, str)): 24 raise ValueError(f"{type} must be a StrEnum or str and Enum") 25 # store provided enum type 26 self.enum_type = type 27 if isinstance(default, str): 28 default = self.enum_type(default.lower()) 29 else: 30 default = [self.enum_type(v.lower()) for v in default] 31 # prevent automatic type conversion 32 super().__init__(option_strings, dest, default=default, **kwargs) 33 34 def __call__(self, parser, namespace, values, option_string=None): 35 if isinstance(values, str): 36 values = self.enum_type(values.lower()) 37 else: 38 values = [self.enum_type(v.lower()) for v in values] 39 setattr(namespace, self.dest, values) 40 41 42@dataclass 43class TestSpec: 44 """Dataclass storing information about a single test case""" 45 name: str 46 only: list = field(default_factory=list) 47 args: list = field(default_factory=list) 48 49 50class RunMode(str, Enum): 51 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 52 __str__ = str.__str__ 53 __format__ = str.__format__ 54 TAP: str = 'tap' 55 JUNIT: str = 'junit' 56 57 58class SuiteSpec(ABC): 59 """Abstract Base Class defining the required interface for running a test suite""" 60 @abstractmethod 61 def get_source_path(self, test: str) -> Path: 62 """Compute path to test source file 63 64 Args: 65 test (str): Name of test 66 67 Returns: 68 Path: Path to source file 69 """ 70 raise NotImplementedError 71 72 @abstractmethod 73 def get_run_path(self, test: str) -> Path: 74 """Compute path to built test executable file 75 76 Args: 77 test (str): Name of test 78 79 Returns: 80 Path: Path to test executable 81 """ 82 raise NotImplementedError 83 84 @abstractmethod 85 def get_output_path(self, test: str, output_file: str) -> Path: 86 """Compute path to expected output file 87 88 Args: 89 test (str): Name of test 90 output_file (str): File name of output file 91 92 Returns: 93 Path: Path to expected output file 94 """ 95 raise NotImplementedError 96 97 def post_test_hook(self, test: str, spec: TestSpec) -> None: 98 """Function callback ran after each test case 99 100 Args: 101 test (str): Name of test 102 spec (TestSpec): Test case specification 103 """ 104 pass 105 106 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 107 """Check if a test case should be skipped prior to running, returning the reason for skipping 108 109 Args: 110 test (str): Name of test 111 spec (TestSpec): Test case specification 112 resource (str): libCEED backend 113 nproc (int): Number of MPI processes to use when running test case 114 115 Returns: 116 Optional[str]: Skip reason, or `None` if test case should not be skipped 117 """ 118 return None 119 120 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 121 """Check if a test case should be allowed to fail, based on its stderr output 122 123 Args: 124 test (str): Name of test 125 spec (TestSpec): Test case specification 126 resource (str): libCEED backend 127 stderr (str): Standard error output from test case execution 128 129 Returns: 130 Optional[str]: Skip reason, or `None` if unexpeced error 131 """ 132 return None 133 134 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]: 135 """Check whether a test case is expected to fail and if it failed expectedly 136 137 Args: 138 test (str): Name of test 139 spec (TestSpec): Test case specification 140 resource (str): libCEED backend 141 stderr (str): Standard error output from test case execution 142 143 Returns: 144 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 145 """ 146 return '', True 147 148 def check_allowed_stdout(self, test: str) -> bool: 149 """Check whether a test is allowed to print console output 150 151 Args: 152 test (str): Name of test 153 154 Returns: 155 bool: True if the test is allowed to print console output 156 """ 157 return False 158 159 160def has_cgnsdiff() -> bool: 161 """Check whether `cgnsdiff` is an executable program in the current environment 162 163 Returns: 164 bool: True if `cgnsdiff` is found 165 """ 166 my_env: dict = os.environ.copy() 167 proc = subprocess.run('cgnsdiff', 168 shell=True, 169 stdout=subprocess.PIPE, 170 stderr=subprocess.PIPE, 171 env=my_env) 172 return 'not found' not in proc.stderr.decode('utf-8') 173 174 175def contains_any(base: str, substrings: list[str]) -> bool: 176 """Helper function, checks if any of the substrings are included in the base string 177 178 Args: 179 base (str): Base string to search in 180 substrings (list[str]): List of potential substrings 181 182 Returns: 183 bool: True if any substrings are included in base string 184 """ 185 return any((sub in base for sub in substrings)) 186 187 188def startswith_any(base: str, prefixes: list[str]) -> bool: 189 """Helper function, checks if the base string is prefixed by any of `prefixes` 190 191 Args: 192 base (str): Base string to search 193 prefixes (list[str]): List of potential prefixes 194 195 Returns: 196 bool: True if base string is prefixed by any of the prefixes 197 """ 198 return any((base.startswith(prefix) for prefix in prefixes)) 199 200 201def parse_test_line(line: str) -> TestSpec: 202 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 203 204 Args: 205 line (str): String containing TESTARGS specification and CLI arguments 206 207 Returns: 208 TestSpec: Parsed specification of test case 209 """ 210 args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 211 if args[0] == 'TESTARGS': 212 return TestSpec(name='', args=args[1:]) 213 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 214 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 215 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 216 name: str = test_args.get('name', '') 217 constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else [] 218 if len(args) > 1: 219 return TestSpec(name=name, only=constraints, args=args[1:]) 220 else: 221 return TestSpec(name=name, only=constraints) 222 223 224def get_test_args(source_file: Path) -> list[TestSpec]: 225 """Parse all test cases from a given source file 226 227 Args: 228 source_file (Path): Path to source file 229 230 Raises: 231 RuntimeError: Errors if source file extension is unsupported 232 233 Returns: 234 list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 235 """ 236 comment_str: str = '' 237 if source_file.suffix in ['.c', '.cpp']: 238 comment_str = '//' 239 elif source_file.suffix in ['.py']: 240 comment_str = '#' 241 elif source_file.suffix in ['.usr']: 242 comment_str = 'C_' 243 elif source_file.suffix in ['.f90']: 244 comment_str = '! ' 245 else: 246 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 247 248 return [parse_test_line(line.strip(comment_str)) 249 for line in source_file.read_text().splitlines() 250 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 251 252 253def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 254 """Compare CSV results against an expected CSV file with tolerances 255 256 Args: 257 test_csv (Path): Path to output CSV results 258 true_csv (Path): Path to expected CSV results 259 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 260 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 261 262 Returns: 263 str: Diff output between result and expected CSVs 264 """ 265 test_lines: list[str] = test_csv.read_text().splitlines() 266 true_lines: list[str] = true_csv.read_text().splitlines() 267 268 if test_lines[0] != true_lines[0]: 269 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 270 tofile='found CSV columns', fromfile='expected CSV columns')) 271 272 diff_lines: list[str] = list() 273 column_names: list[str] = true_lines[0].strip().split(',') 274 for test_line, true_line in zip(test_lines[1:], true_lines[1:]): 275 test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')] 276 true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')] 277 for test_val, true_val, column_name in zip(test_vals, true_vals, column_names): 278 true_zero: bool = abs(true_val) < zero_tol 279 test_zero: bool = abs(test_val) < zero_tol 280 fail: bool = False 281 if true_zero: 282 fail = not test_zero 283 else: 284 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 285 if fail: 286 diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}') 287 return '\n'.join(diff_lines) 288 289 290def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str: 291 """Compare CGNS results against an expected CGSN file with tolerance 292 293 Args: 294 test_cgns (Path): Path to output CGNS file 295 true_cgns (Path): Path to expected CGNS file 296 tolerance (float, optional): Tolerance for comparing floating-point values 297 298 Returns: 299 str: Diff output between result and expected CGNS files 300 """ 301 my_env: dict = os.environ.copy() 302 303 run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)] 304 proc = subprocess.run(' '.join(run_args), 305 shell=True, 306 stdout=subprocess.PIPE, 307 stderr=subprocess.PIPE, 308 env=my_env) 309 310 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 311 312 313def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite: 314 """Run all test cases for `test` with each of the provided `ceed_backends` 315 316 Args: 317 test (str): Name of test 318 ceed_backends (list[str]): List of libCEED backends 319 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 320 nproc (int): Number of MPI processes to use when running each test case 321 suite_spec (SuiteSpec): Object defining required methods for running tests 322 323 Returns: 324 TestSuite: JUnit `TestSuite` containing results of all test cases 325 """ 326 source_path: Path = suite_spec.get_source_path(test) 327 test_specs: list[TestSpec] = get_test_args(source_path) 328 329 if mode is RunMode.TAP: 330 print('1..' + str(len(test_specs) * len(ceed_backends))) 331 332 test_cases: list[TestCase] = [] 333 my_env: dict = os.environ.copy() 334 my_env['CEED_ERROR_HANDLER'] = 'exit' 335 336 index: int = 1 337 for spec in test_specs: 338 for ceed_resource in ceed_backends: 339 run_args: list = [suite_spec.get_run_path(test), *spec.args] 340 341 if '{ceed_resource}' in run_args: 342 run_args[run_args.index('{ceed_resource}')] = ceed_resource 343 if '{nproc}' in run_args: 344 run_args[run_args.index('{nproc}')] = f'{nproc}' 345 elif nproc > 1 and source_path.suffix != '.py': 346 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 347 348 # run test 349 skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc) 350 if skip_reason: 351 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 352 elapsed_sec=0, 353 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 354 stdout='', 355 stderr='') 356 test_case.add_skipped_info(skip_reason) 357 else: 358 start: float = time.time() 359 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 360 shell=True, 361 stdout=subprocess.PIPE, 362 stderr=subprocess.PIPE, 363 env=my_env) 364 365 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 366 classname=source_path.parent, 367 elapsed_sec=time.time() - start, 368 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 369 stdout=proc.stdout.decode('utf-8'), 370 stderr=proc.stderr.decode('utf-8'), 371 allow_multiple_subelements=True) 372 ref_csvs: list[Path] = [] 373 output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg] 374 if output_files: 375 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 376 ref_cgns: list[Path] = [] 377 output_files = [arg for arg in spec.args if 'cgns:' in arg] 378 if output_files: 379 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 380 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 381 suite_spec.post_test_hook(test, spec) 382 383 # check allowed failures 384 if not test_case.is_skipped() and test_case.stderr: 385 skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr) 386 if skip_reason: 387 test_case.add_skipped_info(skip_reason) 388 389 # check required failures 390 if not test_case.is_skipped(): 391 required_message, did_fail = suite_spec.check_required_failure( 392 test, spec, ceed_resource, test_case.stderr) 393 if required_message and did_fail: 394 test_case.status = f'fails with required: {required_message}' 395 elif required_message: 396 test_case.add_failure_info(f'required failure missing: {required_message}') 397 398 # classify other results 399 if not test_case.is_skipped() and not test_case.status: 400 if test_case.stderr: 401 test_case.add_failure_info('stderr', test_case.stderr) 402 if proc.returncode != 0: 403 test_case.add_error_info(f'returncode = {proc.returncode}') 404 if ref_stdout.is_file(): 405 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 406 test_case.stdout.splitlines(keepends=True), 407 fromfile=str(ref_stdout), 408 tofile='New')) 409 if diff: 410 test_case.add_failure_info('stdout', output=''.join(diff)) 411 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 412 test_case.add_failure_info('stdout', output=test_case.stdout) 413 # expected CSV output 414 for ref_csv in ref_csvs: 415 if not ref_csv.is_file(): 416 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 417 else: 418 diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv) 419 if diff: 420 test_case.add_failure_info('csv', output=diff) 421 else: 422 (Path.cwd() / ref_csv.name).unlink() 423 # expected CGNS output 424 for ref_cgn in ref_cgns: 425 if not ref_cgn.is_file(): 426 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 427 else: 428 diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn) 429 if diff: 430 test_case.add_failure_info('cgns', output=diff) 431 else: 432 (Path.cwd() / ref_cgn.name).unlink() 433 434 # store result 435 test_case.args = ' '.join(str(arg) for arg in run_args) 436 test_cases.append(test_case) 437 438 if mode is RunMode.TAP: 439 # print incremental output if TAP mode 440 print(f'# Test: {spec.name}') 441 if spec.only: 442 print('# Only: {}'.format(','.join(spec.only))) 443 print(f'# $ {test_case.args}') 444 if test_case.is_skipped(): 445 print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip())) 446 elif test_case.is_failure() or test_case.is_error(): 447 print(f'not ok {index}') 448 if test_case.is_error(): 449 print(f' ERROR: {test_case.errors[0]["message"]}') 450 if test_case.is_failure(): 451 for i, failure in enumerate(test_case.failures): 452 print(f' FAILURE {i}: {failure["message"]}') 453 print(f' Output: \n{failure["output"]}') 454 else: 455 print(f'ok {index} - PASS') 456 sys.stdout.flush() 457 else: 458 # print error or failure information if JUNIT mode 459 if test_case.is_error() or test_case.is_failure(): 460 print(f'Test: {test} {spec.name}') 461 print(f' $ {test_case.args}') 462 if test_case.is_error(): 463 print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())) 464 print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())) 465 if test_case.is_failure(): 466 for failure in test_case.failures: 467 print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip())) 468 print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip())) 469 sys.stdout.flush() 470 index += 1 471 472 return TestSuite(test, test_cases) 473 474 475def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 476 """Write a JUnit XML file containing the results of a `TestSuite` 477 478 Args: 479 test_suite (TestSuite): JUnit `TestSuite` to write 480 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 481 batch (str): Name of JUnit batch, defaults to empty string 482 """ 483 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 484 output_file.write_text(to_xml_report_string([test_suite])) 485 486 487def has_failures(test_suite: TestSuite) -> bool: 488 """Check whether any test cases in a `TestSuite` failed 489 490 Args: 491 test_suite (TestSuite): JUnit `TestSuite` to check 492 493 Returns: 494 bool: True if any test cases failed 495 """ 496 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 497