package yushijinhun.authlibagent.dao;

import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.annotation.Resource;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.SetOperations;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;
import yushijinhun.authlibagent.model.Token;
import java.util.HashMap;
import java.io.UnsupportedEncodingException;
import java.util.Collections;
import static java.util.stream.Collectors.toSet;
import static yushijinhun.authlibagent.util.HexUtils.bytesToHex;

@Component
public class TokenRepositoryImpl implements TokenRepository {

	/*
	 * Redis结构:
	 * 
	 * a) 对accessToken->Token的存储:
	 * "A"+accessToken -> <hash>{
	 * 	"c": clientToken
	 * 	"o": owner
	 *     "p": selectedProfile
	 *  	"t": lastRefreshTime
	 *  	"a": createTime // 即通过authenticate请求创建时的时间, 不会因为refresh重置
	 * }
	 * 
	 * b) 对clientToken->Tokens的存储:
	 * "C"+clientToken -> <set>[
	 * 	accessToken...
	 * ]
	 * 
	 * c) 对account->Tokens的存储:
	 * "O"+account -> <set>[
	 * 	accessToken...
	 * ]
	 * 
	 * d) 用于对accessToken进行过期处理的存储:
	 *  // expire: $tokenExpireTime
	 *  "X"+accessToken->""
	 * 
	 * 当redis通知一个d)的过期事件时, 则移除a),b), c)中的映射.
	 * 同时, 为了防止redis删除了d)但没有通知客户端, 每经过$tokenExpireScanCycle秒,
	 * 将自动遍历所有token, 测试过期.
	 */

	/* 
	* Token 生命周期:
	* 
	* |----1. 有效----|----2. 暂时失效----|3. 无效
	* |------------------------------------------------------>Time
	* 
	* 当一个token被创建时, 将处于有效状态.
	* 经过$tokenPreExpireTime秒后, token状态转为暂时失效,
	* 此时进行validate请求返回false, 但仍可以refresh, 经refresh之后分配新的token将处于有效状态.
	* 如果不进行refresh, 该token将在创建后$tokenExpireTime秒后失效.
	* 当一个token自通过authenticate创建后, 经过$tokenMaxLivingTime之秒后(refresh不会重置该值), 也将自动失效.
	*/

	private static final Logger LOGGER = LogManager.getFormatterLogger();

	private static final String PREFIX_ACCESS_TOKEN = "A";
	private static final String PREFIX_CLIENT_TOKEN = "C";
	private static final String PREFIX_ACCOUNT = "O";
	private static final String PREFIX_EXPIRE = "X";
	private static final String KEY_CLIENT_TOKEN = "c";
	private static final String KEY_OWNER = "o";
	private static final String KEY_SELECTED_PROFILE = "p";
	private static final String KEY_LAST_REFRESH_TIME = "t";
	private static final String KEY_CREATE_TIME = "a";

	@Autowired
	private RedisMessageListenerContainer container;

	@Autowired
	private RedisTemplate<String, String> template;

	@Resource(name = "redisTemplate")
	private HashOperations<String, String, String> hashOps;

	@Resource(name = "redisTemplate")
	private SetOperations<String, String> setOps;

	@Resource(name = "redisTemplate")
	private ValueOperations<String, String> valOps;

	// Unit: second
	@Value("#{config['expire.token.time']}")
	private long tokenExpireTime;

	// Unit: second
	@Value("#{config['expire.token.maxLiving']}")
	private long tokenMaxLivingTime;

	// Unit: second
	@Value("#{config['expire.token.scanCycle']}")
	private long tokenExpireScanCycle;

	@Value("#{config['security.maxTokensPerAccounts']}")
	private int maxTokensPerAccounts;

	@Value("#{config['security.extraTokensToDelete']}")
	private int extraTokensToDelete;

	private MessageListener expiredListener;

	private Thread expireScanThread;

