|
15 | 15 |
|
16 | 16 | import tree_sitter_devicetree as ts |
17 | 17 | from pcpp.preprocessor import Action, OutputDirective, Preprocessor # type: ignore |
18 | | -from tree_sitter import Language, Node, Parser, Tree |
| 18 | +from tree_sitter import Language, Node, Parser, Query, QueryCursor, Tree |
19 | 19 |
|
20 | 20 | logger = logging.getLogger(__name__) |
21 | 21 |
|
@@ -115,6 +115,10 @@ class DeviceTree: |
115 | 115 |
|
116 | 116 | _custom_data_header = "__keymap_drawer_data__" |
117 | 117 |
|
| 118 | + _root_query = Query(TS_LANG, '(document (node name: (identifier) @nodename (#eq? @nodename "/")) @rootnode)') |
| 119 | + _override_query = Query(TS_LANG, "(document (node name: (reference label: (identifier))) @overridenode)") |
| 120 | + _chosen_query = Query(TS_LANG, '(node name: (identifier) @nodename (#eq? @nodename "chosen")) @chosennode') |
| 121 | + |
118 | 122 | def __init__( |
119 | 123 | self, |
120 | 124 | in_str: str, |
@@ -144,55 +148,25 @@ def __init__( |
144 | 148 | self.override_nodes = [DTNode(node, self.ts_buffer) for node in self._find_override_ts_nodes(tree)] |
145 | 149 | self.chosen_nodes = [DTNode(node, self.ts_buffer) for node in self._find_chosen_ts_nodes(tree)] |
146 | 150 |
|
147 | | - @staticmethod |
148 | | - def _find_root_ts_nodes(tree: Tree) -> list[Node]: |
| 151 | + @classmethod |
| 152 | + def _find_root_ts_nodes(cls, tree: Tree) -> list[Node]: |
149 | 153 | return sorted( |
150 | | - TS_LANG.query( |
151 | | - """ |
152 | | - (document |
153 | | - (node |
154 | | - name: (identifier) @nodename |
155 | | - (#eq? @nodename "/") |
156 | | - ) @rootnode |
157 | | - ) |
158 | | - """ |
159 | | - ) |
160 | | - .captures(tree.root_node) |
161 | | - .get("rootnode", []), |
| 154 | + QueryCursor(cls._root_query).captures(tree.root_node).get("rootnode", []), |
162 | 155 | key=lambda node: node.start_byte, |
163 | 156 | ) |
164 | 157 |
|
165 | | - @staticmethod |
166 | | - def _find_override_ts_nodes(tree: Tree) -> list[Node]: |
| 158 | + @classmethod |
| 159 | + def _find_override_ts_nodes(cls, tree: Tree) -> list[Node]: |
167 | 160 | return sorted( |
168 | | - TS_LANG.query( |
169 | | - """ |
170 | | - (document |
171 | | - (node |
172 | | - name: (reference |
173 | | - label: (identifier) |
174 | | - ) |
175 | | - ) @overridenode |
176 | | - ) |
177 | | - """ |
178 | | - ) |
179 | | - .captures(tree.root_node) |
180 | | - .get("overridenode", []), |
| 161 | + QueryCursor(cls._override_query).captures(tree.root_node).get("overridenode", []), |
181 | 162 | key=lambda node: node.start_byte, |
182 | 163 | ) |
183 | 164 |
|
184 | | - @staticmethod |
185 | | - def _find_chosen_ts_nodes(tree: Tree) -> list[Node]: |
| 165 | + @classmethod |
| 166 | + def _find_chosen_ts_nodes(cls, tree: Tree) -> list[Node]: |
186 | 167 | return sorted( |
187 | | - TS_LANG.query( |
188 | | - """ |
189 | | - (node |
190 | | - name: (identifier) @nodename |
191 | | - (#eq? @nodename "chosen") |
192 | | - ) @chosennode |
193 | | - """ |
194 | | - ) |
195 | | - .set_max_start_depth(2) |
| 168 | + QueryCursor(cls._chosen_query) |
| 169 | + .set_max_start_depth(2) # type: ignore |
196 | 170 | .captures(tree.root_node) |
197 | 171 | .get("chosennode", []), |
198 | 172 | key=lambda node: node.start_byte, |
@@ -225,13 +199,15 @@ def on_error_handler(file, line, msg): # type: ignore |
225 | 199 |
|
226 | 200 | def get_compatible_nodes(self, compatible_value: str) -> list[DTNode]: |
227 | 201 | """Return a list of nodes that have the given compatible value.""" |
228 | | - query = TS_LANG.query( |
229 | | - rf""" |
230 | | - (node |
231 | | - (property name: (identifier) @prop value: (string_literal) @propval) |
232 | | - (#eq? @prop "compatible") (#eq? @propval "\"{compatible_value}\"") |
233 | | - ) @node |
234 | | - """ |
| 202 | + query = QueryCursor( |
| 203 | + Query( |
| 204 | + TS_LANG, |
| 205 | + rf""" |
| 206 | + (node (property name: (identifier) @prop value: (string_literal) @propval) |
| 207 | + (#eq? @prop "compatible") (#eq? @propval "\"{compatible_value}\"") |
| 208 | + ) @node |
| 209 | + """, |
| 210 | + ) |
235 | 211 | ) |
236 | 212 | nodes = chain.from_iterable(query.captures(node).get("node", []) for node in self.root_nodes) |
237 | 213 | return sorted( |
|
0 commit comments