/* * RegistryServer.java * , * Copyright 2007 Eamonn McManus * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.mcmanus.eamonn.customregistry; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.ServerSocket; import java.net.Socket; import java.rmi.NoSuchObjectException; import java.rmi.Remote; import java.rmi.ServerException; import java.rmi.registry.Registry; import java.rmi.server.ObjID; import java.rmi.server.RemoteStub; import java.rmi.server.UID; import java.rmi.server.UnicastRemoteObject; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; public class RegistryServer { private static final Logger log = Logger.getLogger(RegistryServer.class.getName()); private static final int NTHREADS = 5; private static final int THREAD_IDLE_SECONDS = 10; private static final Method[] methodArray; private static final Map methodHashToMethodNum = new HashMap(); static { // The second value in each list below is the method hash used by the // 1.2 variant of the RMI protocol. I obtained it by examining the // output of rmic. Object[][] methods = { {"bind", 7583982177005850366L, String.class, Remote.class}, {"list", 2571371476350237748L}, {"lookup", -7538657168040752697L, String.class}, {"rebind", -8381844669958460146L, String.class, Remote.class}, {"unbind", 7305022919901907578L, String.class}, }; methodArray = new Method[methods.length]; for (int i = 0; i < methods.length; i++) { try { Object[] method = methods[i]; String name = (String) method[0]; Long hash = (Long) method[1]; methodHashToMethodNum.put(hash, i); Class[] params = new Class[method.length - 2]; System.arraycopy(method, 2, params, 0, params.length); methodArray[i] = Registry.class.getMethod(name, params); } catch (Exception e) { throw new AssertionError(e); } } } private final Registry registry; private final int port; public RegistryServer(Registry registry, int port) { this.registry = registry; this.port = port; Thread t = new Thread() { public void run() { try { runWithException(); } catch (IOException e) { log.log(Level.INFO, "Exception from server socket", e); } } }; t.setDaemon(true); t.setName("Custom RMI Registry Server"); t.start(); } private void runWithException() throws IOException { final ServerSocket ss = new ServerSocket(port); final ThreadFactory factory = new ThreadFactory() { public Thread newThread(Runnable runnable) { Thread t = Executors.defaultThreadFactory().newThread(runnable); t.setDaemon(true); t.setName("Custom RMI Registry Connection"); return t; } }; ExecutorService executor = new ThreadPoolExecutor( 0, NTHREADS, THREAD_IDLE_SECONDS, TimeUnit.SECONDS, new SynchronousQueue()); try { Method m = ThreadPoolExecutor.class.getMethod( "allowCoreThreadTimeOut", boolean.class); m.invoke(executor, true); } catch (Exception e) { log.log(Level.FINEST, "Look up allowCoreThreadTimeout", e); // Ignore, assume Java < 1.6 } try { while (true) { final Socket s = ss.accept(); executor.execute(new Runnable() { public void run() { handle(s); } }); } } finally { executor.shutdownNow(); ss.close(); } } private void handle(Socket s) { try { handleWithException(s); s.close(); } catch (IOException e) { log.log(Level.FINE, "Exception handling connection", e); } } private void handleWithException(Socket s) throws IOException { InputStream in = new BufferedInputStream(s.getInputStream()); DataInputStream din = new DataInputStream(in); OutputStream out = new BufferedOutputStream(s.getOutputStream()); DataOutputStream dout = new DataOutputStream(out); expect(in, "JRMI\0\2"); if (in.read() != 0x4b) { // StreamProtocol dout.write(0x4f); // ProtocolNotSupported dout.flush(); return; } dout.write(0x4e); // ProtocolAck // EndpointIdentifier: dout.writeUTF(s.getInetAddress().getHostAddress()); dout.writeInt(s.getPort()); dout.flush(); // peer's claimed EndpointIdentifer: din.readUTF(); din.readInt(); while (handleMessage(din, dout)) ; } private boolean handleMessage(DataInputStream din, DataOutputStream dout) throws IOException { int op = din.read(); switch (op) { case -1: log.finer("Client closed connection"); return false; case 0x52: // Ping log.finer("Ping received"); dout.write(0x53); // PingAck dout.flush(); return true; case 0x54: // DgcAck log.finer("DgcAck received"); UID.read(din); return true; case 0x50: // Call break; default: throw new IOException("Unknown transport op " + op); } ObjectInputStream oin = new ObjectInputStream(din); ObjID id = ObjID.read(oin); int opnum = oin.readInt(); long hash = oin.readLong(); Object ret; int retCode; try { ret = dispatch(oin, id, opnum, hash); retCode = 1; } catch (Exception e) { ret = e; retCode = 2; log.log(Level.FINE, "Exception for this method", e); } dout.writeByte(0x51); ObjectOutputStream oout = new MyObjectOutputStream(dout); oout.writeByte(retCode); new UID().write(oout); if (ret != NO_RETURN) oout.writeObject(ret); oout.flush(); return (!(ret instanceof IOException)); // If we got an IOException, we might be out of sync with the client, // so we close the connection. } private static final Object NO_RETURN = new Object(); private Object dispatch(ObjectInputStream oin, ObjID id, int opnum, long hash) throws Exception { if (!id.equals(new ObjID(ObjID.REGISTRY_ID))) throw new NoSuchObjectException("Not registry id: " + id); log.finest("opnum=" + opnum + ", hash=" + hash); if (opnum == -1) { if (!methodHashToMethodNum.containsKey(hash)) throw new ServerException("Bad method hash: " + hash); opnum = methodHashToMethodNum.get(hash); } else if (hash != 4905912898345647071L) // RMI 1.1 hash for Registry throw new ServerException("Bad interface hash: " + hash); if (opnum < 0 || opnum >= methodArray.length) throw new ServerException("Bad operation number: " + opnum); Method m = methodArray[opnum]; Object[] params = new Object[m.getParameterTypes().length]; // Following assumes no parameters are primitive, true for Registry. // If a param is int (e.g.) then it must be read with oin.readInt(). ClassNotFoundException cnfe = null; for (int i = 0; i < params.length; i++) { try { params[i] = oin.readObject(); } catch (ClassNotFoundException e) { cnfe = e; } } if (cnfe != null) throw cnfe; try { Object ret = m.invoke(registry, params); if (m.getReturnType() == void.class) return NO_RETURN; else return ret; } catch (InvocationTargetException e) { Throwable t = e.getCause(); if (t instanceof Exception) throw (Exception) t; else if (t instanceof Error) throw (Error) t; else throw e; // mutant Throwable } } private void expect(InputStream in, String what) throws IOException { for (int i = 0; i < what.length(); i++) { int expc = what.charAt(i); int c = in.read(); if (c != expc) { throw new IOException("Protocol error: expected " + expc + ", got " + c); } } } private static class MyObjectOutputStream extends ObjectOutputStream { MyObjectOutputStream(OutputStream out) throws IOException { super(out); } @Override protected void annotateClass(Class c) throws IOException { writeObject(null); } @Override protected void annotateProxyClass(Class c) throws IOException { writeObject(null); } } }