1from abc import ABC, abstractmethod 2import argparse 3import csv 4from dataclasses import dataclass, field 5import difflib 6from enum import Enum 7from math import isclose 8import os 9from pathlib import Path 10import re 11import subprocess 12import multiprocessing as mp 13from itertools import product 14import sys 15import time 16from typing import Optional, Tuple, List 17 18sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 19from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 20 21 22class CaseInsensitiveEnumAction(argparse.Action): 23 """Action to convert input values to lower case prior to converting to an Enum type""" 24 25 def __init__(self, option_strings, dest, type, default, **kwargs): 26 if not (issubclass(type, Enum) and issubclass(type, str)): 27 raise ValueError(f"{type} must be a StrEnum or str and Enum") 28 # store provided enum type 29 self.enum_type = type 30 if isinstance(default, str): 31 default = self.enum_type(default.lower()) 32 else: 33 default = [self.enum_type(v.lower()) for v in default] 34 # prevent automatic type conversion 35 super().__init__(option_strings, dest, default=default, **kwargs) 36 37 def __call__(self, parser, namespace, values, option_string=None): 38 if isinstance(values, str): 39 values = self.enum_type(values.lower()) 40 else: 41 values = [self.enum_type(v.lower()) for v in values] 42 setattr(namespace, self.dest, values) 43 44 45@dataclass 46class TestSpec: 47 """Dataclass storing information about a single test case""" 48 name: str 49 only: List = field(default_factory=list) 50 args: List = field(default_factory=list) 51 52 53class RunMode(str, Enum): 54 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 55 __str__ = str.__str__ 56 __format__ = str.__format__ 57 TAP: str = 'tap' 58 JUNIT: str = 'junit' 59 60 61class SuiteSpec(ABC): 62 """Abstract Base Class defining the required interface for running a test suite""" 63 @abstractmethod 64 def get_source_path(self, test: str) -> Path: 65 """Compute path to test source file 66 67 Args: 68 test (str): Name of test 69 70 Returns: 71 Path: Path to source file 72 """ 73 raise NotImplementedError 74 75 @abstractmethod 76 def get_run_path(self, test: str) -> Path: 77 """Compute path to built test executable file 78 79 Args: 80 test (str): Name of test 81 82 Returns: 83 Path: Path to test executable 84 """ 85 raise NotImplementedError 86 87 @abstractmethod 88 def get_output_path(self, test: str, output_file: str) -> Path: 89 """Compute path to expected output file 90 91 Args: 92 test (str): Name of test 93 output_file (str): File name of output file 94 95 Returns: 96 Path: Path to expected output file 97 """ 98 raise NotImplementedError 99 100 def post_test_hook(self, test: str, spec: TestSpec) -> None: 101 """Function callback ran after each test case 102 103 Args: 104 test (str): Name of test 105 spec (TestSpec): Test case specification 106 """ 107 pass 108 109 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 110 """Check if a test case should be skipped prior to running, returning the reason for skipping 111 112 Args: 113 test (str): Name of test 114 spec (TestSpec): Test case specification 115 resource (str): libCEED backend 116 nproc (int): Number of MPI processes to use when running test case 117 118 Returns: 119 Optional[str]: Skip reason, or `None` if test case should not be skipped 120 """ 121 return None 122 123 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 124 """Check if a test case should be allowed to fail, based on its stderr output 125 126 Args: 127 test (str): Name of test 128 spec (TestSpec): Test case specification 129 resource (str): libCEED backend 130 stderr (str): Standard error output from test case execution 131 132 Returns: 133 Optional[str]: Skip reason, or `None` if unexpected error 134 """ 135 return None 136 137 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]: 138 """Check whether a test case is expected to fail and if it failed expectedly 139 140 Args: 141 test (str): Name of test 142 spec (TestSpec): Test case specification 143 resource (str): libCEED backend 144 stderr (str): Standard error output from test case execution 145 146 Returns: 147 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 148 """ 149 return '', True 150 151 def check_allowed_stdout(self, test: str) -> bool: 152 """Check whether a test is allowed to print console output 153 154 Args: 155 test (str): Name of test 156 157 Returns: 158 bool: True if the test is allowed to print console output 159 """ 160 return False 161 162 163def has_cgnsdiff() -> bool: 164 """Check whether `cgnsdiff` is an executable program in the current environment 165 166 Returns: 167 bool: True if `cgnsdiff` is found 168 """ 169 my_env: dict = os.environ.copy() 170 proc = subprocess.run('cgnsdiff', 171 shell=True, 172 stdout=subprocess.PIPE, 173 stderr=subprocess.PIPE, 174 env=my_env) 175 return 'not found' not in proc.stderr.decode('utf-8') 176 177 178def contains_any(base: str, substrings: List[str]) -> bool: 179 """Helper function, checks if any of the substrings are included in the base string 180 181 Args: 182 base (str): Base string to search in 183 substrings (List[str]): List of potential substrings 184 185 Returns: 186 bool: True if any substrings are included in base string 187 """ 188 return any((sub in base for sub in substrings)) 189 190 191def startswith_any(base: str, prefixes: List[str]) -> bool: 192 """Helper function, checks if the base string is prefixed by any of `prefixes` 193 194 Args: 195 base (str): Base string to search 196 prefixes (List[str]): List of potential prefixes 197 198 Returns: 199 bool: True if base string is prefixed by any of the prefixes 200 """ 201 return any((base.startswith(prefix) for prefix in prefixes)) 202 203 204def parse_test_line(line: str) -> TestSpec: 205 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 206 207 Args: 208 line (str): String containing TESTARGS specification and CLI arguments 209 210 Returns: 211 TestSpec: Parsed specification of test case 212 """ 213 args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip()) 214 if args[0] == 'TESTARGS': 215 return TestSpec(name='', args=args[1:]) 216 raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')] 217 # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'} 218 test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)]) 219 name: str = test_args.get('name', '') 220 constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else [] 221 if len(args) > 1: 222 return TestSpec(name=name, only=constraints, args=args[1:]) 223 else: 224 return TestSpec(name=name, only=constraints) 225 226 227def get_test_args(source_file: Path) -> List[TestSpec]: 228 """Parse all test cases from a given source file 229 230 Args: 231 source_file (Path): Path to source file 232 233 Raises: 234 RuntimeError: Errors if source file extension is unsupported 235 236 Returns: 237 List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 238 """ 239 comment_str: str = '' 240 if source_file.suffix in ['.c', '.cc', '.cpp']: 241 comment_str = '//' 242 elif source_file.suffix in ['.py']: 243 comment_str = '#' 244 elif source_file.suffix in ['.usr']: 245 comment_str = 'C_' 246 elif source_file.suffix in ['.f90']: 247 comment_str = '! ' 248 else: 249 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 250 251 return [parse_test_line(line.strip(comment_str)) 252 for line in source_file.read_text().splitlines() 253 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 254 255 256def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: 257 """Compare CSV results against an expected CSV file with tolerances 258 259 Args: 260 test_csv (Path): Path to output CSV results 261 true_csv (Path): Path to expected CSV results 262 zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. 263 rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. 264 265 Returns: 266 str: Diff output between result and expected CSVs 267 """ 268 test_lines: List[str] = test_csv.read_text().splitlines() 269 true_lines: List[str] = true_csv.read_text().splitlines() 270 # Files should not be empty 271 if len(test_lines) == 0: 272 return f'No lines found in test output {test_csv}' 273 if len(true_lines) == 0: 274 return f'No lines found in test source {true_csv}' 275 276 test_reader: csv.DictReader = csv.DictReader(test_lines) 277 true_reader: csv.DictReader = csv.DictReader(true_lines) 278 if test_reader.fieldnames != true_reader.fieldnames: 279 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 280 tofile='found CSV columns', fromfile='expected CSV columns')) 281 282 if len(test_lines) != len(true_lines): 283 return f'Number of lines in {test_csv} and {true_csv} do not match' 284 diff_lines: List[str] = list() 285 for test_line, true_line in zip(test_reader, true_reader): 286 for key in test_reader.fieldnames: 287 # Check if the value is numeric 288 try: 289 true_val: float = float(true_line[key]) 290 test_val: float = float(test_line[key]) 291 true_zero: bool = abs(true_val) < zero_tol 292 test_zero: bool = abs(test_val) < zero_tol 293 fail: bool = False 294 if true_zero: 295 fail = not test_zero 296 else: 297 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 298 if fail: 299 diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}') 300 except ValueError: 301 if test_line[key] != true_line[key]: 302 diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}') 303 304 return '\n'.join(diff_lines) 305 306 307def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str: 308 """Compare CGNS results against an expected CGSN file with tolerance 309 310 Args: 311 test_cgns (Path): Path to output CGNS file 312 true_cgns (Path): Path to expected CGNS file 313 tolerance (float, optional): Tolerance for comparing floating-point values 314 315 Returns: 316 str: Diff output between result and expected CGNS files 317 """ 318 my_env: dict = os.environ.copy() 319 320 run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)] 321 proc = subprocess.run(' '.join(run_args), 322 shell=True, 323 stdout=subprocess.PIPE, 324 stderr=subprocess.PIPE, 325 env=my_env) 326 327 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 328 329 330def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode, 331 backend: str, test: str, index: int) -> str: 332 output_str = '' 333 if mode is RunMode.TAP: 334 # print incremental output if TAP mode 335 if test_case.is_skipped(): 336 output_str += f' ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n' 337 elif test_case.is_failure() or test_case.is_error(): 338 output_str += f' not ok {index} - {spec.name}, {backend}\n' 339 else: 340 output_str += f' ok {index} - {spec.name}, {backend}\n' 341 output_str += f' ---\n' 342 if spec.only: 343 output_str += f' only: {",".join(spec.only)}\n' 344 output_str += f' args: {test_case.args}\n' 345 if test_case.is_error(): 346 output_str += f' error: {test_case.errors[0]["message"]}\n' 347 if test_case.is_failure(): 348 output_str += f' num_failures: {len(test_case.failures)}\n' 349 for i, failure in enumerate(test_case.failures): 350 output_str += f' failure_{i}: {failure["message"]}\n' 351 output_str += f' message: {failure["message"]}\n' 352 if failure["output"]: 353 out = failure["output"].strip().replace('\n', '\n ') 354 output_str += f' output: |\n {out}\n' 355 output_str += f' ...\n' 356 else: 357 # print error or failure information if JUNIT mode 358 if test_case.is_error() or test_case.is_failure(): 359 output_str += f'Test: {test} {spec.name}\n' 360 output_str += f' $ {test_case.args}\n' 361 if test_case.is_error(): 362 output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()) 363 output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()) 364 if test_case.is_failure(): 365 for failure in test_case.failures: 366 output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip()) 367 output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip()) 368 return output_str 369 370 371def run_test(index: int, test: str, spec: TestSpec, backend: str, 372 mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase: 373 """Run a single test case and backend combination 374 375 Args: 376 index (int): Index of backend for current spec 377 test (str): Path to test 378 spec (TestSpec): Specification of test case 379 backend (str): CEED backend 380 mode (RunMode): Output mode 381 nproc (int): Number of MPI processes to use when running test case 382 suite_spec (SuiteSpec): Specification of test suite 383 384 Returns: 385 TestCase: Test case result 386 """ 387 source_path: Path = suite_spec.get_source_path(test) 388 run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)] 389 390 if '{ceed_resource}' in run_args: 391 run_args[run_args.index('{ceed_resource}')] = backend 392 for i, arg in enumerate(run_args): 393 if '{ceed_resource}' in arg: 394 run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-')) 395 if '{nproc}' in run_args: 396 run_args[run_args.index('{nproc}')] = f'{nproc}' 397 elif nproc > 1 and source_path.suffix != '.py': 398 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 399 400 # run test 401 skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc) 402 if skip_reason: 403 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 404 elapsed_sec=0, 405 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 406 stdout='', 407 stderr='', 408 category=spec.name,) 409 test_case.add_skipped_info(skip_reason) 410 else: 411 start: float = time.time() 412 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 413 shell=True, 414 stdout=subprocess.PIPE, 415 stderr=subprocess.PIPE, 416 env=my_env) 417 418 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 419 classname=source_path.parent, 420 elapsed_sec=time.time() - start, 421 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 422 stdout=proc.stdout.decode('utf-8'), 423 stderr=proc.stderr.decode('utf-8'), 424 allow_multiple_subelements=True, 425 category=spec.name,) 426 ref_csvs: List[Path] = [] 427 output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] 428 if output_files: 429 ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] 430 ref_cgns: List[Path] = [] 431 output_files = [arg for arg in run_args if 'cgns:' in arg] 432 if output_files: 433 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 434 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 435 suite_spec.post_test_hook(test, spec) 436 437 # check allowed failures 438 if not test_case.is_skipped() and test_case.stderr: 439 skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr) 440 if skip_reason: 441 test_case.add_skipped_info(skip_reason) 442 443 # check required failures 444 if not test_case.is_skipped(): 445 required_message, did_fail = suite_spec.check_required_failure( 446 test, spec, backend, test_case.stderr) 447 if required_message and did_fail: 448 test_case.status = f'fails with required: {required_message}' 449 elif required_message: 450 test_case.add_failure_info(f'required failure missing: {required_message}') 451 452 # classify other results 453 if not test_case.is_skipped() and not test_case.status: 454 if test_case.stderr: 455 test_case.add_failure_info('stderr', test_case.stderr) 456 if proc.returncode != 0: 457 test_case.add_error_info(f'returncode = {proc.returncode}') 458 if ref_stdout.is_file(): 459 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 460 test_case.stdout.splitlines(keepends=True), 461 fromfile=str(ref_stdout), 462 tofile='New')) 463 if diff: 464 test_case.add_failure_info('stdout', output=''.join(diff)) 465 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 466 test_case.add_failure_info('stdout', output=test_case.stdout) 467 # expected CSV output 468 for ref_csv in ref_csvs: 469 csv_name = ref_csv.name 470 if not ref_csv.is_file(): 471 # remove _{ceed_backend} from path name 472 ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv') 473 if not ref_csv.is_file(): 474 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 475 else: 476 diff: str = diff_csv(Path.cwd() / csv_name, ref_csv) 477 if diff: 478 test_case.add_failure_info('csv', output=diff) 479 else: 480 (Path.cwd() / csv_name).unlink() 481 # expected CGNS output 482 for ref_cgn in ref_cgns: 483 cgn_name = ref_cgn.name 484 if not ref_cgn.is_file(): 485 # remove _{ceed_backend} from path name 486 ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns') 487 if not ref_cgn.is_file(): 488 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 489 else: 490 diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn) 491 if diff: 492 test_case.add_failure_info('cgns', output=diff) 493 else: 494 (Path.cwd() / cgn_name).unlink() 495 496 # store result 497 test_case.args = ' '.join(str(arg) for arg in run_args) 498 output_str = test_case_output_string(test_case, spec, mode, backend, test, index) 499 500 return test_case, output_str 501 502 503def init_process(): 504 """Initialize multiprocessing process""" 505 # set up error handler 506 global my_env 507 my_env = os.environ.copy() 508 my_env['CEED_ERROR_HANDLER'] = 'exit' 509 510 511def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int, 512 suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite: 513 """Run all test cases for `test` with each of the provided `ceed_backends` 514 515 Args: 516 test (str): Name of test 517 ceed_backends (List[str]): List of libCEED backends 518 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 519 nproc (int): Number of MPI processes to use when running each test case 520 suite_spec (SuiteSpec): Object defining required methods for running tests 521 pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1. 522 523 Returns: 524 TestSuite: JUnit `TestSuite` containing results of all test cases 525 """ 526 test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test)) 527 if mode is RunMode.TAP: 528 print('TAP version 13') 529 print(f'1..{len(test_specs)}') 530 531 with mp.Pool(processes=pool_size, initializer=init_process) as pool: 532 async_outputs: List[List[mp.AsyncResult]] = [ 533 [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec)) 534 for (i, backend) in enumerate(ceed_backends, start=1)] 535 for spec in test_specs 536 ] 537 538 test_cases = [] 539 for (i, subtest) in enumerate(async_outputs, start=1): 540 is_new_subtest = True 541 subtest_ok = True 542 for async_output in subtest: 543 test_case, print_output = async_output.get() 544 test_cases.append(test_case) 545 if is_new_subtest and mode == RunMode.TAP: 546 is_new_subtest = False 547 print(f'# Subtest: {test_case.category}') 548 print(f' 1..{len(ceed_backends)}') 549 print(print_output, end='') 550 if test_case.is_failure() or test_case.is_error(): 551 subtest_ok = False 552 if mode == RunMode.TAP: 553 print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}') 554 555 return TestSuite(test, test_cases) 556 557 558def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 559 """Write a JUnit XML file containing the results of a `TestSuite` 560 561 Args: 562 test_suite (TestSuite): JUnit `TestSuite` to write 563 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 564 batch (str): Name of JUnit batch, defaults to empty string 565 """ 566 output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 567 output_file.write_text(to_xml_report_string([test_suite])) 568 569 570def has_failures(test_suite: TestSuite) -> bool: 571 """Check whether any test cases in a `TestSuite` failed 572 573 Args: 574 test_suite (TestSuite): JUnit `TestSuite` to check 575 576 Returns: 577 bool: True if any test cases failed 578 """ 579 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 580