From cfe2caf1a914ac34fea61f5363a7833e201c06f4 Mon Sep 17 00:00:00 2001 From: Luke Murray Date: Sat, 28 Oct 2023 16:22:42 +1100 Subject: [PATCH] first pass at a Aggregate extension --- .../Compiler/Util/LinqRuntimeTypeBuilder.cs | 10 +- .../Aggregate/AggregateExtension.cs | 216 +++++++++++++++ .../Aggregate/UseAggregateExtension.cs | 98 +++++++ .../AggregateTests/AggregateExtensionTests.cs | 245 ++++++++++++++++++ .../EntityGraphQL.Tests/TestDataContext.cs | 9 +- 5 files changed, 571 insertions(+), 7 deletions(-) create mode 100644 src/EntityGraphQL/Schema/FieldExtensions/Aggregate/AggregateExtension.cs create mode 100644 src/EntityGraphQL/Schema/FieldExtensions/Aggregate/UseAggregateExtension.cs create mode 100644 src/tests/EntityGraphQL.Tests/AggregateTests/AggregateExtensionTests.cs diff --git a/src/EntityGraphQL/Compiler/Util/LinqRuntimeTypeBuilder.cs b/src/EntityGraphQL/Compiler/Util/LinqRuntimeTypeBuilder.cs index 58ab5b75..2212f610 100644 --- a/src/EntityGraphQL/Compiler/Util/LinqRuntimeTypeBuilder.cs +++ b/src/EntityGraphQL/Compiler/Util/LinqRuntimeTypeBuilder.cs @@ -34,12 +34,12 @@ private static string GetTypeKey(Dictionary fields) /// /// /// - public static Type GetDynamicType(Dictionary fields, string description, Type? parentType = null) + public static Type GetDynamicType(Dictionary fields, string description, Type? parentType = null, Type[]? interfaces = null, Action? build = null) { if (null == fields) throw new ArgumentNullException(nameof(fields)); - - string classFullName = GetTypeKey(fields) + parentType?.Name.GetHashCode(); + + string classFullName = GetTypeKey(fields) + parentType?.Name.GetHashCode() + (interfaces != null ? string.Join("_", interfaces.Select(i => i.Name.GetHashCode())) : ""); lock (typesByFullName) { if (!typesByFullName.ContainsKey(classFullName)) @@ -51,7 +51,7 @@ public static Type GetDynamicType(Dictionary fields, string descri if (builtTypes.ContainsKey(classId)) return builtTypes[classId]; - var typeBuilder = moduleBuilder.DefineType(classId.ToString(), TypeAttributes.Public | TypeAttributes.Class | TypeAttributes.Serializable, parentType); + var typeBuilder = moduleBuilder.DefineType(classId.ToString(), TypeAttributes.Public | TypeAttributes.Class | TypeAttributes.Serializable, parentType, interfaces); foreach (var field in fields) { @@ -61,6 +61,8 @@ public static Type GetDynamicType(Dictionary fields, string descri typeBuilder.DefineField(field.Key, field.Value, FieldAttributes.Public); } + build?.Invoke(typeBuilder); + builtTypes[classId] = typeBuilder.CreateTypeInfo()!.AsType(); return builtTypes[classId]; } diff --git a/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/AggregateExtension.cs b/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/AggregateExtension.cs new file mode 100644 index 00000000..d9be0ca0 --- /dev/null +++ b/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/AggregateExtension.cs @@ -0,0 +1,216 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Reflection.Emit; +using EntityGraphQL.Compiler; +using EntityGraphQL.Compiler.Util; +using EntityGraphQL.Extensions; + +namespace EntityGraphQL.Schema.FieldExtensions; + +/// +/// Builds a field to query aggregate data on a list field. +/// +/// E.g. Given a field that returns a list of people: people: [Person!]! it will add a new field peopleAggregate: PeopleAggregate +/// similar to the following: +/// +/// schema.AddType("PeopleAggregate", "Aggregate people", type => +/// { +/// type.AddField("count", (c) => c.Count(), "Count of people"); +/// type.AddField("heightMin", (c) => c.Min(p => p.Height), "Min height"); +/// type.AddField("heightMax", (c) => c.Max(p => p.Height), "Max height"); +/// type.AddField("heightAvg", (c) => c.Average(p => p.Height), "Average height"); +/// type.AddField("heightSum", (c) => c.Sum(p => p.Height), "Sum of height"); +/// }); +/// schema.Query().AddField("peopleAggregate", ctx => ctx.People, "Aggregate people") +/// .Returns("PeopleAggregate"); +/// +/// Where PeopleAggregate is a new dotnet type used to differentiate the return type from a normal list of people. +/// +/// public class PeopleAggregate : IEnumerable +/// { +/// public IEnumerator GetEnumerator() => throw new System.NotImplementedException(); +/// IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); +/// } +/// +/// If this is at the root level of the query it would build a LINQ expression query similar to the following: +/// (MyContext ctx) => new +/// { +/// count = ctx.People.Count(), +/// minHeight = ctx.People.Min(p => p.Height), +/// maxHeight = ctx.People.Max(p => p.Height), +/// avgHeight = ctx.People.Average(p => p.Height), +/// sumHeight = ctx.People.Sum(p => p.Height) +/// }; +/// +public class AggregateExtension : BaseFieldExtension +{ + private static readonly object addAggregateTypeLock = new(); + private readonly string? fieldName; + private readonly List? aggregateFieldList; + private readonly bool fieldListIsExclude; + private static int dupeCnt = 1; + + public AggregateExtension(string? fieldName, IEnumerable? aggregateFieldList, bool excludeFields) + { + this.fieldName = fieldName; + this.aggregateFieldList = aggregateFieldList?.ToList(); + fieldListIsExclude = excludeFields; + } + + public override void Configure(ISchemaProvider schema, IField field) + { + if (field.ResolveExpression == null) + throw new EntityGraphQLCompilerException($"ConnectionPagingExtension requires a Resolve function set on the field"); + + if (!field.ResolveExpression.Type.IsEnumerableOrArray()) + throw new ArgumentException($"Expression for field {field.Name} must be a collection to use ConnectionPagingExtension. Found type {field.ReturnType.TypeDotnet}"); + + var schemaTypeName = $"{field.FromType.Name}{field.Name.FirstCharToUpper()}Aggregate"; + var fieldName = this.fieldName ?? $"{field.Name}Aggregate"; + var listElementType = field.ReturnType.TypeDotnet.GetEnumerableOrArrayType()!; + + // likely not to happen in normal use but the unit tests will hit this as it tries to create the same type multiple times + lock (addAggregateTypeLock) + { + // first build type and fetch it from the dynamic types - this will reuse the type if it already exists + var aggregateDotnetType = GetDotnetType(field, listElementType); + + // Then check if we have that in the schema already + var aggregateSchemaType = schema.HasType(aggregateDotnetType) ? schema.GetSchemaType(aggregateDotnetType, null) : null; + if (aggregateSchemaType == null) + { + // The type may be different but it may have the same name, see test TestDifferentOptionsOnSameTypeDifferentFields + if (schema.HasType(schemaTypeName)) + schemaTypeName = $"{schemaTypeName}{dupeCnt++}"; // could do this better + + // use reflection to call AddAggregateTypeToSchema + var addTypeMethod = typeof(AggregateExtension).GetMethod(nameof(AddAggregateTypeToSchema), BindingFlags.NonPublic | BindingFlags.Static); + var genericAddTypeMethod = addTypeMethod!.MakeGenericMethod(aggregateDotnetType, listElementType); + aggregateSchemaType = (genericAddTypeMethod.Invoke(this, new object[] { schema, schemaTypeName, field }) as ISchemaType)!; + + var contextParam = Expression.Parameter(aggregateSchemaType.TypeDotnet, schemaTypeName); + // set up all the fields on the aggregate type + ForEachPossibleAggregateField(field.ReturnType.SchemaType.GetFields(), + (possibleAggregateField, returnFieldType) => + { + AddAggregateFieldByReflection("Average", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType); + AddAggregateFieldByReflection("Sum", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType); + AddAggregateFieldByReflection("Min", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType); + AddAggregateFieldByReflection("Max", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType); + }, + (possibleAggregateField, returnFieldType) => + { + AddAggregateFieldByReflection("Min", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType, true); + AddAggregateFieldByReflection("Max", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType, true); + } + ); + } + else + { + schemaTypeName = aggregateSchemaType.Name; + } + } + + var newField = new Field(schema, field.FromType, fieldName, Expression.Lambda(field.ResolveExpression, field.FieldParam!), $"Aggregate data for {field.Name}", null, new GqlTypeInfo(() => schema.Type(schemaTypeName), schema.Type(schemaTypeName).TypeDotnet), null); + field.FromType.AddField(newField); + } + + private Type GetDotnetType(IField field, Type listElementType) + { + var fields = new Dictionary { + { "count", typeof(int) } + }; + ForEachPossibleAggregateField(field.ReturnType.SchemaType.GetFields(), + (possibleAggregateField, returnFieldType) => + { + fields.Add($"{possibleAggregateField.Name}Average", returnFieldType); + fields.Add($"{possibleAggregateField.Name}Sum", returnFieldType); + fields.Add($"{possibleAggregateField.Name}Min", returnFieldType); + fields.Add($"{possibleAggregateField.Name}Max", returnFieldType); + }, + (possibleAggregateField, returnFieldType) => + { + fields.Add($"{possibleAggregateField.Name}Min", returnFieldType); + fields.Add($"{possibleAggregateField.Name}Max", returnFieldType); + } + ); + var aggregateDotnetType = LinqRuntimeTypeBuilder.GetDynamicType(fields, field.Name, null, + new[] { typeof(IEnumerable<>).MakeGenericType(listElementType) }, + aggregateDotnetTypeDef => + { + // define the IEnumerable implementation + var getEnumeratorMethod = typeof(IEnumerable<>).MakeGenericType(listElementType).GetMethod("GetEnumerator"); + var getEnumeratorIL = aggregateDotnetTypeDef.DefineMethod(getEnumeratorMethod!.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, getEnumeratorMethod.ReturnType, Type.EmptyTypes).GetILGenerator(); + getEnumeratorIL.Emit(OpCodes.Newobj, typeof(NotImplementedException).GetConstructor(Type.EmptyTypes)!); + var getEnumeratorMethod2 = typeof(System.Collections.IEnumerable).GetMethod("GetEnumerator"); + var getEnumeratorIL2 = aggregateDotnetTypeDef.DefineMethod(getEnumeratorMethod2!.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, getEnumeratorMethod2.ReturnType, Type.EmptyTypes).GetILGenerator(); + getEnumeratorIL2.Emit(OpCodes.Newobj, typeof(NotImplementedException).GetConstructor(Type.EmptyTypes)!); + })!; + return aggregateDotnetType; + } + + private void ForEachPossibleAggregateField(IEnumerable fields, Action aggregateFieldActionNumeric, Action aggregateFieldTyped) + { + foreach (IField possibleAggregateField in fields) + { + if (possibleAggregateField.Name.StartsWith("__", StringComparison.InvariantCulture) + || possibleAggregateField.ResolveExpression == null) + continue; + if (aggregateFieldList != null) + { + if (fieldListIsExclude && aggregateFieldList.Contains(possibleAggregateField.Name)) + continue; + if (!fieldListIsExclude && !aggregateFieldList.Contains(possibleAggregateField.Name)) + continue; + } + + // use reflection to build aggregateSchemaType.AddField("min", (c) => c.Min(fieldExp), "Min value"); ETC + // average & sum can only be done on numeric types from Queryable.Average/Sum + var returnFieldType = possibleAggregateField.ReturnType.TypeDotnet; + if (returnFieldType == typeof(int) || returnFieldType == typeof(int?) + || returnFieldType == typeof(long) || returnFieldType == typeof(long?) + || returnFieldType == typeof(double) || returnFieldType == typeof(double?) + || returnFieldType == typeof(decimal) || returnFieldType == typeof(decimal?) + || returnFieldType == typeof(float) || returnFieldType == typeof(float?)) + { + aggregateFieldActionNumeric(possibleAggregateField, returnFieldType); + } + else if (returnFieldType == typeof(DateTimeOffset) || returnFieldType == typeof(DateTimeOffset?) + || returnFieldType == typeof(DateTime) || returnFieldType == typeof(DateTime?) + ) + { + aggregateFieldTyped(possibleAggregateField, returnFieldType); + } + } + } + + private static void AddAggregateFieldByReflection(string method, Type aggregateDotnetType, ISchemaType aggregateSchemaType, IField field, ParameterExpression contextParam, Type listElementType, bool typedReturn = false) + { + var fieldName = $"{field.Name}{method}"; + var fieldDescription = $"{method} of {field.Name}"; + var genTypes = typedReturn ? new[] { listElementType, field.ReturnType.TypeDotnet } : new[] { listElementType }; + var call = Expression.Call(typeof(Enumerable), method, genTypes, contextParam, Expression.Lambda(field.ResolveExpression!, field.FieldParam!)); + var fieldExp = Expression.Lambda(call, contextParam); + // find public Field AddField(string name, Expression> fieldSelection, string? description) + var addFieldMethod = aggregateSchemaType.GetType().GetMethods() + .SingleOrDefault(m => m.Name == nameof(ISchemaType.AddField) + && m.IsGenericMethod + && m.ReturnType == typeof(Field) + && m.GetGenericArguments().Length == 1 + && m.GetParameters().Length == 3); + var genericAddFieldMethod = addFieldMethod!.MakeGenericMethod(fieldExp.ReturnType); + genericAddFieldMethod.Invoke(aggregateSchemaType, new object[] { fieldName, fieldExp, fieldDescription }); + } + + private static ISchemaType AddAggregateTypeToSchema(ISchemaProvider schema, string schemaTypeName, IField field) where TType : class, IEnumerable + { + var aggregateType = schema.AddType(schemaTypeName, $"Aggregate {field.Name}", type => + { + type.AddField("count", (c) => c.Count(), "Count of items"); + }); + return aggregateType; + } +} \ No newline at end of file diff --git a/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/UseAggregateExtension.cs b/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/UseAggregateExtension.cs new file mode 100644 index 00000000..8d0270d7 --- /dev/null +++ b/src/EntityGraphQL/Schema/FieldExtensions/Aggregate/UseAggregateExtension.cs @@ -0,0 +1,98 @@ +using System; +using System.Collections.Generic; + +namespace EntityGraphQL.Schema.FieldExtensions; + +public static class UseAggregateExtension +{ + /// + /// If the field is a list, add a new field at the same level with the name {field}Aggregate + /// Only call on a field that returns an IEnumerable + /// + /// + /// Use this for the name of the created field. Is null the field will be called Aggregate + /// + public static IField UseAggregate(this IField field, string? fieldName = null) + { + return field.AddExtension(new AggregateExtension(fieldName, null, false)); + } + + /// + /// If the field is a list, add a new field at the same level with the name {field}Aggregate + /// Only call on a field that returns an IEnumerable + /// + /// + /// + /// + /// + /// + /// GraphQL Schema field names to include or exclude in building the aggregation fields. By default numerical and date fields + /// will have the relevant aggregation fields created. If the field is not available on the type it will be ignored. + /// Field name in GraphQL Schema are case sensitive. + /// + /// If true, the fields in fieldSelection will be excluded from the aggregate fields instead + /// + /// + public static IField UseAggregate(this IField field, IEnumerable? fieldSelection, bool excludeFields = false, string? fieldName = null) + { + return field.AddExtension(new AggregateExtension(fieldName, fieldSelection, excludeFields)); + } +} + +public class UseAggregateAttribute : ExtensionAttribute +{ + /// + /// Overrides the default name of the aggregate field. If null the field will be called Aggregate + /// + public string? FieldName { get; set; } + /// + /// If false, you will need to use [IncludeAggregateField] to include fields in the aggregate + /// + public bool AutoAddFields { get; set; } = true; + + public UseAggregateAttribute() { } + + public override void ApplyExtension(IField field) + { + var fieldList = AutoAddFields ? FindExcludedFields(field) : FindIncludedFields(field); + field.UseAggregate(fieldList, AutoAddFields, FieldName); + } + + private static IEnumerable FindIncludedFields(IField field) + { + var includedFields = new List(); + foreach (var prop in field.ReturnType.SchemaType.TypeDotnet.GetProperties()) + { + if (prop.GetCustomAttributes(typeof(IncludeAggregateFieldAttribute), true).Length > 0) + { + var (name, _) = SchemaBuilder.GetNameAndDescription(prop, field.Schema); + includedFields.Add(name); + } + } + return includedFields; + } + + private static IEnumerable FindExcludedFields(IField field) + { + var excludedFields = new List(); + foreach (var prop in field.ReturnType.SchemaType.TypeDotnet.GetProperties()) + { + if (prop.GetCustomAttributes(typeof(ExcludeAggregateFieldAttribute), true).Length > 0) + { + var (name, _) = SchemaBuilder.GetNameAndDescription(prop, field.Schema); + excludedFields.Add(name); + } + } + return excludedFields; + } +} + +[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Method, AllowMultiple = false)] +public class IncludeAggregateFieldAttribute : Attribute +{ +} + +[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Method, AllowMultiple = false)] +public class ExcludeAggregateFieldAttribute : Attribute +{ +} \ No newline at end of file diff --git a/src/tests/EntityGraphQL.Tests/AggregateTests/AggregateExtensionTests.cs b/src/tests/EntityGraphQL.Tests/AggregateTests/AggregateExtensionTests.cs new file mode 100644 index 00000000..e7bf6f96 --- /dev/null +++ b/src/tests/EntityGraphQL.Tests/AggregateTests/AggregateExtensionTests.cs @@ -0,0 +1,245 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using EntityGraphQL.Schema; +using EntityGraphQL.Schema.FieldExtensions; +using Xunit; + +namespace EntityGraphQL.Tests.AggregateExtensionTests; + +public class AggregateExtensionTests +{ + [Fact] + public void TestGetsAllAtRoot() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + data.People.Add(new Person { Height = 184 }); + data.People.Add(new Person { Height = 175 }); + data.People.Add(new Person { Height = 163 }); + data.People.Add(new Person { Height = 167 }); + + schema.Query().ReplaceField("people", ctx => ctx.People, "Return list of people") + .UseAggregate(); + + var gql = new QueryRequest + { + Query = @"{ + peopleAggregate { + count + heightMin + heightMax + heightAverage + heightSum + } + }", + }; + + var result = schema.ExecuteRequestWithContext(gql, data, null, null); + Assert.Null(result.Errors); + + dynamic peopleAggregate = result.Data["peopleAggregate"]; + Assert.Equal(4, peopleAggregate.count); + Assert.Equal(163, peopleAggregate.heightMin); + Assert.Equal(184, peopleAggregate.heightMax); + Assert.Equal(172.25, peopleAggregate.heightAverage); + Assert.Equal(689, peopleAggregate.heightSum); + } + + [Fact] + public void TestGetsAllAtNonRoot() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + data.Projects.Add(new Project + { + Name = "Project 1", + Tasks = new List + { + new Task { Name = "Task 1", HoursEstimated = 1 }, + new Task { Name = "Task 2", HoursEstimated = 2 }, + new Task { Name = "Task 3", HoursEstimated = 3 }, + new Task { Name = "Task 4", HoursEstimated = 4 }, + } + }); + + // Project.Tasks has [UseAggregate] + + var gql = new QueryRequest + { + Query = @"{ + projects { + tasksAggregate { + count + hoursEstimatedMin + hoursEstimatedMax + hoursEstimatedAverage + hoursEstimatedSum + } + } + }", + }; + + var result = schema.ExecuteRequestWithContext(gql, data, null, null); + Assert.Null(result.Errors); + + dynamic taskAggregate = ((dynamic)result.Data["projects"])[0].tasksAggregate; + Assert.Equal(4, taskAggregate.count); + Assert.Equal(1, taskAggregate.hoursEstimatedMin); + Assert.Equal(4, taskAggregate.hoursEstimatedMax); + Assert.Equal(2.5, taskAggregate.hoursEstimatedAverage); + Assert.Equal(10, taskAggregate.hoursEstimatedSum); + } + + [Fact] + public void TestRenameField() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + data.People.Add(new Person { Height = 184 }); + data.People.Add(new Person { Height = 175 }); + data.People.Add(new Person { Height = 163 }); + data.People.Add(new Person { Height = 167 }); + + schema.Query().ReplaceField("people", ctx => ctx.People, "Return list of people") + .UseAggregate("aggregatePeeps"); + + var gql = new QueryRequest + { + Query = @"{ + aggregatePeeps { + count + heightMin + heightMax + heightAverage + heightSum + } + }", + }; + + var result = schema.ExecuteRequestWithContext(gql, data, null, null); + Assert.Null(result.Errors); + + dynamic peopleAggregate = result.Data["aggregatePeeps"]; + Assert.Equal(4, peopleAggregate.count); + Assert.Equal(163, peopleAggregate.heightMin); + Assert.Equal(184, peopleAggregate.heightMax); + Assert.Equal(172.25, peopleAggregate.heightAverage); + Assert.Equal(689, peopleAggregate.heightSum); + } + + [Fact] + public void TestOnlyIncludeCertainFields() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + + schema.Query().ReplaceField("people", ctx => ctx.People, "Return list of people") + // only include height field + .UseAggregate(new string[] { "height" }); + + Assert.Empty(schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "idSum" || f.Name == "idMin" || f.Name == "idMax" || f.Name == "idAverage")); + Assert.Empty(schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "birthdayMin" || f.Name == "birthdayMax")); + Assert.Equal(4, schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "heightSum" || f.Name == "heightMin" || f.Name == "heightMax" || f.Name == "heightAverage").Count()); + } + + [Fact] + public void TestOnlyExcludeCertainFields() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + + schema.Query().ReplaceField("people", ctx => ctx.People, "Return list of people") + // exclude height field + .UseAggregate(new string[] { "height" }, true); + + Assert.Equal(4, schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "idSum" || f.Name == "idMin" || f.Name == "idMax" || f.Name == "idAverage").Count()); + Assert.Empty(schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "heightSum" || f.Name == "heightMin" || f.Name == "heightMax" || f.Name == "heightAverage")); + Assert.Equal(2, schema.GetSchemaType("QueryPeopleAggregate", null).GetFields().Where(f => f.Name == "birthdayMin" || f.Name == "birthdayMax").Count()); + } + + [Fact] + public void TestDifferentOptionsOnSameTypeDifferentFields() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContext(); + + schema.Query().ReplaceField("tasks", ctx => ctx.Tasks, "Return list of tasks") + .UseAggregate(new string[] { "hoursEstimated" }); + + schema.UpdateType(type => + { + type.ReplaceField("tasks", ctx => ctx.Tasks, "Return list of tasks") + .UseAggregate(new string[] { "hoursCompleted" }); + }); + + Assert.Equal(6, schema.GetSchemaType("QueryTasksAggregate", null).GetFields().Count()); + Assert.Equal(4, schema.GetSchemaType("QueryTasksAggregate", null).GetFields().Where(f => f.Name == "hoursEstimatedSum" || f.Name == "hoursEstimatedMin" || f.Name == "hoursEstimatedMax" || f.Name == "hoursEstimatedAverage").Count()); + Assert.Equal(6, schema.GetSchemaType("PersonTasksAggregate", null).GetFields().Count()); + Assert.Equal(4, schema.GetSchemaType("PersonTasksAggregate", null).GetFields().Where(f => f.Name == "hoursCompletedSum" || f.Name == "hoursCompletedMin" || f.Name == "hoursCompletedMax" || f.Name == "hoursCompletedAverage").Count()); + } + + [Fact] + public void TestIncludeAttribute() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContextExtended(); + + schema.Query().ReplaceField("SomeEntities", ctx => ctx.SomeEntities, "Return list of SomeEntities"); + + Assert.Empty(schema.GetSchemaType("QuerySomeEntitiesAggregate", null).GetFields().Where(f => f.Name == "idSum" || f.Name == "idMin" || f.Name == "idMax" || f.Name == "idAverage")); + Assert.Empty(schema.GetSchemaType("QuerySomeEntitiesAggregate", null).GetFields().Where(f => f.Name == "birthdayMin" || f.Name == "birthdayMax")); + Assert.Equal(4, schema.GetSchemaType("QuerySomeEntitiesAggregate", null).GetFields().Where(f => f.Name == "heightSum" || f.Name == "heightMin" || f.Name == "heightMax" || f.Name == "heightAverage").Count()); + } + + [Fact] + public void TestExcludeAttribute() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContextExtended(); + + schema.Query().ReplaceField("OtherEntities", ctx => ctx.OtherEntities, "Return list of OtherEntities"); + + Assert.Empty(schema.GetSchemaType("QueryOtherEntitiesAggregate", null).GetFields().Where(f => f.Name == "idSum" || f.Name == "idMin" || f.Name == "idMax" || f.Name == "idAverage")); + Assert.Empty(schema.GetSchemaType("QueryOtherEntitiesAggregate", null).GetFields().Where(f => f.Name == "heightSum" || f.Name == "heightMin" || f.Name == "heightMax" || f.Name == "heightAverage")); + Assert.Equal(2, schema.GetSchemaType("QueryOtherEntitiesAggregate", null).GetFields().Where(f => f.Name == "birthdayMin" || f.Name == "birthdayMax").Count()); + } + + [Fact] + public void TestUsesSchemaFieldName() + { + var schema = SchemaBuilder.FromObject(); + var data = new TestDataContextExtended(); + + Assert.Single(schema.Query().GetFields().Where(f => f.Name == "renamedAggregate")); + } +} + +public class TestDataContextExtended : TestDataContext +{ + [UseAggregate(AutoAddFields = false)] // force use of [IncludeAggregateField] + public IEnumerable SomeEntities { get; set; } = new List(); + + [UseAggregate] + public IEnumerable OtherEntities { get; set; } = new List(); + [UseAggregate] + [GraphQLField("renamed")] + public IEnumerable OtherEntities2 { get; set; } = new List(); +} + +public class EntityWithExclude +{ + [ExcludeAggregateField] + public int Id { get; set; } + public DateTime Birthday { get; set; } + [ExcludeAggregateField] + public float Height { get; set; } +} + +public class EntityWithInclude +{ + public int Id { get; set; } + public DateTime Birthday { get; set; } + [IncludeAggregateField] + public float Height { get; set; } +} \ No newline at end of file diff --git a/src/tests/EntityGraphQL.Tests/TestDataContext.cs b/src/tests/EntityGraphQL.Tests/TestDataContext.cs index 701684f4..e51216d1 100644 --- a/src/tests/EntityGraphQL.Tests/TestDataContext.cs +++ b/src/tests/EntityGraphQL.Tests/TestDataContext.cs @@ -4,6 +4,7 @@ using EntityGraphQL.Schema; using Newtonsoft.Json; using Newtonsoft.Json.Converters; +using EntityGraphQL.Schema.FieldExtensions; namespace EntityGraphQL.Tests { @@ -15,13 +16,13 @@ namespace EntityGraphQL.Tests public class TestDataContext { [GraphQLIgnore] - private IEnumerable projects = new List(); + private List projects = new List(); public int TotalPeople => People.Count; [Obsolete("This is obsolete, use Projects instead")] public IEnumerable ProjectsOld { get; set; } - public IEnumerable Projects { get => projects; set => projects = value; } - public IQueryable QueryableProjects { get => projects.AsQueryable(); set => projects = value; } + public List Projects { get => projects; set => projects = value; } + public IQueryable QueryableProjects { get => projects.AsQueryable(); set => projects = value.ToList(); } public virtual IEnumerable Tasks { get; set; } = new List(); public List Locations { get; set; } = new List(); public virtual List People { get; set; } = new List(); @@ -131,6 +132,7 @@ public class Project public string Name { get; set; } public int Type { get; set; } public Location Location { get; set; } + [UseAggregate] public IEnumerable Tasks { get; set; } public Person Owner { get; set; } public int CreatedBy { get; set; } @@ -154,6 +156,7 @@ public class Task public bool IsActive { get; set; } public Person Assignee { get; set; } public Project Project { get; set; } + public int HoursEstimated { get; set; } } public class Location {