xref: /libCEED/tests/junit_common.py (revision 40b22b27240d17420c09a578f3aa6802b0d55a72)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List, Callable
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    @property
101    def cgns_tol(self):
102        """Absolute tolerance for CGNS diff"""
103        return getattr(self, '_cgns_tol', 1.0e-12)
104
105    @cgns_tol.setter
106    def cgns_tol(self, val):
107        self._cgns_tol = val
108
109    @property
110    def diff_csv_kwargs(self):
111        """Keyword arguments to be passed to diff_csv()"""
112        return getattr(self, '_diff_csv_kwargs', {})
113
114    @diff_csv_kwargs.setter
115    def diff_csv_kwargs(self, val):
116        self._diff_csv_kwargs = val
117
118    def post_test_hook(self, test: str, spec: TestSpec) -> None:
119        """Function callback ran after each test case
120
121        Args:
122            test (str): Name of test
123            spec (TestSpec): Test case specification
124        """
125        pass
126
127    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
128        """Check if a test case should be skipped prior to running, returning the reason for skipping
129
130        Args:
131            test (str): Name of test
132            spec (TestSpec): Test case specification
133            resource (str): libCEED backend
134            nproc (int): Number of MPI processes to use when running test case
135
136        Returns:
137            Optional[str]: Skip reason, or `None` if test case should not be skipped
138        """
139        return None
140
141    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
142        """Check if a test case should be allowed to fail, based on its stderr output
143
144        Args:
145            test (str): Name of test
146            spec (TestSpec): Test case specification
147            resource (str): libCEED backend
148            stderr (str): Standard error output from test case execution
149
150        Returns:
151            Optional[str]: Skip reason, or `None` if unexpected error
152        """
153        return None
154
155    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
156        """Check whether a test case is expected to fail and if it failed expectedly
157
158        Args:
159            test (str): Name of test
160            spec (TestSpec): Test case specification
161            resource (str): libCEED backend
162            stderr (str): Standard error output from test case execution
163
164        Returns:
165            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
166        """
167        return '', True
168
169    def check_allowed_stdout(self, test: str) -> bool:
170        """Check whether a test is allowed to print console output
171
172        Args:
173            test (str): Name of test
174
175        Returns:
176            bool: True if the test is allowed to print console output
177        """
178        return False
179
180
181def has_cgnsdiff() -> bool:
182    """Check whether `cgnsdiff` is an executable program in the current environment
183
184    Returns:
185        bool: True if `cgnsdiff` is found
186    """
187    my_env: dict = os.environ.copy()
188    proc = subprocess.run('cgnsdiff',
189                          shell=True,
190                          stdout=subprocess.PIPE,
191                          stderr=subprocess.PIPE,
192                          env=my_env)
193    return 'not found' not in proc.stderr.decode('utf-8')
194
195
196def contains_any(base: str, substrings: List[str]) -> bool:
197    """Helper function, checks if any of the substrings are included in the base string
198
199    Args:
200        base (str): Base string to search in
201        substrings (List[str]): List of potential substrings
202
203    Returns:
204        bool: True if any substrings are included in base string
205    """
206    return any((sub in base for sub in substrings))
207
208
209def startswith_any(base: str, prefixes: List[str]) -> bool:
210    """Helper function, checks if the base string is prefixed by any of `prefixes`
211
212    Args:
213        base (str): Base string to search
214        prefixes (List[str]): List of potential prefixes
215
216    Returns:
217        bool: True if base string is prefixed by any of the prefixes
218    """
219    return any((base.startswith(prefix) for prefix in prefixes))
220
221
222def parse_test_line(line: str) -> TestSpec:
223    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
224
225    Args:
226        line (str): String containing TESTARGS specification and CLI arguments
227
228    Returns:
229        TestSpec: Parsed specification of test case
230    """
231    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
232    if args[0] == 'TESTARGS':
233        return TestSpec(name='', args=args[1:])
234    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
235    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
236    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
237    name: str = test_args.get('name', '')
238    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
239    if len(args) > 1:
240        return TestSpec(name=name, only=constraints, args=args[1:])
241    else:
242        return TestSpec(name=name, only=constraints)
243
244
245def get_test_args(source_file: Path) -> List[TestSpec]:
246    """Parse all test cases from a given source file
247
248    Args:
249        source_file (Path): Path to source file
250
251    Raises:
252        RuntimeError: Errors if source file extension is unsupported
253
254    Returns:
255        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
256    """
257    comment_str: str = ''
258    if source_file.suffix in ['.c', '.cc', '.cpp']:
259        comment_str = '//'
260    elif source_file.suffix in ['.py']:
261        comment_str = '#'
262    elif source_file.suffix in ['.usr']:
263        comment_str = 'C_'
264    elif source_file.suffix in ['.f90']:
265        comment_str = '! '
266    else:
267        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
268
269    return [parse_test_line(line.strip(comment_str))
270            for line in source_file.read_text().splitlines()
271            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
272
273
274def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2,
275             comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
276    """Compare CSV results against an expected CSV file with tolerances
277
278    Args:
279        test_csv (Path): Path to output CSV results
280        true_csv (Path): Path to expected CSV results
281        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
282        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
283        comment_str (str, optional): String to denoting commented line
284        comment_func (Callable, optional): Function to determine if test and true line are different
285
286    Returns:
287        str: Diff output between result and expected CSVs
288    """
289    test_lines: List[str] = test_csv.read_text().splitlines()
290    true_lines: List[str] = true_csv.read_text().splitlines()
291    # Files should not be empty
292    if len(test_lines) == 0:
293        return f'No lines found in test output {test_csv}'
294    if len(true_lines) == 0:
295        return f'No lines found in test source {true_csv}'
296    if len(test_lines) != len(true_lines):
297        return f'Number of lines in {test_csv} and {true_csv} do not match'
298
299    # Process commented lines
300    uncommented_lines: List[int] = []
301    for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)):
302        if test_line[0] == comment_str and true_line[0] == comment_str:
303            if comment_func:
304                output = comment_func(test_line, true_line)
305                if output:
306                    return output
307        elif test_line[0] == comment_str and true_line[0] != comment_str:
308            return f'Commented line found in {test_csv} at line {n} but not in {true_csv}'
309        elif test_line[0] != comment_str and true_line[0] == comment_str:
310            return f'Commented line found in {true_csv} at line {n} but not in {test_csv}'
311        else:
312            uncommented_lines.append(n)
313
314    # Remove commented lines
315    test_lines = [test_lines[line] for line in uncommented_lines]
316    true_lines = [true_lines[line] for line in uncommented_lines]
317
318    test_reader: csv.DictReader = csv.DictReader(test_lines)
319    true_reader: csv.DictReader = csv.DictReader(true_lines)
320    if test_reader.fieldnames != true_reader.fieldnames:
321        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
322                       tofile='found CSV columns', fromfile='expected CSV columns'))
323
324    diff_lines: List[str] = list()
325    for test_line, true_line in zip(test_reader, true_reader):
326        for key in test_reader.fieldnames:
327            # Check if the value is numeric
328            try:
329                true_val: float = float(true_line[key])
330                test_val: float = float(test_line[key])
331                true_zero: bool = abs(true_val) < zero_tol
332                test_zero: bool = abs(test_val) < zero_tol
333                fail: bool = False
334                if true_zero:
335                    fail = not test_zero
336                else:
337                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
338                if fail:
339                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
340            except ValueError:
341                if test_line[key] != true_line[key]:
342                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
343
344    return '\n'.join(diff_lines)
345
346
347def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str:
348    """Compare CGNS results against an expected CGSN file with tolerance
349
350    Args:
351        test_cgns (Path): Path to output CGNS file
352        true_cgns (Path): Path to expected CGNS file
353        cgns_tol (float, optional): Tolerance for comparing floating-point values
354
355    Returns:
356        str: Diff output between result and expected CGNS files
357    """
358    my_env: dict = os.environ.copy()
359
360    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)]
361    proc = subprocess.run(' '.join(run_args),
362                          shell=True,
363                          stdout=subprocess.PIPE,
364                          stderr=subprocess.PIPE,
365                          env=my_env)
366
367    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
368
369
370def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
371                            backend: str, test: str, index: int) -> str:
372    output_str = ''
373    if mode is RunMode.TAP:
374        # print incremental output if TAP mode
375        if test_case.is_skipped():
376            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
377        elif test_case.is_failure() or test_case.is_error():
378            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
379        else:
380            output_str += f'    ok {index} - {spec.name}, {backend}\n'
381        output_str += f'      ---\n'
382        if spec.only:
383            output_str += f'      only: {",".join(spec.only)}\n'
384        output_str += f'      args: {test_case.args}\n'
385        if test_case.is_error():
386            output_str += f'      error: {test_case.errors[0]["message"]}\n'
387        if test_case.is_failure():
388            output_str += f'      num_failures: {len(test_case.failures)}\n'
389            for i, failure in enumerate(test_case.failures):
390                output_str += f'      failure_{i}: {failure["message"]}\n'
391                output_str += f'        message: {failure["message"]}\n'
392                if failure["output"]:
393                    out = failure["output"].strip().replace('\n', '\n          ')
394                    output_str += f'        output: |\n          {out}\n'
395        output_str += f'      ...\n'
396    else:
397        # print error or failure information if JUNIT mode
398        if test_case.is_error() or test_case.is_failure():
399            output_str += f'Test: {test} {spec.name}\n'
400            output_str += f'  $ {test_case.args}\n'
401            if test_case.is_error():
402                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
403                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
404            if test_case.is_failure():
405                for failure in test_case.failures:
406                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
407                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
408    return output_str
409
410
411def run_test(index: int, test: str, spec: TestSpec, backend: str,
412             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
413    """Run a single test case and backend combination
414
415    Args:
416        index (int): Index of backend for current spec
417        test (str): Path to test
418        spec (TestSpec): Specification of test case
419        backend (str): CEED backend
420        mode (RunMode): Output mode
421        nproc (int): Number of MPI processes to use when running test case
422        suite_spec (SuiteSpec): Specification of test suite
423
424    Returns:
425        TestCase: Test case result
426    """
427    source_path: Path = suite_spec.get_source_path(test)
428    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
429
430    if '{ceed_resource}' in run_args:
431        run_args[run_args.index('{ceed_resource}')] = backend
432    for i, arg in enumerate(run_args):
433        if '{ceed_resource}' in arg:
434            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
435    if '{nproc}' in run_args:
436        run_args[run_args.index('{nproc}')] = f'{nproc}'
437    elif nproc > 1 and source_path.suffix != '.py':
438        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
439
440    # run test
441    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
442    if skip_reason:
443        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
444                                       elapsed_sec=0,
445                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
446                                       stdout='',
447                                       stderr='',
448                                       category=spec.name,)
449        test_case.add_skipped_info(skip_reason)
450    else:
451        start: float = time.time()
452        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
453                              shell=True,
454                              stdout=subprocess.PIPE,
455                              stderr=subprocess.PIPE,
456                              env=my_env)
457
458        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
459                             classname=source_path.parent,
460                             elapsed_sec=time.time() - start,
461                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
462                             stdout=proc.stdout.decode('utf-8'),
463                             stderr=proc.stderr.decode('utf-8'),
464                             allow_multiple_subelements=True,
465                             category=spec.name,)
466        ref_csvs: List[Path] = []
467        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
468        if output_files:
469            ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files]
470        ref_cgns: List[Path] = []
471        output_files = [arg for arg in run_args if 'cgns:' in arg]
472        if output_files:
473            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
474        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
475        suite_spec.post_test_hook(test, spec)
476
477    # check allowed failures
478    if not test_case.is_skipped() and test_case.stderr:
479        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
480        if skip_reason:
481            test_case.add_skipped_info(skip_reason)
482
483    # check required failures
484    if not test_case.is_skipped():
485        required_message, did_fail = suite_spec.check_required_failure(
486            test, spec, backend, test_case.stderr)
487        if required_message and did_fail:
488            test_case.status = f'fails with required: {required_message}'
489        elif required_message:
490            test_case.add_failure_info(f'required failure missing: {required_message}')
491
492    # classify other results
493    if not test_case.is_skipped() and not test_case.status:
494        if test_case.stderr:
495            test_case.add_failure_info('stderr', test_case.stderr)
496        if proc.returncode != 0:
497            test_case.add_error_info(f'returncode = {proc.returncode}')
498        if ref_stdout.is_file():
499            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
500                                             test_case.stdout.splitlines(keepends=True),
501                                             fromfile=str(ref_stdout),
502                                             tofile='New'))
503            if diff:
504                test_case.add_failure_info('stdout', output=''.join(diff))
505        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
506            test_case.add_failure_info('stdout', output=test_case.stdout)
507        # expected CSV output
508        for ref_csv in ref_csvs:
509            csv_name = ref_csv.name
510            if not ref_csv.is_file():
511                # remove _{ceed_backend} from path name
512                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
513            if not ref_csv.is_file():
514                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
515            elif not (Path.cwd() / csv_name).is_file():
516                test_case.add_failure_info('csv', output=f'{csv_name} not found')
517            else:
518                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs)
519                if diff:
520                    test_case.add_failure_info('csv', output=diff)
521                else:
522                    (Path.cwd() / csv_name).unlink()
523        # expected CGNS output
524        for ref_cgn in ref_cgns:
525            cgn_name = ref_cgn.name
526            if not ref_cgn.is_file():
527                # remove _{ceed_backend} from path name
528                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
529            if not ref_cgn.is_file():
530                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
531            elif not (Path.cwd() / cgn_name).is_file():
532                test_case.add_failure_info('csv', output=f'{cgn_name} not found')
533            else:
534                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol)
535                if diff:
536                    test_case.add_failure_info('cgns', output=diff)
537                else:
538                    (Path.cwd() / cgn_name).unlink()
539
540    # store result
541    test_case.args = ' '.join(str(arg) for arg in run_args)
542    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
543
544    return test_case, output_str
545
546
547def init_process():
548    """Initialize multiprocessing process"""
549    # set up error handler
550    global my_env
551    my_env = os.environ.copy()
552    my_env['CEED_ERROR_HANDLER'] = 'exit'
553
554
555def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
556              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
557    """Run all test cases for `test` with each of the provided `ceed_backends`
558
559    Args:
560        test (str): Name of test
561        ceed_backends (List[str]): List of libCEED backends
562        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
563        nproc (int): Number of MPI processes to use when running each test case
564        suite_spec (SuiteSpec): Object defining required methods for running tests
565        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
566
567    Returns:
568        TestSuite: JUnit `TestSuite` containing results of all test cases
569    """
570    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
571    if mode is RunMode.TAP:
572        print('TAP version 13')
573        print(f'1..{len(test_specs)}')
574
575    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
576        async_outputs: List[List[mp.AsyncResult]] = [
577            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
578             for (i, backend) in enumerate(ceed_backends, start=1)]
579            for spec in test_specs
580        ]
581
582        test_cases = []
583        for (i, subtest) in enumerate(async_outputs, start=1):
584            is_new_subtest = True
585            subtest_ok = True
586            for async_output in subtest:
587                test_case, print_output = async_output.get()
588                test_cases.append(test_case)
589                if is_new_subtest and mode == RunMode.TAP:
590                    is_new_subtest = False
591                    print(f'# Subtest: {test_case.category}')
592                    print(f'    1..{len(ceed_backends)}')
593                print(print_output, end='')
594                if test_case.is_failure() or test_case.is_error():
595                    subtest_ok = False
596            if mode == RunMode.TAP:
597                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
598
599    return TestSuite(test, test_cases)
600
601
602def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
603    """Write a JUnit XML file containing the results of a `TestSuite`
604
605    Args:
606        test_suite (TestSuite): JUnit `TestSuite` to write
607        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
608        batch (str): Name of JUnit batch, defaults to empty string
609    """
610    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
611    output_file.write_text(to_xml_report_string([test_suite]))
612
613
614def has_failures(test_suite: TestSuite) -> bool:
615    """Check whether any test cases in a `TestSuite` failed
616
617    Args:
618        test_suite (TestSuite): JUnit `TestSuite` to check
619
620    Returns:
621        bool: True if any test cases failed
622    """
623    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
624