Clean up unification code

This commit is contained in:
Marijn Haverbeke 2012-02-21 17:02:02 +01:00
parent ff927f18f5
commit 9d20ed7bf9

View file

@ -258,6 +258,7 @@ enum type_err {
terr_mismatch,
terr_ret_style_mismatch(ast::ret_style, ast::ret_style),
terr_box_mutability,
terr_ptr_mutability,
terr_vec_mutability,
terr_tuple_size(uint, uint),
terr_record_size(uint, uint),
@ -1557,26 +1558,9 @@ mod unify {
}
}
fn record_var_binding_for_expected(
cx: @uctxt, key: int, typ: t, variance: variance) -> result {
record_var_binding(
cx, key, typ, variance_transform(variance, covariant))
}
fn record_var_binding_for_actual(
cx: @uctxt, key: int, typ: t, variance: variance) -> result {
// Unifying in 'the other direction' so flip the variance
record_var_binding(
cx, key, typ, variance_transform(variance, contravariant))
}
fn record_var_binding(
cx: @uctxt, key: int, typ: t, variance: variance) -> result {
let vb = alt cx.st { in_bindings(vb) { vb }
_ { cx.tcx.sess.bug("Someone forgot to document an invariant \
in record_var_binding"); }
};
fn record_var_binding(cx: @uctxt, key: int, typ: t, variance: variance)
-> result {
let vb = alt check cx.st { in_bindings(vb) { vb } };
ufind::grow(vb.sets, (key as uint) + 1u);
let root = ufind::find(vb.sets, key as uint);
let result_type = typ;
@ -1589,8 +1573,8 @@ mod unify {
}
none {/* fall through */ }
}
smallintmap::insert::<t>(vb.types, root, result_type);
ret ures_ok(typ);
smallintmap::insert(vb.types, root, result_type);
ret ures_ok(mk_var(cx.tcx, key));
}
// Simple structural type comparison.
@ -1837,7 +1821,7 @@ mod unify {
}
fn unify_tps(cx: @uctxt, expected_tps: [t], actual_tps: [t],
variance: variance, finish: fn([t]) -> result) -> result {
variance: variance, finish: fn([t]) -> t) -> result {
let result_tps = [], i = 0u;
for exp in expected_tps {
let act = actual_tps[i];
@ -1848,345 +1832,148 @@ mod unify {
_ { ret result; }
}
}
finish(result_tps)
ures_ok(finish(result_tps))
}
fn unify_mt(cx: @uctxt, e_mt: mt, a_mt: mt, variance: variance,
mut_err: type_err, finish: fn(ctxt, mt) -> t) -> result {
alt unify_mut(e_mt.mutbl, a_mt.mutbl, variance) {
none { ures_err(mut_err) }
some((mutt, var)) {
alt unify_step(cx, e_mt.ty, a_mt.ty, var) {
ures_ok(result_sub) {
ures_ok(finish(cx.tcx, {ty: result_sub, mutbl: mutt}))
}
err { err }
}
}
}
}
fn unify_step(cx: @uctxt, expected: t, actual: t,
variance: variance) -> result {
// FIXME: rewrite this using tuple pattern matching when available, to
// avoid all this rightward drift and spikiness.
// NOTE: we have tuple matching now, but that involves copying the
// matched elements into a tuple first, which is expensive, since sty
// holds vectors, which are currently unique
// Fast path.
if expected == actual { ret ures_ok(expected); }
// Stage 1: Handle the cases in which one side or another is a type
// variable
alt get(actual).struct {
// If the RHS is a variable type, then just do the
// appropriate binding.
ty_var(actual_id) {
let actual_n = actual_id as uint;
alt get(expected).struct {
ty_var(expected_id) {
let expected_n = expected_id as uint;
alt union(cx, expected_n, actual_n, variance) {
unres_ok {/* fall through */ }
unres_err(t_e) { ret ures_err(t_e); }
}
}
_ {
// Just bind the type variable to the expected type.
alt record_var_binding_for_actual(
cx, actual_id, expected, variance) {
ures_ok(_) {/* fall through */ }
rs { ret rs; }
}
}
}
ret ures_ok(mk_var(cx.tcx, actual_id));
}
_ {/* empty */ }
}
alt get(expected).struct {
ty_var(expected_id) {
// Add a binding. (`actual` can't actually be a var here.)
alt record_var_binding_for_expected(
cx, expected_id, actual,
variance) {
ures_ok(_) {/* fall through */ }
rs { ret rs; }
}
ret ures_ok(mk_var(cx.tcx, expected_id));
}
_ {/* fall through */ }
}
// Stage 2: Handle all other cases.
alt get(actual).struct {
ty_bot { ret ures_ok(expected); }
_ {/* fall through */ }
}
alt get(expected).struct {
ty_nil { ret struct_cmp(cx, expected, actual); }
// _|_ unifies with anything
ty_bot {
ret ures_ok(actual);
}
ty_bool | ty_int(_) | ty_uint(_) | ty_float(_) |
ty_str | ty_send_type {
ret struct_cmp(cx, expected, actual);
}
ty_param(expected_n, _) {
alt get(actual).struct {
ty_param(actual_n, _) if expected_n == actual_n {
ret ures_ok(expected);
}
_ { ret ures_err(terr_mismatch); }
alt (get(expected).struct, get(actual).struct) {
(ty_var(e_id), ty_var(a_id)) {
alt union(cx, e_id as uint, a_id as uint, variance) {
unres_ok { ures_ok(actual) }
unres_err(err) { ures_err(err) }
}
}
ty_enum(expected_id, expected_tps) {
alt get(actual).struct {
ty_enum(actual_id, actual_tps) {
if expected_id != actual_id {
ret ures_err(terr_mismatch);
}
ret unify_tps(cx, expected_tps, actual_tps, variance, {|tps|
ures_ok(mk_enum(cx.tcx, expected_id, tps))
});
}
_ {/* fall through */ }
}
ret ures_err(terr_mismatch);
(_, ty_var(a_id)) {
let v = variance_transform(variance, contravariant);
record_var_binding(cx, a_id, expected, v)
}
ty_iface(expected_id, expected_tps) {
alt get(actual).struct {
ty_iface(actual_id, actual_tps) {
if expected_id != actual_id {
ret ures_err(terr_mismatch);
}
ret unify_tps(cx, expected_tps, actual_tps, variance, {|tps|
ures_ok(mk_iface(cx.tcx, expected_id, tps))
});
}
_ {}
}
ret ures_err(terr_mismatch);
(ty_var(e_id), _) {
let v = variance_transform(variance, covariant);
record_var_binding(cx, e_id, actual, v)
}
ty_box(expected_mt) {
alt get(actual).struct {
ty_box(actual_mt) {
let (mutt, var) = alt unify_mut(
expected_mt.mutbl, actual_mt.mutbl, variance) {
none { ret ures_err(terr_box_mutability); }
some(mv) { mv }
};
let result = unify_step(
cx, expected_mt.ty, actual_mt.ty, var);
alt result {
ures_ok(result_sub) {
let mt = {ty: result_sub, mutbl: mutt};
ret ures_ok(mk_box(cx.tcx, mt));
}
_ { ret result; }
}
(_, ty_bot) { ures_ok(expected) }
(ty_bot, _) { ures_ok(actual) }
(ty_nil, _) | (ty_bool, _) | (ty_int(_), _) | (ty_uint(_), _) |
(ty_float(_), _) | (ty_str, _) | (ty_send_type, _) {
struct_cmp(cx, expected, actual)
}
(ty_param(e_n, _), ty_param(a_n, _)) if e_n == a_n {
ures_ok(expected)
}
(ty_enum(e_id, e_tps), ty_enum(a_id, a_tps)) if e_id == a_id {
unify_tps(cx, e_tps, a_tps, variance, {|tps|
mk_enum(cx.tcx, e_id, tps)
})
}
(ty_iface(e_id, e_tps), ty_iface(a_id, a_tps)) if e_id == a_id {
unify_tps(cx, e_tps, a_tps, variance, {|tps|
mk_iface(cx.tcx, e_id, tps)
})
}
(ty_class(e_id, e_tps), ty_class(a_id, a_tps)) if e_id == a_id {
unify_tps(cx, e_tps, a_tps, variance, {|tps|
mk_class(cx.tcx, e_id, tps)
})
}
(ty_box(e_mt), ty_box(a_mt)) {
unify_mt(cx, e_mt, a_mt, variance, terr_box_mutability, mk_box)
}
(ty_uniq(e_mt), ty_uniq(a_mt)) {
unify_mt(cx, e_mt, a_mt, variance, terr_box_mutability, mk_uniq)
}
(ty_vec(e_mt), ty_vec(a_mt)) {
unify_mt(cx, e_mt, a_mt, variance, terr_vec_mutability, mk_vec)
}
(ty_ptr(e_mt), ty_ptr(a_mt)) {
unify_mt(cx, e_mt, a_mt, variance, terr_ptr_mutability, mk_ptr)
}
(ty_res(e_id, e_inner, e_tps), ty_res(a_id, a_inner, a_tps))
if e_id == a_id {
alt unify_step(cx, e_inner, a_inner, variance) {
ures_ok(res_inner) {
unify_tps(cx, e_tps, a_tps, variance, {|tps|
mk_res(cx.tcx, a_id, res_inner, tps)
})
}
_ { ret ures_err(terr_mismatch); }
err { err }
}
}
ty_uniq(expected_mt) {
alt get(actual).struct {
ty_uniq(actual_mt) {
let (mutt, var) = alt unify_mut(
expected_mt.mutbl, actual_mt.mutbl, variance) {
none { ret ures_err(terr_box_mutability); }
some(mv) { mv }
};
let result = unify_step(
cx, expected_mt.ty, actual_mt.ty, var);
alt result {
ures_ok(result_mt) {
let mt = {ty: result_mt, mutbl: mutt};
ret ures_ok(mk_uniq(cx.tcx, mt));
}
_ { ret result; }
}
}
_ { ret ures_err(terr_mismatch); }
(ty_rec(e_fields), ty_rec(a_fields)) {
let e_len = e_fields.len(), a_len = a_fields.len();
if e_len != a_len {
ret ures_err(terr_record_size(e_len, a_len));
}
}
ty_vec(expected_mt) {
alt get(actual).struct {
ty_vec(actual_mt) {
let (mutt, var) = alt unify_mut(
expected_mt.mutbl, actual_mt.mutbl, variance) {
none { ret ures_err(terr_vec_mutability); }
some(mv) { mv }
};
let result = unify_step(
cx, expected_mt.ty, actual_mt.ty, var);
alt result {
ures_ok(result_sub) {
let mt = {ty: result_sub, mutbl: mutt};
ret ures_ok(mk_vec(cx.tcx, mt));
}
_ { ret result; }
let result_fields = [], i = 0u;
while i < a_len {
let e_field = e_fields[i], a_field = a_fields[i];
if e_field.ident != a_field.ident {
ret ures_err(terr_record_fields(e_field.ident,
a_field.ident));
}
}
_ { ret ures_err(terr_mismatch); }
alt unify_mt(cx, e_field.mt, a_field.mt, variance,
terr_record_mutability, {|cx, mt|
result_fields += [{mt: mt with e_field}];
mk_nil(cx)
}) {
ures_ok(_) {}
err { ret err; }
}
i += 1u;
}
ures_ok(mk_rec(cx.tcx, result_fields))
}
ty_ptr(expected_mt) {
alt get(actual).struct {
ty_ptr(actual_mt) {
let (mutt, var) = alt unify_mut(
expected_mt.mutbl, actual_mt.mutbl, variance) {
none { ret ures_err(terr_vec_mutability); }
some(mv) { mv }
};
let result = unify_step(
cx, expected_mt.ty, actual_mt.ty, var);
alt result {
ures_ok(result_sub) {
let mt = {ty: result_sub, mutbl: mutt};
ret ures_ok(mk_ptr(cx.tcx, mt));
}
_ { ret result; }
(ty_tup(e_elems), ty_tup(a_elems)) {
let e_len = e_elems.len(), a_len = a_elems.len();
if e_len != a_len { ret ures_err(terr_tuple_size(e_len, a_len)); }
let result_elems = [], i = 0u;
while i < a_len {
alt unify_step(cx, e_elems[i], a_elems[i], variance) {
ures_ok(rty) { result_elems += [rty]; }
err { ret err; }
}
}
_ { ret ures_err(terr_mismatch); }
i += 1u;
}
ures_ok(mk_tup(cx.tcx, result_elems))
}
ty_res(ex_id, ex_inner, ex_tps) {
alt get(actual).struct {
ty_res(act_id, act_inner, act_tps) {
if ex_id.crate != act_id.crate || ex_id.node != act_id.node {
ret ures_err(terr_mismatch);
}
let result = unify_step(
cx, ex_inner, act_inner, variance);
alt result {
ures_ok(res_inner) {
let i = 0u;
let res_tps = [];
for ex_tp: t in ex_tps {
let result = unify_step(
cx, ex_tp, act_tps[i], variance);
alt result {
ures_ok(rty) { res_tps += [rty]; }
_ { ret result; }
}
i += 1u;
}
ret ures_ok(mk_res(cx.tcx, act_id, res_inner, res_tps));
}
_ { ret result; }
}
}
_ { ret ures_err(terr_mismatch); }
}
(ty_fn(e_fty), ty_fn(a_fty)) {
unify_fn(cx, e_fty, a_fty, variance)
}
ty_rec(expected_fields) {
alt get(actual).struct {
ty_rec(actual_fields) {
let expected_len = vec::len::<field>(expected_fields);
let actual_len = vec::len::<field>(actual_fields);
if expected_len != actual_len {
let err = terr_record_size(expected_len, actual_len);
ret ures_err(err);
}
let result_fields = [], i = 0u;
while i < actual_len {
let expected_field = expected_fields[i],
actual_field = actual_fields[i];
let u_mut = unify_mut(expected_field.mt.mutbl,
actual_field.mt.mutbl,
variance);
let (mutt, var) = alt u_mut {
none { ret ures_err(terr_record_mutability); }
some(mv) { mv }
};
if !str::eq(expected_field.ident, actual_field.ident) {
let err =
terr_record_fields(expected_field.ident,
actual_field.ident);
ret ures_err(err);
}
let result =
unify_step(cx, expected_field.mt.ty,
actual_field.mt.ty, var);
alt result {
ures_ok(rty) {
let mt = {ty: rty, mutbl: mutt};
result_fields += [{mt: mt with expected_field}];
}
_ { ret result; }
}
i += 1u;
}
ret ures_ok(mk_rec(cx.tcx, result_fields));
}
_ { ret ures_err(terr_mismatch); }
}
}
ty_tup(expected_elems) {
alt get(actual).struct {
ty_tup(actual_elems) {
let expected_len = vec::len(expected_elems);
let actual_len = vec::len(actual_elems);
if expected_len != actual_len {
let err = terr_tuple_size(expected_len, actual_len);
ret ures_err(err);
}
let result_elems = [], i = 0u;
while i < actual_len {
alt unify_step(cx, expected_elems[i], actual_elems[i],
variance) {
ures_ok(rty) { result_elems += [rty]; }
r { ret r; }
}
i += 1u;
}
ret ures_ok(mk_tup(cx.tcx, result_elems));
}
_ { ret ures_err(terr_mismatch); }
}
}
ty_fn(expected_f) {
alt get(actual).struct {
ty_fn(actual_f) {
ret unify_fn(cx, expected_f, actual_f, variance);
}
_ { ret ures_err(terr_mismatch); }
}
}
ty_constr(expected_t, expected_constrs) {
(ty_constr(e_t, e_constrs), ty_constr(a_t, a_constrs)) {
// unify the base types...
alt get(actual).struct {
ty_constr(actual_t, actual_constrs) {
let rslt = unify_step(
cx, expected_t, actual_t, variance);
alt rslt {
ures_ok(rty) {
// FIXME: probably too restrictive --
// requires the constraints to be
// syntactically equal
ret unify_constrs(expected, expected_constrs,
actual_constrs);
}
_ { ret rslt; }
}
}
_ {
// If the actual type is *not* a constrained type,
// then we go ahead and just ignore the constraints on
// the expected type. typestate handles the rest.
ret unify_step(
cx, expected_t, actual, variance);
alt unify_step(cx, e_t, a_t, variance) {
ures_ok(rty) {
// FIXME: probably too restrictive --
// requires the constraints to be syntactically equal
unify_constrs(expected, e_constrs, a_constrs)
}
err { err }
}
}
ty_class(expected_class, expected_tys) {
alt get(actual).struct {
ty_class(actual_class, actual_tys) {
if expected_class != actual_class {
ret ures_err(terr_mismatch);
}
ret unify_tps(cx, expected_tys, actual_tys, variance,
{|tps|
ures_ok(mk_class(cx.tcx, expected_class, tps))});
}
_ {
ret ures_err(terr_mismatch);
}
}
(ty_constr(e_t, _), _) {
// If the actual type is *not* a constrained type,
// then we go ahead and just ignore the constraints on
// the expected type. typestate handles the rest.
unify_step(cx, e_t, actual, variance)
}
_ { cx.tcx.sess.bug("unify: unexpected type"); }
_ { ures_err(terr_mismatch) }
}
}
fn unify(expected: t, actual: t, st: unify_style,
@ -2293,6 +2080,7 @@ fn type_err_to_str(err: type_err) -> str {
}
terr_box_mutability { ret "boxed values differ in mutability"; }
terr_vec_mutability { ret "vectors differ in mutability"; }
terr_ptr_mutability { ret "pointers differ in mutability"; }
terr_tuple_size(e_sz, a_sz) {
ret "expected a tuple with " + uint::to_str(e_sz, 10u) +
" elements but found one with " + uint::to_str(a_sz, 10u) +