如何在.netcore 3.0中使用IAsyncEnumerable <T>作为返回类型拦截异步方法

时间:2019-10-18 14:14:19

标签: c# .net-core c#-8.0

我正在尝试实现一个代理,该代理承担记录方法调用的工作。在意识到应该调用异步方法的情况下,某些日志仅应在方法完成后进行,但代理方法不应是blockig调用。

经过多次尝试,这是我想出的解决方案。

namespace ClassLibrary1
{
    using System;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Reflection;
    using System.Threading.Tasks;

    public static class ObjectExtender
    {
        internal static bool IsOfGenericType(this object obj, Type check, out Type? genericType)
        {
            Type actType = obj.GetType();
            while (actType != null && actType != typeof(object))
            {
                if (actType.IsGenericType && actType.GetGenericTypeDefinition() == check.GetGenericTypeDefinition())
                {
                    genericType = actType;
                    return true;
                }

                actType = actType.BaseType;
            }

            genericType = null;
            return false;
        }
    }

    public class Class1<T> : DispatchProxy
    {
        private static readonly MethodInfo AsyncEnumeration;
        private static readonly Dictionary<Type, MethodInfo> CachedAsyncEnumerationMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly Dictionary<Type, MethodInfo> CachedGenericTaskMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly Dictionary<Type, MethodInfo> CachedSyncEnumerationMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly MethodInfo GenericTask;
        private static readonly MethodInfo SyncEnumeration;
        private T _decorated = default!;

        static Class1()
        {
            GenericTask = typeof(Class1<T>).GetMethod("HandleTaskGenericAsync", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.DeclaredOnly);
            AsyncEnumeration = typeof(Class1<T>).GetMethod("Wrapper", BindingFlags.Static | BindingFlags.NonPublic           | BindingFlags.DeclaredOnly);
            SyncEnumeration = typeof(Class1<T>).GetMethod("SyncWrapper", BindingFlags.Static | BindingFlags.NonPublic        | BindingFlags.DeclaredOnly);
        }

        public static T Create(T decorated)
        {
            T proxy = Create<T, Class1<T>>();
            Class1<T> ap = proxy as Class1<T> ?? throw new ArgumentNullException(nameof(decorated));
            ap._decorated = decorated;

            return proxy;
        }

        private static Task<T2> HandleTaskGenericAsync<T1, T2>(T1 result, MethodInfo methodName) where T1 : Task<T2>
        {
            return result.ContinueWith(parent =>
                                       {
                                           Console.WriteLine($"After: {methodName}");
                                           return parent.Result;
                                       });
        }

        protected override object Invoke(MethodInfo targetMethod, object[] args)
        {
            try
            {
                Console.WriteLine($"Before: {targetMethod}");

                object result = targetMethod.Invoke(_decorated, args);

                if (result is Task resultTask)
                {
                    if (!resultTask.IsOfGenericType(typeof(Task<>), out Type? genericType))
                    {
                        return resultTask.ContinueWith(task =>
                                                       {
                                                           if (task.Exception != null)
                                                           {
                                                               Console.WriteLine($"{task.Exception.InnerException ?? task.Exception}, {targetMethod}");
                                                           }
                                                           else
                                                           {
                                                               Console.WriteLine($"After: {targetMethod}");
                                                           }
                                                       });
                    }

                    Debug.Assert(genericType != null, nameof(genericType) + " != null");
                    Type resultType = genericType.GetGenericArguments()[0]; // Task<> hat nur einen.
                    if (!CachedGenericTaskMethodInfos.ContainsKey(resultType))
                    {
                        CachedGenericTaskMethodInfos.Add(resultType, GenericTask.MakeGenericMethod(genericType, resultType));
                    }

                    return CachedGenericTaskMethodInfos[resultType].Invoke(null, new object[] {resultTask, targetMethod});
                }

                Type returnType = targetMethod.ReturnType;
                if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>))
                {
                    Type resultType = returnType.GetGenericArguments()[0]; //IAsyncEnumerable hat nur eines
                    if (!CachedAsyncEnumerationMethodInfos.ContainsKey(resultType))
                    {
                        CachedAsyncEnumerationMethodInfos.Add(resultType, AsyncEnumeration.MakeGenericMethod(resultType));
                    }

                    return CachedAsyncEnumerationMethodInfos[resultType].Invoke(null, new[] {result, targetMethod});
                }

