xref: /libCEED/tests/junit_common.py (revision 2fee3251344e325ffa4c1e5d9ff42b3fc5f53b0c)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import sys
12import time
13from typing import Optional
14
15sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
16from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
17
18
19class CaseInsensitiveEnumAction(argparse.Action):
20    """Action to convert input values to lower case prior to converting to an Enum type"""
21
22    def __init__(self, option_strings, dest, type, default, **kwargs):
23        if not (issubclass(type, Enum) and issubclass(type, str)):
24            raise ValueError(f"{type} must be a StrEnum or str and Enum")
25        # store provided enum type
26        self.enum_type = type
27        if isinstance(default, str):
28            default = self.enum_type(default.lower())
29        else:
30            default = [self.enum_type(v.lower()) for v in default]
31        # prevent automatic type conversion
32        super().__init__(option_strings, dest, default=default, **kwargs)
33
34    def __call__(self, parser, namespace, values, option_string=None):
35        if isinstance(values, str):
36            values = self.enum_type(values.lower())
37        else:
38            values = [self.enum_type(v.lower()) for v in values]
39        setattr(namespace, self.dest, values)
40
41
42@dataclass
43class TestSpec:
44    """Dataclass storing information about a single test case"""
45    name: str
46    only: list = field(default_factory=list)
47    args: list = field(default_factory=list)
48
49
50class RunMode(str, Enum):
51    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
52    __str__ = str.__str__
53    __format__ = str.__format__
54    TAP: str = 'tap'
55    JUNIT: str = 'junit'
56
57
58class SuiteSpec(ABC):
59    """Abstract Base Class defining the required interface for running a test suite"""
60    @abstractmethod
61    def get_source_path(self, test: str) -> Path:
62        """Compute path to test source file
63
64        Args:
65            test (str): Name of test
66
67        Returns:
68            Path: Path to source file
69        """
70        raise NotImplementedError
71
72    @abstractmethod
73    def get_run_path(self, test: str) -> Path:
74        """Compute path to built test executable file
75
76        Args:
77            test (str): Name of test
78
79        Returns:
80            Path: Path to test executable
81        """
82        raise NotImplementedError
83
84    @abstractmethod
85    def get_output_path(self, test: str, output_file: str) -> Path:
86        """Compute path to expected output file
87
88        Args:
89            test (str): Name of test
90            output_file (str): File name of output file
91
92        Returns:
93            Path: Path to expected output file
94        """
95        raise NotImplementedError
96
97    def post_test_hook(self, test: str, spec: TestSpec) -> None:
98        """Function callback ran after each test case
99
100        Args:
101            test (str): Name of test
102            spec (TestSpec): Test case specification
103        """
104        pass
105
106    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
107        """Check if a test case should be skipped prior to running, returning the reason for skipping
108
109        Args:
110            test (str): Name of test
111            spec (TestSpec): Test case specification
112            resource (str): libCEED backend
113            nproc (int): Number of MPI processes to use when running test case
114
115        Returns:
116            Optional[str]: Skip reason, or `None` if test case should not be skipped
117        """
118        return None
119
120    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
121        """Check if a test case should be allowed to fail, based on its stderr output
122
123        Args:
124            test (str): Name of test
125            spec (TestSpec): Test case specification
126            resource (str): libCEED backend
127            stderr (str): Standard error output from test case execution
128
129        Returns:
130            Optional[str]: Skip reason, or `None` if unexpeced error
131        """
132        return None
133
134    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]:
135        """Check whether a test case is expected to fail and if it failed expectedly
136
137        Args:
138            test (str): Name of test
139            spec (TestSpec): Test case specification
140            resource (str): libCEED backend
141            stderr (str): Standard error output from test case execution
142
143        Returns:
144            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
145        """
146        return '', True
147
148    def check_allowed_stdout(self, test: str) -> bool:
149        """Check whether a test is allowed to print console output
150
151        Args:
152            test (str): Name of test
153
154        Returns:
155            bool: True if the test is allowed to print console output
156        """
157        return False
158
159
160def has_cgnsdiff() -> bool:
161    """Check whether `cgnsdiff` is an executable program in the current environment
162
163    Returns:
164        bool: True if `cgnsdiff` is found
165    """
166    my_env: dict = os.environ.copy()
167    proc = subprocess.run('cgnsdiff',
168                          shell=True,
169                          stdout=subprocess.PIPE,
170                          stderr=subprocess.PIPE,
171                          env=my_env)
172    return 'not found' not in proc.stderr.decode('utf-8')
173
174
175def contains_any(base: str, substrings: list[str]) -> bool:
176    """Helper function, checks if any of the substrings are included in the base string
177
178    Args:
179        base (str): Base string to search in
180        substrings (list[str]): List of potential substrings
181
182    Returns:
183        bool: True if any substrings are included in base string
184    """
185    return any((sub in base for sub in substrings))
186
187
188def startswith_any(base: str, prefixes: list[str]) -> bool:
189    """Helper function, checks if the base string is prefixed by any of `prefixes`
190
191    Args:
192        base (str): Base string to search
193        prefixes (list[str]): List of potential prefixes
194
195    Returns:
196        bool: True if base string is prefixed by any of the prefixes
197    """
198    return any((base.startswith(prefix) for prefix in prefixes))
199
200
201def parse_test_line(line: str) -> TestSpec:
202    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
203
204    Args:
205        line (str): String containing TESTARGS specification and CLI arguments
206
207    Returns:
208        TestSpec: Parsed specification of test case
209    """
210    args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
211    if args[0] == 'TESTARGS':
212        return TestSpec(name='', args=args[1:])
213    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
214    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
215    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
216    name: str = test_args.get('name', '')
217    constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else []
218    if len(args) > 1:
219        return TestSpec(name=name, only=constraints, args=args[1:])
220    else:
221        return TestSpec(name=name, only=constraints)
222
223
224def get_test_args(source_file: Path) -> list[TestSpec]:
225    """Parse all test cases from a given source file
226
227    Args:
228        source_file (Path): Path to source file
229
230    Raises:
231        RuntimeError: Errors if source file extension is unsupported
232
233    Returns:
234        list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
235    """
236    comment_str: str = ''
237    if source_file.suffix in ['.c', '.cpp']:
238        comment_str = '//'
239    elif source_file.suffix in ['.py']:
240        comment_str = '#'
241    elif source_file.suffix in ['.usr']:
242        comment_str = 'C_'
243    elif source_file.suffix in ['.f90']:
244        comment_str = '! '
245    else:
246        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
247
248    return [parse_test_line(line.strip(comment_str))
249            for line in source_file.read_text().splitlines()
250            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
251
252
253def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
254    """Compare CSV results against an expected CSV file with tolerances
255
256    Args:
257        test_csv (Path): Path to output CSV results
258        true_csv (Path): Path to expected CSV results
259        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
260        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
261
262    Returns:
263        str: Diff output between result and expected CSVs
264    """
265    test_lines: list[str] = test_csv.read_text().splitlines()
266    true_lines: list[str] = true_csv.read_text().splitlines()
267
268    if test_lines[0] != true_lines[0]:
269        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
270                       tofile='found CSV columns', fromfile='expected CSV columns'))
271
272    diff_lines: list[str] = list()
273    column_names: list[str] = true_lines[0].strip().split(',')
274    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
275        test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')]
276        true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')]
277        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
278            true_zero: bool = abs(true_val) < zero_tol
279            test_zero: bool = abs(test_val) < zero_tol
280            fail: bool = False
281            if true_zero:
282                fail = not test_zero
283            else:
284                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
285            if fail:
286                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
287    return '\n'.join(diff_lines)
288
289
290def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
291    """Compare CGNS results against an expected CGSN file with tolerance
292
293    Args:
294        test_cgns (Path): Path to output CGNS file
295        true_cgns (Path): Path to expected CGNS file
296        tolerance (float, optional): Tolerance for comparing floating-point values
297
298    Returns:
299        str: Diff output between result and expected CGNS files
300    """
301    my_env: dict = os.environ.copy()
302
303    run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
304    proc = subprocess.run(' '.join(run_args),
305                          shell=True,
306                          stdout=subprocess.PIPE,
307                          stderr=subprocess.PIPE,
308                          env=my_env)
309
310    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
311
312
313def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite:
314    """Run all test cases for `test` with each of the provided `ceed_backends`
315
316    Args:
317        test (str): Name of test
318        ceed_backends (list[str]): List of libCEED backends
319        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
320        nproc (int): Number of MPI processes to use when running each test case
321        suite_spec (SuiteSpec): Object defining required methods for running tests
322
323    Returns:
324        TestSuite: JUnit `TestSuite` containing results of all test cases
325    """
326    source_path: Path = suite_spec.get_source_path(test)
327    test_specs: list[TestSpec] = get_test_args(source_path)
328
329    if mode is RunMode.TAP:
330        print('1..' + str(len(test_specs) * len(ceed_backends)))
331
332    test_cases: list[TestCase] = []
333    my_env: dict = os.environ.copy()
334    my_env['CEED_ERROR_HANDLER'] = 'exit'
335
336    index: int = 1
337    for spec in test_specs:
338        for ceed_resource in ceed_backends:
339            run_args: list = [suite_spec.get_run_path(test), *spec.args]
340
341            if '{ceed_resource}' in run_args:
342                run_args[run_args.index('{ceed_resource}')] = ceed_resource
343            if '{nproc}' in run_args:
344                run_args[run_args.index('{nproc}')] = f'{nproc}'
345            elif nproc > 1 and source_path.suffix != '.py':
346                run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
347
348            # run test
349            skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc)
350            if skip_reason:
351                test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
352                                               elapsed_sec=0,
353                                               timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
354                                               stdout='',
355                                               stderr='')
356                test_case.add_skipped_info(skip_reason)
357            else:
358                start: float = time.time()
359                proc = subprocess.run(' '.join(str(arg) for arg in run_args),
360                                      shell=True,
361                                      stdout=subprocess.PIPE,
362                                      stderr=subprocess.PIPE,
363                                      env=my_env)
364
365                test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
366                                     classname=source_path.parent,
367                                     elapsed_sec=time.time() - start,
368                                     timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
369                                     stdout=proc.stdout.decode('utf-8'),
370                                     stderr=proc.stderr.decode('utf-8'),
371                                     allow_multiple_subelements=True)
372                ref_csvs: list[Path] = []
373                output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg]
374                if output_files:
375                    ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
376                ref_cgns: list[Path] = []
377                output_files = [arg for arg in spec.args if 'cgns:' in arg]
378                if output_files:
379                    ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
380                ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
381                suite_spec.post_test_hook(test, spec)
382
383            # check allowed failures
384            if not test_case.is_skipped() and test_case.stderr:
385                skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr)
386                if skip_reason:
387                    test_case.add_skipped_info(skip_reason)
388
389            # check required failures
390            if not test_case.is_skipped():
391                required_message, did_fail = suite_spec.check_required_failure(
392                    test, spec, ceed_resource, test_case.stderr)
393                if required_message and did_fail:
394                    test_case.status = f'fails with required: {required_message}'
395                elif required_message:
396                    test_case.add_failure_info(f'required failure missing: {required_message}')
397
398            # classify other results
399            if not test_case.is_skipped() and not test_case.status:
400                if test_case.stderr:
401                    test_case.add_failure_info('stderr', test_case.stderr)
402                if proc.returncode != 0:
403                    test_case.add_error_info(f'returncode = {proc.returncode}')
404                if ref_stdout.is_file():
405                    diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
406                                                     test_case.stdout.splitlines(keepends=True),
407                                                     fromfile=str(ref_stdout),
408                                                     tofile='New'))
409                    if diff:
410                        test_case.add_failure_info('stdout', output=''.join(diff))
411                elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
412                    test_case.add_failure_info('stdout', output=test_case.stdout)
413                # expected CSV output
414                for ref_csv in ref_csvs:
415                    if not ref_csv.is_file():
416                        test_case.add_failure_info('csv', output=f'{ref_csv} not found')
417                    else:
418                        diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv)
419                        if diff:
420                            test_case.add_failure_info('csv', output=diff)
421                        else:
422                            (Path.cwd() / ref_csv.name).unlink()
423                # expected CGNS output
424                for ref_cgn in ref_cgns:
425                    if not ref_cgn.is_file():
426                        test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
427                    else:
428                        diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn)
429                        if diff:
430                            test_case.add_failure_info('cgns', output=diff)
431                        else:
432                            (Path.cwd() / ref_cgn.name).unlink()
433
434            # store result
435            test_case.args = ' '.join(str(arg) for arg in run_args)
436            test_cases.append(test_case)
437
438            if mode is RunMode.TAP:
439                # print incremental output if TAP mode
440                print(f'# Test: {spec.name}')
441                if spec.only:
442                    print('# Only: {}'.format(','.join(spec.only)))
443                print(f'# $ {test_case.args}')
444                if test_case.is_skipped():
445                    print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip()))
446                elif test_case.is_failure() or test_case.is_error():
447                    print(f'not ok {index}')
448                    if test_case.is_error():
449                        print(f'  ERROR: {test_case.errors[0]["message"]}')
450                    if test_case.is_failure():
451                        for i, failure in enumerate(test_case.failures):
452                            print(f'  FAILURE {i}: {failure["message"]}')
453                            print(f'    Output: \n{failure["output"]}')
454                else:
455                    print(f'ok {index} - PASS')
456                sys.stdout.flush()
457            else:
458                # print error or failure information if JUNIT mode
459                if test_case.is_error() or test_case.is_failure():
460                    print(f'Test: {test} {spec.name}')
461                    print(f'  $ {test_case.args}')
462                    if test_case.is_error():
463                        print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()))
464                        print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()))
465                    if test_case.is_failure():
466                        for failure in test_case.failures:
467                            print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip()))
468                            print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip()))
469                sys.stdout.flush()
470            index += 1
471
472    return TestSuite(test, test_cases)
473
474
475def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
476    """Write a JUnit XML file containing the results of a `TestSuite`
477
478    Args:
479        test_suite (TestSuite): JUnit `TestSuite` to write
480        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
481        batch (str): Name of JUnit batch, defaults to empty string
482    """
483    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
484    output_file.write_text(to_xml_report_string([test_suite]))
485
486
487def has_failures(test_suite: TestSuite) -> bool:
488    """Check whether any test cases in a `TestSuite` failed
489
490    Args:
491        test_suite (TestSuite): JUnit `TestSuite` to check
492
493    Returns:
494        bool: True if any test cases failed
495    """
496    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
497