Shaolinq.AsyncRewriter.Rewriter.RewriteMethodAsync C# (CSharp) Method

RewriteMethodAsync() private method

private RewriteMethodAsync ( Microsoft.CodeAnalysis.CSharp.Syntax.MethodDeclarationSyntax methodSyntax, Microsoft.CodeAnalysis.SemanticModel semanticModel, bool cancellationVersion ) : Microsoft.CodeAnalysis.CSharp.Syntax.MethodDeclarationSyntax
methodSyntax Microsoft.CodeAnalysis.CSharp.Syntax.MethodDeclarationSyntax
semanticModel Microsoft.CodeAnalysis.SemanticModel
cancellationVersion bool
return Microsoft.CodeAnalysis.CSharp.Syntax.MethodDeclarationSyntax
		private MethodDeclarationSyntax RewriteMethodAsync(MethodDeclarationSyntax methodSyntax, SemanticModel semanticModel, bool cancellationVersion)
		{
			var methodSymbol = (IMethodSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, methodSyntax);
			var asyncMethodName = methodSymbol.Name + "Async";
			var isInterfaceMethod = methodSymbol.ContainingType.TypeKind == TypeKind.Interface;
			
			if (((methodSyntax.Parent as TypeDeclarationSyntax)?.Modifiers)?.Any(c => c.Kind() == SyntaxKind.PartialKeyword) != true)
			{
				var name = ((TypeDeclarationSyntax)methodSyntax.Parent).Identifier.ToString();

				if (!typesAlreadyWarnedAbout.Contains(name))
				{
					typesAlreadyWarnedAbout.Add(name);

					log.LogError($"Type '{name}' needs to be marked as partial");
				}
			}

			var newAsyncMethod = MethodInvocationAsyncRewriter.Rewrite(this.log, this.lookup, semanticModel, this.excludedTypes, this.cancellationTokenSymbol, methodSyntax);
			var returnTypeName = methodSyntax.ReturnType.ToString();

			newAsyncMethod = newAsyncMethod
				.WithIdentifier(SyntaxFactory.Identifier(asyncMethodName))
				.WithAttributeLists(new SyntaxList<AttributeListSyntax>())
				.WithReturnType(SyntaxFactory.ParseTypeName(returnTypeName == "void" ? "Task" : $"Task<{returnTypeName}>"));

			if (cancellationVersion)
			{
				newAsyncMethod = newAsyncMethod.WithParameterList(SyntaxFactory.ParameterList(methodSyntax.ParameterList.Parameters.Insert
				(
					methodSyntax.ParameterList.Parameters.TakeWhile(p => p.Default == null && !p.Modifiers.Any(m => m.IsKind(SyntaxKind.ParamsKeyword))).Count(),
					SyntaxFactory.Parameter
					(
						SyntaxFactory.List<AttributeListSyntax>(),
						SyntaxFactory.TokenList(),
						SyntaxFactory.ParseTypeName(this.cancellationTokenSymbol.ToMinimalDisplayString(semanticModel, newAsyncMethod.SpanStart)),
						SyntaxFactory.Identifier("cancellationToken"),
						null
					)
				)));

				if (!(isInterfaceMethod || methodSymbol.IsAbstract))
				{
					newAsyncMethod = newAsyncMethod.WithModifiers(methodSyntax.Modifiers.Add(SyntaxFactory.Token(SyntaxKind.AsyncKeyword)));
				}
			}
			else
			{
				var methodName = asyncMethodName;

				if (methodSymbol.TypeParameters.Length > 0)
				{
					var typeParams = string.Join(", ", methodSymbol.TypeParameters.Select(c => c.ToString()));

					methodName += "<" + typeParams + ">";
				}

				var callAsyncWithCancellationToken = SyntaxFactory.InvocationExpression
				(
					SyntaxFactory.IdentifierName(methodName),
					SyntaxFactory.ArgumentList
					(
						new SeparatedSyntaxList<ArgumentSyntax>()
						.AddRange(methodSymbol.Parameters.TakeWhile(c => !(c.HasExplicitDefaultValue || c.IsParams)).Select(c => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(c.Name))))
						.Add(SyntaxFactory.Argument(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.ParseName("CancellationToken"), SyntaxFactory.IdentifierName("None"))))
						.AddRange(methodSymbol.Parameters.SkipWhile(c => !(c.HasExplicitDefaultValue || c.IsParams)).Select(c => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(c.Name))))
					)
				);

				if (!(isInterfaceMethod || methodSymbol.IsAbstract))
				{
					newAsyncMethod = newAsyncMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ReturnStatement(callAsyncWithCancellationToken)));
				}
			}

			if (!(isInterfaceMethod || methodSymbol.IsAbstract))
			{
				var baseAsyncMethod = this
					.GetMethods(semanticModel, methodSymbol.ReceiverType.BaseType, methodSymbol.Name + "Async", newAsyncMethod, methodSyntax)
					.FirstOrDefault();

				var baseMethod = this
					.GetMethods(semanticModel, methodSymbol.ReceiverType.BaseType, methodSymbol.Name, methodSyntax, methodSyntax)
					.FirstOrDefault();

				var parentContainsAsyncMethod = baseAsyncMethod != null;
				var parentContainsMethodWithRewriteAsync = 
					baseMethod?.GetAttributes().Any(c => c.AttributeClass.Name.StartsWith("RewriteAsync")) == true
					|| baseMethod?.ContainingType.GetAttributes().Any(c => c.AttributeClass.Name.StartsWith("RewriteAsync")) == true;
				var hadNew = newAsyncMethod.Modifiers.Any(c => c.Kind() == SyntaxKind.NewKeyword);
				var hadOverride = newAsyncMethod.Modifiers.Any(c => c.Kind() == SyntaxKind.OverrideKeyword);

				if (!parentContainsAsyncMethod && hadNew)
				{
					newAsyncMethod = newAsyncMethod.WithModifiers(new SyntaxTokenList().AddRange(newAsyncMethod.Modifiers.Where(c => c.Kind() != SyntaxKind.NewKeyword)));
				}

				if (parentContainsAsyncMethod && !(baseAsyncMethod.IsVirtual || baseAsyncMethod.IsAbstract || baseAsyncMethod.IsOverride))
				{
					return null;
				}

				if (!(parentContainsAsyncMethod || parentContainsMethodWithRewriteAsync))
				{
					newAsyncMethod = newAsyncMethod.WithModifiers(new SyntaxTokenList().AddRange(newAsyncMethod.Modifiers.Where(c => c.Kind() != SyntaxKind.OverrideKeyword)));

					if (hadOverride)
					{
						newAsyncMethod = newAsyncMethod.WithModifiers(newAsyncMethod.Modifiers.Add(SyntaxFactory.Token(SyntaxKind.VirtualKeyword)));
					}
				}

				var baseMatchedMethod = baseMethod ?? baseAsyncMethod;

				if (methodSyntax.ConstraintClauses.Any())
				{
					newAsyncMethod = newAsyncMethod.WithConstraintClauses(methodSyntax.ConstraintClauses);
				}
				else if (!hadOverride && baseMatchedMethod != null)
				{
					var constraintClauses = new List<TypeParameterConstraintClauseSyntax>();

					foreach (var typeParameter in baseMatchedMethod.TypeParameters)
					{
						var constraintClause = SyntaxFactory.TypeParameterConstraintClause(typeParameter.Name);
						var constraints = new List<TypeParameterConstraintSyntax>();

						if (typeParameter.HasReferenceTypeConstraint)
						{
							constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.ClassConstraint));
						}

						if (typeParameter.HasValueTypeConstraint)
						{
							constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.StructConstraint));
						}

						if (typeParameter.HasConstructorConstraint)
						{
							constraints.Add(SyntaxFactory.ConstructorConstraint());
						}

						constraints.AddRange(typeParameter.ConstraintTypes.Select(c => SyntaxFactory.TypeConstraint(SyntaxFactory.ParseName(c.ToMinimalDisplayString(semanticModel, methodSyntax.SpanStart)))));

						constraintClause = constraintClause.WithConstraints(SyntaxFactory.SeparatedList(constraints));
						constraintClauses.Add(constraintClause);
					}

					newAsyncMethod = newAsyncMethod.WithConstraintClauses(SyntaxFactory.List(constraintClauses));
				}
			}

			var attribute = methodSymbol.GetAttributes().SingleOrDefault(a => a.AttributeClass.Name.EndsWith("RewriteAsyncAttribute"))
							?? methodSymbol.ContainingType.GetAttributes().SingleOrDefault(a => a.AttributeClass.Name.EndsWith("RewriteAsyncAttribute"));

			if (attribute?.ConstructorArguments.Length > 0)
			{
				var first = attribute.ConstructorArguments.First();

				if (first.Type.Equals(this.methodAttributesSymbol))
				{
					var methodAttributes = (MethodAttributes)Enum.ToObject(typeof(MethodAttributes), Convert.ToInt32(first.Value));

					newAsyncMethod = newAsyncMethod.WithAccessModifiers(methodAttributes);
				}
			}

			return newAsyncMethod;
		}
	}