xref: /honee/tests/junit_common.py (revision a32db64d340db16914d4892be21e91c50f2a7cbd)
1from abc import ABC, abstractmethod
2import argparse
3import csv
4from dataclasses import dataclass, field
5import difflib
6from enum import Enum
7from math import isclose
8import os
9from pathlib import Path
10import re
11import subprocess
12import multiprocessing as mp
13from itertools import product
14import sys
15import time
16from typing import Optional, Tuple, List
17
18sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
19from junit_xml import TestCase, TestSuite, to_xml_report_string  # nopep8
20
21
22class CaseInsensitiveEnumAction(argparse.Action):
23    """Action to convert input values to lower case prior to converting to an Enum type"""
24
25    def __init__(self, option_strings, dest, type, default, **kwargs):
26        if not (issubclass(type, Enum) and issubclass(type, str)):
27            raise ValueError(f"{type} must be a StrEnum or str and Enum")
28        # store provided enum type
29        self.enum_type = type
30        if isinstance(default, str):
31            default = self.enum_type(default.lower())
32        else:
33            default = [self.enum_type(v.lower()) for v in default]
34        # prevent automatic type conversion
35        super().__init__(option_strings, dest, default=default, **kwargs)
36
37    def __call__(self, parser, namespace, values, option_string=None):
38        if isinstance(values, str):
39            values = self.enum_type(values.lower())
40        else:
41            values = [self.enum_type(v.lower()) for v in values]
42        setattr(namespace, self.dest, values)
43
44
45@dataclass
46class TestSpec:
47    """Dataclass storing information about a single test case"""
48    name: str
49    only: List = field(default_factory=list)
50    args: List = field(default_factory=list)
51
52
53class RunMode(str, Enum):
54    """Enumeration of run modes, either `RunMode.TAP` or `RunMode.JUNIT`"""
55    __str__ = str.__str__
56    __format__ = str.__format__
57    TAP: str = 'tap'
58    JUNIT: str = 'junit'
59
60
61class SuiteSpec(ABC):
62    """Abstract Base Class defining the required interface for running a test suite"""
63    @abstractmethod
64    def get_source_path(self, test: str) -> Path:
65        """Compute path to test source file
66
67        Args:
68            test (str): Name of test
69
70        Returns:
71            Path: Path to source file
72        """
73        raise NotImplementedError
74
75    @abstractmethod
76    def get_run_path(self, test: str) -> Path:
77        """Compute path to built test executable file
78
79        Args:
80            test (str): Name of test
81
82        Returns:
83            Path: Path to test executable
84        """
85        raise NotImplementedError
86
87    @abstractmethod
88    def get_output_path(self, test: str, output_file: str) -> Path:
89        """Compute path to expected output file
90
91        Args:
92            test (str): Name of test
93            output_file (str): File name of output file
94
95        Returns:
96            Path: Path to expected output file
97        """
98        raise NotImplementedError
99
100    @property
101    def cgns_tol(self):
102        """Absolute tolerance for CGNS diff"""
103        return getattr(self, '_cgns_tol', 1.0e-12)
104
105    @cgns_tol.setter
106    def cgns_tol(self, val):
107        self._cgns_tol = val
108
109    def post_test_hook(self, test: str, spec: TestSpec) -> None:
110        """Function callback ran after each test case
111
112        Args:
113            test (str): Name of test
114            spec (TestSpec): Test case specification
115        """
116        pass
117
118    def check_pre_skip(self, test: str, spec: TestSpec, resource: str, nproc: int) -> Optional[str]:
119        """Check if a test case should be skipped prior to running, returning the reason for skipping
120
121        Args:
122            test (str): Name of test
123            spec (TestSpec): Test case specification
124            resource (str): libCEED backend
125            nproc (int): Number of MPI processes to use when running test case
126
127        Returns:
128            Optional[str]: Skip reason, or `None` if test case should not be skipped
129        """
130        return None
131
132    def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Optional[str]:
133        """Check if a test case should be allowed to fail, based on its stderr output
134
135        Args:
136            test (str): Name of test
137            spec (TestSpec): Test case specification
138            resource (str): libCEED backend
139            stderr (str): Standard error output from test case execution
140
141        Returns:
142            Optional[str]: Skip reason, or `None` if unexpected error
143        """
144        return None
145
146    def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
147        """Check whether a test case is expected to fail and if it failed expectedly
148
149        Args:
150            test (str): Name of test
151            spec (TestSpec): Test case specification
152            resource (str): libCEED backend
153            stderr (str): Standard error output from test case execution
154
155        Returns:
156            tuple[str, bool]: Tuple of the expected failure string and whether it was present in `stderr`
157        """
158        return '', True
159
160    def check_allowed_stdout(self, test: str) -> bool:
161        """Check whether a test is allowed to print console output
162
163        Args:
164            test (str): Name of test
165
166        Returns:
167            bool: True if the test is allowed to print console output
168        """
169        return False
170
171
172def has_cgnsdiff() -> bool:
173    """Check whether `cgnsdiff` is an executable program in the current environment
174
175    Returns:
176        bool: True if `cgnsdiff` is found
177    """
178    my_env: dict = os.environ.copy()
179    proc = subprocess.run('cgnsdiff',
180                          shell=True,
181                          stdout=subprocess.PIPE,
182                          stderr=subprocess.PIPE,
183                          env=my_env)
184    return 'not found' not in proc.stderr.decode('utf-8')
185
186
187def contains_any(base: str, substrings: List[str]) -> bool:
188    """Helper function, checks if any of the substrings are included in the base string
189
190    Args:
191        base (str): Base string to search in
192        substrings (List[str]): List of potential substrings
193
194    Returns:
195        bool: True if any substrings are included in base string
196    """
197    return any((sub in base for sub in substrings))
198
199
200def startswith_any(base: str, prefixes: List[str]) -> bool:
201    """Helper function, checks if the base string is prefixed by any of `prefixes`
202
203    Args:
204        base (str): Base string to search
205        prefixes (List[str]): List of potential prefixes
206
207    Returns:
208        bool: True if base string is prefixed by any of the prefixes
209    """
210    return any((base.startswith(prefix) for prefix in prefixes))
211
212
213def parse_test_line(line: str) -> TestSpec:
214    """Parse a single line of TESTARGS and CLI arguments into a `TestSpec` object
215
216    Args:
217        line (str): String containing TESTARGS specification and CLI arguments
218
219    Returns:
220        TestSpec: Parsed specification of test case
221    """
222    args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
223    if args[0] == 'TESTARGS':
224        return TestSpec(name='', args=args[1:])
225    raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
226    # transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
227    test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
228    name: str = test_args.get('name', '')
229    constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
230    if len(args) > 1:
231        return TestSpec(name=name, only=constraints, args=args[1:])
232    else:
233        return TestSpec(name=name, only=constraints)
234
235
236def get_test_args(source_file: Path) -> List[TestSpec]:
237    """Parse all test cases from a given source file
238
239    Args:
240        source_file (Path): Path to source file
241
242    Raises:
243        RuntimeError: Errors if source file extension is unsupported
244
245    Returns:
246        List[TestSpec]: List of parsed `TestSpec` objects, or a list containing a single, default `TestSpec` if none were found
247    """
248    comment_str: str = ''
249    if source_file.suffix in ['.c', '.cc', '.cpp']:
250        comment_str = '//'
251    elif source_file.suffix in ['.py']:
252        comment_str = '#'
253    elif source_file.suffix in ['.usr']:
254        comment_str = 'C_'
255    elif source_file.suffix in ['.f90']:
256        comment_str = '! '
257    else:
258        raise RuntimeError(f'Unrecognized extension for file: {source_file}')
259
260    return [parse_test_line(line.strip(comment_str))
261            for line in source_file.read_text().splitlines()
262            if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]
263
264
265def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
266    """Compare CSV results against an expected CSV file with tolerances
267
268    Args:
269        test_csv (Path): Path to output CSV results
270        true_csv (Path): Path to expected CSV results
271        zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
272        rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
273
274    Returns:
275        str: Diff output between result and expected CSVs
276    """
277    test_lines: List[str] = test_csv.read_text().splitlines()
278    true_lines: List[str] = true_csv.read_text().splitlines()
279    # Files should not be empty
280    if len(test_lines) == 0:
281        return f'No lines found in test output {test_csv}'
282    if len(true_lines) == 0:
283        return f'No lines found in test source {true_csv}'
284
285    test_reader: csv.DictReader = csv.DictReader(test_lines)
286    true_reader: csv.DictReader = csv.DictReader(true_lines)
287    if test_reader.fieldnames != true_reader.fieldnames:
288        return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
289                       tofile='found CSV columns', fromfile='expected CSV columns'))
290
291    if len(test_lines) != len(true_lines):
292        return f'Number of lines in {test_csv} and {true_csv} do not match'
293    diff_lines: List[str] = list()
294    for test_line, true_line in zip(test_reader, true_reader):
295        for key in test_reader.fieldnames:
296            # Check if the value is numeric
297            try:
298                true_val: float = float(true_line[key])
299                test_val: float = float(test_line[key])
300                true_zero: bool = abs(true_val) < zero_tol
301                test_zero: bool = abs(test_val) < zero_tol
302                fail: bool = False
303                if true_zero:
304                    fail = not test_zero
305                else:
306                    fail = not isclose(test_val, true_val, rel_tol=rel_tol)
307                if fail:
308                    diff_lines.append(f'column: {key}, expected: {true_val}, got: {test_val}')
309            except ValueError:
310                if test_line[key] != true_line[key]:
311                    diff_lines.append(f'column: {key}, expected: {true_line[key]}, got: {test_line[key]}')
312
313    return '\n'.join(diff_lines)
314
315
316def diff_cgns(test_cgns: Path, true_cgns: Path, cgns_tol: float = 1e-12) -> str:
317    """Compare CGNS results against an expected CGSN file with tolerance
318
319    Args:
320        test_cgns (Path): Path to output CGNS file
321        true_cgns (Path): Path to expected CGNS file
322        cgns_tol (float, optional): Tolerance for comparing floating-point values
323
324    Returns:
325        str: Diff output between result and expected CGNS files
326    """
327    my_env: dict = os.environ.copy()
328
329    run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{cgns_tol}', str(test_cgns), str(true_cgns)]
330    proc = subprocess.run(' '.join(run_args),
331                          shell=True,
332                          stdout=subprocess.PIPE,
333                          stderr=subprocess.PIPE,
334                          env=my_env)
335
336    return proc.stderr.decode('utf-8') + proc.stdout.decode('utf-8')
337
338
339def test_case_output_string(test_case: TestCase, spec: TestSpec, mode: RunMode,
340                            backend: str, test: str, index: int) -> str:
341    output_str = ''
342    if mode is RunMode.TAP:
343        # print incremental output if TAP mode
344        if test_case.is_skipped():
345            output_str += f'    ok {index} - {spec.name}, {backend} # SKIP {test_case.skipped[0]["message"]}\n'
346        elif test_case.is_failure() or test_case.is_error():
347            output_str += f'    not ok {index} - {spec.name}, {backend}\n'
348        else:
349            output_str += f'    ok {index} - {spec.name}, {backend}\n'
350        output_str += f'      ---\n'
351        if spec.only:
352            output_str += f'      only: {",".join(spec.only)}\n'
353        output_str += f'      args: {test_case.args}\n'
354        if test_case.is_error():
355            output_str += f'      error: {test_case.errors[0]["message"]}\n'
356        if test_case.is_failure():
357            output_str += f'      num_failures: {len(test_case.failures)}\n'
358            for i, failure in enumerate(test_case.failures):
359                output_str += f'      failure_{i}: {failure["message"]}\n'
360                output_str += f'        message: {failure["message"]}\n'
361                if failure["output"]:
362                    out = failure["output"].strip().replace('\n', '\n          ')
363                    output_str += f'        output: |\n          {out}\n'
364        output_str += f'      ...\n'
365    else:
366        # print error or failure information if JUNIT mode
367        if test_case.is_error() or test_case.is_failure():
368            output_str += f'Test: {test} {spec.name}\n'
369            output_str += f'  $ {test_case.args}\n'
370            if test_case.is_error():
371                output_str += 'ERROR: {}\n'.format((test_case.errors[0]['message'] or 'NO MESSAGE').strip())
372                output_str += 'Output: \n{}\n'.format((test_case.errors[0]['output'] or 'NO MESSAGE').strip())
373            if test_case.is_failure():
374                for failure in test_case.failures:
375                    output_str += 'FAIL: {}\n'.format((failure['message'] or 'NO MESSAGE').strip())
376                    output_str += 'Output: \n{}\n'.format((failure['output'] or 'NO MESSAGE').strip())
377    return output_str
378
379
380def run_test(index: int, test: str, spec: TestSpec, backend: str,
381             mode: RunMode, nproc: int, suite_spec: SuiteSpec) -> TestCase:
382    """Run a single test case and backend combination
383
384    Args:
385        index (int): Index of backend for current spec
386        test (str): Path to test
387        spec (TestSpec): Specification of test case
388        backend (str): CEED backend
389        mode (RunMode): Output mode
390        nproc (int): Number of MPI processes to use when running test case
391        suite_spec (SuiteSpec): Specification of test suite
392
393    Returns:
394        TestCase: Test case result
395    """
396    source_path: Path = suite_spec.get_source_path(test)
397    run_args: List = [f'{suite_spec.get_run_path(test)}', *map(str, spec.args)]
398
399    if '{ceed_resource}' in run_args:
400        run_args[run_args.index('{ceed_resource}')] = backend
401    for i, arg in enumerate(run_args):
402        if '{ceed_resource}' in arg:
403            run_args[i] = arg.replace('{ceed_resource}', backend.replace('/', '-'))
404    if '{nproc}' in run_args:
405        run_args[run_args.index('{nproc}')] = f'{nproc}'
406    elif nproc > 1 and source_path.suffix != '.py':
407        run_args = ['mpiexec', '-n', f'{nproc}', *run_args]
408
409    # run test
410    skip_reason: str = suite_spec.check_pre_skip(test, spec, backend, nproc)
411    if skip_reason:
412        test_case: TestCase = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
413                                       elapsed_sec=0,
414                                       timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime()),
415                                       stdout='',
416                                       stderr='',
417                                       category=spec.name,)
418        test_case.add_skipped_info(skip_reason)
419    else:
420        start: float = time.time()
421        proc = subprocess.run(' '.join(str(arg) for arg in run_args),
422                              shell=True,
423                              stdout=subprocess.PIPE,
424                              stderr=subprocess.PIPE,
425                              env=my_env)
426
427        test_case = TestCase(f'{test}, "{spec.name}", n{nproc}, {backend}',
428                             classname=source_path.parent,
429                             elapsed_sec=time.time() - start,
430                             timestamp=time.strftime('%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
431                             stdout=proc.stdout.decode('utf-8'),
432                             stderr=proc.stderr.decode('utf-8'),
433                             allow_multiple_subelements=True,
434                             category=spec.name,)
435        ref_csvs: List[Path] = []
436        output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
437        if output_files:
438            ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
439        ref_cgns: List[Path] = []
440        output_files = [arg for arg in run_args if 'cgns:' in arg]
441        if output_files:
442            ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
443        ref_stdout: Path = suite_spec.get_output_path(test, test + '.out')
444        suite_spec.post_test_hook(test, spec)
445
446    # check allowed failures
447    if not test_case.is_skipped() and test_case.stderr:
448        skip_reason: str = suite_spec.check_post_skip(test, spec, backend, test_case.stderr)
449        if skip_reason:
450            test_case.add_skipped_info(skip_reason)
451
452    # check required failures
453    if not test_case.is_skipped():
454        required_message, did_fail = suite_spec.check_required_failure(
455            test, spec, backend, test_case.stderr)
456        if required_message and did_fail:
457            test_case.status = f'fails with required: {required_message}'
458        elif required_message:
459            test_case.add_failure_info(f'required failure missing: {required_message}')
460
461    # classify other results
462    if not test_case.is_skipped() and not test_case.status:
463        if test_case.stderr:
464            test_case.add_failure_info('stderr', test_case.stderr)
465        if proc.returncode != 0:
466            test_case.add_error_info(f'returncode = {proc.returncode}')
467        if ref_stdout.is_file():
468            diff = list(difflib.unified_diff(ref_stdout.read_text().splitlines(keepends=True),
469                                             test_case.stdout.splitlines(keepends=True),
470                                             fromfile=str(ref_stdout),
471                                             tofile='New'))
472            if diff:
473                test_case.add_failure_info('stdout', output=''.join(diff))
474        elif test_case.stdout and not suite_spec.check_allowed_stdout(test):
475            test_case.add_failure_info('stdout', output=test_case.stdout)
476        # expected CSV output
477        for ref_csv in ref_csvs:
478            csv_name = ref_csv.name
479            if not ref_csv.is_file():
480                # remove _{ceed_backend} from path name
481                ref_csv = (ref_csv.parent / ref_csv.name.rsplit('_', 1)[0]).with_suffix('.csv')
482            if not ref_csv.is_file():
483                test_case.add_failure_info('csv', output=f'{ref_csv} not found')
484            else:
485                diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
486                if diff:
487                    test_case.add_failure_info('csv', output=diff)
488                else:
489                    (Path.cwd() / csv_name).unlink()
490        # expected CGNS output
491        for ref_cgn in ref_cgns:
492            cgn_name = ref_cgn.name
493            if not ref_cgn.is_file():
494                # remove _{ceed_backend} from path name
495                ref_cgn = (ref_cgn.parent / ref_cgn.name.rsplit('_', 1)[0]).with_suffix('.cgns')
496            if not ref_cgn.is_file():
497                test_case.add_failure_info('cgns', output=f'{ref_cgn} not found')
498            else:
499                diff = diff_cgns(Path.cwd() / cgn_name, ref_cgn, cgns_tol=suite_spec.cgns_tol)
500                if diff:
501                    test_case.add_failure_info('cgns', output=diff)
502                else:
503                    (Path.cwd() / cgn_name).unlink()
504
505    # store result
506    test_case.args = ' '.join(str(arg) for arg in run_args)
507    output_str = test_case_output_string(test_case, spec, mode, backend, test, index)
508
509    return test_case, output_str
510
511
512def init_process():
513    """Initialize multiprocessing process"""
514    # set up error handler
515    global my_env
516    my_env = os.environ.copy()
517    my_env['CEED_ERROR_HANDLER'] = 'exit'
518
519
520def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
521              suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
522    """Run all test cases for `test` with each of the provided `ceed_backends`
523
524    Args:
525        test (str): Name of test
526        ceed_backends (List[str]): List of libCEED backends
527        mode (RunMode): Output mode, either `RunMode.TAP` or `RunMode.JUNIT`
528        nproc (int): Number of MPI processes to use when running each test case
529        suite_spec (SuiteSpec): Object defining required methods for running tests
530        pool_size (int, optional): Number of processes to use when running tests in parallel. Defaults to 1.
531
532    Returns:
533        TestSuite: JUnit `TestSuite` containing results of all test cases
534    """
535    test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
536    if mode is RunMode.TAP:
537        print('TAP version 13')
538        print(f'1..{len(test_specs)}')
539
540    with mp.Pool(processes=pool_size, initializer=init_process) as pool:
541        async_outputs: List[List[mp.AsyncResult]] = [
542            [pool.apply_async(run_test, (i, test, spec, backend, mode, nproc, suite_spec))
543             for (i, backend) in enumerate(ceed_backends, start=1)]
544            for spec in test_specs
545        ]
546
547        test_cases = []
548        for (i, subtest) in enumerate(async_outputs, start=1):
549            is_new_subtest = True
550            subtest_ok = True
551            for async_output in subtest:
552                test_case, print_output = async_output.get()
553                test_cases.append(test_case)
554                if is_new_subtest and mode == RunMode.TAP:
555                    is_new_subtest = False
556                    print(f'# Subtest: {test_case.category}')
557                    print(f'    1..{len(ceed_backends)}')
558                print(print_output, end='')
559                if test_case.is_failure() or test_case.is_error():
560                    subtest_ok = False
561            if mode == RunMode.TAP:
562                print(f'{"" if subtest_ok else "not "}ok {i} - {test_case.category}')
563
564    return TestSuite(test, test_cases)
565
566
567def write_junit_xml(test_suite: TestSuite, output_file: Optional[Path], batch: str = '') -> None:
568    """Write a JUnit XML file containing the results of a `TestSuite`
569
570    Args:
571        test_suite (TestSuite): JUnit `TestSuite` to write
572        output_file (Optional[Path]): Path to output file, or `None` to generate automatically as `build/{test_suite.name}{batch}.junit`
573        batch (str): Name of JUnit batch, defaults to empty string
574    """
575    output_file: Path = output_file or Path('build') / (f'{test_suite.name}{batch}.junit')
576    output_file.write_text(to_xml_report_string([test_suite]))
577
578
579def has_failures(test_suite: TestSuite) -> bool:
580    """Check whether any test cases in a `TestSuite` failed
581
582    Args:
583        test_suite (TestSuite): JUnit `TestSuite` to check
584
585    Returns:
586        bool: True if any test cases failed
587    """
588    return any(c.is_failure() or c.is_error() for c in test_suite.test_cases)
589