xref: /libCEED/backends/magma/tuning/tuning.cpp (revision ac8b7a1c891ad04d71a0a3b312ee8eb2f11e3b62)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <algorithm>
10 #include <array>
11 #include <chrono>
12 #include <iostream>
13 #include <random>
14 #include <vector>
15 
16 // XX TODO WIP: Add other quadrature orders, prism/pyramid, ...
17 // clang-format off
18 constexpr static std::array<std::array<int, 3>, 11> PQ_VALUES = {
19     {{3, 1, 2}, {6, 3,  2}, {10, 6,  2}, {15, 12, 2}, {21, 16, 2}, {28, 25, 2}, {36, 33, 2},
20      {4, 1, 3}, {10, 4, 3}, {20, 11, 3}, {35, 24, 3}}
21 };
22 // clang-format on
23 
24 constexpr static std::array<int, 7> N_VALUES = {1024, 5120, 10240, 51200, 102400, 512000, 1024000};
25 
26 constexpr int NUM_TRIALS = 25;
27 
28 using Clock    = std::chrono::steady_clock;
29 using Duration = std::chrono::duration<double>;
30 
31 int main(int argc, char **argv) {
32   Ceed ceed;
33 
34   std::random_device               rand_device;
35   std::default_random_engine       rand_engine(rand_device());
36   std::uniform_real_distribution<> rand_dist(0.0, 1.0);
37   auto                             generate_random = [&rand_dist, &rand_engine]() { return rand_dist(rand_engine); };
38 
39   CeedInit((argc < 2) ? "/gpu/cuda/magma" : argv[1], &ceed);
40   CeedSetErrorHandler(ceed, CeedErrorStore);
41 
42   for (const auto [P, Q, dim] : PQ_VALUES) {
43     CeedBasis  basis;
44     CeedVector u, v;
45 
46     std::vector<double> q_ref(dim * Q, 0.0), q_weight(Q, 0.0), interp(P * Q), grad(P * Q * dim);
47     std::generate(interp.begin(), interp.end(), generate_random);
48     std::generate(grad.begin(), grad.end(), generate_random);
49 
50     CeedBasisCreateH1(ceed, (dim < 3) ? CEED_TOPOLOGY_TRIANGLE : CEED_TOPOLOGY_TET, 1, P, Q, interp.data(), grad.data(), q_ref.data(),
51                       q_weight.data(), &basis);
52 
53     for (const auto N : N_VALUES) {
54       double data_interp_n = 0.0, data_interp_t = 0.0, data_grad_n = 0.0, data_grad_t = 0.0;
55 
56       // Interp
57       {
58         CeedVectorCreate(ceed, P * N, &u);
59         CeedVectorCreate(ceed, Q * N, &v);
60 
61         // NoTranspose
62         CeedVectorSetValue(u, 1.0);
63         for (int trial = 0; trial <= NUM_TRIALS; trial++) {
64           CeedVectorSetValue(v, 0.0);
65 
66           const auto start = Clock::now();
67           int        ierr  = CeedBasisApply(basis, N, CEED_NOTRANSPOSE, CEED_EVAL_INTERP, u, v);
68           if (ierr) {
69             break;
70           }
71           if (trial > 0) {
72             data_interp_n += std::chrono::duration_cast<Duration>(Clock::now() - start).count();
73           }
74         }
75 
76         // Transpose
77         CeedVectorSetValue(v, 1.0);
78         for (int trial = 0; trial <= NUM_TRIALS; trial++) {
79           CeedVectorSetValue(u, 0.0);
80 
81           const auto start = Clock::now();
82           int        ierr  = CeedBasisApply(basis, N, CEED_TRANSPOSE, CEED_EVAL_INTERP, v, u);
83           if (ierr) {
84             break;
85           }
86           if (trial > 0) {
87             data_interp_t += std::chrono::duration_cast<Duration>(Clock::now() - start).count();
88           }
89         }
90 
91         CeedVectorDestroy(&u);
92         CeedVectorDestroy(&v);
93       }
94 
95       // Grad
96       {
97         CeedVectorCreate(ceed, P * N, &u);
98         CeedVectorCreate(ceed, dim * Q * N, &v);
99 
100         // NoTranspose
101         CeedVectorSetValue(u, 1.0);
102         for (int trial = 0; trial < NUM_TRIALS; trial++) {
103           CeedVectorSetValue(v, 0.0);
104 
105           const auto start = Clock::now();
106           int        ierr  = CeedBasisApply(basis, N, CEED_NOTRANSPOSE, CEED_EVAL_GRAD, u, v);
107           if (ierr) {
108             break;
109           }
110           if (trial > 0) {
111             data_grad_n += std::chrono::duration_cast<Duration>(Clock::now() - start).count();
112           }
113         }
114 
115         // Transpose
116         CeedVectorSetValue(v, 1.0);
117         for (int trial = 0; trial < NUM_TRIALS; trial++) {
118           CeedVectorSetValue(u, 0.0);
119 
120           const auto start = Clock::now();
121           int        ierr  = CeedBasisApply(basis, N, CEED_TRANSPOSE, CEED_EVAL_GRAD, v, u);
122           if (ierr) {
123             break;
124           }
125           if (trial > 0) {
126             data_grad_t += std::chrono::duration_cast<Duration>(Clock::now() - start).count();
127           }
128         }
129 
130         CeedVectorDestroy(&u);
131         CeedVectorDestroy(&v);
132       }
133 
134       // Postprocess and log the data
135       const double  interp_flops = P * Q * (double)N;
136       const double  grad_flops   = P * Q * dim * (double)N;
137       constexpr int width = 12, precision = 2;
138       // clang-format off
139       std::printf("%-*d%-*d%-*d%-*d%-*d%*.*f\n",
140                   width, P, width, N, width, Q, width, 1, width, 0, width, precision,
141                   (data_interp_n > 0.0) ? NUM_TRIALS * interp_flops / data_interp_n * 1.0e-6 : 0.0);
142       std::printf("%-*d%-*d%-*d%-*d%-*d%*.*f\n",
143                   width, P, width, N, width, Q, width, 1, width, 1, width, precision,
144                   (data_interp_t > 0.0) ? NUM_TRIALS * interp_flops / data_interp_t * 1.0e-6 : 0.0);
145       std::printf("%-*d%-*d%-*d%-*d%-*d%*.*f\n",
146                   width, P, width, N, width, Q, width, dim, width, 0, width, precision,
147                   (data_grad_n > 0.0) ? NUM_TRIALS * grad_flops / data_grad_n * 1.0e-6 : 0.0);
148       std::printf("%-*d%-*d%-*d%-*d%-*d%*.*f\n",
149                   width, P, width, N, width, Q, width, dim, width, 1, width, precision,
150                   (data_grad_n > 0.0) ? NUM_TRIALS * grad_flops / data_grad_n * 1.0e-6 : 0.0);
151       // clang-format on
152     }
153 
154     CeedBasisDestroy(&basis);
155   }
156 
157   CeedDestroy(&ceed);
158   return 0;
159 }
160