Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/EntityGraphQL/Compiler/Util/LinqRuntimeTypeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ private static string GetTypeKey(Dictionary<string, Type> fields)
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static Type GetDynamicType(Dictionary<string, Type> fields, string description, Type? parentType = null)
public static Type GetDynamicType(Dictionary<string, Type> fields, string description, Type? parentType = null, Type[]? interfaces = null, Action<TypeBuilder>? build = null)
{
if (null == fields)
throw new ArgumentNullException(nameof(fields));
string classFullName = GetTypeKey(fields) + parentType?.Name.GetHashCode();

string classFullName = GetTypeKey(fields) + parentType?.Name.GetHashCode() + (interfaces != null ? string.Join("_", interfaces.Select(i => i.Name.GetHashCode())) : "");
lock (typesByFullName)
{
if (!typesByFullName.ContainsKey(classFullName))
Expand All @@ -51,7 +51,7 @@ public static Type GetDynamicType(Dictionary<string, Type> fields, string descri
if (builtTypes.ContainsKey(classId))
return builtTypes[classId];

var typeBuilder = moduleBuilder.DefineType(classId.ToString(), TypeAttributes.Public | TypeAttributes.Class | TypeAttributes.Serializable, parentType);
var typeBuilder = moduleBuilder.DefineType(classId.ToString(), TypeAttributes.Public | TypeAttributes.Class | TypeAttributes.Serializable, parentType, interfaces);

foreach (var field in fields)
{
Expand All @@ -61,6 +61,8 @@ public static Type GetDynamicType(Dictionary<string, Type> fields, string descri
typeBuilder.DefineField(field.Key, field.Value, FieldAttributes.Public);
}

build?.Invoke(typeBuilder);

builtTypes[classId] = typeBuilder.CreateTypeInfo()!.AsType();
return builtTypes[classId];
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using EntityGraphQL.Compiler;
using EntityGraphQL.Compiler.Util;
using EntityGraphQL.Extensions;

namespace EntityGraphQL.Schema.FieldExtensions;

/// <summary>
/// Builds a field to query aggregate data on a list field.
///
/// E.g. Given a field that returns a list of people: people: [Person!]! it will add a new field peopleAggregate: PeopleAggregate
/// similar to the following:
///
/// schema.AddType<PeopleAggregate>("PeopleAggregate", "Aggregate people", type =>
/// {
/// type.AddField("count", (c) => c.Count(), "Count of people");
/// type.AddField("heightMin", (c) => c.Min(p => p.Height), "Min height");
/// type.AddField("heightMax", (c) => c.Max(p => p.Height), "Max height");
/// type.AddField("heightAvg", (c) => c.Average(p => p.Height), "Average height");
/// type.AddField("heightSum", (c) => c.Sum(p => p.Height), "Sum of height");
/// });
/// schema.Query().AddField("peopleAggregate", ctx => ctx.People, "Aggregate people")
/// .Returns("PeopleAggregate");
///
/// Where PeopleAggregate is a new dotnet type used to differentiate the return type from a normal list of people.
///
/// public class PeopleAggregate : IEnumerable<Person>
/// {
/// public IEnumerator<Person> GetEnumerator() => throw new System.NotImplementedException();
/// IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
/// }
///
/// If this is at the root level of the query it would build a LINQ expression query similar to the following:
/// (MyContext ctx) => new
/// {
/// count = ctx.People.Count(),
/// minHeight = ctx.People.Min(p => p.Height),
/// maxHeight = ctx.People.Max(p => p.Height),
/// avgHeight = ctx.People.Average(p => p.Height),
/// sumHeight = ctx.People.Sum(p => p.Height)
/// };
/// </summary>
public class AggregateExtension : BaseFieldExtension
{
private static readonly object addAggregateTypeLock = new();
private readonly string? fieldName;
private readonly List<string>? aggregateFieldList;
private readonly bool fieldListIsExclude;
private static int dupeCnt = 1;

public AggregateExtension(string? fieldName, IEnumerable<string>? aggregateFieldList, bool excludeFields)
{
this.fieldName = fieldName;
this.aggregateFieldList = aggregateFieldList?.ToList();
fieldListIsExclude = excludeFields;
}

public override void Configure(ISchemaProvider schema, IField field)
{
if (field.ResolveExpression == null)
throw new EntityGraphQLCompilerException($"ConnectionPagingExtension requires a Resolve function set on the field");

if (!field.ResolveExpression.Type.IsEnumerableOrArray())
throw new ArgumentException($"Expression for field {field.Name} must be a collection to use ConnectionPagingExtension. Found type {field.ReturnType.TypeDotnet}");

var schemaTypeName = $"{field.FromType.Name}{field.Name.FirstCharToUpper()}Aggregate";
var fieldName = this.fieldName ?? $"{field.Name}Aggregate";
var listElementType = field.ReturnType.TypeDotnet.GetEnumerableOrArrayType()!;

// likely not to happen in normal use but the unit tests will hit this as it tries to create the same type multiple times
lock (addAggregateTypeLock)
{
// first build type and fetch it from the dynamic types - this will reuse the type if it already exists
var aggregateDotnetType = GetDotnetType(field, listElementType);

// Then check if we have that in the schema already
var aggregateSchemaType = schema.HasType(aggregateDotnetType) ? schema.GetSchemaType(aggregateDotnetType, null) : null;
if (aggregateSchemaType == null)
{
// The type may be different but it may have the same name, see test TestDifferentOptionsOnSameTypeDifferentFields
if (schema.HasType(schemaTypeName))
schemaTypeName = $"{schemaTypeName}{dupeCnt++}"; // could do this better

// use reflection to call AddAggregateTypeToSchema
var addTypeMethod = typeof(AggregateExtension).GetMethod(nameof(AddAggregateTypeToSchema), BindingFlags.NonPublic | BindingFlags.Static);
var genericAddTypeMethod = addTypeMethod!.MakeGenericMethod(aggregateDotnetType, listElementType);
aggregateSchemaType = (genericAddTypeMethod.Invoke(this, new object[] { schema, schemaTypeName, field }) as ISchemaType)!;

var contextParam = Expression.Parameter(aggregateSchemaType.TypeDotnet, schemaTypeName);
// set up all the fields on the aggregate type
ForEachPossibleAggregateField(field.ReturnType.SchemaType.GetFields(),
(possibleAggregateField, returnFieldType) =>
{
AddAggregateFieldByReflection("Average", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType);
AddAggregateFieldByReflection("Sum", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType);
AddAggregateFieldByReflection("Min", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType);
AddAggregateFieldByReflection("Max", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType);
},
(possibleAggregateField, returnFieldType) =>
{
AddAggregateFieldByReflection("Min", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType, true);
AddAggregateFieldByReflection("Max", aggregateDotnetType, aggregateSchemaType, possibleAggregateField, contextParam, listElementType, true);
}
);
}
else
{
schemaTypeName = aggregateSchemaType.Name;
}
}

var newField = new Field(schema, field.FromType, fieldName, Expression.Lambda(field.ResolveExpression, field.FieldParam!), $"Aggregate data for {field.Name}", null, new GqlTypeInfo(() => schema.Type(schemaTypeName), schema.Type(schemaTypeName).TypeDotnet), null);
field.FromType.AddField(newField);
}

private Type GetDotnetType(IField field, Type listElementType)
{
var fields = new Dictionary<string, Type> {
{ "count", typeof(int) }
};
ForEachPossibleAggregateField(field.ReturnType.SchemaType.GetFields(),
(possibleAggregateField, returnFieldType) =>
{
fields.Add($"{possibleAggregateField.Name}Average", returnFieldType);
fields.Add($"{possibleAggregateField.Name}Sum", returnFieldType);
fields.Add($"{possibleAggregateField.Name}Min", returnFieldType);
fields.Add($"{possibleAggregateField.Name}Max", returnFieldType);
},
(possibleAggregateField, returnFieldType) =>
{
fields.Add($"{possibleAggregateField.Name}Min", returnFieldType);
fields.Add($"{possibleAggregateField.Name}Max", returnFieldType);
}
);
var aggregateDotnetType = LinqRuntimeTypeBuilder.GetDynamicType(fields, field.Name, null,
new[] { typeof(IEnumerable<>).MakeGenericType(listElementType) },
aggregateDotnetTypeDef =>
{
// define the IEnumerable<T> implementation
var getEnumeratorMethod = typeof(IEnumerable<>).MakeGenericType(listElementType).GetMethod("GetEnumerator");
var getEnumeratorIL = aggregateDotnetTypeDef.DefineMethod(getEnumeratorMethod!.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, getEnumeratorMethod.ReturnType, Type.EmptyTypes).GetILGenerator();
getEnumeratorIL.Emit(OpCodes.Newobj, typeof(NotImplementedException).GetConstructor(Type.EmptyTypes)!);
var getEnumeratorMethod2 = typeof(System.Collections.IEnumerable).GetMethod("GetEnumerator");
var getEnumeratorIL2 = aggregateDotnetTypeDef.DefineMethod(getEnumeratorMethod2!.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, getEnumeratorMethod2.ReturnType, Type.EmptyTypes).GetILGenerator();
getEnumeratorIL2.Emit(OpCodes.Newobj, typeof(NotImplementedException).GetConstructor(Type.EmptyTypes)!);
})!;
return aggregateDotnetType;
}

private void ForEachPossibleAggregateField(IEnumerable<IField> fields, Action<IField, Type> aggregateFieldActionNumeric, Action<IField, Type> aggregateFieldTyped)
{
foreach (IField possibleAggregateField in fields)
{
if (possibleAggregateField.Name.StartsWith("__", StringComparison.InvariantCulture)
|| possibleAggregateField.ResolveExpression == null)
continue;
if (aggregateFieldList != null)
{
if (fieldListIsExclude && aggregateFieldList.Contains(possibleAggregateField.Name))
continue;
if (!fieldListIsExclude && !aggregateFieldList.Contains(possibleAggregateField.Name))
continue;
}

// use reflection to build aggregateSchemaType.AddField("min", (c) => c.Min(fieldExp), "Min value"); ETC
// average & sum can only be done on numeric types from Queryable.Average/Sum
var returnFieldType = possibleAggregateField.ReturnType.TypeDotnet;
if (returnFieldType == typeof(int) || returnFieldType == typeof(int?)
|| returnFieldType == typeof(long) || returnFieldType == typeof(long?)
|| returnFieldType == typeof(double) || returnFieldType == typeof(double?)
|| returnFieldType == typeof(decimal) || returnFieldType == typeof(decimal?)
|| returnFieldType == typeof(float) || returnFieldType == typeof(float?))
{
aggregateFieldActionNumeric(possibleAggregateField, returnFieldType);
}
else if (returnFieldType == typeof(DateTimeOffset) || returnFieldType == typeof(DateTimeOffset?)
|| returnFieldType == typeof(DateTime) || returnFieldType == typeof(DateTime?)
)
{
aggregateFieldTyped(possibleAggregateField, returnFieldType);
}
}
}

private static void AddAggregateFieldByReflection(string method, Type aggregateDotnetType, ISchemaType aggregateSchemaType, IField field, ParameterExpression contextParam, Type listElementType, bool typedReturn = false)
{
var fieldName = $"{field.Name}{method}";
var fieldDescription = $"{method} of {field.Name}";
var genTypes = typedReturn ? new[] { listElementType, field.ReturnType.TypeDotnet } : new[] { listElementType };
var call = Expression.Call(typeof(Enumerable), method, genTypes, contextParam, Expression.Lambda(field.ResolveExpression!, field.FieldParam!));
var fieldExp = Expression.Lambda(call, contextParam);
// find public Field AddField<TReturn>(string name, Expression<Func<TBaseType, TReturn>> fieldSelection, string? description)
var addFieldMethod = aggregateSchemaType.GetType().GetMethods()
.SingleOrDefault(m => m.Name == nameof(ISchemaType.AddField)
&& m.IsGenericMethod
&& m.ReturnType == typeof(Field)
&& m.GetGenericArguments().Length == 1
&& m.GetParameters().Length == 3);
var genericAddFieldMethod = addFieldMethod!.MakeGenericMethod(fieldExp.ReturnType);
genericAddFieldMethod.Invoke(aggregateSchemaType, new object[] { fieldName, fieldExp, fieldDescription });
}

private static ISchemaType AddAggregateTypeToSchema<TType, TListElement>(ISchemaProvider schema, string schemaTypeName, IField field) where TType : class, IEnumerable<TListElement>
{
var aggregateType = schema.AddType<TType>(schemaTypeName, $"Aggregate {field.Name}", type =>
{
type.AddField("count", (c) => c.Count(), "Count of items");
});
return aggregateType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using System;
using System.Collections.Generic;

namespace EntityGraphQL.Schema.FieldExtensions;

public static class UseAggregateExtension
{
/// <summary>
/// If the field is a list, add a new field at the same level with the name {field}Aggregate
/// Only call on a field that returns an IEnumerable
/// </summary>
/// <param name="field"></param>
/// <param name="fieldName">Use this for the name of the created field. Is null the field will be called <field-name>Aggregate</param>
/// <returns></returns>
public static IField UseAggregate(this IField field, string? fieldName = null)
{
return field.AddExtension(new AggregateExtension(fieldName, null, false));
}

/// <summary>
/// If the field is a list, add a new field at the same level with the name {field}Aggregate
/// Only call on a field that returns an IEnumerable
/// </summary>
/// <typeparam name="TElementType"></typeparam>
/// <typeparam name="TReturnType"></typeparam>
/// <typeparam name="TSort"></typeparam>
/// <param name="field"></param>
/// <param name="fieldSelection">
/// GraphQL Schema field names to include or exclude in building the aggregation fields. By default numerical and date fields
/// will have the relevant aggregation fields created. If the field is not available on the type it will be ignored.
/// Field name in GraphQL Schema are case sensitive.
/// </param>
/// <param name="excludeFields">If true, the fields in fieldSelection will be excluded from the aggregate fields instead</param>
/// <param name="fieldName"></param>
/// <returns></returns>
public static IField UseAggregate(this IField field, IEnumerable<string>? fieldSelection, bool excludeFields = false, string? fieldName = null)
{
return field.AddExtension(new AggregateExtension(fieldName, fieldSelection, excludeFields));
}
}

public class UseAggregateAttribute : ExtensionAttribute
{
/// <summary>
/// Overrides the default name of the aggregate field. If null the field will be called <field-name>Aggregate
/// </summary>
public string? FieldName { get; set; }
/// <summary>
/// If false, you will need to use [IncludeAggregateField] to include fields in the aggregate
/// </summary>
public bool AutoAddFields { get; set; } = true;

public UseAggregateAttribute() { }

public override void ApplyExtension(IField field)
{
var fieldList = AutoAddFields ? FindExcludedFields(field) : FindIncludedFields(field);
field.UseAggregate(fieldList, AutoAddFields, FieldName);
}

private static IEnumerable<string> FindIncludedFields(IField field)
{
var includedFields = new List<string>();
foreach (var prop in field.ReturnType.SchemaType.TypeDotnet.GetProperties())
{
if (prop.GetCustomAttributes(typeof(IncludeAggregateFieldAttribute), true).Length > 0)
{
var (name, _) = SchemaBuilder.GetNameAndDescription(prop, field.Schema);
includedFields.Add(name);
}
}
return includedFields;
}

private static IEnumerable<string> FindExcludedFields(IField field)
{
var excludedFields = new List<string>();
foreach (var prop in field.ReturnType.SchemaType.TypeDotnet.GetProperties())
{
if (prop.GetCustomAttributes(typeof(ExcludeAggregateFieldAttribute), true).Length > 0)
{
var (name, _) = SchemaBuilder.GetNameAndDescription(prop, field.Schema);
excludedFields.Add(name);
}
}
return excludedFields;
}
}

[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Method, AllowMultiple = false)]
public class IncludeAggregateFieldAttribute : Attribute
{
}

[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Method, AllowMultiple = false)]
public class ExcludeAggregateFieldAttribute : Attribute
{
}
Loading