xref: /libCEED/tests/junit_common.py (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import multiprocessing as mp
12from itertools import product
13import sys
14import time
15from typing import Optional, Tuple, List
16
17sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
18from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
19
20
21class CaseInsensitiveEnumAction(argparse.Action):
22    """Action to convert input values to lower case prior to converting to an Enum type"""
23
24    def __init__(self, option_strings, dest, type, default, **kwargs):
25        if not (issubclass(type, Enum) and issubclass(type, str)):
26            raise ValueError(f"{type} must be a StrEnum or str and Enum")
27        # store provided enum type
28        self.enum_type = type
29        if isinstance(default, str):
30            default = self.enum_type(default.lower())
31        else:
32            default = [self.enum_type(v.lower()) for v in default]
33        # prevent automatic type conversion
34        super().__init__(option_strings, dest, default=default, **kwargs)
35
36    def __call__(self, parser, namespace, values, option_string=None):
37        if isinstance(values, str):
38            values = self.enum_type(values.lower())
39        else:
40            values = [self.enum_type(v.lower()) for v in values]
41        setattr(namespace, self.dest, values)
42
43
44@dataclass
45class TestSpec:
46    """Dataclass storing information about a single test case"""
47    name: str
48    only: List = field(default_factory=list)
49    args: List = field(default_factory=list)
50
51
52class RunMode(str, Enum):
53    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
54    __str__ = str.__str__
55    __format__ = str.__format__
56    TAP: str = 'tap'
57    JUNIT: str = 'junit'
58
59
60class SuiteSpec(ABC):
61    """Abstract Base Class defining the required interface for running a test suite"""
62    @abstractmethod
63    def get_source_path(self, test: str) -> Path:
64        """Compute path to test source file
65
66        Args:
67            test (str): Name of test
68
69        Returns:
70            Path: Path to source file
71        """
72        raise NotImplementedError
73
74    @abstractmethod
75    def get_run_path(self, test: str) -> Path:
76        """Compute path to built test executable file
77
78        Args:
79            test (str): Name of test
80
81        Returns:
82            Path: Path to test executable
83        """
84        raise NotImplementedError
85
86    @abstractmethod
87    def get_output_path(self, test: str, output_file: str) -> Path:
88        """Compute path to expected output file
89
90        Args:
91            test (str): Name of test
92            output_file (str): File name of output file
93
94        Returns:
95            Path: Path to expected output file
96        """
97        raise NotImplementedError
98
99    def post_test_hook(self, test: str, spec: TestSpec) -> None:
100        """Function callback ran after each test case
101
102        Args:
103            test (str): Name of test
104            spec (TestSpec): Test case specification
105        """
106        pass
107
108    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
109        """Check if a test case should be skipped prior to running, returning the reason for skipping
110
111        Args:
112            test (str): Name of test
113            spec (TestSpec): Test case specification
114            resource (str): libCEED backend
115            nproc (int): Number of MPI processes to use when running test case
116
117        Returns:
118            Optional[str]: Skip reason, or `None` if test case should not be skipped
119        """
120        return None
121
122    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
123        """Check if a test case should be allowed to fail, based on its stderr output
124
125        Args:
126            test (str): Name of test
127            spec (TestSpec): Test case specification
128            resource (str): libCEED backend
129            stderr (str): Standard error output from test case execution
130
131        Returns:
132            Optional[str]: Skip reason, or `None` if unexpected error
133        """
134        return None
135
136    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
137        """Check whether a test case is expected to fail and if it failed expectedly
138
139        Args:
140            test (str): Name of test
141            spec (TestSpec): Test case specification
142            resource (str): libCEED backend
143            stderr (str): Standard error output from test case execution
144
145        Returns:
146            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
147        """
148        return '', True
149
150    def check_allowed_stdout(self, test: str) -> bool:
151        """Check whether a test is allowed to print console output
152
153        Args:
154            test (str): Name of test
155
156        Returns:
157            bool: True if the test is allowed to print console output
158        """
159        return False
160
161
162def has_cgnsdiff() -> bool:
163    """Check whether `cgnsdiff` is an executable program in the current environment
164
165    Returns:
166        bool: True if `cgnsdiff` is found
167    """
168    my_env: dict = os.environ.copy()
169    proc = subprocess.run('cgnsdiff',
170                          shell=True,
171                          stdout=subprocess.PIPE,
172                          stderr=subprocess.PIPE,
173                          env=my_env)
174    return 'not found' not in proc.stderr.decode('utf-8')
175
176
177def contains_any(base: str, substrings: List[str]) -> bool:
178    """Helper function, checks if any of the substrings are included in the base string
179
180    Args:
181        base (str): Base string to search in
182        substrings (List[str]): List of potential substrings
183
184    Returns:
185        bool: True if any substrings are included in base string
186    """
187    return any((sub in base for sub in substrings))
188
189
190def startswith_any(base: str, prefixes: List[str]) -> bool:
191    """Helper function, checks if the base string is prefixed by any of `prefixes`
192
193    Args:
194        base (str): Base string to search
195        prefixes (List[str]): List of potential prefixes
196
197    Returns:
198        bool: True if base string is prefixed by any of the prefixes
199    """
200    return any((base.startswith(prefix) for prefix in prefixes))
201
202
203def parse_test_line(line: str) -> TestSpec:
204    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
205
206    Args:
207        line (str): String containing TESTARGS specification and CLI arguments
208
209    Returns:
210        TestSpec: Parsed specification of test case
211    """
212    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
213    if args[0] == 'TESTARGS':
214        return TestSpec(name='', args=args[1:])
215    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
216    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
217    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
218    name: str = test_args.get('name', '')
219    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
220    if len(args) > 1:
221        return TestSpec(name=name, only=constraints, args=args[1:])
222    else:
223        return TestSpec(name=name, only=constraints)
224
225
226def get_test_args(source_file: Path) -> List[TestSpec]:
227    """Parse all test cases from a given source file
228
229    Args:
230        source_file (Path): Path to source file
231
232    Raises:
233        RuntimeError: Errors if source file extension is unsupported
234
235    Returns:
236        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
237    """
238    comment_str: str = ''
239    if source_file.suffix in ['.c', '.cpp']:
240        comment_str = '//'
241    elif source_file.suffix in ['.py']:
242        comment_str = '#'
243    elif source_file.suffix in ['.usr']:
244        comment_str = 'C_'
245    elif source_file.suffix in ['.f90']:
246        comment_str = '! '
247    else:
248        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
249
250    return [parse_test_line(line.strip(comment_str))
251            for line in source_file.read_text().splitlines()
252            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
253
254
255def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
256    """Compare CSV results against an expected CSV file with tolerances
257
258    Args:
259        test_csv (Path): Path to output CSV results
260        true_csv (Path): Path to expected CSV results
261        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
262        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
263
264    Returns:
265        str: Diff output between result and expected CSVs
266    """
267    test_lines: List[str] = test_csv.read_text().splitlines()
268    true_lines: List[str] = true_csv.read_text().splitlines()
269
270    if test_lines[0] != true_lines[0]:
271        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
272                       tofile='found CSV columns', fromfile='expected CSV columns'))
273
274    diff_lines: List[str] = list()
275    column_names: List[str] = true_lines[0].strip().split(',')
276    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
277        test_vals: List[float] = [float(val.strip()) for val in test_line.strip().split(',')]
278        true_vals: List[float] = [float(val.strip()) for val in true_line.strip().split(',')]
279        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
280            true_zero: bool = abs(true_val) < zero_tol
281            test_zero: bool = abs(test_val) < zero_tol
282            fail: bool = False
283            if true_zero:
284                fail = not test_zero
285            else:
286                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
287            if fail:
288                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
289    return '\n'.join(diff_lines)
290
291
292def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
293    """Compare CGNS results against an expected CGSN file with tolerance
294
295    Args:
296        test_cgns (Path): Path to output CGNS file
297        true_cgns (Path): Path to expected CGNS file
298        tolerance (float, optional): Tolerance for comparing floating-point values
299
300    Returns:
301        str: Diff output between result and expected CGNS files
302    """
303    my_env: dict = os.environ.copy()
304
305    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
306    proc = subprocess.run(' '.join(run_args),
307                          shell=True,
308                          stdout=subprocess.PIPE,
309                          stderr=subprocess.PIPE,
310                          env=my_env)
311
312    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
313
314
315def run_test(index: int, test: str, spec: TestSpec, backend: str,
316             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
317    """Run a single test case and backend combination
318
319    Args:
320        index (int): Index of backend for current spec
321        test (str): Path to test
322        spec (TestSpec): Specification of test case
323        backend (str): CEED backend
324        mode (RunMode): Output mode
325        nproc (int): Number of MPI processes to use when running test case
326        suite_spec (SuiteSpec): Specification of test suite
327
328    Returns:
329        TestCase: Test case result
330    """
331    source_path: Path = suite_spec.get_source_path(test)
332    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
333
334    if '{ceed_resource}' in run_args:
335        run_args[run_args.index('{ceed_resource}')] = backend
336    for i, arg in enumerate(run_args):
337        if '{ceed_resource}' in arg:
338            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
339    if '{nproc}' in run_args:
340        run_args[run_args.index('{nproc}')] = f'{nproc}'
341    elif nproc > 1 and source_path.suffix != '.py':
342        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
343
344    # run test
345    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
346    if skip_reason:
347        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
348                                       elapsed_sec=0,
349                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
350                                       stdout='',
351                                       stderr='',
352                                       category=spec.name,)
353        test_case.add_skipped_info(skip_reason)
354    else:
355        start: float = time.time()
356        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
357                              shell=True,
358                              stdout=subprocess.PIPE,
359                              stderr=subprocess.PIPE,
360                              env=my_env)
361
362        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
363                             classname=source_path.parent,
364                             elapsed_sec=time.time() - start,
365                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
366                             stdout=proc.stdout.decode('utf-8'),
367                             stderr=proc.stderr.decode('utf-8'),
368                             allow_multiple_subelements=True,
369                             category=spec.name,)
370        ref_csvs: List[Path] = []
371        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
372        if output_files:
373            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
374        ref_cgns: List[Path] = []
375        output_files = [arg for arg in run_args if 'cgns:' in arg]
376        if output_files:
377            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
378        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
379        suite_spec.post_test_hook(test, spec)
380
381    # check allowed failures
382    if not test_case.is_skipped() and test_case.stderr:
383        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
384        if skip_reason:
385            test_case.add_skipped_info(skip_reason)
386
387    # check required failures
388    if not test_case.is_skipped():
389        required_message, did_fail = suite_spec.check_required_failure(
390            test, spec, backend, test_case.stderr)
391        if required_message and did_fail:
392            test_case.status = f'fails with required: {required_message}'
393        elif required_message:
394            test_case.add_failure_info(f'required failure missing: {required_message}')
395
396    # classify other results
397    if not test_case.is_skipped() and not test_case.status:
398        if test_case.stderr:
399            test_case.add_failure_info('stderr', test_case.stderr)
400        if proc.returncode != 0:
401            test_case.add_error_info(f'returncode = {proc.returncode}')
402        if ref_stdout.is_file():
403            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
404                                             test_case.stdout.splitlines(keepends=True),
405                                             fromfile=str(ref_stdout),
406                                             tofile='New'))
407            if diff:
408                test_case.add_failure_info('stdout', output=''.join(diff))
409        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
410            test_case.add_failure_info('stdout', output=test_case.stdout)
411        # expected CSV output
412        for ref_csv in ref_csvs:
413            csv_name = ref_csv.name
414            if not ref_csv.is_file():
415                # remove _{ceed_backend} from path name
416                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
417            if not ref_csv.is_file():
418                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
419            else:
420                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
421                if diff:
422                    test_case.add_failure_info('csv', output=diff)
423                else:
424                    (Path.cwd() / csv_name).unlink()
425        # expected CGNS output
426        for ref_cgn in ref_cgns:
427            cgn_name = ref_cgn.name
428            if not ref_cgn.is_file():
429                # remove _{ceed_backend} from path name
430                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
431            if not ref_cgn.is_file():
432                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
433            else:
434                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn)
435                if diff:
436                    test_case.add_failure_info('cgns', output=diff)
437                else:
438                    (Path.cwd() / cgn_name).unlink()
439
440    # store result
441    test_case.args = ' '.join(str(arg) for arg in run_args)
442    output_str = ''
443    # print output
444    if mode is RunMode.TAP:
445        # print incremental output if TAP mode
446        if test_case.is_skipped():
447            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
448        elif test_case.is_failure() or test_case.is_error():
449            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
450        else:
451            output_str += f'    ok {index} - {spec.name}, {backend}\n'
452        output_str += f'      ---\n'
453        if spec.only:
454            output_str += f'      only: {",".join(spec.only)}\n'
455        output_str += f'      args: {test_case.args}\n'
456        if test_case.is_error():
457            output_str += f'      error: {test_case.errors[0]["message"]}\n'
458        if test_case.is_failure():
459            output_str += f'      num_failures: {len(test_case.failures)}\n'
460            for i, failure in enumerate(test_case.failures):
461                output_str += f'      failure_{i}: {failure["message"]}\n'
462                output_str += f'        message: {failure["message"]}\n'
463                if failure["output"]:
464                    out = failure["output"].strip().replace('\n', '\n          ')
465                    output_str += f'        output: |\n          {out}\n'
466        output_str += f'      ...\n'
467    else:
468        # print error or failure information if JUNIT mode
469        if test_case.is_error() or test_case.is_failure():
470            output_str += f'Test: {test} {spec.name}\n'
471            output_str += f'  $ {test_case.args}\n'
472            if test_case.is_error():
473                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
474                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
475            if test_case.is_failure():
476                for failure in test_case.failures:
477                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
478                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
479
480    return test_case, output_str
481
482
483def init_process():
484    """Initialize multiprocessing process"""
485    # set up error handler
486    global my_env
487    my_env = os.environ.copy()
488    my_env['CEED_ERROR_HANDLER'] = 'exit'
489
490
491def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
492              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
493    """Run all test cases for `test` with each of the provided `ceed_backends`
494
495    Args:
496        test (str): Name of test
497        ceed_backends (List[str]): List of libCEED backends
498        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
499        nproc (int): Number of MPI processes to use when running each test case
500        suite_spec (SuiteSpec): Object defining required methods for running tests
501        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
502
503    Returns:
504        TestSuite: JUnit `TestSuite` containing results of all test cases
505    """
506    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
507    if mode is RunMode.TAP:
508        print('TAP version 13')
509        print(f'1..{len(test_specs)}')
510
511    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
512        async_outputs: List[List[mp.AsyncResult]] = [
513            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
514             for (i, backend) in enumerate(ceed_backends, start=1)]
515            for spec in test_specs
516        ]
517
518        test_cases = []
519        for (i, subtest) in enumerate(async_outputs, start=1):
520            is_new_subtest = True
521            subtest_ok = True
522            for async_output in subtest:
523                test_case, print_output = async_output.get()
524                test_cases.append(test_case)
525                if is_new_subtest and mode == RunMode.TAP:
526                    is_new_subtest = False
527                    print(f'# Subtest: {test_case.category}')
528                    print(f'    1..{len(ceed_backends)}')
529                print(print_output, end='')
530                if test_case.is_failure() or test_case.is_error():
531                    subtest_ok = False
532            if mode == RunMode.TAP:
533                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
534
535    return TestSuite(test, test_cases)
536
537
538def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
539    """Write a JUnit XML file containing the results of a `TestSuite`
540
541    Args:
542        test_suite (TestSuite): JUnit `TestSuite` to write
543        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
544        batch (str): Name of JUnit batch, defaults to empty string
545    """
546    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
547    output_file.write_text(to_xml_report_string([test_suite]))
548
549
550def has_failures(test_suite: TestSuite) -> bool:
551    """Check whether any test cases in a `TestSuite` failed
552
553    Args:
554        test_suite (TestSuite): JUnit `TestSuite` to check
555
556    Returns:
557        bool: True if any test cases failed
558    """
559    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
560