1from abc import ABC, abstractmethod 2import argparse 3import csv 4from dataclasses import dataclass, field 5import difflib 6from enum import Enum 7from math import isclose 8import os 9from pathlib import Path 10import re 11import subprocess 12import multiprocessing as mp 13from itertools import product 14import sys 15import time 16from typing import Optional, Tuple, List, Callable 17 18sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 19from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 20 21 22class CaseInsensitiveEnumAction(argparse.Action): 23 """Action to convert input values to lower case prior to converting to an Enum type""" 24 25 def __init__(self, option_strings, dest, type, default, **kwargs): 26 if not (issubclass(type, Enum) and issubclass(type, str)): 27 raise ValueError(f"{type} must be a StrEnum or str and Enum") 28 # store provided enum type 29 self.enum_type = type 30 if isinstance(default, str): 31 default = self.enum_type(default.lower()) 32 else: 33 default = [self.enum_type(v.lower()) for v in default] 34 # prevent automatic type conversion 35 super().__init__(option_strings, dest, default=default, **kwargs) 36 37 def __call__(self, parser, namespace, values, option_string=None): 38 if isinstance(values, str): 39 values = self.enum_type(values.lower()) 40 else: 41 values = [self.enum_type(v.lower()) for v in values] 42 setattr(namespace, self.dest, values) 43 44 45@dataclass 46class TestSpec: 47 """Dataclass storing information about a single test case""" 48 name: str 49 only: List = field(default_factory=list) 50 args: List = field(default_factory=list) 51 52 53class RunMode(str, Enum): 54 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 55 __str__ = str.__str__ 56 __format__ = str.__format__ 57 TAP: str = 'tap' 58 JUNIT: str = 'junit' 59 60 61class SuiteSpec(ABC): 62 """Abstract Base Class defining the required interface for running a test suite""" 63 @abstractmethod 64 def get_source_path(self, test: str) -> Path: 65 """Compute path to test source file 66 67 Args: 68 test (str): Name of test 69 70 Returns: 71 Path: Path to source file 72 """ 73 raise NotImplementedError 74 75 @abstractmethod 76 def get_run_path(self, test: str) -> Path: 77 """Compute path to built test executable file 78 79 Args: 80 test (str): Name of test 81 82 Returns: 83 Path: Path to test executable 84 """ 85 raise NotImplementedError 86 87 @abstractmethod 88 def get_output_path(self, test: str, output_file: str) -> Path: 89 """Compute path to expected output file 90 91 Args: 92 test (str): Name of test 93 output_file (str): File name of output file 94 95 Returns: 96 Path: Path to expected output file 97 """ 98 raise NotImplementedError 99 100 @property 101 def cgns_tol(self): 102 """Absolute tolerance for CGNS diff""" 103 return getattr(self, '_cgns_tol', 1.0e-12) 104 105 @cgns_tol.setter 106 def cgns_tol(self, val): 107 self._cgns_tol = val 108 109 @property 110 def diff_csv_kwargs(self): 111 """Keyword arguments to be passed to diff_csv()""" 112 return getattr(self, '_diff_csv_kwargs', {}) 113 114 @diff_csv_kwargs.setter 115 def diff_csv_kwargs(self, val): 116 self._diff_csv_kwargs = val 117 118 def post_test_hook(self, test: str, spec: TestSpec) -> None: 119 """Function callback ran after each test case 120 121 Args: 122 test (str): Name of test 123 spec (TestSpec): Test case specification 124 """ 125 pass 126 127 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 128 """Check if a test case should be skipped prior to running, returning the reason for skipping 129 130 Args: 131 test (str): Name of test 132 spec (TestSpec): Test case specification 133 resource (str): libCEED backend 134 nproc (int): Number of MPI processes to use when running test case 135 136 Returns: 137 Optional[str]: Skip reason, or `None` if test case should not be skipped 138 """ 139 return None 140 141 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 142 """Check if a test case should be allowed to fail, based on its stderr output 143 144 Args: 145 test (str): Name of test 146 spec (TestSpec): Test case specification 147 resource (str): libCEED backend 148 stderr (str): Standard error output from test case execution 149 150 Returns: 151 Optional[str]: Skip reason, or `None` if unexpected error 152 """ 153 return None 154 155 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]: 156 """Check whether a test case is expected to fail and if it failed expectedly 157 158 Args: 159 test (str): Name of test 160 spec (TestSpec): Test case specification 161 resource (str): libCEED backend 162 stderr (str): Standard error output from test case execution 163 164 Returns: 165 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 166 """ 167 return '', True 168 169 def check_allowed_stdout(self, test: str) -> bool: 170 """Check whether a test is allowed to print console output 171 172 Args: 173 test (str): Name of test 174 175 Returns: 176 bool: True if the test is allowed to print console output 177 """ 178 return False 179 180 181def has_cgnsdiff() -> bool: 182 """Check whether `cgnsdiff` is an executable program in the current environment 183 184 Returns: 185 bool: True if `cgnsdiff` is found 186 """ 187 my_env: dict = os.environ.copy() 188 proc = subprocess.run('cgnsdiff', 189 shell=True, 190 stdout=subprocess.PIPE, 191 stderr=subprocess.PIPE, 192 env=my_env) 193 return 'not found' not in proc.stderr.decode('utf-8') 194 195 196def contains_any(base: str, substrings: List[str]) -> bool: 197 """Helper function, checks if any of the substrings are included in the base string 198 199 Args: 200 base (str): Base string to search in 201 substrings (List[str]): List of potential substrings 202 203 Returns: 204 bool: True if any substrings are included in base string 205 """ 206 return any((sub in base for sub in substrings)) 207 208 209def startswith_any(base: str, prefixes: List[str]) -> bool: 210 """Helper function, checks if the base string is prefixed by any of `prefixes` 211 212 Args: 213 base (str): Base string to search 214 prefixes (List[str]): List of potential prefixes 215 216 Returns: 217 bool: True if base string is prefixed by any of the prefixes 218 """ 219 return any((base.startswith(prefix) for prefix in prefixes)) 220 221 222def parse_test_line(line: str) -> TestSpec: 223 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 224 225 Args: 226 line (str): String containing TESTARGS specification and CLI arguments 227 228 Returns: 229 TestSpec: Parsed specification of test case 230 """ 231 args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 232 if args[0] == 'TESTARGS': 233 return TestSpec(name='', args=args[1:]) 234 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 235 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 236 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 237 name: str = test_args.get('name', '') 238 constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else [] 239 if len(args) > 1: 240 return TestSpec(name=name, only=constraints, args=args[1:]) 241 else: 242 return TestSpec(name=name, only=constraints) 243 244 245def get_test_args(source_file: Path) -> List[TestSpec]: 246 """Parse all test cases from a given source file 247 248 Args: 249 source_file (Path): Path to source file 250 251 Raises: 252 RuntimeError: Errors if source file extension is unsupported 253 254 Returns: 255 List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 256 """ 257 comment_str: str = '' 258 if source_file.suffix in ['.c', '.cc', '.cpp']: 259 comment_str = '//' 260 elif source_file.suffix in ['.py']: 261 comment_str = '#' 262 elif source_file.suffix in ['.usr']: 263 comment_str = 'C_' 264 elif source_file.suffix in ['.f90']: 265 comment_str = '! ' 266 else: 267 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 268 269 return [parse_test_line(line.strip(comment_str)) 270 for line in source_file.read_text().splitlines() 271 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 272 273 274def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2, 275 comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str: 276 """Compare CSV results against an expected CSV file with tolerances 277 278 Args: 279 test_csv (Path): Path to output CSV results 280 true_csv (Path): Path to expected CSV results 281 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 282 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 283 comment_str (str, optional): String to denoting commented line 284 comment_func (Callable, optional): Function to determine if test and true line are different 285 286 Returns: 287 str: Diff output between result and expected CSVs 288 """ 289 test_lines: List[str] = test_csv.read_text().splitlines() 290 true_lines: List[str] = true_csv.read_text().splitlines() 291 # Files should not be empty 292 if len(test_lines) == 0: 293 return f'No lines found in test output {test_csv}' 294 if len(true_lines) == 0: 295 return f'No lines found in test source {true_csv}' 296 if len(test_lines) != len(true_lines): 297 return f'Number of lines in {test_csv} and {true_csv} do not match' 298 299 # Process commented lines 300 uncommented_lines: List[int] = [] 301 for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)): 302 if test_line[0] == comment_str and true_line[0] == comment_str: 303 if comment_func: 304 output = comment_func(test_line, true_line) 305 if output: 306 return output 307 elif test_line[0] == comment_str and true_line[0] != comment_str: 308 return f'Commented line found in {test_csv} at line {n} but not in {true_csv}' 309 elif test_line[0] != comment_str and true_line[0] == comment_str: 310 return f'Commented line found in {true_csv} at line {n} but not in {test_csv}' 311 else: 312 uncommented_lines.append(n) 313 314 # Remove commented lines 315 test_lines = [test_lines[line] for line in uncommented_lines] 316 true_lines = [true_lines[line] for line in uncommented_lines] 317 318 test_reader: csv.DictReader = csv.DictReader(test_lines) 319 true_reader: csv.DictReader = csv.DictReader(true_lines) 320 if test_reader.fieldnames != true_reader.fieldnames: 321 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 322 tofile='found CSV columns', fromfile='expected CSV columns')) 323 324 diff_lines: List[str] = list() 325 for test_line, true_line in zip(test_reader, true_reader): 326 for key in test_reader.fieldnames: 327 # Check if the value is numeric 328 try: 329 true_val: float = float(true_line[key]) 330 test_val: float = float(test_line[key]) 331 true_zero: bool = abs(true_val) < zero_tol 332 test_zero: bool = abs(test_val) < zero_tol 333 fail: bool = False 334 if true_zero: 335 fail = not test_zero 336 else: 337 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 338 if fail: 339 diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}') 340 except ValueError: 341 if test_line[key] != true_line[key]: 342 diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}') 343 344 return '\n'.join(diff_lines) 345 346 347def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str: 348 """Compare CGNS results against an expected CGSN file with tolerance 349 350 Args: 351 test_cgns (Path): Path to output CGNS file 352 true_cgns (Path): Path to expected CGNS file 353 cgns_tol (float, optional): Tolerance for comparing floating-point values 354 355 Returns: 356 str: Diff output between result and expected CGNS files 357 """ 358 my_env: dict = os.environ.copy() 359 360 run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)] 361 proc = subprocess.run(' '.join(run_args), 362 shell=True, 363 stdout=subprocess.PIPE, 364 stderr=subprocess.PIPE, 365 env=my_env) 366 367 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 368 369 370def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode, 371 backend: str, test: str, index: int) -> str: 372 output_str = '' 373 if mode is RunMode.TAP: 374 # print incremental output if TAP mode 375 if test_case.is_skipped(): 376 output_str += f' ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n' 377 elif test_case.is_failure() or test_case.is_error(): 378 output_str += f' not ok {index} - {spec.name}, {backend}\n' 379 else: 380 output_str += f' ok {index} - {spec.name}, {backend}\n' 381 output_str += f' ---\n' 382 if spec.only: 383 output_str += f' only: {",".join(spec.only)}\n' 384 output_str += f' args: {test_case.args}\n' 385 if test_case.is_error(): 386 output_str += f' error: {test_case.errors[0]["message"]}\n' 387 if test_case.is_failure(): 388 output_str += f' num_failures: {len(test_case.failures)}\n' 389 for i, failure in enumerate(test_case.failures): 390 output_str += f' failure_{i}: {failure["message"]}\n' 391 output_str += f' message: {failure["message"]}\n' 392 if failure["output"]: 393 out = failure["output"].strip().replace('\n', '\n ') 394 output_str += f' output: |\n {out}\n' 395 output_str += f' ...\n' 396 else: 397 # print error or failure information if JUNIT mode 398 if test_case.is_error() or test_case.is_failure(): 399 output_str += f'Test: {test} {spec.name}\n' 400 output_str += f' $ {test_case.args}\n' 401 if test_case.is_error(): 402 output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()) 403 output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()) 404 if test_case.is_failure(): 405 for failure in test_case.failures: 406 output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip()) 407 output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip()) 408 return output_str 409 410 411def run_test(index: int, test: str, spec: TestSpec, backend: str, 412 mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase: 413 """Run a single test case and backend combination 414 415 Args: 416 index (int): Index of backend for current spec 417 test (str): Path to test 418 spec (TestSpec): Specification of test case 419 backend (str): CEED backend 420 mode (RunMode): Output mode 421 nproc (int): Number of MPI processes to use when running test case 422 suite_spec (SuiteSpec): Specification of test suite 423 424 Returns: 425 TestCase: Test case result 426 """ 427 source_path: Path = suite_spec.get_source_path(test) 428 run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)] 429 430 if '{ceed_resource}' in run_args: 431 run_args[run_args.index('{ceed_resource}')] = backend 432 for i, arg in enumerate(run_args): 433 if '{ceed_resource}' in arg: 434 run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-')) 435 if '{nproc}' in run_args: 436 run_args[run_args.index('{nproc}')] = f'{nproc}' 437 elif nproc > 1 and source_path.suffix != '.py': 438 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 439 440 # run test 441 skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc) 442 if skip_reason: 443 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 444 elapsed_sec=0, 445 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 446 stdout='', 447 stderr='', 448 category=spec.name,) 449 test_case.add_skipped_info(skip_reason) 450 else: 451 start: float = time.time() 452 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 453 shell=True, 454 stdout=subprocess.PIPE, 455 stderr=subprocess.PIPE, 456 env=my_env) 457 458 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 459 classname=source_path.parent, 460 elapsed_sec=time.time() - start, 461 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 462 stdout=proc.stdout.decode('utf-8'), 463 stderr=proc.stderr.decode('utf-8'), 464 allow_multiple_subelements=True, 465 category=spec.name,) 466 ref_csvs: List[Path] = [] 467 output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] 468 if output_files: 469 ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files] 470 ref_cgns: List[Path] = [] 471 output_files = [arg for arg in run_args if 'cgns:' in arg] 472 if output_files: 473 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 474 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 475 suite_spec.post_test_hook(test, spec) 476 477 # check allowed failures 478 if not test_case.is_skipped() and test_case.stderr: 479 skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr) 480 if skip_reason: 481 test_case.add_skipped_info(skip_reason) 482 483 # check required failures 484 if not test_case.is_skipped(): 485 required_message, did_fail = suite_spec.check_required_failure( 486 test, spec, backend, test_case.stderr) 487 if required_message and did_fail: 488 test_case.status = f'fails with required: {required_message}' 489 elif required_message: 490 test_case.add_failure_info(f'required failure missing: {required_message}') 491 492 # classify other results 493 if not test_case.is_skipped() and not test_case.status: 494 if test_case.stderr: 495 test_case.add_failure_info('stderr', test_case.stderr) 496 if proc.returncode != 0: 497 test_case.add_error_info(f'returncode = {proc.returncode}') 498 if ref_stdout.is_file(): 499 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 500 test_case.stdout.splitlines(keepends=True), 501 fromfile=str(ref_stdout), 502 tofile='New')) 503 if diff: 504 test_case.add_failure_info('stdout', output=''.join(diff)) 505 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 506 test_case.add_failure_info('stdout', output=test_case.stdout) 507 # expected CSV output 508 for ref_csv in ref_csvs: 509 csv_name = ref_csv.name 510 if not ref_csv.is_file(): 511 # remove _{ceed_backend} from path name 512 ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv') 513 if not ref_csv.is_file(): 514 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 515 elif not (Path.cwd() / csv_name).is_file(): 516 test_case.add_failure_info('csv', output=f'{csv_name} not found') 517 else: 518 diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs) 519 if diff: 520 test_case.add_failure_info('csv', output=diff) 521 else: 522 (Path.cwd() / csv_name).unlink() 523 # expected CGNS output 524 for ref_cgn in ref_cgns: 525 cgn_name = ref_cgn.name 526 if not ref_cgn.is_file(): 527 # remove _{ceed_backend} from path name 528 ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns') 529 if not ref_cgn.is_file(): 530 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 531 elif not (Path.cwd() / cgn_name).is_file(): 532 test_case.add_failure_info('csv', output=f'{cgn_name} not found') 533 else: 534 diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol) 535 if diff: 536 test_case.add_failure_info('cgns', output=diff) 537 else: 538 (Path.cwd() / cgn_name).unlink() 539 540 # store result 541 test_case.args = ' '.join(str(arg) for arg in run_args) 542 output_str = test_case_output_string(test_case, spec, mode, backend, test, index) 543 544 return test_case, output_str 545 546 547def init_process(): 548 """Initialize multiprocessing process""" 549 # set up error handler 550 global my_env 551 my_env = os.environ.copy() 552 my_env['CEED_ERROR_HANDLER'] = 'exit' 553 554 555def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int, 556 suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite: 557 """Run all test cases for `test` with each of the provided `ceed_backends` 558 559 Args: 560 test (str): Name of test 561 ceed_backends (List[str]): List of libCEED backends 562 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 563 nproc (int): Number of MPI processes to use when running each test case 564 suite_spec (SuiteSpec): Object defining required methods for running tests 565 pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1. 566 567 Returns: 568 TestSuite: JUnit `TestSuite` containing results of all test cases 569 """ 570 test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test)) 571 if mode is RunMode.TAP: 572 print('TAP version 13') 573 print(f'1..{len(test_specs)}') 574 575 with mp.Pool(processes=pool_size, initializer=init_process) as pool: 576 async_outputs: List[List[mp.AsyncResult]] = [ 577 [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec)) 578 for (i, backend) in enumerate(ceed_backends, start=1)] 579 for spec in test_specs 580 ] 581 582 test_cases = [] 583 for (i, subtest) in enumerate(async_outputs, start=1): 584 is_new_subtest = True 585 subtest_ok = True 586 for async_output in subtest: 587 test_case, print_output = async_output.get() 588 test_cases.append(test_case) 589 if is_new_subtest and mode == RunMode.TAP: 590 is_new_subtest = False 591 print(f'# Subtest: {test_case.category}') 592 print(f' 1..{len(ceed_backends)}') 593 print(print_output, end='') 594 if test_case.is_failure() or test_case.is_error(): 595 subtest_ok = False 596 if mode == RunMode.TAP: 597 print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}') 598 599 return TestSuite(test, test_cases) 600 601 602def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 603 """Write a JUnit XML file containing the results of a `TestSuite` 604 605 Args: 606 test_suite (TestSuite): JUnit `TestSuite` to write 607 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 608 batch (str): Name of JUnit batch, defaults to empty string 609 """ 610 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 611 output_file.write_text(to_xml_report_string([test_suite])) 612 613 614def has_failures(test_suite: TestSuite) -> bool: 615 """Check whether any test cases in a `TestSuite` failed 616 617 Args: 618 test_suite (TestSuite): JUnit `TestSuite` to check 619 620 Returns: 621 bool: True if any test cases failed 622 """ 623 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 624