From 64ce092c273a20cc503697c3dfebb956c9025d46 Mon Sep 17 00:00:00 2001 From: Marijn Haverbeke Date: Fri, 2 Dec 2011 13:42:51 +0100 Subject: [PATCH] Allow literal patterns to contain arbitrary literal expressions This removes the need for the unary minus hacks, and allows some other neat things like matching on 1 >> 4. Issue #954 --- src/comp/driver/rustc.rs | 4 +- src/comp/middle/check_alt.rs | 12 ++-- src/comp/middle/check_const.rs | 81 ++++++++++++++----------- src/comp/middle/trans_alt.rs | 46 ++++++--------- src/comp/middle/ty.rs | 5 ++ src/comp/middle/typeck.rs | 28 ++++----- src/comp/syntax/ast.rs | 4 +- src/comp/syntax/ast_util.rs | 101 ++++++++++++++++++++++---------- src/comp/syntax/parse/parser.rs | 20 ++++--- src/comp/syntax/print/pprust.rs | 6 +- 10 files changed, 177 insertions(+), 130 deletions(-) diff --git a/src/comp/driver/rustc.rs b/src/comp/driver/rustc.rs index c976410554a6..7ed81be5eda2 100644 --- a/src/comp/driver/rustc.rs +++ b/src/comp/driver/rustc.rs @@ -137,6 +137,8 @@ fn compile_input(sess: session::session, cfg: ast::crate_cfg, input: str, let freevars = time(time_passes, "freevar finding", bind freevars::annotate_freevars(def_map, crate)); + time(time_passes, "const checking", + bind middle::check_const::check_crate(sess, crate)); let ty_cx = ty::mk_ctxt(sess, def_map, ext_map, ast_map, freevars); time(time_passes, "typechecking", bind typeck::check_crate(ty_cx, crate)); time(time_passes, "block-use checking", @@ -157,8 +159,6 @@ fn compile_input(sess: session::session, cfg: ast::crate_cfg, input: str, bind last_use::find_last_uses(crate, def_map, ref_map, ty_cx)); time(time_passes, "kind checking", bind kind::check_crate(ty_cx, last_uses, crate)); - time(time_passes, "const checking", - bind middle::check_const::check_crate(ty_cx, crate)); if sess.get_opts().no_trans { ret; } let llmod = time(time_passes, "translation", diff --git a/src/comp/middle/check_alt.rs b/src/comp/middle/check_alt.rs index 380fcff24b6a..59410cd3ecb5 100644 --- a/src/comp/middle/check_alt.rs +++ b/src/comp/middle/check_alt.rs @@ -1,5 +1,6 @@ import syntax::ast::*; -import syntax::ast_util::{variant_def_ids, dummy_sp, compare_lit, lit_eq}; +import syntax::ast_util::{variant_def_ids, dummy_sp, compare_lit_exprs, + lit_expr_eq}; import syntax::visit; fn check_crate(tcx: ty::ctxt, crate: @crate) { @@ -66,7 +67,7 @@ fn pattern_supersedes(tcx: ty::ctxt, a: @pat, b: @pat) -> bool { pat_wild. | pat_bind(_) { ret true; } pat_lit(la) { alt b.node { - pat_lit(lb) { ret lit_eq(la, lb); } + pat_lit(lb) { ret lit_expr_eq(la, lb); } _ { ret false; } } } @@ -106,11 +107,12 @@ fn pattern_supersedes(tcx: ty::ctxt, a: @pat, b: @pat) -> bool { pat_range(begina, enda) { alt b.node { pat_lit(lb) { - ret compare_lit(begina, lb) <= 0 && compare_lit(enda, lb) >= 0; + ret compare_lit_exprs(begina, lb) <= 0 && + compare_lit_exprs(enda, lb) >= 0; } pat_range(beginb, endb) { - ret compare_lit(begina, beginb) <= 0 && - compare_lit(enda, endb) >= 0; + ret compare_lit_exprs(begina, beginb) <= 0 && + compare_lit_exprs(enda, endb) >= 0; } _ { ret false; } } diff --git a/src/comp/middle/check_const.rs b/src/comp/middle/check_const.rs index d9e96b0cb736..f851033593a5 100644 --- a/src/comp/middle/check_const.rs +++ b/src/comp/middle/check_const.rs @@ -1,46 +1,61 @@ import syntax::ast::*; import syntax::visit; +import driver::session::session; -fn check_crate(tcx: ty::ctxt, crate: @crate) { - let v = - @{visit_item: bind check_item(tcx, _, _, _) - with *visit::default_visitor::<()>()}; - visit::visit_crate(*crate, (), visit::mk_vt(v)); - tcx.sess.abort_if_errors(); +fn check_crate(sess: session, crate: @crate) { + visit::visit_crate(*crate, false, visit::mk_vt(@{ + visit_item: check_item, + visit_pat: check_pat, + visit_expr: bind check_expr(sess, _, _, _) + with *visit::default_visitor() + })); + sess.abort_if_errors(); } -fn check_item(tcx: ty::ctxt, it: @item, &&s: (), v: visit::vt<()>) { - visit::visit_item(it, s, v); +fn check_item(it: @item, &&_is_const: bool, v: visit::vt) { alt it.node { - item_const(_ /* ty */, ex) { - let v = - @{visit_expr: bind check_const_expr(tcx, _, _, _) - with *visit::default_visitor::<()>()}; - check_const_expr(tcx, ex, (), visit::mk_vt(v)); - } - _ { } + item_const(_, ex) { v.visit_expr(ex, true, v); } + _ { visit::visit_item(it, false, v); } } } -fn check_const_expr(tcx: ty::ctxt, ex: @expr, &&s: (), v: visit::vt<()>) { - visit::visit_expr(ex, s, v); - alt ex.node { - expr_lit(_) { } - expr_binary(_, _, _) { /* subexps covered by visit */ } - expr_unary(u, _) { - alt u { - box(_) | - uniq(_) | - deref. { - tcx.sess.span_err(ex.span, - "disallowed operator in constant expression"); - } - _ { } - } - } - _ { tcx.sess.span_err(ex.span, - "constant contains unimplemented expression type"); } +fn check_pat(p: @pat, &&_is_const: bool, v: visit::vt) { + fn is_str(e: @expr) -> bool { + alt e.node { expr_lit(@{node: lit_str(_), _}) { true } _ { false } } } + alt p.node { + // Let through plain string literals here + pat_lit(a) { if !is_str(a) { v.visit_expr(a, true, v); } } + pat_range(a, b) { + if !is_str(a) { v.visit_expr(a, true, v); } + if !is_str(b) { v.visit_expr(b, true, v); } + } + _ { visit::visit_pat(p, false, v); } + } +} + +fn check_expr(sess: session, e: @expr, &&is_const: bool, v: visit::vt) { + if is_const { + alt e.node { + expr_unary(box(_), _) | expr_unary(uniq(_), _) | + expr_unary(deref., _){ + sess.span_err(e.span, + "disallowed operator in constant expression"); + ret; + } + expr_lit(@{node: lit_str(_), _}) { + sess.span_err(e.span, + "string constants are not supported"); + } + expr_lit(_) | expr_binary(_, _, _) | expr_unary(_, _) {} + _ { + sess.span_err(e.span, + "constant contains unimplemented expression type"); + ret; + } + } + } + visit::visit_expr(e, is_const, v); } // Local Variables: diff --git a/src/comp/middle/trans_alt.rs b/src/comp/middle/trans_alt.rs index 84eebc9b94db..efbad8d212ab 100644 --- a/src/comp/middle/trans_alt.rs +++ b/src/comp/middle/trans_alt.rs @@ -7,7 +7,7 @@ import trans_build::*; import trans::{new_sub_block_ctxt, new_scope_block_ctxt, load_if_immediate}; import syntax::ast; import syntax::ast_util; -import syntax::ast_util::{dummy_sp, lit_eq}; +import syntax::ast_util::{dummy_sp}; import syntax::ast::def_id; import syntax::codemap::span; @@ -15,24 +15,19 @@ import trans_common::*; // An option identifying a branch (either a literal, a tag variant or a range) tag opt { - lit(@ast::lit); + lit(@ast::expr); var(/* variant id */uint, /* variant dids */{tg: def_id, var: def_id}); - range(@ast::lit, @ast::lit); + range(@ast::expr, @ast::expr); } fn opt_eq(a: opt, b: opt) -> bool { - alt a { - lit(la) { - ret alt b { lit(lb) { lit_eq(la, lb) } _ { false } }; - } - var(ida, _) { - ret alt b { var(idb, _) { ida == idb } _ { false } }; - } - range(la1, la2) { - ret alt b { - range(lb1, lb2) { lit_eq(la1, lb1) && lit_eq(la2, lb2) } - _ { false } - }; + alt (a, b) { + (lit(a), lit(b)) { ast_util::compare_lit_exprs(a, b) == 0 } + (range(a1, a2), range(b1, b2)) { + ast_util::compare_lit_exprs(a1, b1) == 0 && + ast_util::compare_lit_exprs(a2, b2) == 0 } + (var(a, _), var(b, _)) { a == b } + _ { false } } } @@ -45,7 +40,7 @@ fn trans_opt(bcx: @block_ctxt, o: opt) -> opt_result { alt o { lit(l) { alt l.node { - ast::lit_str(s) { + ast::expr_lit(@{node: ast::lit_str(s), _}) { let strty = ty::mk_str(bcx_tcx(bcx)); let cell = trans::empty_dest_cell(); bcx = trans_vec::trans_str(bcx, s, trans::by_val(cell)); @@ -54,17 +49,14 @@ fn trans_opt(bcx: @block_ctxt, o: opt) -> opt_result { } _ { ret single_result( - rslt(bcx, trans::trans_crate_lit(ccx, *l))); + rslt(bcx, trans::trans_const_expr(ccx, l))); } } } var(id, _) { ret single_result(rslt(bcx, C_int(ccx, id as int))); } range(l1, l2) { - let cell1 = trans::empty_dest_cell(); - let cell2 = trans::empty_dest_cell(); - let bcx = trans::trans_lit(bcx, *l1, trans::by_val(cell1)); - let bcx = trans::trans_lit(bcx, *l2, trans::by_val(cell2)); - ret range_result(rslt(bcx, *cell1), rslt(bcx, *cell2)); + ret range_result(rslt(bcx, trans::trans_const_expr(ccx, l1)), + rslt(bcx, trans::trans_const_expr(ccx, l2))); } } } @@ -464,13 +456,9 @@ fn compile_submatch(bcx: @block_ctxt, m: match, vals: [ValueRef], f: mk_fail, } } lit(l) { - kind = alt l.node { - ast::lit_str(_) | ast::lit_nil. | ast::lit_float(_) | - ast::lit_mach_float(_, _) { - test_val = Load(bcx, val); compare - } - _ { test_val = Load(bcx, val); switch } - }; + test_val = Load(bcx, val); + let pty = ty::node_id_to_monotype(ccx.tcx, pat_id); + kind = ty::type_is_integral(ccx.tcx, pty) ? switch : compare; } range(_, _) { test_val = Load(bcx, val); diff --git a/src/comp/middle/ty.rs b/src/comp/middle/ty.rs index 5a290f7efc6d..a7dd6c81c7dc 100644 --- a/src/comp/middle/ty.rs +++ b/src/comp/middle/ty.rs @@ -155,6 +155,7 @@ export type_is_vec; export type_is_fp; export type_allows_implicit_copy; export type_is_integral; +export type_is_numeric; export type_is_native; export type_is_nil; export type_is_pod; @@ -1173,6 +1174,10 @@ fn type_is_fp(cx: ctxt, ty: t) -> bool { } } +fn type_is_numeric(cx: ctxt, ty: t) -> bool { + ret type_is_integral(cx, ty) || type_is_fp(cx, ty); +} + fn type_is_signed(cx: ctxt, ty: t) -> bool { alt struct(cx, ty) { ty_int. { ret true; } diff --git a/src/comp/middle/typeck.rs b/src/comp/middle/typeck.rs index fc0a4de5245d..ae26d10dce89 100644 --- a/src/comp/middle/typeck.rs +++ b/src/comp/middle/typeck.rs @@ -1,7 +1,6 @@ import syntax::{ast, ast_util}; import ast::spanned; -import syntax::ast_util::{local_def, respan, ty_param_kind, lit_is_numeric, - lit_types_match}; +import syntax::ast_util::{local_def, respan, ty_param_kind}; import syntax::visit; import metadata::csearch; import driver::session; @@ -1253,8 +1252,8 @@ fn lit_as_float(l: @ast::lit) -> str { } } -fn valid_range_bounds(l1: @ast::lit, l2: @ast::lit) -> bool { - ast_util::compare_lit(l1, l2) <= 0 +fn valid_range_bounds(from: @ast::expr, to: @ast::expr) -> bool { + ast_util::compare_lit_exprs(from, to) <= 0 } // Pattern checking is top-down rather than bottom-up so that bindings get @@ -1264,14 +1263,18 @@ fn check_pat(fcx: @fn_ctxt, map: ast_util::pat_id_map, pat: @ast::pat, alt pat.node { ast::pat_wild. { write::ty_only_fixup(fcx, pat.id, expected); } ast::pat_lit(lt) { - let typ = check_lit(fcx.ccx, lt); - typ = demand::simple(fcx, pat.span, expected, typ); - write::ty_only_fixup(fcx, pat.id, typ); + check_expr_with(fcx, lt, expected); + write::ty_only_fixup(fcx, pat.id, expr_ty(fcx.ccx.tcx, lt)); } ast::pat_range(begin, end) { - if !lit_types_match(begin, end) { + check_expr_with(fcx, begin, expected); + check_expr_with(fcx, end, expected); + let b_ty = resolve_type_vars_if_possible(fcx, expr_ty(fcx.ccx.tcx, + begin)); + if b_ty != resolve_type_vars_if_possible(fcx, expr_ty(fcx.ccx.tcx, + end)) { fcx.ccx.tcx.sess.span_err(pat.span, "mismatched types in range"); - } else if !lit_is_numeric(begin) || !lit_is_numeric(end) { + } else if !ty::type_is_numeric(fcx.ccx.tcx, b_ty) { fcx.ccx.tcx.sess.span_err(pat.span, "non-numeric type used in range"); } else if !valid_range_bounds(begin, end) { @@ -1279,12 +1282,7 @@ fn check_pat(fcx: @fn_ctxt, map: ast_util::pat_id_map, pat: @ast::pat, "lower range bound must be less \ than upper"); } - let typ1 = check_lit(fcx.ccx, begin); - typ1 = demand::simple(fcx, pat.span, expected, typ1); - write::ty_only_fixup(fcx, pat.id, typ1); - let typ2 = check_lit(fcx.ccx, end); - typ2 = demand::simple(fcx, pat.span, typ1, typ2); - write::ty_only_fixup(fcx, pat.id, typ2); + write::ty_only_fixup(fcx, pat.id, b_ty); } ast::pat_bind(name) { let vid = lookup_local(fcx, pat.span, pat.id); diff --git a/src/comp/syntax/ast.rs b/src/comp/syntax/ast.rs index 9c7b9ef422de..d9280a483b14 100644 --- a/src/comp/syntax/ast.rs +++ b/src/comp/syntax/ast.rs @@ -86,13 +86,13 @@ type field_pat = {ident: ident, pat: @pat}; tag pat_ { pat_wild; pat_bind(ident); - pat_lit(@lit); pat_tag(@path, [@pat]); pat_rec([field_pat], bool); pat_tup([@pat]); pat_box(@pat); pat_uniq(@pat); - pat_range(@lit, @lit); + pat_lit(@expr); + pat_range(@expr, @expr); } tag mutability { mut; imm; maybe_mut; } diff --git a/src/comp/syntax/ast_util.rs b/src/comp/syntax/ast_util.rs index da7b5a449534..d8b52300e916 100644 --- a/src/comp/syntax/ast_util.rs +++ b/src/comp/syntax/ast_util.rs @@ -225,45 +225,82 @@ fn ternary_to_if(e: @expr) -> @expr { fn ty_param_kind(tp: ty_param) -> kind { tp.kind } -fn compare_lit(a: @lit, b: @lit) -> int { - fn cmp(a: T, b: T) -> int { a == b ? 0 : a < b ? -1 : 1 } - alt (a.node, b.node) { - (lit_int(a), lit_int(b)) | - (lit_mach_int(_, a), lit_mach_int(_, b)) { cmp(a, b) } - (lit_uint(a), lit_uint(b)) { cmp(a, b) } - (lit_char(a), lit_char(b)) { cmp(a, b) } - (lit_float(a), lit_float(b)) | - (lit_mach_float(_, a), lit_mach_float(_, b)) { - cmp(std::float::from_str(a), std::float::from_str(b)) +// FIXME this doesn't handle big integer/float literals correctly (nor does +// the rest of our literal handling) +tag const_val { const_float(float); const_int(i64); const_str(str); } + +fn eval_const_expr(e: @expr) -> const_val { + fn fromb(b: bool) -> const_val { const_int(b as i64) } + alt e.node { + expr_unary(neg., inner) { + alt eval_const_expr(inner) { + const_float(f) { const_float(-f) } + const_int(i) { const_int(-i) } + } } - (lit_str(a), lit_str(b)) { cmp(a, b) } - (lit_nil., lit_nil.) { 0 } - (lit_bool(a), lit_bool(b)) { cmp(a, b) } + expr_unary(not., inner) { + alt eval_const_expr(inner) { + const_int(i) { const_int(!i) } + } + } + expr_binary(op, a, b) { + alt (eval_const_expr(a), eval_const_expr(b)) { + (const_float(a), const_float(b)) { + alt op { + add. { const_float(a + b) } sub. { const_float(a - b) } + mul. { const_float(a * b) } div. { const_float(a / b) } + rem. { const_float(a % b) } eq. { fromb(a == b) } + lt. { fromb(a < b) } le. { fromb(a <= b) } ne. { fromb(a != b) } + ge. { fromb(a >= b) } gt. { fromb(a > b) } + } + } + (const_int(a), const_int(b)) { + alt op { + add. { const_int(a + b) } sub. { const_int(a - b) } + mul. { const_int(a * b) } div. { const_int(a / b) } + rem. { const_int(a % b) } and. | bitand. { const_int(a & b) } + or. | bitor. { const_int(a | b) } bitxor. { const_int(a ^ b) } + eq. { fromb(a == b) } lt. { fromb(a < b) } + le. { fromb(a <= b) } ne. { fromb(a != b) } + ge. { fromb(a >= b) } gt. { fromb(a > b) } + } + } + } + } + expr_lit(lit) { lit_to_const(lit) } } } -fn lit_eq(a: @lit, b: @lit) -> bool { compare_lit(a, b) == 0 } - -fn lit_types_match(a: @lit, b: @lit) -> bool { - alt (a.node, b.node) { - (lit_int(_), lit_int(_)) | (lit_uint(_), lit_uint(_)) | - (lit_char(_), lit_char(_)) | (lit_float(_), lit_float(_)) | - (lit_str(_), lit_str(_)) | (lit_nil., lit_nil.) | - (lit_bool(_), lit_bool(_ )) { true } - (lit_mach_int(ta, _), lit_mach_int(tb, _)) | - (lit_mach_float(ta, _), lit_mach_float(tb, _)) { ta == tb } - _ { false } +fn lit_to_const(lit: @lit) -> const_val { + alt lit.node { + lit_str(s) { const_str(s) } + lit_char(ch) { const_int(ch as i64) } + lit_int(i) | lit_mach_int(_, i) { const_int(i as i64) } + lit_uint(ui) { const_int(ui as i64) } + lit_float(s) | lit_mach_float(_, s) { + const_float(std::float::from_str(s)) + } + lit_nil. { const_int(0i64) } + lit_bool(b) { const_int(b as i64) } } } -fn lit_is_numeric(l: @ast::lit) -> bool { - alt l.node { - ast::lit_int(_) | ast::lit_char(_) | ast::lit_uint(_) | - ast::lit_mach_int(_, _) | ast::lit_float(_) | ast::lit_mach_float(_,_) { - true - } - _ { false } - } +fn compare_const_vals(a: const_val, b: const_val) -> int { + alt (a, b) { + (const_int(a), const_int(b)) { a == b ? 0 : a < b ? -1 : 1 } + (const_float(a), const_float(b)) { a == b ? 0 : a < b ? -1 : 1 } + (const_str(a), const_str(b)) { a == b ? 0 : a < b ? -1 : 1 } + } +} + +fn compare_lit_exprs(a: @expr, b: @expr) -> int { + compare_const_vals(eval_const_expr(a), eval_const_expr(b)) +} + +fn lit_expr_eq(a: @expr, b: @expr) -> bool { compare_lit_exprs(a, b) == 0 } + +fn lit_eq(a: @lit, b: @lit) -> bool { + compare_const_vals(lit_to_const(a), lit_to_const(b)) == 0 } // Local Variables: diff --git a/src/comp/syntax/parse/parser.rs b/src/comp/syntax/parse/parser.rs index 6d751dff6455..9baf6cf3ccf5 100644 --- a/src/comp/syntax/parse/parser.rs +++ b/src/comp/syntax/parse/parser.rs @@ -9,7 +9,7 @@ import util::interner; import ast::{node_id, spanned}; import front::attr; -tag restriction { UNRESTRICTED; RESTRICT_NO_CALL_EXPRS; } +tag restriction { UNRESTRICTED; RESTRICT_NO_CALL_EXPRS; RESTRICT_NO_BAR_OP; } tag file_type { CRATE_FILE; SOURCE_FILE; } @@ -1189,6 +1189,8 @@ fn parse_more_binops(p: parser, lhs: @ast::expr, min_prec: int) -> } none. { none } }; + if peeked == token::BINOP(token::OR) && + p.get_restriction() == RESTRICT_NO_BAR_OP { ret lhs; } for cur: op_spec in *p.get_prec_table() { if cur.prec > min_prec && cur.tok == peeked { p.bump(); @@ -1462,9 +1464,9 @@ fn parse_pat(p: parser) -> @ast::pat { if p.peek() == token::RPAREN { hi = p.get_hi_pos(); p.bump(); - pat = - ast::pat_lit(@{node: ast::lit_nil, - span: ast_util::mk_sp(lo, hi)}); + let lit = @{node: ast::lit_nil, span: ast_util::mk_sp(lo, hi)}; + let expr = mk_expr(p, lo, hi, ast::expr_lit(lit)); + pat = ast::pat_lit(expr); } else { let fields = [parse_pat(p)]; while p.peek() == token::COMMA { @@ -1479,14 +1481,14 @@ fn parse_pat(p: parser) -> @ast::pat { } tok { if !is_ident(tok) || is_word(p, "true") || is_word(p, "false") { - let lit = parse_lit(p); + let val = parse_expr_res(p, RESTRICT_NO_BAR_OP); if eat_word(p, "to") { - let end = parse_lit(p); + let end = parse_expr_res(p, RESTRICT_NO_BAR_OP); hi = end.span.hi; - pat = ast::pat_range(@lit, @end); + pat = ast::pat_range(val, end); } else { - hi = lit.span.hi; - pat = ast::pat_lit(@lit); + hi = val.span.hi; + pat = ast::pat_lit(val); } } else if is_plain_ident(p) && alt p.look_ahead(1u) { diff --git a/src/comp/syntax/print/pprust.rs b/src/comp/syntax/print/pprust.rs index 9cd20cc87af5..612768c4137c 100644 --- a/src/comp/syntax/print/pprust.rs +++ b/src/comp/syntax/print/pprust.rs @@ -1061,7 +1061,6 @@ fn print_pat(s: ps, &&pat: @ast::pat) { alt pat.node { ast::pat_wild. { word(s.s, "_"); } ast::pat_bind(id) { word(s.s, id); } - ast::pat_lit(lit) { print_literal(s, lit); } ast::pat_tag(path, args) { print_path(s, path, true); if vec::len(args) > 0u { @@ -1094,11 +1093,12 @@ fn print_pat(s: ps, &&pat: @ast::pat) { } ast::pat_box(inner) { word(s.s, "@"); print_pat(s, inner); } ast::pat_uniq(inner) { word(s.s, "~"); print_pat(s, inner); } + ast::pat_lit(e) { print_expr(s, e); } ast::pat_range(begin, end) { - print_literal(s, begin); + print_expr(s, begin); space(s.s); word_space(s, "to"); - print_literal(s, end); + print_expr(s, end); } } s.ann.post(ann_node);