xref: /libCEED/tests/junit_common.py (revision 255dad3207f061d22701e91ddb8337d8c6809493)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    @property
101    def cgns_tol(self):
102        """Absolute tolerance for CGNS diff"""
103        return getattr(self, '_cgns_tol', 1.0e-12)
104
105    @cgns_tol.setter
106    def cgns_tol(self, val):
107        self._cgns_tol = val
108
109    def post_test_hook(self, test: str, spec: TestSpec) -> None:
110        """Function callback ran after each test case
111
112        Args:
113            test (str): Name of test
114            spec (TestSpec): Test case specification
115        """
116        pass
117
118    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
119        """Check if a test case should be skipped prior to running, returning the reason for skipping
120
121        Args:
122            test (str): Name of test
123            spec (TestSpec): Test case specification
124            resource (str): libCEED backend
125            nproc (int): Number of MPI processes to use when running test case
126
127        Returns:
128            Optional[str]: Skip reason, or `None` if test case should not be skipped
129        """
130        return None
131
132    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
133        """Check if a test case should be allowed to fail, based on its stderr output
134
135        Args:
136            test (str): Name of test
137            spec (TestSpec): Test case specification
138            resource (str): libCEED backend
139            stderr (str): Standard error output from test case execution
140
141        Returns:
142            Optional[str]: Skip reason, or `None` if unexpected error
143        """
144        return None
145
146    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
147        """Check whether a test case is expected to fail and if it failed expectedly
148
149        Args:
150            test (str): Name of test
151            spec (TestSpec): Test case specification
152            resource (str): libCEED backend
153            stderr (str): Standard error output from test case execution
154
155        Returns:
156            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
157        """
158        return '', True
159
160    def check_allowed_stdout(self, test: str) -> bool:
161        """Check whether a test is allowed to print console output
162
163        Args:
164            test (str): Name of test
165
166        Returns:
167            bool: True if the test is allowed to print console output
168        """
169        return False
170
171
172def has_cgnsdiff() -> bool:
173    """Check whether `cgnsdiff` is an executable program in the current environment
174
175    Returns:
176        bool: True if `cgnsdiff` is found
177    """
178    my_env: dict = os.environ.copy()
179    proc = subprocess.run('cgnsdiff',
180                          shell=True,
181                          stdout=subprocess.PIPE,
182                          stderr=subprocess.PIPE,
183                          env=my_env)
184    return 'not found' not in proc.stderr.decode('utf-8')
185
186
187def contains_any(base: str, substrings: List[str]) -> bool:
188    """Helper function, checks if any of the substrings are included in the base string
189
190    Args:
191        base (str): Base string to search in
192        substrings (List[str]): List of potential substrings
193
194    Returns:
195        bool: True if any substrings are included in base string
196    """
197    return any((sub in base for sub in substrings))
198
199
200def startswith_any(base: str, prefixes: List[str]) -> bool:
201    """Helper function, checks if the base string is prefixed by any of `prefixes`
202
203    Args:
204        base (str): Base string to search
205        prefixes (List[str]): List of potential prefixes
206
207    Returns:
208        bool: True if base string is prefixed by any of the prefixes
209    """
210    return any((base.startswith(prefix) for prefix in prefixes))
211
212
213def parse_test_line(line: str) -> TestSpec:
214    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
215
216    Args:
217        line (str): String containing TESTARGS specification and CLI arguments
218
219    Returns:
220        TestSpec: Parsed specification of test case
221    """
222    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
223    if args[0] == 'TESTARGS':
224        return TestSpec(name='', args=args[1:])
225    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
226    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
227    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
228    name: str = test_args.get('name', '')
229    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
230    if len(args) > 1:
231        return TestSpec(name=name, only=constraints, args=args[1:])
232    else:
233        return TestSpec(name=name, only=constraints)
234
235
236def get_test_args(source_file: Path) -> List[TestSpec]:
237    """Parse all test cases from a given source file
238
239    Args:
240        source_file (Path): Path to source file
241
242    Raises:
243        RuntimeError: Errors if source file extension is unsupported
244
245    Returns:
246        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
247    """
248    comment_str: str = ''
249    if source_file.suffix in ['.c', '.cc', '.cpp']:
250        comment_str = '//'
251    elif source_file.suffix in ['.py']:
252        comment_str = '#'
253    elif source_file.suffix in ['.usr']:
254        comment_str = 'C_'
255    elif source_file.suffix in ['.f90']:
256        comment_str = '! '
257    else:
258        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
259
260    return [parse_test_line(line.strip(comment_str))
261            for line in source_file.read_text().splitlines()
262            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
263
264
265def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
266    """Compare CSV results against an expected CSV file with tolerances
267
268    Args:
269        test_csv (Path): Path to output CSV results
270        true_csv (Path): Path to expected CSV results
271        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
272        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
273
274    Returns:
275        str: Diff output between result and expected CSVs
276    """
277    test_lines: List[str] = test_csv.read_text().splitlines()
278    true_lines: List[str] = true_csv.read_text().splitlines()
279    # Files should not be empty
280    if len(test_lines) == 0:
281        return f'No lines found in test output {test_csv}'
282    if len(true_lines) == 0:
283        return f'No lines found in test source {true_csv}'
284
285    test_reader: csv.DictReader = csv.DictReader(test_lines)
286    true_reader: csv.DictReader = csv.DictReader(true_lines)
287    if test_reader.fieldnames != true_reader.fieldnames:
288        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
289                       tofile='found CSV columns', fromfile='expected CSV columns'))
290
291    if len(test_lines) != len(true_lines):
292        return f'Number of lines in {test_csv} and {true_csv} do not match'
293    diff_lines: List[str] = list()
294    for test_line, true_line in zip(test_reader, true_reader):
295        for key in test_reader.fieldnames:
296            # Check if the value is numeric
297            try:
298                true_val: float = float(true_line[key])
299                test_val: float = float(test_line[key])
300                true_zero: bool = abs(true_val) < zero_tol
301                test_zero: bool = abs(test_val) < zero_tol
302                fail: bool = False
303                if true_zero:
304                    fail = not test_zero
305                else:
306                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
307                if fail:
308                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
309            except ValueError:
310                if test_line[key] != true_line[key]:
311                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
312
313    return '\n'.join(diff_lines)
314
315
316def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str:
317    """Compare CGNS results against an expected CGSN file with tolerance
318
319    Args:
320        test_cgns (Path): Path to output CGNS file
321        true_cgns (Path): Path to expected CGNS file
322        cgns_tol (float, optional): Tolerance for comparing floating-point values
323
324    Returns:
325        str: Diff output between result and expected CGNS files
326    """
327    my_env: dict = os.environ.copy()
328
329    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)]
330    proc = subprocess.run(' '.join(run_args),
331                          shell=True,
332                          stdout=subprocess.PIPE,
333                          stderr=subprocess.PIPE,
334                          env=my_env)
335
336    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
337
338
339def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
340                            backend: str, test: str, index: int) -> str:
341    output_str = ''
342    if mode is RunMode.TAP:
343        # print incremental output if TAP mode
344        if test_case.is_skipped():
345            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
346        elif test_case.is_failure() or test_case.is_error():
347            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
348        else:
349            output_str += f'    ok {index} - {spec.name}, {backend}\n'
350        output_str += f'      ---\n'
351        if spec.only:
352            output_str += f'      only: {",".join(spec.only)}\n'
353        output_str += f'      args: {test_case.args}\n'
354        if test_case.is_error():
355            output_str += f'      error: {test_case.errors[0]["message"]}\n'
356        if test_case.is_failure():
357            output_str += f'      num_failures: {len(test_case.failures)}\n'
358            for i, failure in enumerate(test_case.failures):
359                output_str += f'      failure_{i}: {failure["message"]}\n'
360                output_str += f'        message: {failure["message"]}\n'
361                if failure["output"]:
362                    out = failure["output"].strip().replace('\n', '\n          ')
363                    output_str += f'        output: |\n          {out}\n'
364        output_str += f'      ...\n'
365    else:
366        # print error or failure information if JUNIT mode
367        if test_case.is_error() or test_case.is_failure():
368            output_str += f'Test: {test} {spec.name}\n'
369            output_str += f'  $ {test_case.args}\n'
370            if test_case.is_error():
371                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
372                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
373            if test_case.is_failure():
374                for failure in test_case.failures:
375                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
376                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
377    return output_str
378
379
380def run_test(index: int, test: str, spec: TestSpec, backend: str,
381             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
382    """Run a single test case and backend combination
383
384    Args:
385        index (int): Index of backend for current spec
386        test (str): Path to test
387        spec (TestSpec): Specification of test case
388        backend (str): CEED backend
389        mode (RunMode): Output mode
390        nproc (int): Number of MPI processes to use when running test case
391        suite_spec (SuiteSpec): Specification of test suite
392
393    Returns:
394        TestCase: Test case result
395    """
396    source_path: Path = suite_spec.get_source_path(test)
397    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
398
399    if '{ceed_resource}' in run_args:
400        run_args[run_args.index('{ceed_resource}')] = backend
401    for i, arg in enumerate(run_args):
402        if '{ceed_resource}' in arg:
403            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
404    if '{nproc}' in run_args:
405        run_args[run_args.index('{nproc}')] = f'{nproc}'
406    elif nproc > 1 and source_path.suffix != '.py':
407        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
408
409    # run test
410    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
411    if skip_reason:
412        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
413                                       elapsed_sec=0,
414                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
415                                       stdout='',
416                                       stderr='',
417                                       category=spec.name,)
418        test_case.add_skipped_info(skip_reason)
419    else:
420        start: float = time.time()
421        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
422                              shell=True,
423                              stdout=subprocess.PIPE,
424                              stderr=subprocess.PIPE,
425                              env=my_env)
426
427        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
428                             classname=source_path.parent,
429                             elapsed_sec=time.time() - start,
430                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
431                             stdout=proc.stdout.decode('utf-8'),
432                             stderr=proc.stderr.decode('utf-8'),
433                             allow_multiple_subelements=True,
434                             category=spec.name,)
435        ref_csvs: List[Path] = []
436        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
437        if output_files:
438            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
439        ref_cgns: List[Path] = []
440        output_files = [arg for arg in run_args if 'cgns:' in arg]
441        if output_files:
442            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
443        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
444        suite_spec.post_test_hook(test, spec)
445
446    # check allowed failures
447    if not test_case.is_skipped() and test_case.stderr:
448        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
449        if skip_reason:
450            test_case.add_skipped_info(skip_reason)
451
452    # check required failures
453    if not test_case.is_skipped():
454        required_message, did_fail = suite_spec.check_required_failure(
455            test, spec, backend, test_case.stderr)
456        if required_message and did_fail:
457            test_case.status = f'fails with required: {required_message}'
458        elif required_message:
459            test_case.add_failure_info(f'required failure missing: {required_message}')
460
461    # classify other results
462    if not test_case.is_skipped() and not test_case.status:
463        if test_case.stderr:
464            test_case.add_failure_info('stderr', test_case.stderr)
465        if proc.returncode != 0:
466            test_case.add_error_info(f'returncode = {proc.returncode}')
467        if ref_stdout.is_file():
468            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
469                                             test_case.stdout.splitlines(keepends=True),
470                                             fromfile=str(ref_stdout),
471                                             tofile='New'))
472            if diff:
473                test_case.add_failure_info('stdout', output=''.join(diff))
474        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
475            test_case.add_failure_info('stdout', output=test_case.stdout)
476        # expected CSV output
477        for ref_csv in ref_csvs:
478            csv_name = ref_csv.name
479            if not ref_csv.is_file():
480                # remove _{ceed_backend} from path name
481                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
482            if not ref_csv.is_file():
483                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
484            elif not (Path.cwd() / csv_name).is_file():
485                test_case.add_failure_info('csv', output=f'{csv_name} not found')
486            else:
487                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
488                if diff:
489                    test_case.add_failure_info('csv', output=diff)
490                else:
491                    (Path.cwd() / csv_name).unlink()
492        # expected CGNS output
493        for ref_cgn in ref_cgns:
494            cgn_name = ref_cgn.name
495            if not ref_cgn.is_file():
496                # remove _{ceed_backend} from path name
497                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
498            if not ref_cgn.is_file():
499                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
500            elif not (Path.cwd() / cgn_name).is_file():
501                test_case.add_failure_info('csv', output=f'{cgn_name} not found')
502            else:
503                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol)
504                if diff:
505                    test_case.add_failure_info('cgns', output=diff)
506                else:
507                    (Path.cwd() / cgn_name).unlink()
508
509    # store result
510    test_case.args = ' '.join(str(arg) for arg in run_args)
511    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
512
513    return test_case, output_str
514
515
516def init_process():
517    """Initialize multiprocessing process"""
518    # set up error handler
519    global my_env
520    my_env = os.environ.copy()
521    my_env['CEED_ERROR_HANDLER'] = 'exit'
522
523
524def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
525              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
526    """Run all test cases for `test` with each of the provided `ceed_backends`
527
528    Args:
529        test (str): Name of test
530        ceed_backends (List[str]): List of libCEED backends
531        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
532        nproc (int): Number of MPI processes to use when running each test case
533        suite_spec (SuiteSpec): Object defining required methods for running tests
534        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
535
536    Returns:
537        TestSuite: JUnit `TestSuite` containing results of all test cases
538    """
539    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
540    if mode is RunMode.TAP:
541        print('TAP version 13')
542        print(f'1..{len(test_specs)}')
543
544    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
545        async_outputs: List[List[mp.AsyncResult]] = [
546            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
547             for (i, backend) in enumerate(ceed_backends, start=1)]
548            for spec in test_specs
549        ]
550
551        test_cases = []
552        for (i, subtest) in enumerate(async_outputs, start=1):
553            is_new_subtest = True
554            subtest_ok = True
555            for async_output in subtest:
556                test_case, print_output = async_output.get()
557                test_cases.append(test_case)
558                if is_new_subtest and mode == RunMode.TAP:
559                    is_new_subtest = False
560                    print(f'# Subtest: {test_case.category}')
561                    print(f'    1..{len(ceed_backends)}')
562                print(print_output, end='')
563                if test_case.is_failure() or test_case.is_error():
564                    subtest_ok = False
565            if mode == RunMode.TAP:
566                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
567
568    return TestSuite(test, test_cases)
569
570
571def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
572    """Write a JUnit XML file containing the results of a `TestSuite`
573
574    Args:
575        test_suite (TestSuite): JUnit `TestSuite` to write
576        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
577        batch (str): Name of JUnit batch, defaults to empty string
578    """
579    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
580    output_file.write_text(to_xml_report_string([test_suite]))
581
582
583def has_failures(test_suite: TestSuite) -> bool:
584    """Check whether any test cases in a `TestSuite` failed
585
586    Args:
587        test_suite (TestSuite): JUnit `TestSuite` to check
588
589    Returns:
590        bool: True if any test cases failed
591    """
592    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
593