全体像がわかりやすいよう、簡単なデモを実装した。 以下のように動作を確認することができる。
ghci> inferExpr (App (Lam "x" (Var "x")) (LitInt 42))
Right TInt
module Main where
import Control.Monad.Except
import Control.Monad.State.Strict
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
data Expr
= Var String
| Lam String Expr
| App Expr Expr
| Let String Expr Expr
| LitInt Int
| LitBool Bool
| If Expr Expr Expr
deriving (Eq, Ord, Show)
type TypeVar = Int
data Type
= TVar TypeVar
| TInt
| TBool
| TFun Type Type
deriving (Eq, Ord, Show)
data Scheme = Forall [TypeVar] Type
deriving (Eq, Ord, Show)
type Subst = Map TypeVar Type
type TypeEnv = Map String Scheme
type Infer a = ExceptT String (State TypeVar) a
runInfer :: Infer a -> Either String a
runInfer m = evalState (runExceptT m) 0
freshTypeVar :: Infer TypeVar
freshTypeVar = do
v <- get
put (v + 1)
pure v
freshTVar :: Infer Type
freshTVar = TVar <$> freshTypeVar
emptySubst :: Subst
emptySubst = Map.empty
composeSubst :: Subst -> Subst -> Subst
composeSubst s1 s2 = Map.map (apply s1) s2 `Map.union` s1
class Types a where
ftv :: a -> Set TypeVar
apply :: Subst -> a -> a
instance Types Type where
ftv (TVar v) = Set.singleton v
ftv TInt = Set.empty
ftv TBool = Set.empty
ftv (TFun a b) = ftv a `Set.union` ftv b
apply s t@(TVar v) = Map.findWithDefault t v s
apply _ TInt = TInt
apply _ TBool = TBool
apply s (TFun a b) = TFun (apply s a) (apply s b)
instance Types Scheme where
ftv (Forall vars t) = ftv t `Set.difference` Set.fromList vars
apply s (Forall vars t) = Forall vars (apply s' t)
where
s' = foldr Map.delete s vars
instance Types a => Types [a] where
ftv = foldr (Set.union . ftv) Set.empty
apply s = map (apply s)
instance Types TypeEnv where
ftv env = ftv (Map.elems env)
apply s = Map.map (apply s)
generalize :: TypeEnv -> Type -> Scheme
generalize env t = Forall vars t
where
vars = Set.toList (ftv t `Set.difference` ftv env)
instantiate :: Scheme -> Infer Type
instantiate (Forall vars t) = do
freshTypes <- mapM (const freshTVar) vars
let s = Map.fromList (zip vars freshTypes)
pure (apply s t)
lookupEnv :: TypeEnv -> String -> Infer Type
lookupEnv env name =
case Map.lookup name env of
Nothing -> throwError ("Unbound variable: " ++ name)
Just scheme -> instantiate scheme
infer :: TypeEnv -> Expr -> Infer (Subst, Type)
infer env expr = case expr of
Var name -> do
t <- lookupEnv env name
pure (emptySubst, t)
LitInt _ -> pure (emptySubst, TInt)
LitBool _ -> pure (emptySubst, TBool)
Lam param body -> do
tv <- freshTVar
let env' = Map.insert param (Forall [] tv) env
(s1, tBody) <- infer env' body
pure (s1, TFun (apply s1 tv) tBody)
App funExpr argExpr -> do
(s1, tFun) <- infer env funExpr
(s2, tArg) <- infer (apply s1 env) argExpr
tv <- freshTVar
s3 <- unify (apply s2 tFun) (TFun tArg tv)
pure (s3 `composeSubst` s2 `composeSubst` s1, apply s3 tv)
Let name rhs body -> do
(s1, t1) <- infer env rhs
let env' = apply s1 env
scheme = generalize env' (apply s1 t1)
env'' = Map.insert name scheme env'
(s2, t2) <- infer env'' body
pure (s2 `composeSubst` s1, t2)
If condExpr thenExpr elseExpr -> do
(s1, tCond) <- infer env condExpr
s2 <- unify (apply s1 tCond) TBool
let s12 = s2 `composeSubst` s1
env' = apply s12 env
(s3, tThen) <- infer env' thenExpr
let env'' = apply s3 env'
(s4, tElse) <- infer env'' elseExpr
s5 <- unify (apply s4 tThen) tElse
let s = s5 `composeSubst` s4 `composeSubst` s3 `composeSubst` s12
pure (s, apply s5 tElse)
unify :: Type -> Type -> Infer Subst
unify (TFun l r) (TFun l' r') = do
s1 <- unify l l'
s2 <- unify (apply s1 r) (apply s1 r')
pure (s2 `composeSubst` s1)
unify (TVar v) t = bind v t
unify t (TVar v) = bind v t
unify TInt TInt = pure emptySubst
unify TBool TBool = pure emptySubst
unify t1 t2 = throwError ("Type mismatch. Cannot unify " ++ show t1 ++ " with " ++ show t2)
bind :: TypeVar -> Type -> Infer Subst
bind v t
| t == TVar v = pure emptySubst
| v `Set.member` ftv t = throwError ("Occurs check failed: t" ++ show v ++ " occurs in " ++ show t)
| otherwise = pure (Map.singleton v t)
inferExpr :: Expr -> Either String Type
inferExpr expr = runInfer $ do
(s, t) <- infer Map.empty expr
pure (apply s t)
main :: IO ()
main = pure ()