From 4a069979387c01c1136676c98b755265533a267b Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 2 Aug 2021 17:33:11 -0300 Subject: [PATCH 1/2] Correctly track `training` attribute for traced modules. --- R/script_module.R | 13 +++++++++++++ R/trace.R | 1 + src/script_module.cpp | 8 ++++---- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/R/script_module.R b/R/script_module.R index 60873d44f3..7481989e22 100644 --- a/R/script_module.R +++ b/R/script_module.R @@ -115,6 +115,15 @@ nn_ScriptModule <- R6::R6Class( env = private ) + rm(list = "training", envir = self) + makeActiveBinding( + "training", + fun = function() { + cpp_jit_script_module_is_training(private$ptr) + }, + env = self + ) + }, register_parameter = function(name, param) { private$ptr$register_parameter(name, param) @@ -130,6 +139,10 @@ nn_ScriptModule <- R6::R6Class( }, ..ptr.. = function() { private$ptr + }, + train = function(mode = TRUE) { + private$ptr$train(mode = mode) + invisible(create_nn_module_callable(self)) } ), private = list( diff --git a/R/trace.R b/R/trace.R index f0489ab757..7c3256d34a 100644 --- a/R/trace.R +++ b/R/trace.R @@ -229,6 +229,7 @@ create_script_module <- function(mod) { module$register_module(name, create_script_module(child)) }) + module$train(mod$training) # Let's not keep the constants in the module right now as it might cause more # problems than benefits. In pytorch they are only added if their name is in diff --git a/src/script_module.cpp b/src/script_module.cpp index 213a7cbb1c..55077b1d17 100644 --- a/src/script_module.cpp +++ b/src/script_module.cpp @@ -16,25 +16,25 @@ XPtrTorchjit_named_buffer_list cpp_jit_script_module_buffers (XPtrTorchScriptMod // [[Rcpp::export]] void cpp_jit_script_module_train (XPtrTorchScriptModule self, bool on) { - _lantern_ScriptModule_train(self.get(), on); + lantern_ScriptModule_train(self.get(), on); } // [[Rcpp::export]] void cpp_jit_script_module_set_optimized (XPtrTorchScriptModule self, bool on) { - _lantern_ScriptModule_set_optimized(self.get(), on); + lantern_ScriptModule_set_optimized(self.get(), on); } // [[Rcpp::export]] bool cpp_jit_script_module_is_training (XPtrTorchScriptModule self) { - return _lantern_ScriptModule_is_training(self.get()); + return lantern_ScriptModule_is_training(self.get()); } // [[Rcpp::export]] bool cpp_jit_script_module_is_optimized (XPtrTorchScriptModule self) { - return _lantern_ScriptModule_is_optimized(self.get()); + return lantern_ScriptModule_is_optimized(self.get()); } // [[Rcpp::export]] From 1b3eac20bedb5f0f341081fa2a668c0e54316a2d Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 2 Aug 2021 17:35:38 -0300 Subject: [PATCH 2/2] Add a regression test --- tests/testthat/test-script_module.R | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/testthat/test-script_module.R b/tests/testthat/test-script_module.R index 7df6ad8ee8..2fe0830279 100644 --- a/tests/testthat/test-script_module.R +++ b/tests/testthat/test-script_module.R @@ -129,3 +129,19 @@ test_that("can print the graph", { }) }) + +test_that("training attribute is persisted", { + model <- nn_sequential( + nn_linear(10, 10), + nn_relu(), + nn_dropout(), + nn_linear(10, 1) + ) + + model$eval() + model$training + test_model <- jit_trace(model, torch_randn(10, 10)) + + expect_true(!test_model$training) + expect_true(!test_model[["0"]]$training) +})