#!/bin/python3

import time
import serial
import argparse
import binascii

import serial.tools.list_ports as list_ports

from clint.textui import colored, puts, progress

COMMAND_HELLO = '>'

COMMAND_BUFFER_CRC = 'h'
COMMAND_BUFFER_LOAD = 'l'
COMMAND_BUFFER_STORE = 's'

COMMAND_FLASH_READ = 'r'
COMMAND_FLASH_WRITE = 'w'
COMMAND_FLASH_ERASE_SECTOR = 'k'

COMMAND_WRITE_PROTECTION_ENABLE = 'p'
COMMAND_WRITE_PROTECTION_DISABLE = 'u'
COMMAND_WRITE_PROTECTION_CHECK = 'x'
COMMAND_STATUS_REGISTER_READ = 'y'
COMMAND_ID_REGISTER_READ = 'i'

COMMAND_SET_CS_IO = '*'
COMMAND_SET_OUTPUT = 'o'

WRITE_PROTECTION_NONE = 0x00
WRITE_PROTECTION_PARTIAL = 0x01
WRITE_PROTECTION_FULL = 0x02
WRITE_PROTECTION_UNKNOWN = 0x03

WRITE_PROTECTION_CONFIGURATION_NONE = 0x00
WRITE_PROTECTION_CONFIGURATION_PARTIAL = 0x01
WRITE_PROTECTION_CONFIGURATION_LOCKED = 0x02
WRITE_PROTECTION_CONFIGURATION_UNKNOWN = 0x03

DEFAULT_FLASH_SIZE = 4096 * 1024
DEFAULT_SECTOR_SIZE = 4096
DEFAULT_PAGE_SIZE = 256

ENCODING = 'iso-8859-1'

DEBUG_NORMAL = 1
DEBUG_VERBOSE = 2


def logMessage(text):
    puts(colored.white(text))


def logOk(text):
    puts(colored.green(text))


def logError(text):
    puts(colored.red(text))


def logDebug(text, type):
    if type == DEBUG_NORMAL:
        puts(colored.cyan(text))
    else:  # DEBUG_VERBOSE
        puts(colored.magenta(text))


