#lang plai-typed (require plai-typed/s-exp-match) ;; Start with "subtype.rkt" ;; Add `if0` so that an expression like ;; ;; {get {if0 .... ;; {record {x 8}} ;; {record {x 9} {y 10}}} ;; x} ;; ;; has a type. (define-type Value [numV (n : number)] [closV (arg : symbol) (body : ExprC) (env : Env)] [recordV (l1 : (listof symbol)) (l2 : (listof Value))]) (define-type ExprC [numC (n : number)] [idC (s : symbol)] [plusC (l : ExprC) (r : ExprC)] [multC (l : ExprC) (r : ExprC)] [lamC (n : symbol) (arg-type : Type) (body : ExprC)] [appC (fun : ExprC) (arg : ExprC)] [recordC (ns : (listof symbol)) (args : (listof ExprC))] [getC (rec : ExprC) (n : symbol)] [setC (rec : ExprC) (n : symbol) (val : ExprC)] [if0C [tst : ExprC] [thn : ExprC] [els : ExprC]]) (define-type Type [numT] [boolT] [arrowT (arg : Type) (result : Type)] [recordT (names : (listof symbol)) (field : (listof Type))]) (define-type Binding [bind (name : symbol) (val : Value)]) (define-type-alias Env (listof Binding)) (define-type TypeBinding [tbind (name : symbol) (type : Type)]) (define-type-alias TypeEnv (listof TypeBinding)) (define mt-env empty) (define extend-env cons) (module+ test (print-only-errors true)) ;; parse ---------------------------------------- (define (parse [s : s-expression]) : ExprC (cond [(s-exp-match? `NUMBER s) (numC (s-exp->number s))] [(s-exp-match? `SYMBOL s) (idC (s-exp->symbol s))] [(s-exp-match? '{+ ANY ANY} s) (plusC (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? '{* ANY ANY} s) (multC (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? '{let {[SYMBOL : ANY ANY]} ANY} s) (let ([bs (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (appC (lamC (s-exp->symbol (first bs)) (parse-type (third bs)) (parse (third (s-exp->list s)))) (parse (fourth bs))))] [(s-exp-match? '{lambda {[SYMBOL : ANY]} ANY} s) (let ([arg (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (lamC (s-exp->symbol (first arg)) (parse-type (third arg)) (parse (third (s-exp->list s)))))] [(s-exp-match? '{record {SYMBOL ANY} ...} s) (recordC (map (lambda (l) (s-exp->symbol (first (s-exp->list l)))) (rest (s-exp->list s))) (map (lambda (l) (parse (second (s-exp->list l)))) (rest (s-exp->list s))))] [(s-exp-match? '{get ANY SYMBOL} s) (getC (parse (second (s-exp->list s))) (s-exp->symbol (third (s-exp->list s))))] [(s-exp-match? '{set ANY SYMBOL ANY} s) (setC (parse (second (s-exp->list s))) (s-exp->symbol (third (s-exp->list s))) (parse (fourth (s-exp->list s))))] [(s-exp-match? '{if0 ANY ANY ANY} s) (if0C (parse (second (s-exp->list s))) (parse (third (s-exp->list s))) (parse (fourth (s-exp->list s))))] [(s-exp-match? '{ANY ANY} s) (appC (parse (first (s-exp->list s))) (parse (second (s-exp->list s))))] [else (error 'parse "invalid input")])) (define (parse-type [s : s-expression]) : Type (cond [(s-exp-match? `num s) (numT)] [(s-exp-match? `bool s) (boolT)] [(s-exp-match? `(ANY -> ANY) s) (arrowT (parse-type (first (s-exp->list s))) (parse-type (third (s-exp->list s))))] [(s-exp-match? `{[SYMBOL : ANY] ...} s) (let ([fields (map s-exp->list (s-exp->list s))]) (recordT (map s-exp->symbol (map first fields)) (map parse-type (map third fields))))] [else (error 'parse-type "invalid input")])) (module+ test (test (parse '2) (numC 2)) (test (parse `x) ; note: backquote instead of normal quote (idC 'x)) (test (parse '{+ 2 1}) (plusC (numC 2) (numC 1))) (test (parse '{* 3 4}) (multC (numC 3) (numC 4))) (test (parse '{+ {* 3 4} 8}) (plusC (multC (numC 3) (numC 4)) (numC 8))) (test (parse '{let {[x : num {+ 1 2}]} y}) (appC (lamC 'x (numT) (idC 'y)) (plusC (numC 1) (numC 2)))) (test (parse '{lambda {[x : num]} 9}) (lamC 'x (numT) (numC 9))) (test (parse '{double 9}) (appC (idC 'double) (numC 9))) (test (parse '{record {x {+ 1 2}} {y 3}}) (recordC (list 'x 'y) (list (plusC (numC 1) (numC 2)) (numC 3)))) (test (parse '{get {+ 1 2} a}) (getC (plusC (numC 1) (numC 2)) 'a)) (test (parse '{set {+ 1 2} a 7}) (setC (plusC (numC 1) (numC 2)) 'a (numC 7))) (test/exn (parse '{{+ 1 2}}) "invalid input") (test (parse-type `num) (numT)) (test (parse-type `bool) (boolT)) (test (parse-type `(num -> bool)) (arrowT (numT) (boolT))) (test (parse-type '{[x : num] [y : bool]}) (recordT (list 'x 'y) (list (numT) (boolT)))) (test/exn (parse-type '1) "invalid input")) ;; interp ---------------------------------------- (define (interp [a : ExprC] [env : Env]) : Value (type-case ExprC a [numC (n) (numV n)] [idC (s) (lookup s env)] [plusC (l r) (num+ (interp l env) (interp r env))] [multC (l r) (num* (interp l env) (interp r env))] [lamC (n t body) (closV n body env)] [appC (fun arg) (type-case Value (interp fun env) [closV (n body c-env) (interp body (extend-env (bind n (interp arg env)) c-env))] [else (error 'interp "not a function")])] [recordC (ns as) (recordV ns (map (lambda (a) (interp a env)) as))] [getC (a n) (type-case Value (interp a env) [recordV (ns vs) (find n ns vs)] [else (error 'interp "not a record")])] [setC (a n v) (type-case Value (interp a env) [recordV (ns vs) (recordV ns (update n (interp v env) ns vs))] [else (error 'interp "not a record")])] [if0C (t z nz) (let [(n (interp t env))] (if (equal? 0 (numV-n n)) (interp z env) (interp nz env)))])) (define (update [n : symbol] [v : Value] [ns : (listof symbol)] [vs : (listof Value)]) : (listof Value) (cond [(empty? ns) (error 'interp "no such field")] [else (if (symbol=? n (first ns)) (cons v (rest vs)) (cons (first vs) (update n v (rest ns) (rest vs))))])) (module+ test (test (interp (parse '2) mt-env) (numV 2)) (test/exn (interp (parse `x) mt-env) "free variable") (test (interp (parse `x) (extend-env (bind 'x (numV 9)) mt-env)) (numV 9)) (test (interp (parse '{+ 2 1}) mt-env) (numV 3)) (test (interp (parse '{* 2 1}) mt-env) (numV 2)) (test (interp (parse '{+ {* 2 3} {+ 5 8}}) mt-env) (numV 19)) (test (interp (parse '{lambda {[x : num]} {+ x x}}) mt-env) (closV 'x (plusC (idC 'x) (idC 'x)) mt-env)) (test (interp (parse '{let {[x : num 5]} {+ x x}}) mt-env) (numV 10)) (test (interp (parse '{let {[x : num 5]} {let {[x : num {+ 1 x}]} {+ x x}}}) mt-env) (numV 12)) (test (interp (parse '{let {[x : num 5]} {let {[y : num 6]} x}}) mt-env) (numV 5)) (test (interp (parse '{{lambda {[x : num]} {+ x x}} 8}) mt-env) (numV 16)) (test (interp (parse '{record {x {+ 1 3}}}) mt-env) (recordV (list 'x) (list (numV 4)))) (test (interp (parse '{get {record {x {+ 1 3}}} x}) mt-env) (numV 4)) (test (interp (parse '{set {record {a {+ 1 1}} {b {+ 2 2}}} a 5}) mt-env) (recordV (list 'a 'b) (list (numV 5) (numV 4)))) (test (interp (parse '{set {record {a {+ 1 1}} {b {+ 2 2}}} b 5}) mt-env) (recordV (list 'a 'b) (list (numV 2) (numV 5)))) (test (interp (parse '{let {[r1 : {[a : num] [b : num]} {record {a {+ 1 1}} {b {+ 2 2}}}]} {let {[r2 : {[a : num] [b : num]} {set r1 a 5}]} {+ {get r1 a} {get r2 a}}}}) mt-env) (numV 7)) (test (interp (parse '{if0 0 1 2}) mt-env) (numV 1)) (test (interp (parse '{get {if0 0 {record {x 8}} {record {x 9} {y 10}}} x}) mt-env) (numV 8)) (test (interp (parse '{get {if0 1 {record {x 8}} {record {x 9} {y 10}}} x}) mt-env) (numV 9)) (test/exn (interp (parse '{get 1 x}) mt-env) "not a record") (test/exn (interp (parse '{set 1 x 1}) mt-env) "not a record") (test/exn (interp (parse '{set {record} x 1}) mt-env) "no such field") (test/exn (interp (parse '{1 2}) mt-env) "not a function") (test/exn (interp (parse '{+ 1 {lambda {[x : num]} x}}) mt-env) "not a number") (test/exn (interp (parse '{let {[bad : (num -> num) {lambda {[x : num]} {+ x y}}]} {let {[y : num 5]} {bad 2}}}) mt-env) "free variable")) ;; num+ and num* ---------------------------------------- (define (num-op [op : (number number -> number)] [l : Value] [r : Value]) : Value (cond [(and (numV? l) (numV? r)) (numV (op (numV-n l) (numV-n r)))] [else (error 'interp "not a number")])) (define (num+ [l : Value] [r : Value]) : Value (num-op + l r)) (define (num* [l : Value] [r : Value]) : Value (num-op * l r)) (module+ test (test (num+ (numV 1) (numV 2)) (numV 3)) (test (num* (numV 2) (numV 3)) (numV 6))) ;; lookup ---------------------------------------- (define (make-lookup [name-of : ('a -> symbol)] [val-of : ('a -> 'b)]) (lambda ([name : symbol] [vals : (listof 'a)]) : 'b (cond [(empty? vals) (error 'find "free variable")] [else (if (equal? name (name-of (first vals))) (val-of (first vals)) ((make-lookup name-of val-of) name (rest vals)))]))) (define lookup (make-lookup bind-name bind-val)) (module+ test (test/exn (lookup 'x mt-env) "free variable") (test (lookup 'x (extend-env (bind 'x (numV 8)) mt-env)) (numV 8)) (test (lookup 'x (extend-env (bind 'x (numV 9)) (extend-env (bind 'x (numV 8)) mt-env))) (numV 9)) (test (lookup 'y (extend-env (bind 'x (numV 9)) (extend-env (bind 'y (numV 8)) mt-env))) (numV 8))) ;; typecheck ---------------------------------------- (define (typecheck [a : ExprC] [tenv : TypeEnv]) (type-case ExprC a [numC (n) (numT)] [plusC (l r) (typecheck-nums l r tenv)] [multC (l r) (typecheck-nums l r tenv)] [idC (n) (type-lookup n tenv)] [lamC (n arg-type body) (arrowT arg-type (typecheck body (extend-env (tbind n arg-type) tenv)))] [appC (fun arg) (type-case Type (typecheck fun tenv) [arrowT (arg-type result-type) (if (subtype? (typecheck arg tenv) arg-type) result-type (type-error arg (to-string arg-type)))] [else (type-error fun "function")])] [recordC (ns as) (recordT ns (map (lambda (a) (typecheck a tenv)) as))] [getC (rec-expr field-name) (type-case Type (typecheck rec-expr tenv) [recordT (ns ts) (try (find field-name ns ts) (lambda () (type-error rec-expr "record with field")))] [else (type-error rec-expr "record")])] [setC (r n v) (let ([rec-type (typecheck r tenv)]) (type-case Type rec-type [recordT (ns ts) (let ([field-type (try (find n ns ts) (lambda () (type-error r "record with field")))]) (if (subtype? (typecheck v tenv) field-type) rec-type (type-error v (to-string field-type))))] [else (type-error r "record")]))] [if0C (t z nz) (let ([t-type (typecheck t tenv)] [z-type (typecheck z tenv)] [nz-type (typecheck nz tenv)]) (type-case Type t-type [numT () (if (subtype? z-type nz-type) nz-type (if (subtype? nz-type z-type) z-type (type-error z (to-string nz-type))))] [else (type-error t "number")]))])) (define (find [n : symbol] [ns : (listof symbol)] [ts : (listof 'a)]) : 'a (cond [(empty? ns) (error 'interp "no such field")] [else (if (symbol=? n (first ns)) (first ts) (find n (rest ns) (rest ts)))])) (define (typecheck-nums l r tenv) (type-case Type (typecheck l tenv) [numT () (type-case Type (typecheck r tenv) [numT () (numT)] [else (type-error r "num")])] [else (type-error l "num")])) (define (type-error a msg) (error 'typecheck (string-append "no type: " (string-append (to-string a) (string-append " not " msg))))) (define type-lookup (make-lookup tbind-name tbind-type)) (module+ test (test (typecheck (parse '{record {x 1}}) mt-env) (recordT (list 'x) (list (numT)))) (test (typecheck (parse '{get {record {x 1}} x}) mt-env) (numT)) (test/exn (typecheck (parse '{{get {record {x 1}} x} 0}) mt-env) "no type") (test/exn (typecheck (parse '{get {record {x 1}} y}) mt-env) "no type") (test/exn (typecheck (parse '{get 1 y}) mt-env) "no type") (test/exn (typecheck (parse '{set 1 y 1}) mt-env) "no type") (test (typecheck (parse '10) mt-env) (numT)) (test (typecheck (parse '{+ 10 17}) mt-env) (numT)) (test (typecheck (parse '{* 10 17}) mt-env) (numT)) (test (typecheck (parse '{lambda {[x : num]} 12}) mt-env) (arrowT (numT) (numT))) (test (typecheck (parse '{lambda {[x : num]} {lambda {[y : bool]} x}}) mt-env) (arrowT (numT) (arrowT (boolT) (numT)))) (test (typecheck (parse '{{lambda {[x : num]} 12} {+ 1 17}}) mt-env) (numT)) (test (typecheck (parse '{let {[x : num 4]} {let {[f : (num -> num) {lambda {[y : num]} {+ x y}}]} {f x}}}) mt-env) (numT)) (test (typecheck (parse '{lambda {[r : {[x : num]}]} {get r x}}) mt-env) (arrowT (recordT (list 'x) (list (numT))) (numT))) (test (typecheck (parse '{{lambda {[r : {[x : num]}]} {get r x}} {record {x {+ 1 2}}}}) mt-env) (numT)) (test (interp (parse '{{lambda {[r : {[x : num]}]} {get r x}} {record {x {+ 1 2}} {y 9}}}) mt-env) (numV 3)) (test (typecheck (parse '{{lambda {[r : {[x : num]}]} {get r x}} {record {x {+ 1 2}} {y 9}}}) mt-env) (numT)) (test (typecheck (parse '{{lambda {[r : {[y : num] [x : num]}]} {get r x}} {record {x {+ 1 2}} {y 9}}}) mt-env) (numT)) (test/exn (typecheck (parse '{{lambda {[r : {[x : num] [y : num]}]} {get r x}} {record {x {+ 1 2}}}}) mt-env) "no type") (test (typecheck (parse '{lambda {[r : {[x : num]}]} {set r x 2}}) mt-env) (arrowT (recordT (list 'x) (list (numT))) (recordT (list 'x) (list (numT))))) (test/exn (typecheck (parse '{lambda {[r : {[x : num]}]} {set r x r}}) mt-env) "no type") (test/exn (typecheck (parse '{lambda {[r : {[x : num]}]} {set r y 1}}) mt-env) "no type") (test (typecheck (parse '{{lambda {[r : {[x : {[y : num]}]}]} {set r x {record {y 1} {z 1}}}} {record {x {record {y 7}}}}}) mt-env) (recordT (list 'x) (list (recordT (list 'y) (list (numT)))))) (test (typecheck (parse '{if0 0 1 2}) mt-env) (numT)) (test (typecheck (parse '{get {if0 0 {record {x 8}} {record {x 9} {y 10}}} x}) mt-env) (numT)) (test (typecheck (parse '{get {if0 1 {record {x 8}} {record {x 9} {y 10}}} x}) mt-env) (numT)) (test (typecheck (parse '{if0 1 {record {x 8}} {record {x 9} {y 10}}}) mt-env) (recordT (list 'x) (list (numT)))) ;; But it would be nice to have a type here... (test/exn (typecheck (parse '{get {if0 1 {record {x 8} {z 11}} {record {x 9} {y 10}}} x}) mt-env) "no type") (test/exn (typecheck (parse '{if0 {0 0} 1 2}) mt-env) "no type") (test/exn (typecheck (parse '{if0 {lambda {[x : num]} x} 1 2}) mt-env) "no type") (test/exn (typecheck (parse '{if0 0 {record {x 8}} 2}) mt-env) "no type") (test/exn (typecheck (parse '{1 2}) mt-env) "no type") (test/exn (typecheck (parse '{{lambda {[x : bool]} x} 2}) mt-env) "no type") (test/exn (typecheck (parse '{+ 1 {lambda {[x : num]} x}}) mt-env) "no type") (test/exn (typecheck (parse '{* {lambda {[x : num]} x} 1}) mt-env) "no type")) ;; subtype? ---------------------------------------- (define (subtype? [t1 : Type] [t2 : Type]) : boolean (type-case Type t1 [numT () (type-case Type t2 [numT () #t] [else #f])] [boolT () (type-case Type t2 [boolT () #t] [else #f])] [arrowT (l1 r1) (type-case Type t2 [arrowT (l2 r2) (and (subtype? l2 l1) ; contravariant (subtype? r1 r2))] ; covariant [else #f])] [recordT (ns1 ts1) (type-case Type t2 [recordT (ns2 ts2) ;; Every field in ns2 must be in ns1, ;; and each corresponind type in ts1 ;; should be a subtype of the one in ts2 (foldl (lambda (n b) (and b (member n ns1) ;; covariant: (subtype? (find n ns1 ts1) (find n ns2 ts2)))) true ns2)] [else #f])])) (module+ test (test (subtype? (numT) (numT)) #t) (test (subtype? (numT) (boolT)) #f) (test (subtype? (boolT) (boolT)) #t) (test (subtype? (boolT) (numT)) #f) (test (subtype? (arrowT (numT) (numT)) (arrowT (numT) (numT))) #t) (test (subtype? (arrowT (numT) (boolT)) (arrowT (numT) (numT))) #f) (test (subtype? (arrowT (boolT) (numT)) (arrowT (numT) (numT))) #f) (test (subtype? (arrowT (boolT) (numT)) (numT)) #f) (test (subtype? (recordT (list 'x) (list (numT))) (recordT empty empty)) #t) (test (subtype? (recordT empty empty) (recordT (list 'x) (list (numT)))) #f) (test (subtype? (recordT (list 'x) (list (recordT (list 'y) (list (numT))))) (recordT (list 'x) (list (recordT empty empty)))) #t) (test (subtype? (arrowT (recordT empty empty) (recordT (list 'x) (list (numT)))) (arrowT (recordT (list 'x) (list (numT))) (recordT empty empty))) #t) (test (subtype? (recordT empty empty) (numT)) #f))