diff --git a/src/librustc/middle/infer/freshen.rs b/src/librustc/middle/infer/freshen.rs index 29f74d12ea3e..d93d13beec8f 100644 --- a/src/librustc/middle/infer/freshen.rs +++ b/src/librustc/middle/infer/freshen.rs @@ -37,7 +37,7 @@ use middle::ty_fold::TypeFolder; use std::collections::hash_map::{self, Entry}; use super::InferCtxt; -use super::unify::ToType; +use super::unify_key::ToType; pub struct TypeFreshener<'a, 'tcx:'a> { infcx: &'a InferCtxt<'a, 'tcx>, diff --git a/src/librustc/middle/infer/mod.rs b/src/librustc/middle/infer/mod.rs index 0f62b440bf32..b0921a266f39 100644 --- a/src/librustc/middle/infer/mod.rs +++ b/src/librustc/middle/infer/mod.rs @@ -29,6 +29,7 @@ use middle::ty::replace_late_bound_regions; use middle::ty::{self, Ty}; use middle::ty_fold::{TypeFolder, TypeFoldable}; use middle::ty_relate::{Relate, RelateResult, TypeRelation}; +use rustc_data_structures::unify::{self, UnificationTable}; use std::cell::{RefCell}; use std::fmt; use std::rc::Rc; @@ -41,8 +42,8 @@ use util::ppaux::{Repr, UserString}; use self::combine::CombineFields; use self::region_inference::{RegionVarBindings, RegionSnapshot}; -use self::unify::{ToType, UnificationTable}; use self::error_reporting::ErrorReporting; +use self::unify_key::ToType; pub mod bivariate; pub mod combine; @@ -57,7 +58,7 @@ pub mod resolve; mod freshen; pub mod sub; pub mod type_variable; -pub mod unify; +pub mod unify_key; pub type Bound = Option; pub type UnitResult<'tcx> = RelateResult<'tcx, ()>; // "unify result" diff --git a/src/librustc/middle/infer/unify_key.rs b/src/librustc/middle/infer/unify_key.rs new file mode 100644 index 000000000000..6b23e2c5029b --- /dev/null +++ b/src/librustc/middle/infer/unify_key.rs @@ -0,0 +1,48 @@ +// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use middle::ty::{self, IntVarValue, Ty}; +use rustc_data_structures::unify::UnifyKey; +use syntax::ast; + +pub trait ToType<'tcx> { + fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx>; +} + +impl UnifyKey for ty::IntVid { + type Value = Option; + fn index(&self) -> u32 { self.index } + fn from_index(i: u32) -> ty::IntVid { ty::IntVid { index: i } } + fn tag(_: Option) -> &'static str { "IntVid" } +} + +impl<'tcx> ToType<'tcx> for IntVarValue { + fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> { + match *self { + ty::IntType(i) => ty::mk_mach_int(tcx, i), + ty::UintType(i) => ty::mk_mach_uint(tcx, i), + } + } +} + +// Floating point type keys + +impl UnifyKey for ty::FloatVid { + type Value = Option; + fn index(&self) -> u32 { self.index } + fn from_index(i: u32) -> ty::FloatVid { ty::FloatVid { index: i } } + fn tag(_: Option) -> &'static str { "FloatVid" } +} + +impl<'tcx> ToType<'tcx> for ast::FloatTy { + fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> { + ty::mk_mach_float(tcx, *self) + } +} diff --git a/src/librustc_data_structures/lib.rs b/src/librustc_data_structures/lib.rs index d90a40941cb2..6562a7488984 100644 --- a/src/librustc_data_structures/lib.rs +++ b/src/librustc_data_structures/lib.rs @@ -35,3 +35,4 @@ extern crate serialize as rustc_serialize; // used by deriving pub mod snapshot_vec; pub mod graph; pub mod bitvec; +pub mod unify; diff --git a/src/librustc/middle/infer/unify.rs b/src/librustc_data_structures/unify/mod.rs similarity index 64% rename from src/librustc/middle/infer/unify.rs rename to src/librustc_data_structures/unify/mod.rs index 5aec42271369..aff79e25956f 100644 --- a/src/librustc/middle/infer/unify.rs +++ b/src/librustc_data_structures/unify/mod.rs @@ -8,16 +8,13 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -pub use self::VarValue::*; - use std::marker; - -use middle::ty::{IntVarValue}; -use middle::ty::{self, Ty}; use std::fmt::Debug; use std::marker::PhantomData; -use syntax::ast; -use rustc_data_structures::snapshot_vec as sv; +use snapshot_vec as sv; + +#[cfg(test)] +mod test; /// This trait is implemented by any type that can serve as a type /// variable. We call such variables *unification keys*. For example, @@ -28,9 +25,10 @@ use rustc_data_structures::snapshot_vec as sv; /// `IntVid`, this is `Option`, representing some /// (possibly not yet known) sort of integer. /// -/// Implementations of this trait are at the end of this file. -pub trait UnifyKey : Clone + Debug + PartialEq { - type Value : UnifyValue; +/// Clients are expected to provide implementations of this trait; you +/// can see some examples in the `test` module. +pub trait UnifyKey : Copy + Clone + Debug + PartialEq { + type Value: Clone + PartialEq + Debug; fn index(&self) -> u32; @@ -39,15 +37,6 @@ pub trait UnifyKey : Clone + Debug + PartialEq { fn tag(k: Option) -> &'static str; } -/// Trait for valid types that a type variable can be set to. Note that -/// this is typically not the end type that the value will take on, but -/// rather an `Option` wrapper (where `None` represents a variable -/// whose value is not yet set). -/// -/// Implementations of this trait are at the end of this file. -pub trait UnifyValue : Clone + PartialEq + Debug { -} - /// Value of a unification key. We implement Tarjan's union-find /// algorithm: when two keys are unified, one of them is converted /// into a "redirect" pointing at the other. These redirects form a @@ -57,9 +46,10 @@ pub trait UnifyValue : Clone + PartialEq + Debug { /// time of the algorithm under control. For more information, see /// . #[derive(PartialEq,Clone,Debug)] -pub enum VarValue { - Redirect(K), - Root(K::Value, usize), +pub struct VarValue { + parent: K, // if equal to self, this is a root + value: K::Value, // value assigned (only relevant to root) + rank: u32, // max depth (only relevant to root) } /// Table of unification keys and their values. @@ -76,16 +66,46 @@ pub struct Snapshot { snapshot: sv::Snapshot, } -/// Internal type used to represent the result of a `get()` operation. -/// Conveys the current root and value of the key. -pub struct Node { - pub key: K, - pub value: K::Value, - pub rank: usize, -} - #[derive(Copy, Clone)] -pub struct Delegate(PhantomData); +struct Delegate(PhantomData); + +impl VarValue { + fn new_var(key: K, value: K::Value) -> VarValue { + VarValue::new(key, value, 0) + } + + fn new(parent: K, value: K::Value, rank: u32) -> VarValue { + VarValue { parent: parent, // this is a root + value: value, + rank: rank } + } + + fn redirect(self, to: K) -> VarValue { + VarValue { parent: to, ..self } + } + + fn root(self, rank: u32, value: K::Value) -> VarValue { + VarValue { rank: rank, value: value, ..self } + } + + /// Returns the key of this node. Only valid if this is a root + /// node, which you yourself must ensure. + fn key(&self) -> K { + self.parent + } + + fn parent(&self, self_key: K) -> Option { + self.if_not_self(self.parent, self_key) + } + + fn if_not_self(&self, key: K, self_key: K) -> Option { + if key == self_key { + None + } else { + Some(key) + } + } +} // We can't use V:LatticeValue, much as I would like to, // because frequently the pattern is that V=Option for some @@ -95,7 +115,7 @@ pub struct Delegate(PhantomData); impl UnificationTable { pub fn new() -> UnificationTable { UnificationTable { - values: sv::SnapshotVec::new(), + values: sv::SnapshotVec::new() } } @@ -121,12 +141,13 @@ impl UnificationTable { } pub fn new_key(&mut self, value: K::Value) -> K { - let index = self.values.push(Root(value, 0)); - let k = UnifyKey::from_index(index as u32); + let len = self.values.len(); + let key: K = UnifyKey::from_index(len as u32); + self.values.push(VarValue::new_var(key, value)); debug!("{}: created new key: {:?}", UnifyKey::tag(None::), - k); - k + key); + key } /// Find the root node for `vid`. This uses the standard @@ -135,36 +156,34 @@ impl UnificationTable { /// /// NB. This is a building-block operation and you would probably /// prefer to call `probe` below. - fn get(&mut self, vid: K) -> Node { + fn get(&mut self, vid: K) -> VarValue { let index = vid.index() as usize; - let value = (*self.values.get(index)).clone(); - match value { - Redirect(redirect) => { - let node: Node = self.get(redirect.clone()); - if node.key != redirect { + let mut value: VarValue = self.values.get(index).clone(); + match value.parent(vid) { + Some(redirect) => { + let root: VarValue = self.get(redirect); + if root.key() != redirect { // Path compression - self.values.set(index, Redirect(node.key.clone())); + value.parent = root.key(); + self.values.set(index, value); } - node + root } - Root(value, rank) => { - Node { key: vid, value: value, rank: rank } + None => { + value } } } - fn is_root(&self, key: &K) -> bool { + fn is_root(&self, key: K) -> bool { let index = key.index() as usize; - match *self.values.get(index) { - Redirect(..) => false, - Root(..) => true, - } + self.values.get(index).parent(key).is_none() } /// Sets the value for `vid` to `new_value`. `vid` MUST be a root /// node! This is an internal operation used to impl other things. fn set(&mut self, key: K, new_value: VarValue) { - assert!(self.is_root(&key)); + assert!(self.is_root(key)); debug!("Updating variable {:?} to {:?}", key, new_value); @@ -181,31 +200,36 @@ impl UnificationTable { /// really more of a building block. If the values associated with /// your key are non-trivial, you would probably prefer to call /// `unify_var_var` below. - fn unify(&mut self, node_a: &Node, node_b: &Node, new_value: K::Value) { - debug!("unify(node_a(id={:?}, rank={:?}), node_b(id={:?}, rank={:?}))", - node_a.key, - node_a.rank, - node_b.key, - node_b.rank); + fn unify(&mut self, root_a: VarValue, root_b: VarValue, new_value: K::Value) { + debug!("unify(root_a(id={:?}, rank={:?}), root_b(id={:?}, rank={:?}))", + root_a.key(), + root_a.rank, + root_b.key(), + root_b.rank); - let (new_root, new_rank) = if node_a.rank > node_b.rank { + if root_a.rank > root_b.rank { // a has greater rank, so a should become b's parent, // i.e., b should redirect to a. - self.set(node_b.key.clone(), Redirect(node_a.key.clone())); - (node_a.key.clone(), node_a.rank) - } else if node_a.rank < node_b.rank { + self.redirect_root(root_a.rank, root_b, root_a, new_value); + } else if root_a.rank < root_b.rank { // b has greater rank, so a should redirect to b. - self.set(node_a.key.clone(), Redirect(node_b.key.clone())); - (node_b.key.clone(), node_b.rank) + self.redirect_root(root_b.rank, root_a, root_b, new_value); } else { // If equal, redirect one to the other and increment the // other's rank. - assert_eq!(node_a.rank, node_b.rank); - self.set(node_b.key.clone(), Redirect(node_a.key.clone())); - (node_a.key.clone(), node_a.rank + 1) - }; + self.redirect_root(root_a.rank + 1, root_a, root_b, new_value); + } + } - self.set(new_root, Root(new_value, new_rank)); + fn redirect_root(&mut self, + new_rank: u32, + old_root: VarValue, + new_root: VarValue, + new_value: K::Value) { + let old_root_key = old_root.key(); + let new_root_key = new_root.key(); + self.set(old_root_key, old_root.redirect(new_root_key)); + self.set(new_root_key, new_root.root(new_rank, new_value)); } } @@ -213,8 +237,31 @@ impl sv::SnapshotVecDelegate for Delegate { type Value = VarValue; type Undo = (); - fn reverse(_: &mut Vec>, _: ()) { - panic!("Nothing to reverse"); + fn reverse(_: &mut Vec>, _: ()) {} +} + +/////////////////////////////////////////////////////////////////////////// +// Base union-find algorithm, where we are just making setes + +impl<'tcx,K> UnificationTable + where K : UnifyKey, +{ + pub fn union(&mut self, a_id: K, b_id: K) { + let node_a = self.get(a_id); + let node_b = self.get(b_id); + let a_id = node_a.key(); + let b_id = node_b.key(); + if a_id != b_id { + self.unify(node_a, node_b, ()); + } + } + + pub fn find(&mut self, id: K) -> K { + self.get(id).key() + } + + pub fn unioned(&mut self, a_id: K, b_id: K) -> bool { + self.find(a_id) == self.find(b_id) } } @@ -226,7 +273,6 @@ impl sv::SnapshotVecDelegate for Delegate { impl<'tcx,K,V> UnificationTable where K: UnifyKey>, V: Clone+PartialEq, - Option: UnifyValue, { pub fn unify_var_var(&mut self, a_id: K, @@ -235,8 +281,8 @@ impl<'tcx,K,V> UnificationTable { let node_a = self.get(a_id); let node_b = self.get(b_id); - let a_id = node_a.key.clone(); - let b_id = node_b.key.clone(); + let a_id = node_a.key(); + let b_id = node_b.key(); if a_id == b_id { return Ok(()); } @@ -257,7 +303,7 @@ impl<'tcx,K,V> UnificationTable } }; - Ok(self.unify(&node_a, &node_b, combined)) + Ok(self.unify(node_a, node_b, combined)) } /// Sets the value of the key `a_id` to `b`. Because simple keys do not have any subtyping @@ -267,12 +313,12 @@ impl<'tcx,K,V> UnificationTable b: V) -> Result<(),(V,V)> { - let node_a = self.get(a_id); - let a_id = node_a.key.clone(); + let mut node_a = self.get(a_id); match node_a.value { None => { - self.set(a_id, Root(Some(b), node_a.rank)); + node_a.value = Some(b); + self.set(node_a.key(), node_a); Ok(()) } @@ -295,46 +341,3 @@ impl<'tcx,K,V> UnificationTable } } -/////////////////////////////////////////////////////////////////////////// - -// Integral type keys - -pub trait ToType<'tcx> { - fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx>; -} - -impl UnifyKey for ty::IntVid { - type Value = Option; - fn index(&self) -> u32 { self.index } - fn from_index(i: u32) -> ty::IntVid { ty::IntVid { index: i } } - fn tag(_: Option) -> &'static str { "IntVid" } -} - -impl<'tcx> ToType<'tcx> for IntVarValue { - fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> { - match *self { - ty::IntType(i) => ty::mk_mach_int(tcx, i), - ty::UintType(i) => ty::mk_mach_uint(tcx, i), - } - } -} - -impl UnifyValue for Option { } - -// Floating point type keys - -impl UnifyKey for ty::FloatVid { - type Value = Option; - fn index(&self) -> u32 { self.index } - fn from_index(i: u32) -> ty::FloatVid { ty::FloatVid { index: i } } - fn tag(_: Option) -> &'static str { "FloatVid" } -} - -impl UnifyValue for Option { -} - -impl<'tcx> ToType<'tcx> for ast::FloatTy { - fn to_type(&self, tcx: &ty::ctxt<'tcx>) -> Ty<'tcx> { - ty::mk_mach_float(tcx, *self) - } -} diff --git a/src/librustc_data_structures/unify/test.rs b/src/librustc_data_structures/unify/test.rs new file mode 100644 index 000000000000..d662842a37af --- /dev/null +++ b/src/librustc_data_structures/unify/test.rs @@ -0,0 +1,185 @@ +#![allow(non_snake_case)] + +extern crate test; +use self::test::Bencher; +use std::collections::HashSet; +use unify::{UnifyKey, UnificationTable}; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +struct UnitKey(u32); + +impl UnifyKey for UnitKey { + type Value = (); + fn index(&self) -> u32 { self.0 } + fn from_index(u: u32) -> UnitKey { UnitKey(u) } + fn tag(_: Option) -> &'static str { "UnitKey" } +} + +#[test] +fn basic() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(()); + let k2 = ut.new_key(()); + assert_eq!(ut.unioned(k1, k2), false); + ut.union(k1, k2); + assert_eq!(ut.unioned(k1, k2), true); +} + +#[test] +fn big_array() { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + for i in 1..MAX { + let l = keys[i-1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } +} + +#[bench] +fn big_array_bench(b: &mut Bencher) { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + + b.iter(|| { + for i in 1..MAX { + let l = keys[i-1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } + }) +} + +#[test] +fn even_odd() { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 10; + + for i in 0..MAX { + let key = ut.new_key(()); + keys.push(key); + + if i >= 2 { + ut.union(key, keys[i-2]); + } + } + + for i in 1..MAX { + assert!(!ut.unioned(keys[i-1], keys[i])); + } + + for i in 2..MAX { + assert!(ut.unioned(keys[i-2], keys[i])); + } +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +struct IntKey(u32); + +impl UnifyKey for IntKey { + type Value = Option; + fn index(&self) -> u32 { self.0 } + fn from_index(u: u32) -> IntKey { IntKey(u) } + fn tag(_: Option) -> &'static str { "IntKey" } +} + +/// Test unifying a key whose value is `Some(_)` with a key whose value is `None`. +/// Afterwards both should be `Some(_)`. +#[test] +fn unify_key_Some_key_None() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + let k2 = ut.new_key(None); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert_eq!(ut.probe(k2), Some(22)); + assert_eq!(ut.probe(k1), Some(22)); +} + +/// Test unifying a key whose value is `None` with a key whose value is `Some(_)`. +/// Afterwards both should be `Some(_)`. +#[test] +fn unify_key_None_key_Some() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + let k2 = ut.new_key(None); + assert!(ut.unify_var_var(k2, k1).is_ok()); + assert_eq!(ut.probe(k2), Some(22)); + assert_eq!(ut.probe(k1), Some(22)); +} + +/// Test unifying a key whose value is `Some(x)` with a key whose value is `Some(y)`. +/// This should yield an error. +#[test] +fn unify_key_Some_x_key_Some_y() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + let k2 = ut.new_key(Some(23)); + assert_eq!(ut.unify_var_var(k1, k2), Err((22, 23))); + assert_eq!(ut.unify_var_var(k2, k1), Err((23, 22))); + assert_eq!(ut.probe(k1), Some(22)); + assert_eq!(ut.probe(k2), Some(23)); +} + +/// Test unifying a key whose value is `Some(x)` with a key whose value is `Some(x)`. +/// This should be ok. +#[test] +fn unify_key_Some_x_key_Some_x() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + let k2 = ut.new_key(Some(22)); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert_eq!(ut.probe(k1), Some(22)); + assert_eq!(ut.probe(k2), Some(22)); +} + +/// Test unifying a key whose value is `None` with a value is `x`. +/// Afterwards key should be `x`. +#[test] +fn unify_key_None_val() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + assert!(ut.unify_var_value(k1, 22).is_ok()); + assert_eq!(ut.probe(k1), Some(22)); +} + +/// Test unifying a key whose value is `Some(x)` with the value `y`. +/// This should yield an error. +#[test] +fn unify_key_Some_x_val_y() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + assert_eq!(ut.unify_var_value(k1, 23), Err((22, 23))); + assert_eq!(ut.probe(k1), Some(22)); +} + +/// Test unifying a key whose value is `Some(x)` with the value `x`. +/// This should be ok. +#[test] +fn unify_key_Some_x_val_x() { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(Some(22)); + assert!(ut.unify_var_value(k1, 22).is_ok()); + assert_eq!(ut.probe(k1), Some(22)); +} +