/**
 * Copyright (C) 2014 Stratio (http://stratio.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.stratio.ingestion.source.snmptraps;

import static com.stratio.ingestion.source.snmptraps.SNMPSourceConstants.*;

import java.io.IOException;
import java.util.Locale;

import org.apache.flume.Context;
import org.apache.flume.EventDrivenSource;
import org.apache.flume.channel.ChannelProcessor;
import org.apache.flume.conf.Configurable;
import org.apache.flume.event.EventBuilder;
import org.apache.flume.instrumentation.SourceCounter;
import org.apache.flume.source.AbstractSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.snmp4j.CommandResponder;
import org.snmp4j.CommandResponderEvent;
import org.snmp4j.MessageDispatcher;
import org.snmp4j.MessageDispatcherImpl;
import org.snmp4j.PDU;
import org.snmp4j.Snmp;
import org.snmp4j.mp.MPv1;
import org.snmp4j.mp.MPv2c;
import org.snmp4j.mp.MPv3;
import org.snmp4j.mp.MessageProcessingModel;
import org.snmp4j.mp.SnmpConstants;
import org.snmp4j.security.AuthMD5;
import org.snmp4j.security.AuthSHA;
import org.snmp4j.security.Priv3DES;
import org.snmp4j.security.PrivAES128;
import org.snmp4j.security.PrivAES192;
import org.snmp4j.security.PrivAES256;
import org.snmp4j.security.PrivDES;
import org.snmp4j.security.SecurityLevel;
import org.snmp4j.security.SecurityModels;
import org.snmp4j.security.SecurityProtocols;
import org.snmp4j.security.USM;
import org.snmp4j.security.UsmUser;
import org.snmp4j.security.nonstandard.PrivAES192With3DESKeyExtension;
import org.snmp4j.security.nonstandard.PrivAES256With3DESKeyExtension;
import org.snmp4j.smi.OID;
import org.snmp4j.smi.OctetString;
import org.snmp4j.smi.UdpAddress;
import org.snmp4j.transport.AbstractTransportMapping;
import org.snmp4j.transport.DefaultUdpTransportMapping;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;

public class SNMPSource extends AbstractSource implements EventDrivenSource, Configurable {

    private static final Logger log = LoggerFactory.getLogger(SNMPSource.class);

    private String address;
    private int snmpTrapPort;
    private String username;
    private String password;
    private int version;
    private int trapVersion;
    private OID authMethod;
    private int securityMethod;
    private String privacyProtocol;
    private String privacyPassword;
    private SourceCounter sourceCounter;

    /** SNMP4J stuff **/
    private AbstractTransportMapping<UdpAddress> transport;
    private Snmp snmp;

    /** SNMP4J trap stuff **/
    private AbstractTransportMapping<UdpAddress> trap_transport;
    private Snmp trap_snmp;

    @Override
    public void configure(Context context) {
        address = context.getString(CONF_ADDRESS, DEFAULT_ADDRESS);
        snmpTrapPort = context.getInteger(CONF_TRAP_PORT, DEFAULT_TRAP_PORT);

        switch (Integer.valueOf(context.getString(CONF_SNMP_VERSION, DEFAULT_SNMP_VERSION)
                .replaceAll("(?i)V", "").replaceAll("(?i)C", ""))) {
            case 1:
                version = SnmpConstants.version1;
                break;
            case 2:
                version = SnmpConstants.version2c;
                break;
            case 3:
                version = SnmpConstants.version3;
                break;
            default:
                version = SnmpConstants.version1;
                break;
        }

        switch (Integer.valueOf(context
                .getString(CONF_SNMP_TRAP_VERSION, DEFAULT_SNMP_TRAP_VERSION)
                .replaceAll("(?i)V", "").replaceAll("(?i)C", ""))) {
            case 1:
                trapVersion = SnmpConstants.version1;
                break;
            case 2:
                trapVersion = SnmpConstants.version2c;
                break;
            case 3:
                trapVersion = SnmpConstants.version3;
                break;
            default:
                trapVersion = SnmpConstants.version1;
                break;
        }

        final String hashAlgo = context.getString(CONF_ENCRYPTION, DEFAULT_ENCRYPTION);
        if ("SHA".equals(hashAlgo)) {
          authMethod = AuthSHA.ID;
        } else if ("MD5".equals(hashAlgo)) {
          authMethod = AuthMD5.ID;
        } else {
          authMethod = AuthMD5.ID;
        }

        final String auth = context.getString(CONF_AUTH, DEFAULT_AUTH);
        if ("AUTH_NOPRIV".equals(auth)) {
          securityMethod = SecurityLevel.AUTH_NOPRIV;
          username = context.getString(CONF_USERNAME);
          password = context.getString(CONF_PASSWD);
        } else if ("NOAUTH_NOPRIV".equals(auth)) {
          securityMethod = SecurityLevel.NOAUTH_NOPRIV;
        } else if ("AUTH_PRIV".equals(auth)) {
          securityMethod = SecurityLevel.AUTH_PRIV;
          username = context.getString(CONF_USERNAME);
          password = context.getString(CONF_PASSWD);
          privacyProtocol = context.getString(CONF_PRIV_PROTOCOL, DEFAULT_PRIV_PROTOCOL);
          privacyPassword = context.getString(CONF_PRIV_PASSPHRASE);
        } else {
          securityMethod = SecurityLevel.NOAUTH_NOPRIV;
        }

        if (sourceCounter == null) {
            sourceCounter = new SourceCounter(getName());
        }
    }

    @Override
    public void start() {
        try {

            // --------------------------------------------------
            // POLLING
            // --------------------------------------------------
            transport = new DefaultUdpTransportMapping();
            if (version == SnmpConstants.version3) {
                snmp = new Snmp(transport);

                // add security model
                @SuppressWarnings("static-access")
                byte[] localEngineID = ((MPv3) snmp
                        .getMessageProcessingModel(MessageProcessingModel.MPv3))
                        .createLocalEngineID();
                USM usm = new USM(SecurityProtocols.getInstance(), new OctetString(localEngineID),
                        0);
                SecurityModels.getInstance().addSecurityModel(usm);
                snmp.setLocalEngine(localEngineID, 0, 0);

                // add auth and privacy
                UsmUser user = createUser();
                snmp.getUSM().addUser(new OctetString(username), user);

            } else {
                MessageDispatcher mDispatcher = new MessageDispatcherImpl();
                mDispatcher.addMessageProcessingModel(new MPv1());
                mDispatcher.addMessageProcessingModel(new MPv2c());
                mDispatcher.addMessageProcessingModel(new MPv3());

                snmp = new Snmp(mDispatcher, transport);
            }

            transport.listen();

            // --------------------------------------------------
            // TRAPS
            // --------------------------------------------------

            UdpAddress trapUdpAddress = new UdpAddress("0.0.0.0/" + snmpTrapPort);
            trap_transport = new DefaultUdpTransportMapping(trapUdpAddress);

            if (trapVersion == SnmpConstants.version3) {

                trap_snmp = new Snmp(trap_transport);

                // add security model
                @SuppressWarnings("static-access")
                byte[] localEngineID = ((MPv3) trap_snmp
                        .getMessageProcessingModel(MessageProcessingModel.MPv3))
                        .createLocalEngineID();
                USM usm = new USM(SecurityProtocols.getInstance(), new OctetString(localEngineID),
                        0);
                SecurityModels.getInstance().addSecurityModel(usm);
                trap_snmp.setLocalEngine(localEngineID, 0, 0);

                // add auth and Privacy
                UsmUser user = createUser();
                trap_snmp.getUSM().addUser(new OctetString(username), user);

            } else {
                MessageDispatcher mDispatcher = new MessageDispatcherImpl();
                mDispatcher.addMessageProcessingModel(new MPv1());
                mDispatcher.addMessageProcessingModel(new MPv2c());
                mDispatcher.addMessageProcessingModel(new MPv3());

                trap_snmp = new Snmp(mDispatcher, trap_transport);
            }

            CommandResponder trapsCatch = new CommandResponder() {
                public synchronized void processPdu(CommandResponderEvent e) {
                    PDU command = e.getPDU();
                    if (command != null) {
                        ChannelProcessor channelProcessor = getChannelProcessor();
                        sourceCounter.addToEventReceivedCount(1);
                        sourceCounter.incrementAppendBatchReceivedCount();
                        channelProcessor.processEvent(EventBuilder.withBody(command.toString(),
                                Charsets.UTF_8));
                        sourceCounter.addToEventAcceptedCount(1);
                        sourceCounter.incrementAppendBatchAcceptedCount();
                    }
                }
            };

            trap_snmp.addCommandResponder(trapsCatch);
            trap_transport.listen();

            // --------------------------------------------------

            log.debug("[SNMP] SNMP Trap binding is listening on " + address);

        } catch (IOException e) {
            log.debug("couldn't listen to " + address + "----" + e);
            e.printStackTrace();
        }

        sourceCounter.start();

    }

    @Override
    public void stop() {
        log.debug("Closing SNMP connection with: " + this.address);
        try {
            snmp.close();
            trap_snmp.close();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @VisibleForTesting
    protected SourceCounter getSourceCounter() {
        return sourceCounter;
    }

    private UsmUser createUser() {
        OID privacyOID = null;
        OctetString privacyPasswd = null;
        if (securityMethod == SecurityLevel.AUTH_PRIV) {
            final String privacyProtocolUpper = privacyProtocol.toUpperCase(Locale.ENGLISH);
                if ("PRIVDES".equals(privacyProtocolUpper)) {
                  privacyOID = PrivDES.ID;
                } else if ("PRIV3DES".equals(privacyProtocolUpper)) {
                  privacyOID = Priv3DES.ID;
                } else if ("PRIVAES128".equals(privacyProtocolUpper)) {
                  privacyOID = PrivAES128.ID;
                } else if  ("PRIVAES192".equals(privacyProtocolUpper)) {
                  privacyOID = PrivAES192.ID;
                } else if ("PRIVAES256".equals(privacyProtocolUpper)) {
                  privacyOID = PrivAES256.ID;
                } else if  ("PRIVAES192WITH3DESKEYEXTENSION".equals(privacyProtocolUpper)) {
                  privacyOID = PrivAES192With3DESKeyExtension.ID;
                } else if ("PRIVAES256WITH3DESKEYEXTENSION".equals(privacyProtocolUpper)) {
                  privacyOID = PrivAES256With3DESKeyExtension.ID;
                } else {
                    log.debug("Privacy protocol " + privacyProtocolUpper + " unsupported or invalid.");
                }
            privacyPasswd = new OctetString(privacyPassword);
        }
        return new UsmUser(new OctetString(username), authMethod, new OctetString(password),
                privacyOID, privacyPasswd);
    }

}