1 module dimportsort; 2 3 import std.algorithm : cmp, copy, count, equal, map, setIntersection, sort, uniq; 4 import std.array : array, join; 5 import std.format : format; 6 import std.stdio : writeln; 7 import std.string : empty, strip; 8 9 import dparse.ast; 10 import dparse.lexer : getTokensForParser, LexerConfig, str, StringCache; 11 import dparse.parser : parseModule; 12 import dparse.rollback_allocator : RollbackAllocator; 13 14 15 /// 16 class ImportVisitor : ASTVisitor { 17 18 /// 19 this(string sourceCode) { 20 this.cache = StringCache(StringCache.defaultBucketCount); 21 this.sourceCode = sourceCode; 22 } 23 24 alias visit = ASTVisitor.visit; 25 26 /** 27 Syntax: 28 29 declaration: 30 | attribute* declaration2 31 | attribute+ '{' declaration* '}' 32 ; 33 attribute: 34 | public 35 | private 36 | protected 37 | package 38 | static 39 | ... 40 ; 41 declaration2: 42 | importDeclaration 43 | ... 44 ; 45 46 importBind: 47 Identifier ('=' Identifier)? 48 ; 49 importBindings: 50 singleImport ':' importBind (',' importBind)* 51 ; 52 importDeclaration: 53 | 'import' singleImport (',' singleImport)* (',' importBindings)? ';' 54 | 'import' importBindings ';' 55 ; 56 */ 57 override void visit(const Declaration decl) { 58 decl.accept(this); 59 if (auto idecl = decl.importDeclaration) { 60 if (importGroups.empty || 61 !isConsective(declGroups[$-1][$-1], decl)) { 62 declGroups ~= [decl]; 63 importGroups ~= toIdentifiers(decl); 64 return; 65 } 66 declGroups[$-1] ~= decl; 67 importGroups[$-1] ~= toIdentifiers(decl); 68 } 69 } 70 71 string diff() { 72 import std.algorithm : find; 73 import std.range : drop, take; 74 import std.algorithm : maxElement, minElement, joiner, splitter; 75 76 string ret; 77 foreach (i, decls; declGroups) { 78 auto lines = decls.map!(d => d.tokens.map!(t => t.line)).joiner; 79 auto min = lines.minElement - 1; 80 auto max = lines.maxElement; 81 auto input = sourceCode.splitter('\n').drop(min).take(max - min).join("\n"); 82 83 auto indent = input[0 .. $ - input.find("import").length]; 84 auto output = outputImports(importGroups[i], indent); 85 if (input == output) continue; 86 87 ret ~= format!"<<<<%s:%d-%d\n"(fileName, min, max) 88 ~ input ~ "\n" 89 ~ "----\n" 90 ~ output ~ "\n" 91 ~ ">>>>\n"; 92 } 93 return ret; 94 } 95 96 private: 97 98 struct Output { 99 string mod; 100 string[] binds; 101 string[] attrs; 102 103 bool canMerge(Output that) const { 104 return this.mod == that.mod && equal(this.attrs, that.attrs); 105 } 106 } 107 108 string outputImports(ImportIdentifiers[] idents, string indent = "") const { 109 import std.range : chain, only; 110 // TODO: support max line length. 111 sort(idents); 112 // Merge redundant modules. 113 Output[] outputs; 114 foreach (id; idents) { 115 auto attrs = id.attrs.array.dup; 116 sort(attrs); 117 auto o = Output(id.name, id.bindNames.array.dup, attrs); 118 if (outputs.empty || !outputs[$-1].canMerge(o)) { 119 outputs ~= o; 120 continue; 121 } 122 outputs[$-1].binds ~= o.binds; 123 } 124 125 string ret; 126 foreach (o; outputs) { 127 ret ~= indent; 128 if (!o.attrs.empty) { 129 ret ~= o.attrs.join(" ") ~ " "; 130 } 131 ret ~= "import " ~ o.mod; 132 if (!o.binds.empty) { 133 sort(o.binds); 134 ret ~= " : " ~ o.binds.uniq.join(", "); 135 } 136 ret ~= ";\n"; 137 } 138 // Remove the last new line (\n). 139 return ret[0 .. $-1]; 140 } 141 142 string sourceCode; 143 string fileName; 144 const(Declaration)[][] declGroups; 145 ImportIdentifiers[][] importGroups; 146 147 // For ownerships of tokens. 148 RollbackAllocator rba; 149 StringCache cache; 150 } 151 152 /// Checks declarations are consective. 153 @nogc nothrow pure @safe 154 bool isConsective(const Declaration a, const Declaration b) { 155 return !setIntersection(a.tokens.map!"a.line + 1", b.tokens.map!"a.line").empty; 156 } 157 158 ImportVisitor visitImports(string sourceCode, string fileName = "unittest") { 159 auto visitor = new ImportVisitor(sourceCode); 160 LexerConfig config; 161 auto tokens = getTokensForParser(sourceCode, config, &visitor.cache); 162 auto m = parseModule(tokens, fileName, &visitor.rba); 163 visitor.visit(m); 164 visitor.fileName = fileName; 165 return visitor; 166 } 167 168 /// Test for diff outputs. 169 unittest { 170 auto visitor = visitImports(q{ 171 import cc; 172 import ab; 173 import aa.cc; 174 import aa.bb; 175 176 import foo; 177 import bar, bar2; // expands to two imports. 178 179 void main() {} 180 }); 181 assert(visitor.declGroups.length == 2); 182 assert(visitor.declGroups[0].length == 4); 183 assert(visitor.declGroups[1].length == 2); 184 185 assert(visitor.importGroups.length == 2); 186 assert(visitor.importGroups[0].length == 4); 187 assert(visitor.importGroups[1].length == 3); 188 189 assert(visitor.diff == 190 `<<<<unittest:1-5 191 import cc; 192 import ab; 193 import aa.cc; 194 import aa.bb; 195 ---- 196 import aa.bb; 197 import aa.cc; 198 import ab; 199 import cc; 200 >>>> 201 <<<<unittest:6-8 202 import foo; 203 import bar, bar2; // expands to two imports. 204 ---- 205 import bar; 206 import bar2; 207 import foo; 208 >>>> 209 `); 210 } 211 212 /// Data type for identifiers in an import declaration. 213 /// import mod : binds, ...; 214 class ImportIdentifiers { 215 this(const Attribute[] attributes, const SingleImport si, 216 const ImportBind[] binds = []) { 217 this.attributes = attributes; 218 this.singleImport = si; 219 this.binds = binds; 220 } 221 222 const Attribute[] attributes; 223 const SingleImport singleImport; 224 const ImportBind[] binds; 225 226 pure nothrow @safe 227 string name() const { 228 return singleImport.identifierChain.identifiers.map!"a.text".join("."); 229 } 230 231 pure nothrow @safe 232 auto bindNames() const { 233 return binds.map!"a.left.text"; 234 } 235 236 auto attrs() const { 237 return attributes.map!(a => str(a.attribute.type)); 238 } 239 240 pure @safe 241 override string toString() const { 242 return format!"%s(name=%s, binds=%s)"(typeof(this).stringof, name, bindNames); 243 } 244 245 int opCmp(ImportIdentifiers that) const { 246 // First sort by the module name w/o attrs. 247 auto ret = cmp(this.name, that.name); 248 if (ret != 0) { 249 return ret; 250 } 251 // Then sort by attrs. 252 return cmp(this.attrs.join(" "), that.attrs.join(" ")); 253 } 254 } 255 256 /// Test for binding. 257 unittest { 258 auto visitor = visitImports(q{ 259 import foo : aa, bb, cc; 260 }); 261 assert(visitor.importGroups[0][0].name == "foo"); 262 assert(equal(visitor.importGroups[0][0].bindNames, ["aa", "bb", "cc"])); 263 } 264 265 /// Decomposes multi module import decl to a list of single module with binds. 266 ImportIdentifiers[] toIdentifiers(const Declaration decl) { 267 auto idecl = decl.importDeclaration; 268 assert(idecl !is null, "not import declaration."); 269 auto ret = idecl.singleImports.map!( 270 x => new ImportIdentifiers(decl.attributes, x)).array; 271 if (auto binds = idecl.importBindings) { 272 ret ~= new ImportIdentifiers( 273 decl.attributes, binds.singleImport, binds.importBinds); 274 } 275 return ret; 276 } 277 278 /// Test for import attributes. 279 unittest { 280 auto visitor = visitImports(q{ 281 public import foo; 282 public static import bar; 283 }); 284 auto ids = visitor.importGroups[0]; 285 assert(ids[0].name == "foo"); 286 assert(equal(ids[0].attrs, ["public"])); 287 assert(ids[1].name == "bar"); 288 assert(equal(ids[1].attrs, ["public", "static"])); 289 } 290 291 /// Test for multiple modules and binding. 292 unittest { 293 auto visitor = visitImports(q{ 294 import foo, bar : aa, bb, cc; 295 }); 296 auto ids = visitor.importGroups[0]; 297 assert(ids[0].name == "foo"); 298 assert(ids[0].bindNames.empty); 299 assert(ids[1].name == "bar"); 300 assert(equal(ids[1].bindNames, ["aa", "bb", "cc"])); 301 302 // Test opCmp in sort. 303 sort(ids); 304 assert(ids[0].name == "bar"); 305 assert(ids[1].name == "foo"); 306 } 307 308 /// Test for merging redundant modules. 309 unittest { 310 auto visitor = visitImports(q{ 311 import foo : bar; 312 import foo : baz, bar; 313 }); 314 assert(visitor.outputImports(visitor.importGroups[0]) == 315 "import foo : bar, baz;"); 316 } 317 318 /// Test for modules with attributes. 319 unittest { 320 auto visitor = visitImports(q{ 321 import foo : bar; 322 static import foo; 323 public import foo : bar; 324 public import foo : baz; 325 import bar; 326 }); 327 writeln(visitor.outputImports(visitor.importGroups[0])); 328 assert(visitor.outputImports(visitor.importGroups[0]) == q{ 329 import bar; 330 import foo : bar; 331 public import foo : bar, baz; 332 static import foo; 333 }.strip); 334 }