xref: /libCEED/tests/junit_common.py (revision 1ce8139f46e7307779d69b9ddba9e4b0375a4d52)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import multiprocessing as mp
12from itertools import product
13import sys
14import time
15from typing import Optional, Tuple, List
16
17sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
18from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
19
20
21class CaseInsensitiveEnumAction(argparse.Action):
22    """Action to convert input values to lower case prior to converting to an Enum type"""
23
24    def __init__(self, option_strings, dest, type, default, **kwargs):
25        if not (issubclass(type, Enum) and issubclass(type, str)):
26            raise ValueError(f"{type} must be a StrEnum or str and Enum")
27        # store provided enum type
28        self.enum_type = type
29        if isinstance(default, str):
30            default = self.enum_type(default.lower())
31        else:
32            default = [self.enum_type(v.lower()) for v in default]
33        # prevent automatic type conversion
34        super().__init__(option_strings, dest, default=default, **kwargs)
35
36    def __call__(self, parser, namespace, values, option_string=None):
37        if isinstance(values, str):
38            values = self.enum_type(values.lower())
39        else:
40            values = [self.enum_type(v.lower()) for v in values]
41        setattr(namespace, self.dest, values)
42
43
44@dataclass
45class TestSpec:
46    """Dataclass storing information about a single test case"""
47    name: str
48    only: list = field(default_factory=list)
49    args: list = field(default_factory=list)
50
51
52class RunMode(str, Enum):
53    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
54    __str__ = str.__str__
55    __format__ = str.__format__
56    TAP: str = 'tap'
57    JUNIT: str = 'junit'
58
59
60class SuiteSpec(ABC):
61    """Abstract Base Class defining the required interface for running a test suite"""
62    @abstractmethod
63    def get_source_path(self, test: str) -> Path:
64        """Compute path to test source file
65
66        Args:
67            test (str): Name of test
68
69        Returns:
70            Path: Path to source file
71        """
72        raise NotImplementedError
73
74    @abstractmethod
75    def get_run_path(self, test: str) -> Path:
76        """Compute path to built test executable file
77
78        Args:
79            test (str): Name of test
80
81        Returns:
82            Path: Path to test executable
83        """
84        raise NotImplementedError
85
86    @abstractmethod
87    def get_output_path(self, test: str, output_file: str) -> Path:
88        """Compute path to expected output file
89
90        Args:
91            test (str): Name of test
92            output_file (str): File name of output file
93
94        Returns:
95            Path: Path to expected output file
96        """
97        raise NotImplementedError
98
99    def post_test_hook(self, test: str, spec: TestSpec) -> None:
100        """Function callback ran after each test case
101
102        Args:
103            test (str): Name of test
104            spec (TestSpec): Test case specification
105        """
106        pass
107
108    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
109        """Check if a test case should be skipped prior to running, returning the reason for skipping
110
111        Args:
112            test (str): Name of test
113            spec (TestSpec): Test case specification
114            resource (str): libCEED backend
115            nproc (int): Number of MPI processes to use when running test case
116
117        Returns:
118            Optional[str]: Skip reason, or `None` if test case should not be skipped
119        """
120        return None
121
122    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
123        """Check if a test case should be allowed to fail, based on its stderr output
124
125        Args:
126            test (str): Name of test
127            spec (TestSpec): Test case specification
128            resource (str): libCEED backend
129            stderr (str): Standard error output from test case execution
130
131        Returns:
132            Optional[str]: Skip reason, or `None` if unexpected error
133        """
134        return None
135
136    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
137        """Check whether a test case is expected to fail and if it failed expectedly
138
139        Args:
140            test (str): Name of test
141            spec (TestSpec): Test case specification
142            resource (str): libCEED backend
143            stderr (str): Standard error output from test case execution
144
145        Returns:
146            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
147        """
148        return '', True
149
150    def check_allowed_stdout(self, test: str) -> bool:
151        """Check whether a test is allowed to print console output
152
153        Args:
154            test (str): Name of test
155
156        Returns:
157            bool: True if the test is allowed to print console output
158        """
159        return False
160
161
162def has_cgnsdiff() -> bool:
163    """Check whether `cgnsdiff` is an executable program in the current environment
164
165    Returns:
166        bool: True if `cgnsdiff` is found
167    """
168    my_env: dict = os.environ.copy()
169    proc = subprocess.run('cgnsdiff',
170                          shell=True,
171                          stdout=subprocess.PIPE,
172                          stderr=subprocess.PIPE,
173                          env=my_env)
174    return 'not found' not in proc.stderr.decode('utf-8')
175
176
177def contains_any(base: str, substrings: List[str]) -> bool:
178    """Helper function, checks if any of the substrings are included in the base string
179
180    Args:
181        base (str): Base string to search in
182        substrings (list[str]): List of potential substrings
183
184    Returns:
185        bool: True if any substrings are included in base string
186    """
187    return any((sub in base for sub in substrings))
188
189
190def startswith_any(base: str, prefixes: List[str]) -> bool:
191    """Helper function, checks if the base string is prefixed by any of `prefixes`
192
193    Args:
194        base (str): Base string to search
195        prefixes (list[str]): List of potential prefixes
196
197    Returns:
198        bool: True if base string is prefixed by any of the prefixes
199    """
200    return any((base.startswith(prefix) for prefix in prefixes))
201
202
203def parse_test_line(line: str) -> TestSpec:
204    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
205
206    Args:
207        line (str): String containing TESTARGS specification and CLI arguments
208
209    Returns:
210        TestSpec: Parsed specification of test case
211    """
212    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
213    if args[0] == 'TESTARGS':
214        return TestSpec(name='', args=args[1:])
215    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
216    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
217    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
218    name: str = test_args.get('name', '')
219    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
220    if len(args) > 1:
221        return TestSpec(name=name, only=constraints, args=args[1:])
222    else:
223        return TestSpec(name=name, only=constraints)
224
225
226def get_test_args(source_file: Path) -> List[TestSpec]:
227    """Parse all test cases from a given source file
228
229    Args:
230        source_file (Path): Path to source file
231
232    Raises:
233        RuntimeError: Errors if source file extension is unsupported
234
235    Returns:
236        list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
237    """
238    comment_str: str = ''
239    if source_file.suffix in ['.c', '.cpp']:
240        comment_str = '//'
241    elif source_file.suffix in ['.py']:
242        comment_str = '#'
243    elif source_file.suffix in ['.usr']:
244        comment_str = 'C_'
245    elif source_file.suffix in ['.f90']:
246        comment_str = '! '
247    else:
248        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
249
250    return [parse_test_line(line.strip(comment_str))
251            for line in source_file.read_text().splitlines()
252            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
253
254
255def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
256    """Compare CSV results against an expected CSV file with tolerances
257
258    Args:
259        test_csv (Path): Path to output CSV results
260        true_csv (Path): Path to expected CSV results
261        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
262        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
263
264    Returns:
265        str: Diff output between result and expected CSVs
266    """
267    test_lines: List[str] = test_csv.read_text().splitlines()
268    true_lines: List[str] = true_csv.read_text().splitlines()
269
270    if test_lines[0] != true_lines[0]:
271        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
272                       tofile='found CSV columns', fromfile='expected CSV columns'))
273
274    diff_lines: List[str] = list()
275    column_names: List[str] = true_lines[0].strip().split(',')
276    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
277        test_vals: List[float] = [float(val.strip()) for val in test_line.strip().split(',')]
278        true_vals: List[float] = [float(val.strip()) for val in true_line.strip().split(',')]
279        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
280            true_zero: bool = abs(true_val) < zero_tol
281            test_zero: bool = abs(test_val) < zero_tol
282            fail: bool = False
283            if true_zero:
284                fail = not test_zero
285            else:
286                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
287            if fail:
288                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
289    return '\n'.join(diff_lines)
290
291
292def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
293    """Compare CGNS results against an expected CGSN file with tolerance
294
295    Args:
296        test_cgns (Path): Path to output CGNS file
297        true_cgns (Path): Path to expected CGNS file
298        tolerance (float, optional): Tolerance for comparing floating-point values
299
300    Returns:
301        str: Diff output between result and expected CGNS files
302    """
303    my_env: dict = os.environ.copy()
304
305    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
306    proc = subprocess.run(' '.join(run_args),
307                          shell=True,
308                          stdout=subprocess.PIPE,
309                          stderr=subprocess.PIPE,
310                          env=my_env)
311
312    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
313
314
315def run_test(index: int, test: str, spec: TestSpec, backend: str,
316             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
317    """Run a single test case and backend combination
318
319    Args:
320        index (int): Index of test case
321        test (str): Path to test
322        spec (TestSpec): Specification of test case
323        backend (str): CEED backend
324        mode (RunMode): Output mode
325        nproc (int): Number of MPI processes to use when running test case
326        suite_spec (SuiteSpec): Specification of test suite
327
328    Returns:
329        TestCase: Test case result
330    """
331    source_path: Path = suite_spec.get_source_path(test)
332    run_args: List = [suite_spec.get_run_path(test), *spec.args]
333
334    if '{ceed_resource}' in run_args:
335        run_args[run_args.index('{ceed_resource}')] = backend
336    if '{nproc}' in run_args:
337        run_args[run_args.index('{nproc}')] = f'{nproc}'
338    elif nproc > 1 and source_path.suffix != '.py':
339        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
340
341    # run test
342    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
343    if skip_reason:
344        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
345                                       elapsed_sec=0,
346                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
347                                       stdout='',
348                                       stderr='')
349        test_case.add_skipped_info(skip_reason)
350    else:
351        start: float = time.time()
352        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
353                              shell=True,
354                              stdout=subprocess.PIPE,
355                              stderr=subprocess.PIPE,
356                              env=my_env)
357
358        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
359                             classname=source_path.parent,
360                             elapsed_sec=time.time() - start,
361                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
362                             stdout=proc.stdout.decode('utf-8'),
363                             stderr=proc.stderr.decode('utf-8'),
364                             allow_multiple_subelements=True)
365        ref_csvs: List[Path] = []
366        output_files: List[str] = [arg for arg in spec.args if 'ascii:' in arg]
367        if output_files:
368            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
369        ref_cgns: List[Path] = []
370        output_files = [arg for arg in spec.args if 'cgns:' in arg]
371        if output_files:
372            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
373        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
374        suite_spec.post_test_hook(test, spec)
375
376    # check allowed failures
377    if not test_case.is_skipped() and test_case.stderr:
378        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
379        if skip_reason:
380            test_case.add_skipped_info(skip_reason)
381
382    # check required failures
383    if not test_case.is_skipped():
384        required_message, did_fail = suite_spec.check_required_failure(
385            test, spec, backend, test_case.stderr)
386        if required_message and did_fail:
387            test_case.status = f'fails with required: {required_message}'
388        elif required_message:
389            test_case.add_failure_info(f'required failure missing: {required_message}')
390
391    # classify other results
392    if not test_case.is_skipped() and not test_case.status:
393        if test_case.stderr:
394            test_case.add_failure_info('stderr', test_case.stderr)
395        if proc.returncode != 0:
396            test_case.add_error_info(f'returncode = {proc.returncode}')
397        if ref_stdout.is_file():
398            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
399                                             test_case.stdout.splitlines(keepends=True),
400                                             fromfile=str(ref_stdout),
401                                             tofile='New'))
402            if diff:
403                test_case.add_failure_info('stdout', output=''.join(diff))
404        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
405            test_case.add_failure_info('stdout', output=test_case.stdout)
406        # expected CSV output
407        for ref_csv in ref_csvs:
408            if not ref_csv.is_file():
409                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
410            else:
411                diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv)
412                if diff:
413                    test_case.add_failure_info('csv', output=diff)
414                else:
415                    (Path.cwd() / ref_csv.name).unlink()
416        # expected CGNS output
417        for ref_cgn in ref_cgns:
418            if not ref_cgn.is_file():
419                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
420            else:
421                diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn)
422                if diff:
423                    test_case.add_failure_info('cgns', output=diff)
424                else:
425                    (Path.cwd() / ref_cgn.name).unlink()
426
427    # store result
428    test_case.args = ' '.join(str(arg) for arg in run_args)
429    output_str = ''
430    # print output
431    if mode is RunMode.TAP:
432        # print incremental output if TAP mode
433        output_str += f'# Test: {spec.name}\n'
434        if spec.only:
435            output_str += f'# Only: {",".join(spec.only)}'
436        output_str += f'# $ {test_case.args}\n'
437        if test_case.is_skipped():
438            output_str += ('ok {} - SKIP: {}\n'.format(index,
439                                                       (test_case.skipped[0]['message'] or 'NO MESSAGE').strip())) + '\n'
440        elif test_case.is_failure() or test_case.is_error():
441            output_str += f'not ok {index}\n'
442            if test_case.is_error():
443                output_str += f'  ERROR: {test_case.errors[0]["message"]}\n'
444            if test_case.is_failure():
445                for i, failure in enumerate(test_case.failures):
446                    output_str += f'  FAILURE {i}: {failure["message"]}\n'
447                    output_str += f'    Output: \n{failure["output"]}\n'
448        else:
449            output_str += f'ok {index} - PASS\n'
450    else:
451        # print error or failure information if JUNIT mode
452        if test_case.is_error() or test_case.is_failure():
453            output_str += f'Test: {test} {spec.name}\n'
454            output_str += f'  $ {test_case.args}\n'
455            if test_case.is_error():
456                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
457                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
458            if test_case.is_failure():
459                for failure in test_case.failures:
460                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
461                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
462
463    return test_case, output_str
464
465
466def init_process():
467    """Initialize multiprocessing process"""
468    # set up error handler
469    global my_env
470    my_env = os.environ.copy()
471    my_env['CEED_ERROR_HANDLER'] = 'exit'
472
473
474def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
475              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
476    """Run all test cases for `test` with each of the provided `ceed_backends`
477
478    Args:
479        test (str): Name of test
480        ceed_backends (list[str]): List of libCEED backends
481        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
482        nproc (int): Number of MPI processes to use when running each test case
483        suite_spec (SuiteSpec): Object defining required methods for running tests
484        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
485
486    Returns:
487        TestSuite: JUnit `TestSuite` containing results of all test cases
488    """
489    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
490    if mode is RunMode.TAP:
491        print('1..' + str(len(test_specs) * len(ceed_backends)))
492
493    # list of (test, test_specs, ceed_backend, ...) tuples generated from list of backends and test specs
494    args: List[TestCase] = [(i, test, spec, backend, mode, nproc, suite_spec)
495                            for i, (spec, backend) in enumerate(product(test_specs, ceed_backends), start=1)]
496
497    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
498        async_outputs: List[mp.AsyncResult] = [pool.apply_async(run_test, argv) for argv in args]
499
500        test_cases = []
501        for async_output in async_outputs:
502            test_case, print_output = async_output.get()
503            test_cases.append(test_case)
504            print(print_output, end='')
505
506    return TestSuite(test, test_cases)
507
508
509def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
510    """Write a JUnit XML file containing the results of a `TestSuite`
511
512    Args:
513        test_suite (TestSuite): JUnit `TestSuite` to write
514        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
515        batch (str): Name of JUnit batch, defaults to empty string
516    """
517    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
518    output_file.write_text(to_xml_report_string([test_suite]))
519
520
521def has_failures(test_suite: TestSuite) -> bool:
522    """Check whether any test cases in a `TestSuite` failed
523
524    Args:
525        test_suite (TestSuite): JUnit `TestSuite` to check
526
527    Returns:
528        bool: True if any test cases failed
529    """
530    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
531