Skip to content

Commit

Permalink
fix: Fail scan on InvalidMagicError in picklescan, update default for…
Browse files Browse the repository at this point in the history
… read_checkpoint_meta to scan unless explicitly told not to
  • Loading branch information
brandonrising committed Nov 26, 2024
1 parent 965cd76 commit 756008d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/services/model_load/model_load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_model_from_path(

def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")


Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/util/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
return checkpoint


def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
if str(path).endswith(".safetensors"):
try:
path_str = path.as_posix() if isinstance(path, Path) else path
Expand All @@ -55,7 +55,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
if str(path).endswith(".gguf"):
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
Expand Down

0 comments on commit 756008d

Please sign in to comment.