Skip to content

Commit 3c322e4

Browse files
authored
Merge pull request #212 from Atry/EagerExecution
Add an Eager LDK for eager execution
2 parents 46c063b + c468364 commit 3c322e4

File tree

9 files changed

+225
-3
lines changed

9 files changed

+225
-3
lines changed

plugins-CumulativeFloatLayers/build.sbt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ exampleSuperTypes += "_root_.com.thoughtworks.deeplearning.scalatest.Thoughtwork
1010
libraryDependencies += "com.thoughtworks.each" %% "each" % "3.3.1" % Test
1111

1212
scalacOptions += "-Ypartial-unification"
13+
14+
addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "1.0.0-RC9")
15+
16+
addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-reseteverywhere" % "1.0.0-RC9")

plugins-CumulativeFloatLayers/src/test/scala/com/thoughtworks/deeplearning/plugins/CumulativeFloatLayersSpec.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,36 @@ final class CumulativeFloatLayersSpec
378378
}
379379

380380
}
381+
382+
"EagerExecution" in {
383+
384+
val hyperparameters =
385+
Factory[FloatTraining with Operators with FloatLiterals with CumulativeFloatLayers with ImplicitsSingleton with FixedLearningRate]
386+
.newInstance(fixedLearningRate = 1.0f)
387+
388+
import hyperparameters.implicits._
389+
390+
val weight = hyperparameters.FloatWeight(1.0f)
391+
392+
def myNetwork(input: Float): hyperparameters.FloatLayer = {
393+
// FIXME: inlining !-notation does not compile due to https://github.com/ThoughtWorksInc/Dsl.scala/issues/119
394+
// 6.7f + !(input + weight) + weight + 5.5f
395+
396+
val f = !(input + weight)
397+
6.7f + f + weight + 5.5f
398+
}: @com.thoughtworks.dsl.Dsl.reset
399+
400+
def train(inputData: Float): Future[Float] = {
401+
myNetwork(inputData).train
402+
}
403+
404+
for {
405+
_ <- train(1.0f)
406+
_ <- train(1.0f)
407+
_ <- train(1.0f)
408+
_ <- train(1.0f)
409+
_ <- train(1.0f)
410+
} yield weight.data should be(-4)
411+
412+
}
381413
}

