xref: /libCEED/julia/LibCEED.jl/.style/ceed_style.jl (revision 60801d19602b94955220fc3cc63a65b52bc34d1e)
1using JuliaFormatter, CSTParser, Tokenize
2
3for name in names(JuliaFormatter, all=true)
4    if name != :include && name != :eval && name != Base.Docs.META
5        @eval using JuliaFormatter: $name
6    end
7end
8
9# Same as DefaultStyle, but no space in between operators with precedence CSTParser.TimesOp
10struct CeedStyle <: AbstractStyle end
11@inline JuliaFormatter.getstyle(s::CeedStyle) = s
12
13function JuliaFormatter.p_binaryopcall(
14    ds::CeedStyle,
15    cst::CSTParser.EXPR,
16    s::State;
17    nonest=false,
18    nospace=false,
19)
20    style = getstyle(ds)
21    t = FST(Binary, cst, nspaces(s))
22    op = cst[2]
23
24    nonest = nonest || CSTParser.is_colon(op)
25
26    if CSTParser.iscurly(cst.parent) &&
27       (op.val == "<:" || op.val == ">:") &&
28       !s.opts.whitespace_typedefs
29        nospace = true
30    elseif CSTParser.is_colon(op)
31        nospace = true
32    end
33    nospace_args = s.opts.whitespace_ops_in_indices ? false : nospace
34
35    if is_opcall(cst[1])
36        n = pretty(style, cst[1], s, nonest=nonest, nospace=nospace_args)
37    else
38        n = pretty(style, cst[1], s)
39    end
40
41    if CSTParser.is_colon(op) &&
42       s.opts.whitespace_ops_in_indices &&
43       !is_leaf(cst[1]) &&
44       !is_iterable(cst[1])
45        paren = FST(PUNCTUATION, -1, n.startline, n.startline, "(")
46        add_node!(t, paren, s)
47        add_node!(t, n, s, join_lines=true)
48        paren = FST(PUNCTUATION, -1, n.startline, n.startline, ")")
49        add_node!(t, paren, s, join_lines=true)
50    else
51        add_node!(t, n, s)
52    end
53
54    nrhs = nest_rhs(cst)
55    nrhs && (t.nest_behavior = AlwaysNest)
56    nest = (is_binaryop_nestable(style, cst) && !nonest) || nrhs
57
58    if op.fullspan == 0
59        # Do nothing - represents a binary op with no textual representation.
60        # For example: `2a`, which is equivalent to `2 * a`.
61    elseif CSTParser.is_exor(op)
62        add_node!(t, pretty(style, op, s), s, join_lines=true)
63    elseif (CSTParser.isnumber(cst[1]) || is_circumflex_accent(op)) &&
64           CSTParser.isdotted(op)
65        add_node!(t, Whitespace(1), s)
66        add_node!(t, pretty(style, op, s), s, join_lines=true)
67        nest ? add_node!(t, Placeholder(1), s) : add_node!(t, Whitespace(1), s)
68    elseif !(CSTParser.is_in(op) || CSTParser.is_elof(op)) && (
69        nospace || (
70            !CSTParser.is_anon_func(op) && precedence(op) in (
71                CSTParser.PowerOp,
72                CSTParser.DeclarationOp,
73                CSTParser.DotOp,
74                CSTParser.TimesOp,
75            )
76        )
77    )
78        add_node!(t, pretty(style, op, s), s, join_lines=true)
79    elseif op.val in RADICAL_OPS
80        add_node!(t, pretty(style, op, s), s, join_lines=true)
81    else
82        add_node!(t, Whitespace(1), s)
83        add_node!(t, pretty(style, op, s), s, join_lines=true)
84        nest ? add_node!(t, Placeholder(1), s) : add_node!(t, Whitespace(1), s)
85    end
86
87    if is_opcall(cst[3])
88        n = pretty(style, cst[3], s, nonest=nonest, nospace=nospace_args)
89    else
90        n = pretty(style, cst[3], s)
91    end
92
93    if CSTParser.is_colon(op) &&
94       s.opts.whitespace_ops_in_indices &&
95       !is_leaf(cst[3]) &&
96       !is_iterable(cst[3])
97        paren = FST(PUNCTUATION, -1, n.startline, n.startline, "(")
98        add_node!(t, paren, s, join_lines=true)
99        add_node!(t, n, s, join_lines=true, override_join_lines_based_on_source=!nest)
100        paren = FST(PUNCTUATION, -1, n.startline, n.startline, ")")
101        add_node!(t, paren, s, join_lines=true)
102    else
103        add_node!(t, n, s, join_lines=true, override_join_lines_based_on_source=!nest)
104    end
105
106    if nest
107        # for indent, will be converted to `indent` if needed
108        insert!(t.nodes, length(t.nodes), Placeholder(0))
109    end
110
111    t
112end
113
114function JuliaFormatter.p_chainopcall(
115    ds::CeedStyle,
116    cst::CSTParser.EXPR,
117    s::State;
118    nonest=false,
119    nospace=false,
120)
121    style = getstyle(ds)
122    t = FST(Chain, cst, nspaces(s))
123
124    # Check if there's a number literal on the LHS of a dot operator.
125    # In this case we need to surround the dot operator with whitespace
126    # in order to avoid ambiguity.
127    for (i, a) in enumerate(cst)
128        if CSTParser.isoperator(a) && CSTParser.isdotted(a) && CSTParser.isnumber(cst[i-1])
129            nospace = false
130            break
131        end
132    end
133
134    nws = nospace ? 0 : 1
135    for (i, a) in enumerate(cst)
136        nws_op = precedence(a) == CSTParser.TimesOp ? 0 : nws
137        if CSTParser.isoperator(a)
138            add_node!(t, Whitespace(nws_op), s)
139            add_node!(t, pretty(style, a, s), s, join_lines=true)
140            if nonest
141                add_node!(t, Whitespace(nws_op), s)
142            else
143                add_node!(t, Placeholder(nws_op), s)
144            end
145        elseif is_opcall(a)
146            add_node!(
147                t,
148                pretty(style, a, s, nospace=nospace, nonest=nonest),
149                s,
150                join_lines=true,
151            )
152        elseif i == length(cst) - 1 && is_punc(a) && is_punc(cst[i+1])
153            add_node!(t, pretty(style, a, s), s, join_lines=true)
154        else
155            add_node!(t, pretty(style, a, s), s, join_lines=true)
156        end
157    end
158    t
159end
160
161prefix_path(fname) = joinpath(@__DIR__, "..", fname)
162format(
163    prefix_path.(["src", "test", "examples", ".style"]),
164    style=CeedStyle(),
165    indent=4,
166    margin=92,
167    remove_extra_newlines=true,
168    whitespace_in_kwargs=false,
169)
170