Skip to content

Commit

Permalink
Fix compilation error and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi committed Jun 12, 2024
1 parent 1cba7e3 commit 903e780
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 14 deletions.
8 changes: 3 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ModelArch, ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.ml.util._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.slf4j.{Logger, LoggerFactory}
import org.intel.openvino.Tensor
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -233,9 +233,7 @@ private[johnsnowlabs] class Bert(
.map(sentence => sentence.map(x => if (x == 0) 0L else 1L))
.toArray
val maskTensors =
OnnxTensor.createTensor(
env,
attentionMask)
OnnxTensor.createTensor(env, attentionMask)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)
Expand Down
7 changes: 2 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ModelArch, ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.ml.util._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.slf4j.{Logger, LoggerFactory}
Expand Down Expand Up @@ -243,10 +243,7 @@ private[johnsnowlabs] class XlmRoberta(
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
attentionMask
)
OnnxTensor.createTensor(env, attentionMask)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ trait ReadMPNetForTokenDLModel extends ReadOnnxModel {
case TensorFlow.name =>
throw new NotImplementedError("Tensorflow models are not supported.")
case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
val onnxWrapper =
OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, Some(onnxWrapper))
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/


package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.nlp.annotators.Tokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class XlmRoBertaForQuestionAnsweringTestSpec extends AnyFlatSpec {
val loadedPipelineModel = PipelineModel.load("./tmp_xlmrobertaforquestion_pipeline")
loadedPipelineModel.transform(ddd).select("label.result").show(false)


}

"XlmRoBertaForQuestionAnswering" should "benchmark test" taggedAs SlowTest in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class XlmRoBertaForTokenClassificationTestSpec extends AnyFlatSpec {

}


"XlmRoBertaForTokenClassification" should "be saved and loaded correctly" taggedAs SlowTest in {

import ResourceHelper.spark.implicits._
Expand Down

0 comments on commit 903e780

Please sign in to comment.