star/lib/rt.py

238 lines
6.2 KiB
Python

from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
from dataclasses import dataclass
from functools import reduce
@dataclass
class Variant:
tag: str
value: Any = None
def __repr__(self):
if self.value is None:
return self.tag
return f"{self.tag}({self.value})"
def variant(tag: str, value: Any = None) -> Variant:
return Variant(tag, value)
def match(value: Variant, cases: dict) -> Any:
if not isinstance(value, Variant):
if isinstance(value, dict) and 'tag' in value:
tag = value['tag']
val = value.get('value')
handler = cases.get(tag, cases.get('_'))
if handler is None:
raise ValueError(f"match: unhandled variant tag: {tag}")
return handler(val) if val is not None else handler()
raise TypeError('match: expected variant value')
handler = cases.get(value.tag, cases.get('_'))
if handler is None:
raise ValueError(f"match: unhandled variant tag: {value.tag}")
return handler(value.value) if value.value is not None else handler()
nil: List = []
def cons(head: Any, tail: List) -> List:
return [head] + tail
def head(lst: List) -> Any:
if not lst:
raise ValueError('head: empty list')
return lst[0]
def tail(lst: List) -> List:
if not lst:
raise ValueError('tail: empty list')
return lst[1:]
def is_empty(lst: List) -> bool:
return len(lst) == 0
def length(lst: List) -> int:
return len(lst)
def map_list(f: Callable, lst: List) -> List:
return [f(x) for x in lst]
def filter_list(pred: Callable, lst: List) -> List:
return [x for x in lst if pred(x)]
def foldl(f: Callable, acc: Any, lst: List) -> Any:
return reduce(f, lst, acc)
def foldr(f: Callable, acc: Any, lst: List) -> Any:
result = acc
for x in reversed(lst):
result = f(x, result)
return result
def concat_lists(lst1: List, lst2: List) -> List:
return lst1 + lst2
def reverse_list(lst: List) -> List:
return list(reversed(lst))
def take(n: int, lst: List) -> List:
return lst[:n]
def drop(n: int, lst: List) -> List:
return lst[n:]
None_variant = Variant('None')
def Some(value: Any) -> Variant:
return Variant('Some', value)
def is_some(opt: Variant) -> bool:
return opt.tag == 'Some'
def is_none(opt: Variant) -> bool:
return opt.tag == 'None'
def unwrap(opt: Variant) -> Any:
if opt.tag == 'None':
raise ValueError('unwrap: None value')
return opt.value
def unwrap_or(opt: Variant, default: Any) -> Any:
return opt.value if opt.tag == 'Some' else default
def Ok(value: Any) -> Variant:
return Variant('Ok', value)
def Err(error: Any) -> Variant:
return Variant('Err', error)
def is_ok(result: Variant) -> bool:
return result.tag == 'Ok'
def is_err(result: Variant) -> bool:
return result.tag == 'Err'
def to_string(value: Any) -> str:
if value is None:
return 'null'
if isinstance(value, list):
return '[' + ', '.join(to_string(v) for v in value) + ']'
if isinstance(value, Variant):
return repr(value)
if isinstance(value, dict) and 'tag' in value:
if 'value' in value:
return f"{value['tag']}({to_string(value['value'])})"
return value['tag']
return str(value)
def str_length(s: str) -> int:
return len(s)
def str_concat(s1: str, s2: str) -> str:
return s1 + s2
def substring(s: str, start: int, length: int) -> str:
return s[start:start+length]
def int_div(a: int, b: int) -> int:
return a // b
def pow_int(base: int, exp: int) -> int:
return base ** exp
def abs_num(x: Union[int, float]) -> Union[int, float]:
return abs(x)
def print_value(value: Any) -> None:
print(to_string(value))
def print_str(value: Any) -> None:
import sys
sys.stdout.write(to_string(value))
sys.stdout.flush()
def identity(x: Any) -> Any:
return x
def constant(x: Any) -> Callable:
return lambda: x
def compose(f: Callable, g: Callable) -> Callable:
return lambda x: f(g(x))
def pipe(value: Any, f: Callable) -> Any:
return f(value)
def curry(f: Callable, arity: Optional[int] = None) -> Callable:
if arity is None:
arity = f.__code__.co_argcount
def curried(*args):
if len(args) >= arity:
return f(*args)
return lambda *more_args: curried(*args, *more_args)
return curried
def record_get(record: dict, field: str) -> Any:
return record[field]
def record_set(record: dict, field: str, value: Any) -> dict:
new_record = record.copy()
new_record[field] = value
return new_record
def record_update(record: dict, **fields) -> dict:
new_record = record.copy()
new_record.update(fields)
return new_record
def fst(tup: Tuple) -> Any:
return tup[0]
def snd(tup: Tuple) -> Any:
return tup[1]
def logical_and(a: bool, b: bool) -> bool:
return a and b
def logical_or(a: bool, b: bool) -> bool:
return a or b
def logical_not(a: bool) -> bool:
return not a
def equal(a: Any, b: Any) -> bool:
return a == b
def not_equal(a: Any, b: Any) -> bool:
return a != b
def less_than(a: Any, b: Any) -> bool:
return a < b
def less_equal(a: Any, b: Any) -> bool:
return a <= b
def greater_than(a: Any, b: Any) -> bool:
return a > b
def greater_equal(a: Any, b: Any) -> bool:
return a >= b
__all__ = [
'Variant', 'variant', 'match',
'nil', 'cons', 'head', 'tail', 'is_empty', 'length',
'map_list', 'filter_list', 'foldl', 'foldr',
'concat_lists', 'reverse_list', 'take', 'drop',
'None_variant', 'Some', 'is_some', 'is_none', 'unwrap', 'unwrap_or',
'Ok', 'Err', 'is_ok', 'is_err',
'to_string', 'str_length', 'str_concat', 'substring',
'int_div', 'pow_int', 'abs_num',
'print_value', 'print_str',
'identity', 'constant', 'compose', 'pipe', 'curry',
'record_get', 'record_set', 'record_update',
'fst', 'snd',
'logical_and', 'logical_or', 'logical_not',
'equal', 'not_equal', 'less_than', 'less_equal',
'greater_than', 'greater_equal'
]