LoginSignup
1
1

More than 1 year has passed since last update.

型推論入門

Last updated at Posted at 2021-11-02

型推論はなにもわからないけど、とりあえず動くものが書けたので備忘録です。
200行ちょっとですので、型推論の雰囲気を知るのにちょうど良いかもしれません。

'use strict'

// 便利関数
const print = (...a) => console.log(...a)
const str = o => JSON.stringify(o, null, '  ')
const eq = (x, y) => str(x) === str(y)
const fail = (msg, o) => { throw new Error(msg + ' ' + str(o)) }
const until = (f, g) => {
  const l = []
  while (f()) {
    l.push(g())
  }
  return l
}

// ソースコードをトークン毎に分解し、丸括弧毎に配列へ変換
// 処理イメージ: (def add a b (+ a b)) -> ["def" "add" "a" "b" ["+" "a" "b"]]
const parse = src => {
  const tokens = src.split(/([()]|\s+)/).filter(x => x.trim()).map(code => ({code}))
  let pos = 0
  const list = l => (t => t.code === ')' ? ({list: l}) : list(l.concat([t])))(unit())
  const unit = () => (t => t.code === '(' ? list([]) : t)(tokens[pos++])
  return until(() => pos < tokens.length, unit)
}

// 型推論する
const inference = nodes => {
  let tvarSequence = 0
  const tvar = () => (name => ({name,var:true}))((++tvarSequence).toString())
  const tlambda = (...types) => types.length === 1 ? types[0] : ({types})
  const ttype = (name) => ({name, types: []})
  const tint = ttype('int')
  const tbool = ttype('bool')

  // 確定したジェネリック型を、未確定に戻す
  // `(def f a a) (f 1) (f true)`のような場合の関数fの推論に必要
  const fresh = (type, nonGeneric) => {
    const d = {}
    const rec = t => {
      const p = prune(t)
      return p.var ?
        (nonGeneric.includes(p.name) ? p : d[p.name] ||= tvar()) :
        ({name: p.name, types: p.types.map(rec)})
    }
    return rec(type)
  }

  // 2つの型が同じになるよう更新
  // 更新できなければ型推論失敗し例外を出す
  const unify = (a, b) => {
    a = prune(a)
    b = prune(b)
    if (a.var) {
      if (a.name !== b.name) {
        a.instance = b
      }
    } else if (b.var) {
      unify(b, a)
    } else {
      if (a.name !== b.name || a.types.length !== b.types.length) {
        fail(`type miss match`, {a,b})
      }
      a.types.forEach((t,i) => unify(t, b.types[i]))
    }
  }

  // ひとつのジェネリック型を複数の変数が参照する場合があるので、一度に更新できるよう t.instance を用意する
  const prune = t => (t.var && t.instance) ? t.instance = prune(t.instance) : t

  // 木構造の中の要素をそれぞれ型推論する
  const analyse = (node, env, nonGeneric) => node.type = _analyse(node, env, nonGeneric)
  const _analyse = (node, env, nonGeneric) => {
    if (node.list) {
      const list = node.list
      let [head,...tail] = list
      if (head.code === 'def') {
        // 関数定義の推論
        const name = tail[0].code
        const args = tail.slice(1, -1).map(arg => (arg.type = tvar(), arg))
        const body = tail.slice(-1)[0]
        const newEnv = Object.assign({}, env)
        args.forEach(arg => newEnv[arg.code] = arg.type)
        const ret = analyse(body, newEnv, nonGeneric.concat(args.map(t => t.type.name)))
        const ft = tlambda(...args.map(t => t.type), ret)
        return env[name] = ft
      } else if (tail.length) {
        // 関数適用の推論
        const argv = tail.map(t => analyse(t, env, nonGeneric))
        const rt = (env.__cache[str(argv)] ||= tvar()) // fix tvar
        const ft = analyse(head, env, nonGeneric)
        unify(tlambda(...argv, rt), ft)
        return rt
      } else {
        // 変数参照の推論
        return analyse(head, env, nonGeneric)
      }
    } else {
      // 値の推論
      const v = node.code
      return v.match(/^[0-9]/) ? tint :
        env[v] ? fresh(env[v], nonGeneric) :
        fail(`Not found ${v} in env`, {v,node,env})
    }
  }
  // 組み込み型の定義
  const v1 = tvar()
  const topEnv = {
    __cache: {},
    'true': tbool,
    'false': tbool,
    '+': tlambda(tint, tint, tint),
    '<': tlambda(tint, tint, tbool),
    'if': tlambda(tbool, v1, v1, v1),
  }
  return nodes.map(node => analyse(node, topEnv, []))
}