                if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
                {
                    Type resultType = returnType.GetGenericArguments()[0]; //IAsyncEnumerable hat nur eines
                    if (!CachedSyncEnumerationMethodInfos.ContainsKey(resultType))
                    {
                        CachedSyncEnumerationMethodInfos.Add(resultType, SyncEnumeration.MakeGenericMethod(resultType));
                    }

                    return CachedSyncEnumerationMethodInfos[resultType].Invoke(null, new[] {result, targetMethod});
                }

                Console.WriteLine($"After: {targetMethod}");

                return result;
            }
            catch (TargetInvocationException ex)
            {
                Console.WriteLine($"{ex.InnerException ?? ex}, {targetMethod}");
                throw;
            }
        }

        private static IEnumerable<T> SyncWrapper<T>(IEnumerable<T> inner, MethodInfo targetMethod)
        {
            foreach (T t in inner)
            {
                yield return t;
            }

            Console.WriteLine($"After List: {targetMethod}");
        }

        private static async IAsyncEnumerable<T> Wrapper<T>(IAsyncEnumerable<T> inner, MethodInfo targetMethod)
        {
            await foreach (T t in inner)
            {
                yield return t;
            }

            Console.WriteLine($"After List: {targetMethod}");
        }
    }
}

此代理仅按我想要的方式拦截方法调用。 这是我的测试的输出

---Test sync calls---
Before: Void Run()
Inside: Run()
After: Void Run()
Before: System.Collections.Generic.IEnumerable`1[System.Int32] RunEnumerator[Int32](Int32[])
Inside Start: RunEnumerator()
Erg: 1
Erg: 2
Erg: 3
Erg: 4
Inside Ende: RunEnumerator()
After List: System.Collections.Generic.IEnumerable`1[System.Int32] RunEnumerator[Int32](Int32[])
---Test async calls---
Before: System.Threading.Tasks.Task RunAsync()
Inside: RunAsync()
After: System.Threading.Tasks.Task RunAsync()
Before: System.Threading.Tasks.Task RunAwaitAsync()
Inside: RunAwaitAsync()
After: System.Threading.Tasks.Task RunAwaitAsync()
Before: System.Threading.Tasks.Task`1[System.String] RunAwaitGenericTask[String](System.String)
Inside: RunAwaitGenericTask()
After: System.Threading.Tasks.Task`1[System.String] RunAwaitGenericTask[String](System.String)
Before: System.Collections.Generic.IAsyncEnumerable`1[System.Int32] RunAwaitGenericEnumeratorTask[Int32](Int32[])
Inside Start: RunAwaitGenericEnumeratorTask()
Erg: 1
Erg: 2
Erg: 3
Erg: 4
Inside Ende: RunAwaitGenericEnumeratorTask()
After List: System.Collections.Generic.IAsyncEnumerable`1[System.Int32] RunAwaitGenericEnumeratorTask[Int32](Int32[])

1 个答案:

答案 0 :(得分:0)

最后我找到了解决方案。感谢Jeroen Mostert为我指出正确的方向。

namespace ClassLibrary1
{
    using System;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Reflection;
    using System.Threading.Tasks;

    public static class ObjectExtender
    {
        internal static bool IsOfGenericType(this object obj, Type check, out Type? genericType)
        {
            Type actType = obj.GetType();
            while (actType != null && actType != typeof(object))
            {
                if (actType.IsGenericType && actType.GetGenericTypeDefinition() == check.GetGenericTypeDefinition())
                {
                    genericType = actType;
                    return true;
                }

                actType = actType.BaseType;
            }

            genericType = null;
            return false;
        }
    }

    public class Class1<T> : DispatchProxy
    {
        private static readonly MethodInfo AsyncEnumeration;
        private static readonly Dictionary<Type, MethodInfo> CachedAsyncEnumerationMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly Dictionary<Type, MethodInfo> CachedGenericTaskMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly Dictionary<Type, MethodInfo> CachedSyncEnumerationMethodInfos = new Dictionary<Type, MethodInfo>();
        private static readonly MethodInfo GenericTask;
        private static readonly MethodInfo SyncEnumeration;
        private T _decorated = default!;

