1from abc import ABC, abstractmethod 2import argparse 3import csv 4from dataclasses import dataclass, field 5import difflib 6from enum import Enum 7from math import isclose 8import os 9from pathlib import Path 10import re 11import subprocess 12import multiprocessing as mp 13from itertools import product 14import sys 15import time 16from typing import Optional, Tuple, List 17 18sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 19from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 20 21 22class CaseInsensitiveEnumAction(argparse.Action): 23 """Action to convert input values to lower case prior to converting to an Enum type""" 24 25 def __init__(self, option_strings, dest, type, default, **kwargs): 26 if not (issubclass(type, Enum) and issubclass(type, str)): 27 raise ValueError(f"{type} must be a StrEnum or str and Enum") 28 # store provided enum type 29 self.enum_type = type 30 if isinstance(default, str): 31 default = self.enum_type(default.lower()) 32 else: 33 default = [self.enum_type(v.lower()) for v in default] 34 # prevent automatic type conversion 35 super().__init__(option_strings, dest, default=default, **kwargs) 36 37 def __call__(self, parser, namespace, values, option_string=None): 38 if isinstance(values, str): 39 values = self.enum_type(values.lower()) 40 else: 41 values = [self.enum_type(v.lower()) for v in values] 42 setattr(namespace, self.dest, values) 43 44 45@dataclass 46class TestSpec: 47 """Dataclass storing information about a single test case""" 48 name: str 49 only: List = field(default_factory=list) 50 args: List = field(default_factory=list) 51 52 53class RunMode(str, Enum): 54 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 55 __str__ = str.__str__ 56 __format__ = str.__format__ 57 TAP: str = 'tap' 58 JUNIT: str = 'junit' 59 60 61class SuiteSpec(ABC): 62 """Abstract Base Class defining the required interface for running a test suite""" 63 @abstractmethod 64 def get_source_path(self, test: str) -> Path: 65 """Compute path to test source file 66 67 Args: 68 test (str): Name of test 69 70 Returns: 71 Path: Path to source file 72 """ 73 raise NotImplementedError 74 75 @abstractmethod 76 def get_run_path(self, test: str) -> Path: 77 """Compute path to built test executable file 78 79 Args: 80 test (str): Name of test 81 82 Returns: 83 Path: Path to test executable 84 """ 85 raise NotImplementedError 86 87 @abstractmethod 88 def get_output_path(self, test: str, output_file: str) -> Path: 89 """Compute path to expected output file 90 91 Args: 92 test (str): Name of test 93 output_file (str): File name of output file 94 95 Returns: 96 Path: Path to expected output file 97 """ 98 raise NotImplementedError 99 100 def get_cgns_tol(self) -> float: 101 """retrieve CGNS test tolerance. 102 103 Returns: 104 tolerance (float): Test tolerance 105 """ 106 return self.cgns_tol if hasattr(self, 'cgns_tol') else 1.0e-12 107 108 def post_test_hook(self, test: str, spec: TestSpec) -> None: 109 """Function callback ran after each test case 110 111 Args: 112 test (str): Name of test 113 spec (TestSpec): Test case specification 114 """ 115 pass 116 117 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 118 """Check if a test case should be skipped prior to running, returning the reason for skipping 119 120 Args: 121 test (str): Name of test 122 spec (TestSpec): Test case specification 123 resource (str): libCEED backend 124 nproc (int): Number of MPI processes to use when running test case 125 126 Returns: 127 Optional[str]: Skip reason, or `None` if test case should not be skipped 128 """ 129 return None 130 131 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 132 """Check if a test case should be allowed to fail, based on its stderr output 133 134 Args: 135 test (str): Name of test 136 spec (TestSpec): Test case specification 137 resource (str): libCEED backend 138 stderr (str): Standard error output from test case execution 139 140 Returns: 141 Optional[str]: Skip reason, or `None` if unexpected error 142 """ 143 return None 144 145 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]: 146 """Check whether a test case is expected to fail and if it failed expectedly 147 148 Args: 149 test (str): Name of test 150 spec (TestSpec): Test case specification 151 resource (str): libCEED backend 152 stderr (str): Standard error output from test case execution 153 154 Returns: 155 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 156 """ 157 return '', True 158 159 def check_allowed_stdout(self, test: str) -> bool: 160 """Check whether a test is allowed to print console output 161 162 Args: 163 test (str): Name of test 164 165 Returns: 166 bool: True if the test is allowed to print console output 167 """ 168 return False 169 170 171def has_cgnsdiff() -> bool: 172 """Check whether `cgnsdiff` is an executable program in the current environment 173 174 Returns: 175 bool: True if `cgnsdiff` is found 176 """ 177 my_env: dict = os.environ.copy() 178 proc = subprocess.run('cgnsdiff', 179 shell=True, 180 stdout=subprocess.PIPE, 181 stderr=subprocess.PIPE, 182 env=my_env) 183 return 'not found' not in proc.stderr.decode('utf-8') 184 185 186def contains_any(base: str, substrings: List[str]) -> bool: 187 """Helper function, checks if any of the substrings are included in the base string 188 189 Args: 190 base (str): Base string to search in 191 substrings (List[str]): List of potential substrings 192 193 Returns: 194 bool: True if any substrings are included in base string 195 """ 196 return any((sub in base for sub in substrings)) 197 198 199def startswith_any(base: str, prefixes: List[str]) -> bool: 200 """Helper function, checks if the base string is prefixed by any of `prefixes` 201 202 Args: 203 base (str): Base string to search 204 prefixes (List[str]): List of potential prefixes 205 206 Returns: 207 bool: True if base string is prefixed by any of the prefixes 208 """ 209 return any((base.startswith(prefix) for prefix in prefixes)) 210 211 212def parse_test_line(line: str) -> TestSpec: 213 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 214 215 Args: 216 line (str): String containing TESTARGS specification and CLI arguments 217 218 Returns: 219 TestSpec: Parsed specification of test case 220 """ 221 args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 222 if args[0] == 'TESTARGS': 223 return TestSpec(name='', args=args[1:]) 224 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 225 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 226 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 227 name: str = test_args.get('name', '') 228 constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else [] 229 if len(args) > 1: 230 return TestSpec(name=name, only=constraints, args=args[1:]) 231 else: 232 return TestSpec(name=name, only=constraints) 233 234 235def get_test_args(source_file: Path) -> List[TestSpec]: 236 """Parse all test cases from a given source file 237 238 Args: 239 source_file (Path): Path to source file 240 241 Raises: 242 RuntimeError: Errors if source file extension is unsupported 243 244 Returns: 245 List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 246 """ 247 comment_str: str = '' 248 if source_file.suffix in ['.c', '.cc', '.cpp']: 249 comment_str = '//' 250 elif source_file.suffix in ['.py']: 251 comment_str = '#' 252 elif source_file.suffix in ['.usr']: 253 comment_str = 'C_' 254 elif source_file.suffix in ['.f90']: 255 comment_str = '! ' 256 else: 257 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 258 259 return [parse_test_line(line.strip(comment_str)) 260 for line in source_file.read_text().splitlines() 261 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 262 263 264def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 265 """Compare CSV results against an expected CSV file with tolerances 266 267 Args: 268 test_csv (Path): Path to output CSV results 269 true_csv (Path): Path to expected CSV results 270 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 271 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 272 273 Returns: 274 str: Diff output between result and expected CSVs 275 """ 276 test_lines: List[str] = test_csv.read_text().splitlines() 277 true_lines: List[str] = true_csv.read_text().splitlines() 278 # Files should not be empty 279 if len(test_lines) == 0: 280 return f'No lines found in test output {test_csv}' 281 if len(true_lines) == 0: 282 return f'No lines found in test source {true_csv}' 283 284 test_reader: csv.DictReader = csv.DictReader(test_lines) 285 true_reader: csv.DictReader = csv.DictReader(true_lines) 286 if test_reader.fieldnames != true_reader.fieldnames: 287 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 288 tofile='found CSV columns', fromfile='expected CSV columns')) 289 290 if len(test_lines) != len(true_lines): 291 return f'Number of lines in {test_csv} and {true_csv} do not match' 292 diff_lines: List[str] = list() 293 for test_line, true_line in zip(test_reader, true_reader): 294 for key in test_reader.fieldnames: 295 # Check if the value is numeric 296 try: 297 true_val: float = float(true_line[key]) 298 test_val: float = float(test_line[key]) 299 true_zero: bool = abs(true_val) < zero_tol 300 test_zero: bool = abs(test_val) < zero_tol 301 fail: bool = False 302 if true_zero: 303 fail = not test_zero 304 else: 305 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 306 if fail: 307 diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}') 308 except ValueError: 309 if test_line[key] != true_line[key]: 310 diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}') 311 312 return '\n'.join(diff_lines) 313 314 315def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str: 316 """Compare CGNS results against an expected CGSN file with tolerance 317 318 Args: 319 test_cgns (Path): Path to output CGNS file 320 true_cgns (Path): Path to expected CGNS file 321 cgns_tol (float, optional): Tolerance for comparing floating-point values 322 323 Returns: 324 str: Diff output between result and expected CGNS files 325 """ 326 my_env: dict = os.environ.copy() 327 328 run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)] 329 proc = subprocess.run(' '.join(run_args), 330 shell=True, 331 stdout=subprocess.PIPE, 332 stderr=subprocess.PIPE, 333 env=my_env) 334 335 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 336 337 338def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode, 339 backend: str, test: str, index: int) -> str: 340 output_str = '' 341 if mode is RunMode.TAP: 342 # print incremental output if TAP mode 343 if test_case.is_skipped(): 344 output_str += f' ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n' 345 elif test_case.is_failure() or test_case.is_error(): 346 output_str += f' not ok {index} - {spec.name}, {backend}\n' 347 else: 348 output_str += f' ok {index} - {spec.name}, {backend}\n' 349 output_str += f' ---\n' 350 if spec.only: 351 output_str += f' only: {",".join(spec.only)}\n' 352 output_str += f' args: {test_case.args}\n' 353 if test_case.is_error(): 354 output_str += f' error: {test_case.errors[0]["message"]}\n' 355 if test_case.is_failure(): 356 output_str += f' num_failures: {len(test_case.failures)}\n' 357 for i, failure in enumerate(test_case.failures): 358 output_str += f' failure_{i}: {failure["message"]}\n' 359 output_str += f' message: {failure["message"]}\n' 360 if failure["output"]: 361 out = failure["output"].strip().replace('\n', '\n ') 362 output_str += f' output: |\n {out}\n' 363 output_str += f' ...\n' 364 else: 365 # print error or failure information if JUNIT mode 366 if test_case.is_error() or test_case.is_failure(): 367 output_str += f'Test: {test} {spec.name}\n' 368 output_str += f' $ {test_case.args}\n' 369 if test_case.is_error(): 370 output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()) 371 output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()) 372 if test_case.is_failure(): 373 for failure in test_case.failures: 374 output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip()) 375 output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip()) 376 return output_str 377 378 379def run_test(index: int, test: str, spec: TestSpec, backend: str, 380 mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase: 381 """Run a single test case and backend combination 382 383 Args: 384 index (int): Index of backend for current spec 385 test (str): Path to test 386 spec (TestSpec): Specification of test case 387 backend (str): CEED backend 388 mode (RunMode): Output mode 389 nproc (int): Number of MPI processes to use when running test case 390 suite_spec (SuiteSpec): Specification of test suite 391 392 Returns: 393 TestCase: Test case result 394 """ 395 source_path: Path = suite_spec.get_source_path(test) 396 run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)] 397 398 if '{ceed_resource}' in run_args: 399 run_args[run_args.index('{ceed_resource}')] = backend 400 for i, arg in enumerate(run_args): 401 if '{ceed_resource}' in arg: 402 run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-')) 403 if '{nproc}' in run_args: 404 run_args[run_args.index('{nproc}')] = f'{nproc}' 405 elif nproc > 1 and source_path.suffix != '.py': 406 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 407 408 # run test 409 skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc) 410 if skip_reason: 411 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 412 elapsed_sec=0, 413 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 414 stdout='', 415 stderr='', 416 category=spec.name,) 417 test_case.add_skipped_info(skip_reason) 418 else: 419 start: float = time.time() 420 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 421 shell=True, 422 stdout=subprocess.PIPE, 423 stderr=subprocess.PIPE, 424 env=my_env) 425 426 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 427 classname=source_path.parent, 428 elapsed_sec=time.time() - start, 429 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 430 stdout=proc.stdout.decode('utf-8'), 431 stderr=proc.stderr.decode('utf-8'), 432 allow_multiple_subelements=True, 433 category=spec.name,) 434 ref_csvs: List[Path] = [] 435 output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] 436 if output_files: 437 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 438 ref_cgns: List[Path] = [] 439 output_files = [arg for arg in run_args if 'cgns:' in arg] 440 if output_files: 441 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 442 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 443 suite_spec.post_test_hook(test, spec) 444 445 # check allowed failures 446 if not test_case.is_skipped() and test_case.stderr: 447 skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr) 448 if skip_reason: 449 test_case.add_skipped_info(skip_reason) 450 451 # check required failures 452 if not test_case.is_skipped(): 453 required_message, did_fail = suite_spec.check_required_failure( 454 test, spec, backend, test_case.stderr) 455 if required_message and did_fail: 456 test_case.status = f'fails with required: {required_message}' 457 elif required_message: 458 test_case.add_failure_info(f'required failure missing: {required_message}') 459 460 # classify other results 461 if not test_case.is_skipped() and not test_case.status: 462 if test_case.stderr: 463 test_case.add_failure_info('stderr', test_case.stderr) 464 if proc.returncode != 0: 465 test_case.add_error_info(f'returncode = {proc.returncode}') 466 if ref_stdout.is_file(): 467 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 468 test_case.stdout.splitlines(keepends=True), 469 fromfile=str(ref_stdout), 470 tofile='New')) 471 if diff: 472 test_case.add_failure_info('stdout', output=''.join(diff)) 473 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 474 test_case.add_failure_info('stdout', output=test_case.stdout) 475 # expected CSV output 476 for ref_csv in ref_csvs: 477 csv_name = ref_csv.name 478 if not ref_csv.is_file(): 479 # remove _{ceed_backend} from path name 480 ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv') 481 if not ref_csv.is_file(): 482 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 483 else: 484 diff: str = diff_csv(Path.cwd() / csv_name, ref_csv) 485 if diff: 486 test_case.add_failure_info('csv', output=diff) 487 else: 488 (Path.cwd() / csv_name).unlink() 489 # expected CGNS output 490 for ref_cgn in ref_cgns: 491 cgn_name = ref_cgn.name 492 if not ref_cgn.is_file(): 493 # remove _{ceed_backend} from path name 494 ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns') 495 if not ref_cgn.is_file(): 496 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 497 else: 498 diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.get_cgns_tol()) 499 if diff: 500 test_case.add_failure_info('cgns', output=diff) 501 else: 502 (Path.cwd() / cgn_name).unlink() 503 504 # store result 505 test_case.args = ' '.join(str(arg) for arg in run_args) 506 output_str = test_case_output_string(test_case, spec, mode, backend, test, index) 507 508 return test_case, output_str 509 510 511def init_process(): 512 """Initialize multiprocessing process""" 513 # set up error handler 514 global my_env 515 my_env = os.environ.copy() 516 my_env['CEED_ERROR_HANDLER'] = 'exit' 517 518 519def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int, 520 suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite: 521 """Run all test cases for `test` with each of the provided `ceed_backends` 522 523 Args: 524 test (str): Name of test 525 ceed_backends (List[str]): List of libCEED backends 526 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 527 nproc (int): Number of MPI processes to use when running each test case 528 suite_spec (SuiteSpec): Object defining required methods for running tests 529 pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1. 530 531 Returns: 532 TestSuite: JUnit `TestSuite` containing results of all test cases 533 """ 534 test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test)) 535 if mode is RunMode.TAP: 536 print('TAP version 13') 537 print(f'1..{len(test_specs)}') 538 539 with mp.Pool(processes=pool_size, initializer=init_process) as pool: 540 async_outputs: List[List[mp.AsyncResult]] = [ 541 [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec)) 542 for (i, backend) in enumerate(ceed_backends, start=1)] 543 for spec in test_specs 544 ] 545 546 test_cases = [] 547 for (i, subtest) in enumerate(async_outputs, start=1): 548 is_new_subtest = True 549 subtest_ok = True 550 for async_output in subtest: 551 test_case, print_output = async_output.get() 552 test_cases.append(test_case) 553 if is_new_subtest and mode == RunMode.TAP: 554 is_new_subtest = False 555 print(f'# Subtest: {test_case.category}') 556 print(f' 1..{len(ceed_backends)}') 557 print(print_output, end='') 558 if test_case.is_failure() or test_case.is_error(): 559 subtest_ok = False 560 if mode == RunMode.TAP: 561 print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}') 562 563 return TestSuite(test, test_cases) 564 565 566def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 567 """Write a JUnit XML file containing the results of a `TestSuite` 568 569 Args: 570 test_suite (TestSuite): JUnit `TestSuite` to write 571 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 572 batch (str): Name of JUnit batch, defaults to empty string 573 """ 574 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 575 output_file.write_text(to_xml_report_string([test_suite])) 576 577 578def has_failures(test_suite: TestSuite) -> bool: 579 """Check whether any test cases in a `TestSuite` failed 580 581 Args: 582 test_suite (TestSuite): JUnit `TestSuite` to check 583 584 Returns: 585 bool: True if any test cases failed 586 """ 587 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 588