LoginSignup
4
2

More than 5 years have passed since last update.

100行のScalaコードでScalaっぽいコードを生成すっからー

Last updated at Posted at 2016-12-23

初心覚えていますか?

Scala書いてますかっ。
バリバリ書いている皆様方におかれましては、布教活動に勤しむこともありましょう。しかし、えてして初心者の頃の気持ちというものは忘れがちなものです。慣れ親しむにつれ、初心者が何につまづきやすいか、分からなくなってしまうのですね。

ということで、実際に私が教えることになった初心者のコードを見てみましょう。
Scalaはおろかプログラミングの知識もゼロ、英単語すら知らない状態だったので、とりあえずScalaコンパイラのソースを6時間くらい眺めてもらって、適当に書かせてみましたっ。スパルタ教育!

/**
 *  Sets concrete methods the case didn't common a
 *  -- --
 * @throws(classOf[AsJavaUnit]) mateclean
 *  @define clars File
 * @since 2.12.0
 */
trait Test extends AbstractTransformed[T] {
  class UsesSettingCheckRandom {
    def complent: Int
  }
  def getBootclasspathDebuggeOKN[Mapper[T]] extends Transformed[T](tps)

お、ちゃんとコメント書いてますね。えらいえらい。
Test とか CheckRandom とか Debug なんて単語が出てくる tarit ということは、テスト用のユーティリティでしょうか。

  /** Returns `true` if this donate is a cleurar linked by read.
   *  Exact local does method sign if the phast on range one element
   *  the result repreten element.
   *
   *  @author  Martin Oderation
   */
  abstract class FileClass extends ScaladocGlobal
  def iterator: Iterator[Char]
  def ==(x1: String): Boolean

うーむ、@author Martin Oderation と名乗っていますね。
Scalaを創ったMartin Odersky先生リスペクトなハンドルネームなのでしょうが、それはちょっとイタい気が……。