        static Class1()
        {
            GenericTask = typeof(Class1<T>).GetMethod("HandleTaskGenericAsync", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.DeclaredOnly);
            AsyncEnumeration = typeof(Class1<T>).GetMethod("Wrapper", BindingFlags.Static | BindingFlags.NonPublic           | BindingFlags.DeclaredOnly);
            SyncEnumeration = typeof(Class1<T>).GetMethod("SyncWrapper", BindingFlags.Static | BindingFlags.NonPublic        | BindingFlags.DeclaredOnly);
        }

        public static T Create(T decorated)
        {
            T proxy = Create<T, Class1<T>>();
            Class1<T> ap = proxy as Class1<T> ?? throw new ArgumentNullException(nameof(decorated));
            ap._decorated = decorated;

            return proxy;
        }

        private static Task<T2> HandleTaskGenericAsync<T1, T2>(T1 result, MethodInfo methodName) where T1 : Task<T2>
        {
            return result.ContinueWith(parent =>
                                       {
                                           Console.WriteLine($"After: {methodName}");
                                           return parent.Result;
                                       });
        }

        protected override object Invoke(MethodInfo targetMethod, object[] args)
        {
            try
            {
                Console.WriteLine($"Before: {targetMethod}");

                object result = targetMethod.Invoke(_decorated, args);

                if (result is Task resultTask)
                {
                    if (!resultTask.IsOfGenericType(typeof(Task<>), out Type? genericType))
                    {
                        return resultTask.ContinueWith(task =>
                                                       {
                                                           if (task.Exception != null)
                                                           {
                                                               Console.WriteLine($"{task.Exception.InnerException ?? task.Exception}, {targetMethod}");
                                                           }
                                                           else
                                                           {
                                                               Console.WriteLine($"After: {targetMethod}");
                                                           }
                                                       });
                    }

                    Debug.Assert(genericType != null, nameof(genericType) + " != null");
                    Type resultType = genericType.GetGenericArguments()[0]; // Task<> hat nur einen.
                    if (!CachedGenericTaskMethodInfos.ContainsKey(resultType))
                    {
                        CachedGenericTaskMethodInfos.Add(resultType, GenericTask.MakeGenericMethod(genericType, resultType));
                    }

                    return CachedGenericTaskMethodInfos[resultType].Invoke(null, new object[] {resultTask, targetMethod});
                }

                Type returnType = targetMethod.ReturnType;
                if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>))
                {
                    Type resultType = returnType.GetGenericArguments()[0]; //IAsyncEnumerable hat nur eines
                    if (!CachedAsyncEnumerationMethodInfos.ContainsKey(resultType))
                    {
                        CachedAsyncEnumerationMethodInfos.Add(resultType, AsyncEnumeration.MakeGenericMethod(resultType));
                    }

                    return CachedAsyncEnumerationMethodInfos[resultType].Invoke(null, new[] {result, targetMethod});
                }

                if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
                {
                    Type resultType = returnType.GetGenericArguments()[0]; //IAsyncEnumerable hat nur eines
                    if (!CachedSyncEnumerationMethodInfos.ContainsKey(resultType))
                    {
                        CachedSyncEnumerationMethodInfos.Add(resultType, SyncEnumeration.MakeGenericMethod(resultType));
                    }

                    return CachedSyncEnumerationMethodInfos[resultType].Invoke(null, new[] {result, targetMethod});
                }

                Console.WriteLine($"After: {targetMethod}");

                return result;
            }
            catch (TargetInvocationException ex)
            {
                Console.WriteLine($"{ex.InnerException ?? ex}, {targetMethod}");
                throw;
            }
        }

        private static IEnumerable<T> SyncWrapper<T>(IEnumerable<T> inner, MethodInfo targetMethod)
        {
            foreach (T t in inner)
            {
                yield return t;
            }

            Console.WriteLine($"After List: {targetMethod}");
        }

        private static async IAsyncEnumerable<T> Wrapper<T>(IAsyncEnumerable<T> inner, MethodInfo targetMethod)
        {
            await foreach (T t in inner)
            {
                yield return t;
            }

            Console.WriteLine($"After List: {targetMethod}");
        }
    }
}