1 module dimportsort;
2 
3 import std.algorithm : cmp, copy, count, equal, map, setIntersection, sort, uniq;
4 import std.array : array, join;
5 import std.format : format;
6 import std.stdio : writeln;
7 import std.string : empty, strip;
8 
9 import dparse.ast;
10 import dparse.lexer : getTokensForParser, LexerConfig, str, StringCache;
11 import dparse.parser : parseModule;
12 import dparse.rollback_allocator : RollbackAllocator;
13 
14 
15 ///
16 class ImportVisitor : ASTVisitor {
17 
18   ///
19   this(string sourceCode) {
20     this.cache = StringCache(StringCache.defaultBucketCount);
21     this.sourceCode = sourceCode;
22   }
23 
24   alias visit = ASTVisitor.visit;
25 
26   /**
27      Syntax:
28 
29      declaration:
30        | attribute* declaration2
31        | attribute+ '{' declaration* '}'
32        ;
33      attribute:
34        | public
35        | private
36        | protected
37        | package
38        | static
39        | ...
40        ;
41      declaration2:
42        | importDeclaration
43        | ...
44        ;
45 
46      importBind:
47        Identifier ('=' Identifier)?
48        ;
49      importBindings:
50        singleImport ':' importBind (',' importBind)*
51        ;
52      importDeclaration:
53        | 'import' singleImport (',' singleImport)* (',' importBindings)? ';'
54        | 'import' importBindings ';'
55        ;
56    */
57   override void visit(const Declaration decl) {
58     decl.accept(this);
59     if (auto idecl = decl.importDeclaration) {
60       if (importGroups.empty ||
61           !isConsective(declGroups[$-1][$-1], decl)) {
62         declGroups ~= [decl];
63         importGroups ~= toIdentifiers(decl);
64         return;
65       }
66       declGroups[$-1] ~= decl;
67       importGroups[$-1] ~= toIdentifiers(decl);      
68     }
69   }
70 
71   string diff() {
72     import std.algorithm : find;
73     import std.range : drop, take;
74     import std.algorithm : maxElement, minElement, joiner, splitter;
75 
76     string ret;
77     foreach (i, decls; declGroups) {
78       auto lines = decls.map!(d => d.tokens.map!(t => t.line)).joiner;
79       auto min = lines.minElement - 1;
80       auto max = lines.maxElement;
81       auto input = sourceCode.splitter('\n').drop(min).take(max - min).join("\n");
82 
83       auto indent = input[0 .. $ - input.find("import").length];
84       auto output = outputImports(importGroups[i], indent);
85       if (input == output) continue;
86 
87       ret ~= format!"<<<<%s:%d-%d\n"(fileName, min, max)
88           ~ input ~ "\n"
89           ~ "----\n"
90           ~ output ~ "\n"
91           ~ ">>>>\n";
92     }
93     return ret;
94   }
95 
96  private:
97   
98   struct Output {
99     string mod;
100     string[] binds;
101     string[] attrs;
102 
103     bool canMerge(Output that) const {
104       return this.mod == that.mod && equal(this.attrs, that.attrs);
105     }
106   }
107 
108   string outputImports(ImportIdentifiers[] idents, string indent = "") const {
109     import std.range : chain, only;
110     // TODO: support max line length.
111     sort(idents);
112     // Merge redundant modules.
113     Output[] outputs;
114     foreach (id; idents) {
115       auto attrs = id.attrs.array.dup;
116       sort(attrs);
117       auto o = Output(id.name, id.bindNames.array.dup, attrs);
118       if (outputs.empty || !outputs[$-1].canMerge(o)) {
119         outputs ~= o;
120         continue;
121       }
122       outputs[$-1].binds ~= o.binds;
123     }
124     
125     string ret;
126     foreach (o; outputs) {
127       ret ~= indent;
128       if (!o.attrs.empty) {
129         ret ~= o.attrs.join(" ") ~ " ";
130       }
131       ret ~= "import " ~ o.mod;
132       if (!o.binds.empty) {
133         sort(o.binds);
134         ret ~= " : " ~ o.binds.uniq.join(", ");
135       }
136       ret ~= ";\n";
137     }
138     // Remove the last new line (\n).
139     return ret[0 .. $-1];
140   }
141   
142   string sourceCode;
143   string fileName;
144   const(Declaration)[][] declGroups;
145   ImportIdentifiers[][] importGroups;
146 
147   // For ownerships of tokens.
148   RollbackAllocator rba;
149   StringCache cache;
150 }
151 
152 /// Checks declarations are consective.
153 @nogc nothrow pure @safe
154 bool isConsective(const Declaration a, const Declaration b) {
155   return !setIntersection(a.tokens.map!"a.line + 1", b.tokens.map!"a.line").empty;
156 }
157 
158 ImportVisitor visitImports(string sourceCode, string fileName = "unittest") {
159   auto visitor = new ImportVisitor(sourceCode);
160   LexerConfig config;
161   auto tokens = getTokensForParser(sourceCode, config, &visitor.cache);
162   auto m = parseModule(tokens, fileName, &visitor.rba);
163   visitor.visit(m);
164   visitor.fileName = fileName;
165   return visitor;
166 }
167 
168 /// Test for diff outputs.
169 unittest {
170   auto visitor = visitImports(q{
171     import cc;
172     import ab;
173     import aa.cc;
174     import aa.bb;
175 
176 import foo;
177 import bar, bar2;  // expands to two imports.
178 
179     void main() {}
180     });
181   assert(visitor.declGroups.length == 2);
182   assert(visitor.declGroups[0].length == 4);
183   assert(visitor.declGroups[1].length == 2);
184 
185   assert(visitor.importGroups.length == 2);
186   assert(visitor.importGroups[0].length == 4);
187   assert(visitor.importGroups[1].length == 3);
188 
189   assert(visitor.diff ==
190 `<<<<unittest:1-5
191     import cc;
192     import ab;
193     import aa.cc;
194     import aa.bb;
195 ----
196     import aa.bb;
197     import aa.cc;
198     import ab;
199     import cc;
200 >>>>
201 <<<<unittest:6-8
202 import foo;
203 import bar, bar2;  // expands to two imports.
204 ----
205 import bar;
206 import bar2;
207 import foo;
208 >>>>
209 `);
210 }
211 
212 /// Data type for identifiers in an import declaration.
213 /// import mod : binds, ...;
214 class ImportIdentifiers {
215   this(const Attribute[] attributes, const SingleImport si,
216        const ImportBind[] binds = []) {
217     this.attributes = attributes;
218     this.singleImport = si;
219     this.binds = binds;
220   }
221   
222   const Attribute[] attributes;
223   const SingleImport singleImport;
224   const ImportBind[] binds;
225 
226   pure nothrow @safe
227   string name() const {
228     return singleImport.identifierChain.identifiers.map!"a.text".join(".");
229   }
230 
231   pure nothrow @safe
232   auto bindNames() const {
233     return binds.map!"a.left.text";
234   }
235 
236   auto attrs() const {
237     return attributes.map!(a => str(a.attribute.type));
238   }
239 
240   pure @safe
241   override string toString() const {
242     return format!"%s(name=%s, binds=%s)"(typeof(this).stringof, name, bindNames);
243   }
244 
245   int opCmp(ImportIdentifiers that) const {
246     // First sort by the module name w/o attrs.
247     auto ret = cmp(this.name, that.name); 
248     if (ret != 0) {
249       return ret;
250     }
251     // Then sort by attrs.
252     return cmp(this.attrs.join(" "), that.attrs.join(" "));
253   }
254 }
255 
256 /// Test for binding.
257 unittest {
258   auto visitor = visitImports(q{
259       import foo : aa, bb, cc;
260     });
261   assert(visitor.importGroups[0][0].name == "foo");
262   assert(equal(visitor.importGroups[0][0].bindNames, ["aa", "bb", "cc"]));
263 }
264 
265 /// Decomposes multi module import decl to a list of single module with binds.
266 ImportIdentifiers[] toIdentifiers(const Declaration decl) {
267   auto idecl = decl.importDeclaration;
268   assert(idecl !is null, "not import declaration.");
269   auto ret = idecl.singleImports.map!(
270       x => new ImportIdentifiers(decl.attributes, x)).array;
271   if (auto binds = idecl.importBindings) {
272     ret ~= new ImportIdentifiers(
273         decl.attributes, binds.singleImport, binds.importBinds);
274   }
275   return ret;
276 }
277 
278 /// Test for import attributes.
279 unittest {
280   auto visitor = visitImports(q{
281       public import foo;
282       public static import bar;
283     });
284   auto ids = visitor.importGroups[0];
285   assert(ids[0].name == "foo");
286   assert(equal(ids[0].attrs, ["public"]));
287   assert(ids[1].name == "bar");
288   assert(equal(ids[1].attrs, ["public", "static"]));
289 }
290 
291 /// Test for multiple modules and binding.
292 unittest {
293   auto visitor = visitImports(q{
294       import foo, bar : aa, bb, cc;
295     });
296   auto ids = visitor.importGroups[0];
297   assert(ids[0].name == "foo");
298   assert(ids[0].bindNames.empty);
299   assert(ids[1].name == "bar");
300   assert(equal(ids[1].bindNames, ["aa", "bb", "cc"]));
301 
302   // Test opCmp in sort.
303   sort(ids);
304   assert(ids[0].name == "bar");
305   assert(ids[1].name == "foo");
306 }
307 
308 /// Test for merging redundant modules.
309 unittest {
310   auto visitor = visitImports(q{
311       import foo : bar;
312       import foo : baz, bar;
313     });
314   assert(visitor.outputImports(visitor.importGroups[0]) ==
315          "import foo : bar, baz;");
316 }
317 
318 /// Test for modules with attributes.
319 unittest {
320   auto visitor = visitImports(q{
321       import foo : bar;
322       static import foo;
323       public import foo : bar;
324       public import foo : baz;
325       import bar;
326     });
327   writeln(visitor.outputImports(visitor.importGroups[0]));
328   assert(visitor.outputImports(visitor.importGroups[0]) == q{
329 import bar;
330 import foo : bar;
331 public import foo : bar, baz;
332 static import foo;
333     }.strip);
334 }