plugins-EagerExecution/build.sbt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
libraryDependencies += "com.thoughtworks.dsl" %% "dsl" % "1.0.0-RC9"
2+
3+
libraryDependencies += "com.thoughtworks.dsl" %% "domains-scalaz" % "1.0.0-RC9" % Test
4+
5+
addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "1.0.0-RC9")
6+
7+
addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-reseteverywhere" % "1.0.0-RC9")
8+
9+
fork in Test := true
10+
11+
enablePlugins(Example)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package com.thoughtworks.deeplearning.plugins
2+
3+
import com.thoughtworks.deeplearning.DeepLearning
4+
import com.thoughtworks.deeplearning.plugins.EagerExecution.Eager
5+
import com.thoughtworks.dsl.Dsl
6+
import com.thoughtworks.dsl.Dsl.{Keyword, shift}
7+
import com.thoughtworks.feature.Factory
8+
9+
import scala.annotation.compileTimeOnly
10+
11+
/**
12+
* @author 杨博 (Yang Bo)
13+
*/
14+
trait EagerExecution extends Layers {
15+
16+
//
17+
// trait LayerApi extends super.LayerApi {
18+
//
19+
//
20+
// @shift
21+
// @compileTimeOnly(
22+
// """This method requires the compiler plugin: `addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "latest.release")` and must only be called inside a code block annotated as `@reset`.""")
23+
// final def data : Data = {
24+
// throw new IllegalAccessException(
25+
// """This method requires the compiler plugin: `addCompilerPlugin("com.thoughtworks.dsl" %% "compilerplugins-bangnotation" % "latest.release")` and must only be called inside a code block annotated as `@reset`."""
26+
// )
27+
// }
28+
//
29+
// @inline
30+
// final def cpsApply[Domain](handler: Data => Domain): Domain = {
31+
// ???
32+
//// dsl.interpret(this, handler)
33+
// }
34+
//
35+
// }
36+
//
37+
// type Layer <: LayerApi
38+
// def Eager[A](a: A)(implicit deepLearning: DeepLearning[A]): Eager[A, deepLearning.Data, deepLearning.Delta] = {
39+
// new Eager[A, deepLearning.Data, deepLearning.Delta](a, deepLearning)
40+
// }
41+
42+
}
43+
44+
object EagerExecution {
45+
final case class Eager[Differentiable, Data, Delta](differentiable: Differentiable)(
46+
implicit
47+
deepLearning: DeepLearning.Aux[Differentiable, Data, Delta])
48+
extends Keyword[Eager[Differentiable, Data, Delta], Data]
49+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package com.thoughtworks.deeplearning
2+
package plugins
3+
4+
import com.thoughtworks.deeplearning.plugins.EagerExecution.Eager
5+
import com.thoughtworks.feature.Factory
6+
import org.scalactic.ErrorMessage
7+
import org.scalatest._
8+
import com.thoughtworks.future._
9+
import com.thoughtworks.deeplearning.scalatest.ThoughtworksFutureToScalaFuture
10+
import scalaz.std.iterable._
11+
12+
object EagerExecutionSpec {
13+
14+
trait FixedLearningRate extends LearningRate {
15+
def fixedLearningRate: scala.Float
16+
trait FloatOptimizerApi extends super.FloatOptimizerApi { this: FloatOptimizer =>
17+
final def learningRate: scala.Float = fixedLearningRate
18+
}
19+
override type FloatOptimizer <: FloatOptimizerApi with Optimizer
20+
}
21+
22+
trait LearningRate extends FloatWeights {
23+
trait FloatOptimizerApi extends super.FloatOptimizerApi { this: FloatOptimizer =>
24+
def learningRate: scala.Float
25+
override def delta: scala.Float = super.delta * learningRate
26+
}
27+
override type FloatOptimizer <: FloatOptimizerApi with Optimizer
28+
}
29+
30+
}
31+
32+
/**
33+
* @author 杨博 (Yang Bo)
34+
*/
35+
final class EagerExecutionSpec extends AsyncFreeSpec with Matchers with Inside with ThoughtworksFutureToScalaFuture {
36+
37+
import EagerExecutionSpec._
38+
39+
"EagerExecution" in {
40+
41+
val hyperparameters =
42+
Factory[FloatTraining with EagerExecution with Operators with FloatLiterals with CumulativeFloatLayers with ImplicitsSingleton with FixedLearningRate]
43+
.newInstance(fixedLearningRate = 1.0f)
44+
45+
import hyperparameters.FloatLayer
46+
47+
import hyperparameters.implicits._
48+
49+
val weight = hyperparameters.FloatWeight(1.0f)
50+
51+
def myNetwork(input: Float): FloatLayer = {
52+
// FIXME: inlining !-notation does not compile due to https://github.com/ThoughtWorksInc/Dsl.scala/issues/119
53+
// 6.7f + !Eager(input + weight) + weight + 5.5f
54+
55+
val f = !Eager(input + weight)
56+
6.7f + f + weight + 5.5f
57+
}: @com.thoughtworks.dsl.Dsl.reset
58+
59+
def train(inputData: Float): Future[Float] = {
60+
myNetwork(inputData).train
61+
}
62+
63+
for {
64+
_ <- train(1.0f)
65+
_ <- train(1.0f)
66+
_ <- train(1.0f)
67+
_ <- train(1.0f)
68+
_ <- train(1.0f)
69+
} yield weight.data should be(-4)
70+
71+
}
72+
73+
}

plugins-FloatLayers/src/main/scala/com/thoughtworks/deeplearning/plugins/FloatLayers.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import com.thoughtworks.deeplearning.DeepLearning.Tape
44
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
55
import com.thoughtworks.feature.Factory.inject
66
import com.thoughtworks.raii.asynchronous._
7-
87
import scalaz.syntax.all._
8+
99
import scala.annotation.meta.getter
1010
import scalaz.Apply
1111
import com.thoughtworks.continuation._
1212
import com.thoughtworks.future._
1313
import DeepLearning.ops._
14+
import com.thoughtworks.deeplearning.plugins.Layers.Eager
15+
import com.thoughtworks.dsl.Dsl
1416

1517
/** A plugin that provides differentiable operators
1618
* on neural networks whose [[DeepLearning.Data Data]] and [[DeepLearning.Delta Delta]] is [[scala.Float]].
@@ -26,6 +28,18 @@ import DeepLearning.ops._
2628
trait FloatLayers extends Layers {
2729

2830
trait ImplicitsApi extends super[Layers].ImplicitsApi {
31+
implicit def eagerFloatDsl[Differentiable, Data, Delta, Constructor, Out <: FloatLayer](
32+
implicit implicitApply: ImplicitApply.Aux[floatPartialApplyRawForward.Rest, Out]
33+
): Dsl[Eager[Differentiable, Data, Delta], FloatLayer, Data] = {
34+
new Dsl[Eager[Differentiable, Data, Delta], FloatLayer, Data] {
35+
def interpret(keyword: Eager[Differentiable, Data, Delta], handler: Data => FloatLayer): Out =
36+
FloatLayer(
37+
keyword.deepLearning.forward(keyword.operand0).flatMap { tape =>
38+
handler(tape.data).internalForward
39+
}
40+
)
41+
}
42+
}
2943

3044
/** An implicit wrapper that adds extension methods for differentiable float types
3145
* that support the [[DeepLearning]] type class.
@@ -275,8 +289,12 @@ trait FloatLayers extends Layers {
275289
*/
276290
protected val rawForward: Do[Tape[Float, Float]]
277291

292+
/** A bridge for calling [[rawForward]] in [[FloatLayers]] */
293+
private[FloatLayers] final def internalForward: Do[Tape[Float, Float]] = rawForward
294+
278295
override def forward: Do[Tape[Float, Float]] = rawForward
279296
}
297+
280298
object FloatLayer {
281299

282300
/** @usecase def apply(forward: Do[Tape[Float, Float]]): FloatLayer = ???

plugins-INDArrayLayers/src/main/scala-2.11/com/thoughtworks/deeplearning/plugins/INDArrayLayers.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import com.thoughtworks.feature.Factory.inject
88
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
99
import com.thoughtworks.raii.asynchronous._
1010
import org.nd4j.linalg.api.ndarray.INDArray
11-
1211
import scalaz.syntax.all._
1312
import scalaz.Tags.Parallel
1413
import scalaz.Semigroup
@@ -19,6 +18,8 @@ import org.nd4j.linalg.factory.Nd4j
1918

2019
import scala.concurrent.ExecutionContext
2120
import com.thoughtworks.continuation._
21+
import com.thoughtworks.deeplearning.plugins.Layers.Eager
22+
import com.thoughtworks.dsl.Dsl
2223

2324
object INDArrayLayers {
2425

@@ -176,6 +177,19 @@ trait INDArrayLayers extends DoubleLayers with DoubleLiterals with ImplicitsSing
176177
}
177178
trait ImplicitsApi extends super[DoubleLiterals].ImplicitsApi with super[DoubleLayers].ImplicitsApi {
178179

180+
implicit def eagerINDArrayDsl[Differentiable, Data, Delta, Constructor, Out <: INDArrayLayer](
181+
implicit implicitApply: ImplicitApply.Aux[indArrayPartialApplyRawForward.Rest, Out]
182+
): Dsl[Eager[Differentiable, Data, Delta], INDArrayLayer, Data] = {
183+
new Dsl[Eager[Differentiable, Data, Delta], INDArrayLayer, Data] {
184+
def interpret(keyword: Eager[Differentiable, Data, Delta], handler: Data => INDArrayLayer): Out =
185+
INDArrayLayer(
186+
keyword.deepLearning.forward(keyword.operand0).flatMap { tape =>
187+
handler(tape.data).internalForward
188+
}
189+
)
190+
}
191+
}
192+
179193
/** An implicit wrapper that adds extension methods for differentiable n-dimensional array types
180194
* that support the [[DeepLearning]] type class.
181195
*/
@@ -719,6 +733,9 @@ trait INDArrayLayers extends DoubleLayers with DoubleLiterals with ImplicitsSing
719733
*/
720734
protected val rawForward: Do[Tape[INDArray, INDArray]]
721735

736+
/** A bridge for calling [[rawForward]] in [[INDArrayLayers]] */
737+
private[INDArrayLayers] final def internalForward: Do[Tape[INDArray, INDArray]] = rawForward
738+
722739
override def forward: Do[Tape[INDArray, INDArray]] = rawForward
723740
}
724741
object INDArrayLayer {

plugins-Layers/build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ libraryDependencies += "com.thoughtworks.feature" %% "implicitapply" % "2.3.0-M8
66

77
libraryDependencies += "com.thoughtworks.feature" %% "factory" % "2.3.0-M8"
88

9+
libraryDependencies += "com.thoughtworks.dsl" %% "dsl" % "1.0.0-RC9"

plugins-Layers/src/main/scala/com/thoughtworks/deeplearning/plugins/Layers.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,24 @@ package plugins
33
import java.util.logging.Logger
44

55
import com.thoughtworks.deeplearning.DeepLearning.Tape
6+
import com.thoughtworks.dsl.Dsl.Keyword
67
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply, The}
78
import com.thoughtworks.feature.Factory.inject
89
import com.thoughtworks.raii.asynchronous.Do
910
import com.thoughtworks.raii.asynchronous.Do._
1011
import com.thoughtworks.raii.shared._
1112
import shapeless.{Poly1, Poly2}
1213
import shapeless.poly.Case1
13-
1414
import scalaz.syntax.all._
15+
import scala.language.implicitConversions
16+
1517
import scala.annotation.meta.getter
1618
import com.thoughtworks.future.Future
1719

