from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union from dataclasses import dataclass from functools import reduce @dataclass class Variant: tag: str value: Any = None def __repr__(self): if self.value is None: return self.tag return f"{self.tag}({self.value})" def variant(tag: str, value: Any = None) -> Variant: return Variant(tag, value) def match(value: Variant, cases: dict) -> Any: if not isinstance(value, Variant): if isinstance(value, dict) and 'tag' in value: tag = value['tag'] val = value.get('value') handler = cases.get(tag, cases.get('_')) if handler is None: raise ValueError(f"match: unhandled variant tag: {tag}") return handler(val) if val is not None else handler() raise TypeError('match: expected variant value') handler = cases.get(value.tag, cases.get('_')) if handler is None: raise ValueError(f"match: unhandled variant tag: {value.tag}") return handler(value.value) if value.value is not None else handler() nil: List = [] def cons(head: Any, tail: List) -> List: return [head] + tail def head(lst: List) -> Any: if not lst: raise ValueError('head: empty list') return lst[0] def tail(lst: List) -> List: if not lst: raise ValueError('tail: empty list') return lst[1:] def is_empty(lst: List) -> bool: return len(lst) == 0 def length(lst: List) -> int: return len(lst) def map_list(f: Callable, lst: List) -> List: return [f(x) for x in lst] def filter_list(pred: Callable, lst: List) -> List: return [x for x in lst if pred(x)] def foldl(f: Callable, acc: Any, lst: List) -> Any: return reduce(f, lst, acc) def foldr(f: Callable, acc: Any, lst: List) -> Any: result = acc for x in reversed(lst): result = f(x, result) return result def concat_lists(lst1: List, lst2: List) -> List: return lst1 + lst2 def reverse_list(lst: List) -> List: return list(reversed(lst)) def take(n: int, lst: List) -> List: return lst[:n] def drop(n: int, lst: List) -> List: return lst[n:] None_variant = Variant('None') def Some(value: Any) -> Variant: return Variant('Some', value) def is_some(opt: Variant) -> bool: return opt.tag == 'Some' def is_none(opt: Variant) -> bool: return opt.tag == 'None' def unwrap(opt: Variant) -> Any: if opt.tag == 'None': raise ValueError('unwrap: None value') return opt.value def unwrap_or(opt: Variant, default: Any) -> Any: return opt.value if opt.tag == 'Some' else default def Ok(value: Any) -> Variant: return Variant('Ok', value) def Err(error: Any) -> Variant: return Variant('Err', error) def is_ok(result: Variant) -> bool: return result.tag == 'Ok' def is_err(result: Variant) -> bool: return result.tag == 'Err' def to_string(value: Any) -> str: if value is None: return 'null' if isinstance(value, list): return '[' + ', '.join(to_string(v) for v in value) + ']' if isinstance(value, Variant): return repr(value) if isinstance(value, dict) and 'tag' in value: if 'value' in value: return f"{value['tag']}({to_string(value['value'])})" return value['tag'] return str(value) def str_length(s: str) -> int: return len(s) def str_concat(s1: str, s2: str) -> str: return s1 + s2 def substring(s: str, start: int, length: int) -> str: return s[start:start+length] def int_div(a: int, b: int) -> int: return a // b def pow_int(base: int, exp: int) -> int: return base ** exp def abs_num(x: Union[int, float]) -> Union[int, float]: return abs(x) def print_value(value: Any) -> None: print(to_string(value)) def print_str(value: Any) -> None: import sys sys.stdout.write(to_string(value)) sys.stdout.flush() def identity(x: Any) -> Any: return x def constant(x: Any) -> Callable: return lambda: x def compose(f: Callable, g: Callable) -> Callable: return lambda x: f(g(x)) def pipe(value: Any, f: Callable) -> Any: return f(value) def curry(f: Callable, arity: Optional[int] = None) -> Callable: if arity is None: arity = f.__code__.co_argcount def curried(*args): if len(args) >= arity: return f(*args) return lambda *more_args: curried(*args, *more_args) return curried def record_get(record: dict, field: str) -> Any: return record[field] def record_set(record: dict, field: str, value: Any) -> dict: new_record = record.copy() new_record[field] = value return new_record def record_update(record: dict, **fields) -> dict: new_record = record.copy() new_record.update(fields) return new_record def fst(tup: Tuple) -> Any: return tup[0] def snd(tup: Tuple) -> Any: return tup[1] def logical_and(a: bool, b: bool) -> bool: return a and b def logical_or(a: bool, b: bool) -> bool: return a or b def logical_not(a: bool) -> bool: return not a def equal(a: Any, b: Any) -> bool: return a == b def not_equal(a: Any, b: Any) -> bool: return a != b def less_than(a: Any, b: Any) -> bool: return a < b def less_equal(a: Any, b: Any) -> bool: return a <= b def greater_than(a: Any, b: Any) -> bool: return a > b def greater_equal(a: Any, b: Any) -> bool: return a >= b __all__ = [ 'Variant', 'variant', 'match', 'nil', 'cons', 'head', 'tail', 'is_empty', 'length', 'map_list', 'filter_list', 'foldl', 'foldr', 'concat_lists', 'reverse_list', 'take', 'drop', 'None_variant', 'Some', 'is_some', 'is_none', 'unwrap', 'unwrap_or', 'Ok', 'Err', 'is_ok', 'is_err', 'to_string', 'str_length', 'str_concat', 'substring', 'int_div', 'pow_int', 'abs_num', 'print_value', 'print_str', 'identity', 'constant', 'compose', 'pipe', 'curry', 'record_get', 'record_set', 'record_update', 'fst', 'snd', 'logical_and', 'logical_or', 'logical_not', 'equal', 'not_equal', 'less_than', 'less_equal', 'greater_than', 'greater_equal' ]