open Ast open Error type ty = | TyVar of type_var ref | TyCon of string * ty list | TyFun of ty * ty | TyTuple of ty list | TyRecord of (string * ty) list | TyVariant of (string * ty option) list [@@deriving show] and type_var = | Unbound of int * int | Link of ty type scheme = { vars : int list; ty : ty; } type env = (string * scheme) list let type_var_counter = ref 0 let current_level = ref 1 let fresh_var level = incr type_var_counter; TyVar (ref (Unbound (!type_var_counter, level))) let empty_env () = [] let extend_env env name scheme = (name, scheme) :: env let lookup_env env name = List.assoc_opt name env let rec occurs id level = function | TyVar ({contents = Link ty}) -> occurs id level ty | TyVar ({contents = Unbound (id', level')} as tv) -> if id = id' then true else begin if level' > level then tv := Unbound (id', level); false end | TyCon (_, args) -> List.exists (occurs id level) args | TyFun (t1, t2) -> occurs id level t1 || occurs id level t2 | TyTuple ts -> List.exists (occurs id level) ts | TyRecord fields -> List.exists (fun (_, t) -> occurs id level t) fields | TyVariant cases -> List.exists (fun (_, t_opt) -> match t_opt with Some t -> occurs id level t | None -> false) cases let rec unify ctx span t1 t2 = match (t1, t2) with | (TyVar {contents = Link t1}, t2) | (t1, TyVar {contents = Link t2}) -> unify ctx span t1 t2 | (TyVar ({contents = Unbound (id1, _)} as tv1), TyVar {contents = Unbound (id2, _)}) when id1 = id2 -> () | (TyVar ({contents = Unbound (id, level)} as tv), t) | (t, TyVar ({contents = Unbound (id, level)} as tv)) -> if occurs id level t then report_error ctx ~span TypeError (Printf.sprintf "Infinite type: cannot unify %s with %s" (string_of_ty t1) (string_of_ty t2)) else tv := Link t | (TyCon (name1, args1), TyCon (name2, args2)) -> if name1 <> name2 then report_error ctx ~span TypeError (Printf.sprintf "Cannot unify %s with %s: type constructors differ" (string_of_ty t1) (string_of_ty t2)) else if List.length args1 <> List.length args2 then report_error ctx ~span TypeError (Printf.sprintf "Cannot unify %s with %s: arity mismatch" (string_of_ty t1) (string_of_ty t2)) else List.iter2 (unify ctx span) args1 args2 | (TyFun (a1, r1), TyFun (a2, r2)) -> unify ctx span a1 a2; unify ctx span r1 r2 | (TyTuple ts1, TyTuple ts2) -> if List.length ts1 <> List.length ts2 then report_error ctx ~span TypeError (Printf.sprintf "Cannot unify tuples of different sizes") else List.iter2 (unify ctx span) ts1 ts2 | (TyRecord fields1, TyRecord fields2) -> let sorted1 = List.sort (fun (a,_) (b,_) -> String.compare a b) fields1 in let sorted2 = List.sort (fun (a,_) (b,_) -> String.compare a b) fields2 in if List.map fst sorted1 <> List.map fst sorted2 then report_error ctx ~span TypeError (Printf.sprintf "Cannot unify records with different fields") else List.iter2 (fun (_, t1) (_, t2) -> unify ctx span t1 t2) sorted1 sorted2 | _ -> report_error ctx ~span TypeError (Printf.sprintf "Cannot unify %s with %s" (string_of_ty t1) (string_of_ty t2)) and generalize level ty = let rec collect_vars ty = match ty with | TyVar {contents = Unbound (id, level')} when level' > level -> [id] | TyVar {contents = Link ty} -> collect_vars ty | TyVar _ -> [] | TyCon (_, args) -> List.concat (List.map collect_vars args) | TyFun (t1, t2) -> collect_vars t1 @ collect_vars t2 | TyTuple ts -> List.concat (List.map collect_vars ts) | TyRecord fields -> List.concat (List.map (fun (_, t) -> collect_vars t) fields) | TyVariant cases -> List.concat (List.filter_map (fun (_, t_opt) -> match t_opt with Some t -> Some (collect_vars t) | None -> None) cases) in let vars = List.sort_uniq compare (collect_vars ty) in { vars; ty } and instantiate level scheme = let subst = List.map (fun id -> (id, fresh_var level)) scheme.vars in let rec inst ty = match ty with | TyVar {contents = Unbound (id, _)} -> (try List.assoc id subst with Not_found -> ty) | TyVar {contents = Link ty} -> inst ty | TyCon (name, args) -> TyCon (name, List.map inst args) | TyFun (t1, t2) -> TyFun (inst t1, inst t2) | TyTuple ts -> TyTuple (List.map inst ts) | TyRecord fields -> TyRecord (List.map (fun (n, t) -> (n, inst t)) fields) | TyVariant cases -> TyVariant (List.map (fun (n, t_opt) -> (n, Option.map inst t_opt)) cases) in inst scheme.ty and ast_type_to_ty ctx env ty_expr = match ty_expr.type_desc with | TVar name -> fresh_var !current_level | TCon (name, args) -> let arg_tys = List.map (ast_type_to_ty ctx env) args in TyCon (name, arg_tys) | TFun (t1, t2) -> let ty1 = ast_type_to_ty ctx env t1 in let ty2 = ast_type_to_ty ctx env t2 in TyFun (ty1, ty2) | TTuple ts -> let tys = List.map (ast_type_to_ty ctx env) ts in TyTuple tys | TRecord fields -> let field_tys = List.map (fun (name, t) -> (name, ast_type_to_ty ctx env t)) fields in TyRecord field_tys | TVariant cases -> let case_tys = List.map (fun (name, t_opt) -> (name, Option.map (ast_type_to_ty ctx env) t_opt)) cases in TyVariant case_tys and infer_pattern ctx env pat = match pat.pat_desc with | PWild -> (fresh_var !current_level, []) | PVar name -> let ty = fresh_var !current_level in (ty, [(name, ty)]) | PLit lit -> (infer_literal lit, []) | PCons (p1, p2) -> let (ty1, bindings1) = infer_pattern ctx env p1 in let (ty2, bindings2) = infer_pattern ctx env p2 in unify ctx pat.pat_span ty2 (TyCon ("list", [ty1])); (TyCon ("list", [ty1]), bindings1 @ bindings2) | PList pats -> let elem_ty = fresh_var !current_level in let bindings = List.fold_left (fun acc p -> let (ty, bs) = infer_pattern ctx env p in unify ctx p.pat_span ty elem_ty; acc @ bs ) [] pats in (TyCon ("list", [elem_ty]), bindings) | PTuple pats -> let tys_bindings = List.map (infer_pattern ctx env) pats in let tys = List.map fst tys_bindings in let bindings = List.concat (List.map snd tys_bindings) in (TyTuple tys, bindings) | PRecord fields -> let field_tys_bindings = List.map (fun (name, p) -> let (ty, bs) = infer_pattern ctx env p in ((name, ty), bs) ) fields in let field_tys = List.map fst field_tys_bindings in let bindings = List.concat (List.map snd field_tys_bindings) in (TyRecord field_tys, bindings) | PVariant (name, p_opt) -> (match p_opt with | Some p -> let (ty, bindings) = infer_pattern ctx env p in (fresh_var !current_level, bindings) | None -> (fresh_var !current_level, [])) | POr (p1, p2) -> let (ty1, bindings1) = infer_pattern ctx env p1 in let (ty2, bindings2) = infer_pattern ctx env p2 in unify ctx pat.pat_span ty1 ty2; (ty1, bindings1) | PAs (p, name) -> let (ty, bindings) = infer_pattern ctx env p in (ty, (name, ty) :: bindings) | PConstraint (p, ty_expr) -> let (ty, bindings) = infer_pattern ctx env p in let expected_ty = ast_type_to_ty ctx env ty_expr in unify ctx pat.pat_span ty expected_ty; (ty, bindings) and infer_literal = function | LInt _ -> TyCon ("int", []) | LFloat _ -> TyCon ("float", []) | LString _ -> TyCon ("string", []) | LChar _ -> TyCon ("char", []) | LBool _ -> TyCon ("bool", []) | LUnit -> TyCon ("unit", []) and infer_expr ctx env expr = match expr.expr_desc with | ELit lit -> infer_literal lit | EVar name -> (match lookup_env env name with | Some scheme -> instantiate !current_level scheme | None -> report_error ctx ~span:expr.expr_span TypeError (Printf.sprintf "Unbound variable: %s" name); fresh_var !current_level) | ELambda (params, body) -> let param_tys_bindings = List.map (infer_pattern ctx env) params in let param_tys = List.map fst param_tys_bindings in let all_bindings = List.concat (List.map snd param_tys_bindings) in let extended_env = List.fold_left (fun e (name, ty) -> extend_env e name { vars = []; ty }) env all_bindings in let body_ty = infer_expr ctx extended_env body in List.fold_right (fun param_ty acc -> TyFun (param_ty, acc)) param_tys body_ty | EApp (e1, e2) -> let ty1 = infer_expr ctx env e1 in let ty2 = infer_expr ctx env e2 in let result_ty = fresh_var !current_level in unify ctx expr.expr_span ty1 (TyFun (ty2, result_ty)); result_ty | ELet (pat, e1, e2) -> let ty1 = infer_expr ctx env e1 in let (pat_ty, bindings) = infer_pattern ctx env pat in unify ctx pat.pat_span pat_ty ty1; let old_level = !current_level in incr current_level; let extended_env = List.fold_left (fun e (name, ty) -> let scheme = generalize old_level ty in extend_env e name scheme) env bindings in let ty2 = infer_expr ctx extended_env e2 in current_level := old_level; ty2 | ELetRec (bindings, body) -> let old_level = !current_level in incr current_level; let rec_vars = List.map (fun (name, _) -> (name, fresh_var !current_level)) bindings in let extended_env = List.fold_left (fun e (name, ty) -> extend_env e name { vars = []; ty }) env rec_vars in List.iter (fun ((name, e), (_, ty)) -> let inferred_ty = infer_expr ctx extended_env e in unify ctx e.expr_span ty inferred_ty ) (List.combine bindings rec_vars); let final_env = List.fold_left (fun e (name, ty) -> let scheme = generalize old_level ty in extend_env e name scheme) env rec_vars in let body_ty = infer_expr ctx final_env body in current_level := old_level; body_ty | EIf (cond, then_e, else_e) -> let cond_ty = infer_expr ctx env cond in unify ctx cond.expr_span cond_ty (TyCon ("bool", [])); let then_ty = infer_expr ctx env then_e in let else_ty = infer_expr ctx env else_e in unify ctx expr.expr_span then_ty else_ty; then_ty | EMatch (e, cases) -> let scrutinee_ty = infer_expr ctx env e in let result_ty = fresh_var !current_level in List.iter (fun (pat, guard_opt, case_expr) -> let (pat_ty, bindings) = infer_pattern ctx env pat in unify ctx pat.pat_span pat_ty scrutinee_ty; let extended_env = List.fold_left (fun e (name, ty) -> extend_env e name { vars = []; ty }) env bindings in (match guard_opt with | Some guard -> let guard_ty = infer_expr ctx extended_env guard in unify ctx guard.expr_span guard_ty (TyCon ("bool", [])) | None -> ()); let case_ty = infer_expr ctx extended_env case_expr in unify ctx case_expr.expr_span case_ty result_ty ) cases; result_ty | ETuple exprs -> let tys = List.map (infer_expr ctx env) exprs in TyTuple tys | EList exprs -> let elem_ty = fresh_var !current_level in List.iter (fun e -> let ty = infer_expr ctx env e in unify ctx e.expr_span ty elem_ty ) exprs; TyCon ("list", [elem_ty]) | ERecord fields -> let field_tys = List.map (fun (name, e) -> (name, infer_expr ctx env e)) fields in TyRecord field_tys | ERecordAccess (e, field) -> let ty = infer_expr ctx env e in let field_ty = fresh_var !current_level in field_ty | ERecordUpdate (base, fields) -> let base_ty = infer_expr ctx env base in List.iter (fun (name, e) -> let _ = infer_expr ctx env e in () ) fields; base_ty | EVariant (name, e_opt) -> (match e_opt with | Some e -> let ty = infer_expr ctx env e in fresh_var !current_level | None -> fresh_var !current_level) | EBinop (op, e1, e2) -> let ty1 = infer_expr ctx env e1 in let ty2 = infer_expr ctx env e2 in (match op with | Add | Sub | Mul | Div | Mod -> unify ctx e1.expr_span ty1 (TyCon ("int", [])); unify ctx e2.expr_span ty2 (TyCon ("int", [])); TyCon ("int", []) | Eq | Ne | Lt | Le | Gt | Ge -> unify ctx expr.expr_span ty1 ty2; TyCon ("bool", []) | And | Or -> unify ctx e1.expr_span ty1 (TyCon ("bool", [])); unify ctx e2.expr_span ty2 (TyCon ("bool", [])); TyCon ("bool", []) | Cons -> unify ctx e2.expr_span ty2 (TyCon ("list", [ty1])); ty2 | Concat -> unify ctx expr.expr_span ty1 ty2; ty1 | Pipe -> let result_ty = fresh_var !current_level in unify ctx e2.expr_span ty2 (TyFun (ty1, result_ty)); result_ty | Compose -> let middle_ty = fresh_var !current_level in let result_ty = fresh_var !current_level in unify ctx e1.expr_span ty1 (TyFun (middle_ty, result_ty)); unify ctx e2.expr_span ty2 (TyFun (ty1, middle_ty)); TyFun (ty1, result_ty)) | EUnop (op, e) -> let ty = infer_expr ctx env e in (match op with | Neg -> unify ctx e.expr_span ty (TyCon ("int", [])); TyCon ("int", []) | Not -> unify ctx e.expr_span ty (TyCon ("bool", [])); TyCon ("bool", [])) | ESequence exprs -> (match List.rev exprs with | [] -> TyCon ("unit", []) | last :: rest -> List.iter (fun e -> let _ = infer_expr ctx env e in ()) (List.rev rest); infer_expr ctx env last) | EConstraint (e, ty_expr) -> let ty = infer_expr ctx env e in let expected_ty = ast_type_to_ty ctx env ty_expr in unify ctx expr.expr_span ty expected_ty; ty | EHole -> fresh_var !current_level and string_of_ty ty = let rec go prec ty = match ty with | TyVar {contents = Link ty} -> go prec ty | TyVar {contents = Unbound (id, _)} -> Printf.sprintf "'t%d" id | TyCon (name, []) -> name | TyCon (name, args) -> Printf.sprintf "%s<%s>" name (String.concat ", " (List.map (go 0) args)) | TyFun (t1, t2) -> let s = Printf.sprintf "%s -> %s" (go 1 t1) (go 0 t2) in if prec > 0 then "(" ^ s ^ ")" else s | TyTuple ts -> "(" ^ String.concat ", " (List.map (go 0) ts) ^ ")" | TyRecord fields -> "{" ^ String.concat ", " (List.map (fun (n, t) -> n ^ ": " ^ go 0 t) fields) ^ "}" | TyVariant cases -> "<" ^ String.concat " | " (List.map (fun (n, t_opt) -> match t_opt with | Some t -> n ^ ": " ^ go 0 t | None -> n) cases) ^ ">" in go 0 ty let check_expr ctx env expr = try let ty = infer_expr ctx env expr in Some (ty, expr) with _ -> None let check_declaration ctx env decl = match decl.decl_desc with | DLet (is_rec, pat, e) -> let old_level = !current_level in incr current_level; let ty = infer_expr ctx env e in let (pat_ty, bindings) = infer_pattern ctx env pat in unify ctx pat.pat_span pat_ty ty; let extended_env = List.fold_left (fun e (name, ty) -> let scheme = generalize old_level ty in extend_env e name scheme) env bindings in current_level := old_level; extended_env | DExpr e -> let _ = infer_expr ctx env e in env | _ -> env let check_program ctx program = try let env = empty_env () in let env = extend_env env "+" { vars = []; ty = TyFun (TyCon ("int", []), TyFun (TyCon ("int", []), TyCon ("int", []))) } in let final_env = List.fold_left (fun e decl -> check_declaration ctx e decl ) env program.declarations in Some (program, final_env) with _ -> None