diff --git a/dimod/constrained/constrained.py b/dimod/constrained/constrained.py index 86c4fc93c..e24c720db 100644 --- a/dimod/constrained/constrained.py +++ b/dimod/constrained/constrained.py @@ -961,7 +961,8 @@ def _from_file_legacy(cls, constraint_labels = set() for arch in zf.namelist(): # even on windows zip uses / - match = re.match("constraints/([^/]+)/", arch) + # rely on the fact that we have at least an lhs file + match = re.match("^constraints/(.+)/lhs$", arch) if match is not None: constraint_labels.add(match.group(1)) @@ -1070,11 +1071,11 @@ def from_file(cls, constraint_labels = set() for arch in zf.namelist(): # even on windows zip uses / - match = re.match("constraints/([^/]+)/", arch) + match = re.match("^constraints/(.+)/lhs$", arch) if match is not None: constraint_labels.add(match.group(1)) - for constraint in constraint_labels: + for constraint in constraint_labels: label = deserialize_variable(json.loads(constraint)) rhs = np.frombuffer(zf.read(f"constraints/{constraint}/rhs"), np.float64)[0] @@ -1783,7 +1784,23 @@ def to_file(self, *, for label, constraint in self.constraints.items(): # put everything in a constraints/label/ directory - lstr = json.dumps(serialize_variable(label)) + lstr = json.dumps(serialize_variable(label), ensure_ascii=False) + + if "/" in lstr: + # Because of the way we do the regex in .from_file(), we actually do + # support these. But it is inconsistent with the description of the file + # format so we do the simpler thing and just disallow it + raise ValueError("cannot serialize constraint labels containing '/'") + + if "\0" in lstr: + # Similarily, this actually works, but it's weird and confusing to support it + # so we disallow + raise ValueError("cannot serialize constraint labels containing the NULL character") + + if os.sep == '\\' and os.sep in lstr: + # Irritatingly, zipfile will automatically convert \ to / on windows, so we + # also don't allow that + raise ValueError("cannot serialize constraint labels containing '\\' on windows") with zf.open(f'constraints/{lstr}/lhs', "w", force_zip64=True) as fdst: constraint.lhs._into_file(fdst) diff --git a/releasenotes/notes/fix-constraint-labels-and-serialization-c0eb5e410293a63c.yaml b/releasenotes/notes/fix-constraint-labels-and-serialization-c0eb5e410293a63c.yaml new file mode 100644 index 000000000..995b0e1ef --- /dev/null +++ b/releasenotes/notes/fix-constraint-labels-and-serialization-c0eb5e410293a63c.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Raise an exception when serializing constrained quadratic models with constraint + labels containing ``"/"``. Previously they would be serialized but would subsequently + break deserialization. + See `#1358 `_. diff --git a/tests/test_constrained.py b/tests/test_constrained.py index 54941b012..46317260b 100644 --- a/tests/test_constrained.py +++ b/tests/test_constrained.py @@ -1334,6 +1334,81 @@ def test_unused_variable(self): self.assertEqual(new.lower_bound(v), cqm.lower_bound(v)) self.assertEqual(new.upper_bound(v), cqm.upper_bound(v)) + def test_unusual_constraint_labels(self): + import os + + x, y = dimod.Binaries("xy") + + unusual_characters = " .~;,>|😜+-&" # not exhaustive + + if os.sep == "\\": + with self.subTest("\\"): + label = "test\\test" + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with self.assertRaises(ValueError): + cqm.to_file() + else: + unusual_characters += "\\" + + for char in unusual_characters: + with self.subTest(f"leading {char}"): + label = f"{char}test" + + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with cqm.to_file() as f: + new = dimod.CQM.from_file(f) + + # best we can hope for is an equivalent after a json round trip + self.assertEqual(list(new.constraints), [label]) + + with self.subTest(f"trailing {char}"): + label = f"test{char}" + + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with cqm.to_file() as f: + new = dimod.CQM.from_file(f) + + # best we can hope for is an equivalent after a json round trip + self.assertEqual(list(new.constraints), [label]) + + with self.subTest(f"embedded {char}"): + label = f"te{char}st" + + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with cqm.to_file() as f: + new = dimod.CQM.from_file(f) + + # best we can hope for is an equivalent after a json round trip + self.assertEqual(list(new.constraints), [label]) + + with self.subTest("empty label"): + label = f"" + + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with cqm.to_file() as f: + new = dimod.CQM.from_file(f) + + self.assertEqual(list(new.constraints), [label]) + + with self.subTest("/"): + label = "test/test" + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with self.assertRaises(ValueError): + cqm.to_file() + + with self.subTest("NULL"): + label = "test\0test" + cqm = dimod.CQM() + cqm.add_constraint(x + y <= 5, label=label) + with self.assertRaises(ValueError): + cqm.to_file() + class TestSetObjective(unittest.TestCase): def test_bqm(self):