1from abc import ABC, abstractmethod 2import argparse 3import csv 4from dataclasses import dataclass, field 5import difflib 6from enum import Enum 7from math import isclose 8import os 9from pathlib import Path 10import re 11import subprocess 12import multiprocessing as mp 13from itertools import product 14import sys 15import time 16from typing import Optional, Tuple, List 17 18sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 19from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 20 21 22class CaseInsensitiveEnumAction(argparse.Action): 23 """Action to convert input values to lower case prior to converting to an Enum type""" 24 25 def __init__(self, option_strings, dest, type, default, **kwargs): 26 if not (issubclass(type, Enum) and issubclass(type, str)): 27 raise ValueError(f"{type} must be a StrEnum or str and Enum") 28 # store provided enum type 29 self.enum_type = type 30 if isinstance(default, str): 31 default = self.enum_type(default.lower()) 32 else: 33 default = [self.enum_type(v.lower()) for v in default] 34 # prevent automatic type conversion 35 super().__init__(option_strings, dest, default=default, **kwargs) 36 37 def __call__(self, parser, namespace, values, option_string=None): 38 if isinstance(values, str): 39 values = self.enum_type(values.lower()) 40 else: 41 values = [self.enum_type(v.lower()) for v in values] 42 setattr(namespace, self.dest, values) 43 44 45@dataclass 46class TestSpec: 47 """Dataclass storing information about a single test case""" 48 name: str 49 only: List = field(default_factory=list) 50 args: List = field(default_factory=list) 51 52 53class RunMode(str, Enum): 54 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 55 __str__ = str.__str__ 56 __format__ = str.__format__ 57 TAP: str = 'tap' 58 JUNIT: str = 'junit' 59 60 61class SuiteSpec(ABC): 62 """Abstract Base Class defining the required interface for running a test suite""" 63 @abstractmethod 64 def get_source_path(self, test: str) -> Path: 65 """Compute path to test source file 66 67 Args: 68 test (str): Name of test 69 70 Returns: 71 Path: Path to source file 72 """ 73 raise NotImplementedError 74 75 @abstractmethod 76 def get_run_path(self, test: str) -> Path: 77 """Compute path to built test executable file 78 79 Args: 80 test (str): Name of test 81 82 Returns: 83 Path: Path to test executable 84 """ 85 raise NotImplementedError 86 87 @abstractmethod 88 def get_output_path(self, test: str, output_file: str) -> Path: 89 """Compute path to expected output file 90 91 Args: 92 test (str): Name of test 93 output_file (str): File name of output file 94 95 Returns: 96 Path: Path to expected output file 97 """ 98 raise NotImplementedError 99 100 @property 101 def cgns_tol(self): 102 """Absolute tolerance for CGNS diff""" 103 return getattr(self, '_cgns_tol', 1.0e-12) 104 105 @cgns_tol.setter 106 def cgns_tol(self, val): 107 self._cgns_tol = val 108 109 def post_test_hook(self, test: str, spec: TestSpec) -> None: 110 """Function callback ran after each test case 111 112 Args: 113 test (str): Name of test 114 spec (TestSpec): Test case specification 115 """ 116 pass 117 118 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 119 """Check if a test case should be skipped prior to running, returning the reason for skipping 120 121 Args: 122 test (str): Name of test 123 spec (TestSpec): Test case specification 124 resource (str): libCEED backend 125 nproc (int): Number of MPI processes to use when running test case 126 127 Returns: 128 Optional[str]: Skip reason, or `None` if test case should not be skipped 129 """ 130 return None 131 132 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 133 """Check if a test case should be allowed to fail, based on its stderr output 134 135 Args: 136 test (str): Name of test 137 spec (TestSpec): Test case specification 138 resource (str): libCEED backend 139 stderr (str): Standard error output from test case execution 140 141 Returns: 142 Optional[str]: Skip reason, or `None` if unexpected error 143 """ 144 return None 145 146 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]: 147 """Check whether a test case is expected to fail and if it failed expectedly 148 149 Args: 150 test (str): Name of test 151 spec (TestSpec): Test case specification 152 resource (str): libCEED backend 153 stderr (str): Standard error output from test case execution 154 155 Returns: 156 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 157 """ 158 return '', True 159 160 def check_allowed_stdout(self, test: str) -> bool: 161 """Check whether a test is allowed to print console output 162 163 Args: 164 test (str): Name of test 165 166 Returns: 167 bool: True if the test is allowed to print console output 168 """ 169 return False 170 171 172def has_cgnsdiff() -> bool: 173 """Check whether `cgnsdiff` is an executable program in the current environment 174 175 Returns: 176 bool: True if `cgnsdiff` is found 177 """ 178 my_env: dict = os.environ.copy() 179 proc = subprocess.run('cgnsdiff', 180 shell=True, 181 stdout=subprocess.PIPE, 182 stderr=subprocess.PIPE, 183 env=my_env) 184 return 'not found' not in proc.stderr.decode('utf-8') 185 186 187def contains_any(base: str, substrings: List[str]) -> bool: 188 """Helper function, checks if any of the substrings are included in the base string 189 190 Args: 191 base (str): Base string to search in 192 substrings (List[str]): List of potential substrings 193 194 Returns: 195 bool: True if any substrings are included in base string 196 """ 197 return any((sub in base for sub in substrings)) 198 199 200def startswith_any(base: str, prefixes: List[str]) -> bool: 201 """Helper function, checks if the base string is prefixed by any of `prefixes` 202 203 Args: 204 base (str): Base string to search 205 prefixes (List[str]): List of potential prefixes 206 207 Returns: 208 bool: True if base string is prefixed by any of the prefixes 209 """ 210 return any((base.startswith(prefix) for prefix in prefixes)) 211 212 213def parse_test_line(line: str) -> TestSpec: 214 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 215 216 Args: 217 line (str): String containing TESTARGS specification and CLI arguments 218 219 Returns: 220 TestSpec: Parsed specification of test case 221 """ 222 args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 223 if args[0] == 'TESTARGS': 224 return TestSpec(name='', args=args[1:]) 225 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 226 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 227 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 228 name: str = test_args.get('name', '') 229 constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else [] 230 if len(args) > 1: 231 return TestSpec(name=name, only=constraints, args=args[1:]) 232 else: 233 return TestSpec(name=name, only=constraints) 234 235 236def get_test_args(source_file: Path) -> List[TestSpec]: 237 """Parse all test cases from a given source file 238 239 Args: 240 source_file (Path): Path to source file 241 242 Raises: 243 RuntimeError: Errors if source file extension is unsupported 244 245 Returns: 246 List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 247 """ 248 comment_str: str = '' 249 if source_file.suffix in ['.c', '.cc', '.cpp']: 250 comment_str = '//' 251 elif source_file.suffix in ['.py']: 252 comment_str = '#' 253 elif source_file.suffix in ['.usr']: 254 comment_str = 'C_' 255 elif source_file.suffix in ['.f90']: 256 comment_str = '! ' 257 else: 258 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 259 260 return [parse_test_line(line.strip(comment_str)) 261 for line in source_file.read_text().splitlines() 262 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 263 264 265def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 266 """Compare CSV results against an expected CSV file with tolerances 267 268 Args: 269 test_csv (Path): Path to output CSV results 270 true_csv (Path): Path to expected CSV results 271 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 272 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 273 274 Returns: 275 str: Diff output between result and expected CSVs 276 """ 277 test_lines: List[str] = test_csv.read_text().splitlines() 278 true_lines: List[str] = true_csv.read_text().splitlines() 279 # Files should not be empty 280 if len(test_lines) == 0: 281 return f'No lines found in test output {test_csv}' 282 if len(true_lines) == 0: 283 return f'No lines found in test source {true_csv}' 284 285 test_reader: csv.DictReader = csv.DictReader(test_lines) 286 true_reader: csv.DictReader = csv.DictReader(true_lines) 287 if test_reader.fieldnames != true_reader.fieldnames: 288 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 289 tofile='found CSV columns', fromfile='expected CSV columns')) 290 291 if len(test_lines) != len(true_lines): 292 return f'Number of lines in {test_csv} and {true_csv} do not match' 293 diff_lines: List[str] = list() 294 for test_line, true_line in zip(test_reader, true_reader): 295 for key in test_reader.fieldnames: 296 # Check if the value is numeric 297 try: 298 true_val: float = float(true_line[key]) 299 test_val: float = float(test_line[key]) 300 true_zero: bool = abs(true_val) < zero_tol 301 test_zero: bool = abs(test_val) < zero_tol 302 fail: bool = False 303 if true_zero: 304 fail = not test_zero 305 else: 306 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 307 if fail: 308 diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}') 309 except ValueError: 310 if test_line[key] != true_line[key]: 311 diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}') 312 313 return '\n'.join(diff_lines) 314 315 316def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str: 317 """Compare CGNS results against an expected CGSN file with tolerance 318 319 Args: 320 test_cgns (Path): Path to output CGNS file 321 true_cgns (Path): Path to expected CGNS file 322 cgns_tol (float, optional): Tolerance for comparing floating-point values 323 324 Returns: 325 str: Diff output between result and expected CGNS files 326 """ 327 my_env: dict = os.environ.copy() 328 329 run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)] 330 proc = subprocess.run(' '.join(run_args), 331 shell=True, 332 stdout=subprocess.PIPE, 333 stderr=subprocess.PIPE, 334 env=my_env) 335 336 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 337 338 339def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode, 340 backend: str, test: str, index: int) -> str: 341 output_str = '' 342 if mode is RunMode.TAP: 343 # print incremental output if TAP mode 344 if test_case.is_skipped(): 345 output_str += f' ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n' 346 elif test_case.is_failure() or test_case.is_error(): 347 output_str += f' not ok {index} - {spec.name}, {backend}\n' 348 else: 349 output_str += f' ok {index} - {spec.name}, {backend}\n' 350 output_str += f' ---\n' 351 if spec.only: 352 output_str += f' only: {",".join(spec.only)}\n' 353 output_str += f' args: {test_case.args}\n' 354 if test_case.is_error(): 355 output_str += f' error: {test_case.errors[0]["message"]}\n' 356 if test_case.is_failure(): 357 output_str += f' num_failures: {len(test_case.failures)}\n' 358 for i, failure in enumerate(test_case.failures): 359 output_str += f' failure_{i}: {failure["message"]}\n' 360 output_str += f' message: {failure["message"]}\n' 361 if failure["output"]: 362 out = failure["output"].strip().replace('\n', '\n ') 363 output_str += f' output: |\n {out}\n' 364 output_str += f' ...\n' 365 else: 366 # print error or failure information if JUNIT mode 367 if test_case.is_error() or test_case.is_failure(): 368 output_str += f'Test: {test} {spec.name}\n' 369 output_str += f' $ {test_case.args}\n' 370 if test_case.is_error(): 371 output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()) 372 output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()) 373 if test_case.is_failure(): 374 for failure in test_case.failures: 375 output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip()) 376 output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip()) 377 return output_str 378 379 380def run_test(index: int, test: str, spec: TestSpec, backend: str, 381 mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase: 382 """Run a single test case and backend combination 383 384 Args: 385 index (int): Index of backend for current spec 386 test (str): Path to test 387 spec (TestSpec): Specification of test case 388 backend (str): CEED backend 389 mode (RunMode): Output mode 390 nproc (int): Number of MPI processes to use when running test case 391 suite_spec (SuiteSpec): Specification of test suite 392 393 Returns: 394 TestCase: Test case result 395 """ 396 source_path: Path = suite_spec.get_source_path(test) 397 run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)] 398 399 if '{ceed_resource}' in run_args: 400 run_args[run_args.index('{ceed_resource}')] = backend 401 for i, arg in enumerate(run_args): 402 if '{ceed_resource}' in arg: 403 run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-')) 404 if '{nproc}' in run_args: 405 run_args[run_args.index('{nproc}')] = f'{nproc}' 406 elif nproc > 1 and source_path.suffix != '.py': 407 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 408 409 # run test 410 skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc) 411 if skip_reason: 412 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 413 elapsed_sec=0, 414 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 415 stdout='', 416 stderr='', 417 category=spec.name,) 418 test_case.add_skipped_info(skip_reason) 419 else: 420 start: float = time.time() 421 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 422 shell=True, 423 stdout=subprocess.PIPE, 424 stderr=subprocess.PIPE, 425 env=my_env) 426 427 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 428 classname=source_path.parent, 429 elapsed_sec=time.time() - start, 430 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 431 stdout=proc.stdout.decode('utf-8'), 432 stderr=proc.stderr.decode('utf-8'), 433 allow_multiple_subelements=True, 434 category=spec.name,) 435 ref_csvs: List[Path] = [] 436 output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] 437 if output_files: 438 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 439 ref_cgns: List[Path] = [] 440 output_files = [arg for arg in run_args if 'cgns:' in arg] 441 if output_files: 442 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 443 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 444 suite_spec.post_test_hook(test, spec) 445 446 # check allowed failures 447 if not test_case.is_skipped() and test_case.stderr: 448 skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr) 449 if skip_reason: 450 test_case.add_skipped_info(skip_reason) 451 452 # check required failures 453 if not test_case.is_skipped(): 454 required_message, did_fail = suite_spec.check_required_failure( 455 test, spec, backend, test_case.stderr) 456 if required_message and did_fail: 457 test_case.status = f'fails with required: {required_message}' 458 elif required_message: 459 test_case.add_failure_info(f'required failure missing: {required_message}') 460 461 # classify other results 462 if not test_case.is_skipped() and not test_case.status: 463 if test_case.stderr: 464 test_case.add_failure_info('stderr', test_case.stderr) 465 if proc.returncode != 0: 466 test_case.add_error_info(f'returncode = {proc.returncode}') 467 if ref_stdout.is_file(): 468 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 469 test_case.stdout.splitlines(keepends=True), 470 fromfile=str(ref_stdout), 471 tofile='New')) 472 if diff: 473 test_case.add_failure_info('stdout', output=''.join(diff)) 474 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 475 test_case.add_failure_info('stdout', output=test_case.stdout) 476 # expected CSV output 477 for ref_csv in ref_csvs: 478 csv_name = ref_csv.name 479 if not ref_csv.is_file(): 480 # remove _{ceed_backend} from path name 481 ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv') 482 if not ref_csv.is_file(): 483 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 484 else: 485 diff: str = diff_csv(Path.cwd() / csv_name, ref_csv) 486 if diff: 487 test_case.add_failure_info('csv', output=diff) 488 else: 489 (Path.cwd() / csv_name).unlink() 490 # expected CGNS output 491 for ref_cgn in ref_cgns: 492 cgn_name = ref_cgn.name 493 if not ref_cgn.is_file(): 494 # remove _{ceed_backend} from path name 495 ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns') 496 if not ref_cgn.is_file(): 497 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 498 else: 499 diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol) 500 if diff: 501 test_case.add_failure_info('cgns', output=diff) 502 else: 503 (Path.cwd() / cgn_name).unlink() 504 505 # store result 506 test_case.args = ' '.join(str(arg) for arg in run_args) 507 output_str = test_case_output_string(test_case, spec, mode, backend, test, index) 508 509 return test_case, output_str 510 511 512def init_process(): 513 """Initialize multiprocessing process""" 514 # set up error handler 515 global my_env 516 my_env = os.environ.copy() 517 my_env['CEED_ERROR_HANDLER'] = 'exit' 518 519 520def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int, 521 suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite: 522 """Run all test cases for `test` with each of the provided `ceed_backends` 523 524 Args: 525 test (str): Name of test 526 ceed_backends (List[str]): List of libCEED backends 527 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 528 nproc (int): Number of MPI processes to use when running each test case 529 suite_spec (SuiteSpec): Object defining required methods for running tests 530 pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1. 531 532 Returns: 533 TestSuite: JUnit `TestSuite` containing results of all test cases 534 """ 535 test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test)) 536 if mode is RunMode.TAP: 537 print('TAP version 13') 538 print(f'1..{len(test_specs)}') 539 540 with mp.Pool(processes=pool_size, initializer=init_process) as pool: 541 async_outputs: List[List[mp.AsyncResult]] = [ 542 [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec)) 543 for (i, backend) in enumerate(ceed_backends, start=1)] 544 for spec in test_specs 545 ] 546 547 test_cases = [] 548 for (i, subtest) in enumerate(async_outputs, start=1): 549 is_new_subtest = True 550 subtest_ok = True 551 for async_output in subtest: 552 test_case, print_output = async_output.get() 553 test_cases.append(test_case) 554 if is_new_subtest and mode == RunMode.TAP: 555 is_new_subtest = False 556 print(f'# Subtest: {test_case.category}') 557 print(f' 1..{len(ceed_backends)}') 558 print(print_output, end='') 559 if test_case.is_failure() or test_case.is_error(): 560 subtest_ok = False 561 if mode == RunMode.TAP: 562 print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}') 563 564 return TestSuite(test, test_cases) 565 566 567def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 568 """Write a JUnit XML file containing the results of a `TestSuite` 569 570 Args: 571 test_suite (TestSuite): JUnit `TestSuite` to write 572 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 573 batch (str): Name of JUnit batch, defaults to empty string 574 """ 575 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 576 output_file.write_text(to_xml_report_string([test_suite])) 577 578 579def has_failures(test_suite: TestSuite) -> bool: 580 """Check whether any test cases in a `TestSuite` failed 581 582 Args: 583 test_suite (TestSuite): JUnit `TestSuite` to check 584 585 Returns: 586 bool: True if any test cases failed 587 """ 588 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 589