xref: /libCEED/tests/junit_common.py (revision 83ebc4c489f36457f06b28184a0de964e4323c95)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    def get_cgns_tol(self) -> float:
101        """retrieve CGNS test tolerance.
102
103        Returns:
104            tolerance (float): Test tolerance
105        """
106        return self.cgns_tol if hasattr(self, 'cgns_tol') else 1.0e-12
107
108    def post_test_hook(self, test: str, spec: TestSpec) -> None:
109        """Function callback ran after each test case
110
111        Args:
112            test (str): Name of test
113            spec (TestSpec): Test case specification
114        """
115        pass
116
117    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
118        """Check if a test case should be skipped prior to running, returning the reason for skipping
119
120        Args:
121            test (str): Name of test
122            spec (TestSpec): Test case specification
123            resource (str): libCEED backend
124            nproc (int): Number of MPI processes to use when running test case
125
126        Returns:
127            Optional[str]: Skip reason, or `None` if test case should not be skipped
128        """
129        return None
130
131    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
132        """Check if a test case should be allowed to fail, based on its stderr output
133
134        Args:
135            test (str): Name of test
136            spec (TestSpec): Test case specification
137            resource (str): libCEED backend
138            stderr (str): Standard error output from test case execution
139
140        Returns:
141            Optional[str]: Skip reason, or `None` if unexpected error
142        """
143        return None
144
145    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
146        """Check whether a test case is expected to fail and if it failed expectedly
147
148        Args:
149            test (str): Name of test
150            spec (TestSpec): Test case specification
151            resource (str): libCEED backend
152            stderr (str): Standard error output from test case execution
153
154        Returns:
155            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
156        """
157        return '', True
158
159    def check_allowed_stdout(self, test: str) -> bool:
160        """Check whether a test is allowed to print console output
161
162        Args:
163            test (str): Name of test
164
165        Returns:
166            bool: True if the test is allowed to print console output
167        """
168        return False
169
170
171def has_cgnsdiff() -> bool:
172    """Check whether `cgnsdiff` is an executable program in the current environment
173
174    Returns:
175        bool: True if `cgnsdiff` is found
176    """
177    my_env: dict = os.environ.copy()
178    proc = subprocess.run('cgnsdiff',
179                          shell=True,
180                          stdout=subprocess.PIPE,
181                          stderr=subprocess.PIPE,
182                          env=my_env)
183    return 'not found' not in proc.stderr.decode('utf-8')
184
185
186def contains_any(base: str, substrings: List[str]) -> bool:
187    """Helper function, checks if any of the substrings are included in the base string
188
189    Args:
190        base (str): Base string to search in
191        substrings (List[str]): List of potential substrings
192
193    Returns:
194        bool: True if any substrings are included in base string
195    """
196    return any((sub in base for sub in substrings))
197
198
199def startswith_any(base: str, prefixes: List[str]) -> bool:
200    """Helper function, checks if the base string is prefixed by any of `prefixes`
201
202    Args:
203        base (str): Base string to search
204        prefixes (List[str]): List of potential prefixes
205
206    Returns:
207        bool: True if base string is prefixed by any of the prefixes
208    """
209    return any((base.startswith(prefix) for prefix in prefixes))
210
211
212def parse_test_line(line: str) -> TestSpec:
213    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
214
215    Args:
216        line (str): String containing TESTARGS specification and CLI arguments
217
218    Returns:
219        TestSpec: Parsed specification of test case
220    """
221    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
222    if args[0] == 'TESTARGS':
223        return TestSpec(name='', args=args[1:])
224    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
225    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
226    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
227    name: str = test_args.get('name', '')
228    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
229    if len(args) > 1:
230        return TestSpec(name=name, only=constraints, args=args[1:])
231    else:
232        return TestSpec(name=name, only=constraints)
233
234
235def get_test_args(source_file: Path) -> List[TestSpec]:
236    """Parse all test cases from a given source file
237
238    Args:
239        source_file (Path): Path to source file
240
241    Raises:
242        RuntimeError: Errors if source file extension is unsupported
243
244    Returns:
245        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
246    """
247    comment_str: str = ''
248    if source_file.suffix in ['.c', '.cc', '.cpp']:
249        comment_str = '//'
250    elif source_file.suffix in ['.py']:
251        comment_str = '#'
252    elif source_file.suffix in ['.usr']:
253        comment_str = 'C_'
254    elif source_file.suffix in ['.f90']:
255        comment_str = '! '
256    else:
257        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
258
259    return [parse_test_line(line.strip(comment_str))
260            for line in source_file.read_text().splitlines()
261            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
262
263
264def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
265    """Compare CSV results against an expected CSV file with tolerances
266
267    Args:
268        test_csv (Path): Path to output CSV results
269        true_csv (Path): Path to expected CSV results
270        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
271        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
272
273    Returns:
274        str: Diff output between result and expected CSVs
275    """
276    test_lines: List[str] = test_csv.read_text().splitlines()
277    true_lines: List[str] = true_csv.read_text().splitlines()
278    # Files should not be empty
279    if len(test_lines) == 0:
280        return f'No lines found in test output {test_csv}'
281    if len(true_lines) == 0:
282        return f'No lines found in test source {true_csv}'
283
284    test_reader: csv.DictReader = csv.DictReader(test_lines)
285    true_reader: csv.DictReader = csv.DictReader(true_lines)
286    if test_reader.fieldnames != true_reader.fieldnames:
287        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
288                       tofile='found CSV columns', fromfile='expected CSV columns'))
289
290    if len(test_lines) != len(true_lines):
291        return f'Number of lines in {test_csv} and {true_csv} do not match'
292    diff_lines: List[str] = list()
293    for test_line, true_line in zip(test_reader, true_reader):
294        for key in test_reader.fieldnames:
295            # Check if the value is numeric
296            try:
297                true_val: float = float(true_line[key])
298                test_val: float = float(test_line[key])
299                true_zero: bool = abs(true_val) < zero_tol
300                test_zero: bool = abs(test_val) < zero_tol
301                fail: bool = False
302                if true_zero:
303                    fail = not test_zero
304                else:
305                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
306                if fail:
307                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
308            except ValueError:
309                if test_line[key] != true_line[key]:
310                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
311
312    return '\n'.join(diff_lines)
313
314
315def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str:
316    """Compare CGNS results against an expected CGSN file with tolerance
317
318    Args:
319        test_cgns (Path): Path to output CGNS file
320        true_cgns (Path): Path to expected CGNS file
321        cgns_tol (float, optional): Tolerance for comparing floating-point values
322
323    Returns:
324        str: Diff output between result and expected CGNS files
325    """
326    my_env: dict = os.environ.copy()
327
328    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)]
329    proc = subprocess.run(' '.join(run_args),
330                          shell=True,
331                          stdout=subprocess.PIPE,
332                          stderr=subprocess.PIPE,
333                          env=my_env)
334
335    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
336
337
338def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
339                            backend: str, test: str, index: int) -> str:
340    output_str = ''
341    if mode is RunMode.TAP:
342        # print incremental output if TAP mode
343        if test_case.is_skipped():
344            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
345        elif test_case.is_failure() or test_case.is_error():
346            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
347        else:
348            output_str += f'    ok {index} - {spec.name}, {backend}\n'
349        output_str += f'      ---\n'
350        if spec.only:
351            output_str += f'      only: {",".join(spec.only)}\n'
352        output_str += f'      args: {test_case.args}\n'
353        if test_case.is_error():
354            output_str += f'      error: {test_case.errors[0]["message"]}\n'
355        if test_case.is_failure():
356            output_str += f'      num_failures: {len(test_case.failures)}\n'
357            for i, failure in enumerate(test_case.failures):
358                output_str += f'      failure_{i}: {failure["message"]}\n'
359                output_str += f'        message: {failure["message"]}\n'
360                if failure["output"]:
361                    out = failure["output"].strip().replace('\n', '\n          ')
362                    output_str += f'        output: |\n          {out}\n'
363        output_str += f'      ...\n'
364    else:
365        # print error or failure information if JUNIT mode
366        if test_case.is_error() or test_case.is_failure():
367            output_str += f'Test: {test} {spec.name}\n'
368            output_str += f'  $ {test_case.args}\n'
369            if test_case.is_error():
370                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
371                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
372            if test_case.is_failure():
373                for failure in test_case.failures:
374                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
375                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
376    return output_str
377
378
379def run_test(index: int, test: str, spec: TestSpec, backend: str,
380             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
381    """Run a single test case and backend combination
382
383    Args:
384        index (int): Index of backend for current spec
385        test (str): Path to test
386        spec (TestSpec): Specification of test case
387        backend (str): CEED backend
388        mode (RunMode): Output mode
389        nproc (int): Number of MPI processes to use when running test case
390        suite_spec (SuiteSpec): Specification of test suite
391
392    Returns:
393        TestCase: Test case result
394    """
395    source_path: Path = suite_spec.get_source_path(test)
396    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
397
398    if '{ceed_resource}' in run_args:
399        run_args[run_args.index('{ceed_resource}')] = backend
400    for i, arg in enumerate(run_args):
401        if '{ceed_resource}' in arg:
402            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
403    if '{nproc}' in run_args:
404        run_args[run_args.index('{nproc}')] = f'{nproc}'
405    elif nproc > 1 and source_path.suffix != '.py':
406        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
407
408    # run test
409    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
410    if skip_reason:
411        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
412                                       elapsed_sec=0,
413                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
414                                       stdout='',
415                                       stderr='',
416                                       category=spec.name,)
417        test_case.add_skipped_info(skip_reason)
418    else:
419        start: float = time.time()
420        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
421                              shell=True,
422                              stdout=subprocess.PIPE,
423                              stderr=subprocess.PIPE,
424                              env=my_env)
425
426        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
427                             classname=source_path.parent,
428                             elapsed_sec=time.time() - start,
429                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
430                             stdout=proc.stdout.decode('utf-8'),
431                             stderr=proc.stderr.decode('utf-8'),
432                             allow_multiple_subelements=True,
433                             category=spec.name,)
434        ref_csvs: List[Path] = []
435        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
436        if output_files:
437            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
438        ref_cgns: List[Path] = []
439        output_files = [arg for arg in run_args if 'cgns:' in arg]
440        if output_files:
441            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
442        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
443        suite_spec.post_test_hook(test, spec)
444
445    # check allowed failures
446    if not test_case.is_skipped() and test_case.stderr:
447        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
448        if skip_reason:
449            test_case.add_skipped_info(skip_reason)
450
451    # check required failures
452    if not test_case.is_skipped():
453        required_message, did_fail = suite_spec.check_required_failure(
454            test, spec, backend, test_case.stderr)
455        if required_message and did_fail:
456            test_case.status = f'fails with required: {required_message}'
457        elif required_message:
458            test_case.add_failure_info(f'required failure missing: {required_message}')
459
460    # classify other results
461    if not test_case.is_skipped() and not test_case.status:
462        if test_case.stderr:
463            test_case.add_failure_info('stderr', test_case.stderr)
464        if proc.returncode != 0:
465            test_case.add_error_info(f'returncode = {proc.returncode}')
466        if ref_stdout.is_file():
467            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
468                                             test_case.stdout.splitlines(keepends=True),
469                                             fromfile=str(ref_stdout),
470                                             tofile='New'))
471            if diff:
472                test_case.add_failure_info('stdout', output=''.join(diff))
473        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
474            test_case.add_failure_info('stdout', output=test_case.stdout)
475        # expected CSV output
476        for ref_csv in ref_csvs:
477            csv_name = ref_csv.name
478            if not ref_csv.is_file():
479                # remove _{ceed_backend} from path name
480                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
481            if not ref_csv.is_file():
482                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
483            else:
484                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
485                if diff:
486                    test_case.add_failure_info('csv', output=diff)
487                else:
488                    (Path.cwd() / csv_name).unlink()
489        # expected CGNS output
490        for ref_cgn in ref_cgns:
491            cgn_name = ref_cgn.name
492            if not ref_cgn.is_file():
493                # remove _{ceed_backend} from path name
494                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
495            if not ref_cgn.is_file():
496                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
497            else:
498                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.get_cgns_tol())
499                if diff:
500                    test_case.add_failure_info('cgns', output=diff)
501                else:
502                    (Path.cwd() / cgn_name).unlink()
503
504    # store result
505    test_case.args = ' '.join(str(arg) for arg in run_args)
506    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
507
508    return test_case, output_str
509
510
511def init_process():
512    """Initialize multiprocessing process"""
513    # set up error handler
514    global my_env
515    my_env = os.environ.copy()
516    my_env['CEED_ERROR_HANDLER'] = 'exit'
517
518
519def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
520              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
521    """Run all test cases for `test` with each of the provided `ceed_backends`
522
523    Args:
524        test (str): Name of test
525        ceed_backends (List[str]): List of libCEED backends
526        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
527        nproc (int): Number of MPI processes to use when running each test case
528        suite_spec (SuiteSpec): Object defining required methods for running tests
529        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
530
531    Returns:
532        TestSuite: JUnit `TestSuite` containing results of all test cases
533    """
534    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
535    if mode is RunMode.TAP:
536        print('TAP version 13')
537        print(f'1..{len(test_specs)}')
538
539    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
540        async_outputs: List[List[mp.AsyncResult]] = [
541            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
542             for (i, backend) in enumerate(ceed_backends, start=1)]
543            for spec in test_specs
544        ]
545
546        test_cases = []
547        for (i, subtest) in enumerate(async_outputs, start=1):
548            is_new_subtest = True
549            subtest_ok = True
550            for async_output in subtest:
551                test_case, print_output = async_output.get()
552                test_cases.append(test_case)
553                if is_new_subtest and mode == RunMode.TAP:
554                    is_new_subtest = False
555                    print(f'# Subtest: {test_case.category}')
556                    print(f'    1..{len(ceed_backends)}')
557                print(print_output, end='')
558                if test_case.is_failure() or test_case.is_error():
559                    subtest_ok = False
560            if mode == RunMode.TAP:
561                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
562
563    return TestSuite(test, test_cases)
564
565
566def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
567    """Write a JUnit XML file containing the results of a `TestSuite`
568
569    Args:
570        test_suite (TestSuite): JUnit `TestSuite` to write
571        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
572        batch (str): Name of JUnit batch, defaults to empty string
573    """
574    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
575    output_file.write_text(to_xml_report_string([test_suite]))
576
577
578def has_failures(test_suite: TestSuite) -> bool:
579    """Check whether any test cases in a `TestSuite` failed
580
581    Args:
582        test_suite (TestSuite): JUnit `TestSuite` to check
583
584    Returns:
585        bool: True if any test cases failed
586    """
587    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
588