xref: /libCEED/backends/magma/tuning/generate_tuning.py (revision 5ebd836c59d60a2e5e1cb67f6731404c7da26f85)
1#!/usr/bin/env python3
2
3# Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
4# All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
5#
6# SPDX-License-Identifier: BSD-2-Clause
7#
8# This file is part of CEED:  http://github.com/ceed
9
10import argparse
11import os
12import glob
13import re
14import shutil
15import subprocess
16import pandas as pd
17import time
18
19script_dir = os.path.dirname(os.path.realpath(__file__))
20
21
22def benchmark(nb, build_cmd, backend, log):
23    # Build for new NB
24    ceed_magma_h = f"{script_dir}/../ceed-magma.h"
25    shutil.copyfile(ceed_magma_h, ceed_magma_h + ".backup")
26    with open(ceed_magma_h, "r") as f:
27        data = f.read()
28        data = re.sub(
29            r".*(#define ceed_magma_queue_sync\(\.\.\.\)).*",
30            r"\1 " +
31            ("hipDeviceSynchronize()" if "hip" in backend else "cudaDeviceSynchronize()"),
32            data)
33    with open(ceed_magma_h, "w") as f:
34        f.write(data)
35
36    ceed_magma_gemm_selector_cpp = f"{script_dir}/../ceed-magma-gemm-selector.cpp"
37    shutil.copyfile(
38        ceed_magma_gemm_selector_cpp,
39        ceed_magma_gemm_selector_cpp +
40        ".backup")
41    with open(ceed_magma_gemm_selector_cpp, "r") as f:
42        data = f.read()
43        data = re.sub(
44            ".*(#define CEED_AUTOTUNE_RTC_NB).*",
45            r"\1 " + f"{nb}",
46            data)
47    with open(ceed_magma_gemm_selector_cpp, "w") as f:
48        f.write(data)
49
50    subprocess.run(build_cmd, cwd=f"{script_dir}/../../..")
51    subprocess.run(["make", "tuning", "OPT=-O0"], cwd=f"{script_dir}")
52    shutil.move(ceed_magma_h + ".backup", ceed_magma_h)
53    shutil.move(ceed_magma_gemm_selector_cpp +
54                ".backup", ceed_magma_gemm_selector_cpp)
55
56    # Run the benchmark
57    with open(log, "w") as f:
58        process = subprocess.run(
59            [f"{script_dir}/tuning", f"{backend}"], stdout=f, stderr=f)
60    csv = pd.read_csv(
61        log,
62        header=None,
63        delim_whitespace=True,
64        names=[
65            "P",
66            "Q",
67            "N",
68            "Q_COMP",
69            "TRANS",
70            "MFLOPS"])
71    return csv
72
73
74if __name__ == "__main__":
75    # Command line arguments
76    parser = argparse.ArgumentParser("MAGMA RTC autotuning")
77    parser.add_argument(
78        "-arch",
79        help="Device architecture name for tuning data",
80        required=True)
81    parser.add_argument(
82        "-max-nb",
83        help="Maximum block size NB to consider for autotuning",
84        default=32,
85        type=int)
86    parser.add_argument(
87        "-build-cmd",
88        help="Command used to build libCEED from the source root directory",
89        default="make")
90    parser.add_argument(
91        "-ceed",
92        help="Ceed resource specifier",
93        default="/cpu/self")
94    args = parser.parse_args()
95
96    nb = 1
97    while nb <= args.max_nb:
98        # Run the benchmarks
99        start = time.perf_counter()
100        data_nb = benchmark(nb, args.build_cmd, args.ceed,
101                            f"{script_dir}/output-nb-{nb}.txt")
102        print(
103            f"Finished benchmarks for NB = {nb}, backend = {args.ceed} ({time.perf_counter() - start} s)")
104
105        # Save the data for the highest performing NB
106        if nb == 1:
107            data = pd.DataFrame(data_nb)
108            data["NB"] = nb
109        else:
110            idx = data_nb["MFLOPS"] > 1.05 * data["MFLOPS"]
111            data.loc[idx, "NB"] = nb
112            data.loc[idx, "MFLOPS"] = data_nb.loc[idx, "MFLOPS"]
113
114        # Speed up the search by considering only some values on NB
115        if nb < 2:
116            nb *= 2
117        elif nb < 8:
118            nb += 2
119        else:
120            nb += 4
121
122    # Print the results
123    with open(f"{script_dir}/{args.arch}_rtc.h", "w") as f:
124        f.write(
125            "////////////////////////////////////////////////////////////////////////////////\n")
126        f.write(f"// auto-generated from data on {args.arch}\n\n")
127
128        rows = data.loc[data["TRANS"] == 1].to_string(header=False, index=False, justify="right", columns=[
129                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
130        f.write(
131            "////////////////////////////////////////////////////////////////////////////////\n")
132        f.write(
133            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_t_{args.arch}" +
134            " = {\n")
135        count = 0
136        for row in rows:
137            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
138                    ("},\n" if count < len(rows) - 1 else "}\n"))
139            count += 1
140        f.write("};\n\n")
141
142        rows = data.loc[data["TRANS"] == 0].to_string(header=False, index=False, justify="right", columns=[
143                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
144        f.write(
145            "////////////////////////////////////////////////////////////////////////////////\n")
146        f.write(
147            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_n_{args.arch}" +
148            " = {\n")
149        count = 0
150        for row in rows:
151            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
152                    ("},\n" if count < len(rows) - 1 else "}\n"))
153            count += 1
154        f.write("};\n")
155