Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an exception when serializing CQMs with bad labels #1359

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions dimod/constrained/constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ def _from_file_legacy(cls,
constraint_labels = set()
for arch in zf.namelist():
# even on windows zip uses /
match = re.match("constraints/([^/]+)/", arch)
# rely on the fact that we have at least an lhs file
match = re.match("^constraints/(.+)/lhs$", arch)
if match is not None:
constraint_labels.add(match.group(1))

Expand Down Expand Up @@ -1070,11 +1071,11 @@ def from_file(cls,
constraint_labels = set()
for arch in zf.namelist():
# even on windows zip uses /
match = re.match("constraints/([^/]+)/", arch)
match = re.match("^constraints/(.+)/lhs$", arch)
if match is not None:
constraint_labels.add(match.group(1))

for constraint in constraint_labels:
for constraint in constraint_labels:
label = deserialize_variable(json.loads(constraint))

rhs = np.frombuffer(zf.read(f"constraints/{constraint}/rhs"), np.float64)[0]
Expand Down Expand Up @@ -1783,7 +1784,23 @@ def to_file(self, *,

for label, constraint in self.constraints.items():
# put everything in a constraints/label/ directory
lstr = json.dumps(serialize_variable(label))
lstr = json.dumps(serialize_variable(label), ensure_ascii=False)

if "/" in lstr:
# Because of the way we do the regex in .from_file(), we actually do
# support these. But it is inconsistent with the description of the file
# format so we do the simpler thing and just disallow it
raise ValueError("cannot serialize constraint labels containing '/'")

if "\0" in lstr:
# Similarily, this actually works, but it's weird and confusing to support it
# so we disallow
raise ValueError("cannot serialize constraint labels containing the NULL character")

if os.sep == '\\' and os.sep in lstr:
# Irritatingly, zipfile will automatically convert \ to / on windows, so we
# also don't allow that
raise ValueError("cannot serialize constraint labels containing '\\' on windows")

with zf.open(f'constraints/{lstr}/lhs', "w", force_zip64=True) as fdst:
constraint.lhs._into_file(fdst)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Raise an exception when serializing constrained quadratic models with constraint
labels containing ``"/"``. Previously they would be serialized but would subsequently
break deserialization.
See `#1358 <https://github.com/dwavesystems/dimod/issues/1358>`_.
75 changes: 75 additions & 0 deletions tests/test_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,81 @@ def test_unused_variable(self):
self.assertEqual(new.lower_bound(v), cqm.lower_bound(v))
self.assertEqual(new.upper_bound(v), cqm.upper_bound(v))

def test_unusual_constraint_labels(self):
import os

x, y = dimod.Binaries("xy")

unusual_characters = " .~;,>|😜+-&" # not exhaustive

if os.sep == "\\":
with self.subTest("\\"):
label = "test\\test"
cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with self.assertRaises(ValueError):
cqm.to_file()
else:
unusual_characters += "\\"

for char in unusual_characters:
with self.subTest(f"leading {char}"):
label = f"{char}test"

cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with cqm.to_file() as f:
new = dimod.CQM.from_file(f)

# best we can hope for is an equivalent after a json round trip
self.assertEqual(list(new.constraints), [label])

with self.subTest(f"trailing {char}"):
label = f"test{char}"

cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with cqm.to_file() as f:
new = dimod.CQM.from_file(f)

# best we can hope for is an equivalent after a json round trip
self.assertEqual(list(new.constraints), [label])

with self.subTest(f"embedded {char}"):
label = f"te{char}st"

cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with cqm.to_file() as f:
new = dimod.CQM.from_file(f)

# best we can hope for is an equivalent after a json round trip
self.assertEqual(list(new.constraints), [label])

with self.subTest("empty label"):
label = f""

cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with cqm.to_file() as f:
new = dimod.CQM.from_file(f)

self.assertEqual(list(new.constraints), [label])

with self.subTest("/"):
label = "test/test"
cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with self.assertRaises(ValueError):
cqm.to_file()

with self.subTest("NULL"):
label = "test\0test"
cqm = dimod.CQM()
cqm.add_constraint(x + y <= 5, label=label)
with self.assertRaises(ValueError):
cqm.to_file()


class TestSetObjective(unittest.TestCase):
def test_bqm(self):
Expand Down