Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/schema #17

Closed
wants to merge 8 commits into from
1 change: 1 addition & 0 deletions graphene_federation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .inaccessible import inaccessible
from .provides import provides
from .override import override
from .compose_directive import mark_composable, is_composable
64 changes: 64 additions & 0 deletions graphene_federation/compose_directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Optional

from graphql import GraphQLDirective


def is_composable(directive: GraphQLDirective) -> bool:
"""
Checks if the directive will be composed to supergraph.
Validates the presence of _compose_import_url attribute
"""
return hasattr(directive, "_compose_import_url")


def mark_composable(
directive: GraphQLDirective, import_url: str, import_as: Optional[str] = None
) -> GraphQLDirective:
"""
Marks directive with _compose_import_url and _compose_import_as
Enables Identification of directives which are to be composed to supergraph
"""
setattr(directive, "_compose_import_url", import_url)
if import_as:
setattr(directive, "_compose_import_as", import_as)
return directive


def compose_directive_schema_extensions(directives: list[GraphQLDirective]):
"""
Generates schema extends string for ComposeDirective
"""
link_schema = ""
compose_directive_schema = ""
# Using dictionary to generate cleaner schema when multiple directives imports from same URL.
links: dict = {}

for directive in directives:
# TODO: Replace with walrus operator when dropping Python 3.8 support
if hasattr(directive, "_compose_import_url"):
compose_import_url = getattr(directive, "_compose_import_url")
if hasattr(directive, "_compose_import_as"):
compose_import_as = getattr(directive, "_compose_import_as")
import_value = (
f'{{ name: "@{directive.name}, as: "@{compose_import_as}" }}'
)
imported_name = compose_import_as
else:
import_value = f'"@{directive.name}"'
imported_name = directive.name

import_url = compose_import_url

if links.get(import_url):
links[import_url] = links[import_url].append(import_value)
else:
links[import_url] = [import_value]

compose_directive_schema += (
f' @composeDirective(name: "@{imported_name}")\n'
)

for import_url in links:
link_schema += f' @link(url: "{import_url}", import: [{",".join(value for value in links[import_url])}])\n'

