xref: /libCEED/tests/junit_common.py (revision a71faab149de599cd7784196409e851b67144f0a)
1from abc import ABC, abstractmethod
2import argparse
3from dataclasses import dataclass, field
4import difflib
5from enum import Enum
6from math import isclose
7import os
8from pathlib import Path
9import re
10import subprocess
11import sys
12import time
13from typing import Optional
14
15sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
16from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
17
18
19class CaseInsensitiveEnumAction(argparse.Action):
20    """Action to convert input values to lower case prior to converting to an Enum type"""
21
22    def __init__(self, option_strings, dest, type, default, **kwargs):
23        if not (issubclass(type, Enum) and issubclass(type, str)):
24            raise ValueError(f"{type} must be a StrEnum or str and Enum")
25        # store provided enum type
26        self.enum_type = type
27        if isinstance(default, str):
28            default = self.enum_type(default.lower())
29        else:
30            default = [self.enum_type(v.lower()) for v in default]
31        # prevent automatic type conversion
32        super().__init__(option_strings, dest, default=default, **kwargs)
33
34    def __call__(self, parser, namespace, values, option_string=None):
35        if isinstance(values, str):
36            values = self.enum_type(values.lower())
37        else:
38            values = [self.enum_type(v.lower()) for v in values]
39        setattr(namespace, self.dest, values)
40
41
42@dataclass
43class TestSpec:
44    """Dataclass storing information about a single test case"""
45    name: str
46    only: list = field(default_factory=list)
47    args: list = field(default_factory=list)
48
49
50class RunMode(str, Enum):
51    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
52    __str__ = str.__str__
53    __format__ = str.__format__
54    TAP: str = 'tap'
55    JUNIT: str = 'junit'
56
57
58class SuiteSpec(ABC):
59    """Abstract Base Class defining the required interface for running a test suite"""
60    @abstractmethod
61    def get_source_path(self, test: str) -> Path:
62        """Compute path to test source file
63
64        Args:
65            test (str): Name of test
66
67        Returns:
68            Path: Path to source file
69        """
70        raise NotImplementedError
71
72    @abstractmethod
73    def get_run_path(self, test: str) -> Path:
74        """Compute path to built test executable file
75
76        Args:
77            test (str): Name of test
78
79        Returns:
80            Path: Path to test executable
81        """
82        raise NotImplementedError
83
84    @abstractmethod
85    def get_output_path(self, test: str, output_file: str) -> Path:
86        """Compute path to expected output file
87
88        Args:
89            test (str): Name of test
90            output_file (str): File name of output file
91
92        Returns:
93            Path: Path to expected output file
94        """
95        raise NotImplementedError
96
97    def post_test_hook(self, test: str, spec: TestSpec) -> None:
98        """Function callback ran after each test case
99
100        Args:
101            test (str): Name of test
102            spec (TestSpec): Test case specification
103        """
104        pass
105
106    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
107        """Check if a test case should be skipped prior to running, returning the reason for skipping
108
109        Args:
110            test (str): Name of test
111            spec (TestSpec): Test case specification
112            resource (str): libCEED backend
113            nproc (int): Number of MPI processes to use when running test case
114
115        Returns:
116            Optional[str]: Skip reason, or `None` if test case should not be skipped
117        """
118        return None
119
120    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
121        """Check if a test case should be allowed to fail, based on its stderr output
122
123        Args:
124            test (str): Name of test
125            spec (TestSpec): Test case specification
126            resource (str): libCEED backend
127            stderr (str): Standard error output from test case execution
128
129        Returns:
130            Optional[str]: Skip reason, or `None` if unexpeced error
131        """
132        return None
133
134    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]:
135        """Check whether a test case is expected to fail and if it failed expectedly
136
137        Args:
138            test (str): Name of test
139            spec (TestSpec): Test case specification
140            resource (str): libCEED backend
141            stderr (str): Standard error output from test case execution
142
143        Returns:
144            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
145        """
146        return '', True
147
148    def check_allowed_stdout(self, test: str) -> bool:
149        """Check whether a test is allowed to print console output
150
151        Args:
152            test (str): Name of test
153
154        Returns:
155            bool: True if the test is allowed to print console output
156        """
157        return False
158
159
160def has_cgnsdiff() -> bool:
161    """Check whether `cgnsdiff` is an executable program in the current environment
162
163    Returns:
164        bool: True if `cgnsdiff` is found
165    """
166    my_env: dict = os.environ.copy()
167    proc = subprocess.run('cgnsdiff',
168                          shell=True,
169                          stdout=subprocess.PIPE,
170                          stderr=subprocess.PIPE,
171                          env=my_env)
172    return 'not found' not in proc.stderr.decode('utf-8')
173
174
175def contains_any(base: str, substrings: list[str]) -> bool:
176    """Helper function, checks if any of the substrings are included in the base string
177
178    Args:
179        base (str): Base string to search in
180        substrings (list[str]): List of potential substrings
181
182    Returns:
183        bool: True if any substrings are included in base string
184    """
185    return any((sub in base for sub in substrings))
186
187
188def startswith_any(base: str, prefixes: list[str]) -> bool:
189    """Helper function, checks if the base string is prefixed by any of `prefixes`
190
191    Args:
192        base (str): Base string to search
193        prefixes (list[str]): List of potential prefixes
194
195    Returns:
196        bool: True if base string is prefixed by any of the prefixes
197    """
198    return any((base.startswith(prefix) for prefix in prefixes))
199
200
201def parse_test_line(line: str) -> TestSpec:
202    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
203
204    Args:
205        line (str): String containing TESTARGS specification and CLI arguments
206
207    Returns:
208        TestSpec: Parsed specification of test case
209    """
210    args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
211    if args[0] == 'TESTARGS':
212        return TestSpec(name='', args=args[1:])
213    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
214    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
215    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
216    name: str = test_args.get('name', '')
217    constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else []
218    if len(args) > 1:
219        return TestSpec(name=name, only=constraints, args=args[1:])
220    else:
221        return TestSpec(name=name, only=constraints)
222
223
224def get_test_args(source_file: Path) -> list[TestSpec]:
225    """Parse all test cases from a given source file
226
227    Args:
228        source_file (Path): Path to source file
229
230    Raises:
231        RuntimeError: Errors if source file extension is unsupported
232
233    Returns:
234        list[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
235    """
236    comment_str: str = ''
237    if source_file.suffix in ['.c', '.cpp']:
238        comment_str = '//'
239    elif source_file.suffix in ['.py']:
240        comment_str = '#'
241    elif source_file.suffix in ['.usr']:
242        comment_str = 'C_'
243    elif source_file.suffix in ['.f90']:
244        comment_str = '! '
245    else:
246        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
247
248    return [parse_test_line(line.strip(comment_str))
249            for line in source_file.read_text().splitlines()
250            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
251
252
253def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
254    """Compare CSV results against an expected CSV file with tolerances
255
256    Args:
257        test_csv (Path): Path to output CSV results
258        true_csv (Path): Path to expected CSV results
259        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
260        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
261
262    Returns:
263        str: Diff output between result and expected CSVs
264    """
265    test_lines: list[str] = test_csv.read_text().splitlines()
266    true_lines: list[str] = true_csv.read_text().splitlines()
267
268    if test_lines[0] != true_lines[0]:
269        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
270                       tofile='found CSV columns', fromfile='expected CSV columns'))
271
272    diff_lines: list[str] = list()
273    column_names: list[str] = true_lines[0].strip().split(',')
274    for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
275        test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')]
276        true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')]
277        for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
278            true_zero: bool = abs(true_val) < zero_tol
279            test_zero: bool = abs(test_val) < zero_tol
280            fail: bool = False
281            if true_zero:
282                fail = not test_zero
283            else:
284                fail = not isclose(test_val, true_val, rel_tol=rel_tol)
285            if fail:
286                diff_lines.append(f'step: {true_line[0]}, column: {column_name}, expected: {true_val}, got: {test_val}')
287    return '\n'.join(diff_lines)
288
289
290def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
291    """Compare CGNS results against an expected CGSN file with tolerance
292
293    Args:
294        test_cgns (Path): Path to output CGNS file
295        true_cgns (Path): Path to expected CGNS file
296        tolerance (float, optional): Tolerance for comparing floating-point values
297
298    Returns:
299        str: Diff output between result and expected CGNS files
300    """
301    my_env: dict = os.environ.copy()
302
303    run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
304    proc = subprocess.run(' '.join(run_args),
305                          shell=True,
306                          stdout=subprocess.PIPE,
307                          stderr=subprocess.PIPE,
308                          env=my_env)
309
310    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
311
312
313def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestSuite:
314    """Run all test cases for `test` with each of the provided `ceed_backends`
315
316    Args:
317        test (str): Name of test
318        ceed_backends (list[str]): List of libCEED backends
319        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
320        nproc (int): Number of MPI processes to use when running each test case
321        suite_spec (SuiteSpec): Object defining required methods for running tests
322
323    Returns:
324        TestSuite: JUnit `TestSuite` containing results of all test cases
325    """
326    source_path: Path = suite_spec.get_source_path(test)
327    test_specs: list[TestSpec] = get_test_args(source_path)
328
329    if mode is RunMode.TAP:
330        print('1..' + str(len(test_specs) * len(ceed_backends)))
331
332    test_cases: list[TestCase] = []
333    my_env: dict = os.environ.copy()
334    my_env['CEED_ERROR_HANDLER'] = 'exit'
335
336    index: int = 1
337    for spec in test_specs:
338        for ceed_resource in ceed_backends:
339            run_args: list = [suite_spec.get_run_path(test), *spec.args]
340
341            if '{ceed_resource}' in run_args:
342                run_args[run_args.index('{ceed_resource}')] = ceed_resource
343            if '{nproc}' in run_args:
344                run_args[run_args.index('{nproc}')] = f'{nproc}'
345            elif nproc > 1 and source_path.suffix != '.py':
346                run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
347
348            # run test
349            skip_reason: str = suite_spec.check_pre_skip(test, spec, ceed_resource, nproc)
350            if skip_reason:
351                test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
352                                               elapsed_sec=0,
353                                               timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
354                                               stdout='',
355                                               stderr='')
356                test_case.add_skipped_info(skip_reason)
357            else:
358                start: float = time.time()
359                proc = subprocess.run(' '.join(str(arg) for arg in run_args),
360                                      shell=True,
361                                      stdout=subprocess.PIPE,
362                                      stderr=subprocess.PIPE,
363                                      env=my_env)
364
365                test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {ceed_resource}',
366                                     classname=source_path.parent,
367                                     elapsed_sec=time.time() - start,
368                                     timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
369                                     stdout=proc.stdout.decode('utf-8'),
370                                     stderr=proc.stderr.decode('utf-8'),
371                                     allow_multiple_subelements=True)
372                ref_csvs: list[Path] = []
373                output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg]
374                if output_files:
375                    ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
376                ref_cgns: list[Path] = []
377                output_files = [arg for arg in spec.args if 'cgns:' in arg]
378                if output_files:
379                    ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
380                ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
381                suite_spec.post_test_hook(test, spec)
382
383            # check allowed failures
384            if not test_case.is_skipped() and test_case.stderr:
385                skip_reason: str = suite_spec.check_post_skip(test, spec, ceed_resource, test_case.stderr)
386                if skip_reason:
387                    test_case.add_skipped_info(skip_reason)
388
389            # check required failures
390            if not test_case.is_skipped():
391                required_message, did_fail = suite_spec.check_required_failure(test, spec, ceed_resource, test_case.stderr)
392                if required_message and did_fail:
393                    test_case.status = f'fails with required: {required_message}'
394                elif required_message:
395                    test_case.add_failure_info(f'required failure missing: {required_message}')
396
397            # classify other results
398            if not test_case.is_skipped() and not test_case.status:
399                if test_case.stderr:
400                    test_case.add_failure_info('stderr', test_case.stderr)
401                if proc.returncode != 0:
402                    test_case.add_error_info(f'returncode = {proc.returncode}')
403                if ref_stdout.is_file():
404                    diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
405                                                     test_case.stdout.splitlines(keepends=True),
406                                                     fromfile=str(ref_stdout),
407                                                     tofile='New'))
408                    if diff:
409                        test_case.add_failure_info('stdout', output=''.join(diff))
410                elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
411                    test_case.add_failure_info('stdout', output=test_case.stdout)
412                # expected CSV output
413                for ref_csv in ref_csvs:
414                    if not ref_csv.is_file():
415                        test_case.add_failure_info('csv', output=f'{ref_csv} not found')
416                    else:
417                        diff: str = diff_csv(Path.cwd() / ref_csv.name, ref_csv)
418                        if diff:
419                            test_case.add_failure_info('csv', output=diff)
420                        else:
421                            (Path.cwd() / ref_csv.name).unlink()
422                # expected CGNS output
423                for ref_cgn in ref_cgns:
424                    if not ref_cgn.is_file():
425                        test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
426                    else:
427                        diff = diff_cgns(Path.cwd() / ref_cgn.name, ref_cgn)
428                        if diff:
429                            test_case.add_failure_info('cgns', output=diff)
430                        else:
431                            (Path.cwd() / ref_cgn.name).unlink()
432
433            # store result
434            test_case.args = ' '.join(str(arg) for arg in run_args)
435            test_cases.append(test_case)
436
437            if mode is RunMode.TAP:
438                # print incremental output if TAP mode
439                print(f'# Test: {spec.name}')
440                if spec.only:
441                    print('# Only: {}'.format(','.join(spec.only)))
442                print(f'# $ {test_case.args}')
443                if test_case.is_skipped():
444                    print('ok {} - SKIP: {}'.format(index, (test_case.skipped[0]['message'] or 'NO MESSAGE').strip()))
445                elif test_case.is_failure() or test_case.is_error():
446                    print(f'not ok {index}')
447                    if test_case.is_error():
448                        print(f'  ERROR: {test_case.errors[0]["message"]}')
449                    if test_case.is_failure():
450                        for i, failure in enumerate(test_case.failures):
451                            print(f'  FAILURE {i}: {failure["message"]}')
452                            print(f'    Output: \n{failure["output"]}')
453                else:
454                    print(f'ok {index} - PASS')
455                sys.stdout.flush()
456            else:
457                # print error or failure information if JUNIT mode
458                if test_case.is_error() or test_case.is_failure():
459                    print(f'Test: {test} {spec.name}')
460                    print(f'  $ {test_case.args}')
461                    if test_case.is_error():
462                        print('ERROR: {}'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip()))
463                        print('Output: \n{}'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip()))
464                    if test_case.is_failure():
465                        for failure in test_case.failures:
466                            print('FAIL: {}'.format((failure['message'] or 'NO MESSAGE').strip()))
467                            print('Output: \n{}'.format((failure['output'] or 'NO MESSAGE').strip()))
468                sys.stdout.flush()
469            index += 1
470
471    return TestSuite(test, test_cases)
472
473
474def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
475    """Write a JUnit XML file containing the results of a `TestSuite`
476
477    Args:
478        test_suite (TestSuite): JUnit `TestSuite` to write
479        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
480        batch (str): Name of JUnit batch, defaults to empty string
481    """
482    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
483    output_file.write_text(to_xml_report_string([test_suite]))
484
485
486def has_failures(test_suite: TestSuite) -> bool:
487    """Check whether any test cases in a `TestSuite` failed
488
489    Args:
490        test_suite (TestSuite): JUnit `TestSuite` to check
491
492    Returns:
493        bool: True if any test cases failed
494    """
495    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
496