diff --git a/src/changes.rs b/src/changes.rs index dd8f220eb273..ce89c2625fc0 100644 --- a/src/changes.rs +++ b/src/changes.rs @@ -115,15 +115,19 @@ impl<'a> ChangeSet<'a> { } } - pub fn write_all_files(&self, mode: WriteMode) -> Result<(), ::std::io::Error> { + pub fn write_all_files(&self, mode: WriteMode) -> Result<(HashMap), ::std::io::Error> { + let mut result = HashMap::new(); for filename in self.file_map.keys() { - try!(self.write_file(filename, mode)); + let one_result = try!(self.write_file(filename, mode)); + if let Some(r) = one_result { + result.insert(filename.clone(), r); + } } - Ok(()) + Ok(result) } - pub fn write_file(&self, filename: &str, mode: WriteMode) -> Result<(), ::std::io::Error> { + pub fn write_file(&self, filename: &str, mode: WriteMode) -> Result, ::std::io::Error> { let text = &self.file_map[filename]; match mode { @@ -147,13 +151,16 @@ impl<'a> ChangeSet<'a> { let mut file = try!(File::create(&filename)); try!(write!(file, "{}", text)); } - _ => { + WriteMode::Display => { println!("{}:\n", filename); println!("{}", text); } + WriteMode::Return(_) => { + return Ok(Some(text.to_string())); + } } - Ok(()) + Ok(None) } } diff --git a/src/functions.rs b/src/functions.rs index 5b08fa797bcf..e7f80922e8af 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -14,12 +14,11 @@ use utils::make_indent; use lists::{write_list, ListFormatting, SeparatorTactic, ListTactic}; use visitor::FmtVisitor; use syntax::{ast, abi}; -use syntax::codemap::{self, Span}; +use syntax::codemap::{self, Span, BytePos}; use syntax::print::pprust; use syntax::parse::token; impl<'a> FmtVisitor<'a> { - // TODO extract methods for args and generics pub fn rewrite_fn(&mut self, indent: usize, ident: ast::Ident, @@ -28,7 +27,9 @@ impl<'a> FmtVisitor<'a> { generics: &ast::Generics, unsafety: &ast::Unsafety, abi: &abi::Abi, - vis: ast::Visibility) + vis: ast::Visibility, + next_span: Span) // next_span is a nasty hack, its a loose upper + // bound on any comments after the where clause. -> String { // FIXME we'll lose any comments in between parts of the function decl, but anyone @@ -56,7 +57,10 @@ impl<'a> FmtVisitor<'a> { result.push_str(&token::get_ident(ident)); // Generics. - result.push_str(&self.rewrite_generics(generics, indent)); + let generics_indent = indent + result.len(); + result.push_str(&self.rewrite_generics(generics, + generics_indent, + span_for_return(&fd.output))); let ret_str = self.rewrite_return(&fd.output); @@ -74,7 +78,7 @@ impl<'a> FmtVisitor<'a> { result.push(')'); // Where clause. - result.push_str(&self.rewrite_where_clause(where_clause, indent)); + result.push_str(&self.rewrite_where_clause(where_clause, indent, next_span)); // Return type. if ret_str.len() > 0 { @@ -98,6 +102,8 @@ impl<'a> FmtVisitor<'a> { result.push_str(&ret_str); } + // TODO any comments here? + // Prepare for the function body by possibly adding a newline and indent. // FIXME we'll miss anything between the end of the signature and the start // of the body, but we need more spans from the compiler to solve this. @@ -157,40 +163,28 @@ impl<'a> FmtVisitor<'a> { // spans for the comment or parens, there is no chance of getting it right. // You also don't get to put a comment on self, unless it is explicit. if args.len() >= min_args { - let mut prev_end = args[min_args-1].ty.span.hi; - for arg in &args[min_args..] { - let cur_start = arg.pat.span.lo; - let snippet = self.snippet(codemap::mk_sp(prev_end, cur_start)); - let mut snippet = snippet.trim(); - if snippet.starts_with(",") { - snippet = snippet[1..].trim(); - } else if snippet.ends_with(",") { - snippet = snippet[..snippet.len()-1].trim(); - } - arg_comments.push(snippet.to_string()); - prev_end = arg.ty.span.hi; - } - // Get the last commment. - // FIXME If you thought the crap with the commas was ugly, just wait. - // This is awful. We're going to look from the last arg span to the - // start of the return type span, then we drop everything after the - // first closing paren. Obviously, this will break if there is a - // closing paren in the comment. - // The fix is comments in the AST or a span for the closing paren. - let snippet = self.snippet(codemap::mk_sp(prev_end, ret_span.lo)); - let snippet = snippet.trim(); - let snippet = &snippet[..snippet.find(")").unwrap()]; - let snippet = snippet.trim(); - arg_comments.push(snippet.to_string()); + arg_comments = self.make_comments_for_list(arg_comments, + args[min_args-1..].iter(), + ",", + ")", + |arg| arg.pat.span.lo, + |arg| arg.ty.span.hi, + ret_span.lo); } debug!("comments: {:?}", arg_comments); + // If there are // comments, keep them multi-line. + let mut list_tactic = ListTactic::HorizontalVertical; + if arg_comments.iter().any(|c| c.contains("//")) { + list_tactic = ListTactic::Vertical; + } + assert_eq!(arg_item_strs.len(), arg_comments.len()); let arg_strs: Vec<_> = arg_item_strs.into_iter().zip(arg_comments.into_iter()).collect(); let fmt = ListFormatting { - tactic: ListTactic::HorizontalVertical, + tactic: list_tactic, separator: ",", trailing_separator: SeparatorTactic::Never, indent: arg_indent, @@ -201,6 +195,51 @@ impl<'a> FmtVisitor<'a> { write_list(&arg_strs, &fmt) } + // Gets comments in between items of a list. + fn make_comments_for_list(&self, + prefix: Vec, + mut it: I, + separator: &str, + terminator: &str, + get_lo: F1, + get_hi: F2, + next_span_start: BytePos) + -> Vec + where I: Iterator, + F1: Fn(&T) -> BytePos, + F2: Fn(&T) -> BytePos + { + let mut result = prefix; + + let mut prev_end = get_hi(&it.next().unwrap()); + for item in it { + let cur_start = get_lo(&item); + let snippet = self.snippet(codemap::mk_sp(prev_end, cur_start)); + let mut snippet = snippet.trim(); + if snippet.starts_with(separator) { + snippet = snippet[1..].trim(); + } else if snippet.ends_with(separator) { + snippet = snippet[..snippet.len()-1].trim(); + } + result.push(snippet.to_string()); + prev_end = get_hi(&item); + } + // Get the last commment. + // FIXME If you thought the crap with the commas was ugly, just wait. + // This is awful. We're going to look from the last item span to the + // start of the return type span, then we drop everything after the + // first closing paren. Obviously, this will break if there is a + // closing paren in the comment. + // The fix is comments in the AST or a span for the closing paren. + let snippet = self.snippet(codemap::mk_sp(prev_end, next_span_start)); + let snippet = snippet.trim(); + let snippet = &snippet[..snippet.find(terminator).unwrap()]; + let snippet = snippet.trim(); + result.push(snippet.to_string()); + + result + } + fn compute_budgets_for_args(&self, result: &mut String, indent: usize, @@ -260,37 +299,70 @@ impl<'a> FmtVisitor<'a> { } } - fn rewrite_generics(&self, generics: &ast::Generics, indent: usize) -> String { + fn rewrite_generics(&self, generics: &ast::Generics, indent: usize, ret_span: Span) -> String { // FIXME convert bounds to where clauses where they get too big or if // there is a where clause at all. let mut result = String::new(); let lifetimes: &[_] = &generics.lifetimes; let tys: &[_] = &generics.ty_params; - if lifetimes.len() + tys.len() > 0 { - let budget = MAX_WIDTH - indent - result.len() - 2; - // TODO might need to insert a newline if the generics are really long - result.push('<'); - - let lt_strs = lifetimes.iter().map(|l| self.rewrite_lifetime_def(l)); - let ty_strs = tys.iter().map(|ty| self.rewrite_ty_param(ty)); - let generics_strs: Vec<_> = lt_strs.chain(ty_strs).map(|s| (s, String::new())).collect(); - let fmt = ListFormatting { - tactic: ListTactic::HorizontalVertical, - separator: ",", - trailing_separator: SeparatorTactic::Never, - indent: indent + result.len() + 1, - h_width: budget, - v_width: budget, - }; - result.push_str(&write_list(&generics_strs, &fmt)); - - result.push('>'); + if lifetimes.len() + tys.len() == 0 { + return result; } + let budget = MAX_WIDTH - indent - 2; + // TODO might need to insert a newline if the generics are really long + result.push('<'); + + // Strings for the generics. + let lt_strs = lifetimes.iter().map(|l| self.rewrite_lifetime_def(l)); + let ty_strs = tys.iter().map(|ty| self.rewrite_ty_param(ty)); + + // Extract comments between generics. + let lt_spans = lifetimes.iter().map(|l| { + let hi = if l.bounds.len() == 0 { + l.lifetime.span.hi + } else { + l.bounds[l.bounds.len() - 1].span.hi + }; + codemap::mk_sp(l.lifetime.span.lo, hi) + }); + let ty_spans = tys.iter().map(span_for_ty_param); + let comments = self.make_comments_for_list(Vec::new(), + lt_spans.chain(ty_spans), + ",", + ">", + |sp| sp.lo, + |sp| sp.hi, + ret_span.lo); + + // If there are // comments, keep them multi-line. + let mut list_tactic = ListTactic::HorizontalVertical; + if comments.iter().any(|c| c.contains("//")) { + list_tactic = ListTactic::Vertical; + } + + let generics_strs: Vec<_> = lt_strs.chain(ty_strs).zip(comments.into_iter()).collect(); + let fmt = ListFormatting { + tactic: list_tactic, + separator: ",", + trailing_separator: SeparatorTactic::Never, + indent: indent + 1, + h_width: budget, + v_width: budget, + }; + result.push_str(&write_list(&generics_strs, &fmt)); + + result.push('>'); + result } - fn rewrite_where_clause(&self, where_clause: &ast::WhereClause, indent: usize) -> String { + fn rewrite_where_clause(&self, + where_clause: &ast::WhereClause, + indent: usize, + next_span: Span) + -> String + { let mut result = String::new(); if where_clause.predicates.len() == 0 { return result; @@ -300,6 +372,21 @@ impl<'a> FmtVisitor<'a> { result.push_str(&make_indent(indent + 4)); result.push_str("where "); + // TODO uncomment when spans are fixed + //println!("{:?} {:?}", where_clause.predicates.iter().map(|p| self.snippet(span_for_where_pred(p))).collect::>(), next_span.lo); + // let comments = self.make_comments_for_list(Vec::new(), + // where_clause.predicates.iter(), + // ",", + // "{", + // |pred| span_for_where_pred(pred).lo, + // |pred| span_for_where_pred(pred).hi, + // next_span.lo); + let comments = vec![String::new(); where_clause.predicates.len()]; + let where_strs: Vec<_> = where_clause.predicates.iter() + .map(|p| (self.rewrite_pred(p))) + .zip(comments.into_iter()) + .collect(); + let budget = IDEAL_WIDTH + LEEWAY - indent - 10; let fmt = ListFormatting { tactic: ListTactic::Vertical, @@ -309,7 +396,6 @@ impl<'a> FmtVisitor<'a> { h_width: budget, v_width: budget, }; - let where_strs: Vec<_> = where_clause.predicates.iter().map(|p| (self.rewrite_pred(p), String::new())).collect(); result.push_str(&write_list(&where_strs, &fmt)); result @@ -338,3 +424,27 @@ fn span_for_return(ret: &ast::FunctionRetTy) -> Span { ast::FunctionRetTy::Return(ref ty) => ty.span, } } + +fn span_for_ty_param(ty: &ast::TyParam) -> Span { + // Note that ty.span is the span for ty.ident, not the whole item. + let lo = ty.span.lo; + if let Some(ref def) = ty.default { + return codemap::mk_sp(lo, def.span.hi); + } + if ty.bounds.len() == 0 { + return ty.span; + } + let hi = match ty.bounds[ty.bounds.len() - 1] { + ast::TyParamBound::TraitTyParamBound(ref ptr, _) => ptr.span.hi, + ast::TyParamBound::RegionTyParamBound(ref l) => l.span.hi, + }; + codemap::mk_sp(lo, hi) +} + +fn span_for_where_pred(pred: &ast::WherePredicate) -> Span { + match *pred { + ast::WherePredicate::BoundPredicate(ref p) => p.span, + ast::WherePredicate::RegionPredicate(ref p) => p.span, + ast::WherePredicate::EqPredicate(ref p) => p.span, + } +} diff --git a/src/mod.rs b/src/mod.rs index 23fe92438c30..fd508a56b2fe 100644 --- a/src/mod.rs +++ b/src/mod.rs @@ -21,11 +21,10 @@ // TODO priorities // Fix fns and methods properly -// dead spans (comments) - in generics +// dead spans (comments) - in where clause (wait for fixed spans, test) // // Smoke testing till we can use it -// no newline at the end of doc.rs -// fmt_skip annotations +// ** no newline at the end of doc.rs // take config options from a file #[macro_use] @@ -48,6 +47,7 @@ use syntax::diagnostics; use syntax::visit; use std::path::PathBuf; +use std::collections::HashMap; use changes::ChangeSet; use visitor::FmtVisitor; @@ -69,14 +69,18 @@ const MIN_STRING: usize = 10; const TAB_SPACES: usize = 4; const FN_BRACE_STYLE: BraceStyle = BraceStyle::SameLineWhere; const FN_RETURN_INDENT: ReturnIndent = ReturnIndent::WithArgs; +// When we get scoped annotations, we should have rustfmt::skip. +const SKIP_ANNOTATION: &'static str = "rustfmt_skip"; -#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[derive(Copy, Clone)] pub enum WriteMode { Overwrite, // str is the extension of the new file NewFile(&'static str), // Write the output to stdout. Display, + // Return the result as a mapping from filenames to StringBuffers. + Return(&'static Fn(HashMap)), } #[derive(Copy, Clone, Eq, PartialEq, Debug)] @@ -157,6 +161,7 @@ fn fmt_lines(changes: &mut ChangeSet) { struct RustFmtCalls { input_path: Option, + write_mode: WriteMode, } impl<'a> CompilerCalls<'a> for RustFmtCalls { @@ -202,19 +207,25 @@ impl<'a> CompilerCalls<'a> for RustFmtCalls { } fn build_controller(&mut self, _: &Session) -> driver::CompileController<'a> { + let write_mode = self.write_mode; let mut control = driver::CompileController::basic(); control.after_parse.stop = Compilation::Stop; - control.after_parse.callback = box |state| { + control.after_parse.callback = box move |state| { let krate = state.krate.unwrap(); let codemap = state.session.codemap(); let mut changes = fmt_ast(krate, codemap); fmt_lines(&mut changes); // FIXME(#5) Should be user specified whether to show or replace. - let result = changes.write_all_files(WriteMode::Display); + let result = changes.write_all_files(write_mode); - if let Err(msg) = result { - println!("Error writing files: {}", msg); + match result { + Err(msg) => println!("Error writing files: {}", msg), + Ok(result) => { + if let WriteMode::Return(callback) = write_mode { + callback(result); + } + } } }; @@ -222,10 +233,14 @@ impl<'a> CompilerCalls<'a> for RustFmtCalls { } } +fn run(args: Vec, write_mode: WriteMode) { + let mut call_ctxt = RustFmtCalls { input_path: None, write_mode: write_mode }; + rustc_driver::run_compiler(&args, &mut call_ctxt); +} + fn main() { let args: Vec<_> = std::env::args().collect(); - let mut call_ctxt = RustFmtCalls { input_path: None }; - rustc_driver::run_compiler(&args, &mut call_ctxt); + run(args, WriteMode::Display); std::env::set_exit_status(0); // TODO unit tests @@ -262,3 +277,68 @@ fn main() { // the right kind. // Should also make sure comments have the right indent + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::fs; + use std::io::Read; + use super::*; + use super::run; + + // For now, the only supported regression tests are idempotent tests - the input and + // output must match exactly. + // TODO would be good to check for error messages and fail on them, or at least report. + #[test] + fn idempotent_tests() { + println!("Idempotent tests:"); + unsafe { FAILURES = 0; } + + // Get all files in the tests/idem directory + let files = fs::read_dir("tests/idem").unwrap(); + // For each file, run rustfmt and collect the output + let mut count = 0; + for entry in files { + let path = entry.unwrap().path(); + let file_name = path.to_str().unwrap(); + println!("Testing '{}'...", file_name); + run(vec!["rustfmt".to_string(), file_name.to_string()], WriteMode::Return(HANDLE_RESULT)); + count += 1; + } + // And also dogfood ourselves! + println!("Testing 'src/mod.rs'..."); + run(vec!["rustfmt".to_string(), "src/mod.rs".to_string()], WriteMode::Return(HANDLE_RESULT)); + count += 1; + + // Display results + let fails = unsafe { FAILURES }; + println!("Ran {} idempotent tests; {} failures.", count, fails); + assert!(fails == 0, "{} idempotent tests failed", fails); + } + + // 'global' used by sys_tests and handle_result. + static mut FAILURES: i32 = 0; + // Ick, just needed to get a &'static to handle_result. + static HANDLE_RESULT: &'static Fn(HashMap) = &handle_result; + + // Compare output to input. + fn handle_result(result: HashMap) { + let mut fails = 0; + + for file_name in result.keys() { + let mut f = fs::File::open(file_name).unwrap(); + let mut text = String::new(); + f.read_to_string(&mut text).unwrap(); + if result[file_name] != text { + fails += 1; + println!("Mismatch in {}.", file_name); + } + } + + if fails > 0 { + unsafe { + FAILURES += 1; + } + } + } +} diff --git a/src/visitor.rs b/src/visitor.rs index 7e8a885c657d..e76d87afe9d5 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -12,7 +12,7 @@ use syntax::ast; use syntax::codemap::{CodeMap, Span, BytePos}; use syntax::visit; -use {MAX_WIDTH, TAB_SPACES}; +use {MAX_WIDTH, TAB_SPACES, SKIP_ANNOTATION}; use changes::ChangeSet; pub struct FmtVisitor<'a> { @@ -86,7 +86,8 @@ impl<'a, 'v> visit::Visitor<'v> for FmtVisitor<'a> { generics, unsafety, abi, - vis); + vis, + b.span); self.changes.push_str_span(s, &new_fn); } visit::FkMethod(ident, ref sig, vis) => { @@ -97,7 +98,8 @@ impl<'a, 'v> visit::Visitor<'v> for FmtVisitor<'a> { &sig.generics, &sig.unsafety, &sig.abi, - vis.unwrap_or(ast::Visibility::Inherited)); + vis.unwrap_or(ast::Visibility::Inherited), + b.span); self.changes.push_str_span(s, &new_fn); } visit::FkFnBlock(..) => {} @@ -108,6 +110,10 @@ impl<'a, 'v> visit::Visitor<'v> for FmtVisitor<'a> { } fn visit_item(&mut self, item: &'v ast::Item) { + if item.attrs.iter().any(|a| is_skip(&a.node.value)) { + return; + } + match item.node { ast::Item_::ItemUse(ref vp) => { match vp.node { @@ -135,6 +141,20 @@ impl<'a, 'v> visit::Visitor<'v> for FmtVisitor<'a> { } } + fn visit_trait_item(&mut self, ti: &'v ast::TraitItem) { + if ti.attrs.iter().any(|a| is_skip(&a.node.value)) { + return; + } + visit::walk_trait_item(self, ti) + } + + fn visit_impl_item(&mut self, ii: &'v ast::ImplItem) { + if ii.attrs.iter().any(|a| is_skip(&a.node.value)) { + return; + } + visit::walk_impl_item(self, ii) + } + fn visit_mac(&mut self, mac: &'v ast::Mac) { visit::walk_mac(self, mac) } @@ -171,3 +191,10 @@ impl<'a> FmtVisitor<'a> { } } } + +fn is_skip(meta_item: &ast::MetaItem) -> bool { + match meta_item.node { + ast::MetaItem_::MetaWord(ref s) => *s == SKIP_ANNOTATION, + _ => false, + } +}