xref: /honee/tests/junit_common.py (revision e941b1e9c029a850a20e06efe1e514154e2a4e67)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List, Callable
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    @property
101    def cgns_tol(self):
102        """Absolute tolerance for CGNS diff"""
103        return getattr(self, '_cgns_tol', 1.0e-12)
104
105    @cgns_tol.setter
106    def cgns_tol(self, val):
107        self._cgns_tol = val
108
109    @property
110    def diff_csv_kwargs(self):
111        """Keyword arguments to be passed to diff_csv()"""
112        return getattr(self, '_diff_csv_kwargs', {})
113
114    @diff_csv_kwargs.setter
115    def diff_csv_kwargs(self, val):
116        self._diff_csv_kwargs = val
117
118    def post_test_hook(self, test: str, spec: TestSpec) -> None:
119        """Function callback ran after each test case
120
121        Args:
122            test (str): Name of test
123            spec (TestSpec): Test case specification
124        """
125        pass
126
127    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
128        """Check if a test case should be skipped prior to running, returning the reason for skipping
129
130        Args:
131            test (str): Name of test
132            spec (TestSpec): Test case specification
133            resource (str): libCEED backend
134            nproc (int): Number of MPI processes to use when running test case
135
136        Returns:
137            Optional[str]: Skip reason, or `None` if test case should not be skipped
138        """
139        return None
140
141    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
142        """Check if a test case should be allowed to fail, based on its stderr output
143
144        Args:
145            test (str): Name of test
146            spec (TestSpec): Test case specification
147            resource (str): libCEED backend
148            stderr (str): Standard error output from test case execution
149
150        Returns:
151            Optional[str]: Skip reason, or `None` if unexpected error
152        """
153        return None
154
155    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
156        """Check whether a test case is expected to fail and if it failed expectedly
157
158        Args:
159            test (str): Name of test
160            spec (TestSpec): Test case specification
161            resource (str): libCEED backend
162            stderr (str): Standard error output from test case execution
163
164        Returns:
165            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
166        """
167        return '', True
168
169    def check_allowed_stdout(self, test: str) -> bool:
170        """Check whether a test is allowed to print console output
171
172        Args:
173            test (str): Name of test
174
175        Returns:
176            bool: True if the test is allowed to print console output
177        """
178        return False
179
180
181def has_cgnsdiff() -> bool:
182    """Check whether `cgnsdiff` is an executable program in the current environment
183
184    Returns:
185        bool: True if `cgnsdiff` is found
186    """
187    my_env: dict = os.environ.copy()
188    proc = subprocess.run('cgnsdiff',
189                          shell=True,
190                          stdout=subprocess.PIPE,
191                          stderr=subprocess.PIPE,
192                          env=my_env)
193    return 'not found' not in proc.stderr.decode('utf-8')
194
195
196def contains_any(base: str, substrings: List[str]) -> bool:
197    """Helper function, checks if any of the substrings are included in the base string
198
199    Args:
200        base (str): Base string to search in
201        substrings (List[str]): List of potential substrings
202
203    Returns:
204        bool: True if any substrings are included in base string
205    """
206    return any((sub in base for sub in substrings))
207
208
209def startswith_any(base: str, prefixes: List[str]) -> bool:
210    """Helper function, checks if the base string is prefixed by any of `prefixes`
211
212    Args:
213        base (str): Base string to search
214        prefixes (List[str]): List of potential prefixes
215
216    Returns:
217        bool: True if base string is prefixed by any of the prefixes
218    """
219    return any((base.startswith(prefix) for prefix in prefixes))
220
221
222def parse_test_line(line: str) -> TestSpec:
223    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
224
225    Args:
226        line (str): String containing TESTARGS specification and CLI arguments
227
228    Returns:
229        TestSpec: Parsed specification of test case
230    """
231    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
232    if args[0] == 'TESTARGS':
233        return TestSpec(name='', args=args[1:])
234    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
235    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
236    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
237    name: str = test_args.get('name', '')
238    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
239    if len(args) > 1:
240        return TestSpec(name=name, only=constraints, args=args[1:])
241    else:
242        return TestSpec(name=name, only=constraints)
243
244
245def get_test_args(source_file: Path) -> List[TestSpec]:
246    """Parse all test cases from a given source file
247
248    Args:
249        source_file (Path): Path to source file
250
251    Raises:
252        RuntimeError: Errors if source file extension is unsupported
253
254    Returns:
255        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
256    """
257    comment_str: str = ''
258    if source_file.suffix in ['.c', '.cc', '.cpp']:
259        comment_str = '//'
260    elif source_file.suffix in ['.py']:
261        comment_str = '#'
262    elif source_file.suffix in ['.usr']:
263        comment_str = 'C_'
264    elif source_file.suffix in ['.f90']:
265        comment_str = '! '
266    else:
267        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
268
269    return [parse_test_line(line.strip(comment_str))
270            for line in source_file.read_text().splitlines()
271            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
272
273
274def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2,
275             comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
276    """Compare CSV results against an expected CSV file with tolerances
277
278    Args:
279        test_csv (Path): Path to output CSV results
280        true_csv (Path): Path to expected CSV results
281        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
282        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
283        comment_str (str, optional): String to denoting commented line
284        comment_func (Callable, optional): Function to determine if test and true line are different
285
286    Returns:
287        str: Diff output between result and expected CSVs
288    """
289    test_lines: List[str] = test_csv.read_text().splitlines()
290    true_lines: List[str] = true_csv.read_text().splitlines()
291    # Files should not be empty
292    if len(test_lines) == 0:
293        return f'No lines found in test output {test_csv}'
294    if len(true_lines) == 0:
295        return f'No lines found in test source {true_csv}'
296    if len(test_lines) != len(true_lines):
297        return f'Number of lines in {test_csv} and {true_csv} do not match'
298
299    # Process commented lines
300    uncommented_lines: List[int] = []
301    for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)):
302        if test_line[0] == comment_str and true_line[0] == comment_str:
303            if comment_func:
304                output = comment_func(test_line, true_line)
305                if output:
306                    return output
307        elif test_line[0] == comment_str and true_line[0] != comment_str:
308            return f'Commented line found in {test_csv} at line {n} but not in {true_csv}'
309        elif test_line[0] != comment_str and true_line[0] == comment_str:
310            return f'Commented line found in {true_csv} at line {n} but not in {test_csv}'
311        else:
312            uncommented_lines.append(n)
313
314    # Remove commented lines
315    test_lines = [test_lines[line] for line in uncommented_lines]
316    true_lines = [true_lines[line] for line in uncommented_lines]
317
318    test_reader: csv.DictReader = csv.DictReader(test_lines)
319    true_reader: csv.DictReader = csv.DictReader(true_lines)
320    if test_reader.fieldnames != true_reader.fieldnames:
321        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
322                       tofile='found CSV columns', fromfile='expected CSV columns'))
323
324    diff_lines: List[str] = list()
325    for test_line, true_line in zip(test_reader, true_reader):
326        for key in test_reader.fieldnames:
327            # Check if the value is numeric
328            try:
329                true_val: float = float(true_line[key])
330                test_val: float = float(test_line[key])
331                true_zero: bool = abs(true_val) < zero_tol
332                test_zero: bool = abs(test_val) < zero_tol
333                fail: bool = False
334                if true_zero:
335                    fail = not test_zero
336                else:
337                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
338                if fail:
339                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
340            except ValueError:
341                if test_line[key] != true_line[key]:
342                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
343
344    return '\n'.join(diff_lines)
345
346
347def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str:
348    """Compare CGNS results against an expected CGSN file with tolerance
349
350    Args:
351        test_cgns (Path): Path to output CGNS file
352        true_cgns (Path): Path to expected CGNS file
353        cgns_tol (float, optional): Tolerance for comparing floating-point values
354
355    Returns:
356        str: Diff output between result and expected CGNS files
357    """
358    my_env: dict = os.environ.copy()
359
360    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)]
361    proc = subprocess.run(' '.join(run_args),
362                          shell=True,
363                          stdout=subprocess.PIPE,
364                          stderr=subprocess.PIPE,
365                          env=my_env)
366
367    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
368
369
370def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
371                            backend: str, test: str, index: int) -> str:
372    output_str = ''
373    if mode is RunMode.TAP:
374        # print incremental output if TAP mode
375        if test_case.is_skipped():
376            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
377        elif test_case.is_failure() or test_case.is_error():
378            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
379        else:
380            output_str += f'    ok {index} - {spec.name}, {backend}\n'
381        output_str += f'      ---\n'
382        if spec.only:
383            output_str += f'      only: {",".join(spec.only)}\n'
384        output_str += f'      args: {test_case.args}\n'
385        if test_case.is_error():
386            output_str += f'      error: {test_case.errors[0]["message"]}\n'
387        if test_case.is_failure():
388            output_str += f'      num_failures: {len(test_case.failures)}\n'
389            for i, failure in enumerate(test_case.failures):
390                output_str += f'      failure_{i}: {failure["message"]}\n'
391                output_str += f'        message: {failure["message"]}\n'
392                if failure["output"]:
393                    out = failure["output"].strip().replace('\n', '\n          ')
394                    output_str += f'        output: |\n          {out}\n'
395        output_str += f'      ...\n'
396    else:
397        # print error or failure information if JUNIT mode
398        if test_case.is_error() or test_case.is_failure():
399            output_str += f'Test: {test} {spec.name}\n'
400            output_str += f'  $ {test_case.args}\n'
401            if test_case.is_error():
402                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
403                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
404            if test_case.is_failure():
405                for failure in test_case.failures:
406                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
407                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
408    return output_str
409
410
411def run_test(index: int, test: str, spec: TestSpec, backend: str,
412             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
413    """Run a single test case and backend combination
414
415    Args:
416        index (int): Index of backend for current spec
417        test (str): Path to test
418        spec (TestSpec): Specification of test case
419        backend (str): CEED backend
420        mode (RunMode): Output mode
421        nproc (int): Number of MPI processes to use when running test case
422        suite_spec (SuiteSpec): Specification of test suite
423
424    Returns:
425        TestCase: Test case result
426    """
427    source_path: Path = suite_spec.get_source_path(test)
428    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
429
430    if '{ceed_resource}' in run_args:
431        run_args[run_args.index('{ceed_resource}')] = backend
432    for i, arg in enumerate(run_args):
433        if '{ceed_resource}' in arg:
434            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
435    if '{nproc}' in run_args:
436        run_args[run_args.index('{nproc}')] = f'{nproc}'
437    elif nproc > 1 and source_path.suffix != '.py':
438        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
439
440    # run test
441    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
442    if skip_reason:
443        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
444                                       elapsed_sec=0,
445                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
446                                       stdout='',
447                                       stderr='',
448                                       category=spec.name,)
449        test_case.add_skipped_info(skip_reason)
450    else:
451        start: float = time.time()
452        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
453                              shell=True,
454                              stdout=subprocess.PIPE,
455                              stderr=subprocess.PIPE,
456                              env=my_env)
457
458        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
459                             classname=source_path.parent,
460                             elapsed_sec=time.time() - start,
461                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
462                             stdout=proc.stdout.decode('utf-8'),
463                             stderr=proc.stderr.decode('utf-8'),
464                             allow_multiple_subelements=True,
465                             category=spec.name,)
466        ref_csvs: List[Path] = []
467        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
468        if output_files:
469            ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files]
470        ref_cgns: List[Path] = []
471        output_files = [arg for arg in run_args if 'cgns:' in arg]
472        if output_files:
473            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
474        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
475        suite_spec.post_test_hook(test, spec)
476
477    # check allowed failures
478    if not test_case.is_skipped() and test_case.stderr:
479        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
480        if skip_reason:
481            test_case.add_skipped_info(skip_reason)
482
483    # check required failures
484    if not test_case.is_skipped():
485        required_message, did_fail = suite_spec.check_required_failure(
486            test, spec, backend, test_case.stderr)
487        if required_message and did_fail:
488            test_case.status = f'fails with required: {required_message}'
489        elif required_message:
490            test_case.add_failure_info(f'required failure missing: {required_message}')
491
492    # classify other results
493    if not test_case.is_skipped() and not test_case.status:
494        if test_case.stderr:
495            test_case.add_failure_info('stderr', test_case.stderr)
496        if proc.returncode != 0:
497            test_case.add_error_info(f'returncode = {proc.returncode}')
498        if ref_stdout.is_file():
499            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
500                                             test_case.stdout.splitlines(keepends=True),
501                                             fromfile=str(ref_stdout),
502                                             tofile='New'))
503            if diff:
504                test_case.add_failure_info('stdout', output=''.join(diff))
505        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
506            test_case.add_failure_info('stdout', output=test_case.stdout)
507        # expected CSV output
508        for ref_csv in ref_csvs:
509            csv_name = ref_csv.name
510            if not ref_csv.is_file():
511                # remove _{ceed_backend} from path name
512                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
513            if not ref_csv.is_file():
514                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
515            else:
516                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs)
517                if diff:
518                    test_case.add_failure_info('csv', output=diff)
519                else:
520                    (Path.cwd() / csv_name).unlink()
521        # expected CGNS output
522        for ref_cgn in ref_cgns:
523            cgn_name = ref_cgn.name
524            if not ref_cgn.is_file():
525                # remove _{ceed_backend} from path name
526                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
527            if not ref_cgn.is_file():
528                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
529            else:
530                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol)
531                if diff:
532                    test_case.add_failure_info('cgns', output=diff)
533                else:
534                    (Path.cwd() / cgn_name).unlink()
535
536    # store result
537    test_case.args = ' '.join(str(arg) for arg in run_args)
538    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
539
540    return test_case, output_str
541
542
543def init_process():
544    """Initialize multiprocessing process"""
545    # set up error handler
546    global my_env
547    my_env = os.environ.copy()
548    my_env['CEED_ERROR_HANDLER'] = 'exit'
549
550
551def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
552              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
553    """Run all test cases for `test` with each of the provided `ceed_backends`
554
555    Args:
556        test (str): Name of test
557        ceed_backends (List[str]): List of libCEED backends
558        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
559        nproc (int): Number of MPI processes to use when running each test case
560        suite_spec (SuiteSpec): Object defining required methods for running tests
561        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
562
563    Returns:
564        TestSuite: JUnit `TestSuite` containing results of all test cases
565    """
566    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
567    if mode is RunMode.TAP:
568        print('TAP version 13')
569        print(f'1..{len(test_specs)}')
570
571    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
572        async_outputs: List[List[mp.AsyncResult]] = [
573            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
574             for (i, backend) in enumerate(ceed_backends, start=1)]
575            for spec in test_specs
576        ]
577
578        test_cases = []
579        for (i, subtest) in enumerate(async_outputs, start=1):
580            is_new_subtest = True
581            subtest_ok = True
582            for async_output in subtest:
583                test_case, print_output = async_output.get()
584                test_cases.append(test_case)
585                if is_new_subtest and mode == RunMode.TAP:
586                    is_new_subtest = False
587                    print(f'# Subtest: {test_case.category}')
588                    print(f'    1..{len(ceed_backends)}')
589                print(print_output, end='')
590                if test_case.is_failure() or test_case.is_error():
591                    subtest_ok = False
592            if mode == RunMode.TAP:
593                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
594
595    return TestSuite(test, test_cases)
596
597
598def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
599    """Write a JUnit XML file containing the results of a `TestSuite`
600
601    Args:
602        test_suite (TestSuite): JUnit `TestSuite` to write
603        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
604        batch (str): Name of JUnit batch, defaults to empty string
605    """
606    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
607    output_file.write_text(to_xml_report_string([test_suite]))
608
609
610def has_failures(test_suite: TestSuite) -> bool:
611    """Check whether any test cases in a `TestSuite` failed
612
613    Args:
614        test_suite (TestSuite): JUnit `TestSuite` to check
615
616    Returns:
617        bool: True if any test cases failed
618    """
619    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
620