package org.nd4j.linalg.memory; import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.abstracts.DummyWorkspace; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; /** * @author [email protected] */ public abstract class BasicMemoryManager implements MemoryManager { protected AtomicInteger frequency = new AtomicInteger(0); protected AtomicLong freqCounter = new AtomicLong(0); protected AtomicLong lastGcTime = new AtomicLong(System.currentTimeMillis()); protected AtomicBoolean periodicEnabled = new AtomicBoolean(true); protected AtomicInteger averageLoopTime = new AtomicInteger(0); protected AtomicInteger noGcWindow = new AtomicInteger(100); protected AtomicBoolean averagingEnabled = new AtomicBoolean(false); protected static final int intervalTail = 100; protected Queue<Integer> intervals = new ConcurrentLinkedQueue<>(); private ThreadLocal<MemoryWorkspace> workspace = new ThreadLocal<>(); private ThreadLocal<MemoryWorkspace> tempWorkspace = new ThreadLocal<>(); /** * This method returns * PLEASE NOTE: Cache options * depend on specific implementations * * @param bytes * @param kind * @param initialize */ @Override public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) { throw new UnsupportedOperationException("This method isn't available for this backend"); } /** * This method detaches off-heap memory from passed INDArray instances, and optionally stores them in cache for future reuse * PLEASE NOTE: Cache options depend on specific implementations * * @param arrays */ @Override public void collect(INDArray... arrays) { throw new UnsupportedOperationException("This method isn't implemented yet"); } @Override public void toggleAveraging(boolean enabled) { averagingEnabled.set(enabled); } /** * This method purges all cached memory chunks * */ @Override public void purgeCaches() { throw new UnsupportedOperationException("This method isn't implemented yet"); } @Override public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(), srcBuffer.length() * srcBuffer.getElementSize()); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, srcBuffer.length() * srcBuffer.getElementSize(), MemcpyDirection.HOST_TO_HOST); } @Override public void notifyScopeEntered() { // TODO: to be implemented } @Override public void notifyScopeLeft() { // TODO: to be implemented } @Override public void invokeGcOccasionally() { long currentTime = System.currentTimeMillis(); if (averagingEnabled.get()) intervals.add((int) (currentTime - lastGcTime.get())); // not sure if we want to conform autoGcWindow here... if (frequency.get() > 0) if (freqCounter.incrementAndGet() % frequency.get() == 0 && currentTime > getLastGcTime() + getAutoGcWindow()) { System.gc(); lastGcTime.set(System.currentTimeMillis()); } if (averagingEnabled.get()) if (intervals.size() > intervalTail) intervals.remove(); } @Override public void invokeGc() { System.gc(); lastGcTime.set(System.currentTimeMillis()); } @Override public boolean isPeriodicGcActive() { return periodicEnabled.get(); } @Override public void setOccasionalGcFrequency(int frequency) { this.frequency.set(frequency); } @Override public void setAutoGcWindow(int windowMillis) { noGcWindow.set(windowMillis); } @Override public int getAutoGcWindow() { return noGcWindow.get(); } @Override public int getOccasionalGcFrequency() { return frequency.get(); } @Override public long getLastGcTime() { return lastGcTime.get(); } @Override public void togglePeriodicGc(boolean enabled) { periodicEnabled.set(enabled); } @Override public int getAverageLoopTime() { if (averagingEnabled.get()) { int cnt = 0; for (Integer value : intervals) { cnt += value; } cnt /= intervals.size(); return cnt; } else return 0; } @Override public MemoryWorkspace getCurrentWorkspace() { return workspace.get(); } @Override public void setCurrentWorkspace(MemoryWorkspace workspace) { this.workspace.set(workspace); } @Override public MemoryWorkspace scopeOutOfWorkspaces() { MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); if (workspace == null) return new DummyWorkspace(); else { //Nd4j.getMemoryManager().setCurrentWorkspace(null); return new DummyWorkspace().notifyScopeEntered();//workspace.tagOutOfScopeUse(); } } }