1 module dimportsort; 2 3 import std.algorithm : cmp, copy, count, map, setIntersection, sort, uniq; 4 import std.array : array, join; 5 import std.format : format; 6 import std.stdio : writeln; 7 import std..string : empty; 8 9 import dparse.ast; 10 import dparse.lexer; 11 import dparse.parser : parseModule; 12 import dparse.rollback_allocator : RollbackAllocator; 13 14 15 /// 16 class ImportVisitor : ASTVisitor { 17 18 /// 19 this(string sourceCode) { 20 this.cache = StringCache(StringCache.defaultBucketCount); 21 this.sourceCode = sourceCode; 22 } 23 24 alias visit = ASTVisitor.visit; 25 26 /** Visit import declaration. 27 28 Params: 29 decl = import declaration. 30 31 Syntax: 32 33 importBind: 34 Identifier ('=' Identifier)? 35 ; 36 importBindings: 37 singleImport ':' importBind (',' importBind)* 38 ; 39 importDeclaration: 40 | 'import' singleImport (',' singleImport)* (',' importBindings)? ';' 41 | 'import' importBindings ';' 42 ; 43 */ 44 override void visit(const ImportDeclaration decl) { 45 if (importGroups.empty || 46 !isConsective(declGroups[$-1][$-1], decl)) { 47 declGroups ~= [decl]; 48 importGroups ~= toIdentifiers(decl); 49 return; 50 } 51 declGroups[$-1] ~= decl; 52 importGroups[$-1] ~= toIdentifiers(decl); 53 54 decl.accept(this); 55 } 56 57 struct Output { 58 string mod; 59 string[] binds; 60 } 61 62 pure @safe 63 string outputImports(ImportIdentifiers[] idents, string indent = "") const { 64 // TODO: support max line length. 65 sort(idents); 66 // Merge redundant modules. 67 Output[] outputs; 68 foreach (id; idents) { 69 if (outputs.empty || outputs[$-1].mod != id.name) { 70 outputs ~= Output(id.name, id.bindNames); 71 continue; 72 } 73 outputs[$-1].binds ~= id.bindNames; 74 } 75 76 string ret; 77 foreach (o; outputs) { 78 ret ~= indent ~ "import " ~ o.mod; 79 if (!o.binds.empty) { 80 sort(o.binds); 81 ret ~= " : " ~ o.binds.uniq.join(", "); 82 } 83 ret ~= ";\n"; 84 } 85 // Remove the last new line (\n). 86 return ret[0 .. $-1]; 87 } 88 89 string diff() { 90 import std.algorithm : find; 91 import std.range : drop, take; 92 import std.algorithm : maxElement, minElement, joiner, splitter; 93 94 string ret; 95 foreach (i, decls; declGroups) { 96 auto lines = decls.map!(d => d.tokens.map!(t => t.line)).joiner; 97 auto min = lines.minElement - 1; 98 auto max = lines.maxElement; 99 auto input = sourceCode.splitter('\n').drop(min).take(max - min).join("\n"); 100 101 auto indent = input[0 .. $ - input.find("import").length]; 102 auto output = outputImports(importGroups[i], indent); 103 if (input == output) continue; 104 105 ret ~= format!"<<<<%s:%d-%d\n"(fileName, min, max) 106 ~ input ~ "\n" 107 ~ "----\n" 108 ~ output ~ "\n" 109 ~ ">>>>\n"; 110 } 111 return ret; 112 } 113 114 private: 115 string sourceCode; 116 string fileName; 117 const(ImportDeclaration)[][] declGroups; 118 ImportIdentifiers[][] importGroups; 119 120 // For ownerships of tokens. 121 RollbackAllocator rba; 122 StringCache cache; 123 } 124 125 /// Checks import declarations are consective. 126 @nogc nothrow pure @safe 127 bool isConsective(const ImportDeclaration a, const ImportDeclaration b) { 128 return !setIntersection(a.tokens.map!"a.line + 1", b.tokens.map!"a.line").empty; 129 } 130 131 ImportVisitor visitImports(string sourceCode, string fileName = "unittest") { 132 auto visitor = new ImportVisitor(sourceCode); 133 LexerConfig config; 134 auto tokens = getTokensForParser(sourceCode, config, &visitor.cache); 135 auto m = parseModule(tokens, fileName, &visitor.rba); 136 visitor.visit(m); 137 visitor.fileName = fileName; 138 return visitor; 139 } 140 141 /// Test for diff outputs. 142 unittest { 143 auto visitor = visitImports(q{ 144 import cc; 145 import ab; 146 import aa.cc; 147 import aa.bb; 148 149 import foo; 150 import bar, bar2; // expands to two imports. 151 152 void main() {} 153 }); 154 assert(visitor.declGroups.length == 2); 155 assert(visitor.declGroups[0].length == 4); 156 assert(visitor.declGroups[1].length == 2); 157 158 assert(visitor.importGroups.length == 2); 159 assert(visitor.importGroups[0].length == 4); 160 assert(visitor.importGroups[1].length == 3); 161 162 assert(visitor.diff == 163 `<<<<unittest:1-5 164 import cc; 165 import ab; 166 import aa.cc; 167 import aa.bb; 168 ---- 169 import aa.bb; 170 import aa.cc; 171 import ab; 172 import cc; 173 >>>> 174 <<<<unittest:6-8 175 import foo; 176 import bar, bar2; // expands to two imports. 177 ---- 178 import bar; 179 import bar2; 180 import foo; 181 >>>> 182 `); 183 } 184 185 /// Data type for identifiers in an import declaration. 186 /// import mod : binds, ...; 187 class ImportIdentifiers { 188 this(const SingleImport si, const ImportBind[] binds = []) { 189 this.singleImport = si; 190 this.binds = binds; 191 } 192 193 const SingleImport singleImport; 194 const ImportBind[] binds; 195 196 pure nothrow @safe 197 string name() const { 198 return singleImport.identifierChain.identifiers.map!"a.text".join("."); 199 } 200 201 pure nothrow @safe 202 string[] bindNames() const { 203 auto ret = new string[binds.length]; 204 copy(binds.map!"a.left.text", ret); 205 sort(ret); 206 return ret; 207 } 208 209 pure @safe 210 override string toString() const { 211 return format!"%s(name=%s, binds=%s)"(typeof(this).stringof, name, bindNames); 212 } 213 214 nothrow pure @safe 215 int opCmp(ImportIdentifiers that) const { 216 return cmp(this.name, that.name); 217 } 218 } 219 220 /// Test for binding. 221 unittest { 222 auto visitor = visitImports(q{ 223 import foo : aa, cc, bb; 224 }); 225 assert(visitor.importGroups[0][0].name == "foo"); 226 assert(visitor.importGroups[0][0].bindNames == ["aa", "bb", "cc"]); 227 } 228 229 /// Decomposes multi module import decl to a list of single module with binds. 230 ImportIdentifiers[] toIdentifiers(const ImportDeclaration decl) { 231 auto ret = decl.singleImports.map!(x => new ImportIdentifiers(x)).array; 232 if (auto binds = decl.importBindings) { 233 ret ~= new ImportIdentifiers(binds.singleImport, binds.importBinds); 234 } 235 return ret; 236 } 237 238 /// Test for multiple modules and binding. 239 unittest { 240 auto visitor = visitImports(q{ 241 import foo, bar : aa, cc, bb; 242 }); 243 auto ids = visitor.importGroups[0]; 244 assert(ids[0].name == "foo"); 245 assert(ids[0].bindNames == []); 246 assert(ids[1].name == "bar"); 247 assert(ids[1].bindNames == ["aa", "bb", "cc"]); 248 249 // Test opCmp in sort. 250 sort(ids); 251 assert(ids[0].name == "bar"); 252 assert(ids[1].name == "foo"); 253 } 254 255 /// Test for merging redundant modules. 256 unittest { 257 auto visitor = visitImports(q{ 258 import foo : bar; 259 import foo : baz, bar; 260 }); 261 assert(visitor.outputImports(visitor.importGroups[0]) == 262 "import foo : bar, baz;"); 263 }