/*
 * Copyright 2016 kay schluehr.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.jupyterkernel.kernel;

import java.nio.charset.StandardCharsets;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import org.json.JSONObject;
import org.jupyterkernel.json.messages.T_JSON;
import org.jupyterkernel.json.messages.T_header;
import org.zeromq.ZMQ.Socket;
import org.zeromq.ZMsg;
import org.zeromq.ZFrame;
import org.jupyterkernel.json.messages.T_message;
import org.jupyterkernel.util.HexBinaryConverter;
import org.jupyterkernel.util.UUID;

/**
 *
 * @author kay schluehr
 *
 * The MessageObject bundles a message together with a socket. This allows to
 * use a socket for a request/reply pair without handling a registry of messages
 *
 */
public class MessageObject {
    
    final String[] supportedProtocolVersions = {"5.0", "5.1", "5.2", "5.3"};
    final String delimiter  = "<IDS|MSG>";
    final byte[] bDelimiter = delimiter.getBytes();

    /* Expresses message parts according to the wire protocol
     *
     * http://jupyter-client.readthedocs.org/en/latest/messaging.html#the-wire-protocol
     * 
     b'u-u-i-d',         # zmq identity(ies)
     b'<IDS|MSG>',       # delimiter
     b'baddad42',        # HMAC signature
     b'{header}',        # serialized header dict
     b'{parent_header}', # serialized parent header dict
     b'{metadata}',      # serialized metadata dict
     b'{content}',       # serialized content dict
     b'blob',            # extra raw data buffer(s)
     */
    class MessageParts {

        public static final int UUID = 0;
        public static final int DELIM = 1;
        public static final int HMAC = 2;
        public static final int HEADER = 3;
        public static final int PARENT = 4;
        public static final int METADATA = 5;
        public static final int CONTENT = 6;
        public static final int BLOB = 7;
    }        

    public T_message msg;
    // The socket from which the message was received.
    public Socket socket;
    // zmq uuid
    byte[] uuid;
    // HMAC key
    byte[] key;
    
    ZMsg zmsg;

    public MessageObject(ZMsg zmsg, Socket socket, byte[] key) {
        this.socket = socket;
        this.key = key;
        this.zmsg = zmsg;
        this.msg = new T_message();
    }
    
    public MessageObject(MessageObject other) {
        this.socket = other.socket;
        this.key = other.key;
        this.zmsg = other.zmsg;
        this.msg = (T_message)other.msg.clone();        
    }
    
    private void checkAllowedProtocolVersion(String protocol)
    {
        for(String version: supportedProtocolVersions)
        {
            if(version.equals(protocol))
                return;
        }
        throw new RuntimeException("[jupyter-kernel] Protocol version "+protocol+"not supported by this kernel");
    }
    
        
    private byte[] computeSignature(byte[] header, 
                                    byte[] parent, 
                                    byte[] meta, 
                                    byte[] content) {
        byte[][] data = {header, parent, meta, content};
        try {
            SecretKeySpec keySpec = new SecretKeySpec(key, "HmacSHA256");
            Mac mac = Mac.getInstance("HmacSHA256");
            mac.init(keySpec);
            for (int i = 0; i < 4; i++) {
                mac.update(data[i]);
            }
            return mac.doFinal();

        } catch (InvalidKeyException | NoSuchAlgorithmException e) {
            System.out.println(e.getMessage());
        }
        return null;
    }

    private void mildlySecureMACCompare(byte[] mac1, byte[] mac2) {
        boolean hmacValid = true;
        if (mac1.length != mac2.length) {
            hmacValid = false;
        } else {
            // use full loop and don't break at error
            for (int i = 0; i < mac1.length; i++) {
                if (mac1[i] != mac2[i]) {
                    hmacValid = false;
                }
            }
        }
        if (!hmacValid) {
            throw new RuntimeException("[jupyter-kernel.jar] HMAC verification failed");
        }
    }

    public void read() {
        try {
            ZFrame[] zframes = new ZFrame[zmsg.size()];
            zmsg.toArray(zframes);
            if (zmsg.size() < 7) {
                throw new RuntimeException("[jupyter-kernel.jar] Message incomplete. Didn't receive required message parts");
            }
            uuid = zframes[MessageParts.UUID].getData();
            String delim = new String(zframes[MessageParts.DELIM].getData(), 
                                      StandardCharsets.UTF_8);
            if (!delim.equals(delimiter)) {
                throw new RuntimeException("[jupyter-kernel.jar] Incorrectly formatted message. Delimiter <IDS|MSG> not found");
            }                        
            byte[] header = zframes[MessageParts.HEADER].getData();
            byte[] parent = zframes[MessageParts.PARENT].getData();
            byte[] meta = zframes[MessageParts.METADATA].getData();
            byte[] content = zframes[MessageParts.CONTENT].getData();

            byte[] digest = computeSignature(header, parent, meta, content);
            byte[] hmac = zframes[MessageParts.HMAC].getData();
            // hmac is an UTF-8 string and has to be converted into a byte array first
            hmac = HexBinaryConverter.parseHexBinary(new String(hmac));
            
            mildlySecureMACCompare(digest, hmac);
            
            JSONObject jsonHeader = new JSONObject(new String(header, StandardCharsets.UTF_8));
            if(null == T_JSON.message_protocol_version)
            {
                String protocolVersion = (String)jsonHeader.get("version");
                checkAllowedProtocolVersion(protocolVersion);
                // set protocol version for protocol specific serialization / deserialization
                T_JSON.setProtocolVersion(protocolVersion);
            }
            msg.header = (T_header)T_JSON.fromJSON("T_header", jsonHeader);                    
            msg.parent_header = (T_header)T_JSON.fromJSON("T_header", 
                    new JSONObject(new String(parent, StandardCharsets.UTF_8)));
            msg.metadata = new JSONObject(new String(meta, StandardCharsets.UTF_8));
            msg.content = T_JSON.fromJSON("T_"+msg.header.msg_type, 
                    new JSONObject(new String(content, StandardCharsets.UTF_8)));
            
        } finally {
            zmsg.destroy();
        }
    }
    
    public void send() {  
        msg.header.msg_id = UUID.newID();
        JSONObject jsonMsg = msg.toJSON();        
        ZMsg newZmsg = new ZMsg();
        newZmsg.add(uuid);
        newZmsg.add(bDelimiter);
        byte[] header  = jsonMsg.getJSONObject("header").toString().getBytes();
        byte[] parent  = jsonMsg.getJSONObject("parent_header").toString().getBytes();        
        byte[] meta    = jsonMsg.getJSONObject("metadata").toString().getBytes();
        byte[] content = jsonMsg.getJSONObject("content").toString().getBytes();
        byte[] digest  = computeSignature(header, parent, meta, content);
        digest = HexBinaryConverter.toHexBinary(digest).toLowerCase().getBytes();
        newZmsg.add(digest);
        newZmsg.add(header);
        newZmsg.add(parent);
        newZmsg.add(meta);
        newZmsg.add(content);        
        newZmsg.send(socket);                
    }

}