class SerialProgrammer:

    def __init__(self, port, baud_rate, debug='off', sector_size=DEFAULT_SECTOR_SIZE, page_size=DEFAULT_PAGE_SIZE):
        self.sector_size = sector_size
        self.page_size = page_size
        self.pages_per_sector = self.sector_size // self.page_size

        if debug == 'normal':
            self.debug = DEBUG_NORMAL
        elif debug == 'verbose':
            self.debug = DEBUG_VERBOSE
        else:  # off
            self.debug = 0

        self._debug('Opening serial connection')
        self.sock = serial.Serial(port, baud_rate, timeout=1)
        self._debug('Serial connection opened successfully')

        time.sleep(2)  # Wait for the Arduino bootloader

    def _debug(self, message, level=DEBUG_NORMAL):
        if self.debug >= level:
            logDebug(message, level)

    def _readExactly(self, length, tries=3):
        """Read exactly n bytes or return None"""
        data = b''
        _try = 0
        while len(data) < length and _try < tries:
            new_data = self.sock.read(length - len(data))
            if new_data == b'':
                _try += 1

            data += new_data

        if len(data) != length:
            return None

        return data

    def _waitForMessage(self, text, tries=3, max_length=100):
        """Wait for the expected message and return True or return False"""
        self._debug('Waiting for \'%s\'' % text, DEBUG_VERBOSE)

        data = text.encode(ENCODING)
        return self._waitFor(len(data), lambda _data: data == _data, tries, max_length)

    def _waitFor(self, length, check, tries=3, max_length=100):
        """Wait for the expected message and return True or return False"""
        data = b''

        _try = 0
        while _try < tries:
            new_data = self.sock.read(max(length - len(data), 1))
            if new_data == b'':
                _try += 1

            max_length -= len(new_data)
            if max_length < 0:
                return False

            self._debug('Recv: \'%s\'' % new_data.decode(ENCODING), DEBUG_VERBOSE)

            data = (data + new_data)[-length:]
            if check(data):
                return True

        return False

    def _getUntilMessage(self, text, tries=3, max_length=100):
        """Wait for the expected message and return the data received"""
        self._debug('Reading until \'%s\'' % text, DEBUG_VERBOSE)

        data = text.encode(ENCODING)
        return self._getUntil(len(data), lambda _data: data == _data, tries, max_length)

    def _getUntil(self, length, check, tries=3, max_length=1000):
        """Wait for the expected message and return the data received"""
        data = b''
        message = b''

        _try = 0
        while _try < tries:
            new_data = self.sock.read(max(length - len(message), 1))
            if new_data == b'':
                _try += 1

            max_length -= len(new_data)
            if max_length < 0:
                return None

            self._debug('Recv: \'%s\'' % new_data.decode(ENCODING), DEBUG_VERBOSE)

            message = (message + new_data)[-length:]
            data += new_data
            if check(message):
                return data[:-len(message)]

        return None

    def _dump(self, data_str):
        for offset, data_row in [(i, data_str[i:i+16]) for i in range(0, len(data_str), 16)]:
            logMessage('%08x: %s' % (offset, ' '.join([data_row[i:i+2] for i in range(0, 16, 2)])))
        return

    def _sendCommand(self, command):
        self._debug('Send: \'%s\'' % command, DEBUG_VERBOSE)

        self.sock.write(command.encode(ENCODING))
        self.sock.flush()

    def _eraseSector(self, sector):
        self._debug('Command: ERASE_SECTOR %d' % sector)

        self._sendCommand('%s%08x' % (COMMAND_FLASH_ERASE_SECTOR, sector))
        return self._waitForMessage(COMMAND_FLASH_ERASE_SECTOR)

    def _readCRC(self):
        self._debug('Command: BUFFER_CRC')

        # Write crc check
        self._sendCommand(COMMAND_BUFFER_CRC)

        # Wait for crc start
        if not self._waitForMessage(COMMAND_BUFFER_CRC):
            self._debug('Invalid / no response for BUFFER_CRC command')
            return None

        crc = self._readExactly(8).decode(ENCODING)
        if crc is None:
            self._debug('Invalid / no CRC response')
            return None

        try:
            return int(crc, 16)
        except ValueError:
            self._debug('Could not decode CRC')
            return None

    def _loadPageOnce(self, page, tries=3):
        """Read a page into the internal buffer"""
        self._debug('Command: FLASH_READ %d' % page)

        # Reads page
        self._sendCommand('%s%08x' % (COMMAND_FLASH_READ, page))

        # Wait for read acknowledge
        if not self._waitForMessage(COMMAND_FLASH_READ):
            self._debug('Invalid / no response for FLASH_READ command')
            return None

        crc = self._readExactly(8).decode(ENCODING)
        if crc is None:
            self._debug('Invalid / no CRC response')
            return None

        try:
            return int(crc, 16)
        except ValueError:
            self._debug('Could not decode CRC')
            return None

    def _loadPageMultiple(self, page, tries=3):
        """Read a page into the internal buffer

        Keeps reading until we get two page reads the same checksum.
        """
        self._debug('Command: FLASH_READ_MULTIPLE %d' % page)

        crc_list = []
        _try = 0

        while _try < tries:
            crc = self._loadPageOnce(page, tries)
            if crc is None:
                _try += 1
                continue

            if len(crc_list) >= 1:
                if crc in crc_list:
                    self._debug('CRC is valid')
                    return crc
                else:
                    _try += 1

            crc_list.append(crc)

        self._debug('CRC reads did not match once')
        return None

    def _readPage(self, page, tries=3):
        """Read a page from the flash and receive it's contents"""
        self._debug('Command: FLASH_READ_PAGE %d' % page)

        # Load page into the buffer
        crc = self._loadPageMultiple(page, tries)

        for _ in range(tries):
            # Dump the buffer
            self._sendCommand(COMMAND_BUFFER_LOAD)

            # Wait for data start
            if not self._waitForMessage(COMMAND_BUFFER_LOAD):
                self._debug('Invalid / no response for BUFFER_LOAD command')
                continue

            # Load successful -> read sector with 2 nibbles per byte
            page_data = self._readExactly(self.page_size * 2)
            if page_data is None:
                self._debug('Invalid / no response for page data')
                continue

            try:
                data = binascii.a2b_hex(page_data.decode(ENCODING))
                if crc == binascii.crc32(data):
                    self._debug('CRC did match with read data')
                    return data
                else:
                    self._debug('CRC did not match with read data')
                    continue

            except TypeError:
                self._debug('CRC could not be parsed')
                continue

        self._debug('Page read tries exceeded')
        return None

    def _writePage(self, page, data):
        """Write a page into the buffer and instruct a page write operation

        This operation checks the written data with a generated checksum.
        """
        assert len(data) == self.page_size, (len(data), data)

        # Write the page and verify that it was written correctly.
        expected_crc = binascii.crc32(data)
        encoded_data = binascii.b2a_hex(data)

        self._sendCommand(COMMAND_BUFFER_STORE + encoded_data.decode(ENCODING))
        if not self._waitForMessage(COMMAND_BUFFER_STORE):
            self._debug('Invalid / no response for BUFFER_STORE command')
            return False

        # This shouldn't fail if we're using a reliable connection.
        crc = self._readExactly(8).decode(ENCODING)
        if crc is None:
            self._debug('Invalid / no CRC response for buffer write')
            return None

        try:
            if int(crc, 16) != expected_crc:
                return None
        except ValueError:
            self._debug('Could not decode CRC')
            return None

        # Write page
        self._sendCommand('%s%08x' % (COMMAND_FLASH_WRITE, page))
        time.sleep(.2)  # Sleep 200 ms

        if not self._waitForMessage(COMMAND_FLASH_WRITE):
            self._debug('Invalid / no response for FLASH_WRITE command')
            return False

        # Read back page
        # Fail if we can't read what we wrote
        read_crc = self._loadPageMultiple(page)
        if read_crc is None:
            self._debug('Invalid / no CRC response for flash write')
            return False

        return (read_crc == expected_crc)

    def _writeSectors(self, offset, data, tries=3):
        """Write one or more sectors with data

        This method clears the sectors before writing to them and checks
        for valid data via reading each page and comparing the checksum.
        """
        assert offset % self.sector_size == 0
        pages_offset = offset // self.page_size
        sectors_offset = offset // self.sector_size

        assert len(data) % self.sector_size == 0
        page_count = len(data) // self.page_size
        sector_count = len(data) // self.sector_size

        with progress.Bar(expected_size=page_count) as bar:
            sector_write_attempt = 0
            sector = 0
            while sector < sector_count:
                sector_index = sectors_offset + sector

                bar.show(sector * self.pages_per_sector)

                # Erase sector up to 'tries' times
                for _ in range(tries):
                    if self._eraseSector(sector_index):
                        break
                else:  # No erase was successful
                    logError('Could not erase sector 0x%08x' % sector_index)
                    return False

                for page in range(self.pages_per_sector):
                    page_data_index = sector * self.pages_per_sector + page
                    data_index = page_data_index * self.page_size
                    page_index = pages_offset + page_data_index

                    if self._writePage(page_index, data[data_index: data_index + self.page_size]):
                        bar.show(page_data_index + 1)
                        continue

                    sector_write_attempt += 1
                    if sector_write_attempt < tries:
                        break  # Retry sector

                    logError('Could not write page 0x%08x' % page_index)
                    return False

                else:  # All pages written normally -> next sector
                    sector += 1

        return True

    def _eraseSectors(self, offset, length, tries=3):
        """Clears one or more sectors"""
        assert offset % self.sector_size == 0
        sectors_offset = offset // self.sector_size

        assert length % self.sector_size == 0
        sector_count = length // self.sector_size

        with progress.Bar(expected_size=sector_count) as bar:
            for sector in range(sector_count):
                sector_index = sectors_offset + sector

                bar.show(sector)

                # Erase sector up to 'tries' times
                for _ in range(tries):
                    if self._eraseSector(sector_index):
                        break
                else:  # No erase was successful
                    logError('Could not erase sector %08x' % sector_index)
                    return False

            bar.show(sector_count)

        return True

    def _hello(self):
        """Send a hello message and expect a version string"""
        self._debug('Command: HELLO')

        # Write hello
        self._sendCommand(COMMAND_HELLO)

        # Wait for hello response start
        if not self._waitForMessage(COMMAND_HELLO):
            self._debug('Invalid / no response for HELLO command')
            return None

        message = self._getUntilMessage(COMMAND_HELLO)
        if message is None:
            self._debug('No termination for HELLO command')
            return None

        return message.decode(ENCODING)

    def _read_register(self, cmd, name):
        """Generic read register function, send cmd and read a <CMD><LEN><DATA> response"""
        self._sendCommand(cmd)
        if not self._waitForMessage(cmd):
            self._debug('Invalid / no response for %s command' % (name,))
            logError('Invalid response')
            return None

        length_str = self._readExactly(2).decode(ENCODING)
        if length_str is None:
            self._debug('Invalid / no response for %s length' % (name,))
            logError('Invalid response')
            return None

        try:
            length = int(length_str, 16)
        except ValueError:
            self._debug('Could not decode %s length' % (name,))
            logError('Invalid register length')
            return None

        data_str = self._readExactly(length * 2).decode(ENCODING)
        if data_str is None:
            self._debug('Invalid / no response for %s check' % (name,))
            logError('Invalid response')
            return None

        try:  # Check if valid data
            decoded_data = binascii.a2b_hex(data_str)
        except TypeError:
            self._debug('Could not decode %s content' % (name))
            logError('Invalid response')
            return None

        return data_str

    def hello(self):
        """Send a hello message and print the retrieved version string"""
        version = self._hello()
        if version is None:
            logError('Connected to unknown device')
            return False
        else:
            logMessage('Connected to \'%s\'' % version.strip())
            return True

    def writeFromFile(self, filename, flash_offset=0, file_offset=0, length=DEFAULT_SECTOR_SIZE, pad=None):
        """Write the data from file to the flash"""
        if pad == None:
            if (length != -1) and (length % self.sector_size != 0):
                logError('length must be a multiple of the sector size %d' % self.sector_size)
                return False

            if flash_offset % self.sector_size != 0:
                logError('flash_offset must be a multiple of the sector size %d' % self.sector_size)
                return False
        elif not ((0x0 <= pad) and (pad <= 0xff)):
            logError('pad must be in range 0x00--0xff')
            return False

        if file_offset < 0:
            logError('file_offset must be a positive value or 0')
            return False

        data = None
        try:
            with open(filename, 'rb') as file:
                file.seek(file_offset)
                data = file.read(length)
        except IOError:
            logError('Could not read from file \'%s\'' % filename)
            return True

        if (length != -1) and (len(data) != length):
            logError('File is not large enough to read %d bytes' % length)
            return True

        if pad != None:
            pad_value = b'%c' % (pad&0xff)
            self._debug("Length of data before padding 0x%x" % (len(data),))

            pad_pre = flash_offset % self.sector_size
            self._debug("Pad 0x%x bytes before data" % (pad_pre,))
            data = pad_value*(flash_offset % self.sector_size) + data

            post_pad = self.sector_size - (len(data) % self.sector_size)
            if post_pad == self.sector_size:
                post_pad = 0x0
            self._debug("Pad 0x%x bytes after data" % (post_pad,))
            data = data + pad_value*(post_pad)

            flash_offset = flash_offset & (self.sector_size-0x1)
        elif (length == -1) and (len(data) % self.sector_size != 0):
            logError('file size must be a multiple of the sector size %d, use --pad' % self.sector_size)
            return False

        if not self._writeSectors(flash_offset, data):
            logError('Aborting')
        else:
            logOk('Done')

        return True

    def readToFile(self, filename, flash_offset=0, length=DEFAULT_FLASH_SIZE):
        """Read the data from the flash into the file"""
        if length % self.page_size != 0:
            logError('length must be a multiple of the page size %d' % self.page_size)
            return False

        if flash_offset % self.page_size != 0:
            logError('flash_offset must be a multiple of the page size %d' % self.page_size)
            return False

        page_count = length // self.page_size
        pages_offset = flash_offset // self.page_size

        try:
            with open(filename, 'wb') as file:
                with progress.Bar(expected_size=page_count) as bar:
                    for page in range(page_count):
                        bar.show(page)

                        page_index = pages_offset + page
                        data = self._readPage(page_index)
                        if data is not None:
                            file.write(data)
                            continue

                        # Invalid data
                        logError('Could not read page 0x%08x' % page_index)
                        return True

                    bar.show(page_count)

            logOk('Done')
            return True
        except IOError:
            logError('Could not write to file \'%s\'' % filename)
            return True

    def verifyWithFile(self, filename, flash_offset=0, file_offset=0, length=DEFAULT_FLASH_SIZE):
        """Verify the flash content by checking against the file

        This method only uses checksums to verify the data integrity.
        """
        if length % self.page_size != 0:
            logError('length must be a multiple of the page size %d' % self.page_size)
            return False

        if flash_offset % self.page_size != 0:
            logError('flash_offset must be a multiple of the page size %d' % self.page_size)
            return False

        page_count = length // self.page_size
        pages_offset = flash_offset // self.page_size

        try:
            with open(filename, 'rb') as file:
                file.seek(file_offset)

                with progress.Bar(expected_size=page_count) as bar:
                    for page in range(page_count):
                        bar.show(page)

                        data = file.read(self.page_size)

                        page_index = pages_offset + page
                        crc = self._loadPageMultiple(page_index)
                        if crc is None:
                            logError('Could not read page 0x%08x' % page_index)
                            return True

                        if crc == binascii.crc32(data):
                            logOk('Page 0x%08x OK' % page_index)
                        else:
                            logError('Page 0x%08x invalid' % page_index)

                    bar.show(page_count)

            logOk('Done')
            return True
        except IOError:
            logError('Could not write to file \'%s\'' % filename)
            return True

    def erase(self, flash_offset=0, length=DEFAULT_FLASH_SIZE):
        """Write the data in the file to the flash"""
        if length % self.sector_size != 0:
            logError('length must be a multiple of the sector size %d' % self.sector_size)
            return False

        if flash_offset % self.sector_size != 0:
            logError('flash_offset must be a multiple of the sector size %d' % self.sector_size)
            return False

        if not self._eraseSectors(flash_offset, length):
            logError('Aborting')
        else:
            logOk('Done')

        return True

    def set_write_protection(self, enable=False):
        """Set or clear the write protection of the flash"""
        self._debug('Command: WITE_PROTECTION %s' % enable)

        # Write command
        if enable:  # Enable
            self._sendCommand(COMMAND_WRITE_PROTECTION_ENABLE)
            if not self._waitForMessage(COMMAND_WRITE_PROTECTION_ENABLE):
                self._debug('Invalid / no response for WITE_PROTECTION command')
                logError('Invalid response')
                return True
        else:  # Disable
            self._sendCommand(COMMAND_WRITE_PROTECTION_DISABLE)
            if not self._waitForMessage(COMMAND_WRITE_PROTECTION_DISABLE):
                self._debug('Invalid / no response for WITE_PROTECTION command')
                logError('Invalid response')
                return True

        logOk('Done')
        return True

    def check_write_protection(self):
        """Check the write protection of the flash"""
        self._debug('Command: WITE_PROTECTION_CHECK')

        self._sendCommand(COMMAND_WRITE_PROTECTION_CHECK)
        if not self._waitForMessage(COMMAND_WRITE_PROTECTION_CHECK):
            self._debug('Invalid / no response for WRITE_PROTECTION_CHECK command')
            logError('Invalid response')
            return True

        protection = self._readExactly(4).decode(ENCODING)
        if protection is None:
            self._debug('Invalid / no response for protection check')
            logError('Invalid response')
            return True

        try:
            configuration_protection = int(protection[0:2], 16)
            write_protection = int(protection[2:4], 16)

            if configuration_protection == WRITE_PROTECTION_CONFIGURATION_NONE:
                logMessage('Configuration is unprotected')
            elif configuration_protection == WRITE_PROTECTION_CONFIGURATION_PARTIAL:
                logMessage('Configuration is partially protected')
            elif configuration_protection == WRITE_PROTECTION_CONFIGURATION_FULL:
                logMessage('Configuration is fully protected')
            elif configuration_protection == WRITE_PROTECTION_CONFIGURATION_UNKNOWN:
                logMessage('Configuration protection is unknown')
            else:
                logError('Unknown configuration protection status')

            if write_protection == WRITE_PROTECTION_NONE:
                logMessage('Flash content is unprotected')
            elif write_protection == WRITE_PROTECTION_PARTIAL:
                logMessage('Flash content is partially protected')
            elif write_protection == WRITE_PROTECTION_FULL:
                logMessage('Flash content is fully protected')
            elif write_protection == WRITE_PROTECTION_UNKNOWN:
                logMessage('Flash content protection is UNKNOWN')
            else:
                logError('Unknown flash protection status')

        except ValueError:
            self._debug('Could not decode protection status')
            logError('Invalid protection status')
            return True

        logOk('Done')
        return True

    def read_status_register(self):
        """Reads the status register contents"""
        self._debug('Command: STATUS_REGISTER')
        data = self._read_register(COMMAND_STATUS_REGISTER_READ, 'STATUS_REGISTER')
        if data==None:
            return True

        self._dump(data)
        return True

    def read_id_register(self):
        """Reads the id register contents"""
        self._debug('Command: ID_REGISTER')
        data = self._read_register(COMMAND_ID_REGISTER_READ, 'ID_REGISTER')
        if data==None:
            return True

        self._dump(data)
        return True

    def set_cs_io(self, io):
        """Overrides the CS/SS IO of Arduino"""
        self._debug('Command: SET_CS_IO')

        self._sendCommand('%s%02x' % (COMMAND_SET_CS_IO, io))
        if not self._waitForMessage(COMMAND_SET_CS_IO):
          self._debug('Invalid / no response for SET_CS_IO command')
          logError('Invalid response')
          return True

        return True

    def set_output(self, io, value):
        """Set IO pin to OUTPUT"""
        self._debug('Command: SET_OUTPUT')

        if value==None:
          value=0x00
        else:
          value=value&0xf
          if value&0xf!=0x0:
            value=0x1
          value=value|0x10

        self._sendCommand('%s%02x%02x' % (COMMAND_SET_OUTPUT, io, value))
        if not self._waitForMessage(COMMAND_SET_OUTPUT):
          self._debug('Invalid / no response for SET_OUTPUT command')
          logError('Invalid response')
          return True

        return True


