Hindley-Milner型システムの簡単な実装

全体像がわかりやすいよう、簡単なデモを実装した。 以下のように動作を確認することができる。

ghci> inferExpr (App (Lam "x" (Var "x")) (LitInt 42))
Right TInt
module Main where

import Control.Monad.Except
import Control.Monad.State.Strict
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set

data Expr
  = Var String
  | Lam String Expr
  | App Expr Expr
  | Let String Expr Expr
  | LitInt Int
  | LitBool Bool
  | If Expr Expr Expr
  deriving (Eq, Ord, Show)

type TypeVar = Int

data Type
  = TVar TypeVar
  | TInt
  | TBool
  | TFun Type Type
  deriving (Eq, Ord, Show)

data Scheme = Forall [TypeVar] Type
  deriving (Eq, Ord, Show)

type Subst = Map TypeVar Type

type TypeEnv = Map String Scheme

type Infer a = ExceptT String (State TypeVar) a

runInfer :: Infer a -> Either String a
runInfer m = evalState (runExceptT m) 0

freshTypeVar :: Infer TypeVar
freshTypeVar = do
  v <- get
  put (v + 1)
  pure v

freshTVar :: Infer Type
freshTVar = TVar <$> freshTypeVar

emptySubst :: Subst
emptySubst = Map.empty

composeSubst :: Subst -> Subst -> Subst
composeSubst s1 s2 = Map.map (apply s1) s2 `Map.union` s1

class Types a where
  ftv :: a -> Set TypeVar
  apply :: Subst -> a -> a

instance Types Type where
  ftv (TVar v)   = Set.singleton v
  ftv TInt       = Set.empty
  ftv TBool      = Set.empty
  ftv (TFun a b) = ftv a `Set.union` ftv b

  apply s t@(TVar v) = Map.findWithDefault t v s
  apply _ TInt       = TInt
  apply _ TBool      = TBool
  apply s (TFun a b) = TFun (apply s a) (apply s b)

instance Types Scheme where
  ftv (Forall vars t) = ftv t `Set.difference` Set.fromList vars
  apply s (Forall vars t) = Forall vars (apply s' t)
    where
      s' = foldr Map.delete s vars

instance Types a => Types [a] where
  ftv = foldr (Set.union . ftv) Set.empty
  apply s = map (apply s)

instance Types TypeEnv where
  ftv env = ftv (Map.elems env)
  apply s = Map.map (apply s)

generalize :: TypeEnv -> Type -> Scheme
generalize env t = Forall vars t
  where
    vars = Set.toList (ftv t `Set.difference` ftv env)

instantiate :: Scheme -> Infer Type
instantiate (Forall vars t) = do
  freshTypes <- mapM (const freshTVar) vars
  let s = Map.fromList (zip vars freshTypes)
  pure (apply s t)

lookupEnv :: TypeEnv -> String -> Infer Type
lookupEnv env name =
  case Map.lookup name env of
    Nothing     -> throwError ("Unbound variable: " ++ name)
    Just scheme -> instantiate scheme

infer :: TypeEnv -> Expr -> Infer (Subst, Type)
infer env expr = case expr of
  Var name -> do
    t <- lookupEnv env name
    pure (emptySubst, t)

  LitInt _ -> pure (emptySubst, TInt)
  LitBool _ -> pure (emptySubst, TBool)

  Lam param body -> do
    tv <- freshTVar
    let env' = Map.insert param (Forall [] tv) env
    (s1, tBody) <- infer env' body
    pure (s1, TFun (apply s1 tv) tBody)

  App funExpr argExpr -> do
    (s1, tFun) <- infer env funExpr
    (s2, tArg) <- infer (apply s1 env) argExpr
    tv <- freshTVar
    s3 <- unify (apply s2 tFun) (TFun tArg tv)
    pure (s3 `composeSubst` s2 `composeSubst` s1, apply s3 tv)

  Let name rhs body -> do
    (s1, t1) <- infer env rhs
    let env' = apply s1 env
        scheme = generalize env' (apply s1 t1)
        env'' = Map.insert name scheme env'
    (s2, t2) <- infer env'' body
    pure (s2 `composeSubst` s1, t2)

  If condExpr thenExpr elseExpr -> do
    (s1, tCond) <- infer env condExpr
    s2 <- unify (apply s1 tCond) TBool

    let s12 = s2 `composeSubst` s1
        env' = apply s12 env

    (s3, tThen) <- infer env' thenExpr
    let env'' = apply s3 env'

    (s4, tElse) <- infer env'' elseExpr

    s5 <- unify (apply s4 tThen) tElse

    let s = s5 `composeSubst` s4 `composeSubst` s3 `composeSubst` s12
    pure (s, apply s5 tElse)

unify :: Type -> Type -> Infer Subst
unify (TFun l r) (TFun l' r') = do
  s1 <- unify l l'
  s2 <- unify (apply s1 r) (apply s1 r')
  pure (s2 `composeSubst` s1)
unify (TVar v) t = bind v t
unify t (TVar v) = bind v t
unify TInt TInt = pure emptySubst
unify TBool TBool = pure emptySubst
unify t1 t2 = throwError ("Type mismatch. Cannot unify " ++ show t1 ++ " with " ++ show t2)

bind :: TypeVar -> Type -> Infer Subst
bind v t
  | t == TVar v             = pure emptySubst
  | v `Set.member` ftv t    = throwError ("Occurs check failed: t" ++ show v ++ " occurs in " ++ show t)
  | otherwise               = pure (Map.singleton v t)

inferExpr :: Expr -> Either String Type
inferExpr expr = runInfer $ do
  (s, t) <- infer Map.empty expr
  pure (apply s t)

main :: IO ()
main = pure ()