/* * Copyright 2019 DeNA Co., Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package packetproxy.common; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang3.ArrayUtils; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; public class Protobuf3 { public static class Key { public static enum Type { Variant, Bit64, LengthDelimited, StartGroup, EndGroup, Bit32, None } long fieldNumber = 0; Type wireType = Type.None; public Key(long fieldNumber, Type wireType) { this.fieldNumber = fieldNumber; this.wireType = wireType; } public Key(long keyData) { init(keyData); } public Key(ByteArrayInputStream data) { init(decodeVar(data)); } private void init(long keyData) { this.fieldNumber = keyData >> 3; this.wireType = Type.values()[(int)(keyData & 0x07)]; } public Type getWireType() { return this.wireType; } public long getFieldNumber() { return this.fieldNumber; } @Override public String toString() { return String.format("Key[FieldNum:%d, Type:%s]", fieldNumber, wireType); } public void writeTo(ByteArrayOutputStream output) { writeVar((fieldNumber << 3)|(wireType.ordinal() & 7), output); } } public static boolean validateVar(ByteArrayInputStream input) { byte[] raw = new byte[input.available()]; input.mark(input.available()); input.read(raw, 0, input.available()); boolean ret = validateVar(raw); input.reset(); return ret; } public static boolean validateVar(byte[] input) { return validateVar(input, null); } public static boolean validateVar(byte[] input, int[] outLength) { long var = 0; int i = 0; while (0 < input.length) { long nextB = input[i] & 0xff; var = var | ((nextB & 0x7f) << (7*i)); i++; if ((nextB & 0x80) == 0) break; if (i > 9) // max 64bit (long size) return false; if (i == input.length) return false; } if (outLength != null) outLength[0] = i; return true; } public static long decodeVar(ByteArrayInputStream input) { long var = 0; for (long i = 0; input.available() > 0; i++) { long nextB = (byte)(input.read() & 0xff) ; var = var | ((nextB & 0x7f) << (7*i)); if ((nextB & 0x80) == 0) break; } return var; } public static void writeVar(long var, ByteArrayOutputStream output) { for (int i = 1; i <= 10; ++i) { byte b = (byte)(var & 0x7f); if (i == 10) { var = 0; } else { var = (var >>> 7); } if (var == 0) { output.write(b); break; } else { output.write((byte)(b|0x80)); } } } public static boolean validateBit64(ByteArrayInputStream input) { return input.available() < 8 ? false : true; } public static boolean validateBit64(byte[] input) { return input.length < 8 ? false : true; } public static long decodeBit64(ByteArrayInputStream input) throws Exception { long bit64 = 0; for (int idx = 0; idx < 8; idx++) { long nextB = input.read(); bit64 = bit64 | (nextB << (8*idx)); } return bit64; } public static boolean validateBit32(ByteArrayInputStream input) { return input.available() < 4 ? false : true; } public static boolean validateBit32(byte[] input) { return input.length < 4 ? false : true; } public static int decodeBit32(ByteArrayInputStream input) throws Exception { int bit32 = 0; for (int idx = 0; idx < 4; idx++) { int nextB = input.read(); bit32 = bit32 | (nextB << (8*idx)); } return bit32; } /* 注意:repeatedデータが、inputバッファとぴったり合わないとfalse */ public static boolean validateRepeatedStrictly(byte[] input) { int i = 0; while (i < input.length) { byte[] subInput = ArrayUtils.subarray(input, i, input.length); int[] varLen = new int[1]; if (validateVar(subInput, varLen) == false) { return false; } i = i + varLen[0]; } return true; } public static List<Object> decodeRepeated(ByteArrayInputStream input) { List<Object> list = new LinkedList<>(); while (input.available() > 0) { long var = decodeVar(input); list.add(var); } return list; } public static String decodeBytes(byte[] rawSubData) { return IntStream.range(0, rawSubData.length).mapToObj(i->String.format("%02x", rawSubData[i])).collect(Collectors.joining(":")); } public static byte[] encodeBytes(String bytes) throws Exception { String hexStr = bytes.replace(":", ""); return new Binary(new Binary.HexString(hexStr)).toByteArray(); } public static String decode(byte[] input) throws Exception { ByteArrayInputStream data = new ByteArrayInputStream(input); Map<String,Object> messages = new TreeMap<>(); decodeData(data, messages); ObjectMapper mapper = new ObjectMapper(); return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(messages); } public static byte[] encode(String input) throws Exception { ObjectMapper mapper = new ObjectMapper(); HashMap<String,Object> messages = mapper.readValue(input, new TypeReference<HashMap<String,Object>>(){}); return encodeData(messages); } public static boolean decodeData(ByteArrayInputStream data, Map<String,Object> messages) throws Exception { int ordinary = 0; while (data.available() > 0) { Key key = new Key(data); switch (key.getWireType()) { case Variant: { if (validateVar(data) == false) { return false; } long variant = decodeVar(data); messages.put(String.format("%02x:%02x:Varint", key.getFieldNumber(), ordinary), variant); break; } case Bit32: { if (validateBit32(data) == false) { return false; } int bit32 = decodeBit32(data); messages.put(String.format("%02x:%02x:32-bit", key.getFieldNumber(), ordinary), bit32); break; } case Bit64: { if (validateBit64(data) == false) { return false; } long bit64 = decodeBit64(data); messages.put(String.format("%02x:%02x:64-bit", key.getFieldNumber(), ordinary), bit64); break; } case LengthDelimited: { if (validateVar(data) == false) { return false; } long length = decodeVar(data); if (length > data.available()) { return false; } byte[] rawSubData = new byte[(int)length]; data.read(rawSubData, 0, (int)length); /* String */ if (StringUtils.validatePrintableUTF8(rawSubData)) { messages.put(String.format("%02x:%02x:String", key.getFieldNumber(), ordinary), new String(rawSubData, "UTF-8")); break; } /* Data */ Map<String,Object> subMsg = new TreeMap<>(); if (decodeData(new ByteArrayInputStream(rawSubData), subMsg) == true) { messages.put(String.format("%02x:%02x:embedded message", key.getFieldNumber(), ordinary), subMsg); break; } /* Repeated */ if (validateRepeatedStrictly(rawSubData) == true) { List<Object> list = decodeRepeated(new ByteArrayInputStream(rawSubData)); messages.put(String.format("%02x:%02x:repeated", key.getFieldNumber(), ordinary), list); break; } /* Bytes */ String result = decodeBytes(rawSubData); messages.put(String.format("%02x:%02x:bytes", key.getFieldNumber(), ordinary), result); break; } default: return false; } ordinary++; } return true; } public static byte[] encodeData(Map<String,Object> messages) throws Exception { ByteArrayOutputStream output = new ByteArrayOutputStream(); String[] orderedKeys = new String[messages.keySet().size()]; messages.keySet().stream().forEach(key -> { String[] keyval = key.split(":"); int ordinary = Integer.parseInt(keyval[1], 16); orderedKeys[ordinary] = key; }); for (String key : orderedKeys) { String[] keyval = key.split(":"); long fieldNumber = Long.parseLong(keyval[0], 16); String type = keyval[2]; switch (type) { case "Varint": { new Key(fieldNumber, Key.Type.Variant).writeTo(output); Object d = messages.get(key); long var = 0; if (d instanceof Integer) { var = ((Integer) d).longValue(); } else if (d instanceof Long){ var = ((Long) d).longValue(); } writeVar(var, output); break; } case "String": { new Key(fieldNumber, Key.Type.LengthDelimited).writeTo(output); String str = (String)messages.get(key); writeVar(str.getBytes().length, output); output.write(str.getBytes()); break; } case "32-bit": { new Key(fieldNumber, Key.Type.Bit32).writeTo(output); int bit32 = (int)messages.get(key); output.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(bit32).array()); break; } case "64-bit": { new Key(fieldNumber, Key.Type.Bit64).writeTo(output); Object d = messages.get(key); long bit64 = 0; if (d instanceof Integer) { bit64 = ((Integer) d).longValue(); } else if (d instanceof Long){ bit64 = ((Long) d).longValue(); } output.write(ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(bit64).array()); break; } case "repeated": { new Key(fieldNumber, Key.Type.LengthDelimited).writeTo(output); List<Object> list = (List<Object>)messages.get(key); ByteArrayOutputStream tmp = new ByteArrayOutputStream(); list.stream().forEach(o -> { long var = 0; if (o instanceof Integer) { var = ((Integer) o).longValue(); } else if (o instanceof Long){ var = ((Long) o).longValue(); } else { System.err.println("Unknown object type"); } writeVar(var, tmp); }); writeVar(tmp.toByteArray().length, output); output.write(tmp.toByteArray()); break; } case "embedded message": { new Key(fieldNumber, Key.Type.LengthDelimited).writeTo(output); byte[] tmp = encodeData((Map<String,Object>)messages.get(key)); writeVar(tmp.length, output); output.write(tmp); break; } case "bytes": { new Key(fieldNumber, Key.Type.LengthDelimited).writeTo(output); byte[] bytes = encodeBytes((String)messages.get(key)); writeVar(bytes.length, output); output.write(bytes); break; } default: System.err.println(String.format("Unknown type: %s", type)); } } return output.toByteArray(); } }