diff --git a/src/System.Management.Automation/engine/parser/PSType.cs b/src/System.Management.Automation/engine/parser/PSType.cs index 7ffaab3ff0a..3ad7ced4fdc 100644 --- a/src/System.Management.Automation/engine/parser/PSType.cs +++ b/src/System.Management.Automation/engine/parser/PSType.cs @@ -265,6 +265,42 @@ internal static void DefineCustomAttributes(EnumBuilder member, ReadOnlyCollecti } } + private class InterfaceExpression + { + private TypeConstraintAst ast; + + internal bool IsGeneric => ast?.TypeName.IsGeneric ?? false; + + internal InterfaceExpression(TypeConstraintAst ast) + { + this.ast = ast; + } + + internal Type ResolveConcreteInterfaceType(TypeBuilder parameter) + => IsGeneric ? ResolveConcreteInterfaceTypeArguments(ast.TypeName, parameter) : ast.TypeName.GetReflectionType(); + + private Type ResolveConcreteInterfaceTypeArguments(ITypeName typeName, TypeBuilder parameter) + { + var typeArgs = new List(); + if (typeName.IsGeneric && typeName is GenericTypeName genericName) + { + foreach (var typeArg in genericName.GenericArguments) + { + typeArgs.Add(ResolveConcreteInterfaceTypeArguments(typeArg, parameter)); + } + + return genericName.TypeName.GetReflectionType().MakeGenericType(typeArgs.ToArray()); + } + + if (parameter.FullName == typeName.FullName) + { + return parameter; + } + + return typeName.GetReflectionType(); + } + } + private class DefineTypeHelper { private readonly Parser _parser; @@ -276,7 +312,7 @@ private class DefineTypeHelper internal readonly TypeBuilder _staticHelpersTypeBuilder; private readonly Dictionary _definedProperties; private readonly Dictionary>> _definedMethods; - private HashSet> _interfaceProperties; + private Dictionary _interfaceProperties; internal readonly List<(string fieldName, IParameterMetadataProvider bodyAst, bool isStatic)> _fieldsToInitForMemberFunctions; private bool _baseClassHasDefaultCtor; @@ -292,10 +328,15 @@ public DefineTypeHelper(Parser parser, ModuleBuilder module, TypeDefinitionAst t _parser = parser; _typeDefinitionAst = typeDefinitionAst; - List interfaces; + List interfaces; var baseClass = this.GetBaseTypes(parser, typeDefinitionAst, out interfaces); - _typeBuilder = module.DefineType(typeName, Reflection.TypeAttributes.Class | Reflection.TypeAttributes.Public, baseClass, interfaces.ToArray()); + _typeBuilder = module.DefineType(typeName, Reflection.TypeAttributes.Class | Reflection.TypeAttributes.Public, baseClass, null); + foreach (var interfaceExpression in interfaces) + { + _typeBuilder.AddInterfaceImplementation(interfaceExpression.ResolveConcreteInterfaceType(_typeBuilder)); + } + _staticHelpersTypeBuilder = module.DefineType(string.Format(CultureInfo.InvariantCulture, "{0}_", typeName), Reflection.TypeAttributes.Class); DefineCustomAttributes(_typeBuilder, typeDefinitionAst.Attributes, _parser, AttributeTargets.Class); _typeDefinitionAst.Type = _typeBuilder; @@ -314,12 +355,31 @@ public DefineTypeHelper(Parser parser, ModuleBuilder module, TypeDefinitionAst t /// /// /// Return declared interfaces. - /// - private Type GetBaseTypes(Parser parser, TypeDefinitionAst typeDefinitionAst, out List interfaces) + /// The base type + private Type GetBaseTypes(Parser parser, TypeDefinitionAst typeDefinitionAst, out List interfaces) { // Define base types and report errors. Type baseClass = null; - interfaces = new List(); + interfaces = new List(); + + bool TryGetInterface(TypeConstraintAst ast, out InterfaceExpression interfaceExpression) + { + interfaceExpression = new InterfaceExpression(ast); + if (ast.TypeName.IsGeneric && ast.TypeName is GenericTypeName genericTypeName) + { + if (genericTypeName.TypeName.GetReflectionType().IsInterface) + { + return true; + } + } + else if (ast.TypeName.GetReflectionType()?.IsInterface ?? false) + { + return true; + } + + interfaceExpression = null; + return false; + } // Default base class is System.Object and it has a default ctor. _baseClassHasDefaultCtor = true; @@ -339,40 +399,43 @@ private Type GetBaseTypes(Parser parser, TypeDefinitionAst typeDefinitionAst, ou } else { - baseClass = firstBaseTypeAst.TypeName.GetReflectionType(); - if (baseClass == null) + if (TryGetInterface(firstBaseTypeAst, out InterfaceExpression interfaceExpression)) { - parser.ReportError(firstBaseTypeAst.Extent, - nameof(ParserStrings.TypeNotFound), - ParserStrings.TypeNotFound, - firstBaseTypeAst.TypeName.FullName); - // fall to the default base type + // First Ast can represent interface as well as BaseClass. + interfaces.Add(interfaceExpression); + baseClass = null; } else { - if (baseClass.IsSealed) + baseClass = firstBaseTypeAst.TypeName.GetReflectionType(); + if (baseClass == null) { parser.ReportError(firstBaseTypeAst.Extent, - nameof(ParserStrings.SealedBaseClass), - ParserStrings.SealedBaseClass, - baseClass.Name); - // ignore base type if it's sealed. - baseClass = null; - } - else if (baseClass.IsGenericType && !baseClass.IsConstructedGenericType) - { - parser.ReportError(firstBaseTypeAst.Extent, - nameof(ParserStrings.SubtypeUnclosedGeneric), - ParserStrings.SubtypeUnclosedGeneric, - baseClass.Name); - // ignore base type, we cannot inherit from unclosed generic. - baseClass = null; + nameof(ParserStrings.TypeNotFound), + ParserStrings.TypeNotFound, + firstBaseTypeAst.TypeName.FullName); + // fall to the default base type } - else if (baseClass.IsInterface) + else { - // First Ast can represent interface as well as BaseClass. - interfaces.Add(baseClass); - baseClass = null; + if (baseClass.IsSealed) + { + parser.ReportError(firstBaseTypeAst.Extent, + nameof(ParserStrings.SealedBaseClass), + ParserStrings.SealedBaseClass, + baseClass.Name); + // ignore base type if it's sealed. + baseClass = null; + } + else if (baseClass.IsGenericType && !baseClass.IsConstructedGenericType) + { + parser.ReportError(firstBaseTypeAst.Extent, + nameof(ParserStrings.SubtypeUnclosedGeneric), + ParserStrings.SubtypeUnclosedGeneric, + baseClass.Name); + // ignore base type, we cannot inherit from unclosed generic. + baseClass = null; + } } } } @@ -416,26 +479,16 @@ private Type GetBaseTypes(Parser parser, TypeDefinitionAst typeDefinitionAst, ou else { Type interfaceType = baseTypeAsts[i].TypeName.GetReflectionType(); - if (interfaceType == null) + if (!TryGetInterface(baseTypeAsts[i], out InterfaceExpression interfaceExpression)) { parser.ReportError(baseTypeAsts[i].Extent, - nameof(ParserStrings.TypeNotFound), - ParserStrings.TypeNotFound, - baseTypeAsts[i].TypeName.FullName); + nameof(ParserStrings.InterfaceNameExpected), + ParserStrings.InterfaceNameExpected, + interfaceType.Name); } else { - if (interfaceType.IsInterface) - { - interfaces.Add(interfaceType); - } - else - { - parser.ReportError(baseTypeAsts[i].Extent, - nameof(ParserStrings.InterfaceNameExpected), - ParserStrings.InterfaceNameExpected, - interfaceType.Name); - } + interfaces.Add(interfaceExpression); } } } @@ -448,31 +501,37 @@ private bool ShouldImplementProperty(string name, Type type) { if (_interfaceProperties == null) { - _interfaceProperties = new HashSet>(); + _interfaceProperties = new Dictionary(); var allInterfaces = new HashSet(); // TypeBuilder.GetInterfaces() returns only the interfaces that was explicitly passed to its constructor. // During compilation the interface hierarchy is flattened, so we only need to resolve one level of ancestral interfaces. foreach (var interfaceType in _typeBuilder.GetInterfaces()) { - foreach (var parentInterface in interfaceType.GetInterfaces()) + var typeDefinition = interfaceType; + if (interfaceType.IsGenericType && interfaceType.GenericTypeArguments.Contains(_typeBuilder)) + { + typeDefinition = interfaceType.GetGenericTypeDefinition(); + } + + foreach (var parentInterface in typeDefinition.GetInterfaces()) { allInterfaces.Add(parentInterface); } - allInterfaces.Add(interfaceType); + allInterfaces.Add(typeDefinition); } foreach (var interfaceType in allInterfaces) { foreach (var property in interfaceType.GetProperties()) { - _interfaceProperties.Add(Tuple.Create(property.Name, property.PropertyType)); + _interfaceProperties.Add(property.Name, property.PropertyType); } } } - return _interfaceProperties.Contains(Tuple.Create(name, type)); + return _interfaceProperties.TryGetValue(name, out Type returnType) && (returnType == type || returnType.IsGenericParameter); } public void DefineMembers() diff --git a/test/powershell/Language/Classes/scripting.Classes.inheritance.tests.ps1 b/test/powershell/Language/Classes/scripting.Classes.inheritance.tests.ps1 index cf304d5918e..d86a0da56a9 100644 --- a/test/powershell/Language/Classes/scripting.Classes.inheritance.tests.ps1 +++ b/test/powershell/Language/Classes/scripting.Classes.inheritance.tests.ps1 @@ -87,6 +87,11 @@ Describe 'Classes inheritance syntax' -Tags "CI" { { [A]::b = "bla" } | Should -Throw -ErrorId 'ExceptionWhenSetting' } + It 'can implement generic interfaces referencing itself as a type parameter' { + $C1 = Invoke-Expression 'class ComparableClass : IComparable[ComparableClass] { [int]$Value; [int]CompareTo([ComparableClass]$obj){ return $this.Value.CompareTo($obj.Value) } } [ComparableClass]' + $C1.ImplementedInterfaces[0].TypeParameter | Should -Be $C1 + } + Context "Inheritance from abstract .NET classes" { BeforeAll { class TestHost : System.Management.Automation.Host.PSHost