-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support adding pipeline component instances #12710
base: v4
Are you sure you want to change the base?
Changes from 7 commits
6f821ef
aa0d747
4332d12
b9730a6
afbdd82
9753484
77a0859
dcd8a76
4cc5bd3
9fcbc8e
8a79a71
0fd797e
42a373d
0c8393e
ef18829
bc3337d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,6 @@ | ||||||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable | ||||||
from typing import Union, Tuple, List, Set, Pattern, Sequence | ||||||
from typing import NoReturn, TypeVar, cast, overload | ||||||
from typing import Union, Tuple, List, Set, Pattern, Sequence, overload | ||||||
from typing import NoReturn, TypeVar, cast | ||||||
|
||||||
from dataclasses import dataclass | ||||||
import random | ||||||
|
@@ -52,6 +52,9 @@ | |||||
# This is the base config for the [pretraining] block and currently not included | ||||||
# in the main config and only added via the 'init fill-config' command | ||||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg" | ||||||
# Factory name indicating that the component wasn't constructed by a factory, | ||||||
# and was instead passed by instance | ||||||
INSTANCE_FACTORY_NAME = "__added_by_instance__" | ||||||
|
||||||
# Type variable for contexts piped with documents | ||||||
_AnyContext = TypeVar("_AnyContext") | ||||||
|
@@ -743,6 +746,9 @@ def add_pipe( | |||||
"""Add a component to the processing pipeline. Valid components are | ||||||
callables that take a `Doc` object, modify it and return it. Only one | ||||||
of before/after/first/last can be set. Default behaviour is "last". | ||||||
Components can be added either by factory name or by instance. If | ||||||
an instance is supplied and you serialize the pipeline, you'll need | ||||||
to also pass an instance into spacy.load() to construct the pipeline. | ||||||
|
||||||
factory_name (str): Name of the component factory. | ||||||
name (str): Name of pipeline component. Overwrites existing | ||||||
|
@@ -790,11 +796,61 @@ def add_pipe( | |||||
raw_config=raw_config, | ||||||
validate=validate, | ||||||
) | ||||||
pipe_index = self._get_pipe_index(before, after, first, last) | ||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name) | ||||||
pipe_index = self._get_pipe_index(before, after, first, last) | ||||||
self._components.insert(pipe_index, (name, pipe_component)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can revert the changes to |
||||||
return pipe_component | ||||||
|
||||||
def add_pipe_instance( | ||||||
self, | ||||||
component: PipeCallable, | ||||||
/, | ||||||
name: Optional[str] = None, | ||||||
*, | ||||||
before: Optional[Union[str, int]] = None, | ||||||
after: Optional[Union[str, int]] = None, | ||||||
first: Optional[bool] = None, | ||||||
last: Optional[bool] = None, | ||||||
) -> PipeCallable: | ||||||
"""Add a component instance to the processing pipeline. Valid components | ||||||
are callables that take a `Doc` object, modify it and return it. Only one | ||||||
of before/after/first/last can be set. Default behaviour is "last". | ||||||
|
||||||
A limitation of this method is that spaCy will not know how to reconstruct | ||||||
your pipeline after you save it out (unlike the 'Language.add_pipe()' method, | ||||||
where you provide a config and let spaCy construct the instance). See 'spacy.load' | ||||||
for details of how to load back a pipeline with components added by instance. | ||||||
|
||||||
pipe_instance (Callable[[Doc], Doc]): The component to add. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
name (str): Name of pipeline component. Overwrites existing | ||||||
component.name attribute if available. If no name is set and | ||||||
the component exposes no name attribute, component.__name__ is | ||||||
used. An error is raised if a name already exists in the pipeline. | ||||||
before (Union[str, int]): Name or index of the component to insert new | ||||||
component directly before. | ||||||
after (Union[str, int]): Name or index of the component to insert new | ||||||
component directly after. | ||||||
first (bool): If True, insert component first in the pipeline. | ||||||
last (bool): If True, insert component last in the pipeline. | ||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component. | ||||||
|
||||||
DOCS: https://spacy.io/api/language#add_pipe_instance | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: add documentation to API page |
||||||
""" | ||||||
name = name if name is not None else getattr(component, "name") | ||||||
if name is None: | ||||||
raise ValueError("TODO error") | ||||||
Comment on lines
+868
to
+870
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, we could require |
||||||
if name in self.component_names: | ||||||
raise ValueError(Errors.E007.format(name=name, opts=self.component_names)) | ||||||
|
||||||
# It would be possible to take arguments for the FactoryMeta here, but we'll then have | ||||||
# a problem on deserialization: where will the data be coming from? | ||||||
# I think if someone wants that, they should register a component function. | ||||||
self._pipe_meta[name] = FactoryMeta(INSTANCE_FACTORY_NAME) | ||||||
self._pipe_configs[name] = Config() | ||||||
pipe_index = self._get_pipe_index(before, after, first, last) | ||||||
self._components.insert(pipe_index, (name, component)) | ||||||
return component | ||||||
|
||||||
def _get_pipe_index( | ||||||
self, | ||||||
before: Optional[Union[str, int]] = None, | ||||||
|
@@ -1696,6 +1752,7 @@ def from_config( | |||||
meta: Dict[str, Any] = SimpleFrozenDict(), | ||||||
auto_fill: bool = True, | ||||||
validate: bool = True, | ||||||
pipe_instances: Dict[str, Any] = SimpleFrozenDict(), | ||||||
) -> "Language": | ||||||
"""Create the nlp object from a loaded config. Will set up the tokenizer | ||||||
and language data, add pipeline components etc. If no config is provided, | ||||||
|
@@ -1771,6 +1828,11 @@ def from_config( | |||||
|
||||||
# Warn about require_gpu usage in jupyter notebook | ||||||
warn_if_jupyter_cupy() | ||||||
# If we've been passed pipe instances, check whether | ||||||
# they have a Vocab instance, and if they do, use | ||||||
# that one. This also performs some additional checks and | ||||||
# warns if there's a mismatch. | ||||||
vocab = _get_instantiated_vocab(vocab, pipe_instances) | ||||||
|
||||||
# Note that we don't load vectors here, instead they get loaded explicitly | ||||||
# inside stuff like the spacy train function. If we loaded them here, | ||||||
|
@@ -1787,6 +1849,11 @@ def from_config( | |||||
interpolated = filled.interpolate() if not filled.is_interpolated else filled | ||||||
pipeline = interpolated.get("components", {}) | ||||||
sourced = util.get_sourced_components(interpolated) | ||||||
# Check for components that aren't in the pipe_instances dict, aren't disabled, | ||||||
# and aren't built by factory. | ||||||
missing_components = _find_missing_components(pipeline, pipe_instances, exclude) | ||||||
if missing_components: | ||||||
raise ValueError(Errors.E1055.format(names=", ".join(missing_components))) | ||||||
# If components are loaded from a source (existing models), we cache | ||||||
# them here so they're only loaded once | ||||||
source_nlps = {} | ||||||
|
@@ -1796,6 +1863,16 @@ def from_config( | |||||
if pipe_name not in pipeline: | ||||||
opts = ", ".join(pipeline.keys()) | ||||||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) | ||||||
if pipe_name in pipe_instances: | ||||||
if pipe_name in exclude: | ||||||
continue | ||||||
else: | ||||||
nlp.add_pipe_instance(pipe_instances[pipe_name]) | ||||||
# Is it important that we instantiate pipes that | ||||||
# aren't excluded? It seems like we would want | ||||||
# the exclude check above. I've left it how it | ||||||
# is though, in case there's some sort of crazy | ||||||
# load-bearing side-effects someone is relying on? | ||||||
pipe_cfg = util.copy_config(pipeline[pipe_name]) | ||||||
raw_config = Config(filled["components"][pipe_name]) | ||||||
if pipe_name not in exclude: | ||||||
Comment on lines
1902
to
1917
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that it looks like we should move the Let's move that to a separate PR though? |
||||||
|
@@ -2312,3 +2389,46 @@ def step(self) -> None: | |||||
if self.count >= self.chunk_size: | ||||||
self.count = 0 | ||||||
self.send() | ||||||
|
||||||
|
||||||
def _get_instantiated_vocab( | ||||||
vocab: Union[bool, Vocab], pipe_instances: Dict[str, Any] | ||||||
) -> Union[bool, Vocab]: | ||||||
vocab_instances = {} | ||||||
for name, instance in pipe_instances.items(): | ||||||
if hasattr(instance, "vocab") and isinstance(instance.vocab, Vocab): | ||||||
vocab_instances[name] = instance.vocab | ||||||
if not vocab_instances: | ||||||
return vocab | ||||||
elif isinstance(vocab, Vocab): | ||||||
for name, inst_voc in vocab_instances.items(): | ||||||
if inst_voc is not vocab: | ||||||
warnings.warn(Warnings.W125.format(name=name)) | ||||||
return vocab | ||||||
else: | ||||||
resolved_vocab = None | ||||||
for name, inst_voc in vocab_instances.items(): | ||||||
if resolved_vocab is None: | ||||||
resolved_vocab = inst_voc | ||||||
elif inst_voc is not resolved_vocab: | ||||||
warnings.warn(Warnings.W125.format(name=name)) | ||||||
# This is supposed to only be for the type checker -- | ||||||
# it should be unreachable | ||||||
assert resolved_vocab is not None | ||||||
return resolved_vocab | ||||||
honnibal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
def _find_missing_components( | ||||||
pipeline: Dict[str, Dict[str, Any]], | ||||||
pipe_instances: Dict[str, Any], | ||||||
exclude: Iterable[str], | ||||||
) -> List[str]: | ||||||
missing = [] | ||||||
for name, config in pipeline.items(): | ||||||
if ( | ||||||
config.get("factory") == INSTANCE_FACTORY_NAME | ||||||
and name not in pipe_instances | ||||||
and name not in exclude | ||||||
): | ||||||
missing.append(name) | ||||||
return missing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf the type we're using in
add_pipe_instance()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment for the other methods, too.