ScalaTestのDeeplearning4jのテストを簡単にしたい

Deeplearning4jのNDArray同士をアバウトに比較したい

scala> import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.Nd4j
scala> val a = Nd4j.create(Array(1.0, 2.0))
a: org.nd4j.linalg.api.ndarray.INDArray = [    1.0000,    2.0000]
scala> val b = Nd4j.create(Array(1.1, 1.9))
b: org.nd4j.linalg.api.ndarray.INDArray = [    1.1000,    1.9000]
scala> a == b
res0: Boolean = false

当然数値で比較するとfalseなわけですが,これをちょっとくらいの誤差だったらtrueでイコールとみなすようなことをしたいです.

Scalacticを使う

まさのそのためのライブラリがあります.

Scalactic

The Scalactic library is focused on constructs related to quality that are useful in both production code and tests. Although ScalaTest is tightly integrated with Scalactic, you can use Scalactic with any Scala project and any test framework. The Scalactic library has no dependencies other than Scala itself.


Scalaticライブラリは高品質なコードのためのもので,プロダクションコードとテストの両方にとって有益です.ScalaTestはScalaticと統合されているので,Scalaのプロジェクトとあらゆるテストフレームワークにおいて利用することができます.ScalaticライブラリはScala自身以外への依存はありません.

Tolerance

まず,Tolerance(耐性,許容)を使って数値をアバウトに比較できるようにします.
[ 1.0000, 2.0000]と[ 1.1000, 1.9000]が等しくなるようにしたい.まずは1.0と1.1が等しかったり,2.0と1.9が等しかったりするようにします.

scala> import org.scalactic._
import org.scalactic._
scala> import TripleEquals._
import TripleEquals._
scala> import Tolerance._
import Tolerance._

scala> val a = 1.0
a: Double = 1.0
scala> val b = 2.0
b: Double = 2.0

scala> a === 1.1 +- 0.1
res3: Boolean = true
scala> b === 1.9 +- 0.1
res4: Boolean = true

なんと簡単.===じゃないとダメです.==ならこうなります.

scala> a == 1.1 +- 0.1
<console>:23: warning: comparing values of types Double and org.scalactic.TripleEqualsSupport.Spread[Double] using `==' will always yield false
       a == 1.1 +- 0.1
         ^

型が違うので等しいとは判断されないということです.
この辺りの説明はオフィシャルドキュメントにあります.
https://www.scalactic.org/user_guide/Tolerance

ただ,いちいち+-を使って誤差許容範囲を書くのはめんどくさい.scalaのimplicitを使うとこう書けます.

scala> import org.scalactic.TolerantNumerics
import org.scalactic.TolerantNumerics
scala> val epsilon = 1e-1
scala> implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(epsilon)
doubleEq: org.scalactic.Equality[Double] = TolerantDoubleEquality(0.1)

scala> a === 1.1
res6: Boolean = true

この場合,誤差範囲が0.1より大きくなると等しくなくなります.

scala> a === 1.2
res9: Boolean = false

目的は数値同士の比較ではなく,NDArrayのベクトルなり行列なりを比較することなので,もう少し拡張します.

Equality

Equalityを使って,ベクトル同士をどう比較するかを定義します.通常の比較だとベクトルの各成分を==で比較しますが,それを変更し,ベクトルの各成分をTolerance付きの===で比較するように変更すれば良さそうです.

scala>  import org.scalactic._
import org.scalactic._

scala> import TripleEquals._
import TripleEquals._

scala>  import Tolerance._
import Tolerance._

scala> import org.scalactic.TolerantNumerics
import org.scalactic.TolerantNumerics

scala> val epsilon = 1e-1
epsilon: Double = 0.1

scala> implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(epsilon)
doubleEq: org.scalactic.Equality[Double] = TolerantDoubleEquality(0.1)

scala> import org.scalactic.Equality
import org.scalactic.Equality

scala> import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ndarray.INDArray

scala> implicit val ndarrayEq =
     |   new Equality[INDArray] {
     |     def areEqual(a: INDArray, b: Any): Boolean =
     |       b match {
     |         case p: INDArray =>  a.shape === p.shape &&
     |             (0L until a.shape()(0)).map(f =>
     |               a.getDouble(f) === p.getDouble(f))
     |             .foldLeft(true)((g,h) => g && h)
     |         case _ => false
     |       }
     |
     |   }
ndarrayEq: org.scalactic.Equality[org.nd4j.linalg.api.ndarray.INDArray] = $anon$1@5c5bf478
scala>  import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.Nd4j

scala> val a = Nd4j.create(Array(1.0, 2.0))
log4j:WARN No appenders could be found for logger (org.nd4j.linalg.factory.Nd4jBackend).
log4j:WARN Please initialize the log4j system properly.
log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.
a: org.nd4j.linalg.api.ndarray.INDArray = [    1.0000,    2.0000]

scala> val b = Nd4j.create(Array(1.1, 1.9))
b: org.nd4j.linalg.api.ndarray.INDArray = [    1.1000,    1.9000]

scala> a === b
res0: Boolean = true

Equalityについては,オフィシャルドキュメントに説明があります.
https://www.scalactic.org/user_guide/CustomEquality


一応イケたんですが,1次元ベクトルしか対応していないので,行列やテンソルに拡張するにはもう少し改良が必要です.