1 module dimportsort;
2 
3 import std.algorithm : cmp, copy, count, map, setIntersection, sort, uniq;
4 import std.array : array, join;
5 import std.format : format;
6 import std.stdio : writeln;
7 import std..string : empty;
8 
9 import dparse.ast;
10 import dparse.lexer;
11 import dparse.parser : parseModule;
12 import dparse.rollback_allocator : RollbackAllocator;
13 
14 
15 ///
16 class ImportVisitor : ASTVisitor {
17 
18   ///
19   this(string sourceCode) {
20     this.cache = StringCache(StringCache.defaultBucketCount);
21     this.sourceCode = sourceCode;
22   }
23 
24   alias visit = ASTVisitor.visit;
25 
26   /** Visit import declaration.
27 
28    Params:
29      decl = import declaration.
30 
31    Syntax:
32 
33      importBind:
34        Identifier ('=' Identifier)?
35        ;
36      importBindings:
37        singleImport ':' importBind (',' importBind)*
38        ;
39      importDeclaration:
40        | 'import' singleImport (',' singleImport)* (',' importBindings)? ';'
41        | 'import' importBindings ';'
42        ;
43    */
44   override void visit(const ImportDeclaration decl) {
45     if (importGroups.empty ||
46         !isConsective(declGroups[$-1][$-1], decl)) {
47       declGroups ~= [decl];
48       importGroups ~= toIdentifiers(decl);
49       return;
50     }
51     declGroups[$-1] ~= decl;
52     importGroups[$-1] ~= toIdentifiers(decl);
53 
54     decl.accept(this);
55   }
56 
57   struct Output {
58     string mod;
59     string[] binds;
60   }
61 
62   pure @safe
63   string outputImports(ImportIdentifiers[] idents, string indent = "") const {
64     // TODO: support max line length.
65     sort(idents);
66     // Merge redundant modules.
67     Output[] outputs;
68     foreach (id; idents) {
69       if (outputs.empty || outputs[$-1].mod != id.name) {
70         outputs ~= Output(id.name, id.bindNames);
71         continue;
72       }
73       outputs[$-1].binds ~= id.bindNames;
74     }
75     
76     string ret;
77     foreach (o; outputs) {
78       ret ~= indent ~ "import " ~ o.mod;
79       if (!o.binds.empty) {
80         sort(o.binds);
81         ret ~= " : " ~ o.binds.uniq.join(", ");
82       }
83       ret ~= ";\n";
84     }
85     // Remove the last new line (\n).
86     return ret[0 .. $-1];
87   }
88 
89   string diff() {
90     import std.algorithm : find;
91     import std.range : drop, take;
92     import std.algorithm : maxElement, minElement, joiner, splitter;
93 
94     string ret;
95     foreach (i, decls; declGroups) {
96       auto lines = decls.map!(d => d.tokens.map!(t => t.line)).joiner;
97       auto min = lines.minElement - 1;
98       auto max = lines.maxElement;
99       auto input = sourceCode.splitter('\n').drop(min).take(max - min).join("\n");
100 
101       auto indent = input[0 .. $ - input.find("import").length];
102       auto output = outputImports(importGroups[i], indent);
103       if (input == output) continue;
104 
105       ret ~= format!"<<<<%s:%d-%d\n"(fileName, min, max)
106           ~ input ~ "\n"
107           ~ "----\n"
108           ~ output ~ "\n"
109           ~ ">>>>\n";
110     }
111     return ret;
112   }
113 
114  private:
115   string sourceCode;
116   string fileName;
117   const(ImportDeclaration)[][] declGroups;
118   ImportIdentifiers[][] importGroups;
119 
120   // For ownerships of tokens.
121   RollbackAllocator rba;
122   StringCache cache;
123 }
124 
125 /// Checks import declarations are consective.
126 @nogc nothrow pure @safe
127 bool isConsective(const ImportDeclaration a, const ImportDeclaration b) {
128   return !setIntersection(a.tokens.map!"a.line + 1", b.tokens.map!"a.line").empty;
129 }
130 
131 ImportVisitor visitImports(string sourceCode, string fileName = "unittest") {
132   auto visitor = new ImportVisitor(sourceCode);
133   LexerConfig config;
134   auto tokens = getTokensForParser(sourceCode, config, &visitor.cache);
135   auto m = parseModule(tokens, fileName, &visitor.rba);
136   visitor.visit(m);
137   visitor.fileName = fileName;
138   return visitor;
139 }
140 
141 /// Test for diff outputs.
142 unittest {
143   auto visitor = visitImports(q{
144     import cc;
145     import ab;
146     import aa.cc;
147     import aa.bb;
148 
149 import foo;
150 import bar, bar2;  // expands to two imports.
151 
152     void main() {}
153     });
154   assert(visitor.declGroups.length == 2);
155   assert(visitor.declGroups[0].length == 4);
156   assert(visitor.declGroups[1].length == 2);
157 
158   assert(visitor.importGroups.length == 2);
159   assert(visitor.importGroups[0].length == 4);
160   assert(visitor.importGroups[1].length == 3);
161 
162   assert(visitor.diff ==
163 `<<<<unittest:1-5
164     import cc;
165     import ab;
166     import aa.cc;
167     import aa.bb;
168 ----
169     import aa.bb;
170     import aa.cc;
171     import ab;
172     import cc;
173 >>>>
174 <<<<unittest:6-8
175 import foo;
176 import bar, bar2;  // expands to two imports.
177 ----
178 import bar;
179 import bar2;
180 import foo;
181 >>>>
182 `);
183 }
184 
185 /// Data type for identifiers in an import declaration.
186 /// import mod : binds, ...;
187 class ImportIdentifiers {
188   this(const SingleImport si, const ImportBind[] binds = []) {
189     this.singleImport = si;
190     this.binds = binds;
191   }
192 
193   const SingleImport singleImport;
194   const ImportBind[] binds;
195 
196   pure nothrow @safe
197   string name() const {
198     return singleImport.identifierChain.identifiers.map!"a.text".join(".");
199   }
200 
201   pure nothrow @safe
202   string[] bindNames() const {
203     auto ret = new string[binds.length];
204     copy(binds.map!"a.left.text", ret);
205     sort(ret);
206     return ret;
207   }
208 
209   pure @safe
210   override string toString() const {
211     return format!"%s(name=%s, binds=%s)"(typeof(this).stringof, name, bindNames);
212   }
213 
214   nothrow pure @safe
215   int opCmp(ImportIdentifiers that) const {
216     return cmp(this.name, that.name);
217   }
218 }
219 
220 /// Test for binding.
221 unittest {
222   auto visitor = visitImports(q{
223       import foo : aa, cc, bb;
224     });
225   assert(visitor.importGroups[0][0].name == "foo");
226   assert(visitor.importGroups[0][0].bindNames == ["aa", "bb", "cc"]);
227 }
228 
229 /// Decomposes multi module import decl to a list of single module with binds.
230 ImportIdentifiers[] toIdentifiers(const ImportDeclaration decl) {
231   auto ret = decl.singleImports.map!(x => new ImportIdentifiers(x)).array;
232   if (auto binds = decl.importBindings) {
233     ret ~= new ImportIdentifiers(binds.singleImport, binds.importBinds);
234   }
235   return ret;
236 }
237 
238 /// Test for multiple modules and binding.
239 unittest {
240   auto visitor = visitImports(q{
241       import foo, bar : aa, cc, bb;
242     });
243   auto ids = visitor.importGroups[0];
244   assert(ids[0].name == "foo");
245   assert(ids[0].bindNames == []);
246   assert(ids[1].name == "bar");
247   assert(ids[1].bindNames == ["aa", "bb", "cc"]);
248 
249   // Test opCmp in sort.
250   sort(ids);
251   assert(ids[0].name == "bar");
252   assert(ids[1].name == "foo");
253 }
254 
255 /// Test for merging redundant modules.
256 unittest {
257   auto visitor = visitImports(q{
258       import foo : bar;
259       import foo : baz, bar;
260     });
261   assert(visitor.outputImports(visitor.importGroups[0]) ==
262          "import foo : bar, baz;");
263 }