From 56fe4c2681f7ef467d8b3405279acbcfc6b0ebcf Mon Sep 17 00:00:00 2001 From: Marijn Haverbeke Date: Thu, 12 Jan 2012 16:57:30 +0100 Subject: [PATCH] Implement passing cast-to-vtable values as bounded params Closes #1492 --- src/comp/middle/trans.rs | 22 +--- src/comp/middle/trans_common.rs | 4 +- src/comp/middle/trans_impl.rs | 177 ++++++++++++++++++++++++-------- src/comp/middle/typeck.rs | 30 +++--- src/test/run-pass/iface-cast.rs | 6 +- 5 files changed, 161 insertions(+), 78 deletions(-) diff --git a/src/comp/middle/trans.rs b/src/comp/middle/trans.rs index 7fb21547589a..1ad1b4f50eff 100644 --- a/src/comp/middle/trans.rs +++ b/src/comp/middle/trans.rs @@ -5443,24 +5443,10 @@ fn trans_constant(ccx: @crate_ctxt, it: @ast::item, &&pt: [str], } ast::item_impl(tps, some(@{node: ast::ty_path(_, id), _}), _, ms) { let i_did = ast_util::def_id_of_def(ccx.tcx.def_map.get(id)); - let ty = ty::lookup_item_type(ccx.tcx, i_did).ty; - let new_pt = pt + [it.ident + int::str(it.id), "wrap"]; - let extra_tps = vec::map(tps, {|p| param_bounds(ccx, p)}); - let tbl = C_struct(vec::map(*ty::iface_methods(ccx.tcx, i_did), {|im| - alt vec::find(ms, {|m| m.ident == im.ident}) { - some(m) { - trans_impl::trans_wrapper(ccx, new_pt, extra_tps, m) - } - } - })); - let s = mangle_exported_name(ccx, new_pt + ["!vtable"], ty); - let vt_gvar = str::as_buf(s, {|buf| - llvm::LLVMAddGlobal(ccx.llmod, val_ty(tbl), buf) - }); - llvm::LLVMSetInitializer(vt_gvar, tbl); - llvm::LLVMSetGlobalConstant(vt_gvar, True); - ccx.item_ids.insert(it.id, vt_gvar); - ccx.item_symbols.insert(it.id, s); + trans_impl::trans_impl_vtable(ccx, pt, i_did, ms, tps, it); + } + ast::item_iface(_, _) { + trans_impl::trans_iface_vtable(ccx, pt, it); } _ { } } diff --git a/src/comp/middle/trans_common.rs b/src/comp/middle/trans_common.rs index 72803fd309a4..6b28f1483e7d 100644 --- a/src/comp/middle/trans_common.rs +++ b/src/comp/middle/trans_common.rs @@ -916,9 +916,9 @@ tag dict_param { dict_param_dict(dict_id); dict_param_ty(ty::t); } -type dict_id = @{impl_def: ast::def_id, params: [dict_param]}; +type dict_id = @{def: ast::def_id, params: [dict_param]}; fn hash_dict_id(&&dp: dict_id) -> uint { - let h = syntax::ast_util::hash_def_id(dp.impl_def); + let h = syntax::ast_util::hash_def_id(dp.def); for param in dp.params { h = h << 2u; alt param { diff --git a/src/comp/middle/trans_impl.rs b/src/comp/middle/trans_impl.rs index bb17553b2f03..0139148d728b 100644 --- a/src/comp/middle/trans_impl.rs +++ b/src/comp/middle/trans_impl.rs @@ -68,18 +68,23 @@ fn trans_static_callee(bcx: @block_ctxt, e: @ast::expr, base: @ast::expr, {env: obj_env(val) with lval_static_fn(bcx, did, e.id)} } +fn wrapper_fn_ty(ccx: @crate_ctxt, dict_ty: TypeRef, m: ty::method) + -> {ty: ty::t, llty: TypeRef} { + let fty = ty::mk_fn(ccx.tcx, m.fty); + let bare_fn_ty = type_of_fn_from_ty(ccx, ast_util::dummy_sp(), + fty, *m.tps); + let {inputs, output} = llfn_arg_tys(bare_fn_ty); + {ty: fty, llty: T_fn([dict_ty] + inputs, output)} +} + fn trans_vtable_callee(bcx: @block_ctxt, self: ValueRef, dict: ValueRef, fld_expr: @ast::expr, iface_id: ast::def_id, n_method: uint) -> lval_maybe_callee { let bcx = bcx, ccx = bcx_ccx(bcx), tcx = ccx.tcx; let method = ty::iface_methods(tcx, iface_id)[n_method]; - let fty = ty::mk_fn(tcx, method.fty); - let bare_fn_ty = type_of_fn_from_ty(ccx, ast_util::dummy_sp(), - fty, *method.tps); - let {inputs: bare_inputs, output} = llfn_arg_tys(bare_fn_ty); - let fn_ty = T_fn([val_ty(dict)] + bare_inputs, output); + let {ty: fty, llty: llfty} = wrapper_fn_ty(ccx, val_ty(dict), method); let vtable = PointerCast(bcx, Load(bcx, GEPi(bcx, dict, [0, 0])), - T_ptr(T_array(T_ptr(fn_ty), n_method + 1u))); + T_ptr(T_array(T_ptr(llfty), n_method + 1u))); let mptr = Load(bcx, GEPi(bcx, vtable, [0, n_method as int])); let generic = none; if vec::len(*method.tps) > 0u || ty::type_contains_params(tcx, fty) { @@ -138,9 +143,36 @@ fn llfn_arg_tys(ft: TypeRef) -> {inputs: [TypeRef], output: TypeRef} { {inputs: args, output: out_ty} } -fn trans_wrapper(ccx: @crate_ctxt, pt: [ast::ident], - extra_tps: [ty::param_bounds], m: @ast::method) -> ValueRef { - let real_fn = ccx.item_ids.get(m.id); +fn trans_vtable(ccx: @crate_ctxt, id: ast::node_id, name: str, + ptrs: [ValueRef]) { + let tbl = C_struct(ptrs); + let vt_gvar = str::as_buf(name, {|buf| + llvm::LLVMAddGlobal(ccx.llmod, val_ty(tbl), buf) + }); + llvm::LLVMSetInitializer(vt_gvar, tbl); + llvm::LLVMSetGlobalConstant(vt_gvar, lib::llvm::True); + ccx.item_ids.insert(id, vt_gvar); + ccx.item_symbols.insert(id, name); +} + +fn trans_wrapper(ccx: @crate_ctxt, pt: [ast::ident], llfty: TypeRef, + fill: block(ValueRef, @block_ctxt) -> @block_ctxt) + -> ValueRef { + let lcx = @{path: pt, module_path: [], + obj_typarams: [], obj_fields: [], ccx: ccx}; + let name = link::mangle_internal_name_by_path(ccx, pt); + let llfn = decl_internal_cdecl_fn(ccx.llmod, name, llfty); + let fcx = new_fn_ctxt(lcx, ast_util::dummy_sp(), llfn); + let bcx = new_top_block_ctxt(fcx), lltop = bcx.llbb; + let bcx = fill(llfn, bcx); + build_return(bcx); + finish_fn(fcx, lltop); + ret llfn; +} + +fn trans_impl_wrapper(ccx: @crate_ctxt, pt: [ast::ident], + extra_tps: [ty::param_bounds], real_fn: ValueRef) + -> ValueRef { let {inputs: real_args, output: real_ret} = llfn_arg_tys(llvm::LLVMGetElementType(val_ty(real_fn))); let extra_ptrs = []; @@ -159,32 +191,80 @@ fn trans_wrapper(ccx: @crate_ctxt, pt: [ast::ident], let wrap_args = [T_ptr(T_dict())] + vec::slice(real_args, 0u, 2u) + vec::slice(real_args, 2u + vec::len(extra_ptrs), vec::len(real_args)); let llfn_ty = T_fn(wrap_args, real_ret); + trans_wrapper(ccx, pt, llfn_ty, {|llfn, bcx| + let dict = PointerCast(bcx, LLVMGetParam(llfn, 0u), env_ty); + // retptr, self + let args = [LLVMGetParam(llfn, 1u), LLVMGetParam(llfn, 2u)], i = 0u; + // saved tydescs/dicts + while i < n_extra_ptrs { + i += 1u; + args += [load_inbounds(bcx, dict, [0, i as int])]; + } + // the rest of the parameters + let i = 3u, params_total = llvm::LLVMCountParamTypes(llfn_ty); + while i < params_total { + args += [LLVMGetParam(llfn, i)]; + i += 1u; + } + Call(bcx, real_fn, args); + bcx + }) +} - let lcx = @{path: pt + ["wrapper", m.ident], module_path: [], - obj_typarams: [], obj_fields: [], ccx: ccx}; - let name = link::mangle_internal_name_by_path_and_seq(ccx, pt, m.ident); - let llfn = decl_internal_cdecl_fn(ccx.llmod, name, llfn_ty); - let fcx = new_fn_ctxt(lcx, ast_util::dummy_sp(), llfn); - let bcx = new_top_block_ctxt(fcx), lltop = bcx.llbb; +fn trans_impl_vtable(ccx: @crate_ctxt, pt: [ast::ident], + iface_id: ast::def_id, ms: [@ast::method], + tps: [ast::ty_param], it: @ast::item) { + let new_pt = pt + [it.ident + int::str(it.id), "wrap"]; + let extra_tps = vec::map(tps, {|p| param_bounds(ccx, p)}); + let ptrs = vec::map(*ty::iface_methods(ccx.tcx, iface_id), {|im| + alt vec::find(ms, {|m| m.ident == im.ident}) { + some(m) { + let target = ccx.item_ids.get(m.id); + trans_impl_wrapper(ccx, new_pt + [m.ident], extra_tps, target) + } + } + }); + let s = link::mangle_internal_name_by_path(ccx, new_pt + ["!vtable"]); + trans_vtable(ccx, it.id, s, ptrs); +} - let dict = PointerCast(bcx, LLVMGetParam(llfn, 0u), env_ty); - // retptr, self - let args = [LLVMGetParam(llfn, 1u), LLVMGetParam(llfn, 2u)], i = 0u; - // saved tydescs/dicts - while i < n_extra_ptrs { +fn trans_iface_wrapper(ccx: @crate_ctxt, pt: [ast::ident], m: ty::method, + n: uint) -> ValueRef { + let {llty: llfty, _} = wrapper_fn_ty(ccx, T_ptr(T_i8()), m); + trans_wrapper(ccx, pt, llfty, {|llfn, bcx| + let self = Load(bcx, PointerCast(bcx, LLVMGetParam(llfn, 2u), + T_ptr(T_opaque_iface_ptr(ccx)))); + let boxed = GEPi(bcx, self, [0, abi::box_rc_field_body]); + let dict = Load(bcx, PointerCast(bcx, GEPi(bcx, boxed, [0, 1]), + T_ptr(T_ptr(T_dict())))); + let vtable = PointerCast(bcx, Load(bcx, GEPi(bcx, dict, [0, 0])), + T_ptr(T_array(T_ptr(llfty), n + 1u))); + let mptr = Load(bcx, GEPi(bcx, vtable, [0, n as int])); + // FIXME[impl] This doesn't account for more-than-ptr-sized alignment + let inner_self = GEPi(bcx, boxed, [0, 2]); + let args = [PointerCast(bcx, dict, T_ptr(T_i8())), + LLVMGetParam(llfn, 1u), + PointerCast(bcx, inner_self, T_opaque_cbox_ptr(ccx))]; + let i = 3u, total = llvm::LLVMCountParamTypes(llfty); + while i < total { + args += [LLVMGetParam(llfn, i)]; + i += 1u; + } + Call(bcx, mptr, args); + bcx + }) +} + +fn trans_iface_vtable(ccx: @crate_ctxt, pt: [ast::ident], it: @ast::item) { + let new_pt = pt + [it.ident + int::str(it.id)]; + let i_did = ast_util::local_def(it.id), i = 0u; + let ptrs = vec::map(*ty::iface_methods(ccx.tcx, i_did), {|m| + let w = trans_iface_wrapper(ccx, new_pt + [m.ident], m, i); i += 1u; - args += [load_inbounds(bcx, dict, [0, i as int])]; - } - // the rest of the parameters - let i = 3u, params_total = llvm::LLVMCountParamTypes(llfn_ty); - while i < params_total { - args += [LLVMGetParam(llfn, i)]; - i += 1u; - } - Call(bcx, ccx.item_ids.get(m.id), args); - build_return(bcx); - finish_fn(fcx, lltop); - ret llfn; + w + }); + let s = link::mangle_internal_name_by_path(ccx, new_pt + ["!vtable"]); + trans_vtable(ccx, it.id, s, ptrs); } fn dict_is_static(tcx: ty::ctxt, origin: typeck::dict_origin) -> bool { @@ -193,7 +273,8 @@ fn dict_is_static(tcx: ty::ctxt, origin: typeck::dict_origin) -> bool { vec::all(ts, {|t| !ty::type_contains_params(tcx, t)}) && vec::all(*origs, {|o| dict_is_static(tcx, o)}) } - typeck::dict_param(_, _) { false } + typeck::dict_iface(_) { true } + _ { false } } } @@ -219,6 +300,9 @@ fn get_dict(bcx: @block_ctxt, origin: typeck::dict_origin) -> result { typeck::dict_param(n_param, n_bound) { rslt(bcx, option::get(bcx.fcx.lltyparams[n_param].dicts)[n_bound]) } + typeck::dict_iface(did) { + ret rslt(bcx, get_static_dict(bcx, origin)); + } } } @@ -226,7 +310,7 @@ fn dict_id(tcx: ty::ctxt, origin: typeck::dict_origin) -> dict_id { alt origin { typeck::dict_static(did, ts, origs) { let d_params = [], orig = 0u; - if vec::len(ts) == 0u { ret @{impl_def: did, params: d_params}; } + if vec::len(ts) == 0u { ret @{def: did, params: d_params}; } let impl_params = ty::lookup_item_type(tcx, did).bounds; vec::iter2(ts, *impl_params) {|t, bounds| d_params += [dict_param_ty(t)]; @@ -239,7 +323,10 @@ fn dict_id(tcx: ty::ctxt, origin: typeck::dict_origin) -> dict_id { } } } - @{impl_def: did, params: d_params} + @{def: did, params: d_params} + } + typeck::dict_iface(did) { + @{def: did, params: []} } } } @@ -269,16 +356,19 @@ fn get_static_dict(bcx: @block_ctxt, origin: typeck::dict_origin) fn get_dict_ptrs(bcx: @block_ctxt, origin: typeck::dict_origin) -> {bcx: @block_ctxt, ptrs: [ValueRef]} { let ccx = bcx_ccx(bcx); + fn get_vtable(ccx: @crate_ctxt, did: ast::def_id) -> ValueRef { + if did.crate == ast::local_crate { + ccx.item_ids.get(did.node) + } else { + let name = csearch::get_symbol(ccx.sess.get_cstore(), did); + get_extern_const(ccx.externs, ccx.llmod, name, T_ptr(T_i8())) + } + } alt origin { typeck::dict_static(impl_did, tys, sub_origins) { - let vtable = if impl_did.crate == ast::local_crate { - ccx.item_ids.get(impl_did.node) - } else { - let name = csearch::get_symbol(ccx.sess.get_cstore(), impl_did); - get_extern_const(ccx.externs, ccx.llmod, name, T_ptr(T_i8())) - }; let impl_params = ty::lookup_item_type(ccx.tcx, impl_did).bounds; - let ptrs = [vtable], origin = 0u, ti = none, bcx = bcx; + let ptrs = [get_vtable(ccx, impl_did)]; + let origin = 0u, ti = none, bcx = bcx; vec::iter2(*impl_params, tys) {|param, ty| let rslt = get_tydesc(bcx, ty, true, tps_normal, ti).result; ptrs += [rslt.val]; @@ -297,6 +387,9 @@ fn get_dict_ptrs(bcx: @block_ctxt, origin: typeck::dict_origin) } {bcx: bcx, ptrs: ptrs} } + typeck::dict_iface(did) { + {bcx: bcx, ptrs: [get_vtable(ccx, did)]} + } } } diff --git a/src/comp/middle/typeck.rs b/src/comp/middle/typeck.rs index bde40bdde554..18b602e33aad 100644 --- a/src/comp/middle/typeck.rs +++ b/src/comp/middle/typeck.rs @@ -20,7 +20,7 @@ import syntax::print::pprust::*; export check_crate; export method_map, method_origin, method_static, method_param, method_iface; -export dict_map, dict_res, dict_origin, dict_static, dict_param; +export dict_map, dict_res, dict_origin, dict_static, dict_param, dict_iface; tag method_origin { method_static(ast::def_id); @@ -36,6 +36,7 @@ tag dict_origin { dict_static(ast::def_id, [ty::t], dict_res); // Param number, bound number dict_param(uint, uint); + dict_iface(ast::def_id); } type dict_map = hashmap; @@ -2197,24 +2198,20 @@ fn check_expr_with_unifier(fcx: @fn_ctxt, expr: @ast::expr, unify: unifier, let t_1 = ast_ty_to_ty_crate(fcx.ccx, t); let t_e = ty::expr_ty(tcx, e); - if ty::type_is_nil(tcx, t_e) { - tcx.sess.span_err(expr.span, - "cast from nil: " + - ty_to_str(tcx, t_e) + " as " + - ty_to_str(tcx, t_1)); - } - - if ty::type_is_nil(tcx, t_1) { - tcx.sess.span_err(expr.span, - "cast to nil: " + - ty_to_str(tcx, t_e) + " as " + - ty_to_str(tcx, t_1)); - } - alt ty::struct(tcx, t_1) { // This will be looked up later on ty::ty_iface(_, _) {} _ { + if ty::type_is_nil(tcx, t_e) { + tcx.sess.span_err(expr.span, "cast from nil: " + + ty_to_str(tcx, t_e) + " as " + + ty_to_str(tcx, t_1)); + } else if ty::type_is_nil(tcx, t_1) { + tcx.sess.span_err(expr.span, "cast to nil: " + + ty_to_str(tcx, t_e) + " as " + + ty_to_str(tcx, t_1)); + } + let t_1_is_scalar = type_is_scalar(fcx, expr.span, t_1); if type_is_c_like_enum(fcx,expr.span,t_e) && t_1_is_scalar { /* this case is allowed */ @@ -2942,6 +2939,9 @@ mod dict { } } } + ty::ty_iface(did, _) { + ret dict_iface(did); + } _ { let found = none; std::list::iter(isc) {|impls| diff --git a/src/test/run-pass/iface-cast.rs b/src/test/run-pass/iface-cast.rs index 05da479f8647..f252eef07bc5 100644 --- a/src/test/run-pass/iface-cast.rs +++ b/src/test/run-pass/iface-cast.rs @@ -33,6 +33,8 @@ impl of to_str for Tree { } } +fn foo(x: T) -> str { x.to_str() } + fn main() { let t1 = Tree(@{mutable left: none, mutable right: none, @@ -40,6 +42,8 @@ fn main() { let t2 = Tree(@{mutable left: some(t1), mutable right: some(t1), val: 2 as to_str }); - assert t2.to_str() == "[2, some([1, none, none]), some([1, none, none])]"; + let expected = "[2, some([1, none, none]), some([1, none, none])]"; + assert t2.to_str() == expected; + assert foo(t2 as to_str) == expected; t1.left = some(t2); // create cycle }