const showType = type => {
  const show = t => t.instance ? show(t.instance) :
    t.name || '(' + t.types.map(show).join(' ') + ')'
  const s = show(type)
  const o = {}
  const r = s.replace(/\d+/g, t => o[t] ||= Object.keys(o).length + 1)
  return r
}

const testType = () => {
  const reject = src => {
    try {
      inference(parse(src))
    } catch (e) {
      if (e.message.match(/^type miss match/)) {
        process.stdout.write('.')
        return
      }
    }
    print('Failed')
    print('src:', src)
  }
  const inf = (src, expect) => {
    try {
      let types = inference(parse(src))
      const actual = showType(types.slice(-1)[0])
      if (eq(actual, expect)) {
        process.stdout.write('.')
      } else {
        print('Failed')
        print('expect:', expect)
        print('actual:', actual)
        print('   src:', src)
      }
    } catch (e) {
      print('Failed')
      print('  src:', src)
      print('error:', e)
    }
  }
  // 値
  inf('(1)', 'int')
  inf('(true)', 'bool')
  inf('(false)', 'bool')

  // 式
  inf('(+ 1 1)', 'int')
  inf('(< 1 1)', 'bool')

  // if関数
  inf('(if true 1 2)', 'int')
  inf('(if true true true)', 'bool')

  // 引数なし関数
  inf('(def value 1)', 'int')

  // 引数あり関数
  inf('(def inc a (+ a 1))', '(int int)')
  inf('(def add a b (+ a b))', '(int int int)')

  // ジェネリック関数
  inf('(def f a a)', '(1 1)')
  inf('(def f a b a)', '(1 2 1)')
  inf('(def f a b b)', '(1 2 2)')
  inf('(def f a a) (f 1)', 'int')
  inf('(def f a a) (f 1) (f true)', 'bool')

  // 複雑なジェネリック関数
  inf('(def f x (+ x 1)) (def g x (+ x 2)) (+ (f 1) (g 1))', 'int')
  inf('(def _ f g x (g (f x)))', '((1 2) (2 3) 1 3)')
  inf('(def _ x y z (x z (y z)))', '((1 2 3) (1 2) 1 3)')
  inf('(def _ b x (if (x b) x (def _ x b)))', '(1 (1 bool) (1 1))')
  inf('(def _ x (if true x (if x true false)))', '(bool bool)')
  inf('(def _ x y (if x x y))', '(bool bool bool)')
  inf('(def _ n ((def _ x (x (def _ y y))) (def _ f (f n))))', '(1 1)')
  inf('(def _ x y (x y))', '((1 2) 1 2)')
  inf('(def _ x y (x (y x)))', '((1 2) ((1 2) 1) 2)')
  inf('(def _ h t f x (f h (t f x)))', '(1 ((1 2 3) 4 2) (1 2 3) 4 3)')
  inf('(def _ x y (x (y x) (y x)))', '((1 1 2) ((1 1 2) 1) 2)')
  inf('(def id x x) (def f y (id (y id)))', '(((1 1) 2) 2)')
  inf('(def id x x) (def f (if (id true) (id 1) (id 2)))', 'int')
  inf('(def f x (3)) (def g (+ (f true) (f 4)))', 'int')
  inf('(def f x x) (def g y y) (def h b (if b (f g) (g f)))', '(bool (1 1))')

  // 型推論が失敗すべきケース
  reject('(+ 1 true)')

  // テストの終了
  print('ok')
}

testType()

実行結果

...............................ok

また、以下のサイトが大変に参考になりました。この場を借りて作者にお礼申し上げます。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1