1from abc import ABC, abstractmethod 2from collections.abc import Iterable 3import argparse 4import csv 5from dataclasses import dataclass, field, fields 6import difflib 7from enum import Enum 8from math import isclose 9import os 10from pathlib import Path 11import re 12import subprocess 13import multiprocessing as mp 14import sys 15import time 16from typing import Optional, Tuple, List, Dict, Callable, Iterable, get_origin 17import shutil 18 19sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) 20from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 21 22 23class ParseError(RuntimeError): 24 """A custom exception for failed parsing.""" 25 26 def __init__(self, message): 27 super().__init__(message) 28 29 30class CaseInsensitiveEnumAction(argparse.Action): 31 """Action to convert input values to lower case prior to converting to an Enum type""" 32 33 def __init__(self, option_strings, dest, type, default, **kwargs): 34 if not issubclass(type, Enum): 35 raise ValueError(f"{type} must be an Enum") 36 # store provided enum type 37 self.enum_type = type 38 if isinstance(default, self.enum_type): 39 pass 40 elif isinstance(default, str): 41 default = self.enum_type(default.lower()) 42 elif isinstance(default, Iterable): 43 default = [self.enum_type(v.lower()) for v in default] 44 else: 45 raise argparse.ArgumentTypeError("Invalid value type, must be str or iterable") 46 # prevent automatic type conversion 47 super().__init__(option_strings, dest, default=default, **kwargs) 48 49 def __call__(self, parser, namespace, values, option_string=None): 50 if isinstance(values, self.enum_type): 51 pass 52 elif isinstance(values, str): 53 values = self.enum_type(values.lower()) 54 elif isinstance(values, Iterable): 55 values = [self.enum_type(v.lower()) for v in values] 56 else: 57 raise argparse.ArgumentTypeError("Invalid value type, must be str or iterable") 58 setattr(namespace, self.dest, values) 59 60 61@dataclass 62class TestSpec: 63 """Dataclass storing information about a single test case""" 64 name: str = field(default_factory=str) 65 csv_rtol: float = -1 66 csv_ztol: float = -1 67 cgns_tol: float = -1 68 only: List = field(default_factory=list) 69 args: List = field(default_factory=list) 70 key_values: Dict = field(default_factory=dict) 71 72 73class RunMode(Enum): 74 """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`""" 75 TAP = 'tap' 76 JUNIT = 'junit' 77 78 def __str__(self): 79 return self.value 80 81 def __repr__(self): 82 return self.value 83 84 85class SuiteSpec(ABC): 86 """Abstract Base Class defining the required interface for running a test suite""" 87 @abstractmethod 88 def get_source_path(self, test: str) -> Path: 89 """Compute path to test source file 90 91 Args: 92 test (str): Name of test 93 94 Returns: 95 Path: Path to source file 96 """ 97 raise NotImplementedError 98 99 @abstractmethod 100 def get_run_path(self, test: str) -> Path: 101 """Compute path to built test executable file 102 103 Args: 104 test (str): Name of test 105 106 Returns: 107 Path: Path to test executable 108 """ 109 raise NotImplementedError 110 111 @abstractmethod 112 def get_output_path(self, test: str, output_file: str) -> Path: 113 """Compute path to expected output file 114 115 Args: 116 test (str): Name of test 117 output_file (str): File name of output file 118 119 Returns: 120 Path: Path to expected output file 121 """ 122 raise NotImplementedError 123 124 @property 125 def test_failure_artifacts_path(self) -> Path: 126 """Path to test failure artifacts""" 127 return Path('build') / 'test_failure_artifacts' 128 129 @property 130 def cgns_tol(self): 131 """Absolute tolerance for CGNS diff""" 132 return getattr(self, '_cgns_tol', 1.0e-12) 133 134 @cgns_tol.setter 135 def cgns_tol(self, val): 136 self._cgns_tol = val 137 138 @property 139 def csv_ztol(self): 140 """Keyword arguments to be passed to diff_csv()""" 141 return getattr(self, '_csv_ztol', 3e-10) 142 143 @csv_ztol.setter 144 def csv_ztol(self, val): 145 self._csv_ztol = val 146 147 @property 148 def csv_rtol(self): 149 """Keyword arguments to be passed to diff_csv()""" 150 return getattr(self, '_csv_rtol', 1e-6) 151 152 @csv_rtol.setter 153 def csv_rtol(self, val): 154 self._csv_rtol = val 155 156 def post_test_hook(self, test: str, spec: TestSpec, backend: str) -> None: 157 """Function callback ran after each test case 158 159 Args: 160 test (str): Name of test 161 spec (TestSpec): Test case specification 162 """ 163 pass 164 165 def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]: 166 """Check if a test case should be skipped prior to running, returning the reason for skipping 167 168 Args: 169 test (str): Name of test 170 spec (TestSpec): Test case specification 171 resource (str): libCEED backend 172 nproc (int): Number of MPI processes to use when running test case 173 174 Returns: 175 Optional[str]: Skip reason, or `None` if test case should not be skipped 176 """ 177 return None 178 179 def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]: 180 """Check if a test case should be allowed to fail, based on its stderr output 181 182 Args: 183 test (str): Name of test 184 spec (TestSpec): Test case specification 185 resource (str): libCEED backend 186 stderr (str): Standard error output from test case execution 187 188 Returns: 189 Optional[str]: Skip reason, or `None` if unexpected error 190 """ 191 return None 192 193 def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]: 194 """Check whether a test case is expected to fail and if it failed expectedly 195 196 Args: 197 test (str): Name of test 198 spec (TestSpec): Test case specification 199 resource (str): libCEED backend 200 stderr (str): Standard error output from test case execution 201 202 Returns: 203 tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr` 204 """ 205 return '', True 206 207 def check_allowed_stdout(self, test: str) -> bool: 208 """Check whether a test is allowed to print console output 209 210 Args: 211 test (str): Name of test 212 213 Returns: 214 bool: True if the test is allowed to print console output 215 """ 216 return False 217 218 219def has_cgnsdiff() -> bool: 220 """Check whether `cgnsdiff` is an executable program in the current environment 221 222 Returns: 223 bool: True if `cgnsdiff` is found 224 """ 225 my_env: dict = os.environ.copy() 226 proc = subprocess.run('cgnsdiff', 227 shell=True, 228 stdout=subprocess.PIPE, 229 stderr=subprocess.PIPE, 230 env=my_env) 231 return 'not found' not in proc.stderr.decode('utf-8') 232 233 234def contains_any(base: str, substrings: List[str]) -> bool: 235 """Helper function, checks if any of the substrings are included in the base string 236 237 Args: 238 base (str): Base string to search in 239 substrings (List[str]): List of potential substrings 240 241 Returns: 242 bool: True if any substrings are included in base string 243 """ 244 return any((sub in base for sub in substrings)) 245 246 247def startswith_any(base: str, prefixes: List[str]) -> bool: 248 """Helper function, checks if the base string is prefixed by any of `prefixes` 249 250 Args: 251 base (str): Base string to search 252 prefixes (List[str]): List of potential prefixes 253 254 Returns: 255 bool: True if base string is prefixed by any of the prefixes 256 """ 257 return any((base.startswith(prefix) for prefix in prefixes)) 258 259 260def find_matching(line: str, open: str = '(', close: str = ')') -> Tuple[int, int]: 261 """Find the start and end positions of the first outer paired delimeters 262 263 Args: 264 line (str): Line to search 265 open (str, optional): Opening delimiter, must be different than `close`. Defaults to '('. 266 close (str, optional): Closing delimeter, must be different than `open`. Defaults to ')'. 267 268 Raises: 269 RuntimeError: If open or close is not a single character 270 RuntimeError: If open and close are the same characters 271 272 Returns: 273 Tuple[int]: If matching delimeters are found, return indices in `list`. Otherwise, return end < start. 274 """ 275 if len(open) != 1 or len(close) != 1: 276 raise RuntimeError("`open` and `close` must be single characters") 277 if open == close: 278 raise RuntimeError("`open` and `close` must be different characters") 279 start: int = line.find(open) 280 if start < 0: 281 return -1, -1 282 count: int = 1 283 for i in range(start + 1, len(line)): 284 if line[i] == open: 285 count += 1 286 if line[i] == close: 287 count -= 1 288 if count == 0: 289 return start, i 290 return start, -1 291 292 293def parse_test_line(line: str) -> TestSpec: 294 """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object 295 296 Args: 297 line (str): String containing TESTARGS specification and CLI arguments 298 299 Returns: 300 TestSpec: Parsed specification of test case 301 """ 302 test_fields = fields(TestSpec) 303 field_names = [f.name for f in test_fields] 304 known: Dict = dict() 305 other: Dict = dict() 306 if line[0] == "(": 307 # have key/value pairs to parse 308 start, end = find_matching(line) 309 if end < start: 310 raise ParseError(f"Mismatched parentheses in TESTCASE: {line}") 311 312 keyvalues_str = line[start:end + 1] 313 keyvalues_pattern = re.compile(r''' 314 (?:\(\s*|\s*,\s*) # start with open parentheses or comma, no capture 315 ([A-Za-z]+[\w\-]+) # match key starting with alpha, containing alphanumeric, _, or -; captured as Group 1 316 \s*=\s* # key is followed by = (whitespace ignored) 317 (?: # uncaptured group for OR 318 "((?:[^"]|\\")+)" # match quoted value (any internal " must be escaped as \"); captured as Group 2 319 | ([^=]+) # OR match unquoted value (no equals signs allowed); captured as Group 3 320 ) # end uncaptured group for OR 321 \s*(?=,|\)) # lookahead for either next comma or closing parentheses 322 ''', re.VERBOSE) 323 324 for match in re.finditer(keyvalues_pattern, keyvalues_str): 325 if not match: # empty 326 continue 327 key = match.group(1) 328 value = match.group(2) if match.group(2) else match.group(3) 329 try: 330 index = field_names.index(key) 331 if key == "only": # weird bc only is a list 332 value = [constraint.strip() for constraint in value.split(',')] 333 try: 334 # TODO: stop supporting python <=3.8 335 known[key] = test_fields[index].type(value) # type: ignore 336 except TypeError: 337 # TODO: this is still liable to fail for complex types 338 known[key] = get_origin(test_fields[index].type)(value) # type: ignore 339 except ValueError: 340 other[key] = value 341 342 line = line[end + 1:] 343 344 args_pattern = re.compile(r''' 345 \s+( # remove leading space 346 (?:"[^"]+") # match quoted CLI option 347 | (?:[\S]+) # match anything else that is space separated 348 ) 349 ''', re.VERBOSE) 350 args: List[str] = re.findall(args_pattern, line) 351 for k, v in other.items(): 352 print(f"warning, unknown TESTCASE option for test '{known['name']}': {k}={v}") 353 return TestSpec(**known, key_values=other, args=args) 354 355 356def get_test_args(source_file: Path) -> List[TestSpec]: 357 """Parse all test cases from a given source file 358 359 Args: 360 source_file (Path): Path to source file 361 362 Raises: 363 RuntimeError: Errors if source file extension is unsupported 364 365 Returns: 366 List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found 367 """ 368 comment_str: str = '' 369 if source_file.suffix in ['.c', '.cc', '.cpp']: 370 comment_str = '//' 371 elif source_file.suffix in ['.py']: 372 comment_str = '#' 373 elif source_file.suffix in ['.usr']: 374 comment_str = 'C_' 375 elif source_file.suffix in ['.f90']: 376 comment_str = '! ' 377 else: 378 raise RuntimeError(f'Unrecognized extension for file: {source_file}') 379 380 return [parse_test_line(line.strip(comment_str).removeprefix("TESTARGS")) 381 for line in source_file.read_text().splitlines() 382 if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] 383 384 385def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float, rel_tol: float, 386 comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str: 387 """Compare CSV results against an expected CSV file with tolerances 388 389 Args: 390 test_csv (Path): Path to output CSV results 391 true_csv (Path): Path to expected CSV results 392 zero_tol (float): Tolerance below which values are considered to be zero. 393 rel_tol (float): Relative tolerance for comparing non-zero values. 394 comment_str (str, optional): String to denoting commented line 395 comment_func (Callable, optional): Function to determine if test and true line are different 396 397 Returns: 398 str: Diff output between result and expected CSVs 399 """ 400 test_lines: List[str] = test_csv.read_text().splitlines() 401 true_lines: List[str] = true_csv.read_text().splitlines() 402 # Files should not be empty 403 if len(test_lines) == 0: 404 return f'No lines found in test output {test_csv}' 405 if len(true_lines) == 0: 406 return f'No lines found in test source {true_csv}' 407 if len(test_lines) != len(true_lines): 408 return f'Number of lines in {test_csv} and {true_csv} do not match' 409 410 # Process commented lines 411 uncommented_lines: List[int] = [] 412 for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)): 413 if test_line[0] == comment_str and true_line[0] == comment_str: 414 if comment_func: 415 output = comment_func(test_line, true_line) 416 if output: 417 return output 418 elif test_line[0] == comment_str and true_line[0] != comment_str: 419 return f'Commented line found in {test_csv} at line {n} but not in {true_csv}' 420 elif test_line[0] != comment_str and true_line[0] == comment_str: 421 return f'Commented line found in {true_csv} at line {n} but not in {test_csv}' 422 else: 423 uncommented_lines.append(n) 424 425 # Remove commented lines 426 test_lines = [test_lines[line] for line in uncommented_lines] 427 true_lines = [true_lines[line] for line in uncommented_lines] 428 429 test_reader: csv.DictReader = csv.DictReader(test_lines) 430 true_reader: csv.DictReader = csv.DictReader(true_lines) 431 if not test_reader.fieldnames: 432 return f'No CSV columns found in test output {test_csv}' 433 if not true_reader.fieldnames: 434 return f'No CSV columns found in test source {true_csv}' 435 if test_reader.fieldnames != true_reader.fieldnames: 436 return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], 437 tofile='found CSV columns', fromfile='expected CSV columns')) 438 439 diff_lines: List[str] = list() 440 for test_line, true_line in zip(test_reader, true_reader): 441 for key in test_reader.fieldnames: 442 # Check if the value is numeric 443 try: 444 true_val: float = float(true_line[key]) 445 test_val: float = float(test_line[key]) 446 true_zero: bool = abs(true_val) < zero_tol 447 test_zero: bool = abs(test_val) < zero_tol 448 fail: bool = False 449 if true_zero: 450 fail = not test_zero 451 else: 452 fail = not isclose(test_val, true_val, rel_tol=rel_tol) 453 if fail: 454 diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}') 455 except ValueError: 456 if test_line[key] != true_line[key]: 457 diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}') 458 459 return '\n'.join(diff_lines) 460 461 462def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float) -> str: 463 """Compare CGNS results against an expected CGSN file with tolerance 464 465 Args: 466 test_cgns (Path): Path to output CGNS file 467 true_cgns (Path): Path to expected CGNS file 468 cgns_tol (float): Tolerance for comparing floating-point values 469 470 Returns: 471 str: Diff output between result and expected CGNS files 472 """ 473 my_env: dict = os.environ.copy() 474 475 run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)] 476 proc = subprocess.run(' '.join(run_args), 477 shell=True, 478 stdout=subprocess.PIPE, 479 stderr=subprocess.PIPE, 480 env=my_env) 481 482 return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8') 483 484 485def diff_ascii(test_file: Path, true_file: Path, backend: str) -> str: 486 """Compare ASCII results against an expected ASCII file 487 488 Args: 489 test_file (Path): Path to output ASCII file 490 true_file (Path): Path to expected ASCII file 491 492 Returns: 493 str: Diff output between result and expected ASCII files 494 """ 495 tmp_backend: str = backend.replace('/', '-') 496 true_str: str = true_file.read_text().replace('{ceed_resource}', tmp_backend) 497 diff = list(difflib.unified_diff(test_file.read_text().splitlines(keepends=True), 498 true_str.splitlines(keepends=True), 499 fromfile=str(test_file), 500 tofile=str(true_file))) 501 return ''.join(diff) 502 503 504def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode, 505 backend: str, test: str, index: int, verbose: bool) -> str: 506 output_str = '' 507 if mode is RunMode.TAP: 508 # print incremental output if TAP mode 509 if test_case.is_skipped(): 510 output_str += f' ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n' 511 elif test_case.is_failure() or test_case.is_error(): 512 output_str += f' not ok {index} - {spec.name}, {backend} ({test_case.elapsed_sec} s)\n' 513 else: 514 output_str += f' ok {index} - {spec.name}, {backend} ({test_case.elapsed_sec} s)\n' 515 if test_case.is_failure() or test_case.is_error() or verbose: 516 output_str += f' ---\n' 517 if spec.only: 518 output_str += f' only: {",".join(spec.only)}\n' 519 output_str += f' args: {test_case.args}\n' 520 if spec.csv_ztol > 0: 521 output_str += f' csv_ztol: {spec.csv_ztol}\n' 522 if spec.csv_rtol > 0: 523 output_str += f' csv_rtol: {spec.csv_rtol}\n' 524 if spec.cgns_tol > 0: 525 output_str += f' cgns_tol: {spec.cgns_tol}\n' 526 for k, v in spec.key_values.items(): 527 output_str += f' {k}: {v}\n' 528 if test_case.is_error(): 529 output_str += f' error: {test_case.errors[0]["message"]}\n' 530 if test_case.is_failure(): 531 output_str += f' failures:\n' 532 for i, failure in enumerate(test_case.failures): 533 output_str += f' -\n' 534 output_str += f' message: {failure["message"]}\n' 535 if failure["output"]: 536 out = failure["output"].strip().replace('\n', '\n ') 537 output_str += f' output: |\n {out}\n' 538 output_str += f' ...\n' 539 else: 540 # print error or failure information if JUNIT mode 541 if test_case.is_error() or test_case.is_failure(): 542 output_str += f'Test: {test} {spec.name}\n' 543 output_str += f' $ {test_case.args}\n' 544 if test_case.is_error(): 545 output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()) 546 output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()) 547 if test_case.is_failure(): 548 for failure in test_case.failures: 549 output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip()) 550 output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip()) 551 return output_str 552 553 554def save_failure_artifact(suite_spec: SuiteSpec, file: Path) -> Path: 555 """Attach a file to a test case 556 557 Args: 558 test_case (TestCase): Test case to attach the file to 559 file (Path): Path to the file to attach 560 """ 561 save_path: Path = suite_spec.test_failure_artifacts_path / file.name 562 shutil.copyfile(file, save_path) 563 return save_path 564 565 566def run_test(index: int, test: str, spec: TestSpec, backend: str, 567 mode: RunMode, nproc: int, suite_spec: SuiteSpec, verbose: bool = False) -> TestCase: 568 """Run a single test case and backend combination 569 570 Args: 571 index (int): Index of backend for current spec 572 test (str): Path to test 573 spec (TestSpec): Specification of test case 574 backend (str): CEED backend 575 mode (RunMode): Output mode 576 nproc (int): Number of MPI processes to use when running test case 577 suite_spec (SuiteSpec): Specification of test suite 578 verbose (bool, optional): Print detailed output for all runs, not just failures. Defaults to False. 579 580 Returns: 581 TestCase: Test case result 582 """ 583 source_path: Path = suite_spec.get_source_path(test) 584 run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)] 585 586 if '{ceed_resource}' in run_args: 587 run_args[run_args.index('{ceed_resource}')] = backend 588 for i, arg in enumerate(run_args): 589 if '{ceed_resource}' in arg: 590 run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-')) 591 if '{nproc}' in run_args: 592 run_args[run_args.index('{nproc}')] = f'{nproc}' 593 elif nproc > 1 and source_path.suffix != '.py': 594 run_args = ['mpiexec', '-n', f'{nproc}', *run_args] 595 596 # run test 597 skip_reason: Optional[str] = suite_spec.check_pre_skip(test, spec, backend, nproc) 598 if skip_reason: 599 test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 600 elapsed_sec=0, 601 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()), 602 stdout='', 603 stderr='', 604 category=spec.name,) 605 test_case.add_skipped_info(skip_reason) 606 else: 607 start: float = time.time() 608 proc = subprocess.run(' '.join(str(arg) for arg in run_args), 609 shell=True, 610 stdout=subprocess.PIPE, 611 stderr=subprocess.PIPE, 612 env=my_env) 613 614 test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}', 615 classname=source_path.parent, 616 elapsed_sec=time.time() - start, 617 timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 618 stdout=proc.stdout.decode('utf-8'), 619 stderr=proc.stderr.decode('utf-8'), 620 allow_multiple_subelements=True, 621 category=spec.name,) 622 ref_csvs: List[Path] = [] 623 ref_ascii: List[Path] = [] 624 output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] 625 if output_files: 626 ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) 627 for file in output_files if file.endswith('.csv')] 628 ref_ascii = [suite_spec.get_output_path(test, file.split(':')[1]) 629 for file in output_files if not file.endswith('.csv')] 630 ref_cgns: List[Path] = [] 631 output_files = [arg for arg in run_args if 'cgns:' in arg] 632 if output_files: 633 ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files] 634 ref_stdout: Path = suite_spec.get_output_path(test, test + '.out') 635 suite_spec.post_test_hook(test, spec, backend) 636 637 # check allowed failures 638 if not test_case.is_skipped() and test_case.stderr: 639 skip_reason: Optional[str] = suite_spec.check_post_skip(test, spec, backend, test_case.stderr) 640 if skip_reason: 641 test_case.add_skipped_info(skip_reason) 642 643 # check required failures 644 if not test_case.is_skipped(): 645 required_message, did_fail = suite_spec.check_required_failure( 646 test, spec, backend, test_case.stderr) 647 if required_message and did_fail: 648 test_case.status = f'fails with required: {required_message}' 649 elif required_message: 650 test_case.add_failure_info(f'required failure missing: {required_message}') 651 652 # classify other results 653 if not test_case.is_skipped() and not test_case.status: 654 if test_case.stderr: 655 test_case.add_failure_info('stderr', test_case.stderr) 656 if proc.returncode != 0: 657 test_case.add_error_info(f'returncode = {proc.returncode}') 658 if ref_stdout.is_file(): 659 diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True), 660 test_case.stdout.splitlines(keepends=True), 661 fromfile=str(ref_stdout), 662 tofile='New')) 663 if diff: 664 test_case.add_failure_info('stdout', output=''.join(diff)) 665 elif test_case.stdout and not suite_spec.check_allowed_stdout(test): 666 test_case.add_failure_info('stdout', output=test_case.stdout) 667 # expected CSV output 668 for ref_csv in ref_csvs: 669 csv_name = ref_csv.name 670 out_file = Path.cwd() / csv_name 671 if not ref_csv.is_file(): 672 # remove _{ceed_backend} from path name 673 ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv') 674 if not ref_csv.is_file(): 675 test_case.add_failure_info('csv', output=f'{ref_csv} not found') 676 elif not out_file.is_file(): 677 test_case.add_failure_info('csv', output=f'{out_file} not found') 678 else: 679 csv_ztol: float = spec.csv_ztol if spec.csv_ztol > 0 else suite_spec.csv_ztol 680 csv_rtol: float = spec.csv_rtol if spec.csv_rtol > 0 else suite_spec.csv_rtol 681 diff = diff_csv(out_file, ref_csv, zero_tol=csv_ztol, rel_tol=csv_rtol) 682 if diff: 683 save_path: Path = suite_spec.test_failure_artifacts_path / csv_name 684 shutil.move(out_file, save_path) 685 test_case.add_failure_info(f'csv: {save_path}', output=diff) 686 else: 687 out_file.unlink() 688 # expected CGNS output 689 for ref_cgn in ref_cgns: 690 cgn_name = ref_cgn.name 691 out_file = Path.cwd() / cgn_name 692 if not ref_cgn.is_file(): 693 # remove _{ceed_backend} from path name 694 ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns') 695 if not ref_cgn.is_file(): 696 test_case.add_failure_info('cgns', output=f'{ref_cgn} not found') 697 elif not out_file.is_file(): 698 test_case.add_failure_info('cgns', output=f'{out_file} not found') 699 else: 700 cgns_tol = spec.cgns_tol if spec.cgns_tol > 0 else suite_spec.cgns_tol 701 diff = diff_cgns(out_file, ref_cgn, cgns_tol=cgns_tol) 702 if diff: 703 save_path: Path = suite_spec.test_failure_artifacts_path / cgn_name 704 shutil.move(out_file, save_path) 705 test_case.add_failure_info(f'cgns: {save_path}', output=diff) 706 else: 707 out_file.unlink() 708 # expected ASCII output 709 for ref_file in ref_ascii: 710 ref_name = ref_file.name 711 out_file = Path.cwd() / ref_name 712 if not ref_file.is_file(): 713 # remove _{ceed_backend} from path name 714 ref_file = (ref_file.parent / ref_file.name.rsplit('_', 1)[0]).with_suffix(ref_file.suffix) 715 if not ref_file.is_file(): 716 test_case.add_failure_info('ascii', output=f'{ref_file} not found') 717 elif not out_file.is_file(): 718 test_case.add_failure_info('ascii', output=f'{out_file} not found') 719 else: 720 diff = diff_ascii(out_file, ref_file, backend) 721 if diff: 722 save_path: Path = suite_spec.test_failure_artifacts_path / ref_name 723 shutil.move(out_file, save_path) 724 test_case.add_failure_info(f'ascii: {save_path}', output=diff) 725 else: 726 out_file.unlink() 727 728 # store result 729 test_case.args = ' '.join(str(arg) for arg in run_args) 730 output_str = test_case_output_string(test_case, spec, mode, backend, test, index, verbose) 731 732 return test_case, output_str 733 734 735def init_process(): 736 """Initialize multiprocessing process""" 737 # set up error handler 738 global my_env 739 my_env = os.environ.copy() 740 my_env['CEED_ERROR_HANDLER'] = 'exit' 741 742 743def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int, 744 suite_spec: SuiteSpec, pool_size: int = 1, search: str = ".*", verbose: bool = False) -> TestSuite: 745 """Run all test cases for `test` with each of the provided `ceed_backends` 746 747 Args: 748 test (str): Name of test 749 ceed_backends (List[str]): List of libCEED backends 750 mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT` 751 nproc (int): Number of MPI processes to use when running each test case 752 suite_spec (SuiteSpec): Object defining required methods for running tests 753 pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1. 754 search (str, optional): Regular expression used to match tests. Defaults to ".*". 755 verbose (bool, optional): Print detailed output for all runs, not just failures. Defaults to False. 756 757 Returns: 758 TestSuite: JUnit `TestSuite` containing results of all test cases 759 """ 760 test_specs: List[TestSpec] = [ 761 t for t in get_test_args(suite_spec.get_source_path(test)) if re.search(search, t.name, re.IGNORECASE) 762 ] 763 suite_spec.test_failure_artifacts_path.mkdir(parents=True, exist_ok=True) 764 if mode is RunMode.TAP: 765 print('TAP version 13') 766 print(f'1..{len(test_specs)}') 767 768 with mp.Pool(processes=pool_size, initializer=init_process) as pool: 769 async_outputs: List[List[mp.pool.AsyncResult]] = [ 770 [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec, verbose)) 771 for (i, backend) in enumerate(ceed_backends, start=1)] 772 for spec in test_specs 773 ] 774 775 test_cases = [] 776 for (i, subtest) in enumerate(async_outputs, start=1): 777 is_new_subtest = True 778 subtest_ok = True 779 for async_output in subtest: 780 test_case, print_output = async_output.get() 781 test_cases.append(test_case) 782 if is_new_subtest and mode == RunMode.TAP: 783 is_new_subtest = False 784 print(f'# Subtest: {test_case.category}') 785 print(f' 1..{len(ceed_backends)}') 786 print(print_output, end='') 787 if test_case.is_failure() or test_case.is_error(): 788 subtest_ok = False 789 if mode == RunMode.TAP: 790 print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}') 791 792 return TestSuite(test, test_cases) 793 794 795def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None: 796 """Write a JUnit XML file containing the results of a `TestSuite` 797 798 Args: 799 test_suite (TestSuite): JUnit `TestSuite` to write 800 output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit` 801 batch (str): Name of JUnit batch, defaults to empty string 802 """ 803 output_file = output_file or Path('build') / (f'{test_suite.name}{batch}.junit') 804 output_file.write_text(to_xml_report_string([test_suite])) 805 806 807def has_failures(test_suite: TestSuite) -> bool: 808 """Check whether any test cases in a `TestSuite` failed 809 810 Args: 811 test_suite (TestSuite): JUnit `TestSuite` to check 812 813 Returns: 814 bool: True if any test cases failed 815 """ 816 return any(c.is_failure() or c.is_error() for c in test_suite.test_cases) 817