diff --git a/tests/error/test_graphql_error.py b/tests/error/test_graphql_error.py index c7db5d13..a206f673 100644 --- a/tests/error/test_graphql_error.py +++ b/tests/error/test_graphql_error.py @@ -4,7 +4,7 @@ from graphql.error import GraphQLError from graphql.language import ( - Node, + NameNode, ObjectTypeDefinitionNode, OperationDefinitionNode, Source, @@ -352,7 +352,7 @@ def formats_graphql_error(): extensions = {"ext": None} error = GraphQLError( "test message", - Node(), + NameNode(value="stub"), Source( """ query { diff --git a/tests/language/test_ast.py b/tests/language/test_ast.py index b0d965a1..9c1f5c84 100644 --- a/tests/language/test_ast.py +++ b/tests/language/test_ast.py @@ -10,15 +10,22 @@ class SampleTestNode(Node): __slots__ = "alpha", "beta" - alpha: int - beta: int + alpha: int | Node # Union with Node to support copy tests with nested nodes + beta: int | Node | None class SampleNamedNode(Node): __slots__ = "foo", "name" foo: str - name: str | None + name: NameNode | None + + +def make_loc(start: int = 1, end: int = 3) -> Location: + """Create a Location for testing with the given start/end offsets.""" + source = Source("test source") + start_token = Token(TokenKind.NAME, start, end, 1, start, "test") + return Location(start_token, start_token, source) def describe_token_class(): @@ -150,15 +157,21 @@ def can_hash(): def describe_node_class(): def initializes_with_keywords(): - node = SampleTestNode(alpha=1, beta=2, loc=0) + node = SampleTestNode(alpha=1, beta=2) assert node.alpha == 1 assert node.beta == 2 - assert node.loc == 0 - node = SampleTestNode(alpha=1, loc=None) assert node.loc is None + + def initializes_with_location(): + loc = make_loc() + node = SampleTestNode(alpha=1, beta=2, loc=loc) assert node.alpha == 1 - assert node.beta is None - node = SampleTestNode(alpha=1, beta=2, gamma=3) + assert node.beta == 2 + assert node.loc is loc + + def initializes_with_none_location(): + node = SampleTestNode(alpha=1, beta=2, loc=None) + assert node.loc is None assert node.alpha == 1 assert node.beta == 2 assert not hasattr(node, "gamma") @@ -174,27 +187,31 @@ def converts_list_to_tuple_on_init(): def has_representation_with_loc(): node = SampleTestNode(alpha=1, beta=2) assert repr(node) == "SampleTestNode" - node = SampleTestNode(alpha=1, beta=2, loc=3) - assert repr(node) == "SampleTestNode at 3" + loc = make_loc(start=3, end=5) + node = SampleTestNode(alpha=1, beta=2, loc=loc) + assert repr(node) == "SampleTestNode at 3:5" def has_representation_when_named(): name_node = NameNode(value="baz") node = SampleNamedNode(foo="bar", name=name_node) assert repr(node) == "SampleNamedNode(name='baz')" - node = SampleNamedNode(alpha=1, beta=2, name=name_node, loc=3) - assert repr(node) == "SampleNamedNode(name='baz') at 3" + loc = make_loc(start=3, end=5) + node = SampleNamedNode(foo="bar", name=name_node, loc=loc) + assert repr(node) == "SampleNamedNode(name='baz') at 3:5" def has_representation_when_named_but_name_is_none(): - node = SampleNamedNode(alpha=1, beta=2, name=None) + node = SampleNamedNode(foo="bar", name=None) assert repr(node) == "SampleNamedNode" - node = SampleNamedNode(alpha=1, beta=2, name=None, loc=3) - assert repr(node) == "SampleNamedNode at 3" + loc = make_loc(start=3, end=5) + node = SampleNamedNode(foo="bar", name=None, loc=loc) + assert repr(node) == "SampleNamedNode at 3:5" def has_special_representation_when_it_is_a_name_node(): node = NameNode(value="foo") assert repr(node) == "NameNode('foo')" - node = NameNode(value="foo", loc=3) - assert repr(node) == "NameNode('foo') at 3" + loc = make_loc(start=3, end=5) + node = NameNode(value="foo", loc=loc) + assert repr(node) == "NameNode('foo') at 3:5" def can_check_equality(): node = SampleTestNode(alpha=1, beta=2) diff --git a/tests/language/test_schema_parser.py b/tests/language/test_schema_parser.py index 3a0e6301..fd410c40 100644 --- a/tests/language/test_schema_parser.py +++ b/tests/language/test_schema_parser.py @@ -3,7 +3,6 @@ import pickle from copy import deepcopy from textwrap import dedent -from typing import Optional, Tuple import pytest @@ -11,6 +10,8 @@ from graphql.language import ( ArgumentNode, BooleanValueNode, + ConstDirectiveNode, + ConstValueNode, DirectiveDefinitionNode, DirectiveNode, DocumentNode, @@ -22,6 +23,7 @@ InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode, ListTypeNode, + Location, NamedTypeNode, NameNode, NonNullTypeNode, @@ -32,25 +34,30 @@ ScalarTypeDefinitionNode, SchemaDefinitionNode, SchemaExtensionNode, + Source, StringValueNode, + Token, + TokenKind, TypeNode, UnionTypeDefinitionNode, - ValueNode, parse, ) from ..fixtures import kitchen_sink_sdl # noqa: F401 -try: - from typing import TypeAlias -except ImportError: # Python < 3.10 - from typing_extensions import TypeAlias - -Location: TypeAlias = Optional[Tuple[int, int]] +def make_loc(position: tuple[int, int]) -> Location: + """Create a Location for testing with the given (start, end) offsets.""" + source = Source(body="") + token = Token( + kind=TokenKind.NAME, start=position[0], end=position[1], line=1, column=1 + ) + return Location(start_token=token, end_token=token, source=source) -def assert_syntax_error(text: str, message: str, location: Location) -> None: +def assert_syntax_error( + text: str, message: str, location: tuple[int, int] | None +) -> None: with pytest.raises(GraphQLSyntaxError) as exc_info: parse(text) error = exc_info.value @@ -59,85 +66,104 @@ def assert_syntax_error(text: str, message: str, location: Location) -> None: assert error.locations == [location] -def assert_definitions(body: str, loc: Location, num=1): +def assert_definitions(body: str, position: tuple[int, int] | None, num: int = 1): doc = parse(body) assert isinstance(doc, DocumentNode) - assert doc.loc == loc + assert doc.loc == position definitions = doc.definitions assert isinstance(definitions, tuple) assert len(definitions) == num return definitions[0] if num == 1 else definitions -def type_node(name: str, loc: Location): - return NamedTypeNode(name=name_node(name, loc), loc=loc) +def type_node(name: str, position: tuple[int, int]): + return NamedTypeNode(name=name_node(name, position), loc=make_loc(position)) -def name_node(name: str, loc: Location): - return NameNode(value=name, loc=loc) +def name_node(name: str, position: tuple[int, int]): + return NameNode(value=name, loc=make_loc(position)) -def field_node(name: NameNode, type_: TypeNode, loc: Location): - return field_node_with_args(name, type_, (), loc) +def field_node(name: NameNode, type_: TypeNode, position: tuple[int, int]): + return field_node_with_args(name, type_, (), position) -def field_node_with_args(name: NameNode, type_: TypeNode, args: tuple, loc: Location): +def field_node_with_args( + name: NameNode, type_: TypeNode, args: tuple, position: tuple[int, int] +): return FieldDefinitionNode( - name=name, arguments=args, type=type_, directives=(), loc=loc, description=None + name=name, + arguments=args, + type=type_, + directives=(), + loc=make_loc(position), + description=None, ) -def non_null_type(type_: TypeNode, loc: Location): - return NonNullTypeNode(type=type_, loc=loc) +def non_null_type(type_: NamedTypeNode | ListTypeNode, position: tuple[int, int]): + return NonNullTypeNode(type=type_, loc=make_loc(position)) -def enum_value_node(name: str, loc: Location): +def enum_value_node(name: str, position: tuple[int, int]): return EnumValueDefinitionNode( - name=name_node(name, loc), directives=(), loc=loc, description=None + name=name_node(name, position), + directives=(), + loc=make_loc(position), + description=None, ) def input_value_node( - name: NameNode, type_: TypeNode, default_value: ValueNode | None, loc: Location + name: NameNode, + type_: TypeNode, + default_value: ConstValueNode | None, + position: tuple[int, int], ): return InputValueDefinitionNode( name=name, type=type_, default_value=default_value, directives=(), - loc=loc, + loc=make_loc(position), description=None, ) -def boolean_value_node(value: bool, loc: Location): - return BooleanValueNode(value=value, loc=loc) +def boolean_value_node(value: bool, position: tuple[int, int]): + return BooleanValueNode(value=value, loc=make_loc(position)) -def string_value_node(value: str, block: bool | None, loc: Location): - return StringValueNode(value=value, block=block, loc=loc) +def string_value_node(value: str, block: bool | None, position: tuple[int, int]): + return StringValueNode(value=value, block=block, loc=make_loc(position)) -def list_type_node(type_: TypeNode, loc: Location): - return ListTypeNode(type=type_, loc=loc) +def list_type_node(type_: TypeNode, position: tuple[int, int]): + return ListTypeNode(type=type_, loc=make_loc(position)) def schema_extension_node( - directives: tuple[DirectiveNode, ...], + directives: tuple[ConstDirectiveNode, ...], operation_types: tuple[OperationTypeDefinitionNode, ...], - loc: Location, + position: tuple[int, int], ): return SchemaExtensionNode( - directives=directives, operation_types=operation_types, loc=loc + directives=directives, operation_types=operation_types, loc=make_loc(position) ) -def operation_type_definition(operation: OperationType, type_: TypeNode, loc: Location): - return OperationTypeDefinitionNode(operation=operation, type=type_, loc=loc) +def operation_type_definition( + operation: OperationType, type_: NamedTypeNode, position: tuple[int, int] +): + return OperationTypeDefinitionNode( + operation=operation, type=type_, loc=make_loc(position) + ) -def directive_node(name: NameNode, arguments: tuple[ArgumentNode, ...], loc: Location): - return DirectiveNode(name=name, arguments=arguments, loc=loc) +def directive_node( + name: NameNode, arguments: tuple[ArgumentNode, ...], position: tuple[int, int] +): + return DirectiveNode(name=name, arguments=arguments, loc=make_loc(position)) def describe_schema_parser(): diff --git a/tests/language/test_visitor.py b/tests/language/test_visitor.py index 2a8c2bab..b373dbfd 100644 --- a/tests/language/test_visitor.py +++ b/tests/language/test_visitor.py @@ -308,7 +308,7 @@ def allows_editing_a_node_both_on_enter_and_on_leave(): visited = [] class TestVisitor(Visitor): - selection_set = None + selection_set: SelectionSetNode | None = None def enter_operation_definition(self, *args): check_visitor_fn_args(ast, *args) @@ -330,6 +330,7 @@ def leave_operation_definition(self, *args): check_visitor_fn_args_edited(ast, *args) node = args[0] assert not node.selection_set.selections + assert self.selection_set is not None # Create new node with original selection set (immutable pattern) new_node = OperationDefinitionNode( operation=node.operation, diff --git a/tests/type/test_definition.py b/tests/type/test_definition.py index 8b93fe54..40e96867 100644 --- a/tests/type/test_definition.py +++ b/tests/type/test_definition.py @@ -25,11 +25,15 @@ InputValueDefinitionNode, InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode, + NamedTypeNode, + NameNode, ObjectTypeDefinitionNode, ObjectTypeExtensionNode, OperationDefinitionNode, + OperationType, ScalarTypeDefinitionNode, ScalarTypeExtensionNode, + SelectionSetNode, StringValueNode, UnionTypeDefinitionNode, UnionTypeExtensionNode, @@ -63,6 +67,16 @@ except ImportError: # Python < 3.10 from typing_extensions import TypeGuard + +# Helper functions to create stub AST nodes with required fields +def _stub_name(name: str = "Stub") -> NameNode: + return NameNode(value=name) + + +def _stub_type() -> NamedTypeNode: + return NamedTypeNode(name=_stub_name("StubType")) + + ScalarType = GraphQLScalarType("Scalar") ObjectType = GraphQLObjectType("Object", {}) InterfaceType = GraphQLInterfaceType("Interface", {}) @@ -165,8 +179,8 @@ def use_parse_value_for_parsing_literals_if_parse_literal_omitted(): ) def accepts_a_scalar_type_with_ast_node_and_extension_ast_nodes(): - ast_node = ScalarTypeDefinitionNode() - extension_ast_nodes = [ScalarTypeExtensionNode()] + ast_node = ScalarTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [ScalarTypeExtensionNode(name=_stub_name())] scalar = GraphQLScalarType( "SomeScalar", ast_node=ast_node, extension_ast_nodes=extension_ast_nodes ) @@ -435,8 +449,8 @@ def accepts_a_lambda_as_an_object_field_resolver(): assert obj_type.fields def accepts_an_object_type_with_ast_node_and_extension_ast_nodes(): - ast_node = ObjectTypeDefinitionNode() - extension_ast_nodes = [ObjectTypeExtensionNode()] + ast_node = ObjectTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [ObjectTypeExtensionNode(name=_stub_name())] object_type = GraphQLObjectType( "SomeObject", {"f": GraphQLField(ScalarType)}, @@ -601,8 +615,8 @@ def interfaces(): assert calls == 1 def accepts_an_interface_type_with_ast_node_and_extension_ast_nodes(): - ast_node = InterfaceTypeDefinitionNode() - extension_ast_nodes = [InterfaceTypeExtensionNode()] + ast_node = InterfaceTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [InterfaceTypeExtensionNode(name=_stub_name())] interface_type = GraphQLInterfaceType( "SomeInterface", {"f": GraphQLField(ScalarType)}, @@ -667,8 +681,8 @@ def accepts_a_union_type_without_types(): assert union_type.types == () def accepts_a_union_type_with_ast_node_and_extension_ast_nodes(): - ast_node = UnionTypeDefinitionNode() - extension_ast_nodes = [UnionTypeExtensionNode()] + ast_node = UnionTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [UnionTypeExtensionNode(name=_stub_name())] union_type = GraphQLUnionType( "SomeUnion", [ObjectType], @@ -894,8 +908,8 @@ def parses_an_enum(): ) def accepts_an_enum_type_with_ast_node_and_extension_ast_nodes(): - ast_node = EnumTypeDefinitionNode() - extension_ast_nodes = [EnumTypeExtensionNode()] + ast_node = EnumTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [EnumTypeExtensionNode(name=_stub_name())] enum_type = GraphQLEnumType( "SomeEnum", {}, @@ -1010,8 +1024,8 @@ def provides_default_out_type_if_omitted(): assert input_obj_type.to_kwargs()["out_type"] is None def accepts_an_input_object_type_with_ast_node_and_extension_ast_nodes(): - ast_node = InputObjectTypeDefinitionNode() - extension_ast_nodes = [InputObjectTypeExtensionNode()] + ast_node = InputObjectTypeDefinitionNode(name=_stub_name()) + extension_ast_nodes = [InputObjectTypeExtensionNode(name=_stub_name())] input_obj_type = GraphQLInputObjectType( "SomeInputObject", {}, @@ -1126,7 +1140,7 @@ def provides_no_out_name_if_omitted(): assert argument.to_kwargs()["out_name"] is None def accepts_an_argument_with_an_ast_node(): - ast_node = InputValueDefinitionNode() + ast_node = InputValueDefinitionNode(name=_stub_name(), type=_stub_type()) argument = GraphQLArgument(GraphQLString, ast_node=ast_node) assert argument.ast_node is ast_node assert argument.to_kwargs()["ast_node"] is ast_node @@ -1157,7 +1171,7 @@ def provides_no_out_name_if_omitted(): assert input_field.to_kwargs()["out_name"] is None def accepts_an_input_field_with_an_ast_node(): - ast_node = InputValueDefinitionNode() + ast_node = InputValueDefinitionNode(name=_stub_name(), type=_stub_type()) input_field = GraphQLArgument(GraphQLString, ast_node=ast_node) assert input_field.ast_node is ast_node assert input_field.to_kwargs()["ast_node"] is ast_node @@ -1299,7 +1313,9 @@ class InfoArgs(TypedDict): "schema": GraphQLSchema(), "fragments": {}, "root_value": None, - "operation": OperationDefinitionNode(), + "operation": OperationDefinitionNode( + operation=OperationType.QUERY, selection_set=SelectionSetNode() + ), "variable_values": {}, "is_awaitable": is_awaitable, } diff --git a/tests/type/test_directives.py b/tests/type/test_directives.py index 0da2a4c7..5e4bfffb 100644 --- a/tests/type/test_directives.py +++ b/tests/type/test_directives.py @@ -1,14 +1,18 @@ import pytest from graphql.error import GraphQLError -from graphql.language import DirectiveDefinitionNode, DirectiveLocation +from graphql.language import DirectiveDefinitionNode, DirectiveLocation, NameNode from graphql.type import GraphQLArgument, GraphQLDirective, GraphQLInt, GraphQLString def describe_type_system_directive(): def can_create_instance(): arg = GraphQLArgument(GraphQLString, description="arg description") - node = DirectiveDefinitionNode() + node = DirectiveDefinitionNode( + name=NameNode(value="test"), + repeatable=False, + locations=(), + ) locations = [DirectiveLocation.SCHEMA, DirectiveLocation.OBJECT] directive = GraphQLDirective( name="test", diff --git a/tests/type/test_schema.py b/tests/type/test_schema.py index 7c673a1e..6f69f701 100644 --- a/tests/type/test_schema.py +++ b/tests/type/test_schema.py @@ -425,7 +425,7 @@ def configures_the_schema_to_have_no_errors(): def describe_ast_nodes(): def accepts_a_scalar_type_with_ast_node_and_extension_ast_nodes(): - ast_node = SchemaDefinitionNode() + ast_node = SchemaDefinitionNode(operation_types=()) extension_ast_nodes = [SchemaExtensionNode()] schema = GraphQLSchema( GraphQLObjectType("Query", {}),