From 76773545ca8e8d77c0b01eca6c013a16f7e47e10 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 26 Oct 2018 00:07:16 -0400 Subject: [PATCH] [IMPLICITS] Removed 'Strict' from some methods where it was not necessary. --- build.sbt | 4 ++-- .../implicits/helpers/DataTypeStructure.scala | 8 +++---- .../implicits/helpers/DataTypeToOutput.scala | 4 ++-- .../implicits/helpers/DataTypeToShape.scala | 6 ++--- .../api/implicits/helpers/OpStructure.scala | 4 ++-- .../implicits/helpers/OutputStructure.scala | 12 +++++----- .../implicits/helpers/OutputToDataType.scala | 10 ++++---- .../api/implicits/helpers/OutputToShape.scala | 14 +++++------ .../implicits/helpers/OutputToTensor.scala | 8 +++---- .../implicits/helpers/ShapeStructure.scala | 8 +++---- .../implicits/helpers/TensorStructure.scala | 4 ++-- .../implicits/helpers/TensorToDataType.scala | 4 ++-- .../implicits/helpers/TensorToOutput.scala | 6 ++--- .../api/implicits/helpers/TensorToShape.scala | 4 ++-- .../api/implicits/helpers/Zero.scala | 6 ++--- .../learn/estimators/InMemoryEstimator.scala | 16 ++++++------- .../tensorflow/api/ops/data/Dataset.scala | 24 ++++++++++++------- 17 files changed, 75 insertions(+), 67 deletions(-) diff --git a/build.sbt b/build.sbt index 83cb9c7f3..53521d0f3 100644 --- a/build.sbt +++ b/build.sbt @@ -24,11 +24,11 @@ crossScalaVersions in ThisBuild := Seq("2.11.12", "2.12.7") organization in ThisBuild := "org.platanios" +autoCompilerPlugins in ThisBuild := true + val tensorFlowVersion = "1.11.0" val circeVersion = "0.10.0" // Use for working with JSON. -autoCompilerPlugins in ThisBuild := true - // addCompilerPlugin(MetalsPlugin.semanticdbScalac) scalacOptions in ThisBuild ++= Seq( diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeStructure.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeStructure.scala index c84981c4b..088420335 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeStructure.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeStructure.scala @@ -167,15 +167,15 @@ object DataTypeStructure { implicit def fromHList[HD, TD <: HList](implicit evH: Strict[DataTypeStructure[HD]], - evT: Strict[DataTypeStructure[TD]] + evT: DataTypeStructure[TD] ): DataTypeStructure[HD :: TD] = { new DataTypeStructure[HD :: TD] { override def size(dataType: HD :: TD): Int = { - evH.value.size(dataType.head) + evT.value.size(dataType.tail) + evH.value.size(dataType.head) + evT.size(dataType.tail) } override def dataTypes(dataType: HD :: TD): Seq[DataType[Any]] = { - evH.value.dataTypes(dataType.head) ++ evT.value.dataTypes(dataType.tail) + evH.value.dataTypes(dataType.head) ++ evT.dataTypes(dataType.tail) } override def decodeDataType( @@ -183,7 +183,7 @@ object DataTypeStructure { dataTypes: Seq[DataType[Any]] ): (HD :: TD, Seq[DataType[Any]]) = { val (headOut, headRemaining) = evH.value.decodeDataType(dataType.head, dataTypes) - val (tailOut, tailRemaining) = evT.value.decodeDataType(dataType.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeDataType(dataType.tail, headRemaining) (headOut :: tailOut, tailRemaining) } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToOutput.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToOutput.scala index 2518d5e36..f0479bb1c 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToOutput.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToOutput.scala @@ -108,13 +108,13 @@ object DataTypeToOutput { implicit def fromHList[HD, HO, TD <: HList, TO <: HList](implicit evH: Strict[DataTypeToOutput.Aux[HD, HO]], - evT: Strict[DataTypeToOutput.Aux[TD, TO]] + evT: DataTypeToOutput.Aux[TD, TO] ): DataTypeToOutput.Aux[HD :: TD, HO :: TO] = { new DataTypeToOutput[HD :: TD] { override type O = HO :: TO override def dataTypeStructure: DataTypeStructure[HD :: TD] = { - DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.value.dataTypeStructure) + DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.dataTypeStructure) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToShape.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToShape.scala index b83037e06..85b8a56be 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToShape.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/DataTypeToShape.scala @@ -164,13 +164,13 @@ object DataTypeToShape { implicit def fromHList[HD, HS, TD <: HList, TS <: HList](implicit evH: Strict[DataTypeToShape.Aux[HD, HS]], - evT: Strict[DataTypeToShape.Aux[TD, TS]] + evT: DataTypeToShape.Aux[TD, TS] ): DataTypeToShape.Aux[HD :: TD, HS :: TS] = { new DataTypeToShape[HD :: TD] { override type S = HS :: TS override def sizeFromDataType(dataType: HD :: TD): Int = { - evH.value.sizeFromDataType(dataType.head) + evT.value.sizeFromDataType(dataType.tail) + evH.value.sizeFromDataType(dataType.head) + evT.sizeFromDataType(dataType.tail) } override def decodeShape( @@ -178,7 +178,7 @@ object DataTypeToShape { shapes: Seq[Shape] ): (HS :: TS, Seq[Shape]) = { val (headOut, headRemaining) = evH.value.decodeShape(dataType.head, shapes) - val (tailOut, tailRemaining) = evT.value.decodeShape(dataType.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeShape(dataType.tail, headRemaining) (headOut :: tailOut, tailRemaining) } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OpStructure.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OpStructure.scala index 2d187ae16..1970ac209 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OpStructure.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OpStructure.scala @@ -119,12 +119,12 @@ trait NestedStructureOpsLowPriority { implicit def fromHList[H, T <: HList](implicit evH: Strict[OpStructure[H]], - evT: Strict[OpStructure[T]] + evT: OpStructure[T] ): OpStructure[H :: T] = { new OpStructure[H :: T] { override def ops(executable: H :: T): Set[UntypedOp] = { evH.value.ops(executable.head) ++ - evT.value.ops(executable.tail) + evT.ops(executable.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputStructure.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputStructure.scala index 3504d1a1a..6eee72639 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputStructure.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputStructure.scala @@ -49,7 +49,7 @@ import scala.language.higherKinds * * @author Emmanouil Antonios Platanios */ -sealed trait OutputStructure[T] { +trait OutputStructure[T] { def size(output: T): Int def outputs(output: T): Seq[Output[Any]] def decodeOutput(output: T, outputs: Seq[Output[Any]]): (T, Seq[Output[Any]]) @@ -362,17 +362,17 @@ object OutputStructure { implicit def fromHList[HT, TT <: HList](implicit evH: Strict[OutputStructure[HT]], - evT: Strict[OutputStructure[TT]] + evT: OutputStructure[TT] ): OutputStructure[HT :: TT] = { new OutputStructure[HT :: TT] { override def size(output: HT :: TT): Int = { evH.value.size(output.head) + - evT.value.size(output.tail) + evT.size(output.tail) } override def outputs(output: HT :: TT): Seq[Output[Any]] = { evH.value.outputs(output.head) ++ - evT.value.outputs(output.tail) + evT.outputs(output.tail) } override def decodeOutput( @@ -380,7 +380,7 @@ object OutputStructure { outputs: Seq[Output[Any]] ): (HT :: TT, Seq[Output[Any]]) = { val (headOut, headRemaining) = evH.value.decodeOutput(output.head, outputs) - val (tailOut, tailRemaining) = evT.value.decodeOutput(output.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeOutput(output.tail, headRemaining) (headOut :: tailOut, tailRemaining) } @@ -389,7 +389,7 @@ object OutputStructure { converter: OutputStructure.Converter ): HT :: TT = { evH.value.map(value.head, converter) :: - evT.value.map(value.tail, converter) + evT.map(value.tail, converter) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToDataType.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToDataType.scala index 835dc2f05..2d3ae873b 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToDataType.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToDataType.scala @@ -326,21 +326,21 @@ object OutputToDataType { implicit def fromHList[HT, HD, TT <: HList, TD <: HList](implicit evH: Strict[OutputToDataType.Aux[HT, HD]], - evT: Strict[OutputToDataType.Aux[TT, TD]] + evT: OutputToDataType.Aux[TT, TD] ): OutputToDataType.Aux[HT :: TT, HD :: TD] = { new OutputToDataType[HT :: TT] { override type D = HD :: TD override def dataTypeStructure: DataTypeStructure[HD :: TD] = { - DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.value.dataTypeStructure) + DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.dataTypeStructure) } override def sizeFromDataType(dataType: HD :: TD): Int = { - evH.value.sizeFromDataType(dataType.head) + evT.value.sizeFromDataType(dataType.tail) + evH.value.sizeFromDataType(dataType.head) + evT.sizeFromDataType(dataType.tail) } override def dataType(output: HT :: TT): HD :: TD = { - evH.value.dataType(output.head) :: evT.value.dataType(output.tail) + evH.value.dataType(output.head) :: evT.dataType(output.tail) } override def decodeOutput( @@ -348,7 +348,7 @@ object OutputToDataType { outputs: Seq[Output[Any]] ): (HT :: TT, Seq[Output[Any]]) = { val (headOut, headRemaining) = evH.value.decodeOutput(dataType.head, outputs) - val (tailOut, tailRemaining) = evT.value.decodeOutput(dataType.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeOutput(dataType.tail, headRemaining) (headOut :: tailOut, tailRemaining) } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToShape.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToShape.scala index ef7471688..65894157b 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToShape.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToShape.scala @@ -452,27 +452,27 @@ object OutputToShape { implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit evH: Strict[OutputToShape.Aux[HT, HS]], - evT: Strict[OutputToShape.Aux[TT, TS]] + evT: OutputToShape.Aux[TT, TS] ): OutputToShape.Aux[HT :: TT, HS :: TS] = { new OutputToShape[HT :: TT] { override type S = HS :: TS override def outputStructure: OutputStructure[HT :: TT] = { implicit val evOutputToShapeH: OutputStructure[HT] = evH.value.outputStructure - implicit val evOutputToShapeT: OutputStructure[TT] = evT.value.outputStructure + implicit val evOutputToShapeT: OutputStructure[TT] = evT.outputStructure OutputStructure[HT :: TT] } override def shapeStructure: ShapeStructure[HS :: TS] = { - ShapeStructure.fromHList[HS, TS](evH.value.shapeStructure, evT.value.shapeStructure) + ShapeStructure.fromHList[HS, TS](evH.value.shapeStructure, evT.shapeStructure) } override def sizeFromOutput(output: HT :: TT): Int = { - evH.value.sizeFromOutput(output.head) + evT.value.sizeFromOutput(output.tail) + evH.value.sizeFromOutput(output.head) + evT.sizeFromOutput(output.tail) } override def shape(output: HT :: TT): HS :: TS = { - evH.value.shape(output.head) :: evT.value.shape(output.tail) + evH.value.shape(output.head) :: evT.shape(output.tail) } override def decodeShape( @@ -480,7 +480,7 @@ object OutputToShape { shapes: Seq[Shape] ): (HS :: TS, Seq[Shape]) = { val (headOut, headRemaining) = evH.value.decodeShape(output.head, shapes) - val (tailOut, tailRemaining) = evT.value.decodeShape(output.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeShape(output.tail, headRemaining) (headOut :: tailOut, tailRemaining) } @@ -490,7 +490,7 @@ object OutputToShape { converter: OutputStructure.Converter ): HT :: TT = { evH.value.map(value.head, shape.map(_.head), converter) :: - evT.value.map(value.tail, shape.map(_.tail), converter) + evT.map(value.tail, shape.map(_.tail), converter) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToTensor.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToTensor.scala index 070a6a467..d689776f9 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToTensor.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/OutputToTensor.scala @@ -283,17 +283,17 @@ object OutputToTensor { implicit def fromHList[HT, HV, TT <: HList, TV <: HList](implicit evH: Strict[OutputToTensor.Aux[HT, HV]], - evT: Strict[OutputToTensor.Aux[TT, TV]] + evT: OutputToTensor.Aux[TT, TV] ): OutputToTensor.Aux[HT :: TT, HV :: TV] = { new OutputToTensor[HT :: TT] { override type V = HV :: TV override def tensorStructure: TensorStructure[HV :: TV] = { - TensorStructure.fromHList[HV, TV](evH.value.tensorStructure, evT.value.tensorStructure) + TensorStructure.fromHList[HV, TV](evH.value.tensorStructure, evT.tensorStructure) } override def size(output: HT :: TT): Int = { - evH.value.size(output.head) + evT.value.size(output.tail) + evH.value.size(output.head) + evT.size(output.tail) } override def decodeTensor( @@ -301,7 +301,7 @@ object OutputToTensor { tensors: Seq[Tensor[Any]] ): (HV :: TV, Seq[Tensor[Any]]) = { val (headOut, headRemaining) = evH.value.decodeTensor(output.head, tensors) - val (tailOut, tailRemaining) = evT.value.decodeTensor(output.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeTensor(output.tail, headRemaining) (headOut :: tailOut, tailRemaining) } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/ShapeStructure.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/ShapeStructure.scala index ffef31207..574d4ebc2 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/ShapeStructure.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/ShapeStructure.scala @@ -163,15 +163,15 @@ object ShapeStructure { implicit def fromHList[HS, TS <: HList](implicit evH: Strict[ShapeStructure[HS]], - evT: Strict[ShapeStructure[TS]] + evT: ShapeStructure[TS] ): ShapeStructure[HS :: TS] = { new ShapeStructure[HS :: TS] { override def size(shape: HS :: TS): Int = { - evH.value.size(shape.head) + evT.value.size(shape.tail) + evH.value.size(shape.head) + evT.size(shape.tail) } override def shapes(shape: HS :: TS): Seq[Shape] = { - evH.value.shapes(shape.head) ++ evT.value.shapes(shape.tail) + evH.value.shapes(shape.head) ++ evT.shapes(shape.tail) } override def decodeShape( @@ -179,7 +179,7 @@ object ShapeStructure { shapes: Seq[Shape] ): (HS :: TS, Seq[Shape]) = { val (headOut, headRemaining) = evH.value.decodeShape(shape.head, shapes) - val (tailOut, tailRemaining) = evT.value.decodeShape(shape.tail, headRemaining) + val (tailOut, tailRemaining) = evT.decodeShape(shape.tail, headRemaining) (headOut :: tailOut, tailRemaining) } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorStructure.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorStructure.scala index 85d11686b..c89af7cbf 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorStructure.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorStructure.scala @@ -98,11 +98,11 @@ object TensorStructure { implicit def fromHList[HT, TT <: HList](implicit evH: Strict[TensorStructure[HT]], - evT: Strict[TensorStructure[TT]] + evT: TensorStructure[TT] ): TensorStructure[HT :: TT] = { new TensorStructure[HT :: TT] { override def tensors(tensor: HT :: TT): Seq[Tensor[Any]] = { - evH.value.tensors(tensor.head) ++ evT.value.tensors(tensor.tail) + evH.value.tensors(tensor.head) ++ evT.tensors(tensor.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToDataType.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToDataType.scala index d52ac82a0..ffca35ee7 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToDataType.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToDataType.scala @@ -146,13 +146,13 @@ object TensorToDataType { implicit def fromHList[HT, HD, TT <: HList, TD <: HList](implicit evH: Strict[TensorToDataType.Aux[HT, HD]], - evT: Strict[TensorToDataType.Aux[TT, TD]] + evT: TensorToDataType.Aux[TT, TD] ): TensorToDataType.Aux[HT :: TT, HD :: TD] = { new TensorToDataType[HT :: TT] { override type D = HD :: TD override def dataType(output: HT :: TT): HD :: TD = { - evH.value.dataType(output.head) :: evT.value.dataType(output.tail) + evH.value.dataType(output.head) :: evT.dataType(output.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToOutput.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToOutput.scala index 646e964de..3dd0b9dad 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToOutput.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToOutput.scala @@ -163,17 +163,17 @@ object TensorToOutput { implicit def fromHList[HT, HO, TT <: HList, TO <: HList](implicit evH: Strict[TensorToOutput.Aux[HT, HO]], - evT: Strict[TensorToOutput.Aux[TT, TO]] + evT: TensorToOutput.Aux[TT, TO] ): TensorToOutput.Aux[HT :: TT, HO :: TO] = { new TensorToOutput[HT :: TT] { override type O = HO :: TO override def tensorStructure: TensorStructure[HT :: TT] = { - TensorStructure.fromHList[HT, TT](evH.value.tensorStructure, evT.value.tensorStructure) + TensorStructure.fromHList[HT, TT](evH.value.tensorStructure, evT.tensorStructure) } override def output(tensor: HT :: TT): HO :: TO = { - evH.value.output(tensor.head) :: evT.value.output(tensor.tail) + evH.value.output(tensor.head) :: evT.output(tensor.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToShape.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToShape.scala index a17bd8143..abe4b0f52 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToShape.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/TensorToShape.scala @@ -146,13 +146,13 @@ object TensorToShape { implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit evH: Strict[TensorToShape.Aux[HT, HS]], - evT: Strict[TensorToShape.Aux[TT, TS]] + evT: TensorToShape.Aux[TT, TS] ): TensorToShape.Aux[HT :: TT, HS :: TS] = { new TensorToShape[HT :: TT] { override type S = HS :: TS override def shape(output: HT :: TT): HS :: TS = { - evH.value.shape(output.head) :: evT.value.shape(output.tail) + evH.value.shape(output.head) :: evT.shape(output.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/Zero.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/Zero.scala index a16ed7795..4d5c322f3 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/Zero.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/Zero.scala @@ -182,13 +182,13 @@ object Zero { implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit evH: Strict[Zero.Aux[HT, HS]], - evT: Strict[Zero.Aux[TT, TS]] + evT: Zero.Aux[TT, TS] ): Zero.Aux[HT :: TT, HS :: TS] = { new Zero[HT :: TT] { override type S = HS :: TS override def evOutputToShape: OutputToShape.Aux[HT :: TT, HS :: TS] = { - OutputToShape.fromHList[HT, HS, TT, TS](evH.value.evOutputToShape, evT.value.evOutputToShape) + OutputToShape.fromHList[HT, HS, TT, TS](evH.value.evOutputToShape, evT.evOutputToShape) } override def zero( @@ -198,7 +198,7 @@ object Zero { ): HT :: TT = { Op.nameScope(name) { evH.value.zero(batchSize, shape.head) :: - evT.value.zero(batchSize, shape.tail) + evT.zero(batchSize, shape.tail) } } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/InMemoryEstimator.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/InMemoryEstimator.scala index b847d5701..9d566b4af 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/InMemoryEstimator.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/InMemoryEstimator.scala @@ -24,7 +24,7 @@ import org.platanios.tensorflow.api.implicits.Implicits._ import org.platanios.tensorflow.api.implicits.helpers._ import org.platanios.tensorflow.api.learn._ import org.platanios.tensorflow.api.learn.hooks._ -import org.platanios.tensorflow.api.ops.{Op, OpSpecification, Output, UntypedOp} +import org.platanios.tensorflow.api.ops.{Op, Output, UntypedOp} import org.platanios.tensorflow.api.ops.control_flow.ControlFlow import org.platanios.tensorflow.api.ops.data.Dataset import org.platanios.tensorflow.api.ops.metrics.Metric @@ -61,7 +61,7 @@ import scala.collection.mutable * * @author Emmanouil Antonios Platanios */ -class InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss: TF : IsFloatOrDouble, EvalIn] private[estimators] ( +class InMemoryEstimator[In: OutputStructure, TrainIn: OutputStructure, Out: OutputStructure, TrainOut, Loss: TF : IsFloatOrDouble, EvalIn] private[estimators] ( override protected val modelFunction: Estimator.ModelFunction[In, TrainIn, Out, TrainOut, Loss, EvalIn], override protected val configurationBase: Configuration = null, val stopCriteria: StopCriteria = StopCriteria(), @@ -71,12 +71,6 @@ class InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss: TF : IsFloatOrDouble, val evaluateHooks: Set[Hook] = Set.empty, val tensorBoardConfig: TensorBoardConfig = null, val evaluationMetrics: Seq[Metric[EvalIn, Output[Float]]] = Seq.empty -)(implicit - evOutputStructureIn: OutputStructure[In], - evOutputStructureTrainIn: OutputStructure[TrainIn], - evOutputStructureOut: OutputStructure[Out], - // This implicit helps the Scala 2.11 compiler. - evOutputStructureInOut: OutputStructure[(In, Out)] ) extends Estimator[In, TrainIn, Out, TrainOut, Loss, EvalIn](modelFunction, configurationBase) { if (trainHooks.exists(_.isInstanceOf[Stopper]) || trainChiefOnlyHooks.exists(_.isInstanceOf[Stopper]) @@ -269,10 +263,16 @@ class InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss: TF : IsFloatOrDouble, stopHook.reset(session) session.enableHooks() session.resetShouldStop() + + // For some reason this is necessary when compiling for Scala 2.11. + val scala211Helper = implicitly[OutputStructure[(In, Out)]] + ev.convertFetched(new Iterator[(InV, OutV)] { override def hasNext: Boolean = !session.shouldStop override def next(): (InV, OutV) = { try { + implicit val evScala211Helper: OutputStructure[(In, Out)] = scala211Helper + // TODO: !!! There might be an issue with the stop criteria here. session.removeHooks(currentTrainHooks ++ evaluateHooks) val output = session.run(fetches = (inferenceOps.input, inferenceOps.output)) diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/Dataset.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/Dataset.scala index b157d0a2b..0f8ba2974 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/Dataset.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/Dataset.scala @@ -701,13 +701,17 @@ abstract class Dataset[T: OutputStructure] { outer => name: String = s"${this.name}/GroupByWindow" )(implicit evOutputToDataType: OutputToDataType.Aux[T, D], - evOutputToShape: OutputToShape.Aux[T, S], - // These implicit helpers is used for Scala 2.11 support. - evOutputToDataType211Helper: OutputToDataType.Aux[(Output[Long], Dataset[T]), (DataType[Long], DataType[Variant])], - evOutputToShape211Helper: OutputToShape.Aux[(Output[Long], Dataset[T]), (Shape, Shape)] + evOutputToShape: OutputToShape.Aux[T, S] ): Dataset[T] = { + // For some reason this is necessary when compiling for Scala 2.11. + val outputToDataType211Helper: OutputToDataType.Aux[(Output[Long], Dataset[T]), (DataType[Long], DataType[Variant])] = OutputToDataType[(Output[Long], Dataset[T])] + val outputToShape211Helper: OutputToShape.Aux[(Output[Long], Dataset[T]), (Shape, Shape)] = OutputToShape[(Output[Long], Dataset[T])] + val providedName = name new Dataset[T] { + implicit val evOutputToDataType211Helper: OutputToDataType.Aux[(Output[Long], Dataset[T]), (DataType[Long], DataType[Variant])] = outputToDataType211Helper + implicit val evOutputToShape211Helper: OutputToShape.Aux[(Output[Long], Dataset[T]), (Shape, Shape)] = outputToShape211Helper + override val name: String = providedName private var instantiatedKeyFunction : Option[InstantiatedFunction[T, Output[Long]]] = None @@ -1381,16 +1385,20 @@ abstract class Dataset[T: OutputStructure] { outer => shardIndex: Long )(implicit evOutputToDataType: OutputToDataType.Aux[T, D], - evOutputToShape: OutputToShape.Aux[T, S], - // These implicit helpers is used for Scala 2.11 support. - evOutputToDataType211Helper: OutputToDataType.Aux[(T, Output[Long]), (D, DataType[Long])], - evOutputToShape211Helper: OutputToShape.Aux[(T, Output[Long]), (S, Shape)] + evOutputToShape: OutputToShape.Aux[T, S] ): Dataset[T] = { + // For some reason this is necessary when compiling for Scala 2.11. + val outputToDataType211Helper: OutputToDataType.Aux[(T, Output[Long]), (D, DataType[Long])] = OutputToDataType[(T, Output[Long])] + val outputToShape211Helper: OutputToShape.Aux[(T, Output[Long]), (S, Shape)] = OutputToShape[(T, Output[Long])] + if (shardIndex >= numShards) throw InvalidArgumentException(s"'index' (= $shardIndex) must be smaller than 'numShards' (= $numShards).") if (numShards == 1) { this } else { + implicit val evOutputToDataType211Helper: OutputToDataType.Aux[(T, Output[Long]), (D, DataType[Long])] = outputToDataType211Helper + implicit val evOutputToShape211Helper: OutputToShape.Aux[(T, Output[Long]), (S, Shape)] = outputToShape211Helper + this.zip(Data.datasetFromRange(0L, Long.MaxValue)) .filter(t => Math.equal(Math.mod(t._2, numShards), shardIndex)) .map(o => o._1)