//! Hindley-Milner type inference (Algorithm W style) with let-polymorphism. //! //! Lambdas and application are fully inferred; list builtins carry polymorphic //! schemes. Nominal structs, their construction and `//` merge keep their concrete //! checks. `Type::Dyn` is a gradual top that unifies with anything - used for //! records-meeting-vars and effect builtins (`pkg`/`dotfile`/...), which take //! attrsets dynamically. Heterogeneous list literals (used as tuples, e.g. //! permission pairs) degrade to `[?]` rather than erroring. use std::collections::{BTreeMap, HashMap, HashSet}; use std::rc::Rc; use super::ast::*; use super::engine::{BuiltinScheme, Engine}; #[derive(Clone)] struct Scheme { vars: Vec, /// type-class constraints `(class, var)` (var is one of `vars`) constraints: Vec<(String, u32)>, ty: Type, } fn mono(ty: Type) -> Scheme { Scheme { vars: Vec::new(), constraints: Vec::new(), ty, } } fn fun(a: Type, b: Type) -> Type { Type::Fun(Box::new(a), Box::new(b)) } fn list(a: Type) -> Type { Type::List(Box::new(a)) } pub struct Checker { structs: BTreeMap>, enums: BTreeMap>, /// (type name, method name) inherent methods that exist method_names: HashSet<(String, String)>, /// class method name -> class name (for `x.method` class-method sugar) class_methods: HashMap, /// (class, type head) instances that exist (coherence + resolution) instances: HashSet<(String, String)>, /// unresolved type-class constraints (class, type) pending: Vec<(String, Type)>, subst: Vec>, env: Vec<(String, Scheme)>, pub errors: Vec, } impl Checker { /// Build a checker for `program` against `engine`'s registered surface /// (builtins, values, structs/enums, classes, instances). pub fn with_engine(program: &Program, engine: &Engine) -> Self { let mut structs: BTreeMap> = program .structs .iter() .map(|d| (d.name.clone(), d.clone())) .collect(); for d in &engine.structs { structs.entry(d.name.clone()).or_insert_with(|| d.clone()); } let mut enums: BTreeMap> = program .enums .iter() .map(|d| (d.name.clone(), d.clone())) .collect(); for d in &engine.enums { enums.entry(d.name.clone()).or_insert_with(|| d.clone()); } let mut c = Checker { structs, enums, method_names: HashSet::new(), class_methods: HashMap::new(), instances: HashSet::new(), pending: Vec::new(), subst: Vec::new(), env: Vec::new(), errors: Vec::new(), }; c.install_engine(engine); // built-in classes (engine) + user classes let mut classes = engine.classes.clone(); classes.extend(program.classes.iter().cloned()); c.install_classes(&classes, &program.impls); for (cl, head) in &engine.instances { c.instances.insert((cl.clone(), head.clone())); } c.check_methods(); c } /// Register method names and type-check each method body with `self` bound to /// its nominal type, so errors inside methods are caught. fn check_methods(&mut self) { // (nominal type, params, body) collected owned to avoid borrow conflicts let mut jobs: Vec<(Type, Vec, Rc)> = Vec::new(); for d in self.structs.values() { for m in &d.methods { self.method_names.insert((d.name.clone(), m.name.clone())); jobs.push(( Type::Struct(d.name.clone()), m.params.clone(), m.body.clone(), )); } } for d in self.enums.values() { for m in &d.methods { self.method_names.insert((d.name.clone(), m.name.clone())); jobs.push((Type::Enum(d.name.clone()), m.params.clone(), m.body.clone())); } } for (nominal, params, body) in jobs { let mark = self.env.len(); if let Some(self_param) = params.first() { self.env.push((self_param.clone(), mono(nominal))); } for p in params.iter().skip(1) { let v = self.fresh(); self.env.push((p.clone(), mono(v))); } let _ = self.infer(&body); self.env.truncate(mark); } } fn install_classes(&mut self, classes: &[Rc], impls: &[Rc]) { // register each class method as a constrained polymorphic function for c in classes { let pid = match self.fresh() { Type::Var(i) => i, _ => unreachable!(), }; for (mname, sig) in &c.methods { let ty = subst_param(sig, &c.param, &Type::Var(pid)); let scheme = Scheme { vars: vec![pid], constraints: vec![(c.name.clone(), pid)], ty, }; self.env.push((mname.clone(), scheme)); self.class_methods.insert(mname.clone(), c.name.clone()); } } // register instances (coherence) and type-check their method bodies for im in impls { if !self .instances .insert((im.class.clone(), im.type_name.clone())) { self.errors.push(format!( "duplicate instance `{} for {}`", im.class, im.type_name )); } let class = match classes.iter().find(|c| c.name == im.class) { Some(c) => c.clone(), None => { self.errors.push(format!("unknown class `{}`", im.class)); continue; } }; let inst_ty = self.nominal_of(&im.type_name); for (mname, sig) in &class.methods { match im.methods.iter().find(|(n, _)| n == mname) { Some((_, body)) => { let expected = subst_param(sig, &class.param, &inst_ty); // peel lambda params, binding each to its expected arg type, // so the body sees `self`-like params at the instance type let mark = self.env.len(); let mut e = body.as_ref(); let mut ty = expected.clone(); while let (Expr::Lam(p, inner), Type::Fun(arg, ret)) = (e, ty.clone()) { self.env.push((p.clone(), mono(*arg))); e = inner; ty = *ret; } let got = self.infer(e); if self.unify(&got, &ty).is_err() { self.errors.push(format!( "impl `{} for {}`: `{mname}` : expected {}, got {}", im.class, im.type_name, self.resolve(&ty).show(), self.resolve(&got).show() )); } self.env.truncate(mark); } None => self.errors.push(format!( "impl `{} for {}` is missing method `{mname}`", im.class, im.type_name )), } } } } /// `x.method` where `method` is a class method desugars to `method x`. fn class_method_select(&mut self, recv: Type, field: &str) -> Option { if !self.class_methods.contains_key(field) { return None; } let scheme = self.lookup(field)?; let mty = self.instantiate(&scheme); // Fun(arg, ret); pushes the constraint let ret = self.fresh(); let _ = self.unify(&mty, &fun(recv, ret.clone())); Some(ret) } fn nominal_of(&self, name: &str) -> Type { match name { "Int" => Type::Int, "Str" => Type::Str, "Bool" => Type::Bool, _ if self.enums.contains_key(name) => Type::Enum(name.to_string()), _ => Type::Struct(name.to_string()), } } /// Discharge collected constraints: a concrete type must have an instance. /// Constraints still on a type variable are left (polymorphic / unused). fn resolve_pending(&mut self) { let pending = std::mem::take(&mut self.pending); for (class, ty) in pending { if let Some(head) = type_head(&self.resolve(&ty)) && !self.instances.contains(&(class.clone(), head.clone())) { self.errors .push(format!("no instance `{class} for {head}`")); } } } /// Install the engine's builtins and global values into the environment, /// each as a fresh-instantiated scheme. fn install_engine(&mut self, engine: &Engine) { for b in &engine.builtins { let s = self.lower_scheme(&b.scheme); self.env.push((b.name.clone(), s)); } for v in &engine.values { let s = self.lower_scheme(&v.scheme); self.env.push((v.name.clone(), s)); } } /// Turn a [`BuiltinScheme`] (bound vars written as `Var(0..quantified)`) into /// a real [`Scheme`] by allocating that many fresh inference vars and /// substituting them in, so its quantified vars never collide with inference. fn lower_scheme(&mut self, bs: &BuiltinScheme) -> Scheme { let fresh: Vec = (0..bs.quantified) .map(|_| match self.fresh() { Type::Var(i) => i, _ => unreachable!(), }) .collect(); Scheme { vars: fresh.clone(), constraints: bs .constraints .iter() .map(|(c, i)| (c.clone(), fresh[*i as usize])) .collect(), ty: lower_type(&bs.ty, &fresh), } } // ---- type-variable plumbing ------------------------------------------- fn fresh(&mut self) -> Type { let id = self.subst.len() as u32; self.subst.push(None); Type::Var(id) } fn prune(&self, t: &Type) -> Type { match t { Type::Var(id) => match self.subst.get(*id as usize).and_then(|o| o.clone()) { Some(u) => self.prune(&u), None => t.clone(), }, _ => t.clone(), } } /// Deeply follow substitutions (for generalization and display). fn resolve(&self, t: &Type) -> Type { match self.prune(t) { Type::List(x) => list(self.resolve(&x)), Type::Task(x) => Type::Task(Box::new(self.resolve(&x))), Type::Fun(x, y) => fun(self.resolve(&x), self.resolve(&y)), Type::Record(m) => Type::Record( m.iter() .map(|(k, v)| (k.clone(), self.resolve(v))) .collect(), ), other => other, } } fn occurs(&self, id: u32, t: &Type) -> bool { match self.prune(t) { Type::Var(j) => id == j, Type::List(x) | Type::Task(x) => self.occurs(id, &x), Type::Fun(x, y) => self.occurs(id, &x) || self.occurs(id, &y), Type::Record(m) => m.values().any(|v| self.occurs(id, v)), _ => false, } } fn bind(&mut self, id: u32, t: &Type) -> Result<(), String> { if let Type::Var(j) = t && *j == id { return Ok(()); } if self.occurs(id, t) { return Err(format!( "infinite type: t{id} occurs in {}", self.resolve(t).show() )); } self.subst[id as usize] = Some(t.clone()); Ok(()) } fn unify(&mut self, a: &Type, b: &Type) -> Result<(), String> { let a = self.prune(a); let b = self.prune(b); match (&a, &b) { (Type::Dyn, _) | (_, Type::Dyn) => Ok(()), (Type::Var(i), Type::Var(j)) if i == j => Ok(()), (Type::Var(i), _) => self.bind(*i, &b), (_, Type::Var(j)) => self.bind(*j, &a), (Type::Int, Type::Int) | (Type::Str, Type::Str) | (Type::Bool, Type::Bool) => Ok(()), (Type::List(x), Type::List(y)) => self.unify(x, y), (Type::Task(x), Type::Task(y)) => self.unify(x, y), (Type::Fun(a1, r1), Type::Fun(a2, r2)) => { self.unify(a1, a2)?; self.unify(r1, r2) } (Type::Struct(n), Type::Struct(m)) if n == m => Ok(()), (Type::Enum(n), Type::Enum(m)) if n == m => Ok(()), (Type::Record(m1), Type::Record(m2)) if m1.keys().eq(m2.keys()) => { for (k, v1) in m1 { self.unify(v1, &m2[k])?; } Ok(()) } _ => Err(format!("expected {}, got {}", a.show(), b.show())), } } fn want(&mut self, a: &Type, b: &Type) { if let Err(e) = self.unify(a, b) { self.errors.push(e); } } fn instantiate(&mut self, s: &Scheme) -> Type { let mapping: HashMap = s.vars.iter().map(|v| (*v, self.fresh())).collect(); // each instantiation of a constrained scheme adds a pending constraint for (class, v) in &s.constraints { if let Some(t) = mapping.get(v) { self.pending.push((class.clone(), t.clone())); } } fn go(t: &Type, m: &HashMap) -> Type { match t { Type::Var(id) => m.get(id).cloned().unwrap_or(Type::Var(*id)), Type::List(x) => list(go(x, m)), Type::Task(x) => Type::Task(Box::new(go(x, m))), Type::Fun(x, y) => fun(go(x, m), go(y, m)), Type::Record(r) => { Type::Record(r.iter().map(|(k, v)| (k.clone(), go(v, m))).collect()) } other => other.clone(), } } go(&s.ty, &mapping) } fn generalize(&self, t: &Type) -> Scheme { let t = self.resolve(t); let mut env_fv: HashSet = HashSet::new(); for (_, s) in &self.env { let rt = self.resolve(&s.ty); let mut fv = Vec::new(); free_vars(&rt, &mut fv); for id in fv { if !s.vars.contains(&id) { env_fv.insert(id); } } } let mut tv = Vec::new(); free_vars(&t, &mut tv); let mut vars = Vec::new(); for id in tv { if !env_fv.contains(&id) && !vars.contains(&id) { vars.push(id); } } Scheme { vars, constraints: Vec::new(), ty: t, } } fn lookup(&self, n: &str) -> Option { self.env .iter() .rev() .find(|(k, _)| k == n) .map(|(_, s)| s.clone()) } fn struct_fields(&self, name: &str) -> Option> { self.structs.get(name).map(|d| { d.fields .iter() .map(|f| (f.name.clone(), f.ty.clone())) .collect() }) } // ---- inference --------------------------------------------------------- pub fn check(&mut self, e: &Expr) -> Type { let t = self.infer(e); self.resolve_pending(); t } fn infer(&mut self, e: &Expr) -> Type { match e { Expr::Int(..) => Type::Int, Expr::Str(_) => Type::Str, Expr::Bool(_) => Type::Bool, Expr::Var(n) => match self.lookup(n) { Some(s) => self.instantiate(&s), None => { self.errors.push(format!("unbound variable `{n}`")); Type::Dyn } }, Expr::Lam(p, body) => { let pv = self.fresh(); self.env.push((p.clone(), mono(pv.clone()))); let bt = self.infer(body); self.env.pop(); fun(pv, bt) } Expr::App(f, a) => { let ft = self.infer(f); let at = self.infer(a); let rv = self.fresh(); let expected = fun(at, rv.clone()); if let Err(e) = self.unify(&ft, &expected) { self.errors.push(format!("application: {e}")); return Type::Dyn; } rv } // list literal: homogeneous -> [t]; heterogeneous (tuple-like) -> [?] Expr::List(es) => { let ev = self.fresh(); let mut homogeneous = true; for e in es { let t = self.infer(e); if self.unify(&ev, &t).is_err() { homogeneous = false; } } if homogeneous { list(ev) } else { list(Type::Dyn) } } Expr::Record(fields) => { let mut m = BTreeMap::new(); for (k, e) in fields { let t = self.infer(e); m.insert(k.clone(), t); } Type::Record(m) } Expr::Construct(name, fields) => self.check_construct(name, fields), Expr::EnumVariant(name, variant) => { match self.enums.get(name) { Some(d) if d.variants.iter().any(|v| v == variant) => {} Some(_) => self .errors .push(format!("enum `{name}` has no variant `{variant}`")), None => self.errors.push(format!("unknown enum `{name}`")), } Type::Enum(name.clone()) } Expr::Select(obj, field) => { let ot = self.infer(obj); match self.prune(&ot) { Type::Record(m) => m.get(field).cloned().unwrap_or_else(|| { self.errors .push(format!("no field `{field}` on {}", ot.show())); Type::Dyn }), // field, then inherent method, then class-method (`x.m` == `m x`) Type::Struct(n) => { if let Some(ft) = self.struct_fields(&n).and_then(|m| m.get(field).cloned()) { ft } else if self.method_names.contains(&(n.clone(), field.clone())) { Type::Dyn } else if let Some(t) = self.class_method_select(Type::Struct(n.clone()), field) { t } else { self.errors .push(format!("no field or method `{field}` on `{n}`")); Type::Dyn } } Type::Enum(n) => { if self.method_names.contains(&(n.clone(), field.clone())) { Type::Dyn } else if let Some(t) = self.class_method_select(Type::Enum(n.clone()), field) { t } else { self.errors.push(format!("no method `{field}` on `{n}`")); Type::Dyn } } _ => Type::Dyn, // var/dyn: cannot resolve statically } } Expr::Merge(l, r) => { let lt = self.infer(l); let rt = self.infer(r); self.infer_merge(lt, rt) } Expr::If(c, t, e) => { let ct = self.infer(c); self.want(&ct, &Type::Bool); let tt = self.infer(t); let et = self.infer(e); self.want(&tt, &et); tt } Expr::Bin(op, l, r) => self.infer_bin(*op, l, r), Expr::Let(binds, body) => { let mark = self.env.len(); // recursive: pre-bind each name to a fresh monomorphic var let mut vars = Vec::new(); for b in binds { let v = self.fresh(); vars.push(v.clone()); self.env.push((b.name.clone(), mono(v))); } for (i, b) in binds.iter().enumerate() { let t = self.check_binding(b); self.want(&vars[i].clone(), &t); } // generalize for the body (let-polymorphism) self.env.truncate(mark); for (i, b) in binds.iter().enumerate() { let s = self.generalize(&vars[i].clone()); self.env.push((b.name.clone(), s)); } let bt = self.infer(body); self.env.truncate(mark); bt } } } fn infer_bin(&mut self, op: BinOp, l: &Expr, r: &Expr) -> Type { // arithmetic and `/` dispatch through operator classes (Add/Sub/.../Div), // so `a op b : a` requires an instance for `a` (built-in for Int/Str). if let Some((class, _)) = op_class(op) { let lt = self.infer(l); let rt = self.infer(r); self.want(<, &rt); let t = self.prune(<); self.pending.push((class.to_string(), t.clone())); return t; } match op { BinOp::Eq => { let lt = self.infer(l); let rt = self.infer(r); self.want(<, &rt); Type::Bool } BinOp::And | BinOp::Or => { let lt = self.infer(l); let rt = self.infer(r); self.want(<, &Type::Bool); self.want(&rt, &Type::Bool); Type::Bool } // `++` is string concat for strings, else list append BinOp::Concat => { let lt = self.infer(l); let rt = self.infer(r); if matches!(self.prune(<), Type::Str) { self.want(&rt, &Type::Str); Type::Str } else { let ev = self.fresh(); self.want(<, &list(ev.clone())); self.want(&rt, &list(ev.clone())); list(ev) } } _ => unreachable!("op-class operators handled above"), } } fn infer_merge(&mut self, lt: Type, rt: Type) -> Type { let overrides = match self.prune(&rt) { Type::Record(m) => m, Type::Dyn => return self.prune(<), other => { self.errors.push(format!( "right of `//` must be a record, got {}", other.show() )); return self.prune(<); } }; match self.prune(<) { Type::Struct(name) => { if let Some(schema) = self.struct_fields(&name) { for (k, vt) in &overrides { match schema.get(k) { Some(ft) => { if self.unify(ft, vt).is_err() { self.errors.push(format!( "`{name} // {{ {k} = .. }}` : `{name}.{k}` is {}, got {}", ft.show(), self.resolve(vt).show() )); } } None => self .errors .push(format!("`{name}` has no field `{k}` to override")), } } } Type::Struct(name) } Type::Record(base) => { let mut m = base; for (k, v) in overrides { m.insert(k, v); } Type::Record(m) } other => { self.errors.push(format!( "left of `//` must be a record/struct, got {}", other.show() )); other } } } fn check_construct(&mut self, name: &str, fields: &[(String, Rc)]) -> Type { let decl = match self.structs.get(name) { Some(d) => d.clone(), None => { self.errors.push(format!("unknown struct `{name}`")); return Type::Struct(name.into()); } }; let mut given: BTreeMap = BTreeMap::new(); for (k, e) in fields { let t = self.infer(e); given.insert(k.clone(), t); } for f in &decl.fields { match given.get(&f.name) { Some(gt) => { if self.unify(gt, &f.ty).is_err() { self.errors.push(format!( "`{name}.{}` : expected {}, got {}", f.name, f.ty.show(), self.resolve(gt).show() )); } } None if f.default.is_some() => {} None => self .errors .push(format!("`{name}` missing required field `{}`", f.name)), } } for k in given.keys() { if !decl.fields.iter().any(|f| &f.name == k) { self.errors.push(format!("`{name}` has no field `{k}`")); } } Type::Struct(name.into()) } fn check_binding(&mut self, b: &Binding) -> Type { match (&b.ann, &*b.value) { (Some(Type::Struct(name)), Expr::Record(fields)) => self.check_construct(name, fields), (Some(ann), _) => { let got = self.infer(&b.value); if self.unify(&got, ann).is_err() { self.errors.push(format!( "`{}` : annotated {}, got {}", b.name, ann.show(), self.resolve(&got).show() )); } ann.clone() } (None, _) => self.infer(&b.value), } } } /// The operator class + method an arithmetic/`/` operator desugars to. pub fn op_class(op: BinOp) -> Option<(&'static str, &'static str)> { match op { BinOp::Add => Some(("Add", "add")), BinOp::Sub => Some(("Sub", "sub")), BinOp::Mul => Some(("Mul", "mul")), BinOp::Slash => Some(("Div", "div")), BinOp::Mod => Some(("Mod", "mod")), BinOp::Pow => Some(("Pow", "pow")), _ => None, } } /// The nominal head of a type (for instance lookup), if it has one. fn type_head(t: &Type) -> Option { match t { Type::Int => Some("Int".into()), Type::Str => Some("Str".into()), Type::Bool => Some("Bool".into()), Type::List(_) => Some("List".into()), Type::Struct(n) | Type::Enum(n) => Some(n.clone()), _ => None, } } /// Replace the class parameter (parsed as `Struct(param)`) with `repl` in a sig. fn subst_param(t: &Type, param: &str, repl: &Type) -> Type { match t { Type::Struct(n) if n == param => repl.clone(), Type::List(x) => Type::List(Box::new(subst_param(x, param, repl))), Type::Task(x) => Type::Task(Box::new(subst_param(x, param, repl))), Type::Fun(x, y) => Type::Fun( Box::new(subst_param(x, param, repl)), Box::new(subst_param(y, param, repl)), ), Type::Record(m) => Type::Record( m.iter() .map(|(k, v)| (k.clone(), subst_param(v, param, repl))) .collect(), ), other => other.clone(), } } /// Rewrite a [`BuiltinScheme`]'s local bound vars `Var(0..)` to allocated `fresh` ids. fn lower_type(t: &Type, fresh: &[u32]) -> Type { match t { Type::Var(id) => Type::Var(fresh[*id as usize]), Type::List(x) => Type::List(Box::new(lower_type(x, fresh))), Type::Task(x) => Type::Task(Box::new(lower_type(x, fresh))), Type::Fun(x, y) => Type::Fun( Box::new(lower_type(x, fresh)), Box::new(lower_type(y, fresh)), ), Type::Record(m) => Type::Record( m.iter() .map(|(k, v)| (k.clone(), lower_type(v, fresh))) .collect(), ), other => other.clone(), } } fn free_vars(t: &Type, out: &mut Vec) { match t { Type::Var(id) => { if !out.contains(id) { out.push(*id); } } Type::List(x) | Type::Task(x) => free_vars(x, out), Type::Fun(x, y) => { free_vars(x, out); free_vars(y, out); } Type::Record(m) => { for v in m.values() { free_vars(v, out); } } _ => {} } }