/* * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://www.apache.org/licenses/LICENSE-2.0 * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.amazon.opendistroforelasticsearch.ad.cluster; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.HashSet; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlocks; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.gateway.GatewayService; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; public class ADClusterEventListenerTests extends AbstractADTest { private final String masterNodeId = "masterNode"; private final String dataNode1Id = "dataNode1"; private final String clusterName = "multi-node-cluster"; private ClusterService clusterService; private ADClusterEventListener listener; private HashRing hashRing; private ModelManager modelManager; private ClusterState oldClusterState; private ClusterState newClusterState; private DiscoveryNode masterNode; private DiscoveryNode dataNode1; private DiscoveryNodeFilterer nodeFilter; @BeforeClass public static void setUpBeforeClass() { setUpThreadPool(ADClusterEventListenerTests.class.getSimpleName()); } @AfterClass public static void tearDownAfterClass() { tearDownThreadPool(); } @Override @Before public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(ADClusterEventListener.class); clusterService = createClusterService(threadPool); hashRing = mock(HashRing.class); when(hashRing.build()).thenReturn(true); modelManager = mock(ModelManager.class); nodeFilter = new DiscoveryNodeFilterer(clusterService); masterNode = new DiscoveryNode(masterNodeId, buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), emptyMap(), BUILT_IN_ROLES, Version.CURRENT); oldClusterState = ClusterState .builder(new ClusterName(clusterName)) .nodes(new DiscoveryNodes.Builder().masterNodeId(masterNodeId).localNodeId(masterNodeId).add(masterNode)) .build(); newClusterState = ClusterState .builder(new ClusterName(clusterName)) .nodes(new DiscoveryNodes.Builder().masterNodeId(masterNodeId).localNodeId(dataNode1Id).add(masterNode).add(dataNode1)) .build(); listener = new ADClusterEventListener(clusterService, hashRing, modelManager, nodeFilter); } @Override @After public void tearDown() throws Exception { super.tearDown(); super.tearDownLog4jForJUnit(); clusterService = null; hashRing = null; modelManager = null; oldClusterState = null; listener = null; } public void testIsMasterNode() { listener.clusterChanged(new ClusterChangedEvent("foo", oldClusterState, oldClusterState)); assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_NOT_APPLIED_MSG)); } public void testIsWarmNode() { HashMap<String, String> attributesForNode1 = new HashMap<>(); attributesForNode1.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), attributesForNode1, BUILT_IN_ROLES, Version.CURRENT); ClusterState warmNodeClusterState = ClusterState .builder(new ClusterName(clusterName)) .nodes(new DiscoveryNodes.Builder().masterNodeId(masterNodeId).localNodeId(dataNode1Id).add(masterNode).add(dataNode1)) .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", warmNodeClusterState, oldClusterState)); assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_NOT_APPLIED_MSG)); } public void testNotRecovered() { ClusterState blockedClusterState = ClusterState .builder(new ClusterName(clusterName)) .nodes(new DiscoveryNodes.Builder().masterNodeId(masterNodeId).localNodeId(dataNode1Id).add(masterNode).add(dataNode1)) .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", blockedClusterState, oldClusterState)); assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); } class ListenerRunnable implements Runnable { @Override public void run() { listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); } } public void testInprogress() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); final CountDownLatch executionLatch = new CountDownLatch(1); doAnswer(invocation -> { executionLatch.countDown(); inProgressLatch.await(); return emptySet(); }).when(modelManager).getAllModelIds(); new Thread(new ListenerRunnable()).start(); executionLatch.await(); listener.clusterChanged(new ClusterChangedEvent("bar", newClusterState, oldClusterState)); assertTrue(testAppender.containsMessage(ADClusterEventListener.IN_PROGRESS_MSG)); inProgressLatch.countDown(); } public void testNodeAdded() { String modelId = "123-threshold"; doAnswer(invocation -> { Set<String> res = new HashSet<>(); res.add(modelId); return res; }).when(modelManager).getAllModelIds(); doAnswer(invocation -> { return Optional.<DiscoveryNode>of(masterNode); }).when(hashRing).getOwningNode(any(String.class)); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_ADDED_MSG)); assertTrue(testAppender.containsMessage(ADClusterEventListener.REMOVE_MODEL_MSG + " " + modelId)); } public void testNodeRemoved() { ClusterState twoDataNodeClusterState = ClusterState .builder(new ClusterName(clusterName)) .nodes( new DiscoveryNodes.Builder() .masterNodeId(masterNodeId) .localNodeId(dataNode1Id) .add(new DiscoveryNode(masterNodeId, buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT)) .add(dataNode1) .add(new DiscoveryNode("dataNode2", buildNewFakeTransportAddress(), emptyMap(), BUILT_IN_ROLES, Version.CURRENT)) ) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, twoDataNodeClusterState)); assertTrue(!testAppender.containsMessage(ADClusterEventListener.NODE_ADDED_MSG)); assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_REMOVED_MSG)); } }