xref: /libCEED/interface/ceed-jit-tools.c (revision b46df0d23d416892813aae9c232a5a88657bbf88)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed-impl.h>
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <ceed/jit-tools.h>
12 #include <stdbool.h>
13 #include <stdio.h>
14 #include <string.h>
15 
16 /**
17   @brief Check if valid file exists at path given
18 
19   @param[in]  ceed             `Ceed` object for error handling
20   @param[in]  source_file_path Absolute path to source file
21   @param[out] is_valid         Boolean flag indicating if file can be opened
22 
23   @return An error code: 0 - success, otherwise - failure
24 
25   @ref Backend
26 **/
27 int CeedCheckFilePath(Ceed ceed, const char *source_file_path, bool *is_valid) {
28   // Sometimes we have path/to/file.h:function_name
29   // Create temporary file path without name, if needed
30   char *source_file_path_only;
31   char *last_colon = strrchr(source_file_path, ':');
32 
33   if (last_colon) {
34     size_t source_file_path_length = (last_colon - source_file_path + 1);
35 
36     CeedCall(CeedCalloc(source_file_path_length, &source_file_path_only));
37     memcpy(source_file_path_only, source_file_path, source_file_path_length - 1);
38   } else {
39     source_file_path_only = (char *)source_file_path;
40   }
41 
42   // Debug
43   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Checking for source file: ");
44   CeedDebug(ceed, "%s\n", source_file_path_only);
45 
46   // Check for valid file path
47   FILE *source_file;
48   source_file = fopen(source_file_path_only, "rb");
49   *is_valid   = source_file;
50 
51   if (*is_valid) {
52     // Debug
53     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Found JiT source file: ");
54     CeedDebug(ceed, "%s\n", source_file_path_only);
55     fclose(source_file);
56   }
57 
58   // Free temp file path, if used
59   if (last_colon) CeedCall(CeedFree(&source_file_path_only));
60   return CEED_ERROR_SUCCESS;
61 }
62 
63 /**
64   @brief Normalize a filepath
65 
66   @param[in]   ceed                        `Ceed` object for error handling
67   @param[in]   source_file_path            Absolute path to source file
68   @param[out]  normalized_source_file_path Normalized filepath
69 
70   @return An error code: 0 - success, otherwise - failure
71 
72   @ref Backend
73 **/
74 static int CeedNormalizePath(Ceed ceed, const char *source_file_path, char **normalized_source_file_path) {
75   CeedCall(CeedStringAllocCopy(source_file_path, normalized_source_file_path));
76 
77   char *first_dot = strchr(*normalized_source_file_path, '.');
78 
79   while (first_dot) {
80     char *search_from = first_dot + 1;
81     char  keyword[5]  = "";
82 
83     // -- Check for /./ and covert to /
84     if (first_dot != *normalized_source_file_path && strlen(first_dot) > 2) memcpy(keyword, &first_dot[-1], 3);
85     bool is_here = !strcmp(keyword, "/./");
86 
87     if (is_here) {
88       for (CeedInt i = 0; first_dot[i - 1]; i++) first_dot[i] = first_dot[i + 2];
89       search_from = first_dot;
90     } else {
91       // -- Check for /foo/../ and convert to /
92       if (first_dot != *normalized_source_file_path && strlen(first_dot) > 3) memcpy(keyword, &first_dot[-1], 4);
93       bool is_up_one = !strcmp(keyword, "/../");
94 
95       if (is_up_one) {
96         char *last_slash = &first_dot[-2];
97 
98         while (last_slash[0] != '/' && last_slash != *normalized_source_file_path) last_slash--;
99         CeedCheck(last_slash != *normalized_source_file_path, ceed, CEED_ERROR_MAJOR, "Malformed source path %s", source_file_path);
100         for (CeedInt i = 0; first_dot[i + 1]; i++) last_slash[i] = first_dot[i + 2];
101         search_from = last_slash;
102       }
103     }
104     first_dot = strchr(search_from, '.');
105   }
106   return CEED_ERROR_SUCCESS;
107 }
108 
109 /**
110   @brief Load source file into initialized string buffer, including full text of local files in place of `#include "local.h"`.
111     This also updates the `num_file_paths` and `source_file_paths`.
112     Callers are responsible freeing all filepath strings and the string buffer with @ref CeedFree().
113 
114   @param[in]     ceed             `Ceed` object for error handling
115   @param[in]     source_file_path Absolute path to source file
116   @param[in,out] num_file_paths   Number of files already included
117   @param[in,out] file_paths       Paths of files already included
118   @param[out]    buffer           String buffer for source file contents
119 
120   @return An error code: 0 - success, otherwise - failure
121 
122   @ref Backend
123 **/
124 int CeedLoadSourceToInitializedBuffer(Ceed ceed, const char *source_file_path, CeedInt *num_file_paths, char ***file_paths, char **buffer) {
125   FILE *source_file;
126   long  file_size, file_offset = 0;
127   char *temp_buffer;
128 
129   // Debug
130   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "---------- Ceed JiT ----------\n");
131   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Current source file: ");
132   CeedDebug(ceed, "%s\n", source_file_path);
133 
134   // Read file to temporary buffer
135   source_file = fopen(source_file_path, "rb");
136   CeedCheck(source_file, ceed, CEED_ERROR_MAJOR, "Couldn't open source file: %s", source_file_path);
137   // -- Compute size of source
138   fseek(source_file, 0L, SEEK_END);
139   file_size = ftell(source_file);
140   fseek(source_file, 0L, SEEK_SET);
141   //  -- Allocate memory for entire source file
142   {
143     const int ierr = CeedCalloc(file_size + 1, &temp_buffer);
144 
145     // Close stream before error handling, if necessary
146     if (ierr != CEED_ERROR_SUCCESS) fclose(source_file);
147     CeedCall(ierr);
148   }
149   // -- Copy the file into the buffer
150   if (1 != fread(temp_buffer, file_size, 1, source_file)) {
151     // LCOV_EXCL_START
152     fclose(source_file);
153     CeedCall(CeedFree(&temp_buffer));
154     return CeedError(ceed, CEED_ERROR_MAJOR, "Couldn't read source file: %s", source_file_path);
155     // LCOV_EXCL_STOP
156   }
157   fclose(source_file);
158 
159   // Search for headers to include
160   const char *first_hash = strchr(temp_buffer, '#');
161 
162   while (first_hash) {
163     // -- Check for 'pragma' keyword
164     const char *next_m     = strchr(first_hash, 'm');
165     char        keyword[8] = "";
166 
167     if (next_m && next_m - first_hash >= 5) memcpy(keyword, &next_m[-4], 6);
168     bool is_hash_pragma = !strcmp(keyword, "pragma");
169 
170     // ---- Spaces allowed in '#  pragma'
171     if (next_m) {
172       for (CeedInt i = 1; first_hash - next_m + i < -5; i++) {
173         is_hash_pragma &= first_hash[i] == ' ';
174       }
175     }
176     if (is_hash_pragma) {
177       // -- Check if '#pragma once'
178       char *next_o         = strchr(first_hash, 'o');
179       char *next_new_line  = strchr(first_hash, '\n');
180       bool  is_pragma_once = next_o && (next_new_line - next_o > 0) && !strncmp(next_o, "once", 4);
181 
182       // -- Copy into buffer, omitting last line if #pragma once
183       long current_size = strlen(*buffer);
184       long copy_size    = first_hash - &temp_buffer[file_offset] + (is_pragma_once ? 0 : (next_new_line - first_hash + 1));
185 
186       CeedCall(CeedRealloc(current_size + copy_size + 2, buffer));
187       memcpy(&(*buffer)[current_size], "\n", 2);
188       memcpy(&(*buffer)[current_size + 1], &temp_buffer[file_offset], copy_size);
189       memcpy(&(*buffer)[current_size + copy_size], "", 1);
190 
191       file_offset = strchr(first_hash, '\n') - temp_buffer + 1;
192     }
193 
194     // -- Check for 'include' keyword
195     const char *next_e = strchr(first_hash, 'e');
196 
197     if (next_e && next_e - first_hash >= 7) memcpy(keyword, &next_e[-6], 7);
198     bool is_hash_include = !strcmp(keyword, "include");
199 
200     // ---- Spaces allowed in '#  include <header.h>'
201     if (next_e) {
202       for (CeedInt i = 1; first_hash - next_e + i < -6; i++) {
203         is_hash_include &= first_hash[i] == ' ';
204       }
205     }
206     if (is_hash_include) {
207       // -- Copy into buffer all preceding #
208       long current_size = strlen(*buffer);
209       long copy_size    = first_hash - &temp_buffer[file_offset];
210 
211       CeedCall(CeedRealloc(current_size + copy_size + 2, buffer));
212       memcpy(&(*buffer)[current_size], "\n", 2);
213       memcpy(&(*buffer)[current_size + 1], &temp_buffer[file_offset], copy_size);
214       memcpy(&(*buffer)[current_size + copy_size], "", 1);
215       // -- Load local "header.h"
216       char *next_quote        = strchr(first_hash, '"');
217       char *next_new_line     = strchr(first_hash, '\n');
218       bool  is_local_header   = is_hash_include && next_quote && (next_new_line - next_quote > 0);
219       char *next_left_chevron = strchr(first_hash, '<');
220       bool  is_ceed_header    = next_left_chevron && (next_new_line - next_left_chevron > 0) &&
221                             (!strncmp(next_left_chevron, "<ceed/jit-source/", 17) || !strncmp(next_left_chevron, "<ceed/types.h>", 14) ||
222                              !strncmp(next_left_chevron, "<ceed/ceed-f32.h>", 17) || !strncmp(next_left_chevron, "<ceed/ceed-f64.h>", 17));
223       bool is_std_header =
224           next_left_chevron && (next_new_line - next_left_chevron > 0) &&
225           (!strncmp(next_left_chevron, "<std", 4) || !strncmp(next_left_chevron, "<math.h>", 8) || !strncmp(next_left_chevron, "<ceed", 5));
226 
227       if (is_local_header || is_ceed_header) {
228         // ---- Build source path
229         bool  is_included = false;
230         char *include_source_path;
231 
232         if (is_local_header) {
233           long root_length           = strrchr(source_file_path, '/') - source_file_path;
234           long include_file_name_len = strchr(&next_quote[1], '"') - next_quote - 1;
235 
236           CeedCall(CeedCalloc(root_length + include_file_name_len + 2, &include_source_path));
237           memcpy(include_source_path, source_file_path, root_length + 1);
238           memcpy(&include_source_path[root_length + 1], &next_quote[1], include_file_name_len);
239           memcpy(&include_source_path[root_length + include_file_name_len + 1], "", 1);
240         } else {
241           char *next_right_chevron = strchr(first_hash, '>');
242           char *ceed_relative_path;
243           long  ceed_relative_path_length = next_right_chevron - next_left_chevron - 1;
244 
245           CeedCall(CeedCalloc(ceed_relative_path_length + 1, &ceed_relative_path));
246           memcpy(ceed_relative_path, &next_left_chevron[1], ceed_relative_path_length);
247           CeedCall(CeedGetJitAbsolutePath(ceed, ceed_relative_path, (const char **)&include_source_path));
248           CeedCall(CeedFree(&ceed_relative_path));
249         }
250         // ---- Recursive call to load source to buffer
251         char *normalized_include_source_path;
252 
253         CeedCall(CeedNormalizePath(ceed, include_source_path, &normalized_include_source_path));
254         for (CeedInt i = 0; i < *num_file_paths; i++) is_included |= !strcmp(normalized_include_source_path, (*file_paths)[i]);
255         if (!is_included) {
256           CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "JiT Including: %s\n", normalized_include_source_path);
257           CeedCall(CeedLoadSourceToInitializedBuffer(ceed, normalized_include_source_path, num_file_paths, file_paths, buffer));
258           CeedCall(CeedRealloc(*num_file_paths + 1, file_paths));
259           CeedCall(CeedStringAllocCopy(normalized_include_source_path, &(*file_paths)[*num_file_paths]));
260           (*num_file_paths)++;
261         }
262         CeedCall(CeedFree(&include_source_path));
263         CeedCall(CeedFree(&normalized_include_source_path));
264       } else if (!is_std_header) {
265         long header_copy_size = next_new_line - first_hash + 1;
266 
267         CeedCall(CeedRealloc(current_size + copy_size + header_copy_size + 2, buffer));
268         memcpy(&(*buffer)[current_size + copy_size], "\n", 2);
269         memcpy(&(*buffer)[current_size + copy_size + 1], first_hash, header_copy_size);
270         memcpy(&(*buffer)[current_size + copy_size + header_copy_size], "", 1);
271       }
272       file_offset = strchr(first_hash, '\n') - temp_buffer + 1;
273     }
274     // -- Next hash
275     first_hash = strchr(&first_hash[1], '#');
276   }
277   // Copy rest of source file into buffer
278   long current_size = strlen(*buffer);
279   long copy_size    = strlen(&temp_buffer[file_offset]);
280 
281   CeedCall(CeedRealloc(current_size + copy_size + 2, buffer));
282   memcpy(&(*buffer)[current_size], "\n", 2);
283   memcpy(&(*buffer)[current_size + 1], &temp_buffer[file_offset], copy_size);
284   memcpy(&(*buffer)[current_size + copy_size + 1], "", 1);
285 
286   // Cleanup
287   CeedCall(CeedFree(&temp_buffer));
288 
289   // Debug
290   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "---------- Ceed JiT ----------\n");
291   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Current source file: ");
292   CeedDebug(ceed, "%s\n", source_file_path);
293   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Final buffer:\n");
294   CeedDebug(ceed, "%s\n", *buffer);
295   return CEED_ERROR_SUCCESS;
296 }
297 
298 /**
299   @brief Load source file into initialized string buffer, including full text of local files in place of `#include "local.h"`.
300     This also initializes and populates the `num_file_paths` and `source_file_paths`.
301     Callers are responsible freeing all filepath strings and the string buffer with @ref CeedFree().
302 
303   @param[in]     ceed             `Ceed` object for error handling
304   @param[in]     source_file_path Absolute path to source file
305   @param[in,out] num_file_paths   Number of files already included
306   @param[in,out] file_paths       Paths of files already included
307   @param[out]    buffer           String buffer for source file contents
308 
309   @return An error code: 0 - success, otherwise - failure
310 
311   @ref Backend
312 **/
313 int CeedLoadSourceAndInitializeBuffer(Ceed ceed, const char *source_file_path, CeedInt *num_file_paths, char ***file_paths, char **buffer) {
314   // Ensure defaults were set
315   *num_file_paths = 0;
316   *file_paths     = NULL;
317 
318   // Initialize
319   CeedCall(CeedCalloc(1, buffer));
320 
321   // And load source
322   CeedCall(CeedLoadSourceToInitializedBuffer(ceed, source_file_path, num_file_paths, file_paths, buffer));
323   return CEED_ERROR_SUCCESS;
324 }
325 
326 /**
327   @brief Initialize and load source file into string buffer, including full text of local files in place of `#include "local.h"`.
328     User @ref CeedLoadSourceAndInitializeBuffer() and @ref CeedLoadSourceToInitializedBuffer() if loading multiple source files into the same buffer.
329     Caller is responsible for freeing the string buffer with @ref CeedFree().
330 
331   @param[in]  ceed             `Ceed` object for error handling
332   @param[in]  source_file_path Absolute path to source file
333   @param[out] buffer           String buffer for source file contents
334 
335   @return An error code: 0 - success, otherwise - failure
336 
337   @ref Backend
338 **/
339 int CeedLoadSourceToBuffer(Ceed ceed, const char *source_file_path, char **buffer) {
340   char  **file_paths     = NULL;
341   CeedInt num_file_paths = 0;
342 
343   // Load
344   CeedCall(CeedLoadSourceAndInitializeBuffer(ceed, source_file_path, &num_file_paths, &file_paths, buffer));
345 
346   // Cleanup
347   for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i]));
348   CeedCall(CeedFree(&file_paths));
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 /**
353   @brief Build an absolute filepath from a base filepath and an absolute filepath.
354 
355   This helps construct source file paths for @ref CeedLoadSourceToBuffer().
356 
357   Note: Caller is responsible for freeing the string buffer with @ref CeedFree().
358 
359   @param[in]  ceed               `Ceed` object for error handling
360   @param[in]  base_file_path     Absolute path to current file
361   @param[in]  relative_file_path Relative path to target file
362   @param[out] new_file_path      String buffer for absolute path to target file
363 
364   @return An error code: 0 - success, otherwise - failure
365 
366   @ref Backend
367 **/
368 int CeedPathConcatenate(Ceed ceed, const char *base_file_path, const char *relative_file_path, char **new_file_path) {
369   char  *last_slash  = strrchr(base_file_path, '/');
370   size_t base_length = (last_slash - base_file_path + 1), relative_length = strlen(relative_file_path),
371          new_file_path_length = base_length + relative_length + 1;
372 
373   CeedCall(CeedCalloc(new_file_path_length, new_file_path));
374   memcpy(*new_file_path, base_file_path, base_length);
375   memcpy(&((*new_file_path)[base_length]), relative_file_path, relative_length);
376   return CEED_ERROR_SUCCESS;
377 }
378 
379 /**
380   @brief Find the relative filepath to an installed JiT file
381 
382   @param[in]  absolute_file_path Absolute path to installed JiT file
383   @param[out] relative_file_path Relative path to installed JiT file, a substring of the absolute path
384 
385   @return An error code: 0 - success, otherwise - failure
386 
387   @ref Backend
388 **/
389 int CeedGetJitRelativePath(const char *absolute_file_path, const char **relative_file_path) {
390   *(relative_file_path) = strstr(absolute_file_path, "ceed/jit-source");
391   CeedCheck(*relative_file_path, NULL, CEED_ERROR_MAJOR, "Couldn't find relative path including 'ceed/jit-source' for: %s", absolute_file_path);
392   return CEED_ERROR_SUCCESS;
393 }
394 
395 /**
396   @brief Build an absolute filepath to a JiT file
397 
398   @param[in]  ceed               `Ceed` object for error handling
399   @param[in]  relative_file_path Relative path to installed JiT file
400   @param[out] absolute_file_path String buffer for absolute path to target file, to be freed by caller
401 
402   @return An error code: 0 - success, otherwise - failure
403 
404   @ref Backend
405 **/
406 int CeedGetJitAbsolutePath(Ceed ceed, const char *relative_file_path, const char **absolute_file_path) {
407   const char **jit_source_dirs;
408   CeedInt      num_source_dirs;
409 
410   // Debug
411   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "---------- Ceed JiT ----------\n");
412   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Relative JiT source file: ");
413   CeedDebug(ceed, "%s\n", relative_file_path);
414 
415   CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_source_dirs, &jit_source_dirs));
416   for (CeedInt i = 0; i < num_source_dirs; i++) {
417     bool is_valid;
418 
419     // Debug
420     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Checking JiT root: ");
421     CeedDebug(ceed, "%s\n", jit_source_dirs[i]);
422 
423     // Build and check absolute path with current root
424     CeedCall(CeedPathConcatenate(ceed, jit_source_dirs[i], relative_file_path, (char **)absolute_file_path));
425     CeedCall(CeedCheckFilePath(ceed, *absolute_file_path, &is_valid));
426 
427     if (is_valid) {
428       CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
429       return CEED_ERROR_SUCCESS;
430     }
431     // LCOV_EXCL_START
432     else
433       CeedCall(CeedFree(absolute_file_path));
434     // LCOV_EXCL_STOP
435   }
436   // LCOV_EXCL_START
437   return CeedError(ceed, CEED_ERROR_MAJOR, "Couldn't find matching JiT source file: %s", relative_file_path);
438   // LCOV_EXCL_STOP
439 }
440