diff --git a/src/main/java/artsploit/LdapServer.java b/src/main/java/artsploit/LdapServer.java index 4a19857..6260996 100644 --- a/src/main/java/artsploit/LdapServer.java +++ b/src/main/java/artsploit/LdapServer.java @@ -7,6 +7,10 @@ import com.unboundid.ldap.listener.InMemoryListenerConfig; import com.unboundid.ldap.listener.interceptor.InMemoryInterceptedSearchResult; import com.unboundid.ldap.listener.interceptor.InMemoryOperationInterceptor; +import com.unboundid.ldap.sdk.ReadOnlySearchRequest; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.Socket; import org.reflections.Reflections; import javax.net.ServerSocketFactory; @@ -71,20 +75,58 @@ public LdapServer() throws Exception { */ @Override public void processSearchResult(InMemoryInterceptedSearchResult result) { - String base = result.getRequest().getBaseDN(); + ReadOnlySearchRequest request = result.getRequest(); + System.out.println("request: from: " + getRemoteAddress(result) + " " + request); + String base = request.getBaseDN(); + System.out.println("base: " + base); LdapController controller = null; //find controller for(String key: routes.keySet()) { - //compare using wildcard at the end - if(key.equals(base) || key.endsWith("*") && base.startsWith(key.substring(0, key.length()-1))) { + // compare using contains + if (base.contains(key) && key.length() > 0 || key.equals(base)) { controller = routes.get(key); break; } } + if (controller == null) { + System.out.println("No controller for base '" + base + "', falling back to default."); + controller = routes.get(""); + } try { controller.sendResult(result, base); } catch (Exception e1) { e1.printStackTrace(); } } + + // uses reflection to get the remote address of the client + // since the required method isn't available on the public API + private String getRemoteAddress(InMemoryInterceptedSearchResult result) { + if (getSocketMethod == null || getClientConnectionMethod == null) { + return null; + } + try { + Socket clientConnection = (Socket) getSocketMethod.invoke(getClientConnectionMethod.invoke(result)); + return clientConnection.getRemoteSocketAddress().toString(); + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + private static Method getClientConnectionMethod; + private static Method getSocketMethod; + + static { + Class interceptedOperationClazz = null; + try { + interceptedOperationClazz = Class.forName("com.unboundid.ldap.listener.interceptor.InterceptedOperation"); + getClientConnectionMethod = interceptedOperationClazz.getDeclaredMethod("getClientConnection"); + getClientConnectionMethod.setAccessible(true); + getSocketMethod = getClientConnectionMethod.getReturnType().getDeclaredMethod("getSocket"); + getSocketMethod.setAccessible(true); + } catch (ClassNotFoundException | NoSuchMethodException e) { + e.printStackTrace(); + } + } }