1 module dimportsort;
2 
3 import dparse.ast;
4 import dparse.lexer : getTokensForParser, LexerConfig, str, StringCache;
5 import dparse.parser : parseModule;
6 import dparse.rollback_allocator : RollbackAllocator;
7 import std.algorithm : cmp, copy, count, equal, map, setIntersection, sort, uniq;
8 import std.array : array, join;
9 import std.format : format;
10 import std.stdio : writeln;
11 import std..string : empty, strip;
12 import std.uni : sicmp;
13 
14 ///
15 class ImportVisitor : ASTVisitor {
16 
17   ///
18   @nogc nothrow pure
19   this(string sourceCode) {
20     this.cache = StringCache(StringCache.defaultBucketCount);
21     this.sourceCode = sourceCode;
22   }
23 
24   alias visit = ASTVisitor.visit;
25 
26   /**
27      Syntax:
28 
29      declaration:
30        | attribute* declaration2
31        | attribute+ '{' declaration* '}'
32        ;
33      attribute:
34        | public
35        | private
36        | protected
37        | package
38        | static
39        | ...
40        ;
41      declaration2:
42        | importDeclaration
43        | ...
44        ;
45 
46      importBind:
47        Identifier ('=' Identifier)?
48        ;
49      importBindings:
50        singleImport ':' importBind (',' importBind)*
51        ;
52      importDeclaration:
53        | 'import' singleImport (',' singleImport)* (',' importBindings)? ';'
54        | 'import' importBindings ';'
55        ;
56    */
57   override void visit(const Declaration decl) {
58     decl.accept(this);
59     if (auto idecl = decl.importDeclaration) {
60       if (importGroups.empty || !isConsective(declGroups[$-1][$-1], decl)) {
61         declGroups ~= [decl];
62         importGroups ~= toIdentifiers(decl);
63         return;
64       }
65       declGroups[$-1] ~= decl;
66       importGroups[$-1] ~= toIdentifiers(decl);
67     }
68   }
69 
70   /// Returns: diff patch to sort import declarations (WIP).
71   @safe pure
72   string diff() {
73     import std.algorithm : find, joiner, maxElement, minElement, splitter;
74     import std.range : drop, take;
75 
76     string ret;
77     foreach (i, decls; declGroups) {
78       auto lines = decls.map!(d => d.tokens.map!(t => t.line)).joiner;
79       auto min = minElement(lines) - 1;
80       auto max = maxElement(lines);
81       auto input = sourceCode.splitter('\n').drop(min).take(max - min)
82           .join("\n");
83 
84       auto indent = input[0 .. $ - input.find("import").length];
85       auto output = formatSortedImports(sortedImports(importGroups[i]), indent);
86       if (input == output) continue;
87 
88       ret ~= format!"<<<<%s:%d-%d\n"(fileName, min, max)
89           ~ input ~ "\n"
90           ~ "----\n"
91           ~ output ~ "\n"
92           ~ ">>>>\n";
93     }
94     return ret;
95   }
96 
97  private:
98   string sourceCode;
99   string fileName;
100   const(Declaration)[][] declGroups;
101   ImportIdentifiers[][] importGroups;
102 
103   // For ownerships of tokens.
104   RollbackAllocator rba;
105   StringCache cache;
106 }
107 
108 /// Checks declarations are consective.
109 @nogc nothrow pure @safe
110 bool isConsective(const Declaration a, const Declaration b) {
111   return !setIntersection(a.tokens.map!"a.line + 1", b.tokens.map!"a.line")
112       .empty;
113 }
114 
115 ImportVisitor visitImports(string sourceCode, string fileName = "unittest") {
116   auto visitor = new ImportVisitor(sourceCode);
117   LexerConfig config;
118   auto tokens = getTokensForParser(sourceCode, config, &visitor.cache);
119   auto m = parseModule(tokens, fileName, &visitor.rba);
120   visitor.visit(m);
121   visitor.fileName = fileName;
122   return visitor;
123 }
124 
125 /// Test for diff outputs.
126 unittest {
127   auto visitor = visitImports(q{
128     import cc;
129     import ab;
130     import aa.cc;
131     import aa.bb;
132 
133 import foo;
134 import bar, bar2;  // expands to two imports.
135 
136     void main() {}
137     });
138   assert(visitor.declGroups.length == 2);
139   assert(visitor.declGroups[0].length == 4);
140   assert(visitor.declGroups[1].length == 2);
141 
142   assert(visitor.importGroups.length == 2);
143   assert(visitor.importGroups[0].length == 4);
144   assert(visitor.importGroups[1].length == 3);
145 
146   assert(visitor.diff ==
147 `<<<<unittest:1-5
148     import cc;
149     import ab;
150     import aa.cc;
151     import aa.bb;
152 ----
153     import aa.bb;
154     import aa.cc;
155     import ab;
156     import cc;
157 >>>>
158 <<<<unittest:6-8
159 import foo;
160 import bar, bar2;  // expands to two imports.
161 ----
162 import bar;
163 import bar2;
164 import foo;
165 >>>>
166 `);
167 }
168 
169 nothrow pure @safe
170 string attributeStringOf(const Attribute attr) {
171   auto s = str(attr.attribute.type);
172   if (s != "package") return s;
173   return "package("
174       ~ attr.identifierChain.identifiers.map!"a.text".join(".")
175       ~ ")";
176 }
177 
178 /// Test for import attributes.
179 unittest {
180   auto visitor = visitImports(q{
181       public import foo;
182       public static import bar;
183       package(std.regex) import baz;
184     });
185   auto ids = visitor.importGroups[0];
186   assert(ids[0].fullName == "foo");
187   assert(equal(ids[0].attrs, ["public"]));
188   assert(ids[1].fullName == "bar");
189   assert(equal(ids[1].attrs, ["public", "static"]));
190   assert(ids[2].fullName == "baz");
191   assert(equal(ids[2].attrs, ["package(std.regex)"]));
192 }
193 
194 /// Data type for identifiers in an import declaration.
195 /// import mod : binds, ...;
196 class ImportIdentifiers {
197   @nogc nothrow pure @safe
198   this(const Attribute[] attributes, const SingleImport si,
199        const ImportBind[] binds = []) {
200     this.attributes = attributes;
201     this.singleImport = si;
202     this.binds = binds;
203   }
204 
205   const Attribute[] attributes;
206   const SingleImport singleImport;
207   const ImportBind[] binds;
208 
209   nothrow pure @safe
210   string fullName() const {
211     string prefix = singleImport.rename.text;
212     if (!prefix.empty) {
213       prefix ~= " = ";
214     }
215     return prefix ~ names().join(".");
216   }
217 
218   @nogc nothrow pure @safe
219   auto names() const {
220     return singleImport.identifierChain.identifiers.map!"a.text";
221   }
222 
223   @nogc nothrow pure @safe
224   auto bindNames() const {
225     return binds.map!(b => b.left.text ~
226                       (b.right.text.empty ? "" : " = " ~ b.right.text));
227   }
228 
229   nothrow pure @safe
230   auto attrs() const {
231     return attributes.map!attributeStringOf;
232   }
233 
234   /// Returns: string for debugging.
235   pure @safe
236   override string toString() const {
237     return format!"%s(name=\"%s\", binds=\"%s\")"(
238         typeof(this).stringof, fullName, bindNames);
239   }
240 
241   nothrow pure @safe
242   private string cmpName() const {
243     const rename = singleImport.rename.text;
244     return rename.empty ? fullName : rename;
245   }
246 
247   /// Compares identifiers for sorting.
248   nothrow pure @safe
249   int opCmp(ImportIdentifiers that) const {
250     // First sort by the module name w/o attrs. Note that in D-Scanner,
251     // dscanner/analysis/imports_sortedness.d uses sicmp instead of cmp.
252     auto ret = sicmp(this.fullName, that.fullName);
253     if (ret != 0) {
254       return ret;
255     }
256     // Then sort by attrs. cmp is OK because attributes are always lowercased.
257     return cmp(this.attrs, that.attrs);
258   }
259 }
260 
261 /// Test for selective imports.
262 unittest {
263   auto visitor = visitImports(q{
264       import foo : aa, bb, cc;
265     });
266   assert(visitor.importGroups[0][0].fullName == "foo");
267   assert(equal(visitor.importGroups[0][0].bindNames, ["aa", "bb", "cc"]));
268 }
269 
270 /// Test for renamed imports.
271 unittest {
272   auto visitor = visitImports(q{
273       import foo = bar;
274       import bar = foo;
275       import baz.foo;
276       import zz : f = foo;
277     });
278   // Sorting is based on the renamed name if exists.
279   sort(visitor.importGroups[0]);
280   assert(visitor.importGroups[0][0].fullName == "bar = foo");
281   assert(visitor.importGroups[0][1].fullName == "baz.foo");
282   assert(visitor.importGroups[0][2].fullName == "foo = bar");
283   assert(visitor.importGroups[0][3].fullName == "zz");
284 }
285 
286 /// Decomposes multi module import decl to a list of single module with binds.
287 ImportIdentifiers[] toIdentifiers(const Declaration decl) {
288   const idecl = decl.importDeclaration;
289   assert(idecl !is null, "not import declaration.");
290   auto ret = idecl.singleImports.map!(
291       x => new ImportIdentifiers(decl.attributes, x)).array;
292   if (auto binds = idecl.importBindings) {
293     ret ~= new ImportIdentifiers(
294         decl.attributes, binds.singleImport, binds.importBinds);
295   }
296   return ret;
297 }
298 
299 /// Test for multiple modules and binding.
300 unittest {
301   auto visitor = visitImports(q{
302       import foo,
303           bar : aa, bb,
304           cc;
305     });
306   auto ids = visitor.importGroups[0];
307   assert(ids[0].fullName == "foo");
308   assert(ids[0].bindNames.empty);
309   assert(ids[1].fullName == "bar");
310   assert(equal(ids[1].bindNames, ["aa", "bb", "cc"]));
311 }
312 
313 // Test opCmp in sort.
314 unittest {
315   auto visitor = visitImports(q{
316       import foo.bar;
317       import foo, bar : aa, bb, cc;
318       static import foo.bar;
319     });
320   auto ids = visitor.importGroups[0];
321   sort(ids);
322   assert(ids[0].fullName == "bar");
323   assert(ids[1].fullName == "foo");
324   assert(ids[2].fullName == "foo.bar");
325   assert(ids[3].fullName == "foo.bar");
326   assert(equal(ids[3].attrs, ["static"]));
327 }
328 
329 /// Data type to store a sorted import declaration.
330 struct SortedImport {
331   string mod;
332   // These must be sorted.
333   string[] binds;
334   string[] attrs;
335 }
336 
337 /// Checks if two sorted imports can be merged into one.
338 @nogc nothrow pure @safe
339 bool canMerge(SortedImport a, SortedImport b) {
340   return a.mod == b.mod
341       && equal(a.attrs, b.attrs)  // Cannot merge diff attributes.
342       && (a.binds.empty == b.binds.empty);  // Both selective or non-selective.
343 }
344 
345 unittest {
346   assert(SortedImport("a").canMerge(SortedImport("a")));
347   assert(SortedImport("a", ["b"]).canMerge(SortedImport("a", ["c"])));
348   assert(!SortedImport("a", ["b"]).canMerge(SortedImport("a", [])));
349   assert(SortedImport("a", [], ["public"]).canMerge(
350       SortedImport("a", [], ["public"])));
351   assert(!SortedImport("a", [], []).canMerge(
352       SortedImport("a", [], ["public"])));
353 }
354 
355 /// Merges and sorts import identifiers for outputs.
356 pure @safe
357 SortedImport[] sortedImports(ImportIdentifiers[] idents) {
358   import std.range : chain, only;
359   import std..string : split;
360 
361   // TODO: support max line length.
362   sort(idents);
363   // Merge redundant modules.
364   SortedImport[] outputs;
365   foreach (id; idents) {
366     auto attrs = id.attrs.array.dup;
367     sort(attrs);
368     auto o = SortedImport(id.fullName, id.bindNames.array.dup, attrs);
369     if (outputs.empty || !outputs[$-1].canMerge(o)) {
370       outputs ~= o;
371       continue;
372     }
373     outputs[$-1].binds ~= o.binds;
374   }
375   foreach (ref o; outputs) {
376     o.binds = sort!((a, b) => sicmp(a, b) < 0)(o.binds)
377                     .uniq.array;
378   }
379   return outputs;
380 }
381 
382 // Test with renamed selective imports.
383 unittest {
384   auto visitor = visitImports(q{
385       import foo : bar = foo;
386       import foo : zoo = bar;
387       import foo : baaz;
388       import foo : aar;
389     });
390   assert(sortedImports(visitor.importGroups[0]).formatSortedImports ==
391          "import foo : aar, baaz, bar = foo, zoo = bar;");
392 }
393 
394 /// Formats output imports into a string.
395 nothrow pure @safe
396 string formatSortedImports(SortedImport[] outputs, string indent = "") {
397   string ret;
398   foreach (o; outputs) {
399     ret ~= indent;
400     if (!o.attrs.empty) {
401       ret ~= o.attrs.join(" ") ~ " ";
402     }
403     ret ~= "import " ~ o.mod;
404     if (!o.binds.empty) {
405       ret ~= " : " ~ o.binds.join(", ");
406     }
407     ret ~= ";\n";
408   }
409   // Remove the last new line (\n).
410   return ret[0 .. $-1];
411 }
412 
413 /// Test for merging redundant modules.
414 unittest {
415   auto visitor = visitImports(q{
416       import foo : bar;
417       import foo : baz, bar;
418     });
419   assert(sortedImports(visitor.importGroups[0]).formatSortedImports ==
420          "import foo : bar, baz;");
421 }
422 
423 /// Test for modules with attributes.
424 unittest {
425   auto visitor = visitImports(q{
426       import foo : bar;
427       static import foo;
428       public import foo : bar;
429       public import foo : baz;
430       import bar;
431     });
432   assert(sortedImports(visitor.importGroups[0]).formatSortedImports == q{
433 import bar;
434 import foo : bar;
435 public import foo : bar, baz;
436 static import foo;
437     }.strip);
438 }