return link_schema + compose_directive_schema
18 changes: 11 additions & 7 deletions graphene_federation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ def _get_query(schema: Schema, query_cls: Optional[ObjectType] = None) -> Object
def build_schema(
query: Optional[ObjectType] = None,
mutation: Optional[ObjectType] = None,
enable_federation_2=False,
federation_version: Optional[float] = None,
enable_federation_2: bool = False,
schema: Optional[Schema] = None,
**kwargs
) -> Schema:
schema = schema or Schema(query=query, mutation=mutation, **kwargs)
schema.auto_camelcase = kwargs.get("auto_camelcase", True)
schema.federation_version = 2 if enable_federation_2 else 1
federation_query = _get_query(schema, schema.query if schema else query)
return Schema(
query=federation_query,
mutation=schema.mutation if schema else mutation,
**kwargs
schema.federation_version = float(
(federation_version or 2) if (enable_federation_2 or federation_version) else 1
)
federation_query = _get_query(schema, schema.query)
# Use shallow copy to prevent recursion error
kwargs = schema.__dict__.copy()
kwargs.pop("query")
kwargs.pop("graphql_schema")
kwargs.pop("federation_version")
return type(schema)(query=federation_query, **kwargs)
93 changes: 51 additions & 42 deletions graphene_federation/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from graphene.types.union import UnionOptions
from graphql import GraphQLInterfaceType, GraphQLObjectType

from .compose_directive import is_composable, compose_directive_schema_extensions
from .external import get_external_fields
from .inaccessible import get_inaccessible_types, get_inaccessible_fields
from .override import get_override_fields
Expand Down Expand Up @@ -120,37 +121,50 @@ def get_sdl(schema: Schema) -> str:
external_fields = get_external_fields(schema)
override_fields = get_override_fields(schema)

_schema = ""
schema_extensions = []

if schema.federation_version == 2:
if schema.federation_version >= 2:
shareable_types = get_shareable_types(schema)
inaccessible_types = get_inaccessible_types(schema)
shareable_fields = get_shareable_fields(schema)
tagged_fields = get_tagged_fields(schema)
inaccessible_fields = get_inaccessible_fields(schema)

_schema_import = []
federation_spec_import = []

if extended_types:
_schema_import.append('"@extends"')
federation_spec_import.append('"@extends"')
if external_fields:
_schema_import.append('"@external"')
federation_spec_import.append('"@external"')
if entities:
_schema_import.append('"@key"')
federation_spec_import.append('"@key"')
if override_fields:
_schema_import.append('"@override"')
federation_spec_import.append('"@override"')
if provides_parent_types or provides_fields:
_schema_import.append('"@provides"')
federation_spec_import.append('"@provides"')
if required_fields:
_schema_import.append('"@requires"')
federation_spec_import.append('"@requires"')
if inaccessible_types or inaccessible_fields:
_schema_import.append('"@inaccessible"')
federation_spec_import.append('"@inaccessible"')
if shareable_types or shareable_fields:
_schema_import.append('"@shareable"')
federation_spec_import.append('"@shareable"')
if tagged_fields:
_schema_import.append('"@tag"')
schema_import = ", ".join(_schema_import)
_schema = f'extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: [{schema_import}])\n'
federation_spec_import.append('"@tag"')

if schema.federation_version >= 2.1:
preserved_directives = [
directive for directive in schema.directives if is_composable(directive)
]
if preserved_directives:
federation_spec_import.append('"@composeDirective"')
schema_extensions.append(
compose_directive_schema_extensions(preserved_directives)
)

schema_import = ", ".join(federation_spec_import)
schema_extensions = [
f'@link(url: "https://specs.apollo.dev/federation/v{schema.federation_version}", import: [{schema_import}])'
] + schema_extensions

# Add fields directives (@external, @provides, @requires, @shareable, @inaccessible)
entities_ = (
Expand All @@ -161,7 +175,7 @@ def get_sdl(schema: Schema) -> str:
| set(provides_fields.values())
)

if schema.federation_version == 2:
if schema.federation_version >= 2:
entities_ = (
entities_
| set(shareable_types.values())
Expand All @@ -183,62 +197,57 @@ def get_sdl(schema: Schema) -> str:
# Add entity keys declarations
get_field_name = type_attribute_to_field_name(schema)
for entity_name, entity in entities.items():
type_def_re = rf"(type {entity_name} [^\{{]*)" + " "
type_def_re = rf"(type {entity_name} [^\{{]*)"

# resolvable argument of @key directive is true by default. If false, we add 'resolvable: false' to sdl.
if (
schema.federation_version == 2
schema.federation_version >= 2
and hasattr(entity, "_resolvable")
and not entity._resolvable
):
type_annotation = (
(
" ".join(
[
f'@key(fields: "{get_field_name(key)}"'
for key in entity._keys
]
)
)
+ f", resolvable: {str(entity._resolvable).lower()})"
+ " "
)
else:
type_annotation = (
" ".join(
[f'@key(fields: "{get_field_name(key)}")' for key in entity._keys]
[f'@key(fields: "{get_field_name(key)}"' for key in entity._keys]
)
) + " "
repl_str = rf"\1{type_annotation}"
) + f", resolvable: {str(entity._resolvable).lower()})"
else:
type_annotation = " ".join(
[f'@key(fields: "{get_field_name(key)}")' for key in entity._keys]
)
repl_str = rf"\1 {type_annotation} "
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

if schema.federation_version == 2:
if schema.federation_version >= 2:
for type_name, type in shareable_types.items():
# noinspection PyProtectedMember
if isinstance(type._meta, UnionOptions):
type_def_re = rf"(union {type_name})"
else:
type_def_re = rf"(type {type_name} [^\{{]*)" + " "
type_annotation = " @shareable"
repl_str = rf"\1{type_annotation} "
type_def_re = rf"(type {type_name} [^\{{]*)"
type_annotation = "@shareable"
repl_str = rf"\1 {type_annotation} "
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

for type_name, type in inaccessible_types.items():
# noinspection PyProtectedMember
if isinstance(type._meta, InterfaceOptions):
type_def_re = rf"(interface {type_name}[^\{{]*)" + " "
type_def_re = rf"(interface {type_name}[^\{{]*)"
elif isinstance(type._meta, UnionOptions):
type_def_re = rf"(union {type_name})"
else:
type_def_re = rf"(type {type_name} [^\{{]*)" + " "
type_annotation = " @inaccessible"
repl_str = rf"\1{type_annotation} "
type_def_re = rf"(type {type_name} [^\{{]*)"
type_annotation = "@inaccessible"
repl_str = rf"\1 {type_annotation} "
pattern = re.compile(type_def_re)
string_schema = pattern.sub(repl_str, string_schema)

return _schema + string_schema
if schema_extensions:
string_schema = (
"extend schema\n " + "\n ".join(schema_extensions) + "\n" + string_schema
)
return re.sub(r"[ ]+", " ", re.sub(r"\n+", "\n", string_schema)) # noqa


def get_service_query(schema: Schema):
Expand Down
Loading
Loading