package com.nosuchfield.geisha.mvc.server.nio; import com.google.common.primitives.Bytes; import com.nosuchfield.geisha.ioc.BeansPool; import com.nosuchfield.geisha.mvc.MethodDetail; import com.nosuchfield.geisha.mvc.UrlMappingPool; import com.nosuchfield.geisha.mvc.annotations.Param; import com.nosuchfield.geisha.mvc.enums.RequestMethod; import com.nosuchfield.geisha.utils.Constants; import lombok.extern.slf4j.Slf4j; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; import java.io.UnsupportedEncodingException; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; /** * @author hourui 2017/10/27 21:04 */ @Slf4j public class NioServer { private int port; private Selector selector; public static void start(int port) { new Thread(() -> { try { NioServer server = new NioServer(); server.port = port; log.info("NioServer is running on http://127.0.0.1:{}", port); server.start(); } catch (Exception e) { e.printStackTrace(); } }).start(); } private void start() throws Exception { ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.socket().bind(new InetSocketAddress(port)); serverSocketChannel.configureBlocking(false); selector = Selector.open(); serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); while (true) { selector.select(); // 此处的select方法是阻塞的 // 对所有的key做一次遍历,由key本身判断此事件是否与自己有关 selector.selectedKeys().forEach((this::handleKey)); } } private void handleKey(SelectionKey key) { try { ServerSocketChannel server = null; SocketChannel client = null; if (key.isAcceptable()) { server = (ServerSocketChannel) key.channel(); client = server.accept(); if (client != null) { client.configureBlocking(false); // 给新的链接注册读取事件 client.register(selector, SelectionKey.OP_READ); log.info("Open channel {}", client.getRemoteAddress()); } } else if (key.isReadable()) { client = (SocketChannel) key.channel(); read(client); // key.cancel(); } } catch (Exception e) { e.printStackTrace(); } } // 读取channel数据并且写回响应内容 private void read(SocketChannel channel) throws Exception { LinkedList<Byte> list = new LinkedList<>(); ByteBuffer buf = ByteBuffer.allocate(1024); int bytesRead = channel.read(buf); // 如果读取到-1,则说明客户端关闭了该链接 if (bytesRead == -1) { log.info("Close channel {}", channel.getRemoteAddress()); channel.close(); return; } // 非阻塞IO可以读取0个字节,这种数据应该手动丢弃 if (bytesRead == 0) return; // 读取所有的数据 while (bytesRead > 0) { buf.flip(); while (buf.hasRemaining()) { list.add(buf.get()); } buf.clear(); bytesRead = channel.read(buf); } String request = new String(Bytes.toArray(list), Constants.DEFAULT_ENCODING); try { // 写回响应 response(request, channel); } catch (Exception e) { e.printStackTrace(); // 返回错误信息 StringWriter stringWriter = new StringWriter(); PrintWriter printWriter = new PrintWriter(stringWriter); e.printStackTrace(printWriter); serverError(stringWriter.toString(), channel); } } /** * 解析请求并返回响应 */ private void response(String request, SocketChannel channel) throws Exception { HttpRequest httpRequest = ParseNioRequest.getRequest(request); String url = httpRequest.getUrl(); RequestMethod requestMethod = httpRequest.getRequestMethod(); log.info("{} {}", requestMethod, url); MethodDetail methodDetail = UrlMappingPool.getInstance().getMap(url, requestMethod); // 如果找不到对应的匹配规则 if (methodDetail == null) { notFound(channel); return; } Class clazz = methodDetail.getClazz(); Object object = BeansPool.getInstance().getObject(clazz); if (object == null) throw new RuntimeException("can't find bean for " + clazz); Map<String, String> requestParam = httpRequest.getParams(); // 请求参数 List<String> params = new ArrayList<>(); // 最终的方法参数 Method method = methodDetail.getMethod(); // 获取方法的所有的参数 Parameter[] parameters = method.getParameters(); for (Parameter parameter : parameters) { String name = null; // 获取参数上所有的注解 Annotation[] annotations = parameter.getAnnotations(); for (Annotation annotation : annotations) { if (annotation.annotationType() == Param.class) { Param param = (Param) annotation; name = param.value(); break; } } // 如果请求参数中存在这个参数就把该值赋给方法参数,否则赋值null params.add(requestParam.getOrDefault(name, null)); } Object result = method.invoke(object, params.toArray()); // 写回响应 String str = (String) result; String response = "HTTP/1.1 200 OK" + Constants.CRLF + "Content-Length: " + str.getBytes(Constants.DEFAULT_ENCODING).length + Constants.CRLF_2 + str; writeData(response, channel); } /** * 500 Internal Server Error */ private void serverError(String error, SocketChannel channel) throws UnsupportedEncodingException { String response = "HTTP/1.1 500 Internal Server Error" + Constants.CRLF + "Content-Length: " + error.getBytes(Constants.DEFAULT_ENCODING).length + Constants.CRLF_2 + error; writeData(response, channel); } /** * 404 Not Found */ private void notFound(SocketChannel channel) throws IOException { String str = Constants.NOT_FOUND; String response = "HTTP/1.1 404 Not Found" + Constants.CRLF + "Content-Length: " + str.getBytes(Constants.DEFAULT_ENCODING).length + Constants.CRLF_2 + str; writeData(response, channel); } /** * 向连接中写数据 * * @param data 数据 * @param channel 连接 */ private void writeData(String data, SocketChannel channel) throws UnsupportedEncodingException { ByteBuffer res = ByteBuffer.allocate(data.getBytes(Constants.DEFAULT_ENCODING).length); res.clear(); res.put(data.getBytes(Constants.DEFAULT_ENCODING)); res.flip(); while (res.hasRemaining()) { try { channel.write(res); } catch (IOException e) { log.error("error when writing data"); e.printStackTrace(); } } } }