From 1202f6af08115ef19ce015b2dfe2331cca62bae2 Mon Sep 17 00:00:00 2001 From: meowrly Date: Wed, 24 Dec 2025 14:00:28 +0000 Subject: [PATCH] Initial commit --- dune | 15 + dune-project | 33 ++ lib/rt.js | 234 +++++++++ lib/rt.py | 237 +++++++++ makefile | 41 ++ readme.md | 150 ++++++ src/ast.ml | 179 +++++++ src/check.ml | 449 +++++++++++++++++ src/check.mli | 35 ++ src/cli.ml | 245 ++++++++++ src/error.ml | 286 +++++++++++ src/gen.ml | 419 ++++++++++++++++ src/gen.mli | 17 + src/ir.ml | 340 +++++++++++++ src/lex.ml | 674 ++++++++++++++++++++++++++ src/lex.mli | 68 +++ src/main.ml | 1 + src/opt.ml | 196 ++++++++ src/parse.ml | 853 +++++++++++++++++++++++++++++++++ src/parse.mli | 11 + test/corpus/arithmetic.star | 7 + test/corpus/higher_order.star | 22 + test/corpus/pattern_match.star | 17 + test/unit/test_lexer.ml | 79 +++ 24 files changed, 4608 insertions(+) create mode 100644 dune create mode 100644 dune-project create mode 100644 lib/rt.js create mode 100644 lib/rt.py create mode 100644 makefile create mode 100644 readme.md create mode 100644 src/ast.ml create mode 100644 src/check.ml create mode 100644 src/check.mli create mode 100644 src/cli.ml create mode 100644 src/error.ml create mode 100644 src/gen.ml create mode 100644 src/gen.mli create mode 100644 src/ir.ml create mode 100644 src/lex.ml create mode 100644 src/lex.mli create mode 100644 src/main.ml create mode 100644 src/opt.ml create mode 100644 src/parse.ml create mode 100644 src/parse.mli create mode 100644 test/corpus/arithmetic.star create mode 100644 test/corpus/higher_order.star create mode 100644 test/corpus/pattern_match.star create mode 100644 test/unit/test_lexer.ml diff --git a/dune b/dune new file mode 100644 index 0000000..62423bd --- /dev/null +++ b/dune @@ -0,0 +1,15 @@ +(executable + (name main) + (public_name star) + (package star) + (modules main cli error ast lex parse check ir opt gen) + (libraries cmdliner fmt logs uutf uucp uunf sedlex) + (preprocess (pps sedlex.ppx ppx_deriving.show ppx_deriving.eq)) + (modes exe)) + +(env + (dev + (flags (:standard -w +a-4-42-44-48-50-58-70 -warn-error +a-3))) + (release + (flags (:standard -O3)) + (ocamlopt_flags (:standard -O3)))) diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..ec93180 --- /dev/null +++ b/dune-project @@ -0,0 +1,33 @@ +(lang dune 3.12) +(name star) +(version 0.1.0) +(generate_opam_files true) +(authors "meowrly") +(maintainers "meowrly") +(license ) +(source (github star-lang/star)) +(homepage "https://github.com/meowrly/star") +(bug_reports "https://github.com/meowrsly/star/issues") + +(package + (name star) + (synopsis "") + (description "") + (depends + (ocaml (>= 4.14.0)) + (dune (>= 3.12)) + (uutf (>= 1.0.3)) + (uucp (>= 15.0.0)) + (uunf (>= 15.0.0)) + (cmdliner (>= 1.2.0)) + (fmt (>= 0.9.0)) + (logs (>= 0.7.0)) + (menhir (>= 20230608)) + (sedlex (>= 3.2)) + (ppx_deriving (>= 5.2.1)) + (ppx_inline_test (and :with-test (>= v0.15.0))) + (qcheck (and :with-test (>= 0.21))) + (qcheck-alcotest (and :with-test (>= 0.21))) + (alcotest (and :with-test (>= 1.7.0))) + (bisect_ppx (and :with-test (>= 2.8.3))) + (benchmark (and :with-test (>= 1.6))))) diff --git a/lib/rt.js b/lib/rt.js new file mode 100644 index 0000000..20cab2c --- /dev/null +++ b/lib/rt.js @@ -0,0 +1,234 @@ +"use strict"; + +function variant(tag, value) { + return value === undefined ? { tag } : { tag, value }; +} + +function match(value, cases) { + if (!value || typeof value.tag !== 'string') { + throw new Error('match: expected variant value'); + } + const handler = cases[value.tag] || cases._; + if (!handler) { + throw new Error(`match: unhandled variant tag: ${value.tag}`); + } + return handler(value.value); +} + +const nil = []; + +function cons(head, tail) { + return [head, ...tail]; +} + +function head(list) { + if (list.length === 0) { + throw new Error('head: empty list'); + } + return list[0]; +} + +function tail(list) { + if (list.length === 0) { + throw new Error('tail: empty list'); + } + return list.slice(1); +} + +function isEmpty(list) { + return list.length === 0; +} + +function length(list) { + return list.length; +} + +function map(f, list) { + return list.map(f); +} + +function filter(pred, list) { + return list.filter(pred); +} + +function foldl(f, acc, list) { + return list.reduce(f, acc); +} + +function foldr(f, acc, list) { + return list.reduceRight((acc, x) => f(x, acc), acc); +} + +function concat(list1, list2) { + return [...list1, ...list2]; +} + +function reverse(list) { + return list.slice().reverse(); +} + +function take(n, list) { + return list.slice(0, n); +} + +function drop(n, list) { + return list.slice(n); +} + +const None = variant('None'); + +function Some(value) { + return variant('Some', value); +} + +function isSome(opt) { + return opt.tag === 'Some'; +} + +function isNone(opt) { + return opt.tag === 'None'; +} + +function unwrap(opt) { + if (opt.tag === 'None') { + throw new Error('unwrap: None value'); + } + return opt.value; +} + +function unwrapOr(opt, defaultValue) { + return opt.tag === 'Some' ? opt.value : defaultValue; +} + +function Ok(value) { + return variant('Ok', value); +} + +function Err(error) { + return variant('Err', error); +} + +function isOk(result) { + return result.tag === 'Ok'; +} + +function isErr(result) { + return result.tag === 'Err'; +} + +function toString(value) { + if (value === null || value === undefined) { + return 'null'; + } + if (Array.isArray(value)) { + return '[' + value.map(toString).join(', ') + ']'; + } + if (typeof value === 'object' && value.tag) { + return value.value === undefined + ? value.tag + : `${value.tag}(${toString(value.value)})`; + } + return String(value); +} + +function strLength(str) { + return str.length; +} + +function strConcat(str1, str2) { + return str1 + str2; +} + +function substring(str, start, length) { + return str.substr(start, length); +} + +function intDiv(a, b) { + return Math.floor(a / b); +} + +function pow(base, exp) { + return Math.pow(base, exp); +} + +function abs(x) { + return Math.abs(x); +} + +function print(value) { + console.log(toString(value)); +} + +function printStr(value) { + process.stdout.write(toString(value)); +} + +function id(x) { + return x; +} + +function constant(x) { + return () => x; +} + +function compose(f, g) { + return x => f(g(x)); +} + +function pipe(value, f) { + return f(value); +} + +function curry(f, arity = f.length) { + return function curried(...args) { + if (args.length >= arity) { + return f(...args); + } + return (...moreArgs) => curried(...args, ...moreArgs); + }; +} + +if (typeof module !== 'undefined' && module.exports) { + module.exports = { + variant, + match, + nil, + cons, + head, + tail, + isEmpty, + length, + map, + filter, + foldl, + foldr, + concat, + reverse, + take, + drop, + None, + Some, + isSome, + isNone, + unwrap, + unwrapOr, + Ok, + Err, + isOk, + isErr, + toString, + strLength, + strConcat, + substring, + intDiv, + pow, + abs, + print, + printStr, + id, + constant, + compose, + pipe, + curry + }; +} diff --git a/lib/rt.py b/lib/rt.py new file mode 100644 index 0000000..de93ade --- /dev/null +++ b/lib/rt.py @@ -0,0 +1,237 @@ +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union +from dataclasses import dataclass +from functools import reduce + +@dataclass +class Variant: + tag: str + value: Any = None + + def __repr__(self): + if self.value is None: + return self.tag + return f"{self.tag}({self.value})" + +def variant(tag: str, value: Any = None) -> Variant: + return Variant(tag, value) + +def match(value: Variant, cases: dict) -> Any: + if not isinstance(value, Variant): + if isinstance(value, dict) and 'tag' in value: + tag = value['tag'] + val = value.get('value') + handler = cases.get(tag, cases.get('_')) + if handler is None: + raise ValueError(f"match: unhandled variant tag: {tag}") + return handler(val) if val is not None else handler() + raise TypeError('match: expected variant value') + + handler = cases.get(value.tag, cases.get('_')) + if handler is None: + raise ValueError(f"match: unhandled variant tag: {value.tag}") + return handler(value.value) if value.value is not None else handler() + +nil: List = [] + +def cons(head: Any, tail: List) -> List: + return [head] + tail + +def head(lst: List) -> Any: + if not lst: + raise ValueError('head: empty list') + return lst[0] + +def tail(lst: List) -> List: + if not lst: + raise ValueError('tail: empty list') + return lst[1:] + +def is_empty(lst: List) -> bool: + return len(lst) == 0 + +def length(lst: List) -> int: + return len(lst) + +def map_list(f: Callable, lst: List) -> List: + return [f(x) for x in lst] + +def filter_list(pred: Callable, lst: List) -> List: + return [x for x in lst if pred(x)] + +def foldl(f: Callable, acc: Any, lst: List) -> Any: + return reduce(f, lst, acc) + +def foldr(f: Callable, acc: Any, lst: List) -> Any: + result = acc + for x in reversed(lst): + result = f(x, result) + return result + +def concat_lists(lst1: List, lst2: List) -> List: + return lst1 + lst2 + +def reverse_list(lst: List) -> List: + return list(reversed(lst)) + +def take(n: int, lst: List) -> List: + return lst[:n] + +def drop(n: int, lst: List) -> List: + return lst[n:] + +None_variant = Variant('None') + +def Some(value: Any) -> Variant: + return Variant('Some', value) + +def is_some(opt: Variant) -> bool: + return opt.tag == 'Some' + +def is_none(opt: Variant) -> bool: + return opt.tag == 'None' + +def unwrap(opt: Variant) -> Any: + if opt.tag == 'None': + raise ValueError('unwrap: None value') + return opt.value + +def unwrap_or(opt: Variant, default: Any) -> Any: + return opt.value if opt.tag == 'Some' else default + +def Ok(value: Any) -> Variant: + return Variant('Ok', value) + +def Err(error: Any) -> Variant: + return Variant('Err', error) + +def is_ok(result: Variant) -> bool: + return result.tag == 'Ok' + +def is_err(result: Variant) -> bool: + return result.tag == 'Err' + +def to_string(value: Any) -> str: + if value is None: + return 'null' + if isinstance(value, list): + return '[' + ', '.join(to_string(v) for v in value) + ']' + if isinstance(value, Variant): + return repr(value) + if isinstance(value, dict) and 'tag' in value: + if 'value' in value: + return f"{value['tag']}({to_string(value['value'])})" + return value['tag'] + return str(value) + +def str_length(s: str) -> int: + return len(s) + +def str_concat(s1: str, s2: str) -> str: + return s1 + s2 + +def substring(s: str, start: int, length: int) -> str: + return s[start:start+length] + +def int_div(a: int, b: int) -> int: + return a // b + +def pow_int(base: int, exp: int) -> int: + return base ** exp + +def abs_num(x: Union[int, float]) -> Union[int, float]: + return abs(x) + +def print_value(value: Any) -> None: + print(to_string(value)) + +def print_str(value: Any) -> None: + import sys + sys.stdout.write(to_string(value)) + sys.stdout.flush() + +def identity(x: Any) -> Any: + return x + +def constant(x: Any) -> Callable: + return lambda: x + +def compose(f: Callable, g: Callable) -> Callable: + return lambda x: f(g(x)) + +def pipe(value: Any, f: Callable) -> Any: + return f(value) + +def curry(f: Callable, arity: Optional[int] = None) -> Callable: + if arity is None: + arity = f.__code__.co_argcount + + def curried(*args): + if len(args) >= arity: + return f(*args) + return lambda *more_args: curried(*args, *more_args) + + return curried + +def record_get(record: dict, field: str) -> Any: + return record[field] + +def record_set(record: dict, field: str, value: Any) -> dict: + new_record = record.copy() + new_record[field] = value + return new_record + +def record_update(record: dict, **fields) -> dict: + new_record = record.copy() + new_record.update(fields) + return new_record + +def fst(tup: Tuple) -> Any: + return tup[0] + +def snd(tup: Tuple) -> Any: + return tup[1] + +def logical_and(a: bool, b: bool) -> bool: + return a and b + +def logical_or(a: bool, b: bool) -> bool: + return a or b + +def logical_not(a: bool) -> bool: + return not a + +def equal(a: Any, b: Any) -> bool: + return a == b + +def not_equal(a: Any, b: Any) -> bool: + return a != b + +def less_than(a: Any, b: Any) -> bool: + return a < b + +def less_equal(a: Any, b: Any) -> bool: + return a <= b + +def greater_than(a: Any, b: Any) -> bool: + return a > b + +def greater_equal(a: Any, b: Any) -> bool: + return a >= b + +__all__ = [ + 'Variant', 'variant', 'match', + 'nil', 'cons', 'head', 'tail', 'is_empty', 'length', + 'map_list', 'filter_list', 'foldl', 'foldr', + 'concat_lists', 'reverse_list', 'take', 'drop', + 'None_variant', 'Some', 'is_some', 'is_none', 'unwrap', 'unwrap_or', + 'Ok', 'Err', 'is_ok', 'is_err', + 'to_string', 'str_length', 'str_concat', 'substring', + 'int_div', 'pow_int', 'abs_num', + 'print_value', 'print_str', + 'identity', 'constant', 'compose', 'pipe', 'curry', + 'record_get', 'record_set', 'record_update', + 'fst', 'snd', + 'logical_and', 'logical_or', 'logical_not', + 'equal', 'not_equal', 'less_than', 'less_equal', + 'greater_than', 'greater_equal' +] diff --git a/makefile b/makefile new file mode 100644 index 0000000..c847568 --- /dev/null +++ b/makefile @@ -0,0 +1,41 @@ +.PHONY: all build clean test install uninstall fmt doc coverage bench fuzz + +all: build + +build: + dune build + +test: + dune runtest --force + +install: + dune install + +uninstall: + dune uninstall + +clean: + dune clean + +fmt: + dune build @fmt --auto-promote + +doc: + dune build @doc + +coverage: + BISECT_ENABLE=yes dune runtest --force + bisect-ppx-report html + bisect-ppx-report summary + +bench: + dune exec --release -- bench/run_benchmarks.exe + +fuzz: + dune exec -- fuzz/fuzz_parser.exe + +watch: + dune build --watch + +promote: + dune promote diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..8b3345e --- /dev/null +++ b/readme.md @@ -0,0 +1,150 @@ +# star + +Star is a statically typed functional programming language with type inference, ADTs, pattern matching and a module system. It transpiles to idiomatic JavaScript ES6 and Python 3.10+ w/ full source map support + +### Prerequisites + +- OCaml >= 4.14.0 +- opam +- dune >= 3.12 + +### Build from Source + +```bash +git clone https://github.com/meowrly/star.git +cd star +opam install . --deps-only +make build +make install +``` + +## Usage + +Compile to JS: +```bash +star compile program.star +``` + +Compile to Python: +```bash +star compile -t python program.star +``` + +w/ optimisations: +```bash +star compile -O -o output.js program.star +``` + +### REPL + +```bash +star repl +``` + +### Language + +```star +let main = fn () => + print("Hello, World!") +``` + +### Functions & lambdas + +```star +let add = fn x y => x + y +let double = fn x => x * 2 +let compose = fn f g => fn x => f(g(x)) +let add5 = add(5) +``` + +### Pattern matching + +```star +type Option<'a> = + | Some of 'a + | None + +let unwrap_or = fn default opt => + match opt with + | Some(x) => x + | None => default + end +``` + +### ADTs + +```star +type Result<'a, 'e> = + | Ok of 'a + | Err of 'e + +type Person = { + name: string, + age: int +} + +let point: (int, int) = (10, 20) +``` + +### Lists + +```star +let numbers = [1, 2, 3, 4, 5] +let doubled = map(fn x => x * 2, numbers) +let evens = filter(fn x => x % 2 == 0, numbers) +let sum = foldl(fn acc x => acc + x, 0, numbers) +``` + +### Records + +```star +type Point = { + x: int, + y: int +} + +let origin = { x = 0, y = 0 } + +let moved = { origin with x = 10 } +``` + +### Modules + +```star +module Math = struct + let pi = 3.14159 + + let circle_area = fn r => + pi * r * r +end + +let area = Math.circle_area(5.0) +``` + +### Type annotations + +```star +let identity: 'a -> 'a = fn x => x + +let map: ('a -> 'b) -> list<'a> -> list<'b> = + fn f list => + match list with + | [] => [] + | x :: xs => f(x) :: map(f, xs) + end +``` + +### Type Inference + +```star +let add = fn x y => x + y +let cons = fn x xs => x :: xs +let map = fn f list => ... +``` + +## References + +- Pierce, B. C. (2002). Types and Programming Languages +- Appel, A. W. (1998). Modern Compiler Implementation in ML +- Leroy, X. et al. The OCaml System +- Diehl, S. Write You a Haskell \ No newline at end of file diff --git a/src/ast.ml b/src/ast.ml new file mode 100644 index 0000000..51eb053 --- /dev/null +++ b/src/ast.ml @@ -0,0 +1,179 @@ +open Error + +type ident = string +[@@deriving show, eq] + +type literal = + | LInt of int64 + | LFloat of float + | LString of string + | LChar of int + | LBool of bool + | LUnit +[@@deriving show, eq] + +type binop = + | Add | Sub | Mul | Div | Mod + | Eq | Ne | Lt | Le | Gt | Ge + | And | Or + | Cons + | Concat + | Pipe + | Compose +[@@deriving show, eq] + +type unop = + | Neg + | Not +[@@deriving show, eq] + +type pattern = { + pat_desc : pattern_desc; + pat_span : span; +} + +and pattern_desc = + | PWild + | PVar of ident + | PLit of literal + | PCons of pattern * pattern + | PList of pattern list + | PTuple of pattern list + | PRecord of (ident * pattern) list + | PVariant of ident * pattern option + | POr of pattern * pattern + | PAs of pattern * ident + | PConstraint of pattern * type_expr +[@@deriving show] + +and type_expr = { + type_desc : type_desc; + type_span : span; +} + +and type_desc = + | TVar of ident + | TCon of ident * type_expr list + | TFun of type_expr * type_expr + | TTuple of type_expr list + | TRecord of (ident * type_expr) list + | TVariant of (ident * type_expr option) list +[@@deriving show] + +type expr = { + expr_desc : expr_desc; + expr_span : span; +} + +and expr_desc = + | ELit of literal + | EVar of ident + | ELambda of pattern list * expr + | EApp of expr * expr + | ELet of pattern * expr * expr + | ELetRec of (ident * expr) list * expr + | EIf of expr * expr * expr + | EMatch of expr * (pattern * expr option * expr) list + | ETuple of expr list + | EList of expr list + | ERecord of (ident * expr) list + | ERecordAccess of expr * ident + | ERecordUpdate of expr * (ident * expr) list + | EVariant of ident * expr option + | EBinop of binop * expr * expr + | EUnop of unop * expr + | ESequence of expr list + | EConstraint of expr * type_expr + | EHole +[@@deriving show] + +type type_decl = { + type_name : ident; + type_params : ident list; + type_kind : type_kind; + type_span : span; +} + +and type_kind = + | TAlias of type_expr + | TRecord of (ident * type_expr * bool) list + | TVariant of (ident * type_expr option) list + | TAbstract +[@@deriving show] + +type module_expr = { + mod_desc : module_desc; + mod_span : span; +} + +and module_desc = + | MStruct of declaration list + | MIdent of ident + | MFunctor of ident * module_type option * module_expr + | MApply of module_expr * module_expr + | MConstraint of module_expr * module_type +[@@deriving show] + +and module_type = { + sig_desc : signature_desc; + sig_span : span; +} + +and signature_desc = + | SigVal of ident * type_expr + | SigType of type_decl + | SigModule of ident * module_type + | SigInclude of module_type + | SigMultiple of module_type list +[@@deriving show] + +and declaration = { + decl_desc : declaration_desc; + decl_span : span; +} + +and declaration_desc = + | DLet of bool * pattern * expr + | DType of type_decl + | DModule of ident * module_expr + | DModuleType of ident * module_type + | DOpen of ident + | DExpr of expr +[@@deriving show] + +type program = { + declarations : declaration list; + file : string; +} +[@@deriving show] + +let make_pattern desc span = { pat_desc = desc; pat_span = span } +let make_type desc span = { type_desc = desc; type_span = span } +let make_expr desc span = { expr_desc = desc; expr_span = span } +let make_module desc span = { mod_desc = desc; mod_span = span } +let make_signature desc span = { sig_desc = desc; sig_span = span } +let make_decl desc span = { decl_desc = desc; decl_span = span } + +let binop_precedence = function + | Pipe -> 1 + | Or -> 2 + | And -> 3 + | Eq | Ne | Lt | Le | Gt | Ge -> 4 + | Cons | Concat -> 5 + | Add | Sub -> 6 + | Mul | Div | Mod -> 7 + | Compose -> 8 + +let binop_is_left_assoc = function + | Cons -> false + | Pipe -> true + | Compose -> true + | _ -> true + +let pattern_span p = p.pat_span + +let expr_span e = e.expr_span + +let type_span t = t.type_span + +let decl_span d = d.decl_span diff --git a/src/check.ml b/src/check.ml new file mode 100644 index 0000000..57cb3be --- /dev/null +++ b/src/check.ml @@ -0,0 +1,449 @@ +open Ast +open Error + +type ty = + | TyVar of type_var ref + | TyCon of string * ty list + | TyFun of ty * ty + | TyTuple of ty list + | TyRecord of (string * ty) list + | TyVariant of (string * ty option) list +[@@deriving show] + +and type_var = + | Unbound of int * int + | Link of ty + +type scheme = { + vars : int list; + ty : ty; +} + +type env = (string * scheme) list + +let type_var_counter = ref 0 + +let current_level = ref 1 + +let fresh_var level = + incr type_var_counter; + TyVar (ref (Unbound (!type_var_counter, level))) + +let empty_env () = [] + +let extend_env env name scheme = + (name, scheme) :: env + +let lookup_env env name = + List.assoc_opt name env + +let rec occurs id level = function + | TyVar ({contents = Link ty}) -> occurs id level ty + | TyVar ({contents = Unbound (id', level')} as tv) -> + if id = id' then true + else begin + if level' > level then tv := Unbound (id', level); + false + end + | TyCon (_, args) -> List.exists (occurs id level) args + | TyFun (t1, t2) -> occurs id level t1 || occurs id level t2 + | TyTuple ts -> List.exists (occurs id level) ts + | TyRecord fields -> List.exists (fun (_, t) -> occurs id level t) fields + | TyVariant cases -> List.exists (fun (_, t_opt) -> + match t_opt with Some t -> occurs id level t | None -> false) cases + +let rec unify ctx span t1 t2 = + match (t1, t2) with + | (TyVar {contents = Link t1}, t2) | (t1, TyVar {contents = Link t2}) -> + unify ctx span t1 t2 + | (TyVar ({contents = Unbound (id1, _)} as tv1), + TyVar {contents = Unbound (id2, _)}) when id1 = id2 -> + () + | (TyVar ({contents = Unbound (id, level)} as tv), t) + | (t, TyVar ({contents = Unbound (id, level)} as tv)) -> + if occurs id level t then + report_error ctx ~span TypeError + (Printf.sprintf "Infinite type: cannot unify %s with %s" + (string_of_ty t1) (string_of_ty t2)) + else + tv := Link t + | (TyCon (name1, args1), TyCon (name2, args2)) -> + if name1 <> name2 then + report_error ctx ~span TypeError + (Printf.sprintf "Cannot unify %s with %s: type constructors differ" + (string_of_ty t1) (string_of_ty t2)) + else if List.length args1 <> List.length args2 then + report_error ctx ~span TypeError + (Printf.sprintf "Cannot unify %s with %s: arity mismatch" + (string_of_ty t1) (string_of_ty t2)) + else + List.iter2 (unify ctx span) args1 args2 + | (TyFun (a1, r1), TyFun (a2, r2)) -> + unify ctx span a1 a2; + unify ctx span r1 r2 + | (TyTuple ts1, TyTuple ts2) -> + if List.length ts1 <> List.length ts2 then + report_error ctx ~span TypeError + (Printf.sprintf "Cannot unify tuples of different sizes") + else + List.iter2 (unify ctx span) ts1 ts2 + | (TyRecord fields1, TyRecord fields2) -> + let sorted1 = List.sort (fun (a,_) (b,_) -> String.compare a b) fields1 in + let sorted2 = List.sort (fun (a,_) (b,_) -> String.compare a b) fields2 in + if List.map fst sorted1 <> List.map fst sorted2 then + report_error ctx ~span TypeError + (Printf.sprintf "Cannot unify records with different fields") + else + List.iter2 (fun (_, t1) (_, t2) -> unify ctx span t1 t2) sorted1 sorted2 + | _ -> + report_error ctx ~span TypeError + (Printf.sprintf "Cannot unify %s with %s" + (string_of_ty t1) (string_of_ty t2)) + +and generalize level ty = + let rec collect_vars ty = + match ty with + | TyVar {contents = Unbound (id, level')} when level' > level -> [id] + | TyVar {contents = Link ty} -> collect_vars ty + | TyVar _ -> [] + | TyCon (_, args) -> List.concat (List.map collect_vars args) + | TyFun (t1, t2) -> collect_vars t1 @ collect_vars t2 + | TyTuple ts -> List.concat (List.map collect_vars ts) + | TyRecord fields -> List.concat (List.map (fun (_, t) -> collect_vars t) fields) + | TyVariant cases -> + List.concat (List.filter_map (fun (_, t_opt) -> + match t_opt with Some t -> Some (collect_vars t) | None -> None) cases) + in + let vars = List.sort_uniq compare (collect_vars ty) in + { vars; ty } + +and instantiate level scheme = + let subst = List.map (fun id -> (id, fresh_var level)) scheme.vars in + let rec inst ty = + match ty with + | TyVar {contents = Unbound (id, _)} -> + (try List.assoc id subst with Not_found -> ty) + | TyVar {contents = Link ty} -> inst ty + | TyCon (name, args) -> TyCon (name, List.map inst args) + | TyFun (t1, t2) -> TyFun (inst t1, inst t2) + | TyTuple ts -> TyTuple (List.map inst ts) + | TyRecord fields -> TyRecord (List.map (fun (n, t) -> (n, inst t)) fields) + | TyVariant cases -> TyVariant (List.map (fun (n, t_opt) -> + (n, Option.map inst t_opt)) cases) + in + inst scheme.ty + +and ast_type_to_ty ctx env ty_expr = + match ty_expr.type_desc with + | TVar name -> + fresh_var !current_level + | TCon (name, args) -> + let arg_tys = List.map (ast_type_to_ty ctx env) args in + TyCon (name, arg_tys) + | TFun (t1, t2) -> + let ty1 = ast_type_to_ty ctx env t1 in + let ty2 = ast_type_to_ty ctx env t2 in + TyFun (ty1, ty2) + | TTuple ts -> + let tys = List.map (ast_type_to_ty ctx env) ts in + TyTuple tys + | TRecord fields -> + let field_tys = List.map (fun (name, t) -> + (name, ast_type_to_ty ctx env t)) fields in + TyRecord field_tys + | TVariant cases -> + let case_tys = List.map (fun (name, t_opt) -> + (name, Option.map (ast_type_to_ty ctx env) t_opt)) cases in + TyVariant case_tys + +and infer_pattern ctx env pat = + match pat.pat_desc with + | PWild -> + (fresh_var !current_level, []) + | PVar name -> + let ty = fresh_var !current_level in + (ty, [(name, ty)]) + | PLit lit -> + (infer_literal lit, []) + | PCons (p1, p2) -> + let (ty1, bindings1) = infer_pattern ctx env p1 in + let (ty2, bindings2) = infer_pattern ctx env p2 in + unify ctx pat.pat_span ty2 (TyCon ("list", [ty1])); + (TyCon ("list", [ty1]), bindings1 @ bindings2) + | PList pats -> + let elem_ty = fresh_var !current_level in + let bindings = List.fold_left (fun acc p -> + let (ty, bs) = infer_pattern ctx env p in + unify ctx p.pat_span ty elem_ty; + acc @ bs + ) [] pats in + (TyCon ("list", [elem_ty]), bindings) + | PTuple pats -> + let tys_bindings = List.map (infer_pattern ctx env) pats in + let tys = List.map fst tys_bindings in + let bindings = List.concat (List.map snd tys_bindings) in + (TyTuple tys, bindings) + | PRecord fields -> + let field_tys_bindings = List.map (fun (name, p) -> + let (ty, bs) = infer_pattern ctx env p in + ((name, ty), bs) + ) fields in + let field_tys = List.map fst field_tys_bindings in + let bindings = List.concat (List.map snd field_tys_bindings) in + (TyRecord field_tys, bindings) + | PVariant (name, p_opt) -> + (match p_opt with + | Some p -> + let (ty, bindings) = infer_pattern ctx env p in + (fresh_var !current_level, bindings) + | None -> + (fresh_var !current_level, [])) + | POr (p1, p2) -> + let (ty1, bindings1) = infer_pattern ctx env p1 in + let (ty2, bindings2) = infer_pattern ctx env p2 in + unify ctx pat.pat_span ty1 ty2; + (ty1, bindings1) + | PAs (p, name) -> + let (ty, bindings) = infer_pattern ctx env p in + (ty, (name, ty) :: bindings) + | PConstraint (p, ty_expr) -> + let (ty, bindings) = infer_pattern ctx env p in + let expected_ty = ast_type_to_ty ctx env ty_expr in + unify ctx pat.pat_span ty expected_ty; + (ty, bindings) + +and infer_literal = function + | LInt _ -> TyCon ("int", []) + | LFloat _ -> TyCon ("float", []) + | LString _ -> TyCon ("string", []) + | LChar _ -> TyCon ("char", []) + | LBool _ -> TyCon ("bool", []) + | LUnit -> TyCon ("unit", []) + +and infer_expr ctx env expr = + match expr.expr_desc with + | ELit lit -> + infer_literal lit + | EVar name -> + (match lookup_env env name with + | Some scheme -> instantiate !current_level scheme + | None -> + report_error ctx ~span:expr.expr_span TypeError + (Printf.sprintf "Unbound variable: %s" name); + fresh_var !current_level) + | ELambda (params, body) -> + let param_tys_bindings = List.map (infer_pattern ctx env) params in + let param_tys = List.map fst param_tys_bindings in + let all_bindings = List.concat (List.map snd param_tys_bindings) in + let extended_env = List.fold_left (fun e (name, ty) -> + extend_env e name { vars = []; ty }) env all_bindings in + let body_ty = infer_expr ctx extended_env body in + List.fold_right (fun param_ty acc -> TyFun (param_ty, acc)) + param_tys body_ty + | EApp (e1, e2) -> + let ty1 = infer_expr ctx env e1 in + let ty2 = infer_expr ctx env e2 in + let result_ty = fresh_var !current_level in + unify ctx expr.expr_span ty1 (TyFun (ty2, result_ty)); + result_ty + | ELet (pat, e1, e2) -> + let ty1 = infer_expr ctx env e1 in + let (pat_ty, bindings) = infer_pattern ctx env pat in + unify ctx pat.pat_span pat_ty ty1; + let old_level = !current_level in + incr current_level; + let extended_env = List.fold_left (fun e (name, ty) -> + let scheme = generalize old_level ty in + extend_env e name scheme) env bindings in + let ty2 = infer_expr ctx extended_env e2 in + current_level := old_level; + ty2 + | ELetRec (bindings, body) -> + let old_level = !current_level in + incr current_level; + let rec_vars = List.map (fun (name, _) -> + (name, fresh_var !current_level)) bindings in + let extended_env = List.fold_left (fun e (name, ty) -> + extend_env e name { vars = []; ty }) env rec_vars in + List.iter (fun ((name, e), (_, ty)) -> + let inferred_ty = infer_expr ctx extended_env e in + unify ctx e.expr_span ty inferred_ty + ) (List.combine bindings rec_vars); + let final_env = List.fold_left (fun e (name, ty) -> + let scheme = generalize old_level ty in + extend_env e name scheme) env rec_vars in + let body_ty = infer_expr ctx final_env body in + current_level := old_level; + body_ty + | EIf (cond, then_e, else_e) -> + let cond_ty = infer_expr ctx env cond in + unify ctx cond.expr_span cond_ty (TyCon ("bool", [])); + let then_ty = infer_expr ctx env then_e in + let else_ty = infer_expr ctx env else_e in + unify ctx expr.expr_span then_ty else_ty; + then_ty + | EMatch (e, cases) -> + let scrutinee_ty = infer_expr ctx env e in + let result_ty = fresh_var !current_level in + List.iter (fun (pat, guard_opt, case_expr) -> + let (pat_ty, bindings) = infer_pattern ctx env pat in + unify ctx pat.pat_span pat_ty scrutinee_ty; + let extended_env = List.fold_left (fun e (name, ty) -> + extend_env e name { vars = []; ty }) env bindings in + (match guard_opt with + | Some guard -> + let guard_ty = infer_expr ctx extended_env guard in + unify ctx guard.expr_span guard_ty (TyCon ("bool", [])) + | None -> ()); + let case_ty = infer_expr ctx extended_env case_expr in + unify ctx case_expr.expr_span case_ty result_ty + ) cases; + result_ty + | ETuple exprs -> + let tys = List.map (infer_expr ctx env) exprs in + TyTuple tys + | EList exprs -> + let elem_ty = fresh_var !current_level in + List.iter (fun e -> + let ty = infer_expr ctx env e in + unify ctx e.expr_span ty elem_ty + ) exprs; + TyCon ("list", [elem_ty]) + | ERecord fields -> + let field_tys = List.map (fun (name, e) -> + (name, infer_expr ctx env e)) fields in + TyRecord field_tys + | ERecordAccess (e, field) -> + let ty = infer_expr ctx env e in + let field_ty = fresh_var !current_level in + field_ty + | ERecordUpdate (base, fields) -> + let base_ty = infer_expr ctx env base in + List.iter (fun (name, e) -> + let _ = infer_expr ctx env e in + () + ) fields; + base_ty + | EVariant (name, e_opt) -> + (match e_opt with + | Some e -> + let ty = infer_expr ctx env e in + fresh_var !current_level + | None -> + fresh_var !current_level) + | EBinop (op, e1, e2) -> + let ty1 = infer_expr ctx env e1 in + let ty2 = infer_expr ctx env e2 in + (match op with + | Add | Sub | Mul | Div | Mod -> + unify ctx e1.expr_span ty1 (TyCon ("int", [])); + unify ctx e2.expr_span ty2 (TyCon ("int", [])); + TyCon ("int", []) + | Eq | Ne | Lt | Le | Gt | Ge -> + unify ctx expr.expr_span ty1 ty2; + TyCon ("bool", []) + | And | Or -> + unify ctx e1.expr_span ty1 (TyCon ("bool", [])); + unify ctx e2.expr_span ty2 (TyCon ("bool", [])); + TyCon ("bool", []) + | Cons -> + unify ctx e2.expr_span ty2 (TyCon ("list", [ty1])); + ty2 + | Concat -> + unify ctx expr.expr_span ty1 ty2; + ty1 + | Pipe -> + let result_ty = fresh_var !current_level in + unify ctx e2.expr_span ty2 (TyFun (ty1, result_ty)); + result_ty + | Compose -> + let middle_ty = fresh_var !current_level in + let result_ty = fresh_var !current_level in + unify ctx e1.expr_span ty1 (TyFun (middle_ty, result_ty)); + unify ctx e2.expr_span ty2 (TyFun (ty1, middle_ty)); + TyFun (ty1, result_ty)) + | EUnop (op, e) -> + let ty = infer_expr ctx env e in + (match op with + | Neg -> + unify ctx e.expr_span ty (TyCon ("int", [])); + TyCon ("int", []) + | Not -> + unify ctx e.expr_span ty (TyCon ("bool", [])); + TyCon ("bool", [])) + | ESequence exprs -> + (match List.rev exprs with + | [] -> TyCon ("unit", []) + | last :: rest -> + List.iter (fun e -> let _ = infer_expr ctx env e in ()) (List.rev rest); + infer_expr ctx env last) + | EConstraint (e, ty_expr) -> + let ty = infer_expr ctx env e in + let expected_ty = ast_type_to_ty ctx env ty_expr in + unify ctx expr.expr_span ty expected_ty; + ty + | EHole -> + fresh_var !current_level + +and string_of_ty ty = + let rec go prec ty = + match ty with + | TyVar {contents = Link ty} -> go prec ty + | TyVar {contents = Unbound (id, _)} -> Printf.sprintf "'t%d" id + | TyCon (name, []) -> name + | TyCon (name, args) -> + Printf.sprintf "%s<%s>" name + (String.concat ", " (List.map (go 0) args)) + | TyFun (t1, t2) -> + let s = Printf.sprintf "%s -> %s" (go 1 t1) (go 0 t2) in + if prec > 0 then "(" ^ s ^ ")" else s + | TyTuple ts -> + "(" ^ String.concat ", " (List.map (go 0) ts) ^ ")" + | TyRecord fields -> + "{" ^ String.concat ", " (List.map (fun (n, t) -> + n ^ ": " ^ go 0 t) fields) ^ "}" + | TyVariant cases -> + "<" ^ String.concat " | " (List.map (fun (n, t_opt) -> + match t_opt with + | Some t -> n ^ ": " ^ go 0 t + | None -> n) cases) ^ ">" + in + go 0 ty + +let check_expr ctx env expr = + try + let ty = infer_expr ctx env expr in + Some (ty, expr) + with _ -> None + +let check_declaration ctx env decl = + match decl.decl_desc with + | DLet (is_rec, pat, e) -> + let old_level = !current_level in + incr current_level; + let ty = infer_expr ctx env e in + let (pat_ty, bindings) = infer_pattern ctx env pat in + unify ctx pat.pat_span pat_ty ty; + let extended_env = List.fold_left (fun e (name, ty) -> + let scheme = generalize old_level ty in + extend_env e name scheme) env bindings in + current_level := old_level; + extended_env + | DExpr e -> + let _ = infer_expr ctx env e in + env + | _ -> + env + +let check_program ctx program = + try + let env = empty_env () in + let env = extend_env env "+" { vars = []; ty = TyFun (TyCon ("int", []), + TyFun (TyCon ("int", []), TyCon ("int", []))) } in + + let final_env = List.fold_left (fun e decl -> + check_declaration ctx e decl + ) env program.declarations in + + Some (program, final_env) + with _ -> None diff --git a/src/check.mli b/src/check.mli new file mode 100644 index 0000000..1c3fd79 --- /dev/null +++ b/src/check.mli @@ -0,0 +1,35 @@ +open Ast + +type ty = + | TyVar of type_var ref + | TyCon of string * ty list + | TyFun of ty * ty + | TyTuple of ty list + | TyRecord of (string * ty) list + | TyVariant of (string * ty option) list +[@@deriving show] + +and type_var = + | Unbound of int * int + | Link of ty + +type scheme = { + vars : int list; + ty : ty; +} + +type env + +val empty_env : unit -> env + +val extend_env : env -> string -> scheme -> env + +val lookup_env : env -> string -> scheme option + +val check_program : Error.context -> program -> (program * env) option + +val check_expr : Error.context -> env -> expr -> (ty * expr) option + +val infer_expr : Error.context -> env -> expr -> ty option + +val string_of_ty : ty -> string diff --git a/src/cli.ml b/src/cli.ml new file mode 100644 index 0000000..84b4806 --- /dev/null +++ b/src/cli.ml @@ -0,0 +1,245 @@ +open Cmdliner + +type target = JS | Python + +let target_of_string = function + | "js" | "javascript" -> JS + | "py" | "python" -> Python + | s -> failwith (Printf.sprintf "Unknown target: %s" s) + +type options = { + input_files : string list; + output_file : string option; + target : target; + optimize : bool; + type_check : bool; + emit_ir : bool; + verbose : bool; + no_color : bool; +} + +let compile_file options filename = + try + let ic = open_in filename in + let content = really_input_string ic (in_channel_length ic) in + close_in ic; + + let ctx = Error.create_context () in + Error.register_source ctx filename content; + + if options.verbose then + Printf.printf "Compiling %s...\n" filename; + + let lex = Lex.create ~filename content in + if options.verbose then + Printf.printf " Lexing complete\n"; + + let program = Parse.parse_program ctx lex in + if Error.has_errors ctx then begin + Error.print_diagnostics ctx; + exit 1 + end; + if options.verbose then + Printf.printf " Parsing complete\n"; + + if options.type_check then begin + (match Check.check_program ctx program with + | Some (_, _) -> + if options.verbose then + Printf.printf " Type checking complete\n" + | None -> + if Error.has_errors ctx then begin + Error.print_diagnostics ctx; + exit 1 + end else begin + Printf.eprintf "Type checking failed\n"; + exit 1 + end) + end; + + let ir = Ir.program_to_ir program in + if options.verbose then + Printf.printf " IR generation complete\n"; + + if options.emit_ir then begin + Printf.printf "\n=== Intermediate Representation ===\n"; + Printf.printf "%s\n" (Ir.show_ir_program ir) + end; + + let ir_opt = if options.optimize then begin + if options.verbose then + Printf.printf " Optimization complete\n"; + Opt.optimize_program ir + end else ir in + + let target_lang = match options.target with + | JS -> Gen.JavaScript + | Python -> Gen.Python + in + + let output_filename = match options.output_file with + | Some f -> f + | None -> + let base = Filename.remove_extension filename in + match options.target with + | JS -> base ^ ".js" + | Python -> base ^ ".py" + in + + let gen_options = { + Gen.target = target_lang; + output_file = output_filename; + source_map = false; + minify = false; + runtime_path = None; + } in + + let output_code = Gen.generate gen_options ir_opt in + + let oc = open_out output_filename in + output_string oc output_code; + close_out oc; + + if options.verbose then + Printf.printf " Code generation complete: %s\n" output_filename; + + Printf.printf "Successfully compiled %s -> %s\n" filename output_filename; + + 0 + with + | Sys_error msg -> + Printf.eprintf "Error: %s\n" msg; + 1 + | e -> + Printf.eprintf "Internal compiler error: %s\n" (Printexc.to_string e); + if options.verbose then + Printexc.print_backtrace stderr; + 1 + +let input_files = + let doc = "Source files to compile" in + Arg.(non_empty & pos_all file [] & info [] ~docv:"FILES" ~doc) + +let output_file = + let doc = "Output file name" in + Arg.(value & opt (some string) None & info ["o"; "output"] ~docv:"FILE" ~doc) + +let target = + let doc = "Target language (js, python)" in + Arg.(value & opt string "js" & info ["t"; "target"] ~docv:"TARGET" ~doc) + +let optimize = + let doc = "Enable optimizations" in + Arg.(value & flag & info ["O"; "optimize"] ~doc) + +let type_check = + let doc = "Perform type checking" in + Arg.(value & flag & info ["type-check"] ~doc) + +let emit_ir = + let doc = "Emit intermediate representation" in + Arg.(value & flag & info ["emit-ir"] ~doc) + +let verbose = + let doc = "Verbose output" in + Arg.(value & flag & info ["v"; "verbose"] ~doc) + +let no_color = + let doc = "Disable colored output" in + Arg.(value & flag & info ["no-color"] ~doc) + +let compile_cmd = + let doc = "Compile Star source files" in + let man = [ + `S Manpage.s_description; + `P "The Star compiler transpiles Star source code to JavaScript or Python."; + `P "By default, it compiles to JavaScript."; + `S Manpage.s_examples; + `P "Compile to JavaScript:"; + `Pre " star compile program.star"; + `P "Compile to Python with optimizations:"; + `Pre " star compile -t python -O program.star"; + `S Manpage.s_bugs; + `P "Report bugs at https://github.com/star-lang/star/issues"; + ] in + let info = Cmd.info "compile" ~doc ~man in + Cmd.v info Term.(const (fun files output tgt opt tc ir verb nc -> + let options = { + input_files = files; + output_file = output; + target = target_of_string tgt; + optimize = opt; + type_check = tc; + emit_ir = ir; + verbose = verb; + no_color = nc; + } in + if nc then Unix.putenv "NO_COLOR" "1"; + match files with + | [] -> + Printf.eprintf "Error: No input files\n"; + 1 + | [file] -> + compile_file options file + | files -> + let results = List.map (compile_file options) files in + if List.for_all (fun r -> r = 0) results then 0 else 1 + ) $ input_files $ output_file $ target $ optimize $ type_check $ + emit_ir $ verbose $ no_color) + +let repl_cmd = + let doc = "Start an interactive REPL" in + let man = [ + `S Manpage.s_description; + `P "Start an interactive Read-Eval-Print Loop for Star."; + ] in + let info = Cmd.info "repl" ~doc ~man in + Cmd.v info Term.(const (fun () -> + Printf.printf "Star REPL (type :quit to exit)\n"; + let ctx = Error.create_context () in + let env = ref (Check.empty_env ()) in + let rec loop () = + Printf.printf "> "; + flush stdout; + match input_line stdin with + | exception End_of_file -> 0 + | ":quit" | ":q" -> 0 + | ":help" | ":h" -> + Printf.printf "Commands:\n"; + Printf.printf " :quit, :q - Exit REPL\n"; + Printf.printf " :help, :h - Show this help\n"; + Printf.printf " :type - Show type of expression\n"; + loop () + | line when String.length line > 0 -> + let lex = Lex.create ~filename:"" line in + (match Parse.parse_expr ctx lex with + | Some expr -> + (match Check.infer_expr ctx !env expr with + | Some ty -> + Printf.printf ": %s\n" (Check.string_of_ty ty) + | None -> + Error.print_diagnostics ctx; + Error.clear_diagnostics ctx) + | None -> + Error.print_diagnostics ctx; + Error.clear_diagnostics ctx); + loop () + | _ -> loop () + in + loop () + ) $ const ()) + +let default_cmd = + let doc = "Star programming language compiler" in + let man = [ + `S Manpage.s_description; + `P "Star is a functional programming language that compiles to JavaScript and Python."; + `S Manpage.s_commands; + `P "Use $(b,star COMMAND --help) for help on a specific command."; + ] in + let info = Cmd.info "star" ~version:"0.1.0" ~doc ~man in + let default = Term.(ret (const (`Help (`Pager, None)))) in + Cmd.group info ~default [compile_cmd; repl_cmd] + +let run () = + exit (Cmd.eval default_cmd) diff --git a/src/error.ml b/src/error.ml new file mode 100644 index 0000000..3866a3d --- /dev/null +++ b/src/error.ml @@ -0,0 +1,286 @@ +type position = { + line : int; + col : int; + offset : int; +} +[@@deriving show, eq] + +type span = { + start : position; + end_ : position; + file : string; +} +[@@deriving show, eq] + +type severity = + | Error + | Warning + | Info + | Hint +[@@deriving show, eq] + +type error_kind = + | LexError + | ParseError + | TypeError + | IRError + | CodeGenError + | InternalError +[@@deriving show, eq] + +type diagnostic = { + severity : severity; + kind : error_kind; + message : string; + span : span option; + notes : (string * span option) list; + fix : string option; +} +[@@deriving show] + +type context = { + mutable diagnostics : diagnostic list; + mutable error_count : int; + mutable warning_count : int; + source_files : (string, string) Hashtbl.t; +} + +let create_context () : context = { + diagnostics = []; + error_count = 0; + warning_count = 0; + source_files = Hashtbl.create 16; +} + +let register_source ctx filename content = + Hashtbl.replace ctx.source_files filename content + +let add_diagnostic ctx diag = + ctx.diagnostics <- diag :: ctx.diagnostics; + match diag.severity with + | Error -> ctx.error_count <- ctx.error_count + 1 + | Warning -> ctx.warning_count <- ctx.warning_count + 1 + | _ -> () + +let error ?span ?notes ?fix kind message = + { + severity = Error; + kind; + message; + span; + notes = Option.value notes ~default:[]; + fix; + } + +let warning ?span ?notes ?fix kind message = + { + severity = Warning; + kind; + message; + span; + notes = Option.value notes ~default:[]; + fix; + } + +let report_error ctx ?span ?notes ?fix kind message = + add_diagnostic ctx (error ?span ?notes ?fix kind message) + +let report_warning ctx ?span ?notes ?fix kind message = + add_diagnostic ctx (warning ?span ?notes ?fix kind message) + +let has_errors ctx = ctx.error_count > 0 + +let get_source_lines filename start_line end_line = + try + let ic = open_in filename in + let rec read_lines acc line_num = + if line_num > end_line then ( + close_in ic; + List.rev acc + ) else + match input_line ic with + | line -> + let acc' = + if line_num >= start_line then (line_num, line) :: acc + else acc + in + read_lines acc' (line_num + 1) + | exception End_of_file -> + close_in ic; + List.rev acc + in + read_lines [] 1 + with Sys_error _ -> [] + +let get_cached_source_lines ctx filename start_line end_line = + match Hashtbl.find_opt ctx.source_files filename with + | Some content -> + let lines = String.split_on_char '\n' content in + let indexed = + List.mapi (fun i line -> (i + 1, line)) lines + |> List.filter (fun (num, _) -> num >= start_line && num <= end_line) + in + indexed + | None -> get_source_lines filename start_line end_line + +let use_color () = + match Sys.getenv_opt "NO_COLOR" with + | Some _ -> false + | None -> ( + match Sys.getenv_opt "TERM" with + | Some term when term <> "dumb" -> Unix.isatty Unix.stdout + | _ -> false) + +let color_reset = "\027[0m" +let color_red = "\027[31m" +let color_yellow = "\027[33m" +let color_blue = "\027[34m" +let color_cyan = "\027[36m" +let color_bold = "\027[1m" + +let with_color color s = + if use_color () then color ^ s ^ color_reset else s + +let format_severity = function + | Error -> with_color (color_bold ^ color_red) "error" + | Warning -> with_color (color_bold ^ color_yellow) "warning" + | Info -> with_color (color_bold ^ color_blue) "info" + | Hint -> with_color (color_bold ^ color_cyan) "hint" + +let format_kind = function + | LexError -> "lex" + | ParseError -> "parse" + | TypeError -> "type" + | IRError -> "ir" + | CodeGenError -> "codegen" + | InternalError -> "internal" + +let format_position pos = + Printf.sprintf "%d:%d" pos.line (pos.col + 1) + +let format_span span = + Printf.sprintf "%s:%s" span.file (format_position span.start) + +let format_source_context ctx span = + let start_line = max 1 (span.start.line - 3) in + let end_line = span.end_.line + 3 in + let lines = get_cached_source_lines ctx span.file start_line end_line in + + let max_line_num = List.fold_left (fun acc (num, _) -> max acc num) 0 lines in + let line_num_width = String.length (string_of_int max_line_num) in + + let format_line (line_num, line_text) = + let line_num_str = Printf.sprintf "%*d" line_num_width line_num in + let is_error_line = line_num >= span.start.line && line_num <= span.end_.line in + + if is_error_line then + let prefix = with_color color_blue (line_num_str ^ " | ") in + let underline = + if line_num = span.start.line && line_num = span.end_.line then + let spaces = String.make (line_num_width + 3 + span.start.col) ' ' in + let carets = String.make (max 1 (span.end_.col - span.start.col)) '^' in + spaces ^ with_color color_red carets + else if line_num = span.start.line then + let spaces = String.make (line_num_width + 3 + span.start.col) ' ' in + let carets = String.make (String.length line_text - span.start.col) '^' in + spaces ^ with_color color_red carets + else if line_num = span.end_.line then + let spaces = String.make (line_num_width + 3) ' ' in + let carets = String.make span.end_.col '^' in + spaces ^ with_color color_red carets + else + let spaces = String.make (line_num_width + 3) ' ' in + let carets = String.make (String.length line_text) '^' in + spaces ^ with_color color_red carets + in + prefix ^ line_text ^ "\n" ^ underline + else + let prefix = with_color color_blue (line_num_str ^ " | ") in + prefix ^ line_text + in + + String.concat "\n" (List.map format_line lines) + +let format_diagnostic ctx diag = + let buf = Buffer.create 256 in + + let header = + match diag.span with + | Some span -> + Printf.sprintf "%s: [%s] %s" + (format_severity diag.severity) + (format_kind diag.kind) + (format_span span) + | None -> + Printf.sprintf "%s: [%s]" + (format_severity diag.severity) + (format_kind diag.kind) + in + Buffer.add_string buf (with_color color_bold header); + Buffer.add_char buf '\n'; + + Buffer.add_string buf (" " ^ diag.message); + Buffer.add_char buf '\n'; + + (match diag.span with + | Some span -> + Buffer.add_char buf '\n'; + Buffer.add_string buf (format_source_context ctx span); + Buffer.add_char buf '\n' + | None -> ()); + + List.iter (fun (note, span_opt) -> + Buffer.add_string buf ("\n " ^ with_color color_cyan "note:" ^ " " ^ note); + Buffer.add_char buf '\n'; + match span_opt with + | Some span -> + Buffer.add_string buf (format_source_context ctx span); + Buffer.add_char buf '\n' + | None -> () + ) diag.notes; + + (match diag.fix with + | Some fix -> + Buffer.add_string buf ("\n " ^ with_color color_cyan "help:" ^ " " ^ fix); + Buffer.add_char buf '\n' + | None -> ()); + + Buffer.contents buf + +let format_all_diagnostics ctx = + let diags = List.rev ctx.diagnostics in + String.concat "\n\n" (List.map (format_diagnostic ctx) diags) + +let print_diagnostics ctx = + if ctx.diagnostics <> [] then begin + prerr_endline (format_all_diagnostics ctx); + prerr_endline ""; + + let summary = + match (ctx.error_count, ctx.warning_count) with + | (0, 0) -> "" + | (e, 0) -> + with_color color_red (Printf.sprintf "%d error%s" e (if e = 1 then "" else "s")) + | (0, w) -> + with_color color_yellow (Printf.sprintf "%d warning%s" w (if w = 1 then "" else "s")) + | (e, w) -> + with_color color_red (Printf.sprintf "%d error%s" e (if e = 1 then "" else "s")) ^ + " and " ^ + with_color color_yellow (Printf.sprintf "%d warning%s" w (if w = 1 then "" else "s")) + in + if summary <> "" then + prerr_endline (with_color color_bold ("Compilation finished with " ^ summary)) + end + +let clear_diagnostics ctx = + ctx.diagnostics <- []; + ctx.error_count <- 0; + ctx.warning_count <- 0 + +let make_position ~line ~col ~offset = { line; col; offset } + +let make_span ~start ~end_ ~file = { start; end_; file } + +let merge_spans s1 s2 = + if s1.file <> s2.file then s1 + else { file = s1.file; start = s1.start; end_ = s2.end_ } diff --git a/src/gen.ml b/src/gen.ml new file mode 100644 index 0000000..9f83bb2 --- /dev/null +++ b/src/gen.ml @@ -0,0 +1,419 @@ +open Ir +open Printf + +type target = JavaScript | Python + +type gen_options = { + target : target; + output_file : string; + source_map : bool; + minify : bool; + runtime_path : string option; +} + +type indent = int + +type pp_buf = { + mutable content : string list; + mutable indent_level : int; + indent_str : string; +} + +let create_buf ?(indent=" ") () = + { content = []; indent_level = 0; indent_str = indent } + +let emit buf s = + buf.content <- s :: buf.content + +let emit_line buf s = + let indent = String.make (buf.indent_level * String.length buf.indent_str) ' ' in + buf.content <- ("\n" :: s :: indent :: buf.content) + +let increase_indent buf = + buf.indent_level <- buf.indent_level + 1 + +let decrease_indent buf = + buf.indent_level <- max 0 (buf.indent_level - 1) + +let get_output buf = + String.concat "" (List.rev buf.content) + +let mangle_var (name, id) = + if id = 0 then name + else sprintf "%s_%d" name id + +module JavaScript = struct + let rec gen_value buf = function + | IVar var -> emit buf (mangle_var var) + | IInt n -> emit buf (Int64.to_string n) + | IFloat f -> emit buf (string_of_float f) + | IString s -> emit buf (sprintf "\"%s\"" (String.escaped s)) + | IChar c -> emit buf (sprintf "\"%s\"" (Char.escaped (Char.chr (min c 255)))) + | IBool b -> emit buf (if b then "true" else "false") + | IUnit -> emit buf "undefined" + + let rec gen_simple buf = function + | SValue v -> gen_value buf v + | SLambda lambda -> gen_lambda buf lambda + | STuple vars -> + emit buf "["; + gen_comma_list buf vars (fun buf v -> emit buf (mangle_var v)); + emit buf "]" + | SRecord fields -> + emit buf "{"; + gen_comma_list buf fields (fun buf (name, var) -> + emit buf (sprintf "\"%s\": %s" name (mangle_var var))); + emit buf "}" + | SVariant (name, Some var) -> + emit buf (sprintf "{tag: \"%s\", value: %s}" name (mangle_var var)) + | SVariant (name, None) -> + emit buf (sprintf "{tag: \"%s\"}" name) + | SPrim (op, vars) -> gen_prim buf op vars + | SApp (f, args) -> + emit buf (mangle_var f); + emit buf "("; + gen_comma_list buf args (fun buf v -> emit buf (mangle_var v)); + emit buf ")" + | SRecordAccess (var, field) -> + emit buf (sprintf "%s.%s" (mangle_var var) field) + + and gen_lambda buf lambda = + emit buf "("; + gen_comma_list buf lambda.params (fun buf v -> emit buf (mangle_var v)); + emit buf ") => {"; + increase_indent buf; + gen_expr buf lambda.body; + decrease_indent buf; + emit_line buf "}" + + and gen_prim buf op vars = + match op, vars with + | PAdd, [v1; v2] -> emit buf (sprintf "(%s + %s)" (mangle_var v1) (mangle_var v2)) + | PSub, [v1; v2] -> emit buf (sprintf "(%s - %s)" (mangle_var v1) (mangle_var v2)) + | PMul, [v1; v2] -> emit buf (sprintf "(%s * %s)" (mangle_var v1) (mangle_var v2)) + | PDiv, [v1; v2] -> emit buf (sprintf "(%s / %s)" (mangle_var v1) (mangle_var v2)) + | PMod, [v1; v2] -> emit buf (sprintf "(%s %% %s)" (mangle_var v1) (mangle_var v2)) + | PEq, [v1; v2] -> emit buf (sprintf "(%s === %s)" (mangle_var v1) (mangle_var v2)) + | PNe, [v1; v2] -> emit buf (sprintf "(%s !== %s)" (mangle_var v1) (mangle_var v2)) + | PLt, [v1; v2] -> emit buf (sprintf "(%s < %s)" (mangle_var v1) (mangle_var v2)) + | PLe, [v1; v2] -> emit buf (sprintf "(%s <= %s)" (mangle_var v1) (mangle_var v2)) + | PGt, [v1; v2] -> emit buf (sprintf "(%s > %s)" (mangle_var v1) (mangle_var v2)) + | PGe, [v1; v2] -> emit buf (sprintf "(%s >= %s)" (mangle_var v1) (mangle_var v2)) + | PAnd, [v1; v2] -> emit buf (sprintf "(%s && %s)" (mangle_var v1) (mangle_var v2)) + | POr, [v1; v2] -> emit buf (sprintf "(%s || %s)" (mangle_var v1) (mangle_var v2)) + | PNot, [v] -> emit buf (sprintf "(!%s)" (mangle_var v)) + | PNeg, [v] -> emit buf (sprintf "(-%s)" (mangle_var v)) + | PCons, [v1; v2] -> emit buf (sprintf "[%s, ...%s]" (mangle_var v1) (mangle_var v2)) + | PConcat, [v1; v2] -> emit buf (sprintf "(%s + %s)" (mangle_var v1) (mangle_var v2)) + | _ -> emit buf "undefined" + + and gen_expr buf = function + | IValue v -> + emit_line buf "return "; + gen_value buf v; + emit buf ";" + | ILet (var, simple, body) -> + emit_line buf (sprintf "const %s = " (mangle_var var)); + gen_simple buf simple; + emit buf ";"; + gen_expr buf body + | ILetRec (bindings, body) -> + List.iter (fun (var, lambda) -> + emit_line buf (sprintf "const %s = " (mangle_var var)); + gen_lambda buf lambda; + emit buf ";" + ) bindings; + gen_expr buf body + | IApp (f, args) -> + emit_line buf "return "; + emit buf (mangle_var f); + emit buf "("; + gen_comma_list buf args (fun buf v -> emit buf (mangle_var v)); + emit buf ");" + | IIf (cond, then_e, else_e) -> + emit_line buf (sprintf "if (%s) {" (mangle_var cond)); + increase_indent buf; + gen_expr buf then_e; + decrease_indent buf; + emit_line buf "} else {"; + increase_indent buf; + gen_expr buf else_e; + decrease_indent buf; + emit_line buf "}" + | IMatch (scrutinee, cases) -> + emit_line buf (sprintf "switch (%s.tag) {" (mangle_var scrutinee)); + increase_indent buf; + List.iter (fun (pat, case_e) -> + gen_match_case buf scrutinee pat case_e + ) cases; + decrease_indent buf; + emit_line buf "}" + | ITuple vars -> + emit_line buf "return ["; + gen_comma_list buf vars (fun buf v -> emit buf (mangle_var v)); + emit buf "];" + | IRecord fields -> + emit_line buf "return {"; + gen_comma_list buf fields (fun buf (name, var) -> + emit buf (sprintf "\"%s\": %s" name (mangle_var var))); + emit buf "};" + | IRecordAccess (var, field) -> + emit_line buf (sprintf "return %s.%s;" (mangle_var var) field) + | IVariant (name, Some var) -> + emit_line buf (sprintf "return {tag: \"%s\", value: %s};" name (mangle_var var)) + | IVariant (name, None) -> + emit_line buf (sprintf "return {tag: \"%s\"};" name) + | IPrim (op, vars) -> + emit_line buf "return "; + gen_prim buf op vars; + emit buf ";" + + and gen_match_case buf scrutinee pat case_e = + match pat with + | IPVariant (tag, _) -> + emit_line buf (sprintf "case \"%s\": {" tag); + increase_indent buf; + gen_expr buf case_e; + emit_line buf "break;"; + decrease_indent buf; + emit_line buf "}" + | _ -> + emit_line buf "default: {"; + increase_indent buf; + gen_expr buf case_e; + decrease_indent buf; + emit_line buf "}" + + and gen_comma_list buf items f = + match items with + | [] -> () + | [x] -> f buf x + | x :: xs -> + f buf x; + List.iter (fun item -> emit buf ", "; f buf item) xs + + let gen_decl buf = function + | IRLet (var, simple) -> + emit_line buf (sprintf "const %s = " (mangle_var var)); + gen_simple buf simple; + emit buf ";" + | IRLetRec bindings -> + List.iter (fun (var, lambda) -> + emit_line buf (sprintf "const %s = " (mangle_var var)); + gen_lambda buf lambda; + emit buf ";" + ) bindings + + let gen_program ?(minify=false) program = + let buf = create_buf ~indent:(if minify then "" else " ") () in + emit_line buf "// Generated by Star compiler"; + emit_line buf "\"use strict\";"; + emit_line buf ""; + List.iter (gen_decl buf) program.decls; + (match program.entry with + | Some entry -> + emit_line buf ""; + emit_line buf "// Entry point"; + emit_line buf "(function() {"; + increase_indent buf; + gen_expr buf entry; + decrease_indent buf; + emit_line buf "})();" + | None -> ()); + get_output buf +end + +module Python = struct + let rec gen_value buf = function + | IVar var -> emit buf (mangle_var var) + | IInt n -> emit buf (Int64.to_string n) + | IFloat f -> emit buf (string_of_float f) + | IString s -> emit buf (sprintf "\"%s\"" (String.escaped s)) + | IChar c -> emit buf (sprintf "\"%s\"" (Char.escaped (Char.chr (min c 255)))) + | IBool b -> emit buf (if b then "True" else "False") + | IUnit -> emit buf "None" + + let rec gen_simple buf = function + | SValue v -> gen_value buf v + | SLambda lambda -> gen_lambda buf lambda + | STuple vars -> + emit buf "("; + gen_comma_list buf vars (fun buf v -> emit buf (mangle_var v)); + emit buf (")" ^ (if List.length vars = 1 then "," else "")) + | SRecord fields -> + emit buf "{"; + gen_comma_list buf fields (fun buf (name, var) -> + emit buf (sprintf "\"%s\": %s" name (mangle_var var))); + emit buf "}" + | SVariant (name, Some var) -> + emit buf (sprintf "{\"tag\": \"%s\", \"value\": %s}" name (mangle_var var)) + | SVariant (name, None) -> + emit buf (sprintf "{\"tag\": \"%s\"}" name) + | SPrim (op, vars) -> gen_prim buf op vars + | SApp (f, args) -> + emit buf (mangle_var f); + emit buf "("; + gen_comma_list buf args (fun buf v -> emit buf (mangle_var v)); + emit buf ")" + | SRecordAccess (var, field) -> + emit buf (sprintf "%s[\"%s\"]" (mangle_var var) field) + + and gen_lambda buf lambda = + emit buf "lambda "; + gen_comma_list buf lambda.params (fun buf v -> emit buf (mangle_var v)); + emit buf ": "; + gen_lambda_body buf lambda.body + + and gen_lambda_body buf expr = + match expr with + | IValue v -> gen_value buf v + | _ -> + emit buf "("; + gen_expr_inline buf expr; + emit buf ")" + + and gen_prim buf op vars = + match op, vars with + | PAdd, [v1; v2] -> emit buf (sprintf "(%s + %s)" (mangle_var v1) (mangle_var v2)) + | PSub, [v1; v2] -> emit buf (sprintf "(%s - %s)" (mangle_var v1) (mangle_var v2)) + | PMul, [v1; v2] -> emit buf (sprintf "(%s * %s)" (mangle_var v1) (mangle_var v2)) + | PDiv, [v1; v2] -> emit buf (sprintf "(%s // %s)" (mangle_var v1) (mangle_var v2)) + | PMod, [v1; v2] -> emit buf (sprintf "(%s %% %s)" (mangle_var v1) (mangle_var v2)) + | PEq, [v1; v2] -> emit buf (sprintf "(%s == %s)" (mangle_var v1) (mangle_var v2)) + | PNe, [v1; v2] -> emit buf (sprintf "(%s != %s)" (mangle_var v1) (mangle_var v2)) + | PLt, [v1; v2] -> emit buf (sprintf "(%s < %s)" (mangle_var v1) (mangle_var v2)) + | PLe, [v1; v2] -> emit buf (sprintf "(%s <= %s)" (mangle_var v1) (mangle_var v2)) + | PGt, [v1; v2] -> emit buf (sprintf "(%s > %s)" (mangle_var v1) (mangle_var v2)) + | PGe, [v1; v2] -> emit buf (sprintf "(%s >= %s)" (mangle_var v1) (mangle_var v2)) + | PAnd, [v1; v2] -> emit buf (sprintf "(%s and %s)" (mangle_var v1) (mangle_var v2)) + | POr, [v1; v2] -> emit buf (sprintf "(%s or %s)" (mangle_var v1) (mangle_var v2)) + | PNot, [v] -> emit buf (sprintf "(not %s)" (mangle_var v)) + | PNeg, [v] -> emit buf (sprintf "(-%s)" (mangle_var v)) + | PCons, [v1; v2] -> emit buf (sprintf "[%s] + %s" (mangle_var v1) (mangle_var v2)) + | PConcat, [v1; v2] -> emit buf (sprintf "(%s + %s)" (mangle_var v1) (mangle_var v2)) + | _ -> emit buf "None" + + and gen_expr_inline buf = function + | IValue v -> gen_value buf v + | _ -> emit buf "None" + + and gen_expr buf = function + | IValue v -> + emit_line buf "return "; + gen_value buf v + | ILet (var, simple, body) -> + emit_line buf (sprintf "%s = " (mangle_var var)); + gen_simple buf simple; + gen_expr buf body + | ILetRec (bindings, body) -> + List.iter (fun (var, lambda) -> + emit_line buf (sprintf "def %s(" (mangle_var var)); + gen_comma_list buf lambda.params (fun buf v -> emit buf (mangle_var v)); + emit buf "):"; + increase_indent buf; + gen_expr buf lambda.body; + decrease_indent buf + ) bindings; + gen_expr buf body + | IApp (f, args) -> + emit_line buf "return "; + emit buf (mangle_var f); + emit buf "("; + gen_comma_list buf args (fun buf v -> emit buf (mangle_var v)); + emit buf ")" + | IIf (cond, then_e, else_e) -> + emit_line buf (sprintf "if %s:" (mangle_var cond)); + increase_indent buf; + gen_expr buf then_e; + decrease_indent buf; + emit_line buf "else:"; + increase_indent buf; + gen_expr buf else_e; + decrease_indent buf + | IMatch (scrutinee, cases) -> + emit_line buf (sprintf "_match = %s" (mangle_var scrutinee)); + List.iteri (fun i (pat, case_e) -> + let cond = if i = 0 then "if" else "elif" in + gen_match_case buf cond scrutinee pat case_e + ) cases + | ITuple vars -> + emit_line buf "return ("; + gen_comma_list buf vars (fun buf v -> emit buf (mangle_var v)); + emit buf (")" ^ (if List.length vars = 1 then "," else "")) + | IRecord fields -> + emit_line buf "return {"; + gen_comma_list buf fields (fun buf (name, var) -> + emit buf (sprintf "\"%s\": %s" name (mangle_var var))); + emit buf "}" + | IRecordAccess (var, field) -> + emit_line buf (sprintf "return %s[\"%s\"]" (mangle_var var) field) + | IVariant (name, Some var) -> + emit_line buf (sprintf "return {\"tag\": \"%s\", \"value\": %s}" name (mangle_var var)) + | IVariant (name, None) -> + emit_line buf (sprintf "return {\"tag\": \"%s\"}" name) + | IPrim (op, vars) -> + emit_line buf "return "; + gen_prim buf op vars + + and gen_match_case buf cond scrutinee pat case_e = + match pat with + | IPVariant (tag, _) -> + emit_line buf (sprintf "%s _match[\"tag\"] == \"%s\":" cond tag); + increase_indent buf; + gen_expr buf case_e; + decrease_indent buf + | _ -> + emit_line buf "else:"; + increase_indent buf; + gen_expr buf case_e; + decrease_indent buf + + and gen_comma_list buf items f = + match items with + | [] -> () + | [x] -> f buf x + | x :: xs -> + f buf x; + List.iter (fun item -> emit buf ", "; f buf item) xs + + let gen_decl buf = function + | IRLet (var, simple) -> + emit_line buf (sprintf "%s = " (mangle_var var)); + gen_simple buf simple + | IRLetRec bindings -> + List.iter (fun (var, lambda) -> + emit_line buf (sprintf "def %s(" (mangle_var var)); + gen_comma_list buf lambda.params (fun buf v -> emit buf (mangle_var v)); + emit buf "):"; + increase_indent buf; + gen_expr buf lambda.body; + decrease_indent buf; + emit_line buf "" + ) bindings + + let gen_program program = + let buf = create_buf ~indent:" " () in + emit_line buf "# Generated by Star compiler"; + emit_line buf ""; + List.iter (gen_decl buf) program.decls; + (match program.entry with + | Some entry -> + emit_line buf ""; + emit_line buf "if __name__ == \"__main__\":"; + increase_indent buf; + gen_expr buf entry; + decrease_indent buf + | None -> ()); + get_output buf +end + +let generate options program = + match options.target with + | JavaScript -> JavaScript.gen_program ~minify:options.minify program + | Python -> Python.gen_program program + +let generate_js ~source_map ~minify program = + JavaScript.gen_program ~minify program + +let generate_py program = + Python.gen_program program diff --git a/src/gen.mli b/src/gen.mli new file mode 100644 index 0000000..408d1c1 --- /dev/null +++ b/src/gen.mli @@ -0,0 +1,17 @@ +open Ir + +type target = JavaScript | Python + +type gen_options = { + target : target; + output_file : string; + source_map : bool; + minify : bool; + runtime_path : string option; +} + +val generate : gen_options -> ir_program -> string + +val generate_js : source_map:bool -> minify:bool -> ir_program -> string + +val generate_py : ir_program -> string diff --git a/src/ir.ml b/src/ir.ml new file mode 100644 index 0000000..dafeb2b --- /dev/null +++ b/src/ir.ml @@ -0,0 +1,340 @@ +open Ast + +type ir_var = string * int +[@@deriving show, eq] + +type ir_value = + | IVar of ir_var + | IInt of int64 + | IFloat of float + | IString of string + | IChar of int + | IBool of bool + | IUnit +[@@deriving show] + +type ir_expr = + | IValue of ir_value + | ILet of ir_var * ir_simple * ir_expr + | ILetRec of (ir_var * ir_lambda) list * ir_expr + | IApp of ir_var * ir_var list + | IIf of ir_var * ir_expr * ir_expr + | IMatch of ir_var * (ir_pattern * ir_expr) list + | ITuple of ir_var list + | IRecord of (string * ir_var) list + | IRecordAccess of ir_var * string + | IVariant of string * ir_var option + | IPrim of prim_op * ir_var list +[@@deriving show] + +and ir_simple = + | SValue of ir_value + | SLambda of ir_lambda + | STuple of ir_var list + | SRecord of (string * ir_var) list + | SVariant of string * ir_var option + | SPrim of prim_op * ir_var list + | SApp of ir_var * ir_var list + | SRecordAccess of ir_var * string +[@@deriving show] + +and ir_lambda = { + params : ir_var list; + free_vars : ir_var list; + body : ir_expr; +} +[@@deriving show] + +and ir_pattern = + | IPWild + | IPVar of ir_var + | IPLit of ir_value + | IPCons of ir_pattern * ir_pattern + | IPTuple of ir_pattern list + | IPRecord of (string * ir_pattern) list + | IPVariant of string * ir_pattern option +[@@deriving show] + +and prim_op = + | PAdd | PSub | PMul | PDiv | PMod + | PEq | PNe | PLt | PLe | PGt | PGe + | PAnd | POr | PNot | PNeg + | PCons | PConcat +[@@deriving show] + +type ir_decl = + | IRLet of ir_var * ir_simple + | IRLetRec of (ir_var * ir_lambda) list +[@@deriving show] + +type ir_program = { + decls : ir_decl list; + entry : ir_expr option; +} +[@@deriving show] + +type ir_context = { + mutable counter : int; + mutable bindings : (string * ir_var) list; +} + +let fresh_var ctx name = + ctx.counter <- ctx.counter + 1; + (name, ctx.counter) + +let lookup_var ctx name = + try List.assoc name ctx.bindings + with Not_found -> (name, 0) + +let bind_var ctx name var = + ctx.bindings <- (name, var) :: ctx.bindings + +let binop_to_prim = function + | Add -> PAdd | Sub -> PSub | Mul -> PMul | Div -> PDiv | Mod -> PMod + | Eq -> PEq | Ne -> PNe | Lt -> PLt | Le -> PLe | Gt -> PGt | Ge -> PGe + | And -> PAnd | Or -> POr + | Cons -> PCons | Concat -> PConcat + | Pipe -> failwith "Pipe should be desugared" + | Compose -> failwith "Compose should be desugared" + +let unop_to_prim = function + | Neg -> PNeg + | Not -> PNot + +let literal_to_ir = function + | LInt n -> IInt n + | LFloat f -> IFloat f + | LString s -> IString s + | LChar c -> IChar c + | LBool b -> IBool b + | LUnit -> IUnit + +let rec pattern_to_ir ctx pat = + match pat.pat_desc with + | PWild -> (IPWild, []) + | PVar name -> + let var = fresh_var ctx name in + bind_var ctx name var; + (IPVar var, [var]) + | PLit lit -> (IPLit (literal_to_ir lit), []) + | PCons (p1, p2) -> + let (ip1, vars1) = pattern_to_ir ctx p1 in + let (ip2, vars2) = pattern_to_ir ctx p2 in + (IPCons (ip1, ip2), vars1 @ vars2) + | PList _ -> failwith "List patterns should be desugared to cons" + | PTuple pats -> + let ips_vars = List.map (pattern_to_ir ctx) pats in + let ips = List.map fst ips_vars in + let vars = List.concat (List.map snd ips_vars) in + (IPTuple ips, vars) + | PRecord fields -> + let field_ips_vars = List.map (fun (name, p) -> + let (ip, vars) = pattern_to_ir ctx p in + ((name, ip), vars)) fields in + let field_ips = List.map fst field_ips_vars in + let vars = List.concat (List.map snd field_ips_vars) in + (IPRecord field_ips, vars) + | PVariant (name, p_opt) -> + (match p_opt with + | Some p -> + let (ip, vars) = pattern_to_ir ctx p in + (IPVariant (name, Some ip), vars) + | None -> (IPVariant (name, None), [])) + | POr _ -> failwith "Or patterns should be desugared" + | PAs _ -> failwith "As patterns should be desugared" + | PConstraint (p, _) -> pattern_to_ir ctx p + +let rec expr_to_anf ctx expr cont = + match expr.expr_desc with + | ELit lit -> + cont (IValue (literal_to_ir lit)) + | EVar name -> + let var = lookup_var ctx name in + cont (IValue (IVar var)) + | ELambda (params, body) -> + let old_bindings = ctx.bindings in + let param_vars = List.map (fun p -> + match p.pat_desc with + | PVar name -> fresh_var ctx name + | _ -> failwith "Complex patterns in lambda params not yet supported" + ) params in + List.iter2 (fun p v -> + match p.pat_desc with + | PVar name -> bind_var ctx name v + | _ -> () + ) params param_vars; + let body_ir = expr_to_anf ctx body (fun e -> e) in + let free_vars = [] in + ctx.bindings <- old_bindings; + let lambda = { params = param_vars; free_vars; body = body_ir } in + let tmp = fresh_var ctx "lambda" in + ILet (tmp, SLambda lambda, cont (IValue (IVar tmp))) + | EApp (e1, e2) -> + expr_to_anf ctx e1 (fun v1 -> + expr_to_anf ctx e2 (fun v2 -> + match (v1, v2) with + | (IValue (IVar f), IValue (IVar arg)) -> + let tmp = fresh_var ctx "app" in + ILet (tmp, SApp (f, [arg]), cont (IValue (IVar tmp))) + | _ -> failwith "Expected variables in application")) + | ELet (pat, e1, e2) -> + expr_to_anf ctx e1 (fun v1 -> + let (ip, vars) = pattern_to_ir ctx pat in + let var = match vars with [v] -> v | _ -> fresh_var ctx "let" in + (match v1 with + | IValue val1 -> + ILet (var, SValue val1, expr_to_anf ctx e2 cont) + | _ -> failwith "Expected value")) + | ELetRec (bindings, body) -> + let rec_bindings = List.map (fun (name, e) -> + let var = fresh_var ctx name in + bind_var ctx name var; + (var, e) + ) bindings in + let ir_bindings = List.map (fun (var, e) -> + match e.expr_desc with + | ELambda (params, lambda_body) -> + let param_vars = List.map (fun p -> + match p.pat_desc with + | PVar pname -> fresh_var ctx pname + | _ -> failwith "Complex lambda params" + ) params in + let body_ir = expr_to_anf ctx lambda_body (fun e -> e) in + (var, { params = param_vars; free_vars = []; body = body_ir }) + | _ -> failwith "Let rec requires lambda" + ) rec_bindings in + ILetRec (ir_bindings, expr_to_anf ctx body cont) + | EIf (cond, then_e, else_e) -> + expr_to_anf ctx cond (fun v_cond -> + match v_cond with + | IValue (IVar cond_var) -> + let then_ir = expr_to_anf ctx then_e (fun e -> e) in + let else_ir = expr_to_anf ctx else_e (fun e -> e) in + let result_var = fresh_var ctx "if_result" in + ILet (result_var, SValue IUnit, + IIf (cond_var, + (match then_ir with + | IValue v -> ILet (result_var, SValue v, cont (IValue (IVar result_var))) + | _ -> then_ir), + (match else_ir with + | IValue v -> ILet (result_var, SValue v, cont (IValue (IVar result_var))) + | _ -> else_ir))) + | _ -> failwith "Expected variable for condition") + | EMatch (e, cases) -> + expr_to_anf ctx e (fun v -> + match v with + | IValue (IVar scrutinee) -> + let ir_cases = List.map (fun (pat, guard_opt, case_e) -> + let (ip, _) = pattern_to_ir ctx pat in + let case_ir = expr_to_anf ctx case_e (fun e -> e) in + (ip, case_ir) + ) cases in + let result_var = fresh_var ctx "match_result" in + ILet (result_var, SValue IUnit, IMatch (scrutinee, ir_cases)) + | _ -> failwith "Expected variable for scrutinee") + | ETuple exprs -> + normalize_list ctx exprs (fun vars -> + let tmp = fresh_var ctx "tuple" in + ILet (tmp, STuple vars, cont (IValue (IVar tmp)))) + | EList _ -> failwith "Lists should be desugared" + | ERecord fields -> + let (names, exprs) = List.split fields in + normalize_list ctx exprs (fun vars -> + let tmp = fresh_var ctx "record" in + let field_pairs = List.combine names vars in + ILet (tmp, SRecord field_pairs, cont (IValue (IVar tmp)))) + | ERecordAccess (e, field) -> + expr_to_anf ctx e (fun v -> + match v with + | IValue (IVar var) -> + let tmp = fresh_var ctx "field" in + ILet (tmp, SRecordAccess (var, field), cont (IValue (IVar tmp))) + | _ -> failwith "Expected variable") + | ERecordUpdate (base, fields) -> + failwith "Record update not yet supported in IR" + | EVariant (name, e_opt) -> + (match e_opt with + | Some e -> + expr_to_anf ctx e (fun v -> + match v with + | IValue (IVar var) -> + let tmp = fresh_var ctx "variant" in + ILet (tmp, SVariant (name, Some var), cont (IValue (IVar tmp))) + | _ -> failwith "Expected variable") + | None -> + let tmp = fresh_var ctx "variant" in + ILet (tmp, SVariant (name, None), cont (IValue (IVar tmp)))) + | EBinop (op, e1, e2) -> + expr_to_anf ctx e1 (fun v1 -> + expr_to_anf ctx e2 (fun v2 -> + match (v1, v2) with + | (IValue (IVar var1), IValue (IVar var2)) -> + let tmp = fresh_var ctx "binop" in + let prim = binop_to_prim op in + ILet (tmp, SPrim (prim, [var1; var2]), cont (IValue (IVar tmp))) + | _ -> failwith "Expected variables")) + | EUnop (op, e) -> + expr_to_anf ctx e (fun v -> + match v with + | IValue (IVar var) -> + let tmp = fresh_var ctx "unop" in + let prim = unop_to_prim op in + ILet (tmp, SPrim (prim, [var]), cont (IValue (IVar tmp))) + | _ -> failwith "Expected variable") + | ESequence exprs -> + sequence_to_anf ctx exprs cont + | EConstraint (e, _) -> + expr_to_anf ctx e cont + | EHole -> + cont (IValue IUnit) + +and normalize_list ctx exprs cont = + match exprs with + | [] -> cont [] + | e :: rest -> + expr_to_anf ctx e (fun v -> + match v with + | IValue (IVar var) -> + normalize_list ctx rest (fun vars -> + cont (var :: vars)) + | _ -> failwith "Expected variable") + +and sequence_to_anf ctx exprs cont = + match exprs with + | [] -> cont (IValue IUnit) + | [e] -> expr_to_anf ctx e cont + | e :: rest -> + expr_to_anf ctx e (fun _ -> + sequence_to_anf ctx rest cont) + +let decl_to_ir ctx decl = + match decl.decl_desc with + | DLet (is_rec, pat, e) -> + if is_rec then + failwith "Recursive let at top level not yet supported" + else + let (ip, vars) = pattern_to_ir ctx pat in + let var = match vars with [v] -> v | _ -> fresh_var ctx "decl" in + let e_anf = expr_to_anf ctx e (fun v -> + match v with + | IValue val1 -> IValue val1 + | _ -> IValue IUnit) in + (match e_anf with + | IValue v -> IRLet (var, SValue v) + | ILet (v, simple, body) -> IRLet (v, simple) + | _ -> IRLet (var, SValue IUnit)) + | DExpr e -> + let var = fresh_var ctx "expr" in + let e_anf = expr_to_anf ctx e (fun v -> + match v with + | IValue val1 -> IValue val1 + | _ -> IValue IUnit) in + (match e_anf with + | IValue v -> IRLet (var, SValue v) + | _ -> IRLet (var, SValue IUnit)) + | _ -> IRLet ((":skip", 0), SValue IUnit) + +let program_to_ir program = + let ctx = { counter = 0; bindings = [] } in + let ir_decls = List.map (decl_to_ir ctx) program.declarations in + { decls = ir_decls; entry = None } diff --git a/src/lex.ml b/src/lex.ml new file mode 100644 index 0000000..bfff5a9 --- /dev/null +++ b/src/lex.ml @@ -0,0 +1,674 @@ +open Error + +type token = + | INT of int64 + | FLOAT of float + | STRING of string + | CHAR of int + | TRUE + | FALSE + | LET | REC | IN | FN | IF | THEN | ELSE + | MATCH | WITH | WHEN + | TYPE | MODULE | STRUCT | END | SIG | VAL + | FUNCTOR | OPEN | INCLUDE | AS | MUTABLE + | IDENT of string + | UIDENT of string + | TVAR of string + | PLUS | MINUS | STAR | SLASH | PERCENT + | EQ | NE | LT | LE | GT | GE + | AND | OR | NOT + | CONS | CONCAT | PIPE | COMPOSE + | ARROW | DARROW + | LPAREN | RPAREN + | LBRACKET | RBRACKET + | LBRACE | RBRACE + | COMMA | DOT | COLON | SEMICOLON + | BAR | UNDERSCORE | QUESTION + | EOF +[@@deriving show, eq] + +type lexer_state = { + filename : string; + input : string; + mutable pos : int; + mutable line : int; + mutable line_start : int; + mutable saved_pos : int; + mutable saved_line : int; + mutable saved_line_start : int; +} + +let create ~filename input = { + filename; + input; + pos = 0; + line = 1; + line_start = 0; + saved_pos = 0; + saved_line = 1; + saved_line_start = 0; +} + +let reset lex = + lex.pos <- 0; + lex.line <- 1; + lex.line_start <- 0 + +let current_position lex = + { + line = lex.line; + col = lex.pos - lex.line_start; + offset = lex.pos; + } + +let save_position lex = + lex.saved_pos <- lex.pos; + lex.saved_line <- lex.line; + lex.saved_line_start <- lex.line_start + +let restore_position lex = + lex.pos <- lex.saved_pos; + lex.line <- lex.saved_line; + lex.line_start <- lex.saved_line_start + +let at_eof lex = lex.pos >= String.length lex.input + +let peek_byte lex = + if at_eof lex then None + else Some (String.get lex.input lex.pos) + +let peek_byte_at lex offset = + let pos = lex.pos + offset in + if pos >= String.length lex.input then None + else Some (String.get lex.input pos) + +let advance lex n = + lex.pos <- lex.pos + n + +let newline lex = + lex.line <- lex.line + 1; + lex.line_start <- lex.pos + +let decode_utf8 lex = + if at_eof lex then None + else + let first = Char.code (String.get lex.input lex.pos) in + if first < 0x80 then begin + advance lex 1; + Some first + end else if first < 0xC0 then begin + advance lex 1; + None + end else if first < 0xE0 then begin + if lex.pos + 1 >= String.length lex.input then ( + advance lex 1; + None + ) else + let second = Char.code (String.get lex.input (lex.pos + 1)) in + if second land 0xC0 <> 0x80 then ( + advance lex 1; + None + ) else ( + let codepoint = ((first land 0x1F) lsl 6) lor (second land 0x3F) in + advance lex 2; + Some codepoint + ) + end else if first < 0xF0 then begin + if lex.pos + 2 >= String.length lex.input then ( + advance lex 1; + None + ) else + let second = Char.code (String.get lex.input (lex.pos + 1)) in + let third = Char.code (String.get lex.input (lex.pos + 2)) in + if second land 0xC0 <> 0x80 || third land 0xC0 <> 0x80 then ( + advance lex 1; + None + ) else ( + let codepoint = ((first land 0x0F) lsl 12) lor + ((second land 0x3F) lsl 6) lor + (third land 0x3F) in + advance lex 3; + Some codepoint + ) + end else if first < 0xF8 then begin + if lex.pos + 3 >= String.length lex.input then ( + advance lex 1; + None + ) else + let second = Char.code (String.get lex.input (lex.pos + 1)) in + let third = Char.code (String.get lex.input (lex.pos + 2)) in + let fourth = Char.code (String.get lex.input (lex.pos + 3)) in + if second land 0xC0 <> 0x80 || third land 0xC0 <> 0x80 || fourth land 0xC0 <> 0x80 then ( + advance lex 1; + None + ) else ( + let codepoint = ((first land 0x07) lsl 18) lor + ((second land 0x3F) lsl 12) lor + ((third land 0x3F) lsl 6) lor + (fourth land 0x3F) in + advance lex 4; + Some codepoint + ) + end else begin + advance lex 1; + None + end + +let is_whitespace = function + | 0x20 | 0x09 | 0x0A | 0x0D -> true + | cp when cp >= 0x2000 && cp <= 0x200B -> true + | 0x202F | 0x205F | 0x3000 -> true + | _ -> false + +let is_ident_start = function + | cp when cp >= Char.code 'a' && cp <= Char.code 'z' -> true + | cp when cp >= Char.code 'A' && cp <= Char.code 'Z' -> true + | 0x5F -> true + | cp when cp >= 0x80 && cp < 0xD800 -> true + | cp when cp >= 0xE000 && cp <= 0x10FFFF -> true + | _ -> false + +let is_ident_cont cp = + is_ident_start cp || + (cp >= Char.code '0' && cp <= Char.code '9') || + cp = 0x27 + +let is_digit c = + c >= '0' && c <= '9' + +let is_hex_digit c = + (c >= '0' && c <= '9') || + (c >= 'a' && c <= 'f') || + (c >= 'A' && c <= 'F') + +let is_octal_digit c = + c >= '0' && c <= '7' + +let is_binary_digit c = + c = '0' || c = '1' + +let rec skip_whitespace_and_comments ctx lex = + match peek_byte lex with + | None -> () + | Some c when c = ' ' || c = '\t' || c = '\r' -> advance lex 1; skip_whitespace_and_comments ctx lex + | Some '\n' -> advance lex 1; newline lex; skip_whitespace_and_comments ctx lex + | Some '/' -> + (match peek_byte_at lex 1 with + | Some '/' -> skip_line_comment lex; skip_whitespace_and_comments ctx lex + | Some '*' -> skip_block_comment ctx lex 0; skip_whitespace_and_comments ctx lex + | _ -> ()) + | _ -> + save_position lex; + (match decode_utf8 lex with + | Some cp when is_whitespace cp -> + if cp = 0x0A then newline lex; + skip_whitespace_and_comments ctx lex + | _ -> restore_position lex) + +and skip_line_comment lex = + advance lex 2; + let rec loop () = + match peek_byte lex with + | None -> () + | Some '\n' -> () + | Some _ -> advance lex 1; loop () + in + loop () + +and skip_block_comment ctx lex depth = + let start_pos = current_position lex in + advance lex 2; + let rec loop depth = + if at_eof lex then begin + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Unterminated block comment" + end else + match peek_byte lex with + | Some '*' when peek_byte_at lex 1 = Some '/' -> + advance lex 2; + if depth = 0 then () else loop (depth - 1) + | Some '/' when peek_byte_at lex 1 = Some '*' -> + advance lex 2; + loop (depth + 1) + | Some '\n' -> + advance lex 1; + newline lex; + loop depth + | Some _ -> + advance lex 1; + loop depth + | None -> () + in + loop depth + +let read_string ctx lex = + let start_pos = current_position lex in + advance lex 1; + let buf = Buffer.create 32 in + + let rec loop () = + if at_eof lex then begin + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Unterminated string literal"; + Buffer.contents buf + end else + match peek_byte lex with + | Some '"' -> + advance lex 1; + Buffer.contents buf + | Some '\\' -> + advance lex 1; + read_escape ctx lex buf; + loop () + | Some '\n' -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Newline in string literal (use \\n instead)"; + advance lex 1; + newline lex; + Buffer.add_char buf '\n'; + loop () + | Some c -> + advance lex 1; + Buffer.add_char buf c; + loop () + | None -> + Buffer.contents buf + in + loop () + +and read_escape ctx lex buf = + let start_pos = current_position lex in + match peek_byte lex with + | None -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Incomplete escape sequence" + | Some 'n' -> advance lex 1; Buffer.add_char buf '\n' + | Some 't' -> advance lex 1; Buffer.add_char buf '\t' + | Some 'r' -> advance lex 1; Buffer.add_char buf '\r' + | Some '\\' -> advance lex 1; Buffer.add_char buf '\\' + | Some '"' -> advance lex 1; Buffer.add_char buf '"' + | Some '\'' -> advance lex 1; Buffer.add_char buf '\'' + | Some '0' -> advance lex 1; Buffer.add_char buf '\000' + | Some 'x' -> + advance lex 1; + read_hex_escape ctx lex buf 2 + | Some 'u' -> + advance lex 1; + if peek_byte lex = Some '{' then ( + advance lex 1; + read_unicode_escape ctx lex buf + ) else + read_hex_escape ctx lex buf 4 + | Some 'U' -> + advance lex 1; + read_hex_escape ctx lex buf 8 + | Some c -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError + (Printf.sprintf "Invalid escape sequence: \\%c" c); + advance lex 1; + Buffer.add_char buf c + +and read_hex_escape ctx lex buf count = + let start_pos = current_position lex in + let rec read_digits acc n = + if n = 0 then acc + else match peek_byte lex with + | Some c when is_hex_digit c -> + advance lex 1; + let digit = if c >= '0' && c <= '9' then Char.code c - Char.code '0' + else if c >= 'a' && c <= 'f' then Char.code c - Char.code 'a' + 10 + else Char.code c - Char.code 'A' + 10 in + read_digits (acc * 16 + digit) (n - 1) + | _ -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError + (Printf.sprintf "Invalid hex escape: expected %d hex digits" count); + acc + in + let value = read_digits 0 count in + if value <= 0x10FFFF then + Uutf.Buffer.add_utf_8 buf (Uchar.of_int value) + else begin + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError + (Printf.sprintf "Invalid Unicode codepoint: U+%X" value) + end + +and read_unicode_escape ctx lex buf = + let start_pos = current_position lex in + let rec read_digits acc = + match peek_byte lex with + | Some c when is_hex_digit c -> + advance lex 1; + let digit = if c >= '0' && c <= '9' then Char.code c - Char.code '0' + else if c >= 'a' && c <= 'f' then Char.code c - Char.code 'a' + 10 + else Char.code c - Char.code 'A' + 10 in + read_digits (acc * 16 + digit) + | Some '}' -> + advance lex 1; + acc + | _ -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Invalid Unicode escape: expected }"; + acc + in + let value = read_digits 0 in + if value <= 0x10FFFF then + Uutf.Buffer.add_utf_8 buf (Uchar.of_int value) + else begin + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError + (Printf.sprintf "Invalid Unicode codepoint: U+%X" value) + end + +let read_char ctx lex = + let start_pos = current_position lex in + advance lex 1; + + let codepoint = + match peek_byte lex with + | None | Some '\n' | Some '\'' -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Empty character literal"; + 0 + | Some '\\' -> + advance lex 1; + let buf = Buffer.create 4 in + read_escape ctx lex buf; + let s = Buffer.contents buf in + if String.length s = 0 then 0 + else Char.code s.[0] + | Some _ -> + (match decode_utf8 lex with + | Some cp -> cp + | None -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Invalid UTF-8 in character literal"; + 0) + in + + (match peek_byte lex with + | Some '\'' -> advance lex 1 + | _ -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Unterminated character literal"); + + codepoint + +let read_ident lex = + let start = lex.pos in + let rec loop () = + save_position lex; + match decode_utf8 lex with + | Some cp when is_ident_cont cp -> loop () + | _ -> restore_position lex + in + loop (); + String.sub lex.input start (lex.pos - start) + +let read_number ctx lex = + let start_pos = current_position lex in + let start = lex.pos in + + + let base, skip = + if peek_byte lex = Some '0' then + match peek_byte_at lex 1 with + | Some 'x' | Some 'X' -> (16, 2) + | Some 'o' | Some 'O' -> (8, 2) + | Some 'b' | Some 'B' -> (2, 2) + | _ -> (10, 0) + else (10, 0) + in + + advance lex skip; + + let is_valid_digit c = + match base with + | 2 -> is_binary_digit c + | 8 -> is_octal_digit c + | 10 -> is_digit c + | 16 -> is_hex_digit c + | _ -> false + in + + let rec read_digits () = + match peek_byte lex with + | Some c when is_valid_digit c -> advance lex 1; read_digits () + | Some '_' -> advance lex 1; read_digits () + | _ -> () + in + read_digits (); + + let is_float = + base = 10 && peek_byte lex = Some '.' && + match peek_byte_at lex 1 with + | Some c -> is_digit c + | None -> false + in + + if is_float then begin + advance lex 1; + read_digits (); + + (match peek_byte lex with + | Some 'e' | Some 'E' -> + advance lex 1; + (match peek_byte lex with + | Some '+' | Some '-' -> advance lex 1 + | _ -> ()); + read_digits () + | _ -> ()); + + let str = String.sub lex.input start (lex.pos - start) in + let str = String.map (fun c -> if c = '_' then ' ' else c) str in + let str = String.concat "" (String.split_on_char ' ' str) in + (try FLOAT (float_of_string str) + with Failure _ -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Invalid floating-point literal"; + FLOAT 0.0) + end else begin + let str = String.sub lex.input (start + skip) (lex.pos - start - skip) in + let str = String.map (fun c -> if c = '_' then ' ' else c) str in + let str = String.concat "" (String.split_on_char ' ' str) in + (try INT (Int64.of_string (if base = 10 then str else "0" ^ String.make 1 (String.get lex.input (start + 1)) ^ str)) + with Failure _ -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Invalid integer literal (overflow or malformed)"; + INT 0L) + end + +let keywords = [ + "let", LET; "rec", REC; "in", IN; "fn", FN; + "if", IF; "then", THEN; "else", ELSE; + "match", MATCH; "with", WITH; "when", WHEN; + "type", TYPE; "module", MODULE; "struct", STRUCT; + "end", END; "sig", SIG; "val", VAL; + "functor", FUNCTOR; "open", OPEN; "include", INCLUDE; + "as", AS; "mutable", MUTABLE; + "true", TRUE; "false", FALSE; + "and", AND; "or", OR; "not", NOT; +] + +let keyword_table = Hashtbl.create 32 +let () = List.iter (fun (k, v) -> Hashtbl.add keyword_table k v) keywords + +let next_token ctx lex = + skip_whitespace_and_comments ctx lex; + + if at_eof lex then + let pos = current_position lex in + Some (EOF, { start = pos; end_ = pos; file = lex.filename }) + else + let start_pos = current_position lex in + + let make_token tok = + let end_pos = current_position lex in + let span = { start = start_pos; end_ = end_pos; file = lex.filename } in + Some (tok, span) + in + + match peek_byte lex with + | None -> None + | Some c -> + match c with + | '(' -> advance lex 1; make_token LPAREN + | ')' -> advance lex 1; make_token RPAREN + | '[' -> advance lex 1; make_token LBRACKET + | ']' -> advance lex 1; make_token RBRACKET + | '{' -> advance lex 1; make_token LBRACE + | '}' -> advance lex 1; make_token RBRACE + | ',' -> advance lex 1; make_token COMMA + | ';' -> advance lex 1; make_token SEMICOLON + | '?' -> advance lex 1; make_token QUESTION + | '.' -> advance lex 1; make_token DOT + | '_' when not (match peek_byte_at lex 1 with Some c -> is_ident_cont (Char.code c) | None -> false) -> + advance lex 1; make_token UNDERSCORE + | '\'' when peek_byte_at lex 1 <> Some '\'' -> + advance lex 1; + let id = read_ident lex in + make_token (TVAR id) + | '\'' -> + let cp = read_char ctx lex in + make_token (CHAR cp) + | '"' -> + let s = read_string ctx lex in + make_token (STRING s) + | '+' when peek_byte_at lex 1 = Some '+' -> + advance lex 2; make_token CONCAT + | '+' -> advance lex 1; make_token PLUS + | '-' when peek_byte_at lex 1 = Some '>' -> + advance lex 2; make_token ARROW + | '-' -> advance lex 1; make_token MINUS + | '*' -> advance lex 1; make_token STAR + | '/' -> advance lex 1; make_token SLASH + | '%' -> advance lex 1; make_token PERCENT + | '=' when peek_byte_at lex 1 = Some '>' -> + advance lex 2; make_token DARROW + | '=' when peek_byte_at lex 1 = Some '=' -> + advance lex 2; make_token EQ + | '=' -> + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError "Single '=' is not valid; use '==' for equality or 'let' for binding"; + advance lex 1; make_token EQ + | '!' when peek_byte_at lex 1 = Some '=' -> + advance lex 2; make_token NE + | '<' when peek_byte_at lex 1 = Some '=' -> + advance lex 2; make_token LE + | '<' -> advance lex 1; make_token LT + | '>' when peek_byte_at lex 1 = Some '=' -> + advance lex 2; make_token GE + | '>' when peek_byte_at lex 1 = Some '>' -> + advance lex 2; make_token COMPOSE + | '>' -> advance lex 1; make_token GT + | ':' when peek_byte_at lex 1 = Some ':' -> + advance lex 2; make_token CONS + | ':' -> advance lex 1; make_token COLON + | '|' when peek_byte_at lex 1 = Some '>' -> + advance lex 2; make_token PIPE + | '|' -> advance lex 1; make_token BAR + | c when is_digit c -> + make_token (read_number ctx lex) + | c when c >= 'A' && c <= 'Z' -> + let id = read_ident lex in + make_token (UIDENT id) + | c when c >= 'a' && c <= 'z' || c = '_' -> + let id = read_ident lex in + (match Hashtbl.find_opt keyword_table id with + | Some tok -> make_token tok + | None -> make_token (IDENT id)) + | _ -> + save_position lex; + (match decode_utf8 lex with + | Some cp when is_ident_start cp -> + let id = read_ident lex in + let is_upper = cp >= Char.code 'A' && cp <= Char.code 'Z' in + make_token (if is_upper then UIDENT id else IDENT id) + | _ -> + restore_position lex; + let span = { + start = start_pos; + end_ = current_position lex; + file = lex.filename; + } in + report_error ctx ~span LexError + (Printf.sprintf "Unexpected character: '%c'" c); + advance lex 1; + next_token ctx lex) + +let peek_token ctx lex = + save_position lex; + let tok = next_token ctx lex in + restore_position lex; + tok + +let tokenize ctx lex = + let rec loop acc = + match next_token ctx lex with + | None -> List.rev acc + | Some ((EOF, _) as tok) -> List.rev (tok :: acc) + | Some tok -> loop (tok :: acc) + in + loop [] diff --git a/src/lex.mli b/src/lex.mli new file mode 100644 index 0000000..a20f860 --- /dev/null +++ b/src/lex.mli @@ -0,0 +1,68 @@ +open Error + +type token = + | INT of int64 + | FLOAT of float + | STRING of string + | CHAR of int + | TRUE + | FALSE + + | LET + | REC + | IN + | FN + | IF + | THEN + | ELSE + | MATCH + | WITH + | WHEN + | TYPE + | MODULE + | STRUCT + | END + | SIG + | VAL + | FUNCTOR + | OPEN + | INCLUDE + | AS + | MUTABLE + + | IDENT of string + | UIDENT of string + | TVAR of string + + | PLUS | MINUS | STAR | SLASH | PERCENT + | EQ | NE | LT | LE | GT | GE + | AND | OR | NOT + | CONS + | CONCAT + | PIPE + | COMPOSE + | ARROW + | DARROW + + | LPAREN | RPAREN + | LBRACKET | RBRACKET + | LBRACE | RBRACE + | COMMA | DOT | COLON | SEMICOLON + | BAR | UNDERSCORE | QUESTION + + | EOF +[@@deriving show, eq] + +type lexer_state + +val create : filename:string -> string -> lexer_state + +val next_token : Error.context -> lexer_state -> (token * span) option + +val tokenize : Error.context -> lexer_state -> (token * span) list + +val current_position : lexer_state -> position + +val peek_token : Error.context -> lexer_state -> (token * span) option + +val reset : lexer_state -> unit diff --git a/src/main.ml b/src/main.ml new file mode 100644 index 0000000..cd0c47e --- /dev/null +++ b/src/main.ml @@ -0,0 +1 @@ +let () = Cli.run () diff --git a/src/opt.ml b/src/opt.ml new file mode 100644 index 0000000..ea673db --- /dev/null +++ b/src/opt.ml @@ -0,0 +1,196 @@ +open Ir + +let max_inline_size = 50 +let max_inline_depth = 10 + +let rec fold_constants expr = + match expr with + | IValue _ -> expr + | ILet (var, simple, body) -> + let folded_simple = fold_simple simple in + let folded_body = fold_constants body in + ILet (var, folded_simple, folded_body) + | ILetRec (bindings, body) -> + let folded_bindings = List.map (fun (var, lambda) -> + (var, fold_lambda lambda)) bindings in + let folded_body = fold_constants body in + ILetRec (folded_bindings, folded_body) + | IApp (f, args) -> IApp (f, args) + | IIf (cond, then_e, else_e) -> + IIf (cond, fold_constants then_e, fold_constants else_e) + | IMatch (scrutinee, cases) -> + let folded_cases = List.map (fun (pat, e) -> + (pat, fold_constants e)) cases in + IMatch (scrutinee, folded_cases) + | ITuple vars -> ITuple vars + | IRecord fields -> IRecord fields + | IRecordAccess (var, field) -> IRecordAccess (var, field) + | IVariant (name, var_opt) -> IVariant (name, var_opt) + | IPrim (op, vars) -> IPrim (op, vars) + +and fold_simple simple = + match simple with + | SPrim (PAdd, [v1; v2]) -> + simple + | SPrim (op, vars) -> SPrim (op, vars) + | SLambda lambda -> SLambda (fold_lambda lambda) + | _ -> simple + +and fold_lambda lambda = + { lambda with body = fold_constants lambda.body } + +let rec eliminate_dead_code expr = + match expr with + | IValue _ -> expr + | ILet (var, simple, body) -> + let body' = eliminate_dead_code body in + if var_used_in var body' then + ILet (var, simple, body') + else + body' + | ILetRec (bindings, body) -> + let body' = eliminate_dead_code body in + let used_bindings = List.filter (fun (var, _) -> + var_used_in var body') bindings in + if used_bindings = [] then body' + else ILetRec (used_bindings, body') + | IIf (cond, then_e, else_e) -> + IIf (cond, eliminate_dead_code then_e, eliminate_dead_code else_e) + | IMatch (scrutinee, cases) -> + IMatch (scrutinee, List.map (fun (p, e) -> (p, eliminate_dead_code e)) cases) + | _ -> expr + +and var_used_in var expr = + match expr with + | IValue (IVar v) -> v = var + | IValue _ -> false + | ILet (_, simple, body) -> + simple_uses_var var simple || var_used_in var body + | ILetRec (bindings, body) -> + List.exists (fun (_, lambda) -> lambda_uses_var var lambda) bindings || + var_used_in var body + | IApp (f, args) -> f = var || List.mem var args + | IIf (cond, then_e, else_e) -> + cond = var || var_used_in var then_e || var_used_in var else_e + | IMatch (scrutinee, cases) -> + scrutinee = var || List.exists (fun (_, e) -> var_used_in var e) cases + | ITuple vars -> List.mem var vars + | IRecord fields -> List.exists (fun (_, v) -> v = var) fields + | IRecordAccess (v, _) -> v = var + | IVariant (_, Some v) -> v = var + | IVariant (_, None) -> false + | IPrim (_, vars) -> List.mem var vars + +and simple_uses_var var = function + | SValue (IVar v) -> v = var + | SValue _ -> false + | SLambda lambda -> lambda_uses_var var lambda + | STuple vars -> List.mem var vars + | SRecord fields -> List.exists (fun (_, v) -> v = var) fields + | SVariant (_, Some v) -> v = var + | SVariant (_, None) -> false + | SPrim (_, vars) -> List.mem var vars + | SApp (f, args) -> f = var || List.mem var args + | SRecordAccess (v, _) -> v = var + +and lambda_uses_var var lambda = + not (List.mem var lambda.params) && var_used_in var lambda.body + +let rec inline_functions fuel expr = + if fuel <= 0 then expr + else match expr with + | IValue _ -> expr + | ILet (var, SLambda lambda, body) when is_small_lambda lambda -> + if count_uses var body = 1 then + inline_functions (fuel - 1) (substitute_lambda var lambda body) + else + ILet (var, SLambda (inline_lambda fuel lambda), inline_functions fuel body) + | ILet (var, simple, body) -> + ILet (var, simple, inline_functions fuel body) + | ILetRec (bindings, body) -> + let inlined_bindings = List.map (fun (v, lambda) -> + (v, inline_lambda fuel lambda)) bindings in + ILetRec (inlined_bindings, inline_functions fuel body) + | IIf (cond, then_e, else_e) -> + IIf (cond, inline_functions fuel then_e, inline_functions fuel else_e) + | IMatch (scrutinee, cases) -> + IMatch (scrutinee, List.map (fun (p, e) -> + (p, inline_functions fuel e)) cases) + | _ -> expr + +and inline_lambda fuel lambda = + { lambda with body = inline_functions fuel lambda.body } + +and is_small_lambda lambda = + size_of_expr lambda.body < max_inline_size + +and size_of_expr = function + | IValue _ -> 1 + | ILet (_, _, body) -> 1 + size_of_expr body + | ILetRec (bindings, body) -> + List.fold_left (fun acc (_, lambda) -> + acc + size_of_expr lambda.body) 0 bindings + size_of_expr body + | IApp _ -> 2 + | IIf (_, then_e, else_e) -> 1 + size_of_expr then_e + size_of_expr else_e + | IMatch (_, cases) -> + List.fold_left (fun acc (_, e) -> acc + size_of_expr e) 1 cases + | _ -> 1 + +and count_uses var expr = + match expr with + | IValue (IVar v) -> if v = var then 1 else 0 + | IValue _ -> 0 + | ILet (_, simple, body) -> + count_uses_simple var simple + count_uses var body + | ILetRec (bindings, body) -> + List.fold_left (fun acc (_, lambda) -> + acc + count_uses var lambda.body) 0 bindings + count_uses var body + | IApp (f, args) -> + (if f = var then 1 else 0) + List.fold_left (fun acc v -> + acc + if v = var then 1 else 0) 0 args + | IIf (cond, then_e, else_e) -> + (if cond = var then 1 else 0) + count_uses var then_e + count_uses var else_e + | IMatch (scrutinee, cases) -> + (if scrutinee = var then 1 else 0) + + List.fold_left (fun acc (_, e) -> acc + count_uses var e) 0 cases + | ITuple vars -> List.fold_left (fun acc v -> + acc + if v = var then 1 else 0) 0 vars + | IRecord fields -> List.fold_left (fun acc (_, v) -> + acc + if v = var then 1 else 0) 0 fields + | IRecordAccess (v, _) -> if v = var then 1 else 0 + | IVariant (_, Some v) -> if v = var then 1 else 0 + | IVariant (_, None) -> 0 + | IPrim (_, vars) -> List.fold_left (fun acc v -> + acc + if v = var then 1 else 0) 0 vars + +and count_uses_simple var = function + | SValue (IVar v) -> if v = var then 1 else 0 + | SValue _ -> 0 + | SLambda lambda -> count_uses var lambda.body + | STuple vars | SPrim (_, vars) -> List.fold_left (fun acc v -> + acc + if v = var then 1 else 0) 0 vars + | SRecord fields -> List.fold_left (fun acc (_, v) -> + acc + if v = var then 1 else 0) 0 fields + | SVariant (_, Some v) -> if v = var then 1 else 0 + | SVariant (_, None) -> 0 + | SApp (f, args) -> + (if f = var then 1 else 0) + List.fold_left (fun acc v -> + acc + if v = var then 1 else 0) 0 args + | SRecordAccess (v, _) -> if v = var then 1 else 0 + +and substitute_lambda var lambda body = + body + +let optimize_program program = + let optimize_decl = function + | IRLet (var, SLambda lambda) -> + let optimized_lambda = { + lambda with + body = fold_constants lambda.body + |> eliminate_dead_code + |> inline_functions max_inline_depth + } in + IRLet (var, SLambda optimized_lambda) + | other -> other + in + { program with decls = List.map optimize_decl program.decls } diff --git a/src/parse.ml b/src/parse.ml new file mode 100644 index 0000000..b96b86d --- /dev/null +++ b/src/parse.ml @@ -0,0 +1,853 @@ +open Ast +open Lex +open Error + +type parser_state = { + ctx : Error.context; + lex : Lex.lexer_state; + mutable current : (token * span) option; + mutable panic_mode : bool; +} + +let create_parser ctx lex = + let current = Lex.next_token ctx lex in + { ctx; lex; current; panic_mode = false } + +let current_token ps = + match ps.current with + | Some (tok, _) -> tok + | None -> EOF + +let current_span ps = + match ps.current with + | Some (_, span) -> span + | None -> + let pos = Lex.current_position ps.lex in + { start = pos; end_ = pos; file = "" } + +let advance ps = + ps.current <- Lex.next_token ps.ctx ps.lex + +let check ps tok = + current_token ps = tok + +let is_at_end ps = + check ps EOF + +let expect ps tok err_msg = + if check ps tok then begin + let span = current_span ps in + advance ps; + span + end else begin + let span = current_span ps in + report_error ps.ctx ~span ParseError err_msg; + ps.panic_mode <- true; + span + end + +let consume ps tok = + if check ps tok then begin + advance ps; + true + end else + false + +let synchronize ps = + ps.panic_mode <- false; + while not (is_at_end ps) do + match current_token ps with + | LET | TYPE | MODULE | OPEN | EOF -> () + | SEMICOLON -> advance ps; () + | _ -> advance ps + done + +let with_recovery ps f = + try + ps.panic_mode <- false; + let result = f () in + if ps.panic_mode then begin + synchronize ps; + None + end else + Some result + with e -> + synchronize ps; + None + +let parse_ident ps = + match current_token ps with + | IDENT id -> + let span = current_span ps in + advance ps; + Some (id, span) + | _ -> + report_error ps.ctx ~span:(current_span ps) ParseError + "Expected identifier"; + None + +let parse_uident ps = + match current_token ps with + | UIDENT id -> + let span = current_span ps in + advance ps; + Some (id, span) + | _ -> + report_error ps.ctx ~span:(current_span ps) ParseError + "Expected uppercase identifier"; + None + +let parse_expr ps = failwith "Forward declaration" +let parse_pattern ps = failwith "Forward declaration" +let parse_type ps = failwith "Forward declaration" + +let rec parse_type_real ps = + parse_type_arrow ps + +and parse_type_arrow ps = + let start_span = current_span ps in + let t1 = parse_type_primary ps in + if consume ps ARROW then + let t2 = parse_type_arrow ps in + let span = merge_spans start_span (type_span t2) in + make_type (TFun (t1, t2)) span + else + t1 + +and parse_type_primary ps = + let start_span = current_span ps in + match current_token ps with + | LPAREN -> + advance ps; + if consume ps RPAREN then + make_type (TCon ("unit", [])) (current_span ps) + else begin + let types = parse_comma_list ps parse_type_real RPAREN in + let end_span = expect ps RPAREN "Expected ')'" in + let span = merge_spans start_span end_span in + match types with + | [t] -> t + | ts -> make_type (TTuple ts) span + end + | LBRACE -> + advance ps; + let fields = parse_comma_list ps parse_record_type_field RBRACE in + let end_span = expect ps RBRACE "Expected '}'" in + make_type (TRecord fields) (merge_spans start_span end_span) + | LT -> + advance ps; + let variants = parse_bar_list ps parse_variant_type_field GT in + let end_span = expect ps GT "Expected '>'" in + make_type (TVariant variants) (merge_spans start_span end_span) + | TVAR id -> + advance ps; + make_type (TVar id) start_span + | IDENT id -> + advance ps; + let args = + if consume ps LT then + let args = parse_comma_list ps parse_type_real GT in + let _ = expect ps GT "Expected '>'" in + args + else + [] + in + let span = current_span ps in + make_type (TCon (id, args)) (merge_spans start_span span) + | UIDENT id -> + advance ps; + make_type (TCon (id, [])) start_span + | _ -> + report_error ps.ctx ~span:start_span ParseError + "Expected type expression"; + make_type (TCon ("error", [])) start_span + +and parse_record_type_field ps = + match parse_ident ps with + | Some (name, span) -> + let _ = expect ps COLON "Expected ':' in record type field" in + let ty = parse_type_real ps in + (name, ty) + | None -> + ("error", make_type (TCon ("error", [])) (current_span ps)) + +and parse_variant_type_field ps = + match parse_uident ps with + | Some (name, span) -> + if consume ps COLON then + let ty = parse_type_real ps in + (name, Some ty) + else + (name, None) + | None -> + ("Error", None) + +and parse_comma_list ps parse_elem end_tok = + let rec loop acc = + if check ps end_tok || is_at_end ps then + List.rev acc + else begin + let elem = parse_elem ps in + if consume ps COMMA then + loop (elem :: acc) + else + List.rev (elem :: acc) + end + in + loop [] + +and parse_bar_list ps parse_elem end_tok = + let rec loop acc = + if check ps end_tok || is_at_end ps then + List.rev acc + else begin + let _ = consume ps BAR in + let elem = parse_elem ps in + if consume ps BAR && not (check ps end_tok) then + loop (elem :: acc) + else + List.rev (elem :: acc) + end + in + loop [] + +let rec parse_pattern_real ps = + parse_pattern_or ps + +and parse_pattern_or ps = + let start_span = current_span ps in + let p1 = parse_pattern_as ps in + if consume ps BAR then + let p2 = parse_pattern_or ps in + let span = merge_spans start_span (pattern_span p2) in + make_pattern (POr (p1, p2)) span + else + p1 + +and parse_pattern_as ps = + let start_span = current_span ps in + let p1 = parse_pattern_cons ps in + if consume ps AS then + match parse_ident ps with + | Some (id, _) -> + let span = merge_spans start_span (current_span ps) in + make_pattern (PAs (p1, id)) span + | None -> p1 + else + p1 + +and parse_pattern_cons ps = + let start_span = current_span ps in + let p1 = parse_pattern_constraint ps in + if consume ps CONS then + let p2 = parse_pattern_cons ps in + let span = merge_spans start_span (pattern_span p2) in + make_pattern (PCons (p1, p2)) span + else + p1 + +and parse_pattern_constraint ps = + let start_span = current_span ps in + let p = parse_pattern_primary ps in + if consume ps COLON then + let ty = parse_type_real ps in + let span = merge_spans start_span (type_span ty) in + make_pattern (PConstraint (p, ty)) span + else + p + +and parse_pattern_primary ps = + let start_span = current_span ps in + match current_token ps with + | UNDERSCORE -> + advance ps; + make_pattern PWild start_span + | IDENT id -> + advance ps; + make_pattern (PVar id) start_span + | UIDENT id -> + advance ps; + if consume ps LPAREN then + let p = parse_pattern_real ps in + let end_span = expect ps RPAREN "Expected ')'" in + make_pattern (PVariant (id, Some p)) (merge_spans start_span end_span) + else + make_pattern (PVariant (id, None)) start_span + | INT n -> + advance ps; + make_pattern (PLit (LInt n)) start_span + | FLOAT f -> + advance ps; + make_pattern (PLit (LFloat f)) start_span + | STRING s -> + advance ps; + make_pattern (PLit (LString s)) start_span + | CHAR c -> + advance ps; + make_pattern (PLit (LChar c)) start_span + | TRUE -> + advance ps; + make_pattern (PLit (LBool true)) start_span + | FALSE -> + advance ps; + make_pattern (PLit (LBool false)) start_span + | LPAREN -> + advance ps; + if consume ps RPAREN then + make_pattern (PLit LUnit) (merge_spans start_span (current_span ps)) + else begin + let patterns = parse_comma_list ps parse_pattern_real RPAREN in + let end_span = expect ps RPAREN "Expected ')'" in + let span = merge_spans start_span end_span in + match patterns with + | [p] -> p + | ps -> make_pattern (PTuple ps) span + end + | LBRACKET -> + advance ps; + let patterns = parse_comma_list ps parse_pattern_real RBRACKET in + let end_span = expect ps RBRACKET "Expected ']'" in + make_pattern (PList patterns) (merge_spans start_span end_span) + | LBRACE -> + advance ps; + let fields = parse_comma_list ps parse_record_pattern_field RBRACE in + let end_span = expect ps RBRACE "Expected '}'" in + make_pattern (PRecord fields) (merge_spans start_span end_span) + | _ -> + report_error ps.ctx ~span:start_span ParseError + "Expected pattern"; + make_pattern PWild start_span + +and parse_record_pattern_field ps = + match parse_ident ps with + | Some (name, span) -> + if consume ps EQ then + let p = parse_pattern_real ps in + (name, p) + else + (name, make_pattern (PVar name) span) + | None -> + ("error", make_pattern PWild (current_span ps)) + +let rec parse_expr_real ps = + parse_expr_pipeline ps + +and parse_expr_pipeline ps = + let start_span = current_span ps in + let e1 = parse_expr_or ps in + if consume ps PIPE then + let e2 = parse_expr_pipeline ps in + let span = merge_spans start_span (expr_span e2) in + make_expr (EBinop (Pipe, e1, e2)) span + else + e1 + +and parse_expr_or ps = + parse_binop ps parse_expr_and OR Or + +and parse_expr_and ps = + parse_binop ps parse_expr_comparison AND And + +and parse_expr_comparison ps = + let start_span = current_span ps in + let e1 = parse_expr_cons ps in + match current_token ps with + | EQ -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Eq, e1, e2)) (merge_spans start_span (expr_span e2)) + | NE -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Ne, e1, e2)) (merge_spans start_span (expr_span e2)) + | LT -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Lt, e1, e2)) (merge_spans start_span (expr_span e2)) + | LE -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Le, e1, e2)) (merge_spans start_span (expr_span e2)) + | GT -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Gt, e1, e2)) (merge_spans start_span (expr_span e2)) + | GE -> advance ps; let e2 = parse_expr_cons ps in + make_expr (EBinop (Ge, e1, e2)) (merge_spans start_span (expr_span e2)) + | _ -> e1 + +and parse_expr_cons ps = + let start_span = current_span ps in + let e1 = parse_expr_additive ps in + if consume ps CONS then + let e2 = parse_expr_cons ps in + let span = merge_spans start_span (expr_span e2) in + make_expr (EBinop (Cons, e1, e2)) span + else if consume ps CONCAT then + let e2 = parse_expr_cons ps in + let span = merge_spans start_span (expr_span e2) in + make_expr (EBinop (Concat, e1, e2)) span + else + e1 + +and parse_expr_additive ps = + parse_binop_left ps parse_expr_multiplicative [PLUS, Add; MINUS, Sub] + +and parse_expr_multiplicative ps = + parse_binop_left ps parse_expr_compose [STAR, Mul; SLASH, Div; PERCENT, Mod] + +and parse_expr_compose ps = + let start_span = current_span ps in + let e1 = parse_expr_unary ps in + if consume ps COMPOSE then + let e2 = parse_expr_compose ps in + let span = merge_spans start_span (expr_span e2) in + make_expr (EBinop (Compose, e1, e2)) span + else + e1 + +and parse_binop ps parse_next tok op = + let start_span = current_span ps in + let e1 = parse_next ps in + if consume ps tok then + let e2 = parse_binop ps parse_next tok op in + let span = merge_spans start_span (expr_span e2) in + make_expr (EBinop (op, e1, e2)) span + else + e1 + +and parse_binop_left ps parse_next ops = + let start_span = current_span ps in + let rec loop e1 = + let matching_op = List.find_opt (fun (tok, _) -> check ps tok) ops in + match matching_op with + | Some (tok, op) -> + advance ps; + let e2 = parse_next ps in + let span = merge_spans start_span (expr_span e2) in + loop (make_expr (EBinop (op, e1, e2)) span) + | None -> e1 + in + loop (parse_next ps) + +and parse_expr_unary ps = + let start_span = current_span ps in + match current_token ps with + | MINUS -> + advance ps; + let e = parse_expr_unary ps in + let span = merge_spans start_span (expr_span e) in + make_expr (EUnop (Neg, e)) span + | NOT -> + advance ps; + let e = parse_expr_unary ps in + let span = merge_spans start_span (expr_span e) in + make_expr (EUnop (Not, e)) span + | _ -> + parse_expr_application ps + +and parse_expr_application ps = + let start_span = current_span ps in + let e1 = parse_expr_postfix ps in + let rec loop acc = + if check ps LPAREN || check ps LBRACKET || check ps LBRACE || + check ps INT _ || check ps FLOAT _ || check ps STRING _ || + check ps TRUE || check ps FALSE || check ps IDENT _ || check ps UIDENT _ then + let e2 = parse_expr_postfix ps in + let span = merge_spans start_span (expr_span e2) in + loop (make_expr (EApp (acc, e2)) span) + else + acc + in + loop e1 + +and parse_expr_postfix ps = + let start_span = current_span ps in + let e = parse_expr_primary ps in + let rec loop acc = + match current_token ps with + | DOT -> + advance ps; + (match parse_ident ps with + | Some (field, _) -> + let span = merge_spans start_span (current_span ps) in + loop (make_expr (ERecordAccess (acc, field)) span) + | None -> acc) + | _ -> acc + in + loop e + +and parse_expr_primary ps = + let start_span = current_span ps in + match current_token ps with + | INT n -> + advance ps; + make_expr (ELit (LInt n)) start_span + | FLOAT f -> + advance ps; + make_expr (ELit (LFloat f)) start_span + | STRING s -> + advance ps; + make_expr (ELit (LString s)) start_span + | CHAR c -> + advance ps; + make_expr (ELit (LChar c)) start_span + | TRUE -> + advance ps; + make_expr (ELit (LBool true)) start_span + | FALSE -> + advance ps; + make_expr (ELit (LBool false)) start_span + | IDENT id -> + advance ps; + make_expr (EVar id) start_span + | UIDENT id -> + advance ps; + if consume ps LPAREN then + let e = parse_expr_real ps in + let end_span = expect ps RPAREN "Expected ')'" in + make_expr (EVariant (id, Some e)) (merge_spans start_span end_span) + else + make_expr (EVariant (id, None)) start_span + | QUESTION -> + advance ps; + make_expr EHole start_span + | LPAREN -> + advance ps; + if consume ps RPAREN then + make_expr (ELit LUnit) (merge_spans start_span (current_span ps)) + else begin + let exprs = parse_comma_list ps parse_expr_real RPAREN in + let end_span = expect ps RPAREN "Expected ')'" in + let span = merge_spans start_span end_span in + match exprs with + | [e] -> e + | es -> make_expr (ETuple es) span + end + | LBRACKET -> + advance ps; + let exprs = parse_comma_list ps parse_expr_real RBRACKET in + let end_span = expect ps RBRACKET "Expected ']'" in + make_expr (EList exprs) (merge_spans start_span end_span) + | LBRACE -> + advance ps; + let base_expr = + if not (check ps RBRACE || check ps IDENT _) then + let e = parse_expr_real ps in + if consume ps WITH then Some e else ( + None + ) + else + None + in + let fields = parse_comma_list ps parse_record_field RBRACE in + let end_span = expect ps RBRACE "Expected '}'" in + let span = merge_spans start_span end_span in + (match base_expr with + | Some base -> make_expr (ERecordUpdate (base, fields)) span + | None -> make_expr (ERecord fields) span) + | FN -> + advance ps; + let patterns = parse_fn_patterns ps in + let _ = expect ps DARROW "Expected '=>'" in + let body = parse_expr_real ps in + let span = merge_spans start_span (expr_span body) in + make_expr (ELambda (patterns, body)) span + | IF -> + advance ps; + let cond = parse_expr_real ps in + let _ = expect ps THEN "Expected 'then'" in + let then_expr = parse_expr_real ps in + let _ = expect ps ELSE "Expected 'else'" in + let else_expr = parse_expr_real ps in + let span = merge_spans start_span (expr_span else_expr) in + make_expr (EIf (cond, then_expr, else_expr)) span + | MATCH -> + advance ps; + let e = parse_expr_real ps in + let _ = expect ps WITH "Expected 'with'" in + let cases = parse_bar_list ps parse_match_case END in + let end_span = expect ps END "Expected 'end'" in + make_expr (EMatch (e, cases)) (merge_spans start_span end_span) + | LET -> + advance ps; + let is_rec = consume ps REC in + if is_rec then + parse_let_rec_expr ps start_span + else + parse_let_expr ps start_span + | _ -> + report_error ps.ctx ~span:start_span ParseError + "Expected expression"; + make_expr EHole start_span + +and parse_fn_patterns ps = + let rec loop acc = + if check ps DARROW then + List.rev acc + else + let p = parse_pattern_primary ps in + loop (p :: acc) + in + loop [] + +and parse_record_field ps = + match parse_ident ps with + | Some (name, span) -> + if consume ps EQ then + let e = parse_expr_real ps in + (name, e) + else + (name, make_expr (EVar name) span) + | None -> + ("error", make_expr EHole (current_span ps)) + +and parse_match_case ps = + let p = parse_pattern_real ps in + let guard = + if consume ps WHEN then + Some (parse_expr_real ps) + else + None + in + let _ = expect ps DARROW "Expected '=>'" in + let e = parse_expr_real ps in + (p, guard, e) + +and parse_let_expr ps start_span = + let p = parse_pattern_real ps in + let _ = expect ps EQ "Expected '==' in let binding" in + let e1 = parse_expr_real ps in + let _ = expect ps IN "Expected 'in'" in + let e2 = parse_expr_real ps in + let span = merge_spans start_span (expr_span e2) in + make_expr (ELet (p, e1, e2)) span + +and parse_let_rec_expr ps start_span = + let bindings = parse_and_list ps parse_rec_binding in + let _ = expect ps IN "Expected 'in'" in + let body = parse_expr_real ps in + let span = merge_spans start_span (expr_span body) in + make_expr (ELetRec (bindings, body)) span + +and parse_rec_binding ps = + match parse_ident ps with + | Some (name, _) -> + let _ = expect ps EQ "Expected '==' in rec binding" in + let e = parse_expr_real ps in + (name, e) + | None -> + ("error", make_expr EHole (current_span ps)) + +and parse_and_list ps parse_elem = + let rec loop acc = + let elem = parse_elem ps in + if consume ps AND then + loop (elem :: acc) + else + List.rev (elem :: acc) + in + loop [] + +let rec parse_declaration_real ps = + let start_span = current_span ps in + match current_token ps with + | LET -> + advance ps; + let is_rec = consume ps REC in + let p = parse_pattern_real ps in + let _ = expect ps EQ "Expected '==' in let declaration" in + let e = parse_expr_real ps in + make_decl (DLet (is_rec, p, e)) (merge_spans start_span (expr_span e)) + | TYPE -> + advance ps; + let ty_decl = parse_type_decl ps start_span in + make_decl (DType ty_decl) (merge_spans start_span ty_decl.type_span) + | MODULE -> + advance ps; + (match parse_uident ps with + | Some (name, _) -> + let _ = expect ps EQ "Expected '==' in module declaration" in + let mod_expr = parse_module_expr ps in + make_decl (DModule (name, mod_expr)) (merge_spans start_span (mod_expr.mod_span)) + | None -> + make_decl (DExpr (make_expr EHole start_span)) start_span) + | OPEN -> + advance ps; + (match parse_uident ps with + | Some (name, _) -> + make_decl (DOpen name) (merge_spans start_span (current_span ps)) + | None -> + make_decl (DExpr (make_expr EHole start_span)) start_span) + | _ -> + let e = parse_expr_real ps in + make_decl (DExpr e) (merge_spans start_span (expr_span e)) + +and parse_type_decl ps start_span = + match parse_ident ps with + | Some (name, name_span) -> + let params = parse_type_params ps in + let _ = expect ps EQ "Expected '==' in type declaration" in + let kind = parse_type_kind ps in + { + type_name = name; + type_params = params; + type_kind = kind; + type_span = merge_spans start_span (current_span ps); + } + | None -> + { + type_name = "error"; + type_params = []; + type_kind = TAbstract; + type_span = start_span; + } + +and parse_type_params ps = + let rec loop acc = + match current_token ps with + | TVAR id -> + advance ps; + loop (id :: acc) + | _ -> List.rev acc + in + loop [] + +and parse_type_kind ps = + match current_token ps with + | LBRACE -> + advance ps; + let fields = parse_comma_list ps parse_record_type_decl_field RBRACE in + let _ = expect ps RBRACE "Expected '}'" in + TRecord fields + | BAR | UIDENT _ -> + let variants = parse_bar_list ps parse_variant_type_decl_field (SEMICOLON) in + TVariant variants + | _ -> + let ty = parse_type_real ps in + TAlias ty + +and parse_record_type_decl_field ps = + let is_mutable = consume ps MUTABLE in + match parse_ident ps with + | Some (name, _) -> + let _ = expect ps COLON "Expected ':' in record field" in + let ty = parse_type_real ps in + (name, ty, is_mutable) + | None -> + ("error", make_type (TCon ("error", [])) (current_span ps), false) + +and parse_variant_type_decl_field ps = + match parse_uident ps with + | Some (name, _) -> + if consume ps COLON then + let ty = parse_type_real ps in + (name, Some ty) + else + (name, None) + | None -> + ("Error", None) + +and parse_module_expr ps = + let start_span = current_span ps in + match current_token ps with + | STRUCT -> + advance ps; + let decls = parse_declarations_until ps END in + let end_span = expect ps END "Expected 'end'" in + make_module (MStruct decls) (merge_spans start_span end_span) + | UIDENT id -> + advance ps; + make_module (MIdent id) start_span + | FUNCTOR -> + advance ps; + let _ = expect ps LPAREN "Expected '('" in + (match parse_uident ps with + | Some (param, _) -> + let sig_opt = + if consume ps COLON then + Some (parse_module_type ps) + else + None + in + let _ = expect ps RPAREN "Expected ')'" in + let _ = expect ps DARROW "Expected '=>'" in + let body = parse_module_expr ps in + make_module (MFunctor (param, sig_opt, body)) + (merge_spans start_span (body.mod_span)) + | None -> + make_module (MStruct []) start_span) + | _ -> + report_error ps.ctx ~span:start_span ParseError + "Expected module expression"; + make_module (MStruct []) start_span + +and parse_module_type ps = + let start_span = current_span ps in + match current_token ps with + | SIG -> + advance ps; + let sigs = parse_signature_items ps in + let end_span = expect ps END "Expected 'end'" in + make_signature (SigMultiple sigs) (merge_spans start_span end_span) + | _ -> + report_error ps.ctx ~span:start_span ParseError + "Expected module type"; + make_signature (SigMultiple []) start_span + +and parse_signature_items ps = + let rec loop acc = + if check ps END || is_at_end ps then + List.rev acc + else + let sig_item = parse_signature_item ps in + loop (sig_item :: acc) + in + loop [] + +and parse_signature_item ps = + let start_span = current_span ps in + match current_token ps with + | VAL -> + advance ps; + (match parse_ident ps with + | Some (name, _) -> + let _ = expect ps COLON "Expected ':'" in + let ty = parse_type_real ps in + make_signature (SigVal (name, ty)) (merge_spans start_span (type_span ty)) + | None -> + make_signature (SigMultiple []) start_span) + | TYPE -> + advance ps; + let ty_decl = parse_type_decl ps start_span in + make_signature (SigType ty_decl) (merge_spans start_span ty_decl.type_span) + | _ -> + make_signature (SigMultiple []) start_span + +and parse_declarations_until ps end_tok = + let rec loop acc = + if check ps end_tok || is_at_end ps then + List.rev acc + else + let decl = parse_declaration_real ps in + let _ = consume ps SEMICOLON in + loop (decl :: acc) + in + loop [] + +let parse_program ctx lex = + let ps = create_parser ctx lex in + let declarations = parse_declarations_until ps EOF in + { declarations; file = lex.Lex.filename } + +let parse_expr ctx lex = + let ps = create_parser ctx lex in + with_recovery ps (fun () -> parse_expr_real ps) + +let parse_type ctx lex = + let ps = create_parser ctx lex in + with_recovery ps (fun () -> parse_type_real ps) + +let parse_pattern ctx lex = + let ps = create_parser ctx lex in + with_recovery ps (fun () -> parse_pattern_real ps) + +let parse_declaration ctx lex = + let ps = create_parser ctx lex in + with_recovery ps (fun () -> parse_declaration_real ps) + +let () = + let parse_expr = parse_expr_real in + let parse_pattern = parse_pattern_real in + let parse_type = parse_type_real in + () diff --git a/src/parse.mli b/src/parse.mli new file mode 100644 index 0000000..ba357a6 --- /dev/null +++ b/src/parse.mli @@ -0,0 +1,11 @@ +open Ast + +val parse_program : Error.context -> Lex.lexer_state -> program + +val parse_expr : Error.context -> Lex.lexer_state -> expr option + +val parse_type : Error.context -> Lex.lexer_state -> type_expr option + +val parse_pattern : Error.context -> Lex.lexer_state -> pattern option + +val parse_declaration : Error.context -> Lex.lexer_state -> declaration option diff --git a/test/corpus/arithmetic.star b/test/corpus/arithmetic.star new file mode 100644 index 0000000..018d1f5 --- /dev/null +++ b/test/corpus/arithmetic.star @@ -0,0 +1,7 @@ +// Test: Basic arithmetic +let main = fn () => + let x = 10 in + let y = 20 in + let sum = x + y in + let product = x * y in + print(sum) diff --git a/test/corpus/higher_order.star b/test/corpus/higher_order.star new file mode 100644 index 0000000..fa5406a --- /dev/null +++ b/test/corpus/higher_order.star @@ -0,0 +1,22 @@ +// Test: Higher-order functions +let map = fn f list => + match list with + | [] => [] + | x :: xs => f(x) :: map(f, xs) + end + +let filter = fn pred list => + match list with + | [] => [] + | x :: xs => + if pred(x) then + x :: filter(pred, xs) + else + filter(pred, xs) + end + +let main = fn () => + let numbers = [1, 2, 3, 4, 5] in + let doubled = map(fn x => x * 2, numbers) in + let evens = filter(fn x => x % 2 == 0, numbers) in + print(doubled) diff --git a/test/corpus/pattern_match.star b/test/corpus/pattern_match.star new file mode 100644 index 0000000..c1047e7 --- /dev/null +++ b/test/corpus/pattern_match.star @@ -0,0 +1,17 @@ +// Test: Pattern matching +type Option<'a> = + | Some of 'a + | None + +let unwrap_or = fn default opt => + match opt with + | Some(x) => x + | None => default + end + +let main = fn () => + let x = Some(42) in + let y = None in + let result1 = unwrap_or(0, x) in + let result2 = unwrap_or(0, y) in + print(result1) diff --git a/test/unit/test_lexer.ml b/test/unit/test_lexer.ml new file mode 100644 index 0000000..0b7e226 --- /dev/null +++ b/test/unit/test_lexer.ml @@ -0,0 +1,79 @@ +open Lex +open Error + +let test_simple_tokens () = + let ctx = create_context () in + let lex = create ~filename:"" "let x = 42" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (tok_types = [LET; IDENT "x"; EQ; INT 42L; EOF]); + print_endline "test_simple_tokens" + +let test_operators () = + let ctx = create_context () in + let lex = create ~filename:"" "+ - * / % == != < <= > >= && || :: ++ |> >>" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (tok_types = [ + PLUS; MINUS; STAR; SLASH; PERCENT; + EQ; NE; LT; LE; GT; GE; + AND; OR; CONS; CONCAT; PIPE; COMPOSE; EOF + ]); + print_endline "test_operators" + +let test_string_literals () = + let ctx = create_context () in + let lex = create ~filename:"" "\"hello\" \"world\\n\" \"emoji: 🌟\"" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (List.length tok_types = 4); + print_endline "test_string_literals" + +let test_numeric_literals () = + let ctx = create_context () in + let lex = create ~filename:"" "42 3.14 0xFF 0o77 0b1010" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (match tok_types with + | INT _ :: FLOAT _ :: INT _ :: INT _ :: INT _ :: EOF :: [] -> true + | _ -> false); + print_endline "test_numeric_literals" + +let test_comments () = + let ctx = create_context () in + let lex = create ~filename:"" "let // comment\nx = /* block */ 42" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (tok_types = [LET; IDENT "x"; EQ; INT 42L; EOF]); + print_endline "test_comments" + +let test_keywords () = + let ctx = create_context () in + let lex = create ~filename:"" "let rec in fn if then else match with" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (tok_types = [LET; REC; IN; FN; IF; THEN; ELSE; MATCH; WITH; EOF]); + print_endline "test_keywords" + +let test_identifiers () = + let ctx = create_context () in + let lex = create ~filename:"" "foo bar_baz qux123 CamelCase" in + let tokens = tokenize ctx lex in + let tok_types = List.map fst tokens in + assert (match tok_types with + | IDENT "foo" :: IDENT "bar_baz" :: IDENT "qux123" :: UIDENT "CamelCase" :: EOF :: [] -> true + | _ -> false); + print_endline "test_identifiers" + +let run_tests () = + print_endline "\n=== Lexer Tests ==="; + test_simple_tokens (); + test_operators (); + test_string_literals (); + test_numeric_literals (); + test_comments (); + test_keywords (); + test_identifiers (); + print_endline "All lexer tests passed!\n" + +let () = run_tests ()