Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Query Optimizer that selects fields from query AST #56

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions example/db.sqlite3
Git LFS file not shown
5 changes: 5 additions & 0 deletions example/example/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"grapple",
"graphene_django",
"channels",
"django_extensions",
]

MIDDLEWARE = [
Expand Down Expand Up @@ -178,3 +179,7 @@
"ROUTING": "grapple.urls.channel_routing",
}
}

# Query Optimisation helpers
SHELL_PLUS_PRINT_SQL = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be enabled only locally, maybe?

RUNSERVER_PLUS_PRINT_SQL_TRUNCATE = 100000
2 changes: 2 additions & 0 deletions example/home/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.db import models
from modelcluster.fields import ParentalKey

Expand All @@ -6,6 +7,7 @@
from wagtail.contrib.settings.models import BaseSetting, register_setting

from wagtail.core import blocks
from wagtail.core.fields import RichTextField, StreamField
from wagtail.admin.edit_handlers import FieldPanel, StreamFieldPanel, InlinePanel
from wagtail.images.blocks import ImageChooserBlock
from wagtail.documents.blocks import DocumentChooserBlock
Expand Down
14 changes: 11 additions & 3 deletions grapple/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ class Grapple(AppConfig):
name = "grapple"

def ready(self):
"""
Import all the django apps defined in django settings then process each model
in these apps and create graphql node types from them.
"""
Import all the django apps defined in django settings then process each model
in these apps and create graphql node types from them.
"""
from .actions import import_apps, load_type_fields
from .types.streamfield import register_streamfield_blocks

self.preload_tasks()
import_apps()
load_type_fields()
register_streamfield_blocks()

def preload_tasks(self):
# Monkeypatch Wagtails' PageQueryset .specific method to a more optimized one
from wagtail.core.query import PageQuerySet
from .db.query import specific

PageQuerySet.specific = specific
Empty file added grapple/db/__init__.py
Empty file.
243 changes: 243 additions & 0 deletions grapple/db/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import re
from collections.abc import Iterable
from django.db.models.query_utils import DeferredAttribute
from django.db.models.fields.related_descriptors import (
ReverseOneToOneDescriptor,
ReverseManyToOneDescriptor,
ForwardManyToOneDescriptor,
ForwardOneToOneDescriptor,
ReverseOneToOneDescriptor,
)
from graphene.types.definitions import GrapheneInterfaceType
from graphql.language.ast import (
Field,
InlineFragment,
FragmentSpread,
InterfaceTypeDefinition,
)

from modelcluster.fields import ParentalKey
from django.db.models.fields.related import ForeignKey

pascal_to_snake = re.compile(r"(?<!^)(?=[A-Z])")


class QueryOptimzer:
# Fields extracted from AST that the request wants
qs = None
schema = None
model = None
model_type_map = {}
query_fields = []

def __init__(self):
# Queryset specific lists (for this query)
self.only_fields = []
self.select_related_fields = []
self.prefetch_related_fields = []
# Future Optimisation maps
self.only_field_types = {}
self.select_related_types = {}
self.prefetch_related_types = {}

@staticmethod
def query(qs, info):
# Create new optimiser instance.
qs_optimizer = QueryOptimzer()
qs_optimizer.qs = qs
qs_optimizer.schema = info.schema
qs_optimizer.model = qs.model
# Extract what fields the user wants from the AST.
fields, types = AstExplorer(qs, info).parse_ast()
qs_optimizer.query_fields = fields
qs_optimizer.model_type_map = types
# Sort desired fields in how we preload them.
qs_optimizer.sort_fields()
# Add deferred fields to query.
qs_optimizer.optimise_fields()
# return the new optimised query
return qs_optimizer.qs

# Sort the requested fields, depending on relation to model.
def sort_fields(self):
for field_name in self.query_fields:
# Make sure field name is snake not pascal (graphene converts them that way)
field_name = pascal_to_snake.sub("_", field_name).lower()

# Support using simple field
field = getattr(self.model, field_name, None)
if field:
self.select_field(field, field_name)
continue

# Support more complex sub-selectable fields
field_name_prefix, field_name = field_name.split("__", 1)
nested_field = getattr(self.model, field_name_prefix)
if nested_field:
# Add to query to save recomputing down the line
self.select_field(nested_field, field_name, field_name_prefix)
continue

def select_field(self, field, field_name, field_name_prefix=None):
model = self.model_type_map.get(field_name_prefix, None)
# If link to another model is here then try and parse nested fields
if isinstance(field, ReverseOneToOneDescriptor):
# Recursion to load nested fields to support specific pages
nested_field_name = field_name.split("__")[0]
nested_field = getattr(model, nested_field_name, None)
if hasattr(nested_field, "field"):
return self.select_field(
nested_field.field, nested_field_name, field_name_prefix
)
else:
return self.select_field(
nested_field, nested_field_name, field_name_prefix
)

# Property's or functions can't be selected so we just add id so the query always runs
if hasattr(model, "id"):
self.only_fields.append("id")

if not getattr(field, "is_relation", False):
if model:
# Cache selection for future optimisation (query.py)
existing_fields = self.only_field_types.get(field_name_prefix, [])
self.only_field_types[field_name_prefix] = [
field_name,
*existing_fields,
]
else:
self.only_fields.append(field_name)

elif field.one_to_many or field.many_to_many or isinstance(field, ParentalKey):
if model:
# Cache selection for future optimisation (query.py)
existing_fields = self.prefetch_related_types.get(field_name_prefix, [])
self.prefetch_related_types[field_name_prefix] = [
field_name,
*existing_fields,
]
else:
self.prefetch_related_fields.append(field_name)

elif field.many_to_one or field.one_to_one:
if model:
# Cache selection for future optimisation (query.py)
existing_fields = self.select_related_types.get(field_name_prefix, [])
self.select_related_types[field_name_prefix] = [
field_name,
*existing_fields,
]
else:
self.only_fields.append(field_name)
self.select_related_fields.append(field_name)

# Apply order fields to querysets
def optimise_fields(self):
self.qs = self.qs.only(*self.only_fields)
self.qs = self.qs.select_related(*self.select_related_fields)
self.qs = self.qs.prefetch_related(*self.prefetch_related_fields)

# Add custom lists to query for use in Specific page optimizer.
setattr(self.qs.query, "only_field_types", self.only_field_types)
setattr(self.qs.query, "select_related_types", self.select_related_types)
setattr(self.qs.query, "prefetch_related_types", self.prefetch_related_types)


class AstExplorer:
schema = None
fragments = {}
model_type_map = {}
resolve_info = None

def __init__(self, qs, info):
self.schema = info.schema
self.resolve_info = info
self.fragments = info.fragments

# Parse AST to find fields
def parse_ast(self):
pages_interface = self.get_pages_interface()
return self.parse_field(pages_interface, None), self.model_type_map

# Return the pages interface, we only optimise that for now
def get_pages_interface(self):
for field in self.resolve_info.field_asts:
if field.name.value == "page":
return field
if field.name.value == "pages":
return field

def parse_field(self, field, field_prefix):
field_name = field.name.value

# If field has subset fields
if field.selection_set:
field_prefix = field_prefix + "__" + field_name if field_prefix else ""
return self.parse_selection_set(field.selection_set, field_prefix)

# Prefix this fieldname with that of it's parent
if field_prefix:
return field_prefix + "__" + field.name.value

# Return fields name
return field_name

def parse_selection_set(self, selection_set, field_prefix):
selections = []
if selection_set.selections:
for selection in selection_set.selections:
selection = self.parse_selection(selection, field_prefix)
if isinstance(selection, list):
selections.extend(selection)
else:
selections.append(selection)

return selections

def parse_selection(self, selection, field_prefix):
if isinstance(selection, InlineFragment):
return self.parse_inline_fragment(selection, field_prefix)

if isinstance(selection, FragmentSpread):
return self.parse_fragment_spread(selection, field_prefix)

if isinstance(selection, Field):
return self.parse_field(selection, field_prefix)

def parse_fragment_spread(self, fragment_spread, field_prefix):
# Get fragment name.
fragment_name = fragment_spread.name.value

# Get fragment type from name and return it parsed.
fragment_type = self.fragments[fragment_name]
return self.parse_inline_fragment(fragment_type, field_prefix)

def parse_inline_fragment(self, inline_fragment, field_prefix):
# Get type of inline fragment
gql_type_name = inline_fragment.type_condition.name.value
gql_type = self.schema.get_type(gql_type_name)

# Record what Django model this correlates to
type_prefix = gql_type_name.lower()
self.model_type_map[type_prefix] = getattr(
gql_type.graphene_type._meta, "model", None
)

# Function to add the typename to fieldnames
def prefix_type(field):
# Don't prefix if the type is an interface
if not isinstance(gql_type, GrapheneInterfaceType) and field_prefix:
return field_prefix + "__" + field

return field

# Get fields of inline fragment
selections = []
if inline_fragment.selection_set:
selections = self.parse_selection_set(
inline_fragment.selection_set, type_prefix
)
selections = list(map(prefix_type, selections))

return selections
88 changes: 88 additions & 0 deletions grapple/db/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections import defaultdict

from wagtail.core.query import DeferredSpecificIterable
from django.contrib.contenttypes.models import ContentType
from django.db.models.query import BaseIterable


def specific(self, defer=False, *args, **kwargs):
"""
This efficiently gets all the specific pages for the queryset, using
the minimum number of queries.