  /** Creates a value? */
  def seenWritur(): Unit = {
    val anoExistentially = printPrimitive(System.out)
    for {
      val sourceBody = method match {
        case Apply(Seq(java.lang.DefaultFlags, symbol) => true
        case recoverRef(_, _) != 0
        case WeekDay if x ==> " => " + " = " + for field productSize(tree.rootType) ; fall }

なんか、ややこしそうなパターンマッチしようとしてます。コメントも自信なさそうです。

  dcass("Map(1.0 " + x.x) // ok
  def f2(x: Any] = N0;
  def f4(x: Any) = "xx"
  def f5(x: Any): Foo = Fa(foo)
  def f6(x: Int => Unit) = x + y.x
  def f4(x: Any]  // aborther[Int](x): Int = 0)
  def f7((x: Object ] = 69(); nublog(defaultSubCode(y))

もうちょっと意味のあるメソッド名をつけましょう。

/*                       __                                *\
**    ________ ___   / /  ___     Scala API                        **
**    / __/ __// _ | / /  / _ |    (c) 2003-2013, LAMP/EPFL */ RuntimeUnit = {
        val con = baseClasses
      val n = 0

コーディング中に、ふっとScalaのアスキーアート書きたくなったのかな?

種明かし

……ええと、タイトルやタグでお察しかと思いますが、初心者というのは人間ではなく、手元のMacBook Pro Late 2013です。
機械学習の一種である、ディープラーニングの一種である、RNNの一種である、LSTMを使いました。シェークスピアっぽいテキストなどが生成できたとして、一世を風靡した手法です。詳しくは「The Unreasonable Effectiveness of Recurrent Neural Networks」をどうぞ。

まぁ、それっぽいワードサラダの域を越えるものではないのですが、構文やキーワードの知識なしに、単なる文字列として与えたものをゼロから学習して、さっとこのくらいのものを生成してしまうのは凄いと思います。凄くない?

前処理

それでは生成方法です。
まずは、Scalaコンパイラのソースコードのうち、拡張子.scalaのものを一つのテキストファイルにまとめます。43万行!

$ git clone https://github.com/scala/scala.git
$ cd scala
$ git checkout refs/tags/v2.12.1
$ find . -type f -name "*.scala" -exec cat {} \; -exec echo "" \; > scala_v2-12-1.txt

実装

さて、生成コードです。100行!
フレームワークはDeeplearning4jを利用していて、パラメータは公式のサンプルコードGravesLSTMCharModellingExample.javaのものをほぼそのまま使っています。
ただ、Scalaで書きなおしたため、だいぶループ処理がすっきりしました。
入力テキストに使われている文字の数え上げから行うので、ソースコードに限らずテキストなら何でも学習できるはずです。

src/main/scala/Main.scala
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.layers.{GravesLSTM, RnnOutputLayer}
import org.deeplearning4j.nn.conf.{BackpropType, NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.dataset.DataSet
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
import org.slf4j.LoggerFactory

import scala.io.Source
import scala.util.Random

object Main {
  Random.setSeed(20161224)
  val logger = LoggerFactory.getLogger(getClass)

  def main(args: Array[String]) {
    val path = args.head
    val source = Using(Source.fromFile(path, "UTF-8"))(_.getLines.mkString("\n"))
    val decode = source.groupBy(identity).filter(_._2.length >= 100).keys.toSeq
    val encode = decode.zipWithIndex.toMap
    val codes = source.flatMap(encode.get)

    val net = new MultiLayerNetwork(
      new NeuralNetConfiguration.Builder()
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .iterations(1)
        .learningRate(0.1)
        .rmsDecay(0.95)
        .seed(12345)
        .regularization(true)
        .l2(0.001)
        .weightInit(WeightInit.XAVIER)
        .updater(Updater.RMSPROP)
        .list()
        .layer(0, new GravesLSTM.Builder().nIn(decode.size).nOut(200).activation("tanh").build())
        .layer(1, new GravesLSTM.Builder().nIn(200).nOut(200).activation("tanh").build())
        .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax").nIn(200).nOut(decode.size).build())
        .backpropType(BackpropType.TruncatedBPTT)
        .tBPTTForwardLength(50)
        .tBPTTBackwardLength(50)
        .pretrain(false)
        .backprop(true)
        .build())
    net.init()

    for (epochIndex <- 0 until 5) {
      for ((batch, batchIndex) <- Random.shuffle(codes.grouped(1000).toSeq.init).grouped(32).zipWithIndex) {
        // 学習
        val input = Nd4j.zeros(batch.size, decode.size, batch.head.size)
        val label = Nd4j.zeros(batch.size, decode.size, batch.head.size)
        for {
          (sample, i) <- batch.zipWithIndex
          (charPair, j) <- sample.sliding(2).zipWithIndex
        } {
          input.putScalar(Array(i, charPair(0), j), 1.0f)
          label.putScalar(Array(i, charPair(1), j), 1.0f)
        }
        net.fit(new DataSet(input, label))

        // 生成
        if (batchIndex % 10 == 0) {
          net.rnnClearPreviousState()

          val init = "import "
          val sampleNum = 3
          val initInput = Nd4j.zeros(sampleNum, decode.size, init.length)
          for {
            i <- 0 until sampleNum
            (code, j) <- init.map(encode).zipWithIndex
          } {
            initInput.putScalar(Array(i, code, j), 1.0f)
          }

          var output = net.rnnTimeStep(initInput).tensorAlongDimension(init.length - 1, 1, 0)
          val genCodeMatrix =
            for (_ <- 0 until 2000) yield {
              val nextInput = Nd4j.zeros(sampleNum, decode.size)
              val genCodes =
                for (s <- 0 until sampleNum) yield {
                  val f = Random.nextFloat()
                  val genCode = decode.indices.map(output.getFloat(s, _)).scan(0.0f)(_ + _).tail.indexWhere(_ >= f)

                  nextInput.putScalar(Array(s, genCode), 1.0f)
                  genCode
                }

              output = net.rnnTimeStep(nextInput)
              genCodes
            }

          for ((sample, sampleIndex) <- genCodeMatrix.transpose.map(init + _.map(decode).mkString).zipWithIndex) {
            logger.info(s"# $epochIndex-$batchIndex-$sampleIndex\n\n" + sample)
          }
        }
      }
    }
  }
}

なお、少しコード整理したので、初めに紹介した学習結果は再現しないと思います。
パラメータチューニングの余地は色々とありそうです。

実行環境

src/main/scala/Using.scala
import java.io.Writer

import scala.io.Source

object Using {
  def apply[A, B](resource: A)(process: A => B)(implicit closer: Closer[A]): B =
    try {
      process(resource)
    } finally {
      closer.close(resource)
    }
}

case class Closer[-A](close: A => Unit)

object Closer {
  implicit val sourceCloser: Closer[Source] = Closer(_.close())
  implicit val writerCloser: Closer[Writer] = Closer(_.close())
}

Scalaで一番よく使うローンパターン

build.sbt
name := "scalikeCoder"
version := "0.1.0"
scalaVersion := "2.12.1"
classpathTypes += "maven-plugin"
libraryDependencies ++= Seq(
  "org.deeplearning4j" % "deeplearning4j-core" % "0.7.1",
  "org.nd4j" % "nd4j-native" % "0.7.1" classifier "" classifier "macosx-x86_64",
  "ch.qos.logback" % "logback-classic" % "1.1.8"
)
$ java -version
java version "1.8.0_101"
Java(TM) SE Runtime Environment (build 1.8.0_101-b13)
Java HotSpot(TM) 64-Bit Server VM (build 25.101-b13, mixed mode)

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.11.5
BuildVersion:   15F34
$ sbt "run scala_v2-12-1.txt"
4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2