1from abc import ABC, abstractmethod 2import argparse 3from dataclasses import dataclass, field 4import difflib 5from enum import Enum 6from math import isclose 7import os 8from pathlib import Path 9import re 10import subprocess 11import sys 12import time 13from typing import Optional 14 15sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 16from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 17 18 19class CaseInsensitiveEnumAction(argparse.Action): 20 """Action to convert input values to lower case prior to converting to an Enum type""" 21 22 def __init__(self, option_strings, dest, type, default, **kwargs): 23 if not (issubclass(type, Enum) and issubclass(type, str)): 24 raise ValueError(f"{type} must be a StrEnum or str and Enum") 25 # store provided enum type 26 self.enum_type = type 27 if isinstance(default, str): 28 default = self.enum_type(default.lower()) 29 else: 30 default = [self.enum_type(v.lower()) for v in default] 31 # prevent automatic type conversion 32 super().__init__(option_strings, dest, default=default, **kwargs) 33 34 def __call__(self, parser, namespace, values, option_string=None): 35 if isinstance(values, str): 36 values = self.enum_type(values.lower()) 37 else: 38 values = [self.enum_type(v.lower()) for v in values] 39 setattr(namespace, self.dest, values) 40 41 42@dataclass 43class TestSpec: 44 """Dataclass storing information about a single test case""" 45 name: str 46 only: list = field(default_factory=list) 47 args: list = field(default_factory=list) 48 49 50class RunMode(str, Enum): 51 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 52 __str__ = str.__str__ 53 __format__ = str.__format__ 54 TAP: str = 'tap' 55 JUNIT: str = 'junit' 56 57 58class SuiteSpec(ABC): 59 """Abstract Base Class defining the required interface for running a test suite""" 60 @abstractmethod 61 def get_source_path(self, test: str) -> Path: 62 """Compute path to test source file 63 64 Args: 65 test (str): Name of test 66 67 Returns: 68 Path: Path to source file 69 """ 70 raise NotImplementedError 71 72 @abstractmethod 73 def get_run_path(self, test: str) -> Path: 74 """Compute path to built test executable file 75 76 Args: 77 test (str): Name of test 78 79 Returns: 80 Path: Path to test executable 81 """ 82 raise NotImplementedError 83 84 @abstractmethod 85 def get_output_path(self, test: str, output_file: str) -> Path: 86 """Compute path to expected output file 87 88 Args: 89 test (str): Name of test 90 output_file (str): File name of output file 91 92 Returns: 93 Path: Path to expected output file 94 """ 95 raise NotImplementedError 96 97 def post_test_hook(self, test: str, spec: TestSpec) -> None: 98 """Function callback ran after each test case 99 100 Args: 101 test (str): Name of test 102 spec (TestSpec): Test case specification 103 """ 104 pass 105 106 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 107 """Check if a test case should be skipped prior to running, returning the reason for skipping 108 109 Args: 110 test (str): Name of test 111 spec (TestSpec): Test case specification 112 resource (str): libCEED backend 113 nproc (int): Number of MPI processes to use when running test case 114 115 Returns: 116 Optional[str]: Skip reason, or `None` if test case should not be skipped 117 """ 118 return None 119 120 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 121 """Check if a test case should be allowed to fail, based on its stderr output 122 123 Args: 124 test (str): Name of test 125 spec (TestSpec): Test case specification 126 resource (str): libCEED backend 127 stderr (str): Standard error output from test case execution 128 129 Returns: 130 Optional[str]: Skip reason, or `None` if unexpeced error 131 """ 132 return None 133 134 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]: 135 """Check whether a test case is expected to fail and if it failed expectedly 136 137 Args: 138 test (str): Name of test 139 spec (TestSpec): Test case specification 140 resource (str): libCEED backend 141 stderr (str): Standard error output from test case execution 142 143 Returns: 144 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 145 """ 146 return '', True 147 148 def check_allowed_stdout(self, test: str) -> bool: 149 """Check whether a test is allowed to print console output 150 151 Args: 152 test (str): Name of test 153 154 Returns: 155 bool: True if the test is allowed to print console output 156 """ 157 return False 158 159 160def has_cgnsdiff() -> bool: 161 """Check whether `cgnsdiff` is an executable program in the current environment 162 163 Returns: 164 bool: True if `cgnsdiff` is found 165 """ 166 my_env: dict = os.environ.copy() 167 proc = subprocess.run('cgnsdiff', 168 shell=True, 169 stdout=subprocess.PIPE, 170 stderr=subprocess.PIPE, 171 env=my_env) 172 return 'not found' not in proc.stderr.decode('utf-8') 173 174 175def contains_any(base: str, substrings: list[str]) -> bool: 176 """Helper function, checks if any of the substrings are included in the base string 177 178 Args: 179 base (str): Base string to search in 180 substrings (list[str]): List of potential substrings 181 182 Returns: 183 bool: True if any substrings are included in base string 184 """ 185 return any((sub in base for sub in substrings)) 186 187 188def startswith_any(base: str, prefixes: list[str]) -> bool: 189 """Helper function, checks if the base string is prefixed by any of `prefixes` 190 191 Args: 192 base (str): Base string to search 193 prefixes (list[str]): List of potential prefixes 194 195 Returns: 196 bool: True if base string is prefixed by any of the prefixes 197 """ 198 return any((base.startswith(prefix) for prefix in prefixes)) 199 200 201def parse_test_line(line: str) -> TestSpec: 202 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 203 204 Args: 205 line (str): String containing TESTARGS specification and CLI arguments 206 207 Returns: 208 TestSpec: Parsed specification of test case 209 """ 210 args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 211 if args[0] == 'TESTARGS': 212 return TestSpec(name='', args=args[1:]) 213 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 214 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 215 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 216 name: str = test_args.get('name', '') 217 constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else [] 218 if len(args) > 1: 219 return TestSpec(name=name, only=constraints, args=args[1:]) 220 else: 221 return TestSpec(name=name, only=constraints) 222 223 224def get_test_args(source_file: Path) -> list[TestSpec]: 225 """Parse all test cases from a given source file 226 227 Args: 228 source_file (Path): Path to source file 229 230 Raises: 231 RuntimeError: Errors if source file extension is unsupported 232 233 Returns: 234 list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 235 """ 236 comment_str: str = '' 237 if source_file.suffix in ['.c', '.cpp']: 238 comment_str = '//' 239 elif source_file.suffix in ['.py']: 240 comment_str = '#' 241 elif source_file.suffix in ['.usr']: 242 comment_str = 'C_' 243 elif source_file.suffix in ['.f90']: 244 comment_str = '! ' 245 else: 246 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 247 248 return [parse_test_line(line.strip(comment_str)) 249 for line in source_file.read_text().splitlines() 250 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 251 252 253def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 254 """Compare CSV results against an expected CSV file with tolerances 255 256 Args: 257 test_csv (Path): Path to output CSV results 258 true_csv (Path): Path to expected CSV results 259 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 260 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 261 262 Returns: 263 str: Diff output between result and expected CSVs 264 """ 265 test_lines: list[str] = test_csv.read_text().splitlines() 266 true_lines: list[str] = true_csv.read_text().splitlines() 267 268 if test_lines[0] != true_lines[0]: 269 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 270 tofile='found CSV columns', fromfile='expected CSV columns')) 271 272 diff_lines: list[str] = list() 273 column_names: list[str] = true_lines[0].strip().split(',') 274 for test_line, true_line in zip(test_lines[1:], true_lines[1:]): 275 test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')] 276 true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')] 277 for test_val, true_val, column_name in zip(test_vals, true_vals, column_names): 278 true_zero: bool = abs(true_val) < zero_tol 279 test_zero: bool = abs(test_val) < zero_tol 280 fail: bool = False 281 if true_zero: 282 fail = not test_zero 283 else: 284 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 285 if fail: 286 diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}') 287 return '\n'.join(diff_lines) 288 289 290def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str: 291 """Compare CGNS results against an expected CGSN file with tolerance 292 293 Args: 294 test_cgns (Path): Path to output CGNS file 295 true_cgns (Path): Path to expected CGNS file 296 tolerance (float, optional): Tolerance for comparing floating-point values 297 298 Returns: 299 str: Diff output between result and expected CGNS files 300 """ 301 my_env: dict = os.environ.copy() 302 303 run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)] 304 proc = subprocess.run(' '.join(run_args), 305 shell=True, 306 stdout=subprocess.PIPE, 307 stderr=subprocess.PIPE, 308 env=my_env) 309 310 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 311 312 313def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite: 314 """Run all test cases for `test` with each of the provided `ceed_backends` 315 316 Args: 317 test (str): Name of test 318 ceed_backends (list[str]): List of libCEED backends 319 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 320 nproc (int): Number of MPI processes to use when running each test case 321 suite_spec (SuiteSpec): Object defining required methods for running tests 322 323 Returns: 324 TestSuite: JUnit `TestSuite` containing results of all test cases 325 """ 326 source_path: Path = suite_spec.get_source_path(test) 327 test_specs: list[TestSpec] = get_test_args(source_path) 328 329 if mode is RunMode.TAP: 330 print('1..' + str(len(test_specs) * len(ceed_backends))) 331 332 test_cases: list[TestCase] = [] 333 my_env: dict = os.environ.copy() 334 my_env['CEED_ERROR_HANDLER'] = 'exit' 335 336 index: int = 1 337 for spec in test_specs: 338 for ceed_resource in ceed_backends: 339 run_args: list = [suite_spec.get_run_path(test), *spec.args] 340 341 if '{ceed_resource}' in run_args: 342 run_args[run_args.index('{ceed_resource}')] = ceed_resource 343 if '{nproc}' in run_args: 344 run_args[run_args.index('{nproc}')] = f'{nproc}' 345 elif nproc > 1 and source_path.suffix != '.py': 346 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 347 348 # run test 349 skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc) 350 if skip_reason: 351 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 352 elapsed_sec=0, 353 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 354 stdout='', 355 stderr='') 356 test_case.add_skipped_info(skip_reason) 357 else: 358 start: float = time.time() 359 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 360 shell=True, 361 stdout=subprocess.PIPE, 362 stderr=subprocess.PIPE, 363 env=my_env) 364 365 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 366 classname=source_path.parent, 367 elapsed_sec=time.time() - start, 368 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 369 stdout=proc.stdout.decode('utf-8'), 370 stderr=proc.stderr.decode('utf-8'), 371 allow_multiple_subelements=True) 372 ref_csvs: list[Path] = [] 373 output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg] 374 if output_files: 375 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 376 ref_cgns: list[Path] = [] 377 output_files = [arg for arg in spec.args if 'cgns:' in arg] 378 if output_files: 379 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 380 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 381 suite_spec.post_test_hook(test, spec) 382 383 # check allowed failures 384 if not test_case.is_skipped() and test_case.stderr: 385 skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr) 386 if skip_reason: 387 test_case.add_skipped_info(skip_reason) 388 389 # check required failures 390 if not test_case.is_skipped(): 391 required_message, did_fail = suite_spec.check_required_failure(test, spec, ceed_resource, test_case.stderr) 392 if required_message and did_fail: 393 test_case.status = f'fails with required: {required_message}' 394 elif required_message: 395 test_case.add_failure_info(f'required failure missing: {required_message}') 396 397 # classify other results 398 if not test_case.is_skipped() and not test_case.status: 399 if test_case.stderr: 400 test_case.add_failure_info('stderr', test_case.stderr) 401 if proc.returncode != 0: 402 test_case.add_error_info(f'returncode = {proc.returncode}') 403 if ref_stdout.is_file(): 404 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 405 test_case.stdout.splitlines(keepends=True), 406 fromfile=str(ref_stdout), 407 tofile='New')) 408 if diff: 409 test_case.add_failure_info('stdout', output=''.join(diff)) 410 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 411 test_case.add_failure_info('stdout', output=test_case.stdout) 412 # expected CSV output 413 for ref_csv in ref_csvs: 414 if not ref_csv.is_file(): 415 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 416 else: 417 diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv) 418 if diff: 419 test_case.add_failure_info('csv', output=diff) 420 else: 421 (Path.cwd() / ref_csv.name).unlink() 422 # expected CGNS output 423 for ref_cgn in ref_cgns: 424 if not ref_cgn.is_file(): 425 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 426 else: 427 diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn) 428 if diff: 429 test_case.add_failure_info('cgns', output=diff) 430 else: 431 (Path.cwd() / ref_cgn.name).unlink() 432 433 # store result 434 test_case.args = ' '.join(str(arg) for arg in run_args) 435 test_cases.append(test_case) 436 437 if mode is RunMode.TAP: 438 # print incremental output if TAP mode 439 print(f'# Test: {spec.name}') 440 if spec.only: 441 print('# Only: {}'.format(','.join(spec.only))) 442 print(f'# $ {test_case.args}') 443 if test_case.is_skipped(): 444 print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip())) 445 elif test_case.is_failure() or test_case.is_error(): 446 print(f'not ok {index}') 447 if test_case.is_error(): 448 print(f' ERROR: {test_case.errors[0]["message"]}') 449 if test_case.is_failure(): 450 for i, failure in enumerate(test_case.failures): 451 print(f' FAILURE {i}: {failure["message"]}') 452 print(f' Output: \n{failure["output"]}') 453 else: 454 print(f'ok {index} - PASS') 455 sys.stdout.flush() 456 else: 457 # print error or failure information if JUNIT mode 458 if test_case.is_error() or test_case.is_failure(): 459 print(f'Test: {test} {spec.name}') 460 print(f' $ {test_case.args}') 461 if test_case.is_error(): 462 print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())) 463 print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())) 464 if test_case.is_failure(): 465 for failure in test_case.failures: 466 print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip())) 467 print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip())) 468 sys.stdout.flush() 469 index += 1 470 471 return TestSuite(test, test_cases) 472 473 474def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 475 """Write a JUnit XML file containing the results of a `TestSuite` 476 477 Args: 478 test_suite (TestSuite): JUnit `TestSuite` to write 479 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 480 batch (str): Name of JUnit batch, defaults to empty string 481 """ 482 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 483 output_file.write_text(to_xml_report_string([test_suite])) 484 485 486def has_failures(test_suite: TestSuite) -> bool: 487 """Check whether any test cases in a `TestSuite` failed 488 489 Args: 490 test_suite (TestSuite): JUnit `TestSuite` to check 491 492 Returns: 493 bool: True if any test cases failed 494 """ 495 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 496