xref: /libCEED/tests/junit_common.py (revision 4bd6ffc97dc9a7688ef3a2d802aad5d41776eea1)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    def post_test_hook(self, test: str, spec: TestSpec) -> None:
101        """Function callback ran after each test case
102
103        Args:
104            test (str): Name of test
105            spec (TestSpec): Test case specification
106        """
107        pass
108
109    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
110        """Check if a test case should be skipped prior to running, returning the reason for skipping
111
112        Args:
113            test (str): Name of test
114            spec (TestSpec): Test case specification
115            resource (str): libCEED backend
116            nproc (int): Number of MPI processes to use when running test case
117
118        Returns:
119            Optional[str]: Skip reason, or `None` if test case should not be skipped
120        """
121        return None
122
123    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
124        """Check if a test case should be allowed to fail, based on its stderr output
125
126        Args:
127            test (str): Name of test
128            spec (TestSpec): Test case specification
129            resource (str): libCEED backend
130            stderr (str): Standard error output from test case execution
131
132        Returns:
133            Optional[str]: Skip reason, or `None` if unexpected error
134        """
135        return None
136
137    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
138        """Check whether a test case is expected to fail and if it failed expectedly
139
140        Args:
141            test (str): Name of test
142            spec (TestSpec): Test case specification
143            resource (str): libCEED backend
144            stderr (str): Standard error output from test case execution
145
146        Returns:
147            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
148        """
149        return '', True
150
151    def check_allowed_stdout(self, test: str) -> bool:
152        """Check whether a test is allowed to print console output
153
154        Args:
155            test (str): Name of test
156
157        Returns:
158            bool: True if the test is allowed to print console output
159        """
160        return False
161
162
163def has_cgnsdiff() -> bool:
164    """Check whether `cgnsdiff` is an executable program in the current environment
165
166    Returns:
167        bool: True if `cgnsdiff` is found
168    """
169    my_env: dict = os.environ.copy()
170    proc = subprocess.run('cgnsdiff',
171                          shell=True,
172                          stdout=subprocess.PIPE,
173                          stderr=subprocess.PIPE,
174                          env=my_env)
175    return 'not found' not in proc.stderr.decode('utf-8')
176
177
178def contains_any(base: str, substrings: List[str]) -> bool:
179    """Helper function, checks if any of the substrings are included in the base string
180
181    Args:
182        base (str): Base string to search in
183        substrings (List[str]): List of potential substrings
184
185    Returns:
186        bool: True if any substrings are included in base string
187    """
188    return any((sub in base for sub in substrings))
189
190
191def startswith_any(base: str, prefixes: List[str]) -> bool:
192    """Helper function, checks if the base string is prefixed by any of `prefixes`
193
194    Args:
195        base (str): Base string to search
196        prefixes (List[str]): List of potential prefixes
197
198    Returns:
199        bool: True if base string is prefixed by any of the prefixes
200    """
201    return any((base.startswith(prefix) for prefix in prefixes))
202
203
204def parse_test_line(line: str) -> TestSpec:
205    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
206
207    Args:
208        line (str): String containing TESTARGS specification and CLI arguments
209
210    Returns:
211        TestSpec: Parsed specification of test case
212    """
213    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
214    if args[0] == 'TESTARGS':
215        return TestSpec(name='', args=args[1:])
216    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
217    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
218    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
219    name: str = test_args.get('name', '')
220    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
221    if len(args) > 1:
222        return TestSpec(name=name, only=constraints, args=args[1:])
223    else:
224        return TestSpec(name=name, only=constraints)
225
226
227def get_test_args(source_file: Path) -> List[TestSpec]:
228    """Parse all test cases from a given source file
229
230    Args:
231        source_file (Path): Path to source file
232
233    Raises:
234        RuntimeError: Errors if source file extension is unsupported
235
236    Returns:
237        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
238    """
239    comment_str: str = ''
240    if source_file.suffix in ['.c', '.cc', '.cpp']:
241        comment_str = '//'
242    elif source_file.suffix in ['.py']:
243        comment_str = '#'
244    elif source_file.suffix in ['.usr']:
245        comment_str = 'C_'
246    elif source_file.suffix in ['.f90']:
247        comment_str = '! '
248    else:
249        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
250
251    return [parse_test_line(line.strip(comment_str))
252            for line in source_file.read_text().splitlines()
253            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
254
255
256def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
257    """Compare CSV results against an expected CSV file with tolerances
258
259    Args:
260        test_csv (Path): Path to output CSV results
261        true_csv (Path): Path to expected CSV results
262        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
263        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
264
265    Returns:
266        str: Diff output between result and expected CSVs
267    """
268    test_lines: List[str] = test_csv.read_text().splitlines()
269    true_lines: List[str] = true_csv.read_text().splitlines()
270    # Files should not be empty
271    if len(test_lines) == 0:
272        return f'No lines found in test output {test_csv}'
273    if len(true_lines) == 0:
274        return f'No lines found in test source {true_csv}'
275
276    test_reader: csv.DictReader = csv.DictReader(test_lines)
277    true_reader: csv.DictReader = csv.DictReader(true_lines)
278    if test_reader.fieldnames != true_reader.fieldnames:
279        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
280                       tofile='found CSV columns', fromfile='expected CSV columns'))
281
282    if len(test_lines) != len(true_lines):
283        return f'Number of lines in {test_csv} and {true_csv} do not match'
284    diff_lines: List[str] = list()
285    for test_line, true_line in zip(test_reader, true_reader):
286        for key in test_reader.fieldnames:
287            # Check if the value is numeric
288            try:
289                true_val: float = float(true_line[key])
290                test_val: float = float(test_line[key])
291                true_zero: bool = abs(true_val) < zero_tol
292                test_zero: bool = abs(test_val) < zero_tol
293                fail: bool = False
294                if true_zero:
295                    fail = not test_zero
296                else:
297                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
298                if fail:
299                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
300            except ValueError:
301                if test_line[key] != true_line[key]:
302                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
303
304    return '\n'.join(diff_lines)
305
306
307def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str:
308    """Compare CGNS results against an expected CGSN file with tolerance
309
310    Args:
311        test_cgns (Path): Path to output CGNS file
312        true_cgns (Path): Path to expected CGNS file
313        tolerance (float, optional): Tolerance for comparing floating-point values
314
315    Returns:
316        str: Diff output between result and expected CGNS files
317    """
318    my_env: dict = os.environ.copy()
319
320    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
321    proc = subprocess.run(' '.join(run_args),
322                          shell=True,
323                          stdout=subprocess.PIPE,
324                          stderr=subprocess.PIPE,
325                          env=my_env)
326
327    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
328
329
330def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
331                            backend: str, test: str, index: int) -> str:
332    output_str = ''
333    if mode is RunMode.TAP:
334        # print incremental output if TAP mode
335        if test_case.is_skipped():
336            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
337        elif test_case.is_failure() or test_case.is_error():
338            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
339        else:
340            output_str += f'    ok {index} - {spec.name}, {backend}\n'
341        output_str += f'      ---\n'
342        if spec.only:
343            output_str += f'      only: {",".join(spec.only)}\n'
344        output_str += f'      args: {test_case.args}\n'
345        if test_case.is_error():
346            output_str += f'      error: {test_case.errors[0]["message"]}\n'
347        if test_case.is_failure():
348            output_str += f'      num_failures: {len(test_case.failures)}\n'
349            for i, failure in enumerate(test_case.failures):
350                output_str += f'      failure_{i}: {failure["message"]}\n'
351                output_str += f'        message: {failure["message"]}\n'
352                if failure["output"]:
353                    out = failure["output"].strip().replace('\n', '\n          ')
354                    output_str += f'        output: |\n          {out}\n'
355        output_str += f'      ...\n'
356    else:
357        # print error or failure information if JUNIT mode
358        if test_case.is_error() or test_case.is_failure():
359            output_str += f'Test: {test} {spec.name}\n'
360            output_str += f'  $ {test_case.args}\n'
361            if test_case.is_error():
362                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
363                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
364            if test_case.is_failure():
365                for failure in test_case.failures:
366                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
367                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
368    return output_str
369
370
371def run_test(index: int, test: str, spec: TestSpec, backend: str,
372             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
373    """Run a single test case and backend combination
374
375    Args:
376        index (int): Index of backend for current spec
377        test (str): Path to test
378        spec (TestSpec): Specification of test case
379        backend (str): CEED backend
380        mode (RunMode): Output mode
381        nproc (int): Number of MPI processes to use when running test case
382        suite_spec (SuiteSpec): Specification of test suite
383
384    Returns:
385        TestCase: Test case result
386    """
387    source_path: Path = suite_spec.get_source_path(test)
388    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
389
390    if '{ceed_resource}' in run_args:
391        run_args[run_args.index('{ceed_resource}')] = backend
392    for i, arg in enumerate(run_args):
393        if '{ceed_resource}' in arg:
394            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
395    if '{nproc}' in run_args:
396        run_args[run_args.index('{nproc}')] = f'{nproc}'
397    elif nproc > 1 and source_path.suffix != '.py':
398        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
399
400    # run test
401    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
402    if skip_reason:
403        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
404                                       elapsed_sec=0,
405                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
406                                       stdout='',
407                                       stderr='',
408                                       category=spec.name,)
409        test_case.add_skipped_info(skip_reason)
410    else:
411        start: float = time.time()
412        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
413                              shell=True,
414                              stdout=subprocess.PIPE,
415                              stderr=subprocess.PIPE,
416                              env=my_env)
417
418        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
419                             classname=source_path.parent,
420                             elapsed_sec=time.time() - start,
421                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
422                             stdout=proc.stdout.decode('utf-8'),
423                             stderr=proc.stderr.decode('utf-8'),
424                             allow_multiple_subelements=True,
425                             category=spec.name,)
426        ref_csvs: List[Path] = []
427        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
428        if output_files:
429            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
430        ref_cgns: List[Path] = []
431        output_files = [arg for arg in run_args if 'cgns:' in arg]
432        if output_files:
433            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
434        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
435        suite_spec.post_test_hook(test, spec)
436
437    # check allowed failures
438    if not test_case.is_skipped() and test_case.stderr:
439        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
440        if skip_reason:
441            test_case.add_skipped_info(skip_reason)
442
443    # check required failures
444    if not test_case.is_skipped():
445        required_message, did_fail = suite_spec.check_required_failure(
446            test, spec, backend, test_case.stderr)
447        if required_message and did_fail:
448            test_case.status = f'fails with required: {required_message}'
449        elif required_message:
450            test_case.add_failure_info(f'required failure missing: {required_message}')
451
452    # classify other results
453    if not test_case.is_skipped() and not test_case.status:
454        if test_case.stderr:
455            test_case.add_failure_info('stderr', test_case.stderr)
456        if proc.returncode != 0:
457            test_case.add_error_info(f'returncode = {proc.returncode}')
458        if ref_stdout.is_file():
459            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
460                                             test_case.stdout.splitlines(keepends=True),
461                                             fromfile=str(ref_stdout),
462                                             tofile='New'))
463            if diff:
464                test_case.add_failure_info('stdout', output=''.join(diff))
465        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
466            test_case.add_failure_info('stdout', output=test_case.stdout)
467        # expected CSV output
468        for ref_csv in ref_csvs:
469            csv_name = ref_csv.name
470            if not ref_csv.is_file():
471                # remove _{ceed_backend} from path name
472                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
473            if not ref_csv.is_file():
474                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
475            else:
476                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
477                if diff:
478                    test_case.add_failure_info('csv', output=diff)
479                else:
480                    (Path.cwd() / csv_name).unlink()
481        # expected CGNS output
482        for ref_cgn in ref_cgns:
483            cgn_name = ref_cgn.name
484            if not ref_cgn.is_file():
485                # remove _{ceed_backend} from path name
486                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
487            if not ref_cgn.is_file():
488                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
489            else:
490                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn)
491                if diff:
492                    test_case.add_failure_info('cgns', output=diff)
493                else:
494                    (Path.cwd() / cgn_name).unlink()
495
496    # store result
497    test_case.args = ' '.join(str(arg) for arg in run_args)
498    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
499
500    return test_case, output_str
501
502
503def init_process():
504    """Initialize multiprocessing process"""
505    # set up error handler
506    global my_env
507    my_env = os.environ.copy()
508    my_env['CEED_ERROR_HANDLER'] = 'exit'
509
510
511def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
512              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
513    """Run all test cases for `test` with each of the provided `ceed_backends`
514
515    Args:
516        test (str): Name of test
517        ceed_backends (List[str]): List of libCEED backends
518        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
519        nproc (int): Number of MPI processes to use when running each test case
520        suite_spec (SuiteSpec): Object defining required methods for running tests
521        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
522
523    Returns:
524        TestSuite: JUnit `TestSuite` containing results of all test cases
525    """
526    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
527    if mode is RunMode.TAP:
528        print('TAP version 13')
529        print(f'1..{len(test_specs)}')
530
531    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
532        async_outputs: List[List[mp.AsyncResult]] = [
533            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
534             for (i, backend) in enumerate(ceed_backends, start=1)]
535            for spec in test_specs
536        ]
537
538        test_cases = []
539        for (i, subtest) in enumerate(async_outputs, start=1):
540            is_new_subtest = True
541            subtest_ok = True
542            for async_output in subtest:
543                test_case, print_output = async_output.get()
544                test_cases.append(test_case)
545                if is_new_subtest and mode == RunMode.TAP:
546                    is_new_subtest = False
547                    print(f'# Subtest: {test_case.category}')
548                    print(f'    1..{len(ceed_backends)}')
549                print(print_output, end='')
550                if test_case.is_failure() or test_case.is_error():
551                    subtest_ok = False
552            if mode == RunMode.TAP:
553                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
554
555    return TestSuite(test, test_cases)
556
557
558def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
559    """Write a JUnit XML file containing the results of a `TestSuite`
560
561    Args:
562        test_suite (TestSuite): JUnit `TestSuite` to write
563        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
564        batch (str): Name of JUnit batch, defaults to empty string
565    """
566    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
567    output_file.write_text(to_xml_report_string([test_suite]))
568
569
570def has_failures(test_suite: TestSuite) -> bool:
571    """Check whether any test cases in a `TestSuite` failed
572
573    Args:
574        test_suite (TestSuite): JUnit `TestSuite` to check
575
576    Returns:
577        bool: True if any test cases failed
578    """
579    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
580