package com.alibaba.spring.boot.rsocket.broker.security;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTCreator;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.JWTVerifier;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.jetbrains.annotations.Nullable;
import org.springframework.stereotype.Service;
import org.springframework.util.StreamUtils;

import java.io.*;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * authentication service with JWT implementation, please refer https://github.com/auth0/java-jwt
 *
 * @author leijuan
 */
@Service
public class AuthenticationServiceJwtImpl implements AuthenticationService {
    private List<JWTVerifier> verifiers = new ArrayList<>();
    private static String iss = "RSocketBroker";
    /**
     * cache verified principal
     */
    Cache<Integer, RSocketAppPrincipal> jwtVerifyCache = Caffeine.newBuilder()
            .maximumSize(100_000)
            .expireAfterWrite(30, TimeUnit.MINUTES)
            .build();

    public AuthenticationServiceJwtImpl() throws Exception {
        File rsocketKeysDir = new File(System.getProperty("user.home"), ".rsocket");
        File publicKeyFile = new File(rsocketKeysDir, "jwt_rsa.pub");
        // generate RSA key pairs automatically
        if (!publicKeyFile.exists()) {
            if (!rsocketKeysDir.exists()) {
                //noinspection ResultOfMethodCallIgnored
                rsocketKeysDir.mkdir();
            }
            generateRSAKeyPairs(rsocketKeysDir);
        }
        Algorithm algorithmRSA256Public = Algorithm.RSA256(readPublicKey(), null);
        this.verifiers.add(JWT.require(algorithmRSA256Public).withIssuer(iss).build());
    }

    @Override
    @Nullable
    public RSocketAppPrincipal auth(String type, String credentials) {
        int tokenHashCode = credentials.hashCode();
        RSocketAppPrincipal principal = jwtVerifyCache.getIfPresent(tokenHashCode);
        for (JWTVerifier verifier : verifiers) {
            try {
                principal = new JwtPrincipal(verifier.verify(credentials), credentials);
                jwtVerifyCache.put(tokenHashCode, principal);
                break;
            } catch (JWTVerificationException ignore) {

            }
        }
        return principal;
    }

    public String generateCredentials(String id, String[] organizations, String[] serviceAccounts, String[] roles, String[] authorities, String sub, String[] audience) throws Exception {
        Algorithm algorithmRSA256Private = Algorithm.RSA256(null, readPrivateKey());
        Arrays.sort(audience);
        Arrays.sort(organizations);
        JWTCreator.Builder builder = JWT.create()
                .withIssuer(iss)
                .withSubject(sub)
                .withAudience(audience)
                .withIssuedAt(new Date())
                .withClaim("id", id)
                .withArrayClaim("sas", serviceAccounts)
                .withArrayClaim("orgs", organizations);
        if (roles != null && roles.length > 0) {
            Arrays.sort(roles);
            builder = builder.withArrayClaim("roles", roles);
        }
        if (authorities != null && authorities.length > 0) {
            builder = builder.withArrayClaim("authorities", authorities);
        }
        return builder.sign(algorithmRSA256Private);
    }


    public RSAPrivateKey readPrivateKey() throws Exception {
        File keyFile = new File(System.getProperty("user.home"), ".rsocket/jwt_rsa.key");
        try (InputStream inputStream = new FileInputStream(keyFile)) {
            byte[] keyBytes = toBytes(inputStream);
            PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
            return (RSAPrivateKey) KeyFactory.getInstance("RSA").generatePrivate(spec);
        }

    }

    public RSAPublicKey readPublicKey() throws Exception {
        File keyFile = new File(System.getProperty("user.home"), ".rsocket/jwt_rsa.pub");
        try (InputStream inputStream = new FileInputStream(keyFile)) {
            byte[] keyBytes = toBytes(inputStream);
            X509EncodedKeySpec spec = new X509EncodedKeySpec(keyBytes);
            return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(spec);
        }
    }

    public byte[] toBytes(InputStream inputStream) throws IOException {
        ByteArrayOutputStream buffer = new ByteArrayOutputStream();
        StreamUtils.copy(inputStream, buffer);
        byte[] bytes = buffer.toByteArray();
        inputStream.close();
        buffer.close();
        return bytes;
    }

    private void generateRSAKeyPairs(File rsocketKeysDir) throws Exception {
        KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
        kpg.initialize(2048);
        KeyPair keyPair = kpg.generateKeyPair();
        Key pub = keyPair.getPublic();
        Key pvt = keyPair.getPrivate();
        try (OutputStream out = new FileOutputStream(new File(rsocketKeysDir, "jwt_rsa.key"))) {
            out.write(pvt.getEncoded());
        }
        try (OutputStream out2 = new FileOutputStream(new File(rsocketKeysDir, "jwt_rsa.pub"))) {
            out2.write(pub.getEncoded());
        }
    }
}