package com.scaleunlimited.flinkcrawler.functions; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.Gauge; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.scaleunlimited.flinkcrawler.config.CrawlTerminator; import com.scaleunlimited.flinkcrawler.metrics.CounterUtils; import com.scaleunlimited.flinkcrawler.metrics.CrawlerMetrics; import com.scaleunlimited.flinkcrawler.pojos.CrawlStateUrl; import com.scaleunlimited.flinkcrawler.pojos.DomainScore; import com.scaleunlimited.flinkcrawler.pojos.FetchStatus; import com.scaleunlimited.flinkcrawler.pojos.FetchUrl; import com.scaleunlimited.flinkcrawler.urldb.BaseUrlStateMerger; import com.scaleunlimited.flinkcrawler.urldb.BaseUrlStateMerger.MergeResult; import com.scaleunlimited.flinkcrawler.utils.FetchQueue; /** * The Flink operator that managed the URL portion of the "crawl DB". Incoming URLs are merged in memory, * using a map with key == URL hash, value = CrawlStateUrl. Whenever our timer fires, we emit * URLs from the fetch queue (if available). * * The use of the fetch queue lets us apply some heuristics to fetching the "best" (approximately) URLs, without having * to scan every URL. */ @SuppressWarnings("serial") public class UrlDBFunction extends BaseCoProcessFunction<CrawlStateUrl, DomainScore, FetchUrl> implements CheckpointedFunction { static final Logger LOGGER = LoggerFactory.getLogger(UrlDBFunction.class); // When we have to update the status of a URL in the fetch queue (because we're either fetching it, or it's // getting booted by a better URL) we output it via this side channel. public static final OutputTag<CrawlStateUrl> STATUS_OUTPUT_TAG = new OutputTag<CrawlStateUrl>("status"){}; private static final int MAX_IN_FLIGHT_URLS = 100; // Max and average time between checks for URLs to emit for a domain protected static final long MAX_DOMAIN_CHECK_INTERVAL = 1000; protected static final long AVERAGE_DOMAIN_CHECK_INTERVAL = 200; private BaseUrlStateMerger _merger; private CrawlTerminator _terminator; // List of URLs that are available to be fetched. private final FetchQueue _fetchQueue; // TODO - Sync up the total active urls value with state whenever we restore from state. private int _totalActiveUrls; private transient AtomicInteger _numInFlightUrls; private transient MapState<Long, CrawlStateUrl> _activeUrls; private transient MapState<Integer, Long> _activeUrlsIndex; private transient ValueState<Integer> _numActiveUrls; private transient ValueState<Integer> _activeIndex; private transient ValueState<String> _pld; private transient MapState<Long, CrawlStateUrl> _archivedUrls; private transient ValueState<Float> _domainScore; private transient CrawlStateUrl _mergedUrlState; private transient Set<String> _scoredDomains; private transient float _averageDomainScore; // TODO(kkrugler) remove this debugging code. private transient Map<String, Long> _inFlightUrls; public UrlDBFunction(CrawlTerminator terminator, BaseUrlStateMerger merger, FetchQueue fetchQueue) { _terminator = terminator; _merger = merger; _fetchQueue = fetchQueue; } @Override public void initializeState(FunctionInitializationContext context) throws Exception { // Create our keyed/managed states // 1. Active URLs: MapState (key = url hash, value = CrawlStateUrl) MapStateDescriptor<Long, CrawlStateUrl> urlStateDescriptor = new MapStateDescriptor<>( "active-urls", Long.class, CrawlStateUrl.class); _activeUrls = getRuntimeContext().getMapState(urlStateDescriptor); // 2. Active URL count: ValueState (value = number of entries in Active ULRs MapState) ValueStateDescriptor<Integer> urlCountDescriptor = new ValueStateDescriptor<>( "num-active-urls", TypeInformation.of(new TypeHint<Integer>() { })); _numActiveUrls = getRuntimeContext().getState(urlCountDescriptor); // 3. Index to next active URL: ValueState (Value = index of next URL to try to queue) ValueStateDescriptor<Integer> urlIndexDescriptor = new ValueStateDescriptor<>( "cur-url-index", TypeInformation.of(new TypeHint<Integer>() { })); _activeIndex = getRuntimeContext().getState(urlIndexDescriptor); // 4. Archived URLs: MapState (key = url hash, value = CrawlStateUrl) MapStateDescriptor<Long, CrawlStateUrl> archivedUrlsStateDescriptor = new MapStateDescriptor<>( "archived-urls", Long.class, CrawlStateUrl.class); _archivedUrls = getRuntimeContext().getMapState(archivedUrlsStateDescriptor); // 5. Index to hash mapping for active URLs (key = index, value = url hash) // // So if we have an index from 0...<active url count>-1, we can look up the // hash, and then use that to find the actual CrawlStateUrl in the active // URLs MapState. MapStateDescriptor<Integer, Long> activeUrlsIndexStateDescriptor = new MapStateDescriptor<>( "active-urls-index", Integer.class, Long.class); _activeUrlsIndex = getRuntimeContext().getMapState(activeUrlsIndexStateDescriptor); // 6. PLD (key) for current state, so we can access it in the timer handler. ValueStateDescriptor<String> pldDescriptor = new ValueStateDescriptor<>( "pld", TypeInformation.of(new TypeHint<String>() { })); _pld = getRuntimeContext().getState(pldDescriptor); // 7. Domain score for current PLD, for checkpointing. ValueStateDescriptor<Float> domainScoreDescriptor = new ValueStateDescriptor<>( "domain-score", TypeInformation.of(new TypeHint<Float>() { })); _domainScore = getRuntimeContext().getState(domainScoreDescriptor); } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); RuntimeContext context = getRuntimeContext(); context.getMetricGroup().gauge(CrawlerMetrics.GAUGE_URLS_IN_FETCH_QUEUE.toString(), new Gauge<Integer>() { @Override public Integer getValue() { return _fetchQueue.size(); } }); // Track how many URLs we think are being processed. context.getMetricGroup().gauge(CrawlerMetrics.GAUGE_URLS_IN_FLIGHT.toString(), new Gauge<Integer>() { @Override public Integer getValue() { return _numInFlightUrls.get(); } }); // Track the number of active URLs. context.getMetricGroup().gauge(CrawlerMetrics.GAUGE_URLS_ACTIVE.toString(), new Gauge<Integer>() { @Override public Integer getValue() { return _totalActiveUrls; } }); _mergedUrlState = new CrawlStateUrl(); _numInFlightUrls = new AtomicInteger(0); _fetchQueue.open(); _terminator.open(); _scoredDomains = new HashSet<>(); _averageDomainScore = 0.0f; _inFlightUrls = new HashMap<>(); } @Override public void processElement1(CrawlStateUrl url, Context ctx, Collector<FetchUrl> collector) throws Exception { record(this.getClass(), url, FetchStatus.class.getSimpleName(), url.getStatus().toString()); // See if we have this domain already. If not, create a timer. long processingTime = ctx.timerService().currentProcessingTime(); Integer numUrls = _numActiveUrls.value(); if (numUrls == null) { // Create entries for this domain in the various states. _numActiveUrls.update(0); _activeIndex.update(0); _pld.update(url.getPld()); // And we want to create a timer, so we have one per domain long nextTime = processingTime + 100; LOGGER.debug("Adding timer for domain {} at {}", url.getPld(), nextTime); ctx.timerService().registerProcessingTimeTimer(nextTime); } // Now update state for this URL, and potentially emit it if status is 'fetching'. processUrl(url, collector); } @Override public void onTimer(long timestamp, OnTimerContext ctx, Collector<FetchUrl> out) throws Exception { super.onTimer(timestamp, ctx, out); if (!_terminator.isTerminated()) { // See if we've got a URL that we want to add to the fetch queue. addUrlToFetchQueue(ctx); // See if we've got a URL in the fetch queue that we want to emit // (this sends it out via side channel back to us, so state is // properly updated) emitUrlFromFetchQueue(ctx); // Update our average domain score info if needed. String pld = _pld.value(); updateAverageDomainScore(pld); // And re-register the timer long fireAt = timestamp + checkIntervalForDomain(_pld.value()); LOGGER.debug("Resetting timer for domain {} to fire at {}", pld, fireAt); ctx.timerService().registerProcessingTimeTimer(fireAt); } else { LOGGER.info("Terminating timer for domain {}", _pld.value()); } } @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { // Nothing special we need to do here, since we have keyed state that gets // handled automatically. } @Override public void close() throws Exception { long curTime = System.currentTimeMillis(); for (String url : _inFlightUrls.keySet()) { LOGGER.debug("{}\t{}", curTime - _inFlightUrls.get(url), url); } super.close(); } /** * See if have a URL (for the current key/PLD) in our state that should be * added to the fetch queue. * * @param pld * @param collector * @throws Exception */ private void addUrlToFetchQueue(Context context) throws Exception { final boolean doTracing = LOGGER.isTraceEnabled(); Integer numUrls = _numActiveUrls.value(); if (numUrls == null) { // Since we never archive URLs currently, the number of active URLs for // the current (keyed) PLD will always be at least 1. And by "active" // I mean not archived; this doesn't have anything to do with in-flight // or fetching or queued or any other state. LOGGER.error("Houston, we have a problem - null active URL count for domain"); return; } int index = _activeIndex.value(); Long urlHash = _activeUrlsIndex.get(index); CrawlStateUrl stateUrl = _activeUrls.get(urlHash); CrawlStateUrl rejectedUrl = _fetchQueue.add(stateUrl); if (doTracing) { if (rejectedUrl == null) { LOGGER.trace( "UrlDBFunction ({}/{}) added '{}' to fetch queue", _partition, _parallelism, stateUrl); } else if (rejectedUrl == stateUrl) { // Happens constantly due to state, so just ignore } else { LOGGER.trace( "UrlDBFunction ({}/{}) added '{}' to fetch queue, removing '{}'", _partition, _parallelism, stateUrl, rejectedUrl); } } // If we get back null (there's space in the queue), or some other // URL (our URL was better) then this URL was added to the queue. if (rejectedUrl != stateUrl) { LOGGER.trace( "UrlDBFunction ({}/{}) setting '{}' state status to QUEUED", _partition, _parallelism, stateUrl); stateUrl.setStatus(FetchStatus.QUEUED); _activeUrls.put(urlHash, stateUrl); // If we just replaced a (lower scoring) URL on the fetch queue, // then we need to restore the rejected URL's status in the URL DB. if (rejectedUrl != null) { rejectedUrl.restorePreviousStatus(); LOGGER.trace( "UrlDBFunction ({}/{}) restored '{}' to previous status via side output", _partition, _parallelism, rejectedUrl); context.output(STATUS_OUTPUT_TAG, rejectedUrl); // Otherwise we just added one URL to the queue (vs. replacing one already there) } else { CounterUtils.increment(getRuntimeContext(), FetchStatus.QUEUED); } } // Update index of the URL we should check the next time we get called. index = (index + 1) % numUrls; _activeIndex.update(index); } /** * Given the domain <pld>, return how long between each check. For better domains * we want to check more often. * * @param pld * @return * @throws IOException */ private long checkIntervalForDomain(String pld) throws IOException { Float averagePageScore = _domainScore.value(); if (averagePageScore == null) { LOGGER.debug("UrlDBFunction ({}/{}) no average page scores yet for '{}'", _partition, _parallelism, pld); return MAX_DOMAIN_CHECK_INTERVAL; } // see where this domain's average page score falls, relative to the // overall average page score. Then constrain to range [1...MAX_DOMAIN_CHECK_INTERVAL] float ratio = _averageDomainScore / averagePageScore; long domainCheckInterval = Math.round(AVERAGE_DOMAIN_CHECK_INTERVAL * ratio); domainCheckInterval = Math.min(domainCheckInterval, MAX_DOMAIN_CHECK_INTERVAL); domainCheckInterval = Math.max(domainCheckInterval, 1); LOGGER.debug("UrlDBFunction ({}/{}) average page score for '{}' is {} (vs. average domain score {})", _partition, _parallelism, pld, averagePageScore, _averageDomainScore); LOGGER.debug("UrlDBFunction ({}/{}) returning {}ms between URL checks for '{}'", _partition, _parallelism, domainCheckInterval, pld); return domainCheckInterval; } /** * See if this is a PLD that we don't have in our set of "known" * PLDs for average domain score, and process if so. * * @param pld * @throws IOException */ private void updateAverageDomainScore(String pld) throws IOException { if (!_scoredDomains.contains(pld)) { Float score = _domainScore.value(); // We might not have received a score (via processElement2()) yet for this // PLD. If that's the case, there's nothing to update. if (score != null) { float summedScores = _averageDomainScore * _scoredDomains.size(); summedScores += score; _scoredDomains.add(pld); _averageDomainScore = summedScores / _scoredDomains.size(); } } } /** * See if there's a URL in the fetch queue that we should emit (via side output, so that * it loops around to the UrlDBFunction to update the status). * * @param context */ private void emitUrlFromFetchQueue(Context context) { final boolean doTracing = LOGGER.isTraceEnabled(); int activeUrls = _numInFlightUrls.get(); if (activeUrls > MAX_IN_FLIGHT_URLS) { if (doTracing) { LOGGER.trace( "UrlDBFunction ({}/{}) skipping emit, too many active URLs ({})", _partition, _parallelism, activeUrls); } return; } CrawlStateUrl crawlStateUrl = _fetchQueue.poll(); if (crawlStateUrl != null) { // Update the state of the URL in the URL DB so we know it's no longer just queued, // but now about to be fetched. It now goes into our side channel where it will // come back around to processRegularUrl which will save the state change // and then begin fetching it. LOGGER.trace( "UrlDBFunction ({}/{}) setting '{}' status to FETCHING via side output", _partition, _parallelism, crawlStateUrl); crawlStateUrl.setStatus(FetchStatus.FETCHING); crawlStateUrl.setStatusTime(System.currentTimeMillis()); context.output(STATUS_OUTPUT_TAG, crawlStateUrl); } } /** * We received a URL that we need to add (or merge) into our crawl state. * * @param url * @param collector * @throws Exception */ private void processUrl(CrawlStateUrl url, Collector<FetchUrl> collector) throws Exception { // If it's not an unfetched URL, we can decrement our active URLs FetchStatus newStatus = url.getStatus(); if ((newStatus != FetchStatus.UNFETCHED) && (url.getStatusTime() == 0)) { throw new RuntimeException( String.format("UrlDBFunction (%d/%d) got URL with invalid status time: %s", _partition, _parallelism, url)); } // If it's a URL just pulled from the fetch queue, then we need to emit it // for robots check. Its status time should be newer than what's in the URL DB // so it's guaranteed to win below and update the URL DB as well. if (newStatus == FetchStatus.FETCHING) { FetchUrl fetchUrl = new FetchUrl(url, url.getScore()); collector.collect(fetchUrl); int nowActive = _numInFlightUrls.incrementAndGet(); LOGGER.trace("UrlDBFunction ({}/{}) emitted URL '{}' ({} active)", _partition, _parallelism, fetchUrl, nowActive); _inFlightUrls.put(fetchUrl.getUrl(), System.currentTimeMillis()); // Otherwise, if it's the result of the fetch attempt then it's no longer active. } else if (newStatus != FetchStatus.UNFETCHED) { Long startTime = _inFlightUrls.remove(url.getUrl()); if (startTime == null) { throw new RuntimeException( String.format("UrlDBFunction (%d/%d) got URL not in active state: %s", _partition, _parallelism, url)); } LOGGER.trace("{}ms to process '{}'", System.currentTimeMillis() - startTime, url); int nowActive = _numInFlightUrls.decrementAndGet(); LOGGER.trace("UrlDBFunction ({}/{}) receiving URL {} ({} active)", _partition, _parallelism, url, nowActive); if (nowActive < 0) { throw new RuntimeException( String.format("UrlDBFunction (%d/%d) has negative in-flight URLs", _partition, _parallelism)); } } Long urlHash = url.makeKey(); if (_archivedUrls.contains(urlHash)) { // It's been archived, ignore... if (LOGGER.isTraceEnabled()) { LOGGER.trace("UrlDBFunction ({}/{}) ignoring archived URL '{}'", _partition, _parallelism, url); } // If the state is unfetched, we're all good, but if not then that's a logical error // as we shouldn't be emitting URLs that are archived. if (newStatus != FetchStatus.UNFETCHED) { throw new RuntimeException(String.format( "UrlDBFunction (%d/%d) got archived URL %s with active status %s", _partition, _parallelism, url, newStatus)); } } else { CrawlStateUrl stateUrl = _activeUrls.get(urlHash); if (stateUrl == null) { // We've never seen this URL before. // Better be unfetched. if (newStatus != FetchStatus.UNFETCHED) { throw new RuntimeException(String.format( "UrlDBFunction (%d/%d) got new URL '%s' with active status %s", _partition, _parallelism, url, newStatus)); } CounterUtils.increment(getRuntimeContext(), FetchStatus.UNFETCHED); if (LOGGER.isTraceEnabled()) { LOGGER.trace("UrlDBFunction ({}/{}) adding new URL '{}' to state", _partition, _parallelism, url); } // TODO need to copy URL if object reuse enabled? _activeUrls.put(urlHash, url); int numActiveUrls = _numActiveUrls.value(); _activeUrlsIndex.put(numActiveUrls, urlHash); _numActiveUrls.update(numActiveUrls + 1); _totalActiveUrls++; } else { if (LOGGER.isTraceEnabled()) { LOGGER.trace("UrlDBFunction ({}/{}) needs to merge incoming URL '{}' with '{}' (hash {})", _partition, _parallelism, url, stateUrl, urlHash); } FetchStatus oldStatus = stateUrl.getStatus(); if (mergeUrls(stateUrl, url)) { _activeUrls.put(urlHash, stateUrl); if (LOGGER.isTraceEnabled()) { LOGGER.trace("UrlDBFunction (({}/{}) updated state of URL '{}' (hash {})", _partition, _parallelism, stateUrl, urlHash); } CounterUtils.decrement(getRuntimeContext(), oldStatus); CounterUtils.increment(getRuntimeContext(), newStatus); } } } } private boolean mergeUrls(CrawlStateUrl stateUrl, CrawlStateUrl newUrl) { MergeResult result = _merger.doMerge(stateUrl, newUrl, _mergedUrlState); switch (result) { case USE_FIRST: // All set, stateUrl is what we want to use, so no update return false; case USE_SECOND: stateUrl.setFrom(newUrl); return true; case USE_MERGED: stateUrl.setFrom(_mergedUrlState); return true; default: throw new RuntimeException("Unknown merge result: " + result); } } /* * (non-Javadoc) * * @see org.apache.flink.streaming.api.functions.co.CoProcessFunction#processElement2(IN2, * org.apache.flink.streaming.api.functions.co.CoProcessFunction.Context, org.apache.flink.util.Collector) * * When we get a domain score, just update our internal state, and never emit anything. */ @Override public void processElement2(DomainScore domainScore, Context context, Collector<FetchUrl> out) throws Exception { // Ensure we don't wind up with DBZ problems. float score = Math.max(0.01f, domainScore.getScore()); String pld = domainScore.getPld(); LOGGER.debug("UrlDBFunction ({}/{}) setting '{}' average score to {}", _partition, _parallelism, pld, score); // At this point we might be seeing this PLD for the first time, or we might have seen // it before in this method, or we might have seen it via the onTimer call. So it may // or may not have any state set up, and it may or may not be in _domainScores (non-state) float summedScores = _averageDomainScore * _scoredDomains.size(); if (_scoredDomains.contains(pld)) { summedScores -= _domainScore.value(); } _domainScore.update(score); _scoredDomains.add(pld); summedScores += score; _averageDomainScore = summedScores / _scoredDomains.size(); } }