1from abc import ABC, abstractmethod 2import argparse 3from dataclasses import dataclass, field 4import difflib 5from enum import Enum 6from math import isclose 7import os 8from pathlib import Path 9import re 10import subprocess 11import sys 12import time 13from typing import Optional 14 15sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 16from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 17 18 19class CaseInsensitiveEnumAction(argparse.Action): 20 """Action to convert input values to lower case prior to converting to an Enum type""" 21 22 def __init__(self, option_strings, dest, type, default, **kwargs): 23 if not (issubclass(type, Enum) and issubclass(type, str)): 24 raise ValueError(f"{type} must be a StrEnum or str and Enum") 25 # store provided enum type 26 self.enum_type = type 27 if isinstance(default, str): 28 default = self.enum_type(default.lower()) 29 else: 30 default = [self.enum_type(v.lower()) for v in default] 31 # prevent automatic type conversion 32 super().__init__(option_strings, dest, default=default, **kwargs) 33 34 def __call__(self, parser, namespace, values, option_string=None): 35 if isinstance(values, str): 36 values = self.enum_type(values.lower()) 37 else: 38 values = [self.enum_type(v.lower()) for v in values] 39 setattr(namespace, self.dest, values) 40 41 42@dataclass 43class TestSpec: 44 """Dataclass storing information about a single test case""" 45 name: str 46 only: list = field(default_factory=list) 47 args: list = field(default_factory=list) 48 49 50class RunMode(str, Enum): 51 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 52 __str__ = str.__str__ 53 __format__ = str.__format__ 54 TAP: str = 'tap' 55 JUNIT: str = 'junit' 56 57 58class SuiteSpec(ABC): 59 """Abstract Base Class defining the required interface for running a test suite""" 60 @abstractmethod 61 def get_source_path(self, test: str) -> Path: 62 """Compute path to test source file 63 64 Args: 65 test (str): Name of test 66 67 Returns: 68 Path: Path to source file 69 """ 70 raise NotImplementedError 71 72 @abstractmethod 73 def get_run_path(self, test: str) -> Path: 74 """Compute path to built test executable file 75 76 Args: 77 test (str): Name of test 78 79 Returns: 80 Path: Path to test executable 81 """ 82 raise NotImplementedError 83 84 @abstractmethod 85 def get_output_path(self, test: str, output_file: str) -> Path: 86 """Compute path to expected output file 87 88 Args: 89 test (str): Name of test 90 output_file (str): File name of output file 91 92 Returns: 93 Path: Path to expected output file 94 """ 95 raise NotImplementedError 96 97 def post_test_hook(self, test: str, spec: TestSpec) -> None: 98 """Function callback ran after each test case 99 100 Args: 101 test (str): Name of test 102 spec (TestSpec): Test case specification 103 """ 104 pass 105 106 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 107 """Check if a test case should be skipped prior to running, returning the reason for skipping 108 109 Args: 110 test (str): Name of test 111 spec (TestSpec): Test case specification 112 resource (str): libCEED backend 113 nproc (int): Number of MPI processes to use when running test case 114 115 Returns: 116 Optional[str]: Skip reason, or `None` if test case should not be skipped 117 """ 118 return None 119 120 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 121 """Check if a test case should be allowed to fail, based on its stderr output 122 123 Args: 124 test (str): Name of test 125 spec (TestSpec): Test case specification 126 resource (str): libCEED backend 127 stderr (str): Standard error output from test case execution 128 129 Returns: 130 Optional[str]: Skip reason, or `None` if unexpeced error 131 """ 132 return None 133 134 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]: 135 """Check whether a test case is expected to fail and if it failed expectedly 136 137 Args: 138 test (str): Name of test 139 spec (TestSpec): Test case specification 140 resource (str): libCEED backend 141 stderr (str): Standard error output from test case execution 142 143 Returns: 144 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 145 """ 146 return '', True 147 148 def check_allowed_stdout(self, test: str) -> bool: 149 """Check whether a test is allowed to print console output 150 151 Args: 152 test (str): Name of test 153 154 Returns: 155 bool: True if the test is allowed to print console output 156 """ 157 return False 158 159 160def has_cgnsdiff() -> bool: 161 """Check whether `cgnsdiff` is an executable program in the current environment 162 163 Returns: 164 bool: True if `cgnsdiff` is found 165 """ 166 my_env: dict = os.environ.copy() 167 proc = subprocess.run('cgnsdiff', 168 shell=True, 169 stdout=subprocess.PIPE, 170 stderr=subprocess.PIPE, 171 env=my_env) 172 return 'not found' not in proc.stderr.decode('utf-8') 173 174 175def contains_any(base: str, substrings: list[str]) -> bool: 176 """Helper function, checks if any of the substrings are included in the base string 177 178 Args: 179 base (str): Base string to search in 180 substrings (list[str]): List of potential substrings 181 182 Returns: 183 bool: True if any substrings are included in base string 184 """ 185 return any((sub in base for sub in substrings)) 186 187 188def startswith_any(base: str, prefixes: list[str]) -> bool: 189 """Helper function, checks if the base string is prefixed by any of `prefixes` 190 191 Args: 192 base (str): Base string to search 193 prefixes (list[str]): List of potential prefixes 194 195 Returns: 196 bool: True if base string is prefixed by any of the prefixes 197 """ 198 return any((base.startswith(prefix) for prefix in prefixes)) 199 200 201def parse_test_line(line: str) -> TestSpec: 202 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 203 204 Args: 205 line (str): String containing TESTARGS specification and CLI arguments 206 207 Returns: 208 TestSpec: Parsed specification of test case 209 """ 210 args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 211 if args[0] == 'TESTARGS': 212 return TestSpec(name='', args=args[1:]) 213 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 214 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 215 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 216 constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else [] 217 if len(args) > 1: 218 return TestSpec(name=test_args['name'], only=constraints, args=args[1:]) 219 else: 220 return TestSpec(name=test_args['name'], only=constraints) 221 222 223def get_test_args(source_file: Path) -> list[TestSpec]: 224 """Parse all test cases from a given source file 225 226 Args: 227 source_file (Path): Path to source file 228 229 Raises: 230 RuntimeError: Errors if source file extension is unsupported 231 232 Returns: 233 list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 234 """ 235 comment_str: str = '' 236 if source_file.suffix in ['.c', '.cpp']: 237 comment_str = '//' 238 elif source_file.suffix in ['.py']: 239 comment_str = '#' 240 elif source_file.suffix in ['.usr']: 241 comment_str = 'C_' 242 elif source_file.suffix in ['.f90']: 243 comment_str = '! ' 244 else: 245 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 246 247 return [parse_test_line(line.strip(comment_str)) 248 for line in source_file.read_text().splitlines() 249 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 250 251 252def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 253 """Compare CSV results against an expected CSV file with tolerances 254 255 Args: 256 test_csv (Path): Path to output CSV results 257 true_csv (Path): Path to expected CSV results 258 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 259 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 260 261 Returns: 262 str: Diff output between result and expected CSVs 263 """ 264 test_lines: list[str] = test_csv.read_text().splitlines() 265 true_lines: list[str] = true_csv.read_text().splitlines() 266 267 if test_lines[0] != true_lines[0]: 268 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 269 tofile='found CSV columns', fromfile='expected CSV columns')) 270 271 diff_lines: list[str] = list() 272 column_names: list[str] = true_lines[0].strip().split(',') 273 for test_line, true_line in zip(test_lines[1:], true_lines[1:]): 274 test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')] 275 true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')] 276 for test_val, true_val, column_name in zip(test_vals, true_vals, column_names): 277 true_zero: bool = abs(true_val) < zero_tol 278 test_zero: bool = abs(test_val) < zero_tol 279 fail: bool = False 280 if true_zero: 281 fail = not test_zero 282 else: 283 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 284 if fail: 285 diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}') 286 return '\n'.join(diff_lines) 287 288 289def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str: 290 """Compare CGNS results against an expected CGSN file with tolerance 291 292 Args: 293 test_cgns (Path): Path to output CGNS file 294 true_cgns (Path): Path to expected CGNS file 295 tolerance (float, optional): Tolerance for comparing floating-point values 296 297 Returns: 298 str: Diff output between result and expected CGNS files 299 """ 300 my_env: dict = os.environ.copy() 301 302 run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)] 303 proc = subprocess.run(' '.join(run_args), 304 shell=True, 305 stdout=subprocess.PIPE, 306 stderr=subprocess.PIPE, 307 env=my_env) 308 309 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 310 311 312def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite: 313 """Run all test cases for `test` with each of the provided `ceed_backends` 314 315 Args: 316 test (str): Name of test 317 ceed_backends (list[str]): List of libCEED backends 318 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 319 nproc (int): Number of MPI processes to use when running each test case 320 suite_spec (SuiteSpec): Object defining required methods for running tests 321 322 Returns: 323 TestSuite: JUnit `TestSuite` containing results of all test cases 324 """ 325 source_path: Path = suite_spec.get_source_path(test) 326 test_specs: list[TestSpec] = get_test_args(source_path) 327 328 if mode is RunMode.TAP: 329 print('1..' + str(len(test_specs) * len(ceed_backends))) 330 331 test_cases: list[TestCase] = [] 332 my_env: dict = os.environ.copy() 333 my_env['CEED_ERROR_HANDLER'] = 'exit' 334 335 index: int = 1 336 for spec in test_specs: 337 for ceed_resource in ceed_backends: 338 run_args: list = [suite_spec.get_run_path(test), *spec.args] 339 340 if '{ceed_resource}' in run_args: 341 run_args[run_args.index('{ceed_resource}')] = ceed_resource 342 if '{nproc}' in run_args: 343 run_args[run_args.index('{nproc}')] = f'{nproc}' 344 elif nproc > 1 and source_path.suffix != '.py': 345 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 346 347 # run test 348 skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc) 349 if skip_reason: 350 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 351 elapsed_sec=0, 352 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 353 stdout='', 354 stderr='') 355 test_case.add_skipped_info(skip_reason) 356 else: 357 start: float = time.time() 358 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 359 shell=True, 360 stdout=subprocess.PIPE, 361 stderr=subprocess.PIPE, 362 env=my_env) 363 364 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}', 365 classname=source_path.parent, 366 elapsed_sec=time.time() - start, 367 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 368 stdout=proc.stdout.decode('utf-8'), 369 stderr=proc.stderr.decode('utf-8'), 370 allow_multiple_subelements=True) 371 ref_csvs: list[Path] = [] 372 output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg] 373 if output_files: 374 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 375 ref_cgns: list[Path] = [] 376 output_files = [arg for arg in spec.args if 'cgns:' in arg] 377 if output_files: 378 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 379 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 380 suite_spec.post_test_hook(test, spec) 381 382 # check allowed failures 383 if not test_case.is_skipped() and test_case.stderr: 384 skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr) 385 if skip_reason: 386 test_case.add_skipped_info(skip_reason) 387 388 # check required failures 389 if not test_case.is_skipped(): 390 required_message, did_fail = suite_spec.check_required_failure(test, spec, ceed_resource, test_case.stderr) 391 if required_message and did_fail: 392 test_case.status = f'fails with required: {required_message}' 393 elif required_message: 394 test_case.add_failure_info(f'required failure missing: {required_message}') 395 396 # classify other results 397 if not test_case.is_skipped() and not test_case.status: 398 if test_case.stderr: 399 test_case.add_failure_info('stderr', test_case.stderr) 400 if proc.returncode != 0: 401 test_case.add_error_info(f'returncode = {proc.returncode}') 402 if ref_stdout.is_file(): 403 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 404 test_case.stdout.splitlines(keepends=True), 405 fromfile=str(ref_stdout), 406 tofile='New')) 407 if diff: 408 test_case.add_failure_info('stdout', output=''.join(diff)) 409 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 410 test_case.add_failure_info('stdout', output=test_case.stdout) 411 # expected CSV output 412 for ref_csv in ref_csvs: 413 if not ref_csv.is_file(): 414 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 415 else: 416 diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv) 417 if diff: 418 test_case.add_failure_info('csv', output=diff) 419 else: 420 (Path.cwd() / ref_csv.name).unlink() 421 # expected CGNS output 422 for ref_cgn in ref_cgns: 423 if not ref_cgn.is_file(): 424 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 425 else: 426 diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn) 427 if diff: 428 test_case.add_failure_info('cgns', output=diff) 429 else: 430 (Path.cwd() / ref_cgn.name).unlink() 431 432 # store result 433 test_case.args = ' '.join(str(arg) for arg in run_args) 434 test_cases.append(test_case) 435 436 if mode is RunMode.TAP: 437 # print incremental output if TAP mode 438 print(f'# Test: {spec.name}') 439 if spec.only: 440 print('# Only: {}'.format(','.join(spec.only))) 441 print(f'# $ {test_case.args}') 442 if test_case.is_skipped(): 443 print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip())) 444 elif test_case.is_failure() or test_case.is_error(): 445 print(f'not ok {index}') 446 if test_case.is_error(): 447 print(f' ERROR: {test_case.errors[0]["message"]}') 448 if test_case.is_failure(): 449 for i, failure in enumerate(test_case.failures): 450 print(f' FAILURE {i}: {failure["message"]}') 451 print(f' Output: \n{failure["output"]}') 452 else: 453 print(f'ok {index} - PASS') 454 sys.stdout.flush() 455 else: 456 # print error or failure information if JUNIT mode 457 if test_case.is_error() or test_case.is_failure(): 458 print(f'Test: {test} {spec.name}') 459 print(f' $ {test_case.args}') 460 if test_case.is_error(): 461 print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())) 462 print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())) 463 if test_case.is_failure(): 464 for failure in test_case.failures: 465 print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip())) 466 print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip())) 467 sys.stdout.flush() 468 index += 1 469 470 return TestSuite(test, test_cases) 471 472 473def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 474 """Write a JUnit XML file containing the results of a `TestSuite` 475 476 Args: 477 test_suite (TestSuite): JUnit `TestSuite` to write 478 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 479 batch (str): Name of JUnit batch, defaults to empty string 480 """ 481 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 482 output_file.write_text(to_xml_report_string([test_suite])) 483 484 485def has_failures(test_suite: TestSuite) -> bool: 486 """Check whether any test cases in a `TestSuite` failed 487 488 Args: 489 test_suite (TestSuite): JUnit `TestSuite` to check 490 491 Returns: 492 bool: True if any test cases failed 493 """ 494 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 495