	@PostConstruct
	private void registerExpiredEventListener() {
		expiredListener = new MessageListener() {

			@Override
			public void onMessage(Message message, byte[] pattern) {
				byte[] body = message.getBody();
				if (body == null) {
					return;
				}

				String key;
				try {
					key = new String(message.getBody(), "UTF-8");
				} catch (UnsupportedEncodingException e) {
					LOGGER.debug(() -> "failed to decode message body: " + bytesToHex(body), e);
					return;
				}

				if (!key.startsWith(PREFIX_EXPIRE)) {
					return;
				}

				String accessToken = key.substring(PREFIX_EXPIRE.length());
				template.delete(keyAccessToken(accessToken));
				Map<String, String> values = hashOps.entries(keyAccessToken(accessToken));
				if (values != null && !values.isEmpty()) {
					setOps.remove(keyClientToken(values.get(KEY_CLIENT_TOKEN)), accessToken);
					setOps.remove(keyAccount(values.get(KEY_OWNER)), accessToken);
				}
			}
		};
		container.addMessageListener(expiredListener, new PatternTopic("__keyevent@*__:expired"));
	}

	@PreDestroy
	private void unregisterExpiredEventListener() {
		container.removeMessageListener(expiredListener);
	}

	@PostConstruct
	private void startExpireScanThread() {
		expireScanThread = new Thread(() -> {
			while (true) {
				try {
					Thread.sleep(tokenExpireScanCycle * 1000);
				} catch (InterruptedException e) {
					return;
				}

				template.keys(PREFIX_ACCESS_TOKEN + "*").stream().map(k -> k.substring(PREFIX_EXPIRE.length())).forEach(this::testExpire);
			}
		});
		expireScanThread.setDaemon(true);
		expireScanThread.start();
	}

	@PreDestroy
	private void stopExpireScanThread() {
		expireScanThread.interrupt();

		// wait the thread to terminate
		try {
			expireScanThread.join();
		} catch (InterruptedException e) {
			// restore interrupted status
			Thread.currentThread().interrupt();
		}
	}

	@Override
	public Token get(String accessToken) {
		// notify the accessToken to be expired
		valOps.get(keyExpire(accessToken));

		Map<String, String> values = hashOps.entries(keyAccessToken(accessToken));
		if (values == null || values.isEmpty() || testExpire(accessToken, values)) {
			return null;
		}

		Token token = new Token();
		token.setAccessToken(accessToken);
		token.setClientToken(values.get(KEY_CLIENT_TOKEN));
		token.setOwner(values.get(KEY_OWNER));
		token.setLastRefreshTime(Long.parseLong(values.get(KEY_LAST_REFRESH_TIME)));
		token.setCreateTime(Long.parseLong(values.get(KEY_CREATE_TIME)));
		String selectedProfile = values.get(KEY_SELECTED_PROFILE);
		token.setSelectedProfile(selectedProfile.isEmpty() ? null : UUID.fromString(selectedProfile));
		return token;
	}

	@Override
	public Set<Token> getByClientToken(String clientToken) {
		return getTokens(setOps.members(keyClientToken(clientToken)));
	}

	@Override
	public Set<Token> getByAccount(String account) {
		return getTokens(setOps.members(keyAccount(account)));
	}

	@Override
	public void put(Token token) {
		String accessToken = token.getAccessToken();
		String clientToken = token.getClientToken();
		String owner = token.getOwner();

		String expireKey = keyExpire(accessToken);
		String accountKey = keyAccount(owner);

		Map<String, String> values = new HashMap<>();
		values.put(KEY_CLIENT_TOKEN, clientToken);
		values.put(KEY_OWNER, owner);
		values.put(KEY_LAST_REFRESH_TIME, String.valueOf(token.getLastRefreshTime()));
		values.put(KEY_CREATE_TIME, String.valueOf(token.getCreateTime()));
		UUID selectedProfile = token.getSelectedProfile();
		values.put(KEY_SELECTED_PROFILE, selectedProfile == null ? "" : selectedProfile.toString());

		int accountTokens = setOps.size(accountKey).intValue();
		if (accountTokens > maxTokensPerAccounts) {
			// token limit reached

			// remove $extraTokensToDelete more tokens every time,
			int tokensToDelete = accountTokens - maxTokensPerAccounts + extraTokensToDelete;
			for (int i = 0; i < tokensToDelete; i++) {
				String expiredAccessToken = setOps.pop(accountKey);
				if (expiredAccessToken != null) {
					String expiredClientToken = hashOps.get(keyAccessToken(expiredAccessToken), KEY_CLIENT_TOKEN);
					if (expiredClientToken != null) {
						setOps.remove(keyClientToken(expiredClientToken), expiredAccessToken);
					}
					template.delete(keyExpire(expiredAccessToken));
					template.delete(keyAccessToken(expiredAccessToken));
				}
			}
		}

		valOps.set(expireKey, "");
		template.expire(expireKey, tokenExpireTime, TimeUnit.SECONDS);
		hashOps.putAll(keyAccessToken(accessToken), values);
		setOps.add(keyClientToken(clientToken), accessToken);
		setOps.add(accountKey, accessToken);
	}

	@Override
	public void delete(String accessToken) {
		Map<String, String> values = hashOps.entries(keyAccessToken(accessToken));
		if (values == null || values.isEmpty()) {
			return;
		}

		setOps.remove(keyAccount(values.get(KEY_OWNER)), accessToken);
		setOps.remove(keyClientToken(values.get(KEY_CLIENT_TOKEN)), accessToken);
		template.delete(keyExpire(accessToken));
		template.delete(keyAccessToken(accessToken));
	}

	@Override
	public void deleteByAccount(String account) {
		String accountKey = keyAccount(account);
		Set<String> accessTokens = setOps.members(accountKey);
		if (accessTokens == null || accessTokens.isEmpty()) {
			return;
		}

		// unlink account -> accessTokens
		template.delete(accountKey);
		// unlink expire keys
		template.delete(accessTokens.stream().map(this::keyExpire).collect(toSet()));
		// unlink clientToken -> accessTokens
		for (String accessToken : accessTokens) {
			String clientToken = hashOps.get(keyAccessToken(accessToken), KEY_CLIENT_TOKEN);
			if (clientToken != null) {
				setOps.remove(keyClientToken(clientToken), accessToken);
			}
		}
		// delete tokens
		template.delete(accessTokens);
	}

	private Set<Token> getTokens(Set<String> accessTokens) {
		if (accessTokens == null || accessTokens.isEmpty()) {
			return Collections.emptySet();
		}
		return accessTokens.stream().map(this::get).collect(toSet());
	}

	private boolean testExpire(String accessToken) {
		Map<String, String> values = hashOps.entries(keyAccessToken(accessToken));
		if (values == null || values.isEmpty()) {
			return false;
		}
		return testExpire(accessToken, values);
	}

	private boolean testExpire(String accessToken, Map<String, String> values) {
		long createTime = Long.parseLong(values.get(KEY_CREATE_TIME));
		long lastRefreshTime = Long.parseLong(values.get(KEY_LAST_REFRESH_TIME));
		long now = System.currentTimeMillis();
		if (createTime + tokenMaxLivingTime * 1000 < now || lastRefreshTime + tokenExpireTime * 1000 < now) {
			// reached max living time
			delete(accessToken);
			return true;
		}
		return false;
	}

	private String keyAccessToken(String accessToken) {
		return PREFIX_ACCESS_TOKEN + accessToken;
	}

	private String keyClientToken(String clientToken) {
		return PREFIX_CLIENT_TOKEN + clientToken;
	}

	private String keyAccount(String account) {
		return PREFIX_ACCOUNT + account;
	}

	private String keyExpire(String accessToken) {
		return PREFIX_EXPIRE + accessToken;
	}

}