1820
/** A plugin that enables [[Layer]] in neural networks. */
1921
trait Layers {
22+
import com.thoughtworks.deeplearning.plugins.Layers._
23+
2024
trait LayerApi {
2125
type Data
2226
type Delta
@@ -33,6 +37,13 @@ trait Layers {
3337
type Layer <: LayerApi
3438

3539
trait ImplicitsApi {
40+
41+
@inline
42+
implicit def implicitEager[Operand0](a: Operand0)(
43+
implicit deepLearning: DeepLearning[Operand0]): Eager[Operand0, deepLearning.Data, deepLearning.Delta] = {
44+
new Eager[Operand0, deepLearning.Data, deepLearning.Delta](a)(deepLearning)
45+
}
46+
3647
implicit def layerDeepLearning[From, Data0, Delta0](implicit asLayer: From <:< LayerApi {
3748
type Data = Data0
3849
type Delta = Delta0
@@ -51,3 +62,9 @@ trait Layers {
5162
type Implicits <: ImplicitsApi
5263

5364
}
65+
66+
object Layers {
67+
final case class Eager[Operand0, Data, Delta](operand0: Operand0)(
68+
implicit val deepLearning: DeepLearning.Aux[Operand0, Data, Delta])
69+
extends Keyword[Eager[Operand0, Data, Delta], Data]
70+
}

0 commit comments

Comments
 (0)