// Copyright 2017 The Bazel Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package build.buildfarm.instance.shard; import static build.buildfarm.instance.shard.RedisShardBackplane.parseOperationChange; import static build.buildfarm.instance.shard.RedisShardBackplane.parseWorkerChange; import static java.lang.String.format; import build.buildfarm.instance.WatchFuture; import build.buildfarm.v1test.OperationChange; import build.buildfarm.v1test.WorkerChange; import com.google.common.collect.ImmutableList; import com.google.common.collect.ListMultimap; import com.google.common.util.concurrent.ListenableFuture; import com.google.longrunning.Operation; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Timestamp; import java.time.Instant; import java.util.List; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; import redis.clients.jedis.Client; import redis.clients.jedis.JedisPubSub; class RedisShardSubscriber extends JedisPubSub { private static final Logger logger = Logger.getLogger(RedisShardSubscriber.class.getName()); abstract static class TimedWatchFuture extends WatchFuture { private final TimedWatcher watcher; TimedWatchFuture(TimedWatcher watcher) { super(watcher::observe); this.watcher = watcher; } TimedWatcher getWatcher() { return watcher; } void complete() { super.set(null); } } private final ListMultimap<String, TimedWatchFuture> watchers; private final Set<String> workers; private final String workerChannel; private final Executor executor; RedisShardSubscriber( ListMultimap<String, TimedWatchFuture> watchers, Set<String> workers, String workerChannel, Executor executor) { this.watchers = watchers; this.workers = workers; this.workerChannel = workerChannel; this.executor = executor; } public List<String> watchedOperationChannels() { synchronized (watchers) { return ImmutableList.copyOf(watchers.keySet()); } } public List<String> subscribedChannels() { ImmutableList.Builder<String> channels = ImmutableList.builder(); synchronized (watchers) { channels.addAll(watchers.keySet()); } return channels.add(workerChannel).build(); } public List<String> expiredWatchedOperationChannels(Instant now) { ImmutableList.Builder<String> builder = ImmutableList.builder(); synchronized (watchers) { for (String channel : watchers.keySet()) { for (TimedWatchFuture watchFuture : watchers.get(channel)) { if (watchFuture.getWatcher().isExpiredAt(now)) { builder.add(channel); break; } } } } return builder.build(); } // synchronizing on these because the client has been observed to // cause protocol desynchronization for multiple concurrent calls @Override public synchronized void unsubscribe() { if (isSubscribed()) { super.unsubscribe(); } } @Override public synchronized void unsubscribe(String... channels) { super.unsubscribe(channels); } @Override public synchronized void subscribe(String... channels) { super.subscribe(channels); } @Override public synchronized void psubscribe(String... patterns) { super.psubscribe(patterns); } @Override public synchronized void punsubscribe() { super.punsubscribe(); } @Override public synchronized void punsubscribe(String... patterns) { super.punsubscribe(patterns); } @Override public synchronized void ping() { super.ping(); } public ListenableFuture<Void> watch(String channel, TimedWatcher watcher) { TimedWatchFuture watchFuture = new TimedWatchFuture(watcher) { @Override public void unwatch() { logger.log(Level.FINE, format("unwatching %s", channel)); RedisShardSubscriber.this.unwatch(channel, this); } }; boolean hasSubscribed; synchronized (watchers) { // use prefix hasSubscribed = watchers.containsKey(channel); watchers.put(channel, watchFuture); if (!hasSubscribed) { subscribe(channel); } } return watchFuture; } public void unwatch(String channel, TimedWatchFuture watchFuture) { synchronized (watchers) { if (watchers.remove(channel, watchFuture) && !watchers.containsKey(channel)) { unsubscribe(channel); } } } public void resetWatchers(String channel, Instant expiresAt) { List<TimedWatchFuture> operationWatchers = watchers.get(channel); synchronized (watchers) { for (TimedWatchFuture watchFuture : operationWatchers) { watchFuture.getWatcher().reset(expiresAt); } } } private void terminateExpiredWatchers(String channel, Instant now, boolean force) { onOperation( channel, /* operation=*/ null, (watcher) -> { boolean expired = force || watcher.isExpiredAt(now); if (expired) { logger.log( Level.SEVERE, format( "Terminating expired watcher of %s because: %s >= %s%s", channel, now, watcher.getExpiresAt(), force ? " with force" : "")); } return expired; }, /* expiresAt=*/ null); } public void onOperation(String channel, Operation operation, Instant expiresAt) { onOperation(channel, operation, (watcher) -> true, expiresAt); } private void onOperation( String channel, @Nullable Operation operation, Predicate<TimedWatcher> shouldObserve, @Nullable Instant expiresAt) { List<TimedWatchFuture> operationWatchers = watchers.get(channel); boolean observe = operation == null || operation.hasMetadata() || operation.getDone(); logger.log(Level.FINE, format("onOperation %s: %s", channel, operation)); synchronized (watchers) { ImmutableList.Builder<Consumer<Operation>> observers = ImmutableList.builder(); for (TimedWatchFuture watchFuture : operationWatchers) { TimedWatcher watcher = watchFuture.getWatcher(); if (expiresAt != null) { watcher.reset(expiresAt); } if (shouldObserve.test(watcher)) { observers.add(watchFuture::observe); } } for (Consumer<Operation> observer : observers.build()) { executor.execute( () -> { if (observe) { logger.log(Level.FINE, "observing " + operation); observer.accept(operation); } }); } } } @Override public void onMessage(String channel, String message) { if (channel.equals(workerChannel)) { onWorkerMessage(message); } else { onOperationMessage(channel, message); } } void onWorkerMessage(String message) { try { onWorkerChange(parseWorkerChange(message)); } catch (InvalidProtocolBufferException e) { logger.log(Level.INFO, format("invalid worker change message: %s", message), e); } } void onWorkerChange(WorkerChange workerChange) { switch (workerChange.getTypeCase()) { case TYPE_NOT_SET: logger.log( Level.SEVERE, format( "WorkerChange oneof type is not set from %s at %s", workerChange.getName(), workerChange.getEffectiveAt())); break; case ADD: addWorker(workerChange.getName()); break; case REMOVE: removeWorker(workerChange.getName()); break; } } void addWorker(String worker) { synchronized (workers) { workers.add(worker); } } boolean removeWorker(String worker) { synchronized (workers) { return workers.remove(worker); } } void onOperationMessage(String channel, String message) { try { onOperationChange(channel, parseOperationChange(message)); } catch (InvalidProtocolBufferException e) { logger.log( Level.INFO, format("invalid operation change message for %s: %s", channel, message), e); } } static Instant toInstant(Timestamp timestamp) { return Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()); } void resetOperation(String channel, OperationChange.Reset reset) { onOperation(channel, reset.getOperation(), toInstant(reset.getExpiresAt())); } void onOperationChange(String channel, OperationChange operationChange) { // FIXME indicate lag/clock skew for OOB timestamps switch (operationChange.getTypeCase()) { case TYPE_NOT_SET: // FIXME present nice timestamp logger.log( Level.SEVERE, format( "OperationChange oneof type is not set from %s at %s", operationChange.getSource(), operationChange.getEffectiveAt())); break; case RESET: resetOperation(channel, operationChange.getReset()); break; case EXPIRE: terminateExpiredWatchers( channel, toInstant(operationChange.getEffectiveAt()), operationChange.getExpire().getForce()); break; } } @Override public void onSubscribe(String channel, int subscribedChannels) {} @Override public void onUnsubscribe(String channel, int subscribedChannels) { List<TimedWatchFuture> operationWatchers; synchronized (watchers) { operationWatchers = watchers.removeAll(channel); } for (TimedWatchFuture watchFuture : operationWatchers) { watchFuture.complete(); } } private String[] placeholderChannel() { String[] channels = new String[1]; channels[0] = "placeholder-shard-subscription"; return channels; } @Override public void proceed(Client client, String... channels) { if (channels.length == 0) { channels = placeholderChannel(); } super.proceed(client, channels); } }