Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support call forwarding to a provided instance for types it implements that are also included in substitute being created. #841

Open
rbeurskens opened this issue Nov 5, 2024 · 0 comments

Comments

@rbeurskens
Copy link

rbeurskens commented Nov 5, 2024

Is your feature request related to a problem? Please describe.
I often run into the scenario where there is a custom interface I need to substitute which is partially implemented by an existing type (in my case, it is usually custom collection-like interfaces that inherit (only) from IEnumerable<T>, IReadonlyList<T>, etc They are part of a public API.)

Describe the solution you'd like
In my test I would like to use a real collection (in this case, but it could be for any sort of type) as a substitute without having to manually configure each member to return/call the corresponding member on for example List<T> (which does not implement the custom collection, but it does have interfaces in common)

Describe alternatives you've considered

  • I tried the new feature Substitute.Substitute.ForTypeForwardingTo<TClass,Tinterface>(), but it throws an exception because List<T> does not directly implement the custom collection.
  • So far I use a workaround as a custom ICallHandler, but I suspect all the reflection inside of it will execute on every call, which is not optimal. Also, there is no checking if there is at least one interface the target and the substitute have in common, making it easy to introduce errors if the wrong interface is used to creating a substitute.
   public static class SubstituteExtensions
    {
        public static T ForwardCallsTo<T>(this T subst, object instance)
        {
            SubstitutionContext.Current.GetCallRouterFor(subst)
                .RegisterCustomCallHandlerFactory(state => new RedirectToInstanceHandler(instance));
            return subst;
        }


        private class RedirectToInstanceHandler : ICallHandler
        {
            private readonly object _instance;
            public RedirectToInstanceHandler(object instance)
            {
                _instance = instance;
            }
            public RouteAction Handle(ICall call)
            {
                // Look if the called method is implemented on the instance.
                var methodInfo = call.GetMethodInfo();
                MethodInfo methodOnInstance = null;
                
                if (methodInfo.ReflectedType.IsInterface)
                {
                    if (call.Target() != _instance // prevent stackoverflow caused by forwarding to itself
                        && methodInfo.ReflectedType.IsInstanceOfType(_instance))
                        methodOnInstance = methodInfo; // otherwise, forward to interface implementation on other object
                } else // find implementing method called from the class
                    methodOnInstance = GetInterfaceDeclarationsForMethod(methodInfo).FirstOrDefault(i => i.DeclaringType.IsInstanceOfType(_instance));

                if (methodOnInstance != null) // If so, forward the call to the corresponding method on the instance.
                    return RouteAction.Return(methodOnInstance.Invoke(_instance, call.GetArguments()));
                return RouteAction.Continue(); // If not, do nothing. (fallback to default)
            }
        }

        private static IEnumerable<InterfaceMapping> GetAllInterfaceMaps(Type aType) =>
            aType.GetTypeInfo()
                .ImplementedInterfaces
                .Select(aType.GetInterfaceMap);
        
        private static IEnumerable<MethodInfo> GetInterfaceDeclarationsForMethod(MethodInfo mi) =>
            GetAllInterfaceMaps(mi.ReflectedType)
                .SelectMany(map => Enumerable.Range(0, map.TargetMethods.Length)
                    .Where(n => map.TargetMethods[n] == mi)
                    .Select(n => map.InterfaceMethods[n]));
    }

Additional context

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant