#lang plai-typed (require plai-typed/s-exp-match) (define-type Value [numV (n : number)] [closV (arg : symbol) (body : ExprC) (env : Env)] [polyV (body : ExprC) (env : Env)]) (define-type ExprC [numC (n : number)] [idC (s : symbol)] [plusC (l : ExprC) (r : ExprC)] [multC (l : ExprC) (r : ExprC)] [lamC (n : symbol) (arg-type : Type) (body : ExprC)] [appC (fun : ExprC) (arg : ExprC)] [tylamC (n : symbol) (body : ExprC)] [tyappC (tyfun : ExprC) (tyarg : Type)]) (define-type Type [numT] [boolT] [arrowT (arg : Type) (result : Type)] [varT (n : symbol)] [forallT (n : symbol) (body : Type)]) (define-type Binding [bind (name : symbol) (val : Value)]) (define-type-alias Env (listof Binding)) (define-type TypeBinding [tbind (name : symbol) (type : Type)] [tvar (name : symbol)]) (define-type-alias TypeEnv (listof TypeBinding)) (define mt-env empty) (define extend-env cons) (module+ test (print-only-errors true)) ;; parse ---------------------------------------- (define (parse [s : s-expression]) : ExprC (cond [(s-exp-match? `NUMBER s) (numC (s-exp->number s))] [(s-exp-match? `SYMBOL s) (idC (s-exp->symbol s))] [(s-exp-match? '{+ ANY ANY} s) (plusC (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? '{* ANY ANY} s) (multC (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? '{let {[SYMBOL : ANY ANY]} ANY} s) (let ([bs (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (appC (lamC (s-exp->symbol (first bs)) (parse-type (third bs)) (parse (third (s-exp->list s)))) (parse (fourth bs))))] [(s-exp-match? '{lambda {[SYMBOL : ANY]} ANY} s) (let ([arg (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (lamC (s-exp->symbol (first arg)) (parse-type (third arg)) (parse (third (s-exp->list s)))))] [(s-exp-match? '[LAMBDA ['SYMBOL] ANY] s) (tylamC (s-exp->symbol (second (s-exp->list (first (s-exp->list (second (s-exp->list s))))))) (parse (third (s-exp->list s))))] [(s-exp-match? '[@ ANY ANY] s) (tyappC (parse (second (s-exp->list s))) (parse-type (third (s-exp->list s))))] [(s-exp-match? '{ANY ANY} s) (appC (parse (first (s-exp->list s))) (parse (second (s-exp->list s))))] [else (error 'parse "invalid input")])) (define (parse-type [s : s-expression]) : Type (cond [(s-exp-match? `num s) (numT)] [(s-exp-match? `bool s) (boolT)] [(s-exp-match? `(ANY -> ANY) s) (arrowT (parse-type (first (s-exp->list s))) (parse-type (third (s-exp->list s))))] [(s-exp-match? ''SYMBOL s) ; ''SYMBOL is equivalent to '(quote SYMBOL) (varT (s-exp->symbol (second (s-exp->list s))))] [(s-exp-match? `(forall ('SYMBOL) ANY) s) (forallT (s-exp->symbol (second (s-exp->list (first (s-exp->list (second (s-exp->list s))))))) (parse-type (third (s-exp->list s))))] [else (error 'parse-type "invalid input")])) (module+ test (test (parse '2) (numC 2)) (test (parse `x) ; note: backquote instead of normal quote (idC 'x)) (test (parse '{+ 2 1}) (plusC (numC 2) (numC 1))) (test (parse '{* 3 4}) (multC (numC 3) (numC 4))) (test (parse '{+ {* 3 4} 8}) (plusC (multC (numC 3) (numC 4)) (numC 8))) (test (parse '{let {[x : num {+ 1 2}]} y}) (appC (lamC 'x (numT) (idC 'y)) (plusC (numC 1) (numC 2)))) (test (parse '{lambda {[x : num]} 9}) (lamC 'x (numT) (numC 9))) (test (parse '{double 9}) (appC (idC 'double) (numC 9))) (test (parse '[LAMBDA ['a] {lambda {[x : 'a]} x}]) (tylamC 'a (lamC 'x (varT 'a) (idC 'x)))) (test (parse '[@ f num]) (tyappC (idC 'f) (numT))) (test/exn (parse '{{+ 1 2}}) "invalid input") (test (parse-type `num) (numT)) (test (parse-type `bool) (boolT)) (test (parse-type `(num -> bool)) (arrowT (numT) (boolT))) (test (parse-type ''a) (varT 'a)) (test (parse-type `(forall ('a) ('a -> 'a))) (forallT 'a (arrowT (varT 'a) (varT 'a)))) (test/exn (parse-type '1) "invalid input")) ;; interp ---------------------------------------- (define (interp [a : ExprC] [env : Env]) : Value (type-case ExprC a [numC (n) (numV n)] [idC (s) (lookup s env)] [plusC (l r) (num+ (interp l env) (interp r env))] [multC (l r) (num* (interp l env) (interp r env))] [lamC (n t body) (closV n body env)] [appC (fun arg) (type-case Value (interp fun env) [closV (n body c-env) (interp body (extend-env (bind n (interp arg env)) c-env))] [else (error 'interp "not a function")])] [tylamC (n body) (polyV body env)] [tyappC (tyfun tyarg) (type-case Value (interp tyfun env) [polyV (body p-env) (interp body p-env)] [else (error 'interp "not a polymorphic value")])])) (module+ test (test (interp (parse '2) mt-env) (numV 2)) (test/exn (interp (parse `x) mt-env) "free variable") (test (interp (parse `x) (extend-env (bind 'x (numV 9)) mt-env)) (numV 9)) (test (interp (parse '{+ 2 1}) mt-env) (numV 3)) (test (interp (parse '{* 2 1}) mt-env) (numV 2)) (test (interp (parse '{+ {* 2 3} {+ 5 8}}) mt-env) (numV 19)) (test (interp (parse '{lambda {[x : num]} {+ x x}}) mt-env) (closV 'x (plusC (idC 'x) (idC 'x)) mt-env)) (test (interp (parse '{let {[x : num 5]} {+ x x}}) mt-env) (numV 10)) (test (interp (parse '{let {[x : num 5]} {let {[x : num {+ 1 x}]} {+ x x}}}) mt-env) (numV 12)) (test (interp (parse '{let {[x : num 5]} {let {[y : num 6]} x}}) mt-env) (numV 5)) (test (interp (parse '{{lambda {[x : num]} {+ x x}} 8}) mt-env) (numV 16)) (test (interp (parse '[LAMBDA ['a] {lambda {[x : 'a]} x}]) mt-env) (polyV (lamC 'x (varT 'a) (idC 'x)) mt-env)) (test (interp (parse '[@ [LAMBDA ['a] {lambda {[x : 'a]} x}] num]) mt-env) (closV 'x (idC 'x) mt-env)) (test (interp (parse '{let {[f : (forall ('a) ('a -> 'a)) [LAMBDA ['a] {lambda {[x : 'a]} x}]]} {+ {[@ f num] 1} {{[@ f (num -> num)] {lambda {[n : num]} {+ n 1}}} 2}}}) mt-env) (numV 4)) (test/exn (interp (parse '{1 2}) mt-env) "not a function") (test/exn (interp (parse '{+ 1 {lambda {[x : num]} x}}) mt-env) "not a number") (test/exn (interp (parse '{let {[bad : (num -> num) {lambda {[x : num]} {+ x y}}]} {let {[y : num 5]} {bad 2}}}) mt-env) "free variable") (test/exn (interp (parse '[@ 1 num]) mt-env) "not a polymorphic value")) ;; num+ and num* ---------------------------------------- (define (num-op [op : (number number -> number)] [l : Value] [r : Value]) : Value (cond [(and (numV? l) (numV? r)) (numV (op (numV-n l) (numV-n r)))] [else (error 'interp "not a number")])) (define (num+ [l : Value] [r : Value]) : Value (num-op + l r)) (define (num* [l : Value] [r : Value]) : Value (num-op * l r)) (module+ test (test (num+ (numV 1) (numV 2)) (numV 3)) (test (num* (numV 2) (numV 3)) (numV 6))) ;; lookup ---------------------------------------- (define (make-lookup [check? : ('a -> boolean)] [name-of : ('a -> symbol)] [val-of : ('a -> 'b)]) (lambda ([name : symbol] [vals : (listof 'a)]) : 'b (cond [(empty? vals) (error 'find "free variable")] [else (if (and (check? (first vals)) (equal? name (name-of (first vals)))) (val-of (first vals)) ((make-lookup check? name-of val-of) name (rest vals)))]))) (define lookup (make-lookup bind? bind-name bind-val)) (module+ test (test/exn (lookup 'x mt-env) "free variable") (test (lookup 'x (extend-env (bind 'x (numV 8)) mt-env)) (numV 8)) (test (lookup 'x (extend-env (bind 'x (numV 9)) (extend-env (bind 'x (numV 8)) mt-env))) (numV 9)) (test (lookup 'y (extend-env (bind 'x (numV 9)) (extend-env (bind 'y (numV 8)) mt-env))) (numV 8))) ;; typecheck ---------------------------------------- (define (typecheck [a : ExprC] [tenv : TypeEnv]) (type-case ExprC a [numC (n) (numT)] [plusC (l r) (typecheck-nums l r tenv)] [multC (l r) (typecheck-nums l r tenv)] [idC (n) (type-lookup n tenv)] [lamC (n arg-type body) (begin (tvarcheck arg-type tenv) (arrowT arg-type (typecheck body (extend-env (tbind n arg-type) tenv))))] [appC (fun arg) (type-case Type (typecheck fun tenv) [arrowT (arg-type result-type) (if (equal? arg-type (typecheck arg tenv)) result-type (type-error arg (to-string arg-type)))] [else (type-error fun "function")])] [tylamC (n body) (forallT n (typecheck body (extend-env (tvar n) tenv)))] [tyappC (tyfun tyarg) (begin (tvarcheck tyarg tenv) (type-case Type (typecheck tyfun tenv) [forallT (n body) (type-subst n tyarg body)] [else (type-error tyfun "polymorphic value")]))])) (define (typecheck-nums l r tenv) (type-case Type (typecheck l tenv) [numT () (type-case Type (typecheck r tenv) [numT () (numT)] [else (type-error r "num")])] [else (type-error l "num")])) (define (type-error a msg) (error 'typecheck (string-append "no type: " (string-append (to-string a) (string-append " not " msg))))) (define type-lookup (make-lookup tbind? tbind-name tbind-type)) (define type-var-lookup (make-lookup tvar? tvar-name tvar-name)) (module+ test (test (typecheck (parse '10) mt-env) (numT)) (test (typecheck (parse '{+ 10 17}) mt-env) (numT)) (test (typecheck (parse '{* 10 17}) mt-env) (numT)) (test (typecheck (parse '{lambda {[x : num]} 12}) mt-env) (arrowT (numT) (numT))) (test (typecheck (parse '{lambda {[x : num]} {lambda {[y : bool]} x}}) mt-env) (arrowT (numT) (arrowT (boolT) (numT)))) (test (typecheck (parse '{{lambda {[x : num]} 12} {+ 1 17}}) mt-env) (numT)) (test (typecheck (parse '{let {[x : num 4]} {let {[f : (num -> num) {lambda {[y : num]} {+ x y}}]} {f x}}}) mt-env) (numT)) (test (typecheck (parse '[LAMBDA ['a] {lambda {[x : 'a]} x}]) mt-env) (forallT 'a (arrowT (varT 'a) (varT 'a)))) (test (typecheck (parse '[@ [LAMBDA ['a] {lambda {[x : 'a]} x}] num]) mt-env) (arrowT (numT) (numT))) (test (typecheck (parse '{let {[f : (forall ('a) ('a -> 'a)) [LAMBDA ['a] {lambda {[x : 'a]} x}]]} {+ {[@ f num] 1} {{[@ f (num -> num)] {lambda {[n : num]} {+ n 1}}} 2}}}) mt-env) (numT)) (test/exn (typecheck (parse '{1 2}) mt-env) "no type") (test/exn (typecheck (parse '{{lambda {[x : bool]} x} 2}) mt-env) "no type") (test/exn (typecheck (parse '{+ 1 {lambda {[x : num]} x}}) mt-env) "no type") (test/exn (typecheck (parse '{* {lambda {[x : num]} x} 1}) mt-env) "no type") (test/exn (typecheck (parse '[@ 1 num]) mt-env) "no type")) ;; tvarcheck ---------------------------------------- (define (tvarcheck ty tenv) (type-case Type ty [numT () (values)] [boolT () (values)] [arrowT (a b) (begin (tvarcheck a tenv) (tvarcheck b tenv))] [varT (id) (begin (type-var-lookup id tenv) (values))] [forallT (id t) (tvarcheck t (extend-env (tvar id) tenv))])) (module+ test (test (tvarcheck (numT) mt-env) (values)) (test (tvarcheck (boolT) mt-env) (values)) (test (tvarcheck (arrowT (numT) (boolT)) mt-env) (values)) (test (tvarcheck (varT 'a) (extend-env (tvar 'a) mt-env)) (values)) (test (tvarcheck (forallT 'a (varT 'a)) mt-env) (values)) (test/exn (tvarcheck (varT 'a) mt-env) "free variable")) ;; type-subst ---------------------------------------- (define (type-subst [what : symbol] [for : Type] [in : Type]) (type-case Type in [numT () (numT)] [boolT () (boolT)] [arrowT (l r) (arrowT (type-subst what for l) (type-subst what for r))] [varT (n) (if (equal? what n) for (varT n))] [forallT (n body) (cond [(equal? what n) (forallT n body)] [(free-type-var? n for) ;; If we want to replace `a` in ;; `{forall b (b -> a)}` with `a`, the result ;; `{forall b (b -> b)}` would be wrong, since the ;; `b` would get captured. We instead need to ;; produce `{forall b1 (b1 -> b)}`. (local [(define new-n (gen-name n 1 for body)) (define new-body (type-subst n (varT new-n) body))] (type-subst what for (forallT new-n new-body)))] [else (forallT n (type-subst what for body))])])) ;; Helper function for substitution: generates a name like `n` that is ;; not currently used (as a free type variable) in `for` or `body`. (define (gen-name [n : symbol] [i : number] [for : Type] [body : Type]) (let ([new-n (string->symbol (string-append (symbol->string n) (to-string i)))]) (if (or (free-type-var? new-n for) (free-type-var? new-n body)) (gen-name n (+ i 1) for body) new-n))) ;; Helper function for substutition: check whether a name is used as a ;; free type variable in a type. (define (free-type-var? [n : symbol] [t : Type]) (type-case Type t [numT () #f] [boolT () #f] [arrowT (l r) (or (free-type-var? n l) (free-type-var? n r))] [varT (n-v) (equal? n-v n)] [forallT (n-f body) (cond [(equal? n n-f) #f] [else (free-type-var? n body)])])) (module+ test (test (free-type-var? 'a (numT)) #f) (test (free-type-var? 'a (boolT)) #f) (test (free-type-var? 'a (arrowT (varT 'b) (varT 'b))) #f) (test (free-type-var? 'a (arrowT (varT 'a) (varT 'b))) #t) (test (free-type-var? 'a (varT 'a)) #t) (test (free-type-var? 'a (varT 'b)) #f) (test (free-type-var? 'a (forallT 'a (varT 'a))) #f) (test (free-type-var? 'a (forallT 'b (varT 'a))) #t) (test (free-type-var? 'a (forallT 'b (varT 'c))) #f) (test (gen-name 'a 1 (numT) (numT)) 'a1) (test (gen-name 'a 1 (numT) (varT 'a1)) 'a2) (test (gen-name 'a 1 (varT 'a1) (numT)) 'a2) (test (gen-name 'a 1 (arrowT (numT) (varT 'a1)) (numT)) 'a2) (test (gen-name 'a 1 (forallT 'a1 (varT 'a1)) (numT)) 'a1) (test (gen-name 'a 1 (forallT 'b (varT 'a1)) (numT)) 'a2) (test (type-subst 'a (boolT) (numT)) (numT)) (test (type-subst 'a (numT) (boolT)) (boolT)) (test (type-subst 'a (numT) (arrowT (varT 'a) (boolT))) (arrowT (numT) (boolT))) (test (type-subst 'a (numT) (varT 'a)) (numT)) (test (type-subst 'a (numT) (varT 'b)) (varT 'b)) (test (type-subst 'a (numT) (forallT 'a (varT 'a))) (forallT 'a (varT 'a))) (test (type-subst 'a (numT) (forallT 'b (varT 'a))) (forallT 'b (numT))) (test (type-subst 'a (varT 'b) (forallT 'b (varT 'a))) (forallT 'b1 (varT 'b))) (test (type-subst 'a (varT 'b) (forallT 'b (arrowT (varT 'a) (arrowT (varT 'b1) (varT 'b))))) (forallT 'b2 (arrowT (varT 'b) (arrowT (varT 'b1) (varT 'b2))))))