.NET WCF Custom Headers

I use server-side error logging to trap and record any exceptions an end-user might be receiving. It’s handy for pro-active debugging and it’s also useful for tracking any potential intrusion attempts. To that end, I need to have the end user’s IP address to see if the intrusion attempts are all coming from a IP address or range that can potentially be blocked. In an N-tiered SOA app though, the service call that logs the exception will be in a different tier (and potentially on a different server) to the end-user. That means that the caller IP address for the service’s Log function will actually be the web server’s IP address, rather than the end-user’s browser IP address.

WCF allows for custom headers and it provides an ideal way to pass the end-user’s IP address (or any metadata) to the service layer from the web layer.

Firstly, we need to add the user’s IP address to every WCF call. This is done using a custom IClientMessageInspector to add a message header.

public class ClientMessageInspector : IClientMessageInspector
    {
        private const string HEADER_URI_NAMESPACE = "http://tempuri.org";
        private const string HEADER_SOURCE_ADDRESS = "SOURCE_ADDRESS";

        public ClientMessageInspector()
        {
        }

        public void AfterReceiveReply(ref System.ServiceModel.Channels.Message reply, object correlationState)
        {
        }

        public object BeforeSendRequest(ref System.ServiceModel.Channels.Message request, System.ServiceModel.IClientChannel channel)
        {
            if (HttpContext.Current != null)
            {
                MessageHeader header = null;
                try
                {
                    header = MessageHeader.CreateHeader(HEADER_SOURCE_ADDRESS , HEADER_URI_NAMESPACE, HttpContext.Current.Request.UserHostAddress);
                }
                catch (Exception e)
                {
                    header = MessageHeader.CreateHeader(HEADER_SOURCE_ADDRESS , HEADER_URI_NAMESPACE , null);
                }
                request.Headers.Add(header);
            }
            else if (OperationContext.Current != null)
            {
                //If service layer does a nested call to another service layer method, ensure that original web caller IP is passed through also 
                MessageHeader header = null;
                int index = OperationContext.Current.IncomingMessageHeaders.FindHeader(HEADER_SOURCE_ADDRESS, HEADER_URI_NAMESPACE);
                if (index > -1)
                {
                    string remoteAddress = OperationContext.Current.IncomingMessageHeaders.GetHeader(index);
                    header = MessageHeader.CreateHeader(HEADER_SOURCE_ADDRESS, HEADER_URI_NAMESPACE, remoteAddress);
                }else{
                    header = MessageHeader.CreateHeader(HEADER_SOURCE_ADDRESS , HEADER_URI_NAMESPACE , null);
                }

                request.Headers.Add(header);
            }
            
            return null;

        }
    }

To make WCF service calls use this inspector, a behavior and behavior extension is needed:

 public class EndpointBehavior : IEndpointBehavior
    {
        public EndpointBehavior() {}

        public void AddBindingParameters(ServiceEndpoint endpoint, System.ServiceModel.Channels.BindingParameterCollection bindingParameters) {}

        public void ApplyClientBehavior(ServiceEndpoint endpoint, System.ServiceModel.Dispatcher.ClientRuntime clientRuntime)
        {
            ClientMessageInspector inspector = new ClientMessageInspector();
            clientRuntime.MessageInspectors.Add(inspector);
        }

        public void ApplyDispatchBehavior(ServiceEndpoint endpoint, System.ServiceModel.Dispatcher.EndpointDispatcher endpointDispatcher) {}

        public void Validate(ServiceEndpoint endpoint) {}

    }
   public class BehaviorExtension : BehaviorExtensionElement
    {
        public override Type BehaviorType
        {
            get { return typeof(EndpointBehavior); }
        }

        protected override object CreateBehavior()
        {
            return new EndpointBehavior();
        }
    }

Now we can use the extension in the config file for the client endpoints.

<system.serviceModel>
    <extensions>
      <behaviorExtensions>
        <add name="CustomExtension" type="Example.Service.BehaviorExtension, Example.Service, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null" />
      </behaviorExtensions>
    </extensions>
    <behaviors>
      <endpointBehaviors>
        <behavior name="ClientEndpointBehavior">
          <CustomExtension/>
        </behavior>
      </endpointBehaviors>
    </behaviors>
    <bindings>
      <netTcpBinding>
        <binding name="ExampleServiceClientBinding/>
      </netTcpBinding>
    </bindings>
    <client>
      <endpoint address="net.tcp://localhost:8091/CustomExample/DataService" binding="netTcpBinding" bindingConfiguration="ExampleServiceClientBinding" contract="Example.Service.Contract.IDataService" name="ExampleDataServiceClientEndpoint" behaviorConfiguration="ClientEndpointBehavior">
      </endpoint>
    </client>
  </system.serviceModel>

Using SvcTraceViewer we can see the new header being passed on the SOAP call:

<s:Envelope xmlns:a="http://www.w3.org/2005/08/addressing" xmlns:s="http://www.w3.org/2003/05/soap-envelope">
<s:Header>

[...]

<SOURCE_ADDRESS xmlns="http://tempuri.org">192.168.1.1</SOURCE_ADDRESS>
</s:Header>

Finally, to access this in the service code, I add a helper method to the service base class. A call to GetServiceCallerRemoteAddress() anywhere in service code will always give the IP address of the end-user caller of the service method.

    public abstract class BaseDataService 
    {
        //[...]

        protected string GetServiceCallerRemoteAddress()
        {
            ServiceSecurityContext cxtSec = ServiceSecurityContext.Current;
            int index = OperationContext.Current.IncomingMessageHeaders.FindHeader("SOURCE_ADDRESS", "http://tempuri.org");
            string remoteAddress = null;
            if (index > -1)
            {
                remoteAddress = OperationContext.Current.IncomingMessageHeaders.GetHeader(index);
            }
            return remoteAddress;
        }        
    }