def printComPorts():
    logMessage('Available COM ports:')

    for i, port in enumerate(list_ports.comports()):
        logMessage('%d: %s' % (i+1, port.device))

    logOk('Done')


def main():
    def hex_dec(x):
        # use auto detect mode, supports 0bYYYY=binary, 0xYYYY=hex, YYYY=decimal
        return int(x,0)

    parser = argparse.ArgumentParser(description='Interface with an Arduino-based SPI flash programmer')
    parser.add_argument('-d', dest='device', default='COM1',
                        help='serial port to communicate with')
    parser.add_argument('-f', dest='filename', default='flash.bin',
                        help='file to read from / write to')
    parser.add_argument('-l', type=hex_dec, dest='length', default=DEFAULT_FLASH_SIZE,
                        help='length to read/write in bytes, use -1 to write entire file')

    parser.add_argument('--rate', type=int, dest='baud_rate', default=115200,
                        help='baud-rate of serial connection')
    parser.add_argument('--flash-offset', type=hex_dec, dest='flash_offset', default=0,
                        help='offset for flash read/write in bytes')
    parser.add_argument('--file-offset', type=hex_dec, dest='file_offset', default=0,
                        help='offset for file read/write in bytes')
    parser.add_argument('--pad', type=hex_dec, default=None,
                        help='pad value if file is not algined with SECTOR_SIZE')
    parser.add_argument('--debug', choices=('off', 'normal', 'verbose'), default='off',
                        help='enable debug output')
    parser.add_argument('--io', type=hex_dec, default=None,
                        help="IO pin used for set-cs-io and set-output")
    parser.add_argument('--value', type=hex_dec, default=None,
                        help="value used for set-output")

    parser.add_argument('command', choices=('ports', 'write', 'read', 'verify', 'erase',
                                            'enable-protection', 'disable-protection', 'check-protection',
                                            'status-register', 'id-register', 'set-cs-io', 'set-output'),
                        help='command to execute')

    args = parser.parse_args()
    if args.command == 'ports':
        printComPorts()
        return

    try:
        programmer = SerialProgrammer(args.device, args.baud_rate, args.debug)
    except serial.SerialException:
        logError('Could not connect to serial port %s' % args.device)
        return

    def write(args, prog):
        return prog.writeFromFile(args.filename, args.flash_offset, args.file_offset, args.length, args.pad)

    def read(args, prog):
        return prog.readToFile(args.filename, args.flash_offset, args.length)

    def verify(args, prog):
        return prog.verifyWithFile(args.filename, args.flash_offset, args.file_offset, args.length)

    def erase(args, prog):
        return prog.erase(args.flash_offset, args.length)

    def enable_protection(args, prog):
        return prog.set_write_protection(True)

    def disable_protection(args, prog):
        return prog.set_write_protection(False)

    def check_protection(args, prog):
        return prog.check_write_protection()

    def read_status_register(args, prog):
        return prog.read_status_register()

    def read_id_register(args, prog):
        return prog.read_id_register()

    def set_cs_io(args, prog):
        return prog.set_cs_io(args.io)

    def set_output(args, prog):
        return prog.set_output(args.io, args.value)

    commands = {
            'write': write,
            'read': read,
            'verify': verify,
            'erase': erase,
            'enable-protection': enable_protection,
            'disable-protection': disable_protection,
            'check-protection': check_protection,
            'status-register': read_status_register,
            'id-register': read_id_register,
            'set-cs-io': set_cs_io,
            'set-output': set_output
        }

    if args.command not in commands:
        logError('Invalid command \'%d\'' % args.command)
        parser.print_help()
        return

    if not programmer.hello():
        # Unrecognized device
        parser.print_help()
        return

    if not commands[args.command](args, programmer):
        # Command got invalid arguments
        parser.print_help()


if __name__ == '__main__':
    main()