When the "defer" keyword argument is set to True, only the basic page
fields will be loaded and all specific fields will be deferred. It
will still generate a query for each page type though (this may be
improved to generate only a single query in a future release).
"""
clone = self._clone()
if defer:
clone._iterable_class = GrappleSpecificIterable
else:
clone._iterable_class = DeferredSpecificIterable

return clone


class GrappleSpecificIterable(BaseIterable):
def __iter__(self):
return specific_iterator(self.queryset, defer=True)


# Custom version of code from https://github.com/wagtail/wagtail/blob/master/wagtail/core/query.py#L363
def specific_iterator(qs, defer=True):
"""
This efficiently iterates all the specific pages in a queryset, using
the minimum number of queries.

This should be called from ``PageQuerySet.specific``
"""
pks_and_types = qs.values_list("pk", "content_type")
pks_by_type = defaultdict(list)
for pk, content_type in pks_and_types:
pks_by_type[content_type].append(pk)

# Content types are cached by ID, so this will not run any queries.
content_types = {pk: ContentType.objects.get_for_id(pk) for _, pk in pks_and_types}

# Get the specific instances of all pages, one model class at a time.
pages_by_type = {}
for content_type, pks in pks_by_type.items():
# look up model class for this content type, falling back on the original
# model (i.e. Page) if the more specific one is missing
specific_model = content_types[content_type].model_class() or qs.model
specific_model_name = specific_model.__name__.lower()

# Get deffered fields (.only/.deffer)
only_fields, _ = qs.query.deferred_loading
only_fields_specific = getattr(qs.query, "only_field_types", {}).get(
specific_model_name, []
)
select_related_fields = getattr(qs.query, "select_related_types", {}).get(
specific_model_name, []
)
prefetch_related_fields = getattr(qs.query, "prefetch_related_types", {}).get(
specific_model_name, []
)

# If no fields of this model requested then don't query specific
if not only_fields:
pages_by_type[content_type] = None
continue

# Query pages
pages = specific_model.objects.filter(pk__in=pks)
# Defer all fields apart from those required
pages = pages.only(*only_fields, *only_fields_specific)
# Apply select_related fields (passed down from optimizer.py)
pages = pages.select_related(*select_related_fields)
# Apply prefetch_related fields (passed down from optimizer.py)
pages = pages.prefetch_related(*prefetch_related_fields)

# Replace specific models in same sort order
pages_by_type[content_type] = {page.pk: page for page in pages}

# Yield all of the pages (specific + generic), in the order they occurred in the original query.
for pk, content_type in pks_and_types:
if pages_by_type[content_type]:
yield pages_by_type[content_type][pk]