package jwp.fuzz; import org.objectweb.asm.ClassReader; import org.objectweb.asm.util.TraceClassVisitor; import java.io.PrintWriter; import java.io.StringWriter; import java.util.*; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.IntConsumer; import java.util.function.Predicate; import java.util.stream.IntStream; import java.util.stream.Stream; import java.util.stream.StreamSupport; /** Utility functions and classes */ public class Util { public static byte byte0(short val) { return (byte) val; } public static byte byte1(short val) { return (byte) (val >> 8); } public static byte byte0(int val) { return (byte) val; } public static byte byte1(int val) { return (byte) (val >> 8); } public static byte byte2(int val) { return (byte) (val >> 16); } public static byte byte3(int val) { return (byte) (val >> 24); } // Mutates array, only use with temp arrays public static boolean checkConsecutiveBitsFlipped(byte[] bytes, Predicate<byte[]> pred) { int bitCount = bytes.length; for (int i = -1; i < bitCount; i++) { if (i >= 0) { flipBit(bytes, i); if (pred.test(bytes)) return true; } if (i < bitCount - 1) { flipBit(bytes, i + 1); if (pred.test(bytes)) return true; if (i < bitCount - 2) { flipBit(bytes, i + 2); // Only check this 3 spot if it's exactly 3 until the end if (i == bitCount - 3 && pred.test(bytes)) return true; if (i < bitCount - 3) { flipBit(bytes, i + 3); if (pred.test(bytes)) return true; flipBit(bytes, i + 3); } flipBit(bytes, i + 2); } flipBit(bytes, i + 1); } if (i >= 0) flipBit(bytes, i); } return false; } public static String classBytesToString(byte[] bytes) { ClassReader reader = new ClassReader(bytes); StringWriter stringWriter = new StringWriter(); reader.accept(new TraceClassVisitor(new PrintWriter(stringWriter)), 0); return stringWriter.toString(); } public static boolean contains(byte[] arr, byte item) { for (byte arrItem : arr) if (arrItem == item) return true; return false; } public static boolean contains(short[] arr, short item) { for (short arrItem : arr) if (arrItem == item) return true; return false; } public static boolean contains(int[] arr, int item) { for (int arrItem : arr) if (arrItem == item) return true; return false; } public static boolean couldHaveBitFlippedTo(byte curr, byte... newBytes) { return contains(newBytes, curr) || checkConsecutiveBitsFlipped(new byte[] { curr }, arr -> contains(newBytes, arr[0])); } public static boolean couldHaveBitFlippedTo(short curr, short... newShorts) { return contains(newShorts, curr) || checkConsecutiveBitsFlipped(toByteArray(curr), arr -> contains(newShorts, getShortLe(arr, 0))); } public static boolean couldHaveBitFlippedTo(int curr, int... newInts) { return contains(newInts, curr) || checkConsecutiveBitsFlipped(toByteArray(curr), arr -> contains(newInts, getIntLe(arr, 0))); } public static short endianSwapped(short val) { return shortFromBytes(byte1(val), byte0(val)); } public static int endianSwapped(int val) { return intFromBytes(byte3(val), byte2(val), byte1(val), byte0(val)); } public static void flipBit(byte[] arr, int bitIndex) { flipBit(arr, bitIndex / 8, bitIndex % 8); } public static void flipBit(byte[] arr, int byteIndex, int bitIndex) { arr[byteIndex] ^= (1 << bitIndex); } public static byte flipBit(byte byt, int bitIndex) { return (byte) (byt ^ (1 << bitIndex)); } public static int getIntBe(byte[] arr, int byteIndex) { return intFromBytes(arr[byteIndex + 3], arr[byteIndex + 2], arr[byteIndex + 1], arr[byteIndex]); } public static int getIntLe(byte[] arr, int byteIndex) { return intFromBytes(arr[byteIndex], arr[byteIndex + 1], arr[byteIndex + 2], arr[byteIndex + 3]); } public static short getShortBe(byte[] arr, int byteIndex) { return shortFromBytes(arr[byteIndex + 1], arr[byteIndex]); } public static short getShortLe(byte[] arr, int byteIndex) { return shortFromBytes(arr[byteIndex], arr[byteIndex + 1]); } public static int intFromBytes(byte byte0, byte byte1, byte byte2, byte byte3) { return (byte3 << 24) | ((byte2 & 0xFF) << 16) | ((byte2 & 0xFF) << 8) | (byte0 & 0xFF); } public static void putIntBe(byte[] arr, int byteIndex, int val) { arr[byteIndex] = byte3(val); arr[byteIndex + 1] = byte2(val); arr[byteIndex + 2] = byte1(val); arr[byteIndex + 3] = byte0(val); } public static void putIntLe(byte[] arr, int byteIndex, int val) { arr[byteIndex] = byte0(val); arr[byteIndex + 1] = byte1(val); arr[byteIndex + 2] = byte2(val); arr[byteIndex + 3] = byte3(val); } public static void putShortBe(byte[] arr, int byteIndex, short val) { arr[byteIndex] = byte1(val); arr[byteIndex + 1] = byte0(val); } public static void putShortLe(byte[] arr, int byteIndex, short val) { arr[byteIndex] = byte0(val); arr[byteIndex + 1] = byte1(val); } public static short shortFromBytes(byte byte0, byte byte1) { return (short) ((byte1 << 8) | (byte0 & 0xFF)); } @SafeVarargs public static <T> Stream<T> streamOfNotNull(T... items) { return Stream.of(items).filter(Objects::nonNull); } public static byte[] streamToByteArray(IntStream stream) { int[] ints = stream.toArray(); byte[] ret = new byte[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (byte) ints[i]; return ret; } public static short[] streamToShortArray(IntStream stream) { int[] ints = stream.toArray(); short[] ret = new short[ints.length]; for (int i = 0; i < ints.length; i++) ret[i] = (short) ints[i]; return ret; } public static <T> Stream<T> streamCharacteristics(Stream<T> stream, IntConsumer consumer) { Spliterator<T> spliterator = stream.spliterator(); consumer.accept(spliterator.characteristics()); return StreamSupport.stream(spliterator, stream.isParallel()); } public static byte[] toByteArray(short val) { return new byte[] { byte0(val), byte1(val) }; } public static byte[] toByteArray(int val) { return new byte[] { byte0(val), byte1(val), byte2(val), byte3(val) }; } public static byte[] withBytesRemoved(byte[] bytes, int start, int amount) { byte[] newArr = new byte[bytes.length - amount]; System.arraycopy(bytes, 0, newArr, 0, start); System.arraycopy(bytes, start + amount, newArr, start, newArr.length - start); return newArr; } public static byte[] withCopiedBytes(byte[] arr, Consumer<byte[]> fn) { byte[] ret = Arrays.copyOf(arr, arr.length); fn.accept(ret); return ret; } /** An executor service that only runs submissions on the current thread */ public static class CurrentThreadExecutorService extends ThreadPoolExecutor { public CurrentThreadExecutorService() { super(0, 1, 0, TimeUnit.SECONDS, new SynchronousQueue<>(), new ThreadPoolExecutor.CallerRunsPolicy()); } @Override public void execute(Runnable command) { getRejectedExecutionHandler().rejectedExecution(command, this); } } /** Base iterator that is considered finished the first time it sees a null */ public static abstract class NullMeansCompleteIterator<T> implements Iterator<T> { protected T prev; protected boolean finished = false; /** Provide the next item or null to finish this iterator */ protected abstract T doNext(); @Override public boolean hasNext() { return !finished && next(true) != null; } @Override public T next() { if (finished) throw new NoSuchElementException(); return next(false); } protected T next(boolean canCacheResult) { T ret = prev == null ? doNext() : prev; if (ret == null) finished = true; prev = canCacheResult ? ret : null; return ret; } } }