xref: /libCEED/tests/junit_common.py (revision 37eda346557d6c6a68044736ef6477e646c46425)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import sys
12import time
13from typing import Optional
14
15sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
16from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
17
18
19class CaseInsensitiveEnumAction(argparse.Action):
20    """Action to convert input values to lower case prior to converting to an Enum type"""
21
22    def __init__(self, option_strings, dest, type, default, **kwargs):
23        if not (issubclass(type, Enum) and issubclass(type, str)):
24            raise ValueError(f"{type} must be a StrEnum or str and Enum")
25        # store provided enum type
26        self.enum_type = type
27        if isinstance(default, str):
28            default = self.enum_type(default.lower())
29        else:
30            default = [self.enum_type(v.lower()) for v in default]
31        # prevent automatic type conversion
32        super().__init__(option_strings, dest, default=default, **kwargs)
33
34    def __call__(self, parser, namespace, values, option_string=None):
35        if isinstance(values, str):
36            values = self.enum_type(values.lower())
37        else:
38            values = [self.enum_type(v.lower()) for v in values]
39        setattr(namespace, self.dest, values)
40
41
42@dataclass
43class TestSpec:
44    """Dataclass storing information about a single test case"""
45    name: str
46    only: list = field(default_factory=list)
47    args: list = field(default_factory=list)
48
49
50class RunMode(str, Enum):
51    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
52    __str__ = str.__str__
53    __format__ = str.__format__
54    TAP: str = 'tap'
55    JUNIT: str = 'junit'
56
57
58class SuiteSpec(ABC):
59    """Abstract Base Class defining the required interface for running a test suite"""
60    @abstractmethod
61    def get_source_path(self, test: str) -> Path:
62        """Compute path to test source file
63
64        Args:
65            test (str): Name of test
66
67        Returns:
68            Path: Path to source file
69        """
70        raise NotImplementedError
71
72    @abstractmethod
73    def get_run_path(self, test: str) -> Path:
74        """Compute path to built test executable file
75
76        Args:
77            test (str): Name of test
78
79        Returns:
80            Path: Path to test executable
81        """
82        raise NotImplementedError
83
84    @abstractmethod
85    def get_output_path(self, test: str, output_file: str) -> Path:
86        """Compute path to expected output file
87
88        Args:
89            test (str): Name of test
90            output_file (str): File name of output file
91
92        Returns:
93            Path: Path to expected output file
94        """
95        raise NotImplementedError
96
97    def post_test_hook(self, test: str, spec: TestSpec) -> None:
98        """Function callback ran after each test case
99
100        Args:
101            test (str): Name of test
102            spec (TestSpec): Test case specification
103        """
104        pass
105
106    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
107        """Check if a test case should be skipped prior to running, returning the reason for skipping
108
109        Args:
110            test (str): Name of test
111            spec (TestSpec): Test case specification
112            resource (str): libCEED backend
113            nproc (int): Number of MPI processes to use when running test case
114
115        Returns:
116            Optional[str]: Skip reason, or `None` if test case should not be skipped
117        """
118        return None
119
120    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
121        """Check if a test case should be allowed to fail, based on its stderr output
122
123        Args:
124            test (str): Name of test
125            spec (TestSpec): Test case specification
126            resource (str): libCEED backend
127            stderr (str): Standard error output from test case execution
128
129        Returns:
130            Optional[str]: Skip reason, or `None` if unexpeced error
131        """
132        return None
133
134    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]:
135        """Check whether a test case is expected to fail and if it failed expectedly
136
137        Args:
138            test (str): Name of test
139            spec (TestSpec): Test case specification
140            resource (str): libCEED backend
141            stderr (str): Standard error output from test case execution
142
143        Returns:
144            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
145        """
146        return '', True
147
148    def check_allowed_stdout(self, test: str) -> bool:
149        """Check whether a test is allowed to print console output
150
151        Args:
152            test (str): Name of test
153
154        Returns:
155            bool: True if the test is allowed to print console output
156        """
157        return False
158
159
160def has_cgnsdiff() -> bool:
161    """Check whether `cgnsdiff` is an executable program in the current environment
162
163    Returns:
164        bool: True if `cgnsdiff` is found
165    """
166    my_env: dict = os.environ.copy()
167    proc = subprocess.run('cgnsdiff',
168                          shell=True,
169                          stdout=subprocess.PIPE,
170                          stderr=subprocess.PIPE,
171                          env=my_env)
172    return 'not found' not in proc.stderr.decode('utf-8')
173
174
175def contains_any(base: str, substrings: list[str]) -> bool:
176    """Helper function, checks if any of the substrings are included in the base string
177
178    Args:
179        base (str): Base string to search in
180        substrings (list[str]): List of potential substrings
181
182    Returns:
183        bool: True if any substrings are included in base string
184    """
185    return any((sub in base for sub in substrings))
186
187
188def startswith_any(base: str, prefixes: list[str]) -> bool:
189    """Helper function, checks if the base string is prefixed by any of `prefixes`
190
191    Args:
192        base (str): Base string to search
193        prefixes (list[str]): List of potential prefixes
194
195    Returns:
196        bool: True if base string is prefixed by any of the prefixes
197    """
198    return any((base.startswith(prefix) for prefix in prefixes))
199
200
201def parse_test_line(line: str) -> TestSpec:
202    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
203
204    Args:
205        line (str): String containing TESTARGS specification and CLI arguments
206
207    Returns:
208        TestSpec: Parsed specification of test case
209    """
210    args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
211    if args[0] == 'TESTARGS':
212        return TestSpec(name='', args=args[1:])
213    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
214    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
215    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
216    constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else []
217    if len(args) > 1:
218        return TestSpec(name=test_args['name'], only=constraints, args=args[1:])
219    else:
220        return TestSpec(name=test_args['name'], only=constraints)
221
222
223def get_test_args(source_file: Path) -> list[TestSpec]:
224    """Parse all test cases from a given source file
225
226    Args:
227        source_file (Path): Path to source file
228
229    Raises:
230        RuntimeError: Errors if source file extension is unsupported
231
232    Returns:
233        list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
234    """
235    comment_str: str = ''
236    if source_file.suffix in ['.c', '.cpp']:
237        comment_str = '//'
238    elif source_file.suffix in ['.py']:
239        comment_str = '#'
240    elif source_file.suffix in ['.usr']:
241        comment_str = 'C_'
242    elif source_file.suffix in ['.f90']:
243        comment_str = '! '
244    else:
245        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
246
247    return [parse_test_line(line.strip(comment_str))
248            for line in source_file.read_text().splitlines()
249            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
250
251
252def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
253    """Compare CSV results against an expected CSV file with tolerances
254
255    Args:
256        test_csv (Path): Path to output CSV results
257        true_csv (Path): Path to expected CSV results
258        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
259        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
260
261    Returns:
262        str: Diff output between result and expected CSVs
263    """
264    test_lines: list[str] = test_csv.read_text().splitlines()
265    true_lines: list[str] = true_csv.read_text().splitlines()
266
267    if test_lines[0] != true_lines[0]:
268        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
269                       tofile='found CSV columns', fromfile='expected CSV columns'))
270
271    diff_lines: list[str] = list()
272    column_names: list[str] = true_lines[0].strip().split(',')
273    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
274        test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')]
275        true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')]
276        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
277            true_zero: bool = abs(true_val) < zero_tol
278            test_zero: bool = abs(test_val) < zero_tol
279            fail: bool = False
280            if true_zero:
281                fail = not test_zero
282            else:
283                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
284            if fail:
285                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
286    return '\n'.join(diff_lines)
287
288
289def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
290    """Compare CGNS results against an expected CGSN file with tolerance
291
292    Args:
293        test_cgns (Path): Path to output CGNS file
294        true_cgns (Path): Path to expected CGNS file
295        tolerance (float, optional): Tolerance for comparing floating-point values
296
297    Returns:
298        str: Diff output between result and expected CGNS files
299    """
300    my_env: dict = os.environ.copy()
301
302    run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
303    proc = subprocess.run(' '.join(run_args),
304                          shell=True,
305                          stdout=subprocess.PIPE,
306                          stderr=subprocess.PIPE,
307                          env=my_env)
308
309    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
310
311
312def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite:
313    """Run all test cases for `test` with each of the provided `ceed_backends`
314
315    Args:
316        test (str): Name of test
317        ceed_backends (list[str]): List of libCEED backends
318        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
319        nproc (int): Number of MPI processes to use when running each test case
320        suite_spec (SuiteSpec): Object defining required methods for running tests
321
322    Returns:
323        TestSuite: JUnit `TestSuite` containing results of all test cases
324    """
325    source_path: Path = suite_spec.get_source_path(test)
326    test_specs: list[TestSpec] = get_test_args(source_path)
327
328    if mode is RunMode.TAP:
329        print('1..' + str(len(test_specs) * len(ceed_backends)))
330
331    test_cases: list[TestCase] = []
332    my_env: dict = os.environ.copy()
333    my_env['CEED_ERROR_HANDLER'] = 'exit'
334
335    index: int = 1
336    for spec in test_specs:
337        for ceed_resource in ceed_backends:
338            run_args: list = [suite_spec.get_run_path(test), *spec.args]
339
340            if '{ceed_resource}' in run_args:
341                run_args[run_args.index('{ceed_resource}')] = ceed_resource
342            if '{nproc}' in run_args:
343                run_args[run_args.index('{nproc}')] = f'{nproc}'
344            elif nproc > 1 and source_path.suffix != '.py':
345                run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
346
347            # run test
348            skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc)
349            if skip_reason:
350                test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
351                                               elapsed_sec=0,
352                                               timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
353                                               stdout='',
354                                               stderr='')
355                test_case.add_skipped_info(skip_reason)
356            else:
357                start: float = time.time()
358                proc = subprocess.run(' '.join(str(arg) for arg in run_args),
359                                      shell=True,
360                                      stdout=subprocess.PIPE,
361                                      stderr=subprocess.PIPE,
362                                      env=my_env)
363
364                test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
365                                     classname=source_path.parent,
366                                     elapsed_sec=time.time() - start,
367                                     timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
368                                     stdout=proc.stdout.decode('utf-8'),
369                                     stderr=proc.stderr.decode('utf-8'),
370                                     allow_multiple_subelements=True)
371                ref_csvs: list[Path] = []
372                output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg]
373                if output_files:
374                    ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
375                ref_cgns: list[Path] = []
376                output_files = [arg for arg in spec.args if 'cgns:' in arg]
377                if output_files:
378                    ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
379                ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
380                suite_spec.post_test_hook(test, spec)
381
382            # check allowed failures
383            if not test_case.is_skipped() and test_case.stderr:
384                skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr)
385                if skip_reason:
386                    test_case.add_skipped_info(skip_reason)
387
388            # check required failures
389            if not test_case.is_skipped():
390                required_message, did_fail = suite_spec.check_required_failure(test, spec, ceed_resource, test_case.stderr)
391                if required_message and did_fail:
392                    test_case.status = f'fails with required: {required_message}'
393                elif required_message:
394                    test_case.add_failure_info(f'required failure missing: {required_message}')
395
396            # classify other results
397            if not test_case.is_skipped() and not test_case.status:
398                if test_case.stderr:
399                    test_case.add_failure_info('stderr', test_case.stderr)
400                if proc.returncode != 0:
401                    test_case.add_error_info(f'returncode = {proc.returncode}')
402                if ref_stdout.is_file():
403                    diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
404                                                     test_case.stdout.splitlines(keepends=True),
405                                                     fromfile=str(ref_stdout),
406                                                     tofile='New'))
407                    if diff:
408                        test_case.add_failure_info('stdout', output=''.join(diff))
409                elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
410                    test_case.add_failure_info('stdout', output=test_case.stdout)
411                # expected CSV output
412                for ref_csv in ref_csvs:
413                    if not ref_csv.is_file():
414                        test_case.add_failure_info('csv', output=f'{ref_csv} not found')
415                    else:
416                        diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv)
417                        if diff:
418                            test_case.add_failure_info('csv', output=diff)
419                        else:
420                            (Path.cwd() / ref_csv.name).unlink()
421                # expected CGNS output
422                for ref_cgn in ref_cgns:
423                    if not ref_cgn.is_file():
424                        test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
425                    else:
426                        diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn)
427                        if diff:
428                            test_case.add_failure_info('cgns', output=diff)
429                        else:
430                            (Path.cwd() / ref_cgn.name).unlink()
431
432            # store result
433            test_case.args = ' '.join(str(arg) for arg in run_args)
434            test_cases.append(test_case)
435
436            if mode is RunMode.TAP:
437                # print incremental output if TAP mode
438                print(f'# Test: {spec.name}')
439                if spec.only:
440                    print('# Only: {}'.format(','.join(spec.only)))
441                print(f'# $ {test_case.args}')
442                if test_case.is_skipped():
443                    print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip()))
444                elif test_case.is_failure() or test_case.is_error():
445                    print(f'not ok {index}')
446                    if test_case.is_error():
447                        print(f'  ERROR: {test_case.errors[0]["message"]}')
448                    if test_case.is_failure():
449                        for i, failure in enumerate(test_case.failures):
450                            print(f'  FAILURE {i}: {failure["message"]}')
451                            print(f'    Output: \n{failure["output"]}')
452                else:
453                    print(f'ok {index} - PASS')
454                sys.stdout.flush()
455            else:
456                # print error or failure information if JUNIT mode
457                if test_case.is_error() or test_case.is_failure():
458                    print(f'Test: {test} {spec.name}')
459                    print(f'  $ {test_case.args}')
460                    if test_case.is_error():
461                        print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()))
462                        print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()))
463                    if test_case.is_failure():
464                        for failure in test_case.failures:
465                            print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip()))
466                            print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip()))
467                sys.stdout.flush()
468            index += 1
469
470    return TestSuite(test, test_cases)
471
472
473def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
474    """Write a JUnit XML file containing the results of a `TestSuite`
475
476    Args:
477        test_suite (TestSuite): JUnit `TestSuite` to write
478        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
479        batch (str): Name of JUnit batch, defaults to empty string
480    """
481    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
482    output_file.write_text(to_xml_report_string([test_suite]))
483
484
485def has_failures(test_suite: TestSuite) -> bool:
486    """Check whether any test cases in a `TestSuite` failed
487
488    Args:
489        test_suite (TestSuite): JUnit `TestSuite` to check
490
491    Returns:
492        bool: True if any test cases failed
493    """
494    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
495