package com.linbit; import com.linbit.linstor.LinStorException; import com.linbit.linstor.logging.ErrorReporter; import com.linbit.linstor.netcom.Peer; import com.linbit.linstor.security.AccessContext; import com.linbit.linstor.security.AccessDeniedException; import java.nio.file.Path; import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.slf4j.event.Level; public class WorkerPoolTest { private static final int DEFAULT_THREAD_COUNT = 3; private static final int DEFAULT_QUEUE_SIZE = 10; private static final boolean DEFAULT_FAIRNESS = true; private static final String DEFAULT_THREAD_PREFIX = "TestWorkerThread"; private static final TestErrorReporter DEFAULT_ERROR_REPORTER = new TestErrorReporter(); private WorkerPool pool; @Before public void setUp() throws Exception { } @After public void tearDown() throws Exception { if (pool != null) { pool.shutdown(); } } @Test public void testCreatedThreadAmount() { pool = new WorkerPoolBuilder().build(); Assert.assertEquals("Unexpected worker count", DEFAULT_THREAD_COUNT, pool.getThreadCount()); } @Test public void testQueueSize() { pool = new WorkerPoolBuilder().build(); Assert.assertEquals("Unexpected worker count", DEFAULT_QUEUE_SIZE, pool.getQueueSize()); } @Test public void testFairness() { pool = new WorkerPoolBuilder().build(); Assert.assertEquals("Unexpected fairness", DEFAULT_FAIRNESS, pool.isFairQueue()); } @Test public void testThreadPrefix() { WorkerPoolBuilder workerPoolBuilder = new WorkerPoolBuilder(); pool = workerPoolBuilder.build(); int prefixedThreadCount = getPrefixedThreadCount(workerPoolBuilder.threadPrefix); Assert.assertEquals("Unexpected prefixed thread count", prefixedThreadCount, DEFAULT_THREAD_COUNT); } @Test @SuppressWarnings("checkstyle:magicnumber") public void testShutdown() throws InterruptedException, ExecutionException, TimeoutException { WorkerPoolBuilder workerPoolBuilder = new WorkerPoolBuilder(); pool = workerPoolBuilder.build(); int prefixedThreadCount = getPrefixedThreadCount(workerPoolBuilder.threadPrefix); Assert.assertEquals("Unexpected prefixed thread count", prefixedThreadCount, DEFAULT_THREAD_COUNT); pool.shutdown(); waitUntilPoolFinishes(); prefixedThreadCount = getPrefixedThreadCount(workerPoolBuilder.threadPrefix); for (int waitTimes = 0; prefixedThreadCount > 0 && waitTimes < 10; waitTimes++) { // wait a little longer (max 1 sec) Thread.sleep(100); prefixedThreadCount = getPrefixedThreadCount(workerPoolBuilder.threadPrefix); } Assert.assertEquals("Worker threads still running", prefixedThreadCount, 0); } @Test public void testSumbitSimpleTask() throws InterruptedException, ExecutionException, TimeoutException { pool = new WorkerPoolBuilder().build(); final AtomicInteger finishedTasks = new AtomicInteger(0); Runnable task = new Runnable() { @Override public void run() { finishedTasks.incrementAndGet(); } }; final int taskCount = DEFAULT_QUEUE_SIZE; for (int idx = 0; idx < taskCount; idx++) { pool.submit(task); } waitUntilPoolFinishes(); Assert.assertEquals("Not all tasks were executed", finishedTasks.get(), taskCount); } @Test public void testSubmitTaskWithException() throws InterruptedException { TestErrorReporter errorReporter = new TestErrorReporter(); pool = new WorkerPoolBuilder().errorReporter(errorReporter).build(); final int exceptionId = 1; Runnable taks = new Runnable() { @Override public void run() { throw new TestException(exceptionId); } }; pool.submit(taks); final Throwable throwable = errorReporter.unexpected.poll(10, TimeUnit.SECONDS); Assert.assertEquals("Unexpected throwable received", throwable.getClass(), TestException.class); Assert.assertEquals("Unexpected exception id received", ((TestException) throwable).id, exceptionId); } @Test public void testSubmitTaskWithImplementationError() throws InterruptedException { TestErrorReporter errorReporter = new TestErrorReporter(); pool = new WorkerPoolBuilder().errorReporter(errorReporter).build(); final int exceptionId = 1; Runnable taks = new Runnable() { @Override public void run() { throw new ImplementationError(new TestException(exceptionId)); } }; pool.submit(taks); final Throwable throwable = errorReporter.unexpected.poll(10, TimeUnit.SECONDS); Assert.assertEquals("Unexpected throwable received", throwable.getClass(), ImplementationError.class); final Throwable cause = throwable.getCause(); Assert.assertEquals("Unexpected cause received", cause.getClass(), TestException.class); Assert.assertEquals("Unexpected exception id received", ((TestException) cause).id, exceptionId); } private int getPrefixedThreadCount(String threadPrefix) { Set<Thread> threads = Thread.getAllStackTraces().keySet(); int prefixedThreadCount = 0; for (Thread thread : threads) { if (thread.getName().startsWith(threadPrefix)) { prefixedThreadCount++; } } return prefixedThreadCount; } @SuppressWarnings("checkstyle:magicnumber") private void waitUntilPoolFinishes() throws InterruptedException, ExecutionException, TimeoutException { exec( new Runnable() { @Override public void run() { pool.finish(); } }, 15_000 ); } private void exec(Runnable task, long millisec) throws InterruptedException, ExecutionException, TimeoutException { ExecutorService executor = Executors.newSingleThreadExecutor(); executor.submit(task).get(millisec, TimeUnit.MILLISECONDS); executor.shutdown(); } private static class TestErrorReporter implements ErrorReporter { private BlockingQueue<Throwable> unexpected = new LinkedBlockingQueue<>(); @Override public String getInstanceId() { // Hex instance ID of linstor's error reporter // Not significant for the test, just needs to return something to implement the interface return "CAFEAFFE"; } @Override public String reportError(Throwable throwable) { unexpected.add(throwable); return null; // no error report, no logName } @Override public String reportError(Level logLevel, Throwable errorInfo) { unexpected.add(errorInfo); return null; // no error report, no logName } @Override public String reportError( Level logLevel, Throwable errorInfo, AccessContext accCtx, Peer client, String contextInfo ) { unexpected.add(errorInfo); return null; // no error report, no logName } @Override public String reportError( Throwable errorInfo, AccessContext accCtx, Peer client, String contextInfo ) { unexpected.add(errorInfo); return null; // no error report, no logName } @Override public String reportProblem( Level logLevel, LinStorException errorInfo, AccessContext accCtx, Peer client, String contextInfo ) { unexpected.add(errorInfo); return null; // no error report, no logName } @Override public void logTrace(String format, Object... args) { log("TRACE", format, args); } @Override public void logDebug(String format, Object... args) { log("DEBUG", format, args); } @Override public void logInfo(String format, Object... args) { log("INFO ", format, args); } @Override public void logWarning(String format, Object... args) { log("WARN ", format, args); } @Override public void logError(String format, Object... args) { log("ERROR", format, args); } private void log(String type, String format, Object[] args) { System.err.printf( "%s %s\\n", type, String.format( format, args ) ); } @Override public boolean setLogLevel(AccessContext accCtx, Level level, Level linstorLevel) throws AccessDeniedException { // Tracing on/off not implemented, no-op return false; } @Override public boolean hasAtLeastLogLevel(Level leveRef) { return true; } @Override public Level getCurrentLogLevel() { return Level.TRACE; } @Override public Path getLogDirectory() { return null; } } private static class WorkerPoolBuilder { private static final AtomicInteger ID_GEN = new AtomicInteger(0); private int parallelism = DEFAULT_THREAD_COUNT; private int queueSize = DEFAULT_QUEUE_SIZE; private boolean fairness = DEFAULT_FAIRNESS; private String threadPrefix; private ErrorReporter errorReporter = DEFAULT_ERROR_REPORTER; private WorkerPoolBuilder() { threadPrefix = DEFAULT_THREAD_PREFIX + "_" + Integer.toString(ID_GEN.incrementAndGet()); } public WorkerPool build() { return WorkerPool.initialize(parallelism, queueSize, fairness, threadPrefix, errorReporter, null); } public WorkerPoolBuilder parallelism(int parallelismRef) { parallelism = parallelismRef; return this; } public WorkerPoolBuilder queueSize(int queueSizeRef) { queueSize = queueSizeRef; return this; } public WorkerPoolBuilder fair(boolean fairnessRef) { fairness = fairnessRef; return this; } public WorkerPoolBuilder threadPrefix(String prefixRef) { threadPrefix = prefixRef; return this; } public WorkerPoolBuilder errorReporter(ErrorReporter reporter) { errorReporter = reporter; return this; } } private static class TestException extends RuntimeException { private static final long serialVersionUID = 1308902008629518167L; public int id; TestException(int idRef) { super(Integer.toString(idRef)); id = idRef; } } }