xref: /libCEED/tests/junit_common.py (revision 3451ca54fe05d334f1044b007a59894b1a9bed54)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import multiprocessing as mp
12from itertools import product
13import sys
14import time
15from typing import Optional, Tuple, List
16
17sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
18from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
19
20
21class CaseInsensitiveEnumAction(argparse.Action):
22    """Action to convert input values to lower case prior to converting to an Enum type"""
23
24    def __init__(self, option_strings, dest, type, default, **kwargs):
25        if not (issubclass(type, Enum) and issubclass(type, str)):
26            raise ValueError(f"{type} must be a StrEnum or str and Enum")
27        # store provided enum type
28        self.enum_type = type
29        if isinstance(default, str):
30            default = self.enum_type(default.lower())
31        else:
32            default = [self.enum_type(v.lower()) for v in default]
33        # prevent automatic type conversion
34        super().__init__(option_strings, dest, default=default, **kwargs)
35
36    def __call__(self, parser, namespace, values, option_string=None):
37        if isinstance(values, str):
38            values = self.enum_type(values.lower())
39        else:
40            values = [self.enum_type(v.lower()) for v in values]
41        setattr(namespace, self.dest, values)
42
43
44@dataclass
45class TestSpec:
46    """Dataclass storing information about a single test case"""
47    name: str
48    only: List = field(default_factory=list)
49    args: List = field(default_factory=list)
50
51
52class RunMode(str, Enum):
53    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
54    __str__ = str.__str__
55    __format__ = str.__format__
56    TAP: str = 'tap'
57    JUNIT: str = 'junit'
58
59
60class SuiteSpec(ABC):
61    """Abstract Base Class defining the required interface for running a test suite"""
62    @abstractmethod
63    def get_source_path(self, test: str) -> Path:
64        """Compute path to test source file
65
66        Args:
67            test (str): Name of test
68
69        Returns:
70            Path: Path to source file
71        """
72        raise NotImplementedError
73
74    @abstractmethod
75    def get_run_path(self, test: str) -> Path:
76        """Compute path to built test executable file
77
78        Args:
79            test (str): Name of test
80
81        Returns:
82            Path: Path to test executable
83        """
84        raise NotImplementedError
85
86    @abstractmethod
87    def get_output_path(self, test: str, output_file: str) -> Path:
88        """Compute path to expected output file
89
90        Args:
91            test (str): Name of test
92            output_file (str): File name of output file
93
94        Returns:
95            Path: Path to expected output file
96        """
97        raise NotImplementedError
98
99    def post_test_hook(self, test: str, spec: TestSpec) -> None:
100        """Function callback ran after each test case
101
102        Args:
103            test (str): Name of test
104            spec (TestSpec): Test case specification
105        """
106        pass
107
108    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
109        """Check if a test case should be skipped prior to running, returning the reason for skipping
110
111        Args:
112            test (str): Name of test
113            spec (TestSpec): Test case specification
114            resource (str): libCEED backend
115            nproc (int): Number of MPI processes to use when running test case
116
117        Returns:
118            Optional[str]: Skip reason, or `None` if test case should not be skipped
119        """
120        return None
121
122    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
123        """Check if a test case should be allowed to fail, based on its stderr output
124
125        Args:
126            test (str): Name of test
127            spec (TestSpec): Test case specification
128            resource (str): libCEED backend
129            stderr (str): Standard error output from test case execution
130
131        Returns:
132            Optional[str]: Skip reason, or `None` if unexpected error
133        """
134        return None
135
136    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
137        """Check whether a test case is expected to fail and if it failed expectedly
138
139        Args:
140            test (str): Name of test
141            spec (TestSpec): Test case specification
142            resource (str): libCEED backend
143            stderr (str): Standard error output from test case execution
144
145        Returns:
146            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
147        """
148        return '', True
149
150    def check_allowed_stdout(self, test: str) -> bool:
151        """Check whether a test is allowed to print console output
152
153        Args:
154            test (str): Name of test
155
156        Returns:
157            bool: True if the test is allowed to print console output
158        """
159        return False
160
161
162def has_cgnsdiff() -> bool:
163    """Check whether `cgnsdiff` is an executable program in the current environment
164
165    Returns:
166        bool: True if `cgnsdiff` is found
167    """
168    my_env: dict = os.environ.copy()
169    proc = subprocess.run('cgnsdiff',
170                          shell=True,
171                          stdout=subprocess.PIPE,
172                          stderr=subprocess.PIPE,
173                          env=my_env)
174    return 'not found' not in proc.stderr.decode('utf-8')
175
176
177def contains_any(base: str, substrings: List[str]) -> bool:
178    """Helper function, checks if any of the substrings are included in the base string
179
180    Args:
181        base (str): Base string to search in
182        substrings (List[str]): List of potential substrings
183
184    Returns:
185        bool: True if any substrings are included in base string
186    """
187    return any((sub in base for sub in substrings))
188
189
190def startswith_any(base: str, prefixes: List[str]) -> bool:
191    """Helper function, checks if the base string is prefixed by any of `prefixes`
192
193    Args:
194        base (str): Base string to search
195        prefixes (List[str]): List of potential prefixes
196
197    Returns:
198        bool: True if base string is prefixed by any of the prefixes
199    """
200    return any((base.startswith(prefix) for prefix in prefixes))
201
202
203def parse_test_line(line: str) -> TestSpec:
204    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
205
206    Args:
207        line (str): String containing TESTARGS specification and CLI arguments
208
209    Returns:
210        TestSpec: Parsed specification of test case
211    """
212    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
213    if args[0] == 'TESTARGS':
214        return TestSpec(name='', args=args[1:])
215    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
216    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
217    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
218    name: str = test_args.get('name', '')
219    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
220    if len(args) > 1:
221        return TestSpec(name=name, only=constraints, args=args[1:])
222    else:
223        return TestSpec(name=name, only=constraints)
224
225
226def get_test_args(source_file: Path) -> List[TestSpec]:
227    """Parse all test cases from a given source file
228
229    Args:
230        source_file (Path): Path to source file
231
232    Raises:
233        RuntimeError: Errors if source file extension is unsupported
234
235    Returns:
236        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
237    """
238    comment_str: str = ''
239    if source_file.suffix in ['.c', '.cpp']:
240        comment_str = '//'
241    elif source_file.suffix in ['.py']:
242        comment_str = '#'
243    elif source_file.suffix in ['.usr']:
244        comment_str = 'C_'
245    elif source_file.suffix in ['.f90']:
246        comment_str = '! '
247    else:
248        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
249
250    return [parse_test_line(line.strip(comment_str))
251            for line in source_file.read_text().splitlines()
252            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
253
254
255def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
256    """Compare CSV results against an expected CSV file with tolerances
257
258    Args:
259        test_csv (Path): Path to output CSV results
260        true_csv (Path): Path to expected CSV results
261        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
262        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
263
264    Returns:
265        str: Diff output between result and expected CSVs
266    """
267    test_lines: List[str] = test_csv.read_text().splitlines()
268    true_lines: List[str] = true_csv.read_text().splitlines()
269
270    if test_lines[0] != true_lines[0]:
271        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
272                       tofile='found CSV columns', fromfile='expected CSV columns'))
273
274    diff_lines: List[str] = list()
275    column_names: List[str] = true_lines[0].strip().split(',')
276    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
277        test_vals: List[float] = [float(val.strip()) for val in test_line.strip().split(',')]
278        true_vals: List[float] = [float(val.strip()) for val in true_line.strip().split(',')]
279        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
280            true_zero: bool = abs(true_val) < zero_tol
281            test_zero: bool = abs(test_val) < zero_tol
282            fail: bool = False
283            if true_zero:
284                fail = not test_zero
285            else:
286                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
287            if fail:
288                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
289    return '\n'.join(diff_lines)
290
291
292def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
293    """Compare CGNS results against an expected CGSN file with tolerance
294
295    Args:
296        test_cgns (Path): Path to output CGNS file
297        true_cgns (Path): Path to expected CGNS file
298        tolerance (float, optional): Tolerance for comparing floating-point values
299
300    Returns:
301        str: Diff output between result and expected CGNS files
302    """
303    my_env: dict = os.environ.copy()
304
305    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
306    proc = subprocess.run(' '.join(run_args),
307                          shell=True,
308                          stdout=subprocess.PIPE,
309                          stderr=subprocess.PIPE,
310                          env=my_env)
311
312    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
313
314
315def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
316                            backend: str, test: str, index: int) -> str:
317    output_str = ''
318    if mode is RunMode.TAP:
319        # print incremental output if TAP mode
320        if test_case.is_skipped():
321            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
322        elif test_case.is_failure() or test_case.is_error():
323            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
324        else:
325            output_str += f'    ok {index} - {spec.name}, {backend}\n'
326        output_str += f'      ---\n'
327        if spec.only:
328            output_str += f'      only: {",".join(spec.only)}\n'
329        output_str += f'      args: {test_case.args}\n'
330        if test_case.is_error():
331            output_str += f'      error: {test_case.errors[0]["message"]}\n'
332        if test_case.is_failure():
333            output_str += f'      num_failures: {len(test_case.failures)}\n'
334            for i, failure in enumerate(test_case.failures):
335                output_str += f'      failure_{i}: {failure["message"]}\n'
336                output_str += f'        message: {failure["message"]}\n'
337                if failure["output"]:
338                    out = failure["output"].strip().replace('\n', '\n          ')
339                    output_str += f'        output: |\n          {out}\n'
340        output_str += f'      ...\n'
341    else:
342        # print error or failure information if JUNIT mode
343        if test_case.is_error() or test_case.is_failure():
344            output_str += f'Test: {test} {spec.name}\n'
345            output_str += f'  $ {test_case.args}\n'
346            if test_case.is_error():
347                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
348                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
349            if test_case.is_failure():
350                for failure in test_case.failures:
351                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
352                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
353    return output_str
354
355
356def run_test(index: int, test: str, spec: TestSpec, backend: str,
357             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
358    """Run a single test case and backend combination
359
360    Args:
361        index (int): Index of backend for current spec
362        test (str): Path to test
363        spec (TestSpec): Specification of test case
364        backend (str): CEED backend
365        mode (RunMode): Output mode
366        nproc (int): Number of MPI processes to use when running test case
367        suite_spec (SuiteSpec): Specification of test suite
368
369    Returns:
370        TestCase: Test case result
371    """
372    source_path: Path = suite_spec.get_source_path(test)
373    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
374
375    if '{ceed_resource}' in run_args:
376        run_args[run_args.index('{ceed_resource}')] = backend
377    for i, arg in enumerate(run_args):
378        if '{ceed_resource}' in arg:
379            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
380    if '{nproc}' in run_args:
381        run_args[run_args.index('{nproc}')] = f'{nproc}'
382    elif nproc > 1 and source_path.suffix != '.py':
383        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
384
385    # run test
386    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
387    if skip_reason:
388        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
389                                       elapsed_sec=0,
390                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
391                                       stdout='',
392                                       stderr='',
393                                       category=spec.name,)
394        test_case.add_skipped_info(skip_reason)
395    else:
396        start: float = time.time()
397        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
398                              shell=True,
399                              stdout=subprocess.PIPE,
400                              stderr=subprocess.PIPE,
401                              env=my_env)
402
403        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
404                             classname=source_path.parent,
405                             elapsed_sec=time.time() - start,
406                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
407                             stdout=proc.stdout.decode('utf-8'),
408                             stderr=proc.stderr.decode('utf-8'),
409                             allow_multiple_subelements=True,
410                             category=spec.name,)
411        ref_csvs: List[Path] = []
412        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
413        if output_files:
414            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
415        ref_cgns: List[Path] = []
416        output_files = [arg for arg in run_args if 'cgns:' in arg]
417        if output_files:
418            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
419        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
420        suite_spec.post_test_hook(test, spec)
421
422    # check allowed failures
423    if not test_case.is_skipped() and test_case.stderr:
424        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
425        if skip_reason:
426            test_case.add_skipped_info(skip_reason)
427
428    # check required failures
429    if not test_case.is_skipped():
430        required_message, did_fail = suite_spec.check_required_failure(
431            test, spec, backend, test_case.stderr)
432        if required_message and did_fail:
433            test_case.status = f'fails with required: {required_message}'
434        elif required_message:
435            test_case.add_failure_info(f'required failure missing: {required_message}')
436
437    # classify other results
438    if not test_case.is_skipped() and not test_case.status:
439        if test_case.stderr:
440            test_case.add_failure_info('stderr', test_case.stderr)
441        if proc.returncode != 0:
442            test_case.add_error_info(f'returncode = {proc.returncode}')
443        if ref_stdout.is_file():
444            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
445                                             test_case.stdout.splitlines(keepends=True),
446                                             fromfile=str(ref_stdout),
447                                             tofile='New'))
448            if diff:
449                test_case.add_failure_info('stdout', output=''.join(diff))
450        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
451            test_case.add_failure_info('stdout', output=test_case.stdout)
452        # expected CSV output
453        for ref_csv in ref_csvs:
454            csv_name = ref_csv.name
455            if not ref_csv.is_file():
456                # remove _{ceed_backend} from path name
457                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
458            if not ref_csv.is_file():
459                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
460            else:
461                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
462                if diff:
463                    test_case.add_failure_info('csv', output=diff)
464                else:
465                    (Path.cwd() / csv_name).unlink()
466        # expected CGNS output
467        for ref_cgn in ref_cgns:
468            cgn_name = ref_cgn.name
469            if not ref_cgn.is_file():
470                # remove _{ceed_backend} from path name
471                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
472            if not ref_cgn.is_file():
473                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
474            else:
475                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn)
476                if diff:
477                    test_case.add_failure_info('cgns', output=diff)
478                else:
479                    (Path.cwd() / cgn_name).unlink()
480
481    # store result
482    test_case.args = ' '.join(str(arg) for arg in run_args)
483    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
484
485    return test_case, output_str
486
487
488def init_process():
489    """Initialize multiprocessing process"""
490    # set up error handler
491    global my_env
492    my_env = os.environ.copy()
493    my_env['CEED_ERROR_HANDLER'] = 'exit'
494
495
496def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
497              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
498    """Run all test cases for `test` with each of the provided `ceed_backends`
499
500    Args:
501        test (str): Name of test
502        ceed_backends (List[str]): List of libCEED backends
503        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
504        nproc (int): Number of MPI processes to use when running each test case
505        suite_spec (SuiteSpec): Object defining required methods for running tests
506        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
507
508    Returns:
509        TestSuite: JUnit `TestSuite` containing results of all test cases
510    """
511    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
512    if mode is RunMode.TAP:
513        print('TAP version 13')
514        print(f'1..{len(test_specs)}')
515
516    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
517        async_outputs: List[List[mp.AsyncResult]] = [
518            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
519             for (i, backend) in enumerate(ceed_backends, start=1)]
520            for spec in test_specs
521        ]
522
523        test_cases = []
524        for (i, subtest) in enumerate(async_outputs, start=1):
525            is_new_subtest = True
526            subtest_ok = True
527            for async_output in subtest:
528                test_case, print_output = async_output.get()
529                test_cases.append(test_case)
530                if is_new_subtest and mode == RunMode.TAP:
531                    is_new_subtest = False
532                    print(f'# Subtest: {test_case.category}')
533                    print(f'    1..{len(ceed_backends)}')
534                print(print_output, end='')
535                if test_case.is_failure() or test_case.is_error():
536                    subtest_ok = False
537            if mode == RunMode.TAP:
538                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
539
540    return TestSuite(test, test_cases)
541
542
543def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
544    """Write a JUnit XML file containing the results of a `TestSuite`
545
546    Args:
547        test_suite (TestSuite): JUnit `TestSuite` to write
548        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
549        batch (str): Name of JUnit batch, defaults to empty string
550    """
551    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
552    output_file.write_text(to_xml_report_string([test_suite]))
553
554
555def has_failures(test_suite: TestSuite) -> bool:
556    """Check whether any test cases in a `TestSuite` failed
557
558    Args:
559        test_suite (TestSuite): JUnit `TestSuite` to check
560
561    Returns:
562        bool: True if any test cases failed
563    """
564    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
565