doot/crates/doot-lang/src/lang/check.rs

822 lines
30 KiB
Rust

//! Hindley-Milner type inference (Algorithm W style) with let-polymorphism.
//!
//! Lambdas and application are fully inferred; list builtins carry polymorphic
//! schemes. Nominal structs, their construction and `//` merge keep their concrete
//! checks. `Type::Dyn` is a gradual top that unifies with anything - used for
//! records-meeting-vars and effect builtins (`pkg`/`dotfile`/...), which take
//! attrsets dynamically. Heterogeneous list literals (used as tuples, e.g.
//! permission pairs) degrade to `[?]` rather than erroring.
use std::collections::{BTreeMap, HashMap, HashSet};
use std::rc::Rc;
use super::ast::*;
use super::engine::{BuiltinScheme, Engine};
#[derive(Clone)]
struct Scheme {
vars: Vec<u32>,
/// type-class constraints `(class, var)` (var is one of `vars`)
constraints: Vec<(String, u32)>,
ty: Type,
}
fn mono(ty: Type) -> Scheme {
Scheme {
vars: Vec::new(),
constraints: Vec::new(),
ty,
}
}
fn fun(a: Type, b: Type) -> Type {
Type::Fun(Box::new(a), Box::new(b))
}
fn list(a: Type) -> Type {
Type::List(Box::new(a))
}
pub struct Checker {
structs: BTreeMap<String, Rc<StructDecl>>,
enums: BTreeMap<String, Rc<EnumDecl>>,
/// (type name, method name) inherent methods that exist
method_names: HashSet<(String, String)>,
/// class method name -> class name (for `x.method` class-method sugar)
class_methods: HashMap<String, String>,
/// (class, type head) instances that exist (coherence + resolution)
instances: HashSet<(String, String)>,
/// unresolved type-class constraints (class, type)
pending: Vec<(String, Type)>,
subst: Vec<Option<Type>>,
env: Vec<(String, Scheme)>,
pub errors: Vec<String>,
}
impl Checker {
/// Build a checker for `program` against `engine`'s registered surface
/// (builtins, values, structs/enums, classes, instances).
pub fn with_engine(program: &Program, engine: &Engine) -> Self {
let mut structs: BTreeMap<String, Rc<StructDecl>> = program
.structs
.iter()
.map(|d| (d.name.clone(), d.clone()))
.collect();
for d in &engine.structs {
structs.entry(d.name.clone()).or_insert_with(|| d.clone());
}
let mut enums: BTreeMap<String, Rc<EnumDecl>> = program
.enums
.iter()
.map(|d| (d.name.clone(), d.clone()))
.collect();
for d in &engine.enums {
enums.entry(d.name.clone()).or_insert_with(|| d.clone());
}
let mut c = Checker {
structs,
enums,
method_names: HashSet::new(),
class_methods: HashMap::new(),
instances: HashSet::new(),
pending: Vec::new(),
subst: Vec::new(),
env: Vec::new(),
errors: Vec::new(),
};
c.install_engine(engine);
// built-in classes (engine) + user classes
let mut classes = engine.classes.clone();
classes.extend(program.classes.iter().cloned());
c.install_classes(&classes, &program.impls);
for (cl, head) in &engine.instances {
c.instances.insert((cl.clone(), head.clone()));
}
c.check_methods();
c
}
/// Register method names and type-check each method body with `self` bound to
/// its nominal type, so errors inside methods are caught.
fn check_methods(&mut self) {
// (nominal type, params, body) collected owned to avoid borrow conflicts
let mut jobs: Vec<(Type, Vec<String>, Rc<Expr>)> = Vec::new();
for d in self.structs.values() {
for m in &d.methods {
self.method_names.insert((d.name.clone(), m.name.clone()));
jobs.push((
Type::Struct(d.name.clone()),
m.params.clone(),
m.body.clone(),
));
}
}
for d in self.enums.values() {
for m in &d.methods {
self.method_names.insert((d.name.clone(), m.name.clone()));
jobs.push((Type::Enum(d.name.clone()), m.params.clone(), m.body.clone()));
}
}
for (nominal, params, body) in jobs {
let mark = self.env.len();
if let Some(self_param) = params.first() {
self.env.push((self_param.clone(), mono(nominal)));
}
for p in params.iter().skip(1) {
let v = self.fresh();
self.env.push((p.clone(), mono(v)));
}
let _ = self.infer(&body);
self.env.truncate(mark);
}
}
fn install_classes(&mut self, classes: &[Rc<ClassDecl>], impls: &[Rc<ImplDecl>]) {
// register each class method as a constrained polymorphic function
for c in classes {
let pid = match self.fresh() {
Type::Var(i) => i,
_ => unreachable!(),
};
for (mname, sig) in &c.methods {
let ty = subst_param(sig, &c.param, &Type::Var(pid));
let scheme = Scheme {
vars: vec![pid],
constraints: vec![(c.name.clone(), pid)],
ty,
};
self.env.push((mname.clone(), scheme));
self.class_methods.insert(mname.clone(), c.name.clone());
}
}
// register instances (coherence) and type-check their method bodies
for im in impls {
if !self
.instances
.insert((im.class.clone(), im.type_name.clone()))
{
self.errors.push(format!(
"duplicate instance `{} for {}`",
im.class, im.type_name
));
}
let class = match classes.iter().find(|c| c.name == im.class) {
Some(c) => c.clone(),
None => {
self.errors.push(format!("unknown class `{}`", im.class));
continue;
}
};
let inst_ty = self.nominal_of(&im.type_name);
for (mname, sig) in &class.methods {
match im.methods.iter().find(|(n, _)| n == mname) {
Some((_, body)) => {
let expected = subst_param(sig, &class.param, &inst_ty);
// peel lambda params, binding each to its expected arg type,
// so the body sees `self`-like params at the instance type
let mark = self.env.len();
let mut e = body.as_ref();
let mut ty = expected.clone();
while let (Expr::Lam(p, inner), Type::Fun(arg, ret)) = (e, ty.clone()) {
self.env.push((p.clone(), mono(*arg)));
e = inner;
ty = *ret;
}
let got = self.infer(e);
if self.unify(&got, &ty).is_err() {
self.errors.push(format!(
"impl `{} for {}`: `{mname}` : expected {}, got {}",
im.class,
im.type_name,
self.resolve(&ty).show(),
self.resolve(&got).show()
));
}
self.env.truncate(mark);
}
None => self.errors.push(format!(
"impl `{} for {}` is missing method `{mname}`",
im.class, im.type_name
)),
}
}
}
}
/// `x.method` where `method` is a class method desugars to `method x`.
fn class_method_select(&mut self, recv: Type, field: &str) -> Option<Type> {
if !self.class_methods.contains_key(field) {
return None;
}
let scheme = self.lookup(field)?;
let mty = self.instantiate(&scheme); // Fun(arg, ret); pushes the constraint
let ret = self.fresh();
let _ = self.unify(&mty, &fun(recv, ret.clone()));
Some(ret)
}
fn nominal_of(&self, name: &str) -> Type {
match name {
"Int" => Type::Int,
"Str" => Type::Str,
"Bool" => Type::Bool,
_ if self.enums.contains_key(name) => Type::Enum(name.to_string()),
_ => Type::Struct(name.to_string()),
}
}
/// Discharge collected constraints: a concrete type must have an instance.
/// Constraints still on a type variable are left (polymorphic / unused).
fn resolve_pending(&mut self) {
let pending = std::mem::take(&mut self.pending);
for (class, ty) in pending {
if let Some(head) = type_head(&self.resolve(&ty))
&& !self.instances.contains(&(class.clone(), head.clone()))
{
self.errors
.push(format!("no instance `{class} for {head}`"));
}
}
}
/// Install the engine's builtins and global values into the environment,
/// each as a fresh-instantiated scheme.
fn install_engine(&mut self, engine: &Engine) {
for b in &engine.builtins {
let s = self.lower_scheme(&b.scheme);
self.env.push((b.name.clone(), s));
}
for v in &engine.values {
let s = self.lower_scheme(&v.scheme);
self.env.push((v.name.clone(), s));
}
}
/// Turn a [`BuiltinScheme`] (bound vars written as `Var(0..quantified)`) into
/// a real [`Scheme`] by allocating that many fresh inference vars and
/// substituting them in, so its quantified vars never collide with inference.
fn lower_scheme(&mut self, bs: &BuiltinScheme) -> Scheme {
let fresh: Vec<u32> = (0..bs.quantified)
.map(|_| match self.fresh() {
Type::Var(i) => i,
_ => unreachable!(),
})
.collect();
Scheme {
vars: fresh.clone(),
constraints: bs
.constraints
.iter()
.map(|(c, i)| (c.clone(), fresh[*i as usize]))
.collect(),
ty: lower_type(&bs.ty, &fresh),
}
}
// ---- type-variable plumbing -------------------------------------------
fn fresh(&mut self) -> Type {
let id = self.subst.len() as u32;
self.subst.push(None);
Type::Var(id)
}
fn prune(&self, t: &Type) -> Type {
match t {
Type::Var(id) => match self.subst.get(*id as usize).and_then(|o| o.clone()) {
Some(u) => self.prune(&u),
None => t.clone(),
},
_ => t.clone(),
}
}
/// Deeply follow substitutions (for generalization and display).
fn resolve(&self, t: &Type) -> Type {
match self.prune(t) {
Type::List(x) => list(self.resolve(&x)),
Type::Task(x) => Type::Task(Box::new(self.resolve(&x))),
Type::Fun(x, y) => fun(self.resolve(&x), self.resolve(&y)),
Type::Record(m) => Type::Record(
m.iter()
.map(|(k, v)| (k.clone(), self.resolve(v)))
.collect(),
),
other => other,
}
}
fn occurs(&self, id: u32, t: &Type) -> bool {
match self.prune(t) {
Type::Var(j) => id == j,
Type::List(x) | Type::Task(x) => self.occurs(id, &x),
Type::Fun(x, y) => self.occurs(id, &x) || self.occurs(id, &y),
Type::Record(m) => m.values().any(|v| self.occurs(id, v)),
_ => false,
}
}
fn bind(&mut self, id: u32, t: &Type) -> Result<(), String> {
if let Type::Var(j) = t
&& *j == id
{
return Ok(());
}
if self.occurs(id, t) {
return Err(format!(
"infinite type: t{id} occurs in {}",
self.resolve(t).show()
));
}
self.subst[id as usize] = Some(t.clone());
Ok(())
}
fn unify(&mut self, a: &Type, b: &Type) -> Result<(), String> {
let a = self.prune(a);
let b = self.prune(b);
match (&a, &b) {
(Type::Dyn, _) | (_, Type::Dyn) => Ok(()),
(Type::Var(i), Type::Var(j)) if i == j => Ok(()),
(Type::Var(i), _) => self.bind(*i, &b),
(_, Type::Var(j)) => self.bind(*j, &a),
(Type::Int, Type::Int) | (Type::Str, Type::Str) | (Type::Bool, Type::Bool) => Ok(()),
(Type::List(x), Type::List(y)) => self.unify(x, y),
(Type::Task(x), Type::Task(y)) => self.unify(x, y),
(Type::Fun(a1, r1), Type::Fun(a2, r2)) => {
self.unify(a1, a2)?;
self.unify(r1, r2)
}
(Type::Struct(n), Type::Struct(m)) if n == m => Ok(()),
(Type::Enum(n), Type::Enum(m)) if n == m => Ok(()),
(Type::Record(m1), Type::Record(m2)) if m1.keys().eq(m2.keys()) => {
for (k, v1) in m1 {
self.unify(v1, &m2[k])?;
}
Ok(())
}
_ => Err(format!("expected {}, got {}", a.show(), b.show())),
}
}
fn want(&mut self, a: &Type, b: &Type) {
if let Err(e) = self.unify(a, b) {
self.errors.push(e);
}
}
fn instantiate(&mut self, s: &Scheme) -> Type {
let mapping: HashMap<u32, Type> = s.vars.iter().map(|v| (*v, self.fresh())).collect();
// each instantiation of a constrained scheme adds a pending constraint
for (class, v) in &s.constraints {
if let Some(t) = mapping.get(v) {
self.pending.push((class.clone(), t.clone()));
}
}
fn go(t: &Type, m: &HashMap<u32, Type>) -> Type {
match t {
Type::Var(id) => m.get(id).cloned().unwrap_or(Type::Var(*id)),
Type::List(x) => list(go(x, m)),
Type::Task(x) => Type::Task(Box::new(go(x, m))),
Type::Fun(x, y) => fun(go(x, m), go(y, m)),
Type::Record(r) => {
Type::Record(r.iter().map(|(k, v)| (k.clone(), go(v, m))).collect())
}
other => other.clone(),
}
}
go(&s.ty, &mapping)
}
fn generalize(&self, t: &Type) -> Scheme {
let t = self.resolve(t);
let mut env_fv: HashSet<u32> = HashSet::new();
for (_, s) in &self.env {
let rt = self.resolve(&s.ty);
let mut fv = Vec::new();
free_vars(&rt, &mut fv);
for id in fv {
if !s.vars.contains(&id) {
env_fv.insert(id);
}
}
}
let mut tv = Vec::new();
free_vars(&t, &mut tv);
let mut vars = Vec::new();
for id in tv {
if !env_fv.contains(&id) && !vars.contains(&id) {
vars.push(id);
}
}
Scheme {
vars,
constraints: Vec::new(),
ty: t,
}
}
fn lookup(&self, n: &str) -> Option<Scheme> {
self.env
.iter()
.rev()
.find(|(k, _)| k == n)
.map(|(_, s)| s.clone())
}
fn struct_fields(&self, name: &str) -> Option<BTreeMap<String, Type>> {
self.structs.get(name).map(|d| {
d.fields
.iter()
.map(|f| (f.name.clone(), f.ty.clone()))
.collect()
})
}
// ---- inference ---------------------------------------------------------
pub fn check(&mut self, e: &Expr) -> Type {
let t = self.infer(e);
self.resolve_pending();
t
}
fn infer(&mut self, e: &Expr) -> Type {
match e {
Expr::Int(..) => Type::Int,
Expr::Str(_) => Type::Str,
Expr::Bool(_) => Type::Bool,
Expr::Var(n) => match self.lookup(n) {
Some(s) => self.instantiate(&s),
None => {
self.errors.push(format!("unbound variable `{n}`"));
Type::Dyn
}
},
Expr::Lam(p, body) => {
let pv = self.fresh();
self.env.push((p.clone(), mono(pv.clone())));
let bt = self.infer(body);
self.env.pop();
fun(pv, bt)
}
Expr::App(f, a) => {
let ft = self.infer(f);
let at = self.infer(a);
let rv = self.fresh();
let expected = fun(at, rv.clone());
if let Err(e) = self.unify(&ft, &expected) {
self.errors.push(format!("application: {e}"));
return Type::Dyn;
}
rv
}
// list literal: homogeneous -> [t]; heterogeneous (tuple-like) -> [?]
Expr::List(es) => {
let ev = self.fresh();
let mut homogeneous = true;
for e in es {
let t = self.infer(e);
if self.unify(&ev, &t).is_err() {
homogeneous = false;
}
}
if homogeneous {
list(ev)
} else {
list(Type::Dyn)
}
}
Expr::Record(fields) => {
let mut m = BTreeMap::new();
for (k, e) in fields {
let t = self.infer(e);
m.insert(k.clone(), t);
}
Type::Record(m)
}
Expr::Construct(name, fields) => self.check_construct(name, fields),
Expr::EnumVariant(name, variant) => {
match self.enums.get(name) {
Some(d) if d.variants.iter().any(|v| v == variant) => {}
Some(_) => self
.errors
.push(format!("enum `{name}` has no variant `{variant}`")),
None => self.errors.push(format!("unknown enum `{name}`")),
}
Type::Enum(name.clone())
}
Expr::Select(obj, field) => {
let ot = self.infer(obj);
match self.prune(&ot) {
Type::Record(m) => m.get(field).cloned().unwrap_or_else(|| {
self.errors
.push(format!("no field `{field}` on {}", ot.show()));
Type::Dyn
}),
// field, then inherent method, then class-method (`x.m` == `m x`)
Type::Struct(n) => {
if let Some(ft) = self.struct_fields(&n).and_then(|m| m.get(field).cloned())
{
ft
} else if self.method_names.contains(&(n.clone(), field.clone())) {
Type::Dyn
} else if let Some(t) =
self.class_method_select(Type::Struct(n.clone()), field)
{
t
} else {
self.errors
.push(format!("no field or method `{field}` on `{n}`"));
Type::Dyn
}
}
Type::Enum(n) => {
if self.method_names.contains(&(n.clone(), field.clone())) {
Type::Dyn
} else if let Some(t) =
self.class_method_select(Type::Enum(n.clone()), field)
{
t
} else {
self.errors.push(format!("no method `{field}` on `{n}`"));
Type::Dyn
}
}
_ => Type::Dyn, // var/dyn: cannot resolve statically
}
}
Expr::Merge(l, r) => {
let lt = self.infer(l);
let rt = self.infer(r);
self.infer_merge(lt, rt)
}
Expr::If(c, t, e) => {
let ct = self.infer(c);
self.want(&ct, &Type::Bool);
let tt = self.infer(t);
let et = self.infer(e);
self.want(&tt, &et);
tt
}
Expr::Bin(op, l, r) => self.infer_bin(*op, l, r),
Expr::Let(binds, body) => {
let mark = self.env.len();
// recursive: pre-bind each name to a fresh monomorphic var
let mut vars = Vec::new();
for b in binds {
let v = self.fresh();
vars.push(v.clone());
self.env.push((b.name.clone(), mono(v)));
}
for (i, b) in binds.iter().enumerate() {
let t = self.check_binding(b);
self.want(&vars[i].clone(), &t);
}
// generalize for the body (let-polymorphism)
self.env.truncate(mark);
for (i, b) in binds.iter().enumerate() {
let s = self.generalize(&vars[i].clone());
self.env.push((b.name.clone(), s));
}
let bt = self.infer(body);
self.env.truncate(mark);
bt
}
}
}
fn infer_bin(&mut self, op: BinOp, l: &Expr, r: &Expr) -> Type {
// arithmetic and `/` dispatch through operator classes (Add/Sub/.../Div),
// so `a op b : a` requires an instance for `a` (built-in for Int/Str).
if let Some((class, _)) = op_class(op) {
let lt = self.infer(l);
let rt = self.infer(r);
self.want(&lt, &rt);
let t = self.prune(&lt);
self.pending.push((class.to_string(), t.clone()));
return t;
}
match op {
BinOp::Eq => {
let lt = self.infer(l);
let rt = self.infer(r);
self.want(&lt, &rt);
Type::Bool
}
BinOp::And | BinOp::Or => {
let lt = self.infer(l);
let rt = self.infer(r);
self.want(&lt, &Type::Bool);
self.want(&rt, &Type::Bool);
Type::Bool
}
// `++` is string concat for strings, else list append
BinOp::Concat => {
let lt = self.infer(l);
let rt = self.infer(r);
if matches!(self.prune(&lt), Type::Str) {
self.want(&rt, &Type::Str);
Type::Str
} else {
let ev = self.fresh();
self.want(&lt, &list(ev.clone()));
self.want(&rt, &list(ev.clone()));
list(ev)
}
}
_ => unreachable!("op-class operators handled above"),
}
}
fn infer_merge(&mut self, lt: Type, rt: Type) -> Type {
let overrides = match self.prune(&rt) {
Type::Record(m) => m,
Type::Dyn => return self.prune(&lt),
other => {
self.errors.push(format!(
"right of `//` must be a record, got {}",
other.show()
));
return self.prune(&lt);
}
};
match self.prune(&lt) {
Type::Struct(name) => {
if let Some(schema) = self.struct_fields(&name) {
for (k, vt) in &overrides {
match schema.get(k) {
Some(ft) => {
if self.unify(ft, vt).is_err() {
self.errors.push(format!(
"`{name} // {{ {k} = .. }}` : `{name}.{k}` is {}, got {}",
ft.show(),
self.resolve(vt).show()
));
}
}
None => self
.errors
.push(format!("`{name}` has no field `{k}` to override")),
}
}
}
Type::Struct(name)
}
Type::Record(base) => {
let mut m = base;
for (k, v) in overrides {
m.insert(k, v);
}
Type::Record(m)
}
other => {
self.errors.push(format!(
"left of `//` must be a record/struct, got {}",
other.show()
));
other
}
}
}
fn check_construct(&mut self, name: &str, fields: &[(String, Rc<Expr>)]) -> Type {
let decl = match self.structs.get(name) {
Some(d) => d.clone(),
None => {
self.errors.push(format!("unknown struct `{name}`"));
return Type::Struct(name.into());
}
};
let mut given: BTreeMap<String, Type> = BTreeMap::new();
for (k, e) in fields {
let t = self.infer(e);
given.insert(k.clone(), t);
}
for f in &decl.fields {
match given.get(&f.name) {
Some(gt) => {
if self.unify(gt, &f.ty).is_err() {
self.errors.push(format!(
"`{name}.{}` : expected {}, got {}",
f.name,
f.ty.show(),
self.resolve(gt).show()
));
}
}
None if f.default.is_some() => {}
None => self
.errors
.push(format!("`{name}` missing required field `{}`", f.name)),
}
}
for k in given.keys() {
if !decl.fields.iter().any(|f| &f.name == k) {
self.errors.push(format!("`{name}` has no field `{k}`"));
}
}
Type::Struct(name.into())
}
fn check_binding(&mut self, b: &Binding) -> Type {
match (&b.ann, &*b.value) {
(Some(Type::Struct(name)), Expr::Record(fields)) => self.check_construct(name, fields),
(Some(ann), _) => {
let got = self.infer(&b.value);
if self.unify(&got, ann).is_err() {
self.errors.push(format!(
"`{}` : annotated {}, got {}",
b.name,
ann.show(),
self.resolve(&got).show()
));
}
ann.clone()
}
(None, _) => self.infer(&b.value),
}
}
}
/// The operator class + method an arithmetic/`/` operator desugars to.
pub fn op_class(op: BinOp) -> Option<(&'static str, &'static str)> {
match op {
BinOp::Add => Some(("Add", "add")),
BinOp::Sub => Some(("Sub", "sub")),
BinOp::Mul => Some(("Mul", "mul")),
BinOp::Slash => Some(("Div", "div")),
BinOp::Mod => Some(("Mod", "mod")),
BinOp::Pow => Some(("Pow", "pow")),
_ => None,
}
}
/// The nominal head of a type (for instance lookup), if it has one.
fn type_head(t: &Type) -> Option<String> {
match t {
Type::Int => Some("Int".into()),
Type::Str => Some("Str".into()),
Type::Bool => Some("Bool".into()),
Type::List(_) => Some("List".into()),
Type::Struct(n) | Type::Enum(n) => Some(n.clone()),
_ => None,
}
}
/// Replace the class parameter (parsed as `Struct(param)`) with `repl` in a sig.
fn subst_param(t: &Type, param: &str, repl: &Type) -> Type {
match t {
Type::Struct(n) if n == param => repl.clone(),
Type::List(x) => Type::List(Box::new(subst_param(x, param, repl))),
Type::Task(x) => Type::Task(Box::new(subst_param(x, param, repl))),
Type::Fun(x, y) => Type::Fun(
Box::new(subst_param(x, param, repl)),
Box::new(subst_param(y, param, repl)),
),
Type::Record(m) => Type::Record(
m.iter()
.map(|(k, v)| (k.clone(), subst_param(v, param, repl)))
.collect(),
),
other => other.clone(),
}
}
/// Rewrite a [`BuiltinScheme`]'s local bound vars `Var(0..)` to allocated `fresh` ids.
fn lower_type(t: &Type, fresh: &[u32]) -> Type {
match t {
Type::Var(id) => Type::Var(fresh[*id as usize]),
Type::List(x) => Type::List(Box::new(lower_type(x, fresh))),
Type::Task(x) => Type::Task(Box::new(lower_type(x, fresh))),
Type::Fun(x, y) => Type::Fun(
Box::new(lower_type(x, fresh)),
Box::new(lower_type(y, fresh)),
),
Type::Record(m) => Type::Record(
m.iter()
.map(|(k, v)| (k.clone(), lower_type(v, fresh)))
.collect(),
),
other => other.clone(),
}
}
fn free_vars(t: &Type, out: &mut Vec<u32>) {
match t {
Type::Var(id) => {
if !out.contains(id) {
out.push(*id);
}
}
Type::List(x) | Type::Task(x) => free_vars(x, out),
Type::Fun(x, y) => {
free_vars(x, out);
free_vars(y, out);
}
Type::Record(m) => {
for v in m.values() {
free_vars(v, out);
}
}
_ => {}
}
}