package org.asamk.signal.storage.protocol;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;

import org.asamk.signal.util.Util;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.libsignal.state.SessionStore;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.util.UuidUtil;
import org.whispersystems.util.Base64;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;

class JsonSessionStore implements SessionStore {

    private final List<SessionInfo> sessions = new ArrayList<>();

    private SignalServiceAddressResolver resolver;

    public JsonSessionStore() {
    }

    public void setResolver(final SignalServiceAddressResolver resolver) {
        this.resolver = resolver;
    }

    private SignalServiceAddress resolveSignalServiceAddress(String identifier) {
        if (resolver != null) {
            return resolver.resolveSignalServiceAddress(identifier);
        } else {
            return Util.getSignalServiceAddressFromIdentifier(identifier);
        }
    }

    @Override
    public synchronized SessionRecord loadSession(SignalProtocolAddress address) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(address.getName());
        for (SessionInfo info : sessions) {
            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
                try {
                    return new SessionRecord(info.sessionRecord);
                } catch (IOException e) {
                    System.err.println("Failed to load session, resetting session: " + e);
                    final SessionRecord sessionRecord = new SessionRecord();
                    info.sessionRecord = sessionRecord.serialize();
                    return sessionRecord;
                }
            }
        }

        return new SessionRecord();
    }

    public synchronized List<SessionInfo> getSessions() {
        return sessions;
    }

    @Override
    public synchronized List<Integer> getSubDeviceSessions(String name) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(name);

        List<Integer> deviceIds = new LinkedList<>();
        for (SessionInfo info : sessions) {
            if (info.address.matches(serviceAddress) && info.deviceId != 1) {
                deviceIds.add(info.deviceId);
            }
        }

        return deviceIds;
    }

    @Override
    public synchronized void storeSession(SignalProtocolAddress address, SessionRecord record) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(address.getName());
        for (SessionInfo info : sessions) {
            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
                if (!info.address.getUuid().isPresent() || !info.address.getNumber().isPresent()) {
                    info.address = serviceAddress;
                }
                info.sessionRecord = record.serialize();
                return;
            }
        }

        sessions.add(new SessionInfo(serviceAddress, address.getDeviceId(), record.serialize()));
    }

    @Override
    public synchronized boolean containsSession(SignalProtocolAddress address) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(address.getName());
        for (SessionInfo info : sessions) {
            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
                return true;
            }
        }
        return false;
    }

    @Override
    public synchronized void deleteSession(SignalProtocolAddress address) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(address.getName());
        sessions.removeIf(info -> info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId());
    }

    @Override
    public synchronized void deleteAllSessions(String name) {
        SignalServiceAddress serviceAddress = resolveSignalServiceAddress(name);
        deleteAllSessions(serviceAddress);
    }

    public synchronized void deleteAllSessions(SignalServiceAddress serviceAddress) {
        sessions.removeIf(info -> info.address.matches(serviceAddress));
    }

    public static class JsonSessionStoreDeserializer extends JsonDeserializer<JsonSessionStore> {

        @Override
        public JsonSessionStore deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException {
            JsonNode node = jsonParser.getCodec().readTree(jsonParser);

            JsonSessionStore sessionStore = new JsonSessionStore();

            if (node.isArray()) {
                for (JsonNode session : node) {
                    String sessionName = session.has("name")
                            ? session.get("name").asText()
                            : null;
                    if (UuidUtil.isUuid(sessionName)) {
                        // Ignore sessions that were incorrectly created with UUIDs as name
                        continue;
                    }

                    UUID uuid = session.hasNonNull("uuid")
                            ? UuidUtil.parseOrNull(session.get("uuid").asText())
                            : null;
                    final SignalServiceAddress serviceAddress = uuid == null
                            ? Util.getSignalServiceAddressFromIdentifier(sessionName)
                            : new SignalServiceAddress(uuid, sessionName);
                    final int deviceId = session.get("deviceId").asInt();
                    final String record = session.get("record").asText();
                    try {
                        SessionInfo sessionInfo = new SessionInfo(serviceAddress, deviceId, Base64.decode(record));
                        sessionStore.sessions.add(sessionInfo);
                    } catch (IOException e) {
                        System.out.println(String.format("Error while decoding session for: %s", sessionName));
                    }
                }
            }

            return sessionStore;
        }
    }

    public static class JsonSessionStoreSerializer extends JsonSerializer<JsonSessionStore> {

        @Override
        public void serialize(JsonSessionStore jsonSessionStore, JsonGenerator json, SerializerProvider serializerProvider) throws IOException {
            json.writeStartArray();
            for (SessionInfo sessionInfo : jsonSessionStore.sessions) {
                json.writeStartObject();
                if (sessionInfo.address.getNumber().isPresent()) {
                    json.writeStringField("name", sessionInfo.address.getNumber().get());
                }
                if (sessionInfo.address.getUuid().isPresent()) {
                    json.writeStringField("uuid", sessionInfo.address.getUuid().get().toString());
                }
                json.writeNumberField("deviceId", sessionInfo.deviceId);
                json.writeStringField("record", Base64.encodeBytes(sessionInfo.sessionRecord));
                json.writeEndObject();
            }
            json.writeEndArray();
        }
    }

}