diff --git a/dimod/include/dimod/abc.h b/dimod/include/dimod/abc.h index 88f6f7560..193799260 100644 --- a/dimod/include/dimod/abc.h +++ b/dimod/include/dimod/abc.h @@ -707,7 +707,7 @@ void QuadraticModelBase::fix_variable(index_type v, T ass add_offset(assignment * linear(v)); // finally remove v - remove_variable(v); + QuadraticModelBase::remove_variable(v); } template diff --git a/dimod/include/dimod/expression.h b/dimod/include/dimod/expression.h index a4b71ec3e..7f79126b5 100644 --- a/dimod/include/dimod/expression.h +++ b/dimod/include/dimod/expression.h @@ -419,7 +419,20 @@ void Expression::clear() { template template void Expression::fix_variable(index_type v, T assignment) { - throw std::logic_error("not implemented - fix_variable"); + assert(v >= 0 && static_cast(v) < parent_->num_variables()); + + auto vit = indices_.find(v); + if (vit == indices_.end()) return; // nothing to remove + + // remove the biases + base_type::fix_variable(vit->second, assignment); + + // update the indices + auto it = variables_.erase(variables_.begin() + vit->second); + indices_.erase(vit); + for (; it != variables_.end(); ++it) { + indices_[*it] -= 1; + } } template @@ -592,6 +605,8 @@ bool Expression::remove_interaction(index_type u, index_t template void Expression::remove_variable(index_type v) { + assert(v >= 0 && static_cast(v) < parent_->num_variables()); + auto vit = indices_.find(v); if (vit == indices_.end()) return; // nothing to remove @@ -599,12 +614,10 @@ void Expression::remove_variable(index_type v) { base_type::remove_variable(vit->second); // update the indices - variables_.erase(variables_.begin() + vit->second); - - // indices is no longer valid, so remake - indices_.clear(); - for (size_type ui = 0; ui < variables_.size(); ++ui) { - indices_[variables_[ui]] = ui; + auto it = variables_.erase(variables_.begin() + vit->second); + indices_.erase(vit); + for (; it != variables_.end(); ++it) { + indices_[*it] -= 1; } } diff --git a/dimod/include/dimod/quadratic_model.h b/dimod/include/dimod/quadratic_model.h index df80b14a5..316bb2a1b 100644 --- a/dimod/include/dimod/quadratic_model.h +++ b/dimod/include/dimod/quadratic_model.h @@ -64,6 +64,15 @@ class QuadraticModel : public abc::QuadraticModelBase { /// Change the vartype of `v`, updating the biases appropriately. void change_vartype(Vartype vartype, index_type v); + /** + * Remove variable `v` from the model by fixing its value. + * + * Note that this causes a reindexing, where all variables above `v` have + * their index reduced by one. + */ + template + void fix_variable(index_type v, T assignment); + /// Return the lower bound on variable ``v``. bias_type lower_bound(index_type v) const; @@ -229,6 +238,13 @@ void QuadraticModel::change_vartype(Vartype vartype, inde } } +template +template +void QuadraticModel::fix_variable(index_type v, T assignment) { + base_type::fix_variable(v, assignment); + varinfo_.erase(varinfo_.begin() + v); +} + template bias_type QuadraticModel::lower_bound(index_type v) const { // even though v is unused, we need this to conform the the QuadraticModelBase API diff --git a/releasenotes/notes/expression-fix-variables-performance-f13a65b6a16fa484.yaml b/releasenotes/notes/expression-fix-variables-performance-f13a65b6a16fa484.yaml new file mode 100644 index 000000000..b63f9c6b7 --- /dev/null +++ b/releasenotes/notes/expression-fix-variables-performance-f13a65b6a16fa484.yaml @@ -0,0 +1,12 @@ +--- +features: + - | + Improve the performance of fixing and removing variables from constrained + quadratic model expressions. + - | + Implement the ``Expression::fix_variable()`` C++ method. Previously it would + throw ``std::logic_error("not implemented - fix_variable")``. +upgrade: + - | + Add an overload to the C++ ``QuadraticModel::remove_variable()`` method. + This is binary compatible, but it makes ``&remove_variable`` ambiguous. diff --git a/testscpp/tests/test_constrained_quadratic_model.cpp b/testscpp/tests/test_constrained_quadratic_model.cpp index 082c26053..a74a82e92 100644 --- a/testscpp/tests/test_constrained_quadratic_model.cpp +++ b/testscpp/tests/test_constrained_quadratic_model.cpp @@ -410,16 +410,36 @@ SCENARIO("ConstrainedQuadraticModel tests") { } } - // WHEN("we fix a variable") { - // cqm.fix_variable(x, 0); + WHEN("we fix a variable that is not used in the expression") { + const0.fix_variable(y, 2); - // THEN("everything is updated correctly") { - // REQUIRE(cqm.num_variables() == 3); + THEN("nothing changes") { + REQUIRE(const0.num_variables() == 3); + CHECK(const0.linear(x) == 0); + CHECK(const0.linear(y) == 0); + CHECK(const0.linear(i) == 3); + CHECK(const0.linear(j) == 0); + + CHECK(const0.num_interactions() == 2); + CHECK(const0.quadratic(x, j) == 2); + CHECK(const0.quadratic(i, j) == 5); + } + } - // REQUIRE(const0.num_variables() == 2); - // REQUIRE(const0.linear(i-1) == 3); - // } - // } + WHEN("we fix a variable that is used in the expression") { + const0.fix_variable(x, 2); + + THEN("the biases are updated") { + REQUIRE(const0.num_variables() == 2); + CHECK(const0.linear(x) == 0); + CHECK(const0.linear(y) == 0); + CHECK(const0.linear(i) == 3); + CHECK(const0.linear(j) == 4); + + CHECK(const0.num_interactions() == 1); + CHECK(const0.quadratic(i, j) == 5); + } + } } GIVEN("A constraint with one-hot constraints") {