diff --git a/src/Generator/Generators/CLI/CLIHeaders.cs b/src/Generator/Generators/CLI/CLIHeaders.cs index 327e5c4fdc..a7e9cf5b0f 100644 --- a/src/Generator/Generators/CLI/CLIHeaders.cs +++ b/src/Generator/Generators/CLI/CLIHeaders.cs @@ -222,10 +222,13 @@ public void GenerateFunctions(DeclarationContext decl) { PushBlock(BlockKind.FunctionsClass); - WriteLine("public ref class {0}", TranslationUnit.FileNameWithoutExtension); - WriteLine("{"); - WriteLine("public:"); - Indent(); + if (!(decl is Class)) + { + WriteLine("public ref class {0}", TranslationUnit.FileNameWithoutExtension); + WriteLine("{"); + WriteLine("public:"); + Indent(); + } // Generate all the function declarations for the module. foreach (var function in decl.Functions) @@ -233,8 +236,11 @@ public void GenerateFunctions(DeclarationContext decl) GenerateFunction(function); } - Unindent(); - WriteLine("};"); + if (!(decl is Class)) + { + Unindent(); + WriteLine("};"); + } PopBlock(NewLineKind.BeforeNextBlock); } diff --git a/src/Generator/Generators/CLI/CLISources.cs b/src/Generator/Generators/CLI/CLISources.cs index e633890c65..09c973a042 100644 --- a/src/Generator/Generators/CLI/CLISources.cs +++ b/src/Generator/Generators/CLI/CLISources.cs @@ -889,11 +889,9 @@ public void GenerateFunction(Function function, DeclarationContext @namespace) GenerateDeclarationCommon(function); - var classSig = string.Format("{0}::{1}", QualifiedIdentifier(@namespace), - TranslationUnit.FileNameWithoutExtension); - - Write("{0} {1}::{2}(", function.ReturnType, classSig, - function.Name); + Write($@"{function.ReturnType} {QualifiedIdentifier(@namespace)}::{ + (@namespace is Class ? string.Empty : $@"{ + TranslationUnit.FileNameWithoutExtension}::")}{function.Name}("); for (var i = 0; i < function.Parameters.Count; ++i) { diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 98ecb8f2ba..0715092fb1 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -239,15 +239,19 @@ private IEnumerable EnumerateClasses(DeclarationContext context) public virtual void GenerateNamespaceFunctionsAndVariables(DeclarationContext context) { + var hasGlobalFunctions = !(context is Class) && context.Functions.Any( + f => f.IsGenerated); + var hasGlobalVariables = !(context is Class) && context.Variables.Any( v => v.IsGenerated && v.Access == AccessSpecifier.Public); - if (!context.Functions.Any(f => f.IsGenerated) && !hasGlobalVariables) + if (!hasGlobalFunctions && !hasGlobalVariables) return; - PushBlock(BlockKind.Functions); var parentName = SafeIdentifier(context.TranslationUnit.FileNameWithoutExtension); + PushBlock(BlockKind.Functions); + var keyword = "class"; var classes = EnumerateClasses().ToList(); if (classes.FindAll(cls => cls.IsValueType && cls.Name == parentName && context.QualifiedLogicalName == cls.Namespace.QualifiedLogicalName).Any()) @@ -271,12 +275,8 @@ public virtual void GenerateNamespaceFunctionsAndVariables(DeclarationContext co UnindentAndWriteCloseBrace(); PopBlock(NewLineKind.BeforeNextBlock); - foreach (var function in context.Functions) - { - if (!function.IsGenerated) continue; - + foreach (Function function in context.Functions.Where(f => f.IsGenerated)) GenerateFunction(function, parentName); - } foreach (var variable in context.Variables.Where( v => v.IsGenerated && v.Access == AccessSpecifier.Public)) @@ -443,7 +443,8 @@ public override bool VisitClassDecl(Class @class) } GenerateClassConstructors(@class); - + foreach (Function function in @class.Functions.Where(f => f.IsGenerated)) + GenerateFunction(function, @class.Name); GenerateClassMethods(@class.Methods); GenerateClassVariables(@class); GenerateClassProperties(@class); @@ -649,6 +650,10 @@ private void GatherClassInternalFunctions(Class @class, bool includeCtors, && !functions.Contains(prop.SetMethod)) tryAddOverload(prop.SetMethod); } + + functions.AddRange(from function in @class.Functions + where function.IsGenerated && !function.IsSynthetized + select function); } private IEnumerable GatherInternalParams(Function function, out TypePrinterResult retType) @@ -2323,6 +2328,8 @@ public void GenerateFunction(Function function, string parentName) if (function.SynthKind == FunctionSynthKind.DefaultValueOverload) GenerateOverloadCall(function); + else if (function.IsOperator) + GenerateOperator(function, default(QualifiedType)); else GenerateInternalFunctionCall(function); @@ -2650,14 +2657,14 @@ private string GetVirtualCallDelegate(Method method) return delegateId; } - private void GenerateOperator(Method method, QualifiedType returnType) + private void GenerateOperator(Function function, QualifiedType returnType) { - if (method.SynthKind == FunctionSynthKind.ComplementOperator) + if (function.SynthKind == FunctionSynthKind.ComplementOperator) { - if (method.Kind == CXXMethodKind.Conversion) + if (function is Method method && method.Kind == CXXMethodKind.Conversion) { // To avoid ambiguity when having the multiple inheritance pass enabled - var paramType = method.Parameters[0].Type.SkipPointerRefs().Desugar(); + var paramType = function.Parameters[0].Type.SkipPointerRefs().Desugar(); paramType = (paramType.GetPointee() ?? paramType).Desugar(); Class paramClass; Class @interface = null; @@ -2665,9 +2672,9 @@ private void GenerateOperator(Method method, QualifiedType returnType) @interface = paramClass.GetInterface(); var paramName = string.Format("{0}{1}", - method.Parameters[0].Type.IsPrimitiveTypeConvertibleToRef() ? + function.Parameters[0].Type.IsPrimitiveTypeConvertibleToRef() ? "ref *" : string.Empty, - method.Parameters[0].Name); + function.Parameters[0].Name); var printedType = method.ConversionType.Visit(TypePrinter); if (@interface != null) { @@ -2679,30 +2686,45 @@ private void GenerateOperator(Method method, QualifiedType returnType) } else { - var @operator = Operators.GetOperatorOverloadPair(method.OperatorKind); + var @operator = Operators.GetOperatorOverloadPair(function.OperatorKind); - WriteLine("return !({0} {1} {2});", method.Parameters[0].Name, - @operator, method.Parameters[1].Name); + // handle operators for comparison which return int instead of bool + Type retType = function.OriginalReturnType.Type.Desugar(); + bool regular = retType.IsPrimitiveType(PrimitiveType.Bool); + if (regular) + { + WriteLine($@"return !({function.Parameters[0].Name} { + @operator} {function.Parameters[1].Name});"); + } + else + { + WriteLine($@"return global::System.Convert.ToInt32(({ + function.Parameters[0].Name} {@operator} { + function.Parameters[1].Name}) == 0);"); + } } return; } - if (method.OperatorKind == CXXOperatorKind.EqualEqual || - method.OperatorKind == CXXOperatorKind.ExclaimEqual) + if (function.OperatorKind == CXXOperatorKind.EqualEqual || + function.OperatorKind == CXXOperatorKind.ExclaimEqual) { WriteLine("bool {0}Null = ReferenceEquals({0}, null);", - method.Parameters[0].Name); + function.Parameters[0].Name); WriteLine("bool {0}Null = ReferenceEquals({0}, null);", - method.Parameters[1].Name); + function.Parameters[1].Name); WriteLine("if ({0}Null || {1}Null)", - method.Parameters[0].Name, method.Parameters[1].Name); - WriteLineIndent("return {0}{1}Null && {2}Null{3};", - method.OperatorKind == CXXOperatorKind.EqualEqual ? string.Empty : "!(", - method.Parameters[0].Name, method.Parameters[1].Name, - method.OperatorKind == CXXOperatorKind.EqualEqual ? string.Empty : ")"); + function.Parameters[0].Name, function.Parameters[1].Name); + Type retType = function.OriginalReturnType.Type.Desugar(); + bool regular = retType.IsPrimitiveType(PrimitiveType.Bool); + WriteLineIndent($@"return {(regular ? string.Empty : "global::System.Convert.ToInt32(")}{ + (function.OperatorKind == CXXOperatorKind.EqualEqual ? string.Empty : "!(")}{ + function.Parameters[0].Name}Null && {function.Parameters[1].Name}Null{ + (function.OperatorKind == CXXOperatorKind.EqualEqual ? string.Empty : ")")}{ + (regular ? string.Empty : ")")};"); } - GenerateInternalFunctionCall(method, returnType: returnType); + GenerateInternalFunctionCall(function, returnType: returnType); } private void GenerateClassConstructor(Method method, Class @class) diff --git a/src/Generator/Passes/CheckOperatorsOverloads.cs b/src/Generator/Passes/CheckOperatorsOverloads.cs index d78901ca0c..e3bcb22277 100644 --- a/src/Generator/Passes/CheckOperatorsOverloads.cs +++ b/src/Generator/Passes/CheckOperatorsOverloads.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System.Collections.Generic; +using System.Linq; using CppSharp.AST; using CppSharp.AST.Extensions; using CppSharp.Generators; @@ -65,6 +66,15 @@ private void CheckInvalidOperators(Class @class) else CreateOperator(@class, @operator); } + + foreach (var @operator in @class.Functions.Where( + f => f.IsGenerated && f.IsOperator && + !IsValidOperatorOverload(f) && !f.IsExplicitlyGenerated)) + { + Diagnostics.Debug("Invalid operator overload {0}::{1}", + @class.OriginalName, @operator.OperatorKind); + @operator.ExplicitlyIgnore(); + } } private static void CreateOperator(Class @class, Method @operator) @@ -128,64 +138,85 @@ private void CreateIndexer(Class @class, Method @operator) @operator.GenerationKind = GenerationKind.Internal; } - private static void HandleMissingOperatorOverloadPair(Class @class, CXXOperatorKind op1, - CXXOperatorKind op2) + private static void HandleMissingOperatorOverloadPair(Class @class, + CXXOperatorKind op1, CXXOperatorKind op2) { - foreach (var op in @class.Operators.Where( + List methods = HandleMissingOperatorOverloadPair( + @class, @class.Operators, op1, op2); + foreach (Method @operator in methods) + { + int index = @class.Methods.IndexOf( + (Method) @operator.OriginalFunction); + @class.Methods.Insert(index, @operator); + } + + List functions = HandleMissingOperatorOverloadPair( + @class, @class.Functions, op1, op2); + foreach (Method @operator in functions) + { + int index = @class.Declarations.IndexOf( + @operator.OriginalFunction); + @class.Methods.Insert(index, @operator); + } + } + + private static List HandleMissingOperatorOverloadPair(Class @class, + IEnumerable functions, CXXOperatorKind op1, + CXXOperatorKind op2) where T : Function, new() + { + List fs = new List(); + foreach (var op in functions.Where( o => o.OperatorKind == op1 || o.OperatorKind == op2).ToList()) { - int index; - var missingKind = CheckMissingOperatorOverloadPair(@class, out index, op1, op2, - op.Parameters.First().Type, op.Parameters.Last().Type); + var missingKind = CheckMissingOperatorOverloadPair(functions, + op1, op2, op.Parameters.First().Type, op.Parameters.Last().Type); if (missingKind == CXXOperatorKind.None || !op.IsGenerated) continue; - var method = new Method() - { - Name = Operators.GetOperatorIdentifier(missingKind), - Namespace = @class, - SynthKind = FunctionSynthKind.ComplementOperator, - Kind = CXXMethodKind.Operator, - OperatorKind = missingKind, - ReturnType = op.ReturnType - }; - - method.Parameters.AddRange(op.Parameters.Select( - p => new Parameter(p) { Namespace = method })); - - @class.Methods.Insert(index, method); + var function = new T() + { + Name = Operators.GetOperatorIdentifier(missingKind), + Namespace = @class, + SynthKind = FunctionSynthKind.ComplementOperator, + OperatorKind = missingKind, + ReturnType = op.ReturnType, + OriginalFunction = op + }; + + var method = function as Method; + if (method != null) + method.Kind = CXXMethodKind.Operator; + + function.Parameters.AddRange(op.Parameters.Select( + p => new Parameter(p) { Namespace = function })); + + fs.Add(function); } + return fs; } - - static CXXOperatorKind CheckMissingOperatorOverloadPair(Class @class, out int index, - CXXOperatorKind op1, CXXOperatorKind op2, Type typeLeft, Type typeRight) + + private static CXXOperatorKind CheckMissingOperatorOverloadPair( + IEnumerable functions, + CXXOperatorKind op1, CXXOperatorKind op2, + Type typeLeft, Type typeRight) { - var first = @class.Operators.FirstOrDefault(o => o.IsGenerated && o.OperatorKind == op1 && - o.Parameters.First().Type.Equals(typeLeft) && o.Parameters.Last().Type.Equals(typeRight)); - var second = @class.Operators.FirstOrDefault(o => o.IsGenerated && o.OperatorKind == op2 && - o.Parameters.First().Type.Equals(typeLeft) && o.Parameters.Last().Type.Equals(typeRight)); + var first = functions.FirstOrDefault( + o => o.IsGenerated && o.OperatorKind == op1 && + o.Parameters.First().Type.Equals(typeLeft) && + o.Parameters.Last().Type.Equals(typeRight)); + var second = functions.FirstOrDefault( + o => o.IsGenerated && o.OperatorKind == op2 && + o.Parameters.First().Type.Equals(typeLeft) && + o.Parameters.Last().Type.Equals(typeRight)); var hasFirst = first != null; var hasSecond = second != null; - if (hasFirst && !hasSecond) - { - index = @class.Methods.IndexOf(first); - return op2; - } - - if (hasSecond && !hasFirst) - { - index = @class.Methods.IndexOf(second); - return op1; - } - - index = 0; - return CXXOperatorKind.None; + return hasFirst && !hasSecond ? op2 : hasSecond && !hasFirst ? op1 : CXXOperatorKind.None; } - private bool IsValidOperatorOverload(Method @operator) + private bool IsValidOperatorOverload(Function @operator) { // These follow the order described in MSDN (Overloadable Operators). diff --git a/src/Generator/Passes/MoveFunctionToClassPass.cs b/src/Generator/Passes/MoveFunctionToClassPass.cs index 07fff6155d..0169f5fdb4 100644 --- a/src/Generator/Passes/MoveFunctionToClassPass.cs +++ b/src/Generator/Passes/MoveFunctionToClassPass.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System.Collections.Generic; +using System.Linq; using CppSharp.AST; namespace CppSharp.Passes @@ -17,6 +18,14 @@ public MoveFunctionToClassPass() VisitOptions.VisitTemplateArguments = false; } + public override bool VisitASTContext(ASTContext context) + { + bool result = base.VisitASTContext(context); + foreach (Function movedFunction in movedFunctions) + movedFunction.OriginalNamespace.Declarations.Remove(movedFunction); + return result; + } + public override bool VisitFunctionDecl(Function function) { if (!function.IsGenerated) @@ -28,23 +37,9 @@ public override bool VisitFunctionDecl(Function function) @class.TranslationUnit.Module != function.TranslationUnit.Module) return false; - // Create a new fake method so it acts as a static method. - var method = new Method(function) - { - Namespace = @class, - OperatorKind = function.OperatorKind, - OriginalFunction = null, - IsStatic = true - }; - if (method.IsOperator) - { - method.IsNonMemberOperator = true; - method.Kind = CXXMethodKind.Operator; - } - - function.ExplicitlyIgnore(); - - @class.Methods.Add(method); + function.Namespace = @class; + @class.Declarations.Add(function); + movedFunctions.Add(function); Diagnostics.Debug($"Function {function.Name} moved to class {@class.Name}"); @@ -76,5 +71,7 @@ private Class FindClassToMoveFunctionTo(Function function) return @class; } + + private HashSet movedFunctions = new HashSet(); } }