package com.thimbleware.jmemcached.protocol.binary; import com.thimbleware.jmemcached.protocol.Command; import com.thimbleware.jmemcached.protocol.ResponseMessage; import com.thimbleware.jmemcached.protocol.exceptions.UnknownCommandException; import com.thimbleware.jmemcached.CacheElement; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteOrder; import java.util.Set; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * */ // TODO refactor so this can be unit tested separate from netty? scalacheck? @ChannelHandler.Sharable public class MemcachedBinaryResponseEncoder<CACHE_ELEMENT extends CacheElement> extends SimpleChannelUpstreamHandler { private ConcurrentHashMap<Integer, ChannelBuffer> corkedBuffers = new ConcurrentHashMap<Integer, ChannelBuffer>(); final Logger logger = LoggerFactory.getLogger(MemcachedBinaryResponseEncoder.class); public static enum ResponseCode { OK(0x0000), KEYNF(0x0001), KEYEXISTS(0x0002), TOOLARGE(0x0003), INVARG(0x0004), NOT_STORED(0x0005), UNKNOWN(0x0081), OOM(0x00082); public short code; ResponseCode(int code) { this.code = (short)code; } } public ResponseCode getStatusCode(ResponseMessage command) { Command cmd = command.cmd.cmd; if (cmd == Command.GET || cmd == Command.GETS) { return ResponseCode.OK; } else if (cmd == Command.SET || cmd == Command.CAS || cmd == Command.ADD || cmd == Command.REPLACE || cmd == Command.APPEND || cmd == Command.PREPEND) { switch (command.response) { case EXISTS: return ResponseCode.KEYEXISTS; case NOT_FOUND: return ResponseCode.KEYNF; case NOT_STORED: return ResponseCode.NOT_STORED; case STORED: return ResponseCode.OK; } } else if (cmd == Command.INCR || cmd == Command.DECR) { return command.incrDecrResponse == null ? ResponseCode.KEYNF : ResponseCode.OK; } else if (cmd == Command.DELETE) { switch (command.deleteResponse) { case DELETED: return ResponseCode.OK; case NOT_FOUND: return ResponseCode.KEYNF; } } else if (cmd == Command.STATS) { return ResponseCode.OK; } else if (cmd == Command.VERSION) { return ResponseCode.OK; } else if (cmd == Command.FLUSH_ALL) { return ResponseCode.OK; } return ResponseCode.UNKNOWN; } public ChannelBuffer constructHeader(MemcachedBinaryCommandDecoder.BinaryCommand bcmd, ChannelBuffer extrasBuffer, ChannelBuffer keyBuffer, ChannelBuffer valueBuffer, short responseCode, int opaqueValue, long casUnique) { // take the ResponseMessage and turn it into a binary payload. ChannelBuffer header = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 24); header.writeByte((byte)0x81); // magic header.writeByte(bcmd.code); // opcode short keyLength = (short) (keyBuffer != null ? keyBuffer.capacity() :0); header.writeShort(keyLength); int extrasLength = extrasBuffer != null ? extrasBuffer.capacity() : 0; header.writeByte((byte) extrasLength); // extra length = flags + expiry header.writeByte((byte)0); // data type unused header.writeShort(responseCode); // status code int dataLength = valueBuffer != null ? valueBuffer.capacity() : 0; header.writeInt(dataLength + keyLength + extrasLength); // data length header.writeInt(opaqueValue); // opaque header.writeLong(casUnique); return header; } /** * Handle exceptions in protocol processing. Exceptions are either client or internal errors. Report accordingly. * * @param ctx * @param e * @throws Exception */ @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { try { throw e.getCause(); } catch (UnknownCommandException unknownCommand) { if (ctx.getChannel().isOpen()) ctx.getChannel().write(constructHeader(MemcachedBinaryCommandDecoder.BinaryCommand.Noop, null, null, null, (short)0x0081, 0, 0)); } catch (Throwable err) { logger.error("error", err); if (ctx.getChannel().isOpen()) ctx.getChannel().close(); } } @Override @SuppressWarnings("unchecked") public void messageReceived(ChannelHandlerContext channelHandlerContext, MessageEvent messageEvent) throws Exception { ResponseMessage<CACHE_ELEMENT> command = (ResponseMessage<CACHE_ELEMENT>) messageEvent.getMessage(); Object additional = messageEvent.getMessage(); MemcachedBinaryCommandDecoder.BinaryCommand bcmd = MemcachedBinaryCommandDecoder.BinaryCommand.forCommandMessage(command.cmd); // write extras == flags & expiry ChannelBuffer extrasBuffer = null; // write key if there is one ChannelBuffer keyBuffer = null; if (bcmd.addKeyToResponse && command.cmd.keys != null && command.cmd.keys.size() != 0) { keyBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, command.cmd.keys.get(0).getBytes()); } // write value if there is one ChannelBuffer valueBuffer = null; if (command.elements != null) { extrasBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 4); CacheElement element = command.elements[0]; extrasBuffer.writeShort((short) (element != null ? element.getExpire() : 0)); extrasBuffer.writeShort((short) (element != null ? element.getFlags() : 0)); if ((command.cmd.cmd == Command.GET || command.cmd.cmd == Command.GETS)) { if (element != null) { valueBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, element.getData()); } else { valueBuffer = ChannelBuffers.buffer(0); } } else if (command.cmd.cmd == Command.INCR || command.cmd.cmd == Command.DECR) { valueBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 8); valueBuffer.writeLong(command.incrDecrResponse); } } else if (command.cmd.cmd == Command.INCR || command.cmd.cmd == Command.DECR) { valueBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 8); valueBuffer.writeLong(command.incrDecrResponse); } long casUnique = 0; if (command.elements != null && command.elements.length != 0 && command.elements[0] != null) { casUnique = command.elements[0].getCasUnique(); } // stats is special -- with it, we write N times, one for each stat, then an empty payload if (command.cmd.cmd == Command.STATS) { // first uncork any corked buffers if (corkedBuffers.containsKey(command.cmd.opaque)) uncork(command.cmd.opaque, messageEvent.getChannel()); for (Map.Entry<String, Set<String>> statsEntries : command.stats.entrySet()) { for (String stat : statsEntries.getValue()) { keyBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, statsEntries.getKey().getBytes(MemcachedBinaryCommandDecoder.USASCII)); valueBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, stat.getBytes(MemcachedBinaryCommandDecoder.USASCII)); ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique); writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer); } } keyBuffer = null; valueBuffer = null; ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique); writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer); } else { ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique); // write everything // is the command 'quiet?' if so, then we append to our 'corked' buffer until a non-corked command comes along if (bcmd.noreply) { int totalCapacity = headerBuffer.capacity() + (extrasBuffer != null ? extrasBuffer.capacity() : 0) + (keyBuffer != null ? keyBuffer.capacity() : 0) + (valueBuffer != null ? valueBuffer.capacity() : 0); ChannelBuffer corkedResponse = cork(command.cmd.opaque, totalCapacity); corkedResponse.writeBytes(headerBuffer); if (extrasBuffer != null) corkedResponse.writeBytes(extrasBuffer); if (keyBuffer != null) corkedResponse.writeBytes(keyBuffer); if (valueBuffer != null) corkedResponse.writeBytes(valueBuffer); } else { // first write out any corked responses if (corkedBuffers.containsKey(command.cmd.opaque)) uncork(command.cmd.opaque, messageEvent.getChannel()); writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer); } } } private ChannelBuffer cork(int opaque, int totalCapacity) { if (corkedBuffers.containsKey(opaque)) { ChannelBuffer corkedResponse = corkedBuffers.get(opaque); ChannelBuffer oldBuffer = corkedResponse; corkedResponse = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, totalCapacity + corkedResponse.capacity()); corkedResponse.writeBytes(oldBuffer); oldBuffer.clear(); corkedBuffers.remove(opaque); corkedBuffers.put(opaque, corkedResponse); return corkedResponse; } else { ChannelBuffer buffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, totalCapacity); corkedBuffers.put(opaque, buffer); return buffer; } } private void uncork(int opaque, Channel channel) { ChannelBuffer corkedBuffer = corkedBuffers.get(opaque); assert corkedBuffer != null; channel.write(corkedBuffer); corkedBuffers.remove(opaque); } private void writePayload(MessageEvent messageEvent, ChannelBuffer extrasBuffer, ChannelBuffer keyBuffer, ChannelBuffer valueBuffer, ChannelBuffer headerBuffer) { if (messageEvent.getChannel().isOpen()) { messageEvent.getChannel().write(headerBuffer); if (extrasBuffer != null) messageEvent.getChannel().write(extrasBuffer); if (keyBuffer != null) messageEvent.getChannel().write(keyBuffer); if (valueBuffer != null) messageEvent.getChannel().write(valueBuffer); } } }