private MethodDeclarationSyntax RewriteMethodAsync(MethodDeclarationSyntax methodSyntax, SemanticModel semanticModel, bool cancellationVersion)
{
var methodSymbol = (IMethodSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, methodSyntax);
var asyncMethodName = methodSymbol.Name + "Async";
var isInterfaceMethod = methodSymbol.ContainingType.TypeKind == TypeKind.Interface;
if (((methodSyntax.Parent as TypeDeclarationSyntax)?.Modifiers)?.Any(c => c.Kind() == SyntaxKind.PartialKeyword) != true)
{
var name = ((TypeDeclarationSyntax)methodSyntax.Parent).Identifier.ToString();
if (!typesAlreadyWarnedAbout.Contains(name))
{
typesAlreadyWarnedAbout.Add(name);
log.LogError($"Type '{name}' needs to be marked as partial");
}
}
var newAsyncMethod = MethodInvocationAsyncRewriter.Rewrite(this.log, this.lookup, semanticModel, this.excludedTypes, this.cancellationTokenSymbol, methodSyntax);
var returnTypeName = methodSyntax.ReturnType.ToString();
newAsyncMethod = newAsyncMethod
.WithIdentifier(SyntaxFactory.Identifier(asyncMethodName))
.WithAttributeLists(new SyntaxList<AttributeListSyntax>())
.WithReturnType(SyntaxFactory.ParseTypeName(returnTypeName == "void" ? "Task" : $"Task<{returnTypeName}>"));
if (cancellationVersion)
{
newAsyncMethod = newAsyncMethod.WithParameterList(SyntaxFactory.ParameterList(methodSyntax.ParameterList.Parameters.Insert
(
methodSyntax.ParameterList.Parameters.TakeWhile(p => p.Default == null && !p.Modifiers.Any(m => m.IsKind(SyntaxKind.ParamsKeyword))).Count(),
SyntaxFactory.Parameter
(
SyntaxFactory.List<AttributeListSyntax>(),
SyntaxFactory.TokenList(),
SyntaxFactory.ParseTypeName(this.cancellationTokenSymbol.ToMinimalDisplayString(semanticModel, newAsyncMethod.SpanStart)),
SyntaxFactory.Identifier("cancellationToken"),
null
)
)));
if (!(isInterfaceMethod || methodSymbol.IsAbstract))
{
newAsyncMethod = newAsyncMethod.WithModifiers(methodSyntax.Modifiers.Add(SyntaxFactory.Token(SyntaxKind.AsyncKeyword)));
}
}
else
{
var methodName = asyncMethodName;
if (methodSymbol.TypeParameters.Length > 0)
{
var typeParams = string.Join(", ", methodSymbol.TypeParameters.Select(c => c.ToString()));
methodName += "<" + typeParams + ">";
}
var callAsyncWithCancellationToken = SyntaxFactory.InvocationExpression
(
SyntaxFactory.IdentifierName(methodName),
SyntaxFactory.ArgumentList
(
new SeparatedSyntaxList<ArgumentSyntax>()
.AddRange(methodSymbol.Parameters.TakeWhile(c => !(c.HasExplicitDefaultValue || c.IsParams)).Select(c => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(c.Name))))
.Add(SyntaxFactory.Argument(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.ParseName("CancellationToken"), SyntaxFactory.IdentifierName("None"))))
.AddRange(methodSymbol.Parameters.SkipWhile(c => !(c.HasExplicitDefaultValue || c.IsParams)).Select(c => SyntaxFactory.Argument(SyntaxFactory.IdentifierName(c.Name))))
)
);
if (!(isInterfaceMethod || methodSymbol.IsAbstract))
{
newAsyncMethod = newAsyncMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ReturnStatement(callAsyncWithCancellationToken)));
}
}
if (!(isInterfaceMethod || methodSymbol.IsAbstract))
{
var baseAsyncMethod = this
.GetMethods(semanticModel, methodSymbol.ReceiverType.BaseType, methodSymbol.Name + "Async", newAsyncMethod, methodSyntax)
.FirstOrDefault();
var baseMethod = this
.GetMethods(semanticModel, methodSymbol.ReceiverType.BaseType, methodSymbol.Name, methodSyntax, methodSyntax)
.FirstOrDefault();
var parentContainsAsyncMethod = baseAsyncMethod != null;
var parentContainsMethodWithRewriteAsync =
baseMethod?.GetAttributes().Any(c => c.AttributeClass.Name.StartsWith("RewriteAsync")) == true
|| baseMethod?.ContainingType.GetAttributes().Any(c => c.AttributeClass.Name.StartsWith("RewriteAsync")) == true;
var hadNew = newAsyncMethod.Modifiers.Any(c => c.Kind() == SyntaxKind.NewKeyword);
var hadOverride = newAsyncMethod.Modifiers.Any(c => c.Kind() == SyntaxKind.OverrideKeyword);
if (!parentContainsAsyncMethod && hadNew)
{
newAsyncMethod = newAsyncMethod.WithModifiers(new SyntaxTokenList().AddRange(newAsyncMethod.Modifiers.Where(c => c.Kind() != SyntaxKind.NewKeyword)));
}
if (parentContainsAsyncMethod && !(baseAsyncMethod.IsVirtual || baseAsyncMethod.IsAbstract || baseAsyncMethod.IsOverride))
{
return null;
}
if (!(parentContainsAsyncMethod || parentContainsMethodWithRewriteAsync))
{
newAsyncMethod = newAsyncMethod.WithModifiers(new SyntaxTokenList().AddRange(newAsyncMethod.Modifiers.Where(c => c.Kind() != SyntaxKind.OverrideKeyword)));
if (hadOverride)
{
newAsyncMethod = newAsyncMethod.WithModifiers(newAsyncMethod.Modifiers.Add(SyntaxFactory.Token(SyntaxKind.VirtualKeyword)));
}
}
var baseMatchedMethod = baseMethod ?? baseAsyncMethod;
if (methodSyntax.ConstraintClauses.Any())
{
newAsyncMethod = newAsyncMethod.WithConstraintClauses(methodSyntax.ConstraintClauses);
}
else if (!hadOverride && baseMatchedMethod != null)
{
var constraintClauses = new List<TypeParameterConstraintClauseSyntax>();
foreach (var typeParameter in baseMatchedMethod.TypeParameters)
{
var constraintClause = SyntaxFactory.TypeParameterConstraintClause(typeParameter.Name);
var constraints = new List<TypeParameterConstraintSyntax>();
if (typeParameter.HasReferenceTypeConstraint)
{
constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.ClassConstraint));
}
if (typeParameter.HasValueTypeConstraint)
{
constraints.Add(SyntaxFactory.ClassOrStructConstraint(SyntaxKind.StructConstraint));
}
if (typeParameter.HasConstructorConstraint)
{
constraints.Add(SyntaxFactory.ConstructorConstraint());
}
constraints.AddRange(typeParameter.ConstraintTypes.Select(c => SyntaxFactory.TypeConstraint(SyntaxFactory.ParseName(c.ToMinimalDisplayString(semanticModel, methodSyntax.SpanStart)))));
constraintClause = constraintClause.WithConstraints(SyntaxFactory.SeparatedList(constraints));
constraintClauses.Add(constraintClause);
}
newAsyncMethod = newAsyncMethod.WithConstraintClauses(SyntaxFactory.List(constraintClauses));
}
}
var attribute = methodSymbol.GetAttributes().SingleOrDefault(a => a.AttributeClass.Name.EndsWith("RewriteAsyncAttribute"))
?? methodSymbol.ContainingType.GetAttributes().SingleOrDefault(a => a.AttributeClass.Name.EndsWith("RewriteAsyncAttribute"));
if (attribute?.ConstructorArguments.Length > 0)
{
var first = attribute.ConstructorArguments.First();
if (first.Type.Equals(this.methodAttributesSymbol))
{
var methodAttributes = (MethodAttributes)Enum.ToObject(typeof(MethodAttributes), Convert.ToInt32(first.Value));
newAsyncMethod = newAsyncMethod.WithAccessModifiers(methodAttributes);
}
}
return newAsyncMethod;
}
}