package net.unit8.wscl; import net.unit8.wscl.util.DigestUtils; import net.unit8.wscl.util.IOUtils; import net.unit8.wscl.util.PropertyUtils; import net.unit8.wscl.util.QueryStringDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.websocket.ClientEndpointConfig; import javax.websocket.ContainerProvider; import javax.websocket.DeploymentException; import javax.websocket.WebSocketContainer; import java.io.File; import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; import java.net.URL; import java.net.URLConnection; import java.util.Arrays; import java.util.Enumeration; import java.util.List; import java.util.Vector; /** * ClassLoader fetching classes via WebSocket. * * @author kawasima */ public class WebSocketClassLoader extends ClassLoader { private ClassLoaderEndpoint endpoint; private URL baseUrl; private File cacheDirectory; private static final Logger logger = LoggerFactory.getLogger(WebSocketClassLoader.class); public WebSocketClassLoader(String url) throws IOException, DeploymentException { this(url, Thread.currentThread().getContextClassLoader()); } public WebSocketClassLoader(String url, ClassLoader parent) throws DeploymentException, IOException { super(parent); logger.debug("Parent classloader=" + parent); cacheDirectory = PropertyUtils.getFileSystemProperty("wscl.cache.directory"); if (cacheDirectory != null && !cacheDirectory.exists() && !cacheDirectory.mkdirs()) { throw new IllegalArgumentException( "Can't create cache directory: " + cacheDirectory); } WebSocketContainer container = ContainerProvider.getWebSocketContainer(); endpoint = new ClassLoaderEndpoint(); container.connectToServer(endpoint, ClientEndpointConfig.Builder.create().build(), URI.create(url)); try { URL httpUrl = new URL(url.replaceFirst("ws://", "http://")); QueryStringDecoder decoder = new QueryStringDecoder(httpUrl.getQuery()); List<String> classLoaderIds = decoder.parameters().get("classLoaderId"); WebSocketURLStreamHandler urlStreamHandler = new WebSocketURLStreamHandler(endpoint, cacheDirectory); if (classLoaderIds != null && !classLoaderIds.isEmpty()) urlStreamHandler.setClassLoaderId(classLoaderIds.get(0)); baseUrl = new URL("ws", httpUrl.getHost(), httpUrl.getPort(), httpUrl.getFile(), urlStreamHandler); } catch (Exception e) { throw new RuntimeException("ClassProvider URL is invalid.", e); } } private URL findCache(URL url, byte[] digest) { File cacheFile = new File(cacheDirectory, url.getPath()); if (cacheFile.exists() && Arrays.equals(digest, DigestUtils.md5hash(cacheFile))) { try { return cacheFile.toURI().toURL(); } catch (MalformedURLException e) { return url; } } else { return url; } } @Override protected URL findResource(String name) { URL url; try { StringBuilder file = new StringBuilder(256); if (!name.startsWith("/")) { file.append("/"); } file.append(name); url = new URL(baseUrl, file.toString()); } catch (MalformedURLException ex) { throw new IllegalArgumentException("name"); } try { WebSocketURLConnection connection = (WebSocketURLConnection)url.openConnection(); byte[] digest = connection.getResourceDigest(); logger.debug("findResource:" + name + ":" + url.toString()); if (digest == null) return null; return cacheDirectory != null ? findCache(url, digest) : url; } catch(Exception e) { logger.warn("Exception at fetching.", e); return null; } } /** * Returns an enumeration of URL objects representing all the resources with th given name. * * Currently, WebSocketClassLoader returns only the first element. * * @param name The name of a resource. * @return All founded resources. */ @Override protected Enumeration<URL> findResources(String name) { URL url = findResource(name); Vector<URL> urls = new Vector<>(); if (url != null) { urls.add(url); } return urls.elements(); } @Override protected Class<?> loadClass(String className, boolean resolve) throws ClassNotFoundException { synchronized (getClassLoadingLock(className)) { Class<?> clazz = findLoadedClass(className); if (clazz == null) { try { clazz = getParent().loadClass(className); } catch (ClassNotFoundException ignored) { } if (clazz == null) clazz = findClass(className); } if (resolve) { resolveClass(clazz); } return clazz; } } @Override protected Class<?> findClass(String className) throws ClassNotFoundException { return defineClass(className); } private Class<?> defineClass(String className) throws ClassNotFoundException { String path = className.replace('.', '/').concat(".class"); URL url = findResource(path); if (url == null) throw new ClassNotFoundException(className); try { URLConnection connection = url.openConnection(); byte[] bytes = IOUtils.slurp(connection.getContent()); if (bytes != null) { int idx = className.lastIndexOf("."); if (idx > 0) { String packageName = className.substring(0, idx); Package pkg = getPackage(packageName); if (pkg == null) { definePackage(packageName, null, null, null, null, null, null, null); } } return defineClass(className, bytes, 0, bytes.length); } else { throw new ClassNotFoundException(className); } } catch (Exception ex) { throw new ClassNotFoundException(className, ex); } } @Override public void finalize() throws Throwable{ try { dispose(); } finally { super.finalize(); } } public void dispose() throws IOException { endpoint.close(); } }