package com.github.netty.protocol.servlet; import com.github.netty.core.util.*; import com.github.netty.protocol.servlet.util.HttpConstants; import com.github.netty.protocol.servlet.util.MimeMappingsX; import com.github.netty.protocol.servlet.util.UrlMapper; import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory; import io.netty.handler.codec.http.multipart.DiskAttribute; import io.netty.handler.codec.http.multipart.DiskFileUpload; import io.netty.handler.codec.http.multipart.HttpDataFactory; import io.netty.util.concurrent.FastThreadLocal; import javax.servlet.*; import javax.servlet.descriptor.JspConfigDescriptor; import javax.servlet.http.HttpSessionAttributeListener; import javax.servlet.http.HttpSessionIdListener; import javax.servlet.http.HttpSessionListener; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.MalformedURLException; import java.net.URL; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.ExecutorService; /** * Servlet context (lifetime same as server) * @author wangzihao * 2018/7/14/014 */ public class ServletContext implements javax.servlet.ServletContext { private LoggerX logger = LoggerFactoryX.getLogger(getClass()); /** * Default: 20 minutes, */ private int sessionTimeout = 1200; /** * The maximum number of bytes written to the outputstream.writer () method of the servlet each time it is called is exceeded */ private int responseWriterChunkMaxHeapByteLength = 4096; /** * Minimum upload file length, in bytes (becomes temporary file storage if larger than 16KB) */ private long uploadMinSize = 4096 * 16; private Map<String,Object> attributeMap = new LinkedHashMap<>(16); private Map<String,String> initParamMap = new LinkedHashMap<>(16); private Map<String, ServletRegistration> servletRegistrationMap = new LinkedHashMap<>(8); private Map<String, ServletFilterRegistration> filterRegistrationMap = new LinkedHashMap<>(8); private FastThreadLocal<Map<Charset, HttpDataFactory>> httpDataFactoryThreadLocal = new FastThreadLocal<Map<Charset, HttpDataFactory>>(){ @Override protected Map<Charset, HttpDataFactory> initialValue() throws Exception { return new LinkedHashMap<>(5); } }; private Set<SessionTrackingMode> defaultSessionTrackingModeSet = new HashSet<>(Arrays.asList(SessionTrackingMode.COOKIE,SessionTrackingMode.URL)); // private final PropertyChangeSupport propertyChangeSupport = new PropertyChangeSupport(this); private ServletErrorPageManager servletErrorPageManager = new ServletErrorPageManager(); private MimeMappingsX mimeMappings = new MimeMappingsX(); private ServletEventListenerManager servletEventListenerManager = new ServletEventListenerManager(); private ServletSessionCookieConfig sessionCookieConfig = new ServletSessionCookieConfig(); private UrlMapper<ServletRegistration> servletUrlMapper = new UrlMapper<>(true); private UrlMapper<ServletFilterRegistration> filterUrlMapper = new UrlMapper<>(false); private ResourceManager resourceManager; private ExecutorService asyncExecutorService; private SessionService sessionService = new SessionLocalMemoryServiceImpl(); private Set<SessionTrackingMode> sessionTrackingModeSet; private boolean enableLookupFlag = false; private boolean asyncSwitchThread = true; private String serverHeader; private String contextPath = ""; private String requestCharacterEncoding; private String responseCharacterEncoding; private String servletContextName; private InetSocketAddress serverAddress; private ClassLoader classLoader; public ServletContext() { this(null); } public ServletContext(ClassLoader classLoader) { this.classLoader = classLoader == null ? getClass().getClassLoader(): classLoader; setDocBase(createTempDir("netty-docbase").getAbsolutePath()); } public void setAsyncSwitchThread(boolean asyncSwitchThread) { this.asyncSwitchThread = asyncSwitchThread; } public boolean isAsyncSwitchThread() { return asyncSwitchThread; } public boolean isEnableLookupFlag() { return enableLookupFlag; } public void setEnableLookupFlag(boolean enableLookupFlag) { this.enableLookupFlag = enableLookupFlag; } public void setServerAddress(InetSocketAddress serverAddress) { this.serverAddress = serverAddress; } public void setDocBase(String docBase){ String workspace = '/' + (serverAddress == null || HostUtil.isLocalhost(serverAddress.getHostName())? "localhost": serverAddress.getHostName()); this.resourceManager = new ResourceManager(docBase,workspace,classLoader); this.resourceManager.mkdirs("/"); DiskFileUpload.deleteOnExitTemporaryFile = true; DiskAttribute.deleteOnExitTemporaryFile = true; DiskFileUpload.baseDirectory = resourceManager.getRealPath("/"); DiskAttribute.baseDirectory = resourceManager.getRealPath("/"); } protected static File createTempDir(String prefix) { try { File tempDir = File.createTempFile(prefix + ".", ""); tempDir.delete(); tempDir.mkdir(); tempDir.deleteOnExit(); return tempDir; }catch (IOException ex) { throw new IllegalStateException( "Unable to create tempDir. java.io.tmpdir is set to " + System.getProperty("java.io.tmpdir"), ex); } } public ExecutorService getAsyncExecutorService() { if(asyncExecutorService == null) { synchronized (this){ if(asyncExecutorService == null) { asyncExecutorService = new ThreadPoolX("Async", Runtime.getRuntime().availableProcessors() * 2); // executorService = new DefaultEventExecutorGroup(15); } } } return asyncExecutorService; } public void setAsyncExecutorService(ExecutorService asyncExecutorService) { this.asyncExecutorService = asyncExecutorService; } public HttpDataFactory getHttpDataFactory(Charset charset){ Map<Charset, HttpDataFactory> httpDataFactoryMap = httpDataFactoryThreadLocal.get(); HttpDataFactory factory = httpDataFactoryMap.get(charset); if(factory == null){ factory = new DefaultHttpDataFactory(uploadMinSize,charset); httpDataFactoryMap.put(charset, factory); } return factory; } public String getServletPath(String absoluteUri) { return servletUrlMapper.getServletPath(absoluteUri); } public long getUploadMinSize() { return uploadMinSize; } public void setUploadMinSize(long uploadMinSize) { this.uploadMinSize = uploadMinSize; } public MimeMappingsX getMimeMappings() { return mimeMappings; } public ResourceManager getResourceManager() { return resourceManager; } public ServletErrorPageManager getErrorPageManager() { return servletErrorPageManager; } public void setServletContextName(String servletContextName) { this.servletContextName = servletContextName; } public void setServerHeader(String serverHeader) { this.serverHeader = serverHeader; } public String getServerHeader() { return serverHeader; } public void setContextPath(String contextPath) { this.contextPath = contextPath; this.filterUrlMapper.setRootPath(contextPath); this.servletUrlMapper.setRootPath(contextPath); } public ServletEventListenerManager getServletEventListenerManager() { return servletEventListenerManager; } public long getAsyncTimeout(){ String value = getInitParameter("asyncTimeout"); if(value == null || value.isEmpty()){ return 30000; } try { return Long.parseLong(value); }catch (NumberFormatException e){ return 30000; } } public int getResponseWriterChunkMaxHeapByteLength() { return responseWriterChunkMaxHeapByteLength; } public void setResponseWriterChunkMaxHeapByteLength(int responseWriterChunkMaxHeapByteLength) { this.responseWriterChunkMaxHeapByteLength = responseWriterChunkMaxHeapByteLength; } public InetSocketAddress getServerAddress() { return serverAddress; } public void setSessionService(SessionService sessionService) { this.sessionService = sessionService; } public SessionService getSessionService() { return sessionService; } public int getSessionTimeout() { return sessionTimeout; } public void setSessionTimeout(int sessionTimeout) { if(sessionTimeout <= 0){ return; } this.sessionTimeout = sessionTimeout; } @Override public String getContextPath() { return contextPath; } @Override public ServletContext getContext(String uripath) { return this; } @Override public int getMajorVersion() { return 3; } @Override public int getMinorVersion() { return 0; } @Override public int getEffectiveMajorVersion() { return 3; } @Override public int getEffectiveMinorVersion() { return 0; } @Override public String getMimeType(String file) { if (file == null) { return null; } int period = file.lastIndexOf('.'); if (period < 0) { return null; } String extension = file.substring(period + 1); if (extension.length() < 1) { return null; } return mimeMappings.get(extension); } @Override public Set<String> getResourcePaths(String path) { return resourceManager.getResourcePaths(path); } @Override public URL getResource(String path) throws MalformedURLException { return resourceManager.getResource(path); } @Override public InputStream getResourceAsStream(String path) { return resourceManager.getResourceAsStream(path); } @Override public String getRealPath(String path) { return resourceManager.getRealPath(path); } @Override public ServletRequestDispatcher getRequestDispatcher(String path) { UrlMapper.Element<ServletRegistration> element = servletUrlMapper.getMappingObjectByUri(path); if(element == null){ return null; } ServletRegistration servletRegistration = element.getObject(); if(servletRegistration == null){ return null; } ServletFilterChain filterChain = ServletFilterChain.newInstance(this,servletRegistration); filterUrlMapper.addMappingObjectsByUri(path,filterChain.getFilterRegistrationList()); ServletRequestDispatcher dispatcher = ServletRequestDispatcher.newInstance(filterChain); dispatcher.setMapperElement(element); dispatcher.setPath(path); return dispatcher; } @Override public ServletRequestDispatcher getNamedDispatcher(String name) { ServletRegistration servletRegistration = null == name ? null : getServletRegistration(name); if (servletRegistration == null) { return null; } ServletFilterChain filterChain = ServletFilterChain.newInstance(this,servletRegistration); List<UrlMapper.Element<ServletFilterRegistration>> filterList = filterChain.getFilterRegistrationList(); for (ServletFilterRegistration registration : filterRegistrationMap.values()) { for(String servletName : registration.getServletNameMappings()){ if(servletName.equals(name)){ filterList.add(new UrlMapper.Element<>(name,registration)); } } } ServletRequestDispatcher dispatcher = ServletRequestDispatcher.newInstance(filterChain); dispatcher.setName(name); return dispatcher; } @Override public Servlet getServlet(String name) throws ServletException { ServletRegistration registration = servletRegistrationMap.get(name); if(registration == null){ return null; } return registration.getServlet(); } @Override public Enumeration<Servlet> getServlets() { List<Servlet> list = new ArrayList<>(); for(ServletRegistration registration : servletRegistrationMap.values()){ list.add(registration.getServlet()); } return Collections.enumeration(list); } @Override public Enumeration<String> getServletNames() { List<String> list = new ArrayList<>(); for(ServletRegistration registration : servletRegistrationMap.values()){ list.add(registration.getName()); } return Collections.enumeration(list); } @Override public void log(String msg) { logger.debug(msg); } @Override public void log(Exception exception, String msg) { logger.debug(msg,exception); } @Override public void log(String message, Throwable throwable) { logger.debug(message,throwable); } @Override public String getServerInfo() { return ServerInfo.getServerInfo() .concat("(JDK ") .concat(ServerInfo.getJvmVersion()) .concat(";") .concat(ServerInfo.getOsName()) .concat(" ") .concat(ServerInfo.getArch()) .concat(")"); } @Override public String getInitParameter(String name) { return initParamMap.get(name); } public <T>T getInitParameter(String name,T def) { String value = getInitParameter(name); if(value == null){ return def; } Class<?> clazz = def.getClass(); Object valCast = TypeUtil.cast((Object) value,clazz); if(valCast != null && valCast.getClass().isAssignableFrom(clazz)){ return (T) valCast; } return def; } @Override public Enumeration<String> getInitParameterNames() { return Collections.enumeration(initParamMap.keySet()); } @Override public boolean setInitParameter(String name, String value) { return initParamMap.putIfAbsent(name,value) == null; } @Override public Object getAttribute(String name) { return attributeMap.get(name); } @Override public Enumeration<String> getAttributeNames() { return Collections.enumeration(attributeMap.keySet()); } @Override public void setAttribute(String name, Object object) { Objects.requireNonNull(name); if(object == null){ removeAttribute(name); return; } Object oldObject = attributeMap.put(name,object); ServletEventListenerManager listenerManager = getServletEventListenerManager(); if(listenerManager.hasServletContextAttributeListener()){ listenerManager.onServletContextAttributeAdded(new ServletContextAttributeEvent(this,name,object)); if(oldObject != null){ listenerManager.onServletContextAttributeReplaced(new ServletContextAttributeEvent(this,name,oldObject)); } } } @Override public void removeAttribute(String name) { Object oldObject = attributeMap.remove(name); ServletEventListenerManager listenerManager = getServletEventListenerManager(); if(listenerManager.hasServletContextAttributeListener()){ listenerManager.onServletContextAttributeRemoved(new ServletContextAttributeEvent(this,name,oldObject)); } } @Override public String getServletContextName() { return servletContextName; } @Override public ServletRegistration addServlet(String servletName, String className) { try { return addServlet(servletName, (Class<? extends Servlet>) Class.forName(className).newInstance()); } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { throw new IllegalStateException("addServlet error ="+e+",servletName="+servletName,e); } } @Override public ServletRegistration addServlet(String servletName, Servlet servlet) { Servlet newServlet = servletEventListenerManager.onServletAdded(servlet); ServletRegistration servletRegistration; if(newServlet == null){ servletRegistration = new ServletRegistration(servletName,servlet,this,servletUrlMapper); }else { servletRegistration = new ServletRegistration(servletName,newServlet,this,servletUrlMapper); } servletRegistrationMap.put(servletName,servletRegistration); return servletRegistration; } @Override public ServletRegistration addServlet(String servletName, Class<? extends Servlet> servletClass) { Servlet servlet = null; try { servlet = servletClass.getConstructor().newInstance(); } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { throw new IllegalStateException("createServlet error ="+e+",servletName="+servletName,e); } return addServlet(servletName,servlet); } @Override public <T extends Servlet> T createServlet(Class<T> clazz) throws ServletException { try { return clazz.getConstructor().newInstance(); } catch (NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) { throw new ServletException("createServlet error ="+e+",clazz="+clazz,e); } } @Override public ServletRegistration getServletRegistration(String servletName) { return servletRegistrationMap.get(servletName); } @Override public Map<String, ServletRegistration> getServletRegistrations() { return servletRegistrationMap; } @Override public ServletFilterRegistration addFilter(String filterName, String className) { try { return addFilter(filterName, (Class<? extends Filter>) Class.forName(className)); } catch (ClassNotFoundException e) { throw new IllegalStateException("addFilter error ="+e+",filterName="+filterName,e); } } @Override public ServletFilterRegistration addFilter(String filterName, Filter filter) { ServletFilterRegistration registration = new ServletFilterRegistration(filterName,filter,this,filterUrlMapper); filterRegistrationMap.put(filterName,registration); return registration; } @Override public ServletFilterRegistration addFilter(String filterName, Class<? extends Filter> filterClass) { try { return addFilter(filterName,filterClass.newInstance()); } catch (InstantiationException | IllegalAccessException e) { throw new IllegalStateException("addFilter error ="+e,e); } } @Override public <T extends Filter> T createFilter(Class<T> clazz) throws ServletException { try { return clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new ServletException("createFilter error ="+e,e); } } @Override public FilterRegistration getFilterRegistration(String filterName) { return filterRegistrationMap.get(filterName); } @Override public Map<String, ServletFilterRegistration> getFilterRegistrations() { return filterRegistrationMap; } @Override public ServletSessionCookieConfig getSessionCookieConfig() { return sessionCookieConfig; } @Override public void setSessionTrackingModes(Set<SessionTrackingMode> sessionTrackingModes) { sessionTrackingModeSet = sessionTrackingModes; } @Override public Set<SessionTrackingMode> getDefaultSessionTrackingModes() { return defaultSessionTrackingModeSet; } @Override public Set<SessionTrackingMode> getEffectiveSessionTrackingModes() { if(sessionTrackingModeSet == null){ return getDefaultSessionTrackingModes(); } return sessionTrackingModeSet; } @Override public void addListener(String className) { try { addListener((Class<? extends EventListener>) Class.forName(className)); } catch (ClassNotFoundException e) { throw new IllegalStateException("addListener error ="+e+",className="+className,e); } } @Override public <T extends EventListener> void addListener(T listener) { Objects.requireNonNull(listener); boolean addFlag = false; ServletEventListenerManager listenerManager = getServletEventListenerManager(); if(listener instanceof ServletContextAttributeListener){ listenerManager.addServletContextAttributeListener((ServletContextAttributeListener) listener); addFlag = true; } if(listener instanceof ServletRequestListener){ listenerManager.addServletRequestListener((ServletRequestListener) listener); addFlag = true; } if(listener instanceof ServletRequestAttributeListener){ listenerManager.addServletRequestAttributeListener((ServletRequestAttributeListener) listener); addFlag = true; } if(listener instanceof HttpSessionIdListener){ listenerManager.addHttpSessionIdListenerListener((HttpSessionIdListener) listener); addFlag = true; } if(listener instanceof HttpSessionAttributeListener){ listenerManager.addHttpSessionAttributeListener((HttpSessionAttributeListener) listener); addFlag = true; } if(listener instanceof HttpSessionListener){ listenerManager.addHttpSessionListener((HttpSessionListener) listener); addFlag = true; } if(listener instanceof ServletContextListener){ listenerManager.addServletContextListener((ServletContextListener) listener); addFlag = true; } if(!addFlag){ throw new IllegalArgumentException("applicationContext.addListener.iae.wrongType"+ listener.getClass().getName()); } } @Override public void addListener(Class<? extends EventListener> listenerClass) { try { addListener(listenerClass.newInstance()); } catch (InstantiationException | IllegalAccessException e) { throw new IllegalStateException("addListener listenerClass ="+listenerClass,e); } } @Override public <T extends EventListener> T createListener(Class<T> clazz) throws ServletException { try { return clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new ServletException("addListener clazz ="+clazz,e); } } @Override public JspConfigDescriptor getJspConfigDescriptor() { throw new UnsupportedOperationException("getJspConfigDescriptor"); } @Override public ClassLoader getClassLoader() { return resourceManager.getClassLoader(); } @Override public void declareRoles(String... roleNames) { throw new UnsupportedOperationException("declareRoles"); } @Override public String getVirtualServerName() { return ServerInfo.getServerInfo() .concat(" (") .concat(serverAddress.getHostName()) .concat(":") .concat(SystemPropertyUtil.get("user.name")) .concat(")"); } @Override public String getRequestCharacterEncoding() { if(requestCharacterEncoding == null){ return HttpConstants.DEFAULT_CHARSET.name(); } return requestCharacterEncoding; } @Override public void setRequestCharacterEncoding(String requestCharacterEncoding) { this.requestCharacterEncoding = requestCharacterEncoding; } @Override public String getResponseCharacterEncoding() { if(responseCharacterEncoding == null){ return HttpConstants.DEFAULT_CHARSET.name(); } return responseCharacterEncoding; } @Override public void setResponseCharacterEncoding(String responseCharacterEncoding) { this.responseCharacterEncoding = responseCharacterEncoding; } @Override public javax.servlet.ServletRegistration.Dynamic addJspFile(String jspName, String jspFile) { throw new UnsupportedOperationException("addJspFile"); } }