import argparse
import ast
import codecs
import collections
import contextlib
import keyword
import re
import string
import sys
import tokenize
import warnings
from typing import Any
from typing import cast
from typing import Container
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Match
from typing import NamedTuple
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from tokenize_rt import NON_CODING_TOKENS
from tokenize_rt import Offset
from tokenize_rt import parse_string_literal
from tokenize_rt import reversed_enumerate
from tokenize_rt import rfind_string_parts
from tokenize_rt import src_to_tokens
from tokenize_rt import Token
from tokenize_rt import tokens_to_src
from tokenize_rt import UNIMPORTANT_WS

MinVersion = Tuple[int, ...]
DotFormatPart = Tuple[str, Optional[str], Optional[str], Optional[str]]
PercentFormatPart = Tuple[
    Optional[str],
    Optional[str],
    Optional[str],
    Optional[str],
    str,
]
PercentFormat = Tuple[str, Optional[PercentFormatPart]]

ListCompOrGeneratorExp = Union[ast.ListComp, ast.GeneratorExp]
ListOrTuple = Union[ast.List, ast.Tuple]
NameOrAttr = Union[ast.Name, ast.Attribute]
AnyFunctionDef = Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda]
SyncFunctionDef = Union[ast.FunctionDef, ast.Lambda]

_stdlib_parse_format = string.Formatter().parse

_KEYWORDS = frozenset(keyword.kwlist)


def parse_format(s: str) -> Tuple[DotFormatPart, ...]:
    """Makes the empty string not a special case.  In the stdlib, there's
    loss of information (the type) on the empty string.
    """
    parsed = tuple(_stdlib_parse_format(s))
    if not parsed:
        return ((s, None, None, None),)
    else:
        return parsed


def unparse_parsed_string(parsed: Sequence[DotFormatPart]) -> str:
    def _convert_tup(tup: DotFormatPart) -> str:
        ret, field_name, format_spec, conversion = tup
        ret = ret.replace('{', '{{')
        ret = ret.replace('}', '}}')
        if field_name is not None:
            ret += '{' + field_name
            if conversion:
                ret += '!' + conversion
            if format_spec:
                ret += ':' + format_spec
            ret += '}'
        return ret

    return ''.join(_convert_tup(tup) for tup in parsed)


def _ast_to_offset(node: Union[ast.expr, ast.stmt]) -> Offset:
    return Offset(node.lineno, node.col_offset)


def ast_parse(contents_text: str) -> ast.Module:
    # intentionally ignore warnings, we might be fixing warning-ridden syntax
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        return ast.parse(contents_text.encode())


def inty(s: str) -> bool:
    try:
        int(s)
        return True
    except (ValueError, TypeError):
        return False


BRACES = {'(': ')', '[': ']', '{': '}'}
OPENING, CLOSING = frozenset(BRACES), frozenset(BRACES.values())
SET_TRANSFORM = (ast.List, ast.ListComp, ast.GeneratorExp, ast.Tuple)


def _is_wtf(func: str, tokens: List[Token], i: int) -> bool:
    return tokens[i].src != func or tokens[i + 1].src != '('


def _process_set_empty_literal(tokens: List[Token], start: int) -> None:
    if _is_wtf('set', tokens, start):
        return

    i = start + 2
    brace_stack = ['(']
    while brace_stack:
        token = tokens[i].src
        if token == BRACES[brace_stack[-1]]:
            brace_stack.pop()
        elif token in BRACES:
            brace_stack.append(token)
        elif '\n' in token:
            # Contains a newline, could cause a SyntaxError, bail
            return
        i += 1

    # Remove the inner tokens
    del tokens[start + 2:i - 1]


def _search_until(tokens: List[Token], idx: int, arg: ast.expr) -> int:
    while (
            idx < len(tokens) and
            not (
                tokens[idx].line == arg.lineno and
                tokens[idx].utf8_byte_offset == arg.col_offset
            )
    ):
        idx += 1
    return idx


if sys.version_info >= (3, 8):  # pragma: no cover (py38+)
    # python 3.8 fixed the offsets of generators / tuples
    def _arg_token_index(tokens: List[Token], i: int, arg: ast.expr) -> int:
        idx = _search_until(tokens, i, arg) + 1
        while idx < len(tokens) and tokens[idx].name in NON_CODING_TOKENS:
            idx += 1
        return idx
else:  # pragma: no cover (<py38)
    def _arg_token_index(tokens: List[Token], i: int, arg: ast.expr) -> int:
        # lists containing non-tuples report the first element correctly
        if isinstance(arg, ast.List):
            # If the first element is a tuple, the ast lies to us about its col
            # offset.  We must find the first `(` token after the start of the
            # list element.
            if isinstance(arg.elts[0], ast.Tuple):
                i = _search_until(tokens, i, arg)
                return _find_open_paren(tokens, i)
            else:
                return _search_until(tokens, i, arg.elts[0])
            # others' start position points at their first child node already
        else:
            return _search_until(tokens, i, arg)


class Victims(NamedTuple):
    starts: List[int]
    ends: List[int]
    first_comma_index: Optional[int]
    arg_index: int


def _victims(
        tokens: List[Token],
        start: int,
        arg: ast.expr,
        gen: bool,
) -> Victims:
    starts = [start]
    start_depths = [1]
    ends: List[int] = []
    first_comma_index = None
    arg_depth = None
    arg_index = _arg_token_index(tokens, start, arg)
    brace_stack = [tokens[start].src]
    i = start + 1

    while brace_stack:
        token = tokens[i].src
        is_start_brace = token in BRACES
        is_end_brace = token == BRACES[brace_stack[-1]]

        if i == arg_index:
            arg_depth = len(brace_stack)

        if is_start_brace:
            brace_stack.append(token)

        # Remove all braces before the first element of the inner
        # comprehension's target.
        if is_start_brace and arg_depth is None:
            start_depths.append(len(brace_stack))
            starts.append(i)

        if (
                token == ',' and
                len(brace_stack) == arg_depth and
                first_comma_index is None
        ):
            first_comma_index = i

        if is_end_brace and len(brace_stack) in start_depths:
            if tokens[i - 2].src == ',' and tokens[i - 1].src == ' ':
                ends.extend((i - 2, i - 1, i))
            elif tokens[i - 1].src == ',':
                ends.extend((i - 1, i))
            else:
                ends.append(i)
            if len(brace_stack) > 1 and tokens[i + 1].src == ',':
                ends.append(i + 1)

        if is_end_brace:
            brace_stack.pop()

        i += 1

    # May need to remove a trailing comma for a comprehension
    if gen:
        i -= 2
        while tokens[i].name in NON_CODING_TOKENS:
            i -= 1
        if tokens[i].src == ',':
            ends.append(i)

    return Victims(starts, sorted(set(ends)), first_comma_index, arg_index)


def _find_token(tokens: List[Token], i: int, src: str) -> int:
    while tokens[i].src != src:
        i += 1
    return i


def _find_open_paren(tokens: List[Token], i: int) -> int:
    return _find_token(tokens, i, '(')


def _is_on_a_line_by_self(tokens: List[Token], i: int) -> bool:
    return (
        tokens[i - 2].name == 'NL' and
        tokens[i - 1].name == UNIMPORTANT_WS and
        tokens[i + 1].name == 'NL'
    )


def _remove_brace(tokens: List[Token], i: int) -> None:
    if _is_on_a_line_by_self(tokens, i):
        del tokens[i - 1:i + 2]
    else:
        del tokens[i]


def _process_set_literal(
        tokens: List[Token],
        start: int,
        arg: ast.expr,
) -> None:
    if _is_wtf('set', tokens, start):
        return

    gen = isinstance(arg, ast.GeneratorExp)
    set_victims = _victims(tokens, start + 1, arg, gen=gen)

    del set_victims.starts[0]
    end_index = set_victims.ends.pop()

    tokens[end_index] = Token('OP', '}')
    for index in reversed(set_victims.starts + set_victims.ends):
        _remove_brace(tokens, index)
    tokens[start:start + 2] = [Token('OP', '{')]


def _process_dict_comp(
        tokens: List[Token],
        start: int,
        arg: ListCompOrGeneratorExp,
) -> None:
    if _is_wtf('dict', tokens, start):
        return

    dict_victims = _victims(tokens, start + 1, arg, gen=True)
    elt_victims = _victims(tokens, dict_victims.arg_index, arg.elt, gen=True)

    del dict_victims.starts[0]
    end_index = dict_victims.ends.pop()

    tokens[end_index] = Token('OP', '}')
    for index in reversed(dict_victims.ends):
        _remove_brace(tokens, index)
    # See #6, Fix SyntaxError from rewriting dict((a, b)for a, b in y)
    if tokens[elt_victims.ends[-1] + 1].src == 'for':
        tokens.insert(elt_victims.ends[-1] + 1, Token(UNIMPORTANT_WS, ' '))
    for index in reversed(elt_victims.ends):
        _remove_brace(tokens, index)
    assert elt_victims.first_comma_index is not None
    tokens[elt_victims.first_comma_index] = Token('OP', ':')
    for index in reversed(dict_victims.starts + elt_victims.starts):
        _remove_brace(tokens, index)
    tokens[start:start + 2] = [Token('OP', '{')]


def _process_is_literal(
        tokens: List[Token],
        i: int,
        compare: Union[ast.Is, ast.IsNot],
) -> None:
    while tokens[i].src != 'is':
        i -= 1
    if isinstance(compare, ast.Is):
        tokens[i] = tokens[i]._replace(src='==')
    else:
        tokens[i] = tokens[i]._replace(src='!=')
        # since we iterate backward, the dummy tokens keep the same length
        i += 1
        while tokens[i].src != 'not':
            tokens[i] = Token('DUMMY', '')
            i += 1
        tokens[i] = Token('DUMMY', '')


LITERAL_TYPES = (ast.Str, ast.Num, ast.Bytes)


class Py2CompatibleVisitor(ast.NodeVisitor):
    def __init__(self) -> None:
        self.dicts: Dict[Offset, ListCompOrGeneratorExp] = {}
        self.sets: Dict[Offset, ast.expr] = {}
        self.set_empty_literals: Dict[Offset, ListOrTuple] = {}
        self.is_literal: Dict[Offset, Union[ast.Is, ast.IsNot]] = {}

    def visit_Call(self, node: ast.Call) -> None:
        if (
                isinstance(node.func, ast.Name) and
                node.func.id == 'set' and
                len(node.args) == 1 and
                not node.keywords and
                isinstance(node.args[0], SET_TRANSFORM)
        ):
            arg, = node.args
            key = _ast_to_offset(node.func)
            if isinstance(arg, (ast.List, ast.Tuple)) and not arg.elts:
                self.set_empty_literals[key] = arg
            else:
                self.sets[key] = arg
        elif (
                isinstance(node.func, ast.Name) and
                node.func.id == 'dict' and
                len(node.args) == 1 and
                not node.keywords and
                isinstance(node.args[0], (ast.ListComp, ast.GeneratorExp)) and
                isinstance(node.args[0].elt, (ast.Tuple, ast.List)) and
                len(node.args[0].elt.elts) == 2
        ):
            self.dicts[_ast_to_offset(node.func)] = node.args[0]
        self.generic_visit(node)

    def visit_Compare(self, node: ast.Compare) -> None:
        left = node.left
        for op, right in zip(node.ops, node.comparators):
            if (
                    isinstance(op, (ast.Is, ast.IsNot)) and
                    (
                        isinstance(left, LITERAL_TYPES) or
                        isinstance(right, LITERAL_TYPES)
                    )
            ):
                self.is_literal[_ast_to_offset(right)] = op
            left = right

        self.generic_visit(node)


def _fix_py2_compatible(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text
    visitor = Py2CompatibleVisitor()
    visitor.visit(ast_obj)
    if not any((
            visitor.dicts,
            visitor.sets,
            visitor.set_empty_literals,
            visitor.is_literal,
    )):
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.dicts:
            _process_dict_comp(tokens, i, visitor.dicts[token.offset])
        elif token.offset in visitor.set_empty_literals:
            _process_set_empty_literal(tokens, i)
        elif token.offset in visitor.sets:
            _process_set_literal(tokens, i, visitor.sets[token.offset])
        elif token.offset in visitor.is_literal:
            _process_is_literal(tokens, i, visitor.is_literal[token.offset])
    return tokens_to_src(tokens)


def _imports_unicode_literals(contents_text: str) -> bool:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return False

    for node in ast_obj.body:
        # Docstring
        if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
            continue
        elif isinstance(node, ast.ImportFrom):
            if (
                node.level == 0 and
                node.module == '__future__' and
                any(name.name == 'unicode_literals' for name in node.names)
            ):
                return True
            elif node.module == '__future__':
                continue
            else:
                return False
        else:
            return False

    return False


# https://docs.python.org/3/reference/lexical_analysis.html
ESCAPE_STARTS = frozenset((
    '\n', '\r', '\\', "'", '"', 'a', 'b', 'f', 'n', 'r', 't', 'v',
    '0', '1', '2', '3', '4', '5', '6', '7',  # octal escapes
    'x',  # hex escapes
))
ESCAPE_RE = re.compile(r'\\.', re.DOTALL)
NAMED_ESCAPE_NAME = re.compile(r'\{[^}]+\}')


def _fix_escape_sequences(token: Token) -> Token:
    prefix, rest = parse_string_literal(token.src)
    actual_prefix = prefix.lower()

    if 'r' in actual_prefix or '\\' not in rest:
        return token

    is_bytestring = 'b' in actual_prefix

    def _is_valid_escape(match: Match[str]) -> bool:
        c = match.group()[1]
        return (
            c in ESCAPE_STARTS or
            (not is_bytestring and c in 'uU') or
            (
                not is_bytestring and
                c == 'N' and
                bool(NAMED_ESCAPE_NAME.match(rest, match.end()))
            )
        )

    has_valid_escapes = False
    has_invalid_escapes = False
    for match in ESCAPE_RE.finditer(rest):
        if _is_valid_escape(match):
            has_valid_escapes = True
        else:
            has_invalid_escapes = True

    def cb(match: Match[str]) -> str:
        matched = match.group()
        if _is_valid_escape(match):
            return matched
        else:
            return fr'\{matched}'

    if has_invalid_escapes and (has_valid_escapes or 'u' in actual_prefix):
        return token._replace(src=prefix + ESCAPE_RE.sub(cb, rest))
    elif has_invalid_escapes and not has_valid_escapes:
        return token._replace(src=prefix + 'r' + rest)
    else:
        return token


def _remove_u_prefix(token: Token) -> Token:
    prefix, rest = parse_string_literal(token.src)
    if 'u' not in prefix.lower():
        return token
    else:
        new_prefix = prefix.replace('u', '').replace('U', '')
        return token._replace(src=new_prefix + rest)


def _fix_ur_literals(token: Token) -> Token:
    prefix, rest = parse_string_literal(token.src)
    if prefix.lower() != 'ur':
        return token
    else:
        def cb(match: Match[str]) -> str:
            escape = match.group()
            if escape[1].lower() == 'u':
                return escape
            else:
                return '\\' + match.group()

        rest = ESCAPE_RE.sub(cb, rest)
        prefix = prefix.replace('r', '').replace('R', '')
        return token._replace(src=prefix + rest)


def _fix_long(src: str) -> str:
    return src.rstrip('lL')


def _fix_octal(s: str) -> str:
    if not s.startswith('0') or not s.isdigit() or s == len(s) * '0':
        return s
    elif len(s) == 2:
        return s[1:]
    else:
        return '0o' + s[1:]


def _fix_extraneous_parens(tokens: List[Token], i: int) -> None:
    # search forward for another non-coding token
    i += 1
    while tokens[i].name in NON_CODING_TOKENS:
        i += 1
    # if we did not find another brace, return immediately
    if tokens[i].src != '(':
        return

    start = i
    depth = 1
    while depth:
        i += 1
        # found comma or yield at depth 1: this is a tuple / coroutine
        if depth == 1 and tokens[i].src in {',', 'yield'}:
            return
        elif tokens[i].src in OPENING:
            depth += 1
        elif tokens[i].src in CLOSING:
            depth -= 1
    end = i

    # empty tuple
    if all(t.name in NON_CODING_TOKENS for t in tokens[start + 1:i]):
        return

    # search forward for the next non-coding token
    i += 1
    while tokens[i].name in NON_CODING_TOKENS:
        i += 1

    if tokens[i].src == ')':
        _remove_brace(tokens, end)
        _remove_brace(tokens, start)


def _remove_fmt(tup: DotFormatPart) -> DotFormatPart:
    if tup[1] is None:
        return tup
    else:
        return (tup[0], '', tup[2], tup[3])


def _fix_format_literal(tokens: List[Token], end: int) -> None:
    parts = rfind_string_parts(tokens, end)
    parsed_parts = []
    last_int = -1
    for i in parts:
        try:
            parsed = parse_format(tokens[i].src)
        except ValueError:
            # the format literal was malformed, skip it
            return

        # The last segment will always be the end of the string and not a
        # format, slice avoids the `None` format key
        for _, fmtkey, spec, _ in parsed[:-1]:
            if (
                    fmtkey is not None and inty(fmtkey) and
                    int(fmtkey) == last_int + 1 and
                    spec is not None and '{' not in spec
            ):
                last_int += 1
            else:
                return

        parsed_parts.append(tuple(_remove_fmt(tup) for tup in parsed))

    for i, parsed in zip(parts, parsed_parts):
        tokens[i] = tokens[i]._replace(src=unparse_parsed_string(parsed))


def _fix_encode_to_binary(tokens: List[Token], i: int) -> None:
    # .encode()
    if (
            i + 2 < len(tokens) and
            tokens[i + 1].src == '(' and
            tokens[i + 2].src == ')'
    ):
        victims = slice(i - 1, i + 3)
        latin1_ok = False
    # .encode('encoding')
    elif (
            i + 3 < len(tokens) and
            tokens[i + 1].src == '(' and
            tokens[i + 2].name == 'STRING' and
            tokens[i + 3].src == ')'
    ):
        victims = slice(i - 1, i + 4)
        prefix, rest = parse_string_literal(tokens[i + 2].src)
        if 'f' in prefix.lower():
            return
        encoding = ast.literal_eval(prefix + rest)
        if _is_codec(encoding, 'ascii') or _is_codec(encoding, 'utf-8'):
            latin1_ok = False
        elif _is_codec(encoding, 'iso8859-1'):
            latin1_ok = True
        else:
            return
    else:
        return

    parts = rfind_string_parts(tokens, i - 2)
    if not parts:
        return

    for part in parts:
        prefix, rest = parse_string_literal(tokens[part].src)
        escapes = set(ESCAPE_RE.findall(rest))
        if (
                not _is_ascii(rest) or
                '\\u' in escapes or
                '\\U' in escapes or
                '\\N' in escapes or
                ('\\x' in escapes and not latin1_ok) or
                'f' in prefix.lower()
        ):
            return

    for part in parts:
        prefix, rest = parse_string_literal(tokens[part].src)
        prefix = 'b' + prefix.replace('u', '').replace('U', '')
        tokens[part] = tokens[part]._replace(src=prefix + rest)
    del tokens[victims]


def _build_import_removals() -> Dict[MinVersion, Dict[str, Tuple[str, ...]]]:
    ret = {}
    future: Tuple[Tuple[MinVersion, Tuple[str, ...]], ...] = (
        ((2, 7), ('nested_scopes', 'generators', 'with_statement')),
        (
            (3,), (
                'absolute_import', 'division', 'print_function',
                'unicode_literals',
            ),
        ),
        ((3, 6), ()),
        ((3, 7), ('generator_stop',)),
        ((3, 8), ()),
    )

    prev: Tuple[str, ...] = ()
    for min_version, names in future:
        prev += names
        ret[min_version] = {'__future__': prev}
    # see reorder_python_imports
    for k, v in ret.items():
        if k >= (3,):
            v.update({
                'builtins': (
                    'ascii', 'bytes', 'chr', 'dict', 'filter', 'hex', 'input',
                    'int', 'list', 'map', 'max', 'min', 'next', 'object',
                    'oct', 'open', 'pow', 'range', 'round', 'str', 'super',
                    'zip', '*',
                ),
                'io': ('open',),
                'six': ('callable', 'next'),
                'six.moves': ('filter', 'input', 'map', 'range', 'zip'),
            })
    return ret


IMPORT_REMOVALS = _build_import_removals()


def _fix_import_removals(
        tokens: List[Token],
        start: int,
        min_version: MinVersion,
) -> None:
    i = start + 1
    name_parts = []
    while tokens[i].src != 'import':
        if tokens[i].name in {'NAME', 'OP'}:
            name_parts.append(tokens[i].src)
        i += 1

    modname = ''.join(name_parts)
    if modname not in IMPORT_REMOVALS[min_version]:
        return

    found: List[Optional[int]] = []
    i += 1
    while tokens[i].name not in {'NEWLINE', 'ENDMARKER'}:
        if tokens[i].name == 'NAME' or tokens[i].src == '*':
            # don't touch aliases
            if (
                    found and found[-1] is not None and
                    tokens[found[-1]].src == 'as'
            ):
                found[-2:] = [None]
            else:
                found.append(i)
        i += 1
    # depending on the version of python, some will not emit NEWLINE('') at the
    # end of a file which does not end with a newline (for example 3.6.5)
    if tokens[i].name == 'ENDMARKER':  # pragma: no cover
        i -= 1

    remove_names = IMPORT_REMOVALS[min_version][modname]
    to_remove = [
        x for x in found if x is not None and tokens[x].src in remove_names
    ]
    if len(to_remove) == len(found):
        del tokens[start:i + 1]
    else:
        for idx in reversed(to_remove):
            if found[0] == idx:  # look forward until next name and del
                j = idx + 1
                while tokens[j].name != 'NAME':
                    j += 1
                del tokens[idx:j]
            else:  # look backward for comma and del
                j = idx
                while tokens[j].src != ',':
                    j -= 1
                del tokens[j:idx + 1]


def _fix_tokens(contents_text: str, min_version: MinVersion) -> str:
    remove_u = min_version >= (3,) or _imports_unicode_literals(contents_text)

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.name == 'NUMBER':
            tokens[i] = token._replace(src=_fix_long(_fix_octal(token.src)))
        elif token.name == 'STRING':
            tokens[i] = _fix_ur_literals(tokens[i])
            if remove_u:
                tokens[i] = _remove_u_prefix(tokens[i])
            tokens[i] = _fix_escape_sequences(tokens[i])
        elif token.src == '(':
            _fix_extraneous_parens(tokens, i)
        elif token.src == 'format' and i > 0 and tokens[i - 1].src == '.':
            _fix_format_literal(tokens, i - 2)
        elif token.src == 'encode' and i > 0 and tokens[i - 1].src == '.':
            _fix_encode_to_binary(tokens, i)
        elif (
                min_version >= (3,) and
                token.utf8_byte_offset == 0 and
                token.line < 3 and
                token.name == 'COMMENT' and
                tokenize.cookie_re.match(token.src)
        ):
            del tokens[i]
            assert tokens[i].name == 'NL', tokens[i].name
            del tokens[i]
        elif token.src == 'from' and token.utf8_byte_offset == 0:
            _fix_import_removals(tokens, i, min_version)
    return tokens_to_src(tokens).lstrip()


MAPPING_KEY_RE = re.compile(r'\(([^()]*)\)')
CONVERSION_FLAG_RE = re.compile('[#0+ -]*')
WIDTH_RE = re.compile(r'(?:\*|\d*)')
PRECISION_RE = re.compile(r'(?:\.(?:\*|\d*))?')
LENGTH_RE = re.compile('[hlL]?')


def _must_match(regex: Pattern[str], string: str, pos: int) -> Match[str]:
    match = regex.match(string, pos)
    assert match is not None
    return match


def parse_percent_format(s: str) -> Tuple[PercentFormat, ...]:
    def _parse_inner() -> Generator[PercentFormat, None, None]:
        string_start = 0
        string_end = 0
        in_fmt = False

        i = 0
        while i < len(s):
            if not in_fmt:
                try:
                    i = s.index('%', i)
                except ValueError:  # no more % fields!
                    yield s[string_start:], None
                    return
                else:
                    string_end = i
                    i += 1
                    in_fmt = True
            else:
                key_match = MAPPING_KEY_RE.match(s, i)
                if key_match:
                    key: Optional[str] = key_match.group(1)
                    i = key_match.end()
                else:
                    key = None

                conversion_flag_match = _must_match(CONVERSION_FLAG_RE, s, i)
                conversion_flag = conversion_flag_match.group() or None
                i = conversion_flag_match.end()

                width_match = _must_match(WIDTH_RE, s, i)
                width = width_match.group() or None
                i = width_match.end()

                precision_match = _must_match(PRECISION_RE, s, i)
                precision = precision_match.group() or None
                i = precision_match.end()

                # length modifier is ignored
                i = _must_match(LENGTH_RE, s, i).end()

                try:
                    conversion = s[i]
                except IndexError:
                    raise ValueError('end-of-string while parsing format')
                i += 1

                fmt = (key, conversion_flag, width, precision, conversion)
                yield s[string_start:string_end], fmt

                in_fmt = False
                string_start = i

        if in_fmt:
            raise ValueError('end-of-string while parsing format')

    return tuple(_parse_inner())


class FindPercentFormats(ast.NodeVisitor):
    def __init__(self) -> None:
        self.found: Dict[Offset, ast.BinOp] = {}

    def visit_BinOp(self, node: ast.BinOp) -> None:
        if isinstance(node.op, ast.Mod) and isinstance(node.left, ast.Str):
            try:
                parsed = parse_percent_format(node.left.s)
            except ValueError:
                pass
            else:
                for _, fmt in parsed:
                    if not fmt:
                        continue
                    key, conversion_flag, width, precision, conversion = fmt
                    # timid: these require out-of-order parameter consumption
                    if width == '*' or precision == '.*':
                        break
                    # these conversions require modification of parameters
                    if conversion in {'d', 'i', 'u', 'c'}:
                        break
                    # timid: py2: %#o formats different from {:#o} (--py3?)
                    if '#' in (conversion_flag or '') and conversion == 'o':
                        break
                    # no equivalent in format
                    if key == '':
                        break
                    # timid: py2: conversion is subject to modifiers (--py3?)
                    nontrivial_fmt = any((conversion_flag, width, precision))
                    if conversion == '%' and nontrivial_fmt:
                        break
                    # no equivalent in format
                    if conversion in {'a', 'r'} and nontrivial_fmt:
                        break
                    # all dict substitutions must be named
                    if isinstance(node.right, ast.Dict) and not key:
                        break
                else:
                    self.found[_ast_to_offset(node)] = node
        self.generic_visit(node)


def _simplify_conversion_flag(flag: str) -> str:
    parts: List[str] = []
    for c in flag:
        if c in parts:
            continue
        c = c.replace('-', '<')
        parts.append(c)
        if c == '<' and '0' in parts:
            parts.remove('0')
        elif c == '+' and ' ' in parts:
            parts.remove(' ')
    return ''.join(parts)


def _percent_to_format(s: str) -> str:
    def _handle_part(part: PercentFormat) -> str:
        s, fmt = part
        s = s.replace('{', '{{').replace('}', '}}')
        if fmt is None:
            return s
        else:
            key, conversion_flag, width, precision, conversion = fmt
            if conversion == '%':
                return s + '%'
            parts = [s, '{']
            if width and conversion == 's' and not conversion_flag:
                conversion_flag = '>'
            if conversion == 's':
                conversion = ''
            if key:
                parts.append(key)
            if conversion in {'r', 'a'}:
                converter = f'!{conversion}'
                conversion = ''
            else:
                converter = ''
            if any((conversion_flag, width, precision, conversion)):
                parts.append(':')
            if conversion_flag:
                parts.append(_simplify_conversion_flag(conversion_flag))
            parts.extend(x for x in (width, precision, conversion) if x)
            parts.extend(converter)
            parts.append('}')
            return ''.join(parts)

    return ''.join(_handle_part(part) for part in parse_percent_format(s))


def _is_ascii(s: str) -> bool:
    if sys.version_info >= (3, 7):  # pragma: no cover (py37+)
        return s.isascii()
    else:  # pragma: no cover (<py37)
        return all(c in string.printable for c in s)


def _fix_percent_format_tuple(
        tokens: List[Token],
        start: int,
        node: ast.BinOp,
) -> None:
    # TODO: this is overly timid
    paren = start + 4
    if tokens_to_src(tokens[start + 1:paren + 1]) != ' % (':
        return

    victims = _victims(tokens, paren, node.right, gen=False)
    victims.ends.pop()

    for index in reversed(victims.starts + victims.ends):
        _remove_brace(tokens, index)

    newsrc = _percent_to_format(tokens[start].src)
    tokens[start] = tokens[start]._replace(src=newsrc)
    tokens[start + 1:paren] = [Token('Format', '.format'), Token('OP', '(')]


def _fix_percent_format_dict(
        tokens: List[Token],
        start: int,
        node: ast.BinOp,
) -> None:
    seen_keys: Set[str] = set()
    keys = {}

    # the caller has enforced this
    assert isinstance(node.right, ast.Dict)
    for k in node.right.keys:
        # not a string key
        if not isinstance(k, ast.Str):
            return
        # duplicate key
        elif k.s in seen_keys:
            return
        # not an identifier
        elif not k.s.isidentifier():
            return
        # a keyword
        elif k.s in _KEYWORDS:
            return
        seen_keys.add(k.s)
        keys[_ast_to_offset(k)] = k

    # TODO: this is overly timid
    brace = start + 4
    if tokens_to_src(tokens[start + 1:brace + 1]) != ' % {':
        return

    victims = _victims(tokens, brace, node.right, gen=False)
    brace_end = victims.ends[-1]

    key_indices = []
    for i, token in enumerate(tokens[brace:brace_end], brace):
        key = keys.pop(token.offset, None)
        if key is None:
            continue
        # we found the key, but the string didn't match (implicit join?)
        elif ast.literal_eval(token.src) != key.s:
            return
        # the map uses some strange syntax that's not `'key': value`
        elif tokens_to_src(tokens[i + 1:i + 3]) != ': ':
            return
        else:
            key_indices.append((i, key.s))
    assert not keys, keys

    tokens[brace_end] = tokens[brace_end]._replace(src=')')
    for (key_index, s) in reversed(key_indices):
        tokens[key_index:key_index + 3] = [Token('CODE', f'{s}=')]
    newsrc = _percent_to_format(tokens[start].src)
    tokens[start] = tokens[start]._replace(src=newsrc)
    tokens[start + 1:brace + 1] = [Token('CODE', '.format'), Token('OP', '(')]


def _fix_percent_format(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPercentFormats()
    visitor.visit(ast_obj)

    if not visitor.found:
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text

    for i, token in reversed_enumerate(tokens):
        node = visitor.found.get(token.offset)
        if node is None:
            continue

        # TODO: handle \N escape sequences
        if r'\N' in token.src:
            continue

        if isinstance(node.right, ast.Tuple):
            _fix_percent_format_tuple(tokens, i, node)
        elif isinstance(node.right, ast.Dict):
            _fix_percent_format_dict(tokens, i, node)

    return tokens_to_src(tokens)


SIX_SIMPLE_ATTRS = {
    'text_type': 'str',
    'binary_type': 'bytes',
    'class_types': '(type,)',
    'string_types': '(str,)',
    'integer_types': '(int,)',
    'unichr': 'chr',
    'iterbytes': 'iter',
    'print_': 'print',
    'exec_': 'exec',
    'advance_iterator': 'next',
    'next': 'next',
    'callable': 'callable',
}
SIX_TYPE_CTX_ATTRS = {
    'class_types': 'type',
    'string_types': 'str',
    'integer_types': 'int',
}
SIX_CALLS = {
    'u': '{args[0]}',
    'byte2int': '{args[0]}[0]',
    'indexbytes': '{args[0]}[{rest}]',
    'int2byte': 'bytes(({args[0]},))',
    'iteritems': '{args[0]}.items()',
    'iterkeys': '{args[0]}.keys()',
    'itervalues': '{args[0]}.values()',
    'viewitems': '{args[0]}.items()',
    'viewkeys': '{args[0]}.keys()',
    'viewvalues': '{args[0]}.values()',
    'create_unbound_method': '{args[0]}',
    'get_unbound_function': '{args[0]}',
    'get_method_function': '{args[0]}.__func__',
    'get_method_self': '{args[0]}.__self__',
    'get_function_closure': '{args[0]}.__closure__',
    'get_function_code': '{args[0]}.__code__',
    'get_function_defaults': '{args[0]}.__defaults__',
    'get_function_globals': '{args[0]}.__globals__',
    'assertCountEqual': '{args[0]}.assertCountEqual({rest})',
    'assertRaisesRegex': '{args[0]}.assertRaisesRegex({rest})',
    'assertRegex': '{args[0]}.assertRegex({rest})',
}
SIX_B_TMPL = 'b{args[0]}'
WITH_METACLASS_NO_BASES_TMPL = 'metaclass={args[0]}'
WITH_METACLASS_BASES_TMPL = '{rest}, metaclass={args[0]}'
RAISE_FROM_TMPL = 'raise {args[0]} from {rest}'
RERAISE_2_TMPL = 'raise {args[1]}.with_traceback(None)'
RERAISE_3_TMPL = 'raise {args[1]}.with_traceback({args[2]})'
SIX_NATIVE_STR = frozenset(('ensure_str', 'ensure_text', 'text_type'))
U_MODE_REMOVE = frozenset(('U', 'Ur', 'rU', 'r', 'rt', 'tr'))
U_MODE_REPLACE_R = frozenset(('Ub', 'bU'))
U_MODE_REMOVE_U = frozenset(('rUb', 'Urb', 'rbU', 'Ubr', 'bUr', 'brU'))
U_MODE_REPLACE = U_MODE_REPLACE_R | U_MODE_REMOVE_U


def _all_isinstance(
        vals: Iterable[Any],
        tp: Union[Type[Any], Tuple[Type[Any], ...]],
) -> bool:
    return all(isinstance(v, tp) for v in vals)


def fields_same(n1: ast.AST, n2: ast.AST) -> bool:
    for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
        # ignore ast attributes, they'll be covered by walk
        if a1 != a2:
            return False
        elif _all_isinstance((v1, v2), ast.AST):
            continue
        elif _all_isinstance((v1, v2), (list, tuple)):
            if len(v1) != len(v2):
                return False
            # ignore sequences which are all-ast, they'll be covered by walk
            elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
                continue
            elif v1 != v2:
                return False
        elif v1 != v2:
            return False
    return True


def targets_same(target: ast.AST, yield_value: ast.AST) -> bool:
    for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)):
        # ignore `ast.Load` / `ast.Store`
        if _all_isinstance((t1, t2), ast.expr_context):
            continue
        elif type(t1) != type(t2):
            return False
        elif not fields_same(t1, t2):
            return False
    else:
        return True


def _is_codec(encoding: str, name: str) -> bool:
    try:
        return codecs.lookup(encoding).name == name
    except LookupError:
        return False


class FindPy3Plus(ast.NodeVisitor):
    OS_ERROR_ALIASES = frozenset((
        'EnvironmentError',
        'IOError',
        'WindowsError',
    ))

    OS_ERROR_ALIAS_MODULES = frozenset((
        'mmap',
        'select',
        'socket',
    ))

    FROM_IMPORTED_MODULES = OS_ERROR_ALIAS_MODULES.union(('functools', 'six'))

    MOCK_MODULES = frozenset(('mock', 'mock.mock'))

    class ClassInfo:
        def __init__(self, name: str) -> None:
            self.name = name
            self.def_depth = 0
            self.first_arg_name = ''

    class Scope:
        def __init__(self) -> None:
            self.reads: Set[str] = set()
            self.writes: Set[str] = set()

            self.yield_from_fors: Set[Offset] = set()
            self.yield_from_names: Dict[str, Set[Offset]]
            self.yield_from_names = collections.defaultdict(set)

    def __init__(self, keep_mock: bool) -> None:
        self._find_mock = not keep_mock

        self.bases_to_remove: Set[Offset] = set()

        self.encode_calls: Dict[Offset, ast.Call] = {}

        self._version_info_imported = False
        self.if_py3_blocks: Set[Offset] = set()
        self.if_py2_blocks_else: Set[Offset] = set()
        self.if_py3_blocks_else: Set[Offset] = set()
        self.metaclass_type_assignments: Set[Offset] = set()

        self.native_literals: Set[Offset] = set()

        self._from_imports: Dict[str, Set[str]] = collections.defaultdict(set)
        self.io_open_calls: Set[Offset] = set()
        self.mock_mock: Set[Offset] = set()
        self.mock_absolute_imports: Set[Offset] = set()
        self.mock_relative_imports: Set[Offset] = set()
        self.open_mode_calls: Set[Offset] = set()
        self.os_error_alias_calls: Set[Offset] = set()
        self.os_error_alias_simple: Dict[Offset, NameOrAttr] = {}
        self.os_error_alias_excepts: Set[Offset] = set()

        self.six_add_metaclass: Set[Offset] = set()
        self.six_b: Set[Offset] = set()
        self.six_calls: Dict[Offset, ast.Call] = {}
        self.six_iter: Dict[Offset, ast.Call] = {}
        self._previous_node: Optional[ast.AST] = None
        self.six_raise_from: Set[Offset] = set()
        self.six_reraise: Set[Offset] = set()
        self.six_remove_decorators: Set[Offset] = set()
        self.six_simple: Dict[Offset, NameOrAttr] = {}
        self.six_type_ctx: Dict[Offset, NameOrAttr] = {}
        self.six_with_metaclass: Set[Offset] = set()

        self._class_info_stack: List[FindPy3Plus.ClassInfo] = []
        self._in_comp = 0
        self.super_calls: Dict[Offset, ast.Call] = {}
        self._in_async_def = False
        self._scope_stack: List[FindPy3Plus.Scope] = []
        self.yield_from_fors: Set[Offset] = set()

        self.no_arg_decorators: Set[Offset] = set()

    def _is_six(self, node: ast.expr, names: Container[str]) -> bool:
        return (
            isinstance(node, ast.Name) and
            node.id in names and
            node.id in self._from_imports['six']
        ) or (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id == 'six' and
            node.attr in names
        )

    def _is_lru_cache(self, node: ast.expr) -> bool:
        return (
            isinstance(node, ast.Name) and
            node.id == 'lru_cache' and
            node.id in self._from_imports['functools']
        ) or (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id == 'functools' and
            node.attr == 'lru_cache'
        )

    def _is_mock_mock(self, node: ast.expr) -> bool:
        return (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id == 'mock' and
            node.attr == 'mock'
        )

    def _is_io_open(self, node: ast.expr) -> bool:
        return (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id == 'io' and
            node.attr == 'open'
        )

    def _is_os_error_alias(self, node: Optional[ast.expr]) -> bool:
        return (
            isinstance(node, ast.Name) and
            node.id in self.OS_ERROR_ALIASES
        ) or (
            isinstance(node, ast.Name) and
            node.id == 'error' and
            (
                node.id in self._from_imports['mmap'] or
                node.id in self._from_imports['select'] or
                node.id in self._from_imports['socket']
            )
        ) or (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id in self.OS_ERROR_ALIAS_MODULES and
            node.attr == 'error'
        )

    def _is_version_info(self, node: ast.expr) -> bool:
        return (
            isinstance(node, ast.Name) and
            node.id == 'version_info' and
            self._version_info_imported
        ) or (
            isinstance(node, ast.Attribute) and
            isinstance(node.value, ast.Name) and
            node.value.id == 'sys' and
            node.attr == 'version_info'
        )

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        if not node.level:
            if node.module in self.FROM_IMPORTED_MODULES:
                for name in node.names:
                    if not name.asname:
                        self._from_imports[node.module].add(name.name)
            elif self._find_mock and node.module in self.MOCK_MODULES:
                self.mock_relative_imports.add(_ast_to_offset(node))
            elif node.module == 'sys' and any(
                name.name == 'version_info' and not name.asname
                for name in node.names
            ):
                self._version_info_imported = True
        self.generic_visit(node)

    def visit_Import(self, node: ast.Import) -> None:
        if (
                self._find_mock and
                len(node.names) == 1 and
                node.names[0].name in self.MOCK_MODULES
        ):
            self.mock_absolute_imports.add(_ast_to_offset(node))

        self.generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        for decorator in node.decorator_list:
            if self._is_six(decorator, ('python_2_unicode_compatible',)):
                self.six_remove_decorators.add(_ast_to_offset(decorator))
            elif (
                    isinstance(decorator, ast.Call) and
                    self._is_six(decorator.func, ('add_metaclass',)) and
                    not _starargs(decorator)
            ):
                self.six_add_metaclass.add(_ast_to_offset(decorator))

        for base in node.bases:
            if isinstance(base, ast.Name) and base.id == 'object':
                self.bases_to_remove.add(_ast_to_offset(base))
            elif self._is_six(base, ('Iterator',)):
                self.bases_to_remove.add(_ast_to_offset(base))

        if (
                len(node.bases) == 1 and
                isinstance(node.bases[0], ast.Call) and
                self._is_six(node.bases[0].func, ('with_metaclass',)) and
                not _starargs(node.bases[0])
        ):
            self.six_with_metaclass.add(_ast_to_offset(node.bases[0]))

        self._class_info_stack.append(FindPy3Plus.ClassInfo(node.name))
        self.generic_visit(node)
        self._class_info_stack.pop()

    @contextlib.contextmanager
    def _track_def_depth(
            self,
            node: AnyFunctionDef,
    ) -> Generator[None, None, None]:
        class_info = self._class_info_stack[-1]
        class_info.def_depth += 1
        if class_info.def_depth == 1 and node.args.args:
            class_info.first_arg_name = node.args.args[0].arg
        try:
            yield
        finally:
            class_info.def_depth -= 1

    @contextlib.contextmanager
    def _scope(self) -> Generator[None, None, None]:
        self._scope_stack.append(FindPy3Plus.Scope())
        try:
            yield
        finally:
            info = self._scope_stack.pop()
            # discard any that were referenced outside of the loop
            for name in info.reads:
                offsets = info.yield_from_names[name]
                info.yield_from_fors.difference_update(offsets)
            self.yield_from_fors.update(info.yield_from_fors)
            if self._scope_stack:
                cell_reads = info.reads - info.writes
                self._scope_stack[-1].reads.update(cell_reads)

    def _visit_func(self, node: AnyFunctionDef) -> None:
        with contextlib.ExitStack() as ctx, self._scope():
            if self._class_info_stack:
                ctx.enter_context(self._track_def_depth(node))
            self.generic_visit(node)

    def _visit_sync_func(self, node: SyncFunctionDef) -> None:
        self._in_async_def, orig = False, self._in_async_def
        self._visit_func(node)
        self._in_async_def = orig

    visit_FunctionDef = visit_Lambda = _visit_sync_func

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        self._in_async_def, orig = True, self._in_async_def
        self._visit_func(node)
        self._in_async_def = orig

    def _visit_comp(self, node: ast.expr) -> None:
        self._in_comp += 1
        with self._scope():
            self.generic_visit(node)
        self._in_comp -= 1

    visit_ListComp = visit_SetComp = _visit_comp
    visit_DictComp = visit_GeneratorExp = _visit_comp

    def visit_Attribute(self, node: ast.Attribute) -> None:
        if self._is_six(node, SIX_SIMPLE_ATTRS):
            self.six_simple[_ast_to_offset(node)] = node
        elif self._find_mock and self._is_mock_mock(node):
            self.mock_mock.add(_ast_to_offset(node))
        self.generic_visit(node)

    def visit_Name(self, node: ast.Name) -> None:
        if self._is_six(node, SIX_SIMPLE_ATTRS):
            self.six_simple[_ast_to_offset(node)] = node

        if self._scope_stack:
            if isinstance(node.ctx, ast.Load):
                self._scope_stack[-1].reads.add(node.id)
            elif isinstance(node.ctx, (ast.Store, ast.Del)):
                self._scope_stack[-1].writes.add(node.id)
            else:
                raise AssertionError(node)

        self.generic_visit(node)

    def visit_Try(self, node: ast.Try) -> None:
        for handler in node.handlers:
            htype = handler.type
            if self._is_os_error_alias(htype):
                assert isinstance(htype, (ast.Name, ast.Attribute))
                self.os_error_alias_simple[_ast_to_offset(htype)] = htype
            elif (
                    isinstance(htype, ast.Tuple) and
                    any(
                        self._is_os_error_alias(elt)
                        for elt in htype.elts
                    )
            ):
                self.os_error_alias_excepts.add(_ast_to_offset(htype))

        self.generic_visit(node)

    def visit_Raise(self, node: ast.Raise) -> None:
        exc = node.exc

        if exc is not None and self._is_os_error_alias(exc):
            assert isinstance(exc, (ast.Name, ast.Attribute))
            self.os_error_alias_simple[_ast_to_offset(exc)] = exc
        elif (
                isinstance(exc, ast.Call) and
                self._is_os_error_alias(exc.func)
        ):
            self.os_error_alias_calls.add(_ast_to_offset(exc))

        self.generic_visit(node)

    def visit_Call(self, node: ast.Call) -> None:
        if (
                isinstance(node.func, ast.Name) and
                node.func.id in {'isinstance', 'issubclass'} and
                len(node.args) == 2 and
                self._is_six(node.args[1], SIX_TYPE_CTX_ATTRS)
        ):
            arg = node.args[1]
            # _is_six() enforces this
            assert isinstance(arg, (ast.Name, ast.Attribute))
            self.six_type_ctx[_ast_to_offset(node.args[1])] = arg
        elif self._is_six(node.func, ('b', 'ensure_binary')):
            self.six_b.add(_ast_to_offset(node))
        elif self._is_six(node.func, SIX_CALLS) and not _starargs(node):
            self.six_calls[_ast_to_offset(node)] = node
        elif (
                isinstance(node.func, ast.Name) and
                node.func.id == 'next' and
                not _starargs(node) and
                len(node.args) == 1 and
                isinstance(node.args[0], ast.Call) and
                self._is_six(
                    node.args[0].func,
                    ('iteritems', 'iterkeys', 'itervalues'),
                ) and
                not _starargs(node.args[0])
        ):
            self.six_iter[_ast_to_offset(node.args[0])] = node.args[0]
        elif (
                isinstance(self._previous_node, ast.Expr) and
                self._is_six(node.func, ('raise_from',)) and
                not _starargs(node)
        ):
            self.six_raise_from.add(_ast_to_offset(node))
        elif (
                isinstance(self._previous_node, ast.Expr) and
                self._is_six(node.func, ('reraise',)) and
                not _starargs(node)
        ):
            self.six_reraise.add(_ast_to_offset(node))
        elif (
                not self._in_comp and
                self._class_info_stack and
                self._class_info_stack[-1].def_depth == 1 and
                isinstance(node.func, ast.Name) and
                node.func.id == 'super' and
                len(node.args) == 2 and
                isinstance(node.args[0], ast.Name) and
                isinstance(node.args[1], ast.Name) and
                node.args[0].id == self._class_info_stack[-1].name and
                node.args[1].id == self._class_info_stack[-1].first_arg_name
        ):
            self.super_calls[_ast_to_offset(node)] = node
        elif (
                (
                    self._is_six(node.func, SIX_NATIVE_STR) or
                    isinstance(node.func, ast.Name) and node.func.id == 'str'
                ) and
                not node.keywords and
                not _starargs(node) and
                (
                    len(node.args) == 0 or
                    (
                        len(node.args) == 1 and
                        isinstance(node.args[0], ast.Str)
                    )
                )
        ):
            self.native_literals.add(_ast_to_offset(node))
        elif (
                isinstance(node.func, ast.Attribute) and
                isinstance(node.func.value, ast.Str) and
                node.func.attr == 'encode' and
                not _starargs(node) and
                len(node.args) == 1 and
                isinstance(node.args[0], ast.Str) and
                _is_codec(node.args[0].s, 'utf-8')
        ):
            self.encode_calls[_ast_to_offset(node)] = node
        elif self._is_io_open(node.func):
            self.io_open_calls.add(_ast_to_offset(node))
        elif (
                isinstance(node.func, ast.Name) and
                node.func.id == 'open' and
                not _starargs(node) and
                len(node.args) >= 2 and
                isinstance(node.args[1], ast.Str) and (
                    node.args[1].s in U_MODE_REPLACE or
                    (len(node.args) == 2 and node.args[1].s in U_MODE_REMOVE)
                )
        ):
            self.open_mode_calls.add(_ast_to_offset(node))
        elif (
                not node.args and
                not node.keywords and
                self._is_lru_cache(node.func)
        ):
            self.no_arg_decorators.add(_ast_to_offset(node))

        self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign) -> None:
        if (
                len(node.targets) == 1 and
                isinstance(node.targets[0], ast.Name) and
                node.targets[0].col_offset == 0 and
                node.targets[0].id == '__metaclass__' and
                isinstance(node.value, ast.Name) and
                node.value.id == 'type'
        ):
            self.metaclass_type_assignments.add(_ast_to_offset(node))

        self.generic_visit(node)

    @staticmethod
    def _eq(test: ast.Compare, n: int) -> bool:
        return (
            isinstance(test.ops[0], ast.Eq) and
            isinstance(test.comparators[0], ast.Num) and
            test.comparators[0].n == n
        )

    @staticmethod
    def _compare_to_3(
        test: ast.Compare,
        op: Union[Type[ast.cmpop], Tuple[Type[ast.cmpop], ...]],
    ) -> bool:
        if not (
                isinstance(test.ops[0], op) and
                isinstance(test.comparators[0], ast.Tuple) and
                len(test.comparators[0].elts) >= 1 and
                all(isinstance(n, ast.Num) for n in test.comparators[0].elts)
        ):
            return False

        # checked above but mypy needs help
        elts = cast('List[ast.Num]', test.comparators[0].elts)

        return elts[0].n == 3 and all(n.n == 0 for n in elts[1:])

    def visit_If(self, node: ast.If) -> None:
        if (
                # if six.PY2:
                self._is_six(node.test, ('PY2',)) or
                # if not six.PY3:
                (
                    isinstance(node.test, ast.UnaryOp) and
                    isinstance(node.test.op, ast.Not) and
                    self._is_six(node.test.operand, ('PY3',))
                ) or
                # sys.version_info == 2 or < (3,)
                (
                    isinstance(node.test, ast.Compare) and
                    self._is_version_info(node.test.left) and
                    len(node.test.ops) == 1 and (
                        self._eq(node.test, 2) or
                        self._compare_to_3(node.test, ast.Lt)
                    )
                )
        ):
            if node.orelse and not isinstance(node.orelse[0], ast.If):
                self.if_py2_blocks_else.add(_ast_to_offset(node))
        elif (
                # if six.PY3:
                self._is_six(node.test, 'PY3') or
                # if not six.PY2:
                (
                    isinstance(node.test, ast.UnaryOp) and
                    isinstance(node.test.op, ast.Not) and
                    self._is_six(node.test.operand, ('PY2',))
                ) or
                # sys.version_info == 3 or >= (3,) or > (3,)
                (
                    isinstance(node.test, ast.Compare) and
                    self._is_version_info(node.test.left) and
                    len(node.test.ops) == 1 and (
                        self._eq(node.test, 3) or
                        self._compare_to_3(node.test, (ast.Gt, ast.GtE))
                    )
                )
        ):
            if node.orelse and not isinstance(node.orelse[0], ast.If):
                self.if_py3_blocks_else.add(_ast_to_offset(node))
            elif not node.orelse:
                self.if_py3_blocks.add(_ast_to_offset(node))
        self.generic_visit(node)

    def visit_For(self, node: ast.For) -> None:
        if (
            not self._in_async_def and
            len(node.body) == 1 and
            isinstance(node.body[0], ast.Expr) and
            isinstance(node.body[0].value, ast.Yield) and
            node.body[0].value.value is not None and
            targets_same(node.target, node.body[0].value.value) and
            not node.orelse
        ):
            offset = _ast_to_offset(node)
            func_info = self._scope_stack[-1]
            func_info.yield_from_fors.add(offset)
            for target_node in ast.walk(node.target):
                if (
                        isinstance(target_node, ast.Name) and
                        isinstance(target_node.ctx, ast.Store)
                ):
                    func_info.yield_from_names[target_node.id].add(offset)
            # manually visit, but with target+body as a separate scope
            self.visit(node.iter)
            with self._scope():
                self.visit(node.target)
                for stmt in node.body:
                    self.visit(stmt)
                assert not node.orelse
        else:
            self.generic_visit(node)

    def generic_visit(self, node: ast.AST) -> None:
        self._previous_node = node
        super().generic_visit(node)


def _fixup_dedent_tokens(tokens: List[Token]) -> None:
    """For whatever reason the DEDENT / UNIMPORTANT_WS tokens are misordered

    | if True:
    |     if True:
    |         pass
    |     else:
    |^    ^- DEDENT
    |+----UNIMPORTANT_WS
    """
    for i, token in enumerate(tokens):
        if token.name == UNIMPORTANT_WS and tokens[i + 1].name == 'DEDENT':
            tokens[i], tokens[i + 1] = tokens[i + 1], tokens[i]


def _find_block_start(tokens: List[Token], i: int) -> int:
    depth = 0
    while depth or tokens[i].src != ':':
        if tokens[i].src in OPENING:
            depth += 1
        elif tokens[i].src in CLOSING:
            depth -= 1
        i += 1
    return i


class Block(NamedTuple):
    start: int
    colon: int
    block: int
    end: int
    line: bool

    def _initial_indent(self, tokens: List[Token]) -> int:
        if tokens[self.start].src.isspace():
            return len(tokens[self.start].src)
        else:
            return 0

    def _minimum_indent(self, tokens: List[Token]) -> int:
        block_indent = None
        for i in range(self.block, self.end):
            if (
                    tokens[i - 1].name in ('NL', 'NEWLINE') and
                    tokens[i].name in ('INDENT', UNIMPORTANT_WS)
            ):
                token_indent = len(tokens[i].src)
                if block_indent is None:
                    block_indent = token_indent
                else:
                    block_indent = min(block_indent, token_indent)

        assert block_indent is not None
        return block_indent

    def dedent(self, tokens: List[Token]) -> None:
        if self.line:
            return
        diff = self._minimum_indent(tokens) - self._initial_indent(tokens)
        for i in range(self.block, self.end):
            if (
                    tokens[i - 1].name in ('DEDENT', 'NL', 'NEWLINE') and
                    tokens[i].name in ('INDENT', UNIMPORTANT_WS)
            ):
                tokens[i] = tokens[i]._replace(src=tokens[i].src[diff:])

    def replace_condition(self, tokens: List[Token], new: List[Token]) -> None:
        tokens[self.start:self.colon] = new

    def _trim_end(self, tokens: List[Token]) -> 'Block':
        """the tokenizer reports the end of the block at the beginning of
        the next block
        """
        i = last_token = self.end - 1
        while tokens[i].name in NON_CODING_TOKENS | {'DEDENT', 'NEWLINE'}:
            # if we find an indented comment inside our block, keep it
            if (
                    tokens[i].name in {'NL', 'NEWLINE'} and
                    tokens[i + 1].name == UNIMPORTANT_WS and
                    len(tokens[i + 1].src) > self._initial_indent(tokens)
            ):
                break
            # otherwise we've found another line to remove
            elif tokens[i].name in {'NL', 'NEWLINE'}:
                last_token = i
            i -= 1
        return self._replace(end=last_token + 1)

    @classmethod
    def find(
            cls,
            tokens: List[Token],
            i: int,
            trim_end: bool = False,
    ) -> 'Block':
        if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}:
            i -= 1
        start = i
        colon = _find_block_start(tokens, i)

        j = colon + 1
        while (
                tokens[j].name != 'NEWLINE' and
                tokens[j].name in NON_CODING_TOKENS
        ):
            j += 1

        if tokens[j].name == 'NEWLINE':  # multi line block
            block = j + 1
            while tokens[j].name != 'INDENT':
                j += 1
            level = 1
            j += 1
            while level:
                level += {'INDENT': 1, 'DEDENT': -1}.get(tokens[j].name, 0)
                j += 1
            ret = cls(start, colon, block, j, line=False)
            if trim_end:
                return ret._trim_end(tokens)
            else:
                return ret
        else:  # single line block
            block = j
            j = _find_end(tokens, j)
            return cls(start, colon, block, j, line=True)


def _find_end(tokens: List[Token], i: int) -> int:
    while tokens[i].name not in {'NEWLINE', 'ENDMARKER'}:
        i += 1

    # depending on the version of python, some will not emit
    # NEWLINE('') at the end of a file which does not end with a
    # newline (for example 3.6.5)
    if tokens[i].name == 'ENDMARKER':  # pragma: no cover
        i -= 1
    else:
        i += 1

    return i


def _find_if_else_block(tokens: List[Token], i: int) -> Tuple[Block, Block]:
    if_block = Block.find(tokens, i)
    i = if_block.end
    while tokens[i].src != 'else':
        i += 1
    else_block = Block.find(tokens, i, trim_end=True)
    return if_block, else_block


def _find_elif(tokens: List[Token], i: int) -> int:
    while tokens[i].src != 'elif':  # pragma: no cover (only for <3.8.1)
        i -= 1
    return i


def _remove_decorator(tokens: List[Token], i: int) -> None:
    while tokens[i - 1].src != '@':
        i -= 1
    if i > 1 and tokens[i - 2].name not in {'NEWLINE', 'NL'}:
        i -= 1
    end = i + 1
    while tokens[end].name != 'NEWLINE':
        end += 1
    del tokens[i - 1:end + 1]


def _remove_base_class(tokens: List[Token], i: int) -> None:
    # look forward and backward to find commas / parens
    brace_stack = []
    j = i
    while tokens[j].src not in {',', ':'}:
        if tokens[j].src == ')':
            brace_stack.append(j)
        j += 1
    right = j

    if tokens[right].src == ':':
        brace_stack.pop()
    else:
        # if there's a close-paren after a trailing comma
        j = right + 1
        while tokens[j].name in NON_CODING_TOKENS:
            j += 1
        if tokens[j].src == ')':
            while tokens[j].src != ':':
                j += 1
            right = j

    if brace_stack:
        last_part = brace_stack[-1]
    else:
        last_part = i

    j = i
    while brace_stack:
        if tokens[j].src == '(':
            brace_stack.pop()
        j -= 1

    while tokens[j].src not in {',', '('}:
        j -= 1
    left = j

    # single base, remove the entire bases
    if tokens[left].src == '(' and tokens[right].src == ':':
        del tokens[left:right]
    # multiple bases, base is first
    elif tokens[left].src == '(' and tokens[right].src != ':':
        # if there's space / comment afterwards remove that too
        while tokens[right + 1].name in {UNIMPORTANT_WS, 'COMMENT'}:
            right += 1
        del tokens[left + 1:right + 1]
    # multiple bases, base is not first
    else:
        del tokens[left:last_part + 1]


def _parse_call_args(
        tokens: List[Token],
        i: int,
) -> Tuple[List[Tuple[int, int]], int]:
    args = []
    stack = [i]
    i += 1
    arg_start = i

    while stack:
        token = tokens[i]

        if len(stack) == 1 and token.src == ',':
            args.append((arg_start, i))
            arg_start = i + 1
        elif token.src in BRACES:
            stack.append(i)
        elif token.src == BRACES[tokens[stack[-1]].src]:
            stack.pop()
            # if we're at the end, append that argument
            if not stack and tokens_to_src(tokens[arg_start:i]).strip():
                args.append((arg_start, i))

        i += 1

    return args, i


def _get_tmpl(mapping: Dict[str, str], node: NameOrAttr) -> str:
    if isinstance(node, ast.Name):
        return mapping[node.id]
    else:
        return mapping[node.attr]


def _arg_str(tokens: List[Token], start: int, end: int) -> str:
    return tokens_to_src(tokens[start:end]).strip()


def _replace_call(
        tokens: List[Token],
        start: int,
        end: int,
        args: List[Tuple[int, int]],
        tmpl: str,
) -> None:
    arg_strs = [_arg_str(tokens, *arg) for arg in args]

    start_rest = args[0][1] + 1
    while (
            start_rest < end and
            tokens[start_rest].name in {'COMMENT', UNIMPORTANT_WS}
    ):
        start_rest += 1

    rest = tokens_to_src(tokens[start_rest:end - 1])
    src = tmpl.format(args=arg_strs, rest=rest)
    tokens[start:end] = [Token('CODE', src)]


def _replace_yield(tokens: List[Token], i: int) -> None:
    in_token = _find_token(tokens, i, 'in')
    colon = _find_block_start(tokens, i)
    block = Block.find(tokens, i, trim_end=True)
    container = tokens_to_src(tokens[in_token + 1:colon]).strip()
    tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')]


def _fix_py3_plus(
        contents_text: str,
        min_version: MinVersion,
        keep_mock: bool = False,
) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPy3Plus(keep_mock)
    visitor.visit(ast_obj)

    if not any((
            visitor.bases_to_remove,
            visitor.encode_calls,
            visitor.if_py2_blocks_else,
            visitor.if_py3_blocks,
            visitor.if_py3_blocks_else,
            visitor.metaclass_type_assignments,
            visitor.native_literals,
            visitor.io_open_calls,
            visitor.open_mode_calls,
            visitor.mock_mock,
            visitor.mock_absolute_imports,
            visitor.mock_relative_imports,
            visitor.os_error_alias_calls,
            visitor.os_error_alias_simple,
            visitor.os_error_alias_excepts,
            visitor.no_arg_decorators,
            visitor.six_add_metaclass,
            visitor.six_b,
            visitor.six_calls,
            visitor.six_iter,
            visitor.six_raise_from,
            visitor.six_reraise,
            visitor.six_remove_decorators,
            visitor.six_simple,
            visitor.six_type_ctx,
            visitor.six_with_metaclass,
            visitor.super_calls,
            visitor.yield_from_fors,
    )):
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text

    _fixup_dedent_tokens(tokens)

    def _replace(i: int, mapping: Dict[str, str], node: NameOrAttr) -> None:
        new_token = Token('CODE', _get_tmpl(mapping, node))
        if isinstance(node, ast.Name):
            tokens[i] = new_token
        else:
            j = i
            while tokens[j].src != node.attr:
                # timid: if we see a parenthesis here, skip it
                if tokens[j].src == ')':
                    return
                j += 1
            tokens[i:j + 1] = [new_token]

    for i, token in reversed_enumerate(tokens):
        if not token.src:
            continue
        elif token.offset in visitor.bases_to_remove:
            _remove_base_class(tokens, i)
        elif token.offset in visitor.if_py3_blocks:
            if tokens[i].src == 'if':
                if_block = Block.find(tokens, i)
                if_block.dedent(tokens)
                del tokens[if_block.start:if_block.block]
            else:
                if_block = Block.find(tokens, _find_elif(tokens, i))
                if_block.replace_condition(tokens, [Token('NAME', 'else')])
        elif token.offset in visitor.if_py2_blocks_else:
            if tokens[i].src == 'if':
                if_block, else_block = _find_if_else_block(tokens, i)
                else_block.dedent(tokens)
                del tokens[if_block.start:else_block.block]
            else:
                j = _find_elif(tokens, i)
                if_block, else_block = _find_if_else_block(tokens, j)
                del tokens[if_block.start:else_block.start]
        elif token.offset in visitor.if_py3_blocks_else:
            if tokens[i].src == 'if':
                if_block, else_block = _find_if_else_block(tokens, i)
                if_block.dedent(tokens)
                del tokens[if_block.end:else_block.end]
                del tokens[if_block.start:if_block.block]
            else:
                j = _find_elif(tokens, i)
                if_block, else_block = _find_if_else_block(tokens, j)
                del tokens[if_block.end:else_block.end]
                if_block.replace_condition(tokens, [Token('NAME', 'else')])
        elif token.offset in visitor.metaclass_type_assignments:
            j = _find_end(tokens, i)
            del tokens[i:j + 1]
        elif token.offset in visitor.native_literals:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            if any(tok.name == 'NL' for tok in tokens[i:end]):
                continue
            if func_args:
                _replace_call(tokens, i, end, func_args, '{args[0]}')
            else:
                tokens[i:end] = [token._replace(name='STRING', src="''")]
        elif token.offset in visitor.six_type_ctx:
            _replace(i, SIX_TYPE_CTX_ATTRS, visitor.six_type_ctx[token.offset])
        elif token.offset in visitor.six_simple:
            _replace(i, SIX_SIMPLE_ATTRS, visitor.six_simple[token.offset])
        elif token.offset in visitor.six_remove_decorators:
            _remove_decorator(tokens, i)
        elif token.offset in visitor.six_b:
            j = _find_open_paren(tokens, i)
            if (
                    tokens[j + 1].name == 'STRING' and
                    _is_ascii(tokens[j + 1].src) and
                    tokens[j + 2].src == ')'
            ):
                func_args, end = _parse_call_args(tokens, j)
                _replace_call(tokens, i, end, func_args, SIX_B_TMPL)
        elif token.offset in visitor.six_iter:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            call = visitor.six_iter[token.offset]
            assert isinstance(call.func, (ast.Name, ast.Attribute))
            template = f'iter({_get_tmpl(SIX_CALLS, call.func)})'
            _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.six_calls:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            call = visitor.six_calls[token.offset]
            assert isinstance(call.func, (ast.Name, ast.Attribute))
            template = _get_tmpl(SIX_CALLS, call.func)
            _replace_call(tokens, i, end, func_args, template)
        elif token.offset in visitor.six_raise_from:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            _replace_call(tokens, i, end, func_args, RAISE_FROM_TMPL)
        elif token.offset in visitor.six_reraise:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            if len(func_args) == 2:
                _replace_call(tokens, i, end, func_args, RERAISE_2_TMPL)
            else:
                _replace_call(tokens, i, end, func_args, RERAISE_3_TMPL)
        elif token.offset in visitor.six_add_metaclass:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            metaclass = f'metaclass={_arg_str(tokens, *func_args[0])}'
            # insert `metaclass={args[0]}` into `class:`
            # search forward for the `class` token
            j = i + 1
            while tokens[j].src != 'class':
                j += 1
            class_token = j
            # then search forward for a `:` token, not inside a brace
            j = _find_block_start(tokens, j)
            last_paren = -1
            for k in range(class_token, j):
                if tokens[k].src == ')':
                    last_paren = k

            if last_paren == -1:
                tokens.insert(j, Token('CODE', f'({metaclass})'))
            else:
                insert = last_paren - 1
                while tokens[insert].name in NON_CODING_TOKENS:
                    insert -= 1
                if tokens[insert].src == '(':  # no bases
                    src = metaclass
                elif tokens[insert].src != ',':
                    src = f', {metaclass}'
                else:
                    src = f' {metaclass},'
                tokens.insert(insert + 1, Token('CODE', src))
            _remove_decorator(tokens, i)
        elif token.offset in visitor.six_with_metaclass:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            if len(func_args) == 1:
                tmpl = WITH_METACLASS_NO_BASES_TMPL
            elif len(func_args) == 2:
                base = _arg_str(tokens, *func_args[1])
                if base == 'object':
                    tmpl = WITH_METACLASS_NO_BASES_TMPL
                else:
                    tmpl = WITH_METACLASS_BASES_TMPL
            else:
                tmpl = WITH_METACLASS_BASES_TMPL
            _replace_call(tokens, i, end, func_args, tmpl)
        elif token.offset in visitor.super_calls:
            i = _find_open_paren(tokens, i)
            call = visitor.super_calls[token.offset]
            victims = _victims(tokens, i, call, gen=False)
            del tokens[victims.starts[0] + 1:victims.ends[-1]]
        elif token.offset in visitor.encode_calls:
            i = _find_open_paren(tokens, i)
            call = visitor.encode_calls[token.offset]
            victims = _victims(tokens, i, call, gen=False)
            del tokens[victims.starts[0] + 1:victims.ends[-1]]
        elif token.offset in visitor.io_open_calls:
            j = _find_open_paren(tokens, i)
            tokens[i:j] = [token._replace(name='NAME', src='open')]
        elif token.offset in visitor.mock_mock:
            j = _find_token(tokens, i + 1, 'mock')
            del tokens[i + 1:j + 1]
        elif token.offset in visitor.mock_absolute_imports:
            j = _find_token(tokens, i, 'mock')
            if (
                    j + 2 < len(tokens) and
                    tokens[j + 1].src == '.' and
                    tokens[j + 2].src == 'mock'
            ):
                j += 2
            src = 'from unittest import mock'
            tokens[i:j + 1] = [tokens[j]._replace(name='NAME', src=src)]
        elif token.offset in visitor.mock_relative_imports:
            j = _find_token(tokens, i, 'mock')
            if (
                    j + 2 < len(tokens) and
                    tokens[j + 1].src == '.' and
                    tokens[j + 2].src == 'mock'
            ):
                k = j + 2
            else:
                k = j
            src = 'unittest.mock'
            tokens[j:k + 1] = [tokens[j]._replace(name='NAME', src=src)]
        elif token.offset in visitor.open_mode_calls:
            j = _find_open_paren(tokens, i)
            func_args, end = _parse_call_args(tokens, j)
            mode = tokens_to_src(tokens[slice(*func_args[1])])
            mode_stripped = mode.strip().strip('"\'')
            if mode_stripped in U_MODE_REMOVE:
                del tokens[func_args[0][1]:func_args[1][1]]
            elif mode_stripped in U_MODE_REPLACE_R:
                new_mode = mode.replace('U', 'r')
                tokens[slice(*func_args[1])] = [Token('SRC', new_mode)]
            elif mode_stripped in U_MODE_REMOVE_U:
                new_mode = mode.replace('U', '')
                tokens[slice(*func_args[1])] = [Token('SRC', new_mode)]
            else:
                raise AssertionError(f'unreachable: {mode!r}')
        elif token.offset in visitor.os_error_alias_calls:
            j = _find_open_paren(tokens, i)
            tokens[i:j] = [token._replace(name='NAME', src='OSError')]
        elif token.offset in visitor.os_error_alias_simple:
            node = visitor.os_error_alias_simple[token.offset]
            _replace(i, collections.defaultdict(lambda: 'OSError'), node)
        elif token.offset in visitor.os_error_alias_excepts:
            line, utf8_byte_offset = token.line, token.utf8_byte_offset

            # find all the arg strs in the tuple
            except_index = i
            while tokens[except_index].src != 'except':
                except_index -= 1
            start = _find_open_paren(tokens, except_index)
            func_args, end = _parse_call_args(tokens, start)

            # save the exceptions and remove the block
            arg_strs = [_arg_str(tokens, *arg) for arg in func_args]
            del tokens[start:end]

            # rewrite the block without dupes
            args = []
            for arg in arg_strs:
                left, part, right = arg.partition('.')
                if (
                        left in visitor.OS_ERROR_ALIAS_MODULES and
                        part == '.' and
                        right == 'error'
                ):
                    args.append('OSError')
                elif (
                        left in visitor.OS_ERROR_ALIASES and
                        part == right == ''
                ):
                    args.append('OSError')
                elif (
                        left == 'error' and
                        part == right == '' and
                        (
                            'error' in visitor._from_imports['mmap'] or
                            'error' in visitor._from_imports['select'] or
                            'error' in visitor._from_imports['socket']
                        )
                ):
                    args.append('OSError')
                else:
                    args.append(arg)

            unique_args = tuple(collections.OrderedDict.fromkeys(args))

            if len(unique_args) > 1:
                joined = '({})'.format(', '.join(unique_args))
            elif tokens[start - 1].name != 'UNIMPORTANT_WS':
                joined = ' {}'.format(unique_args[0])
            else:
                joined = unique_args[0]

            new = Token('CODE', joined, line, utf8_byte_offset)
            tokens.insert(start, new)

            visitor.os_error_alias_excepts.discard(token.offset)
        elif token.offset in visitor.yield_from_fors:
            _replace_yield(tokens, i)
        elif (
                min_version >= (3, 8) and
                token.offset in visitor.no_arg_decorators
        ):
            i = _find_open_paren(tokens, i)
            j = _find_token(tokens, i, ')')
            del tokens[i:j + 1]

    return tokens_to_src(tokens)


def _simple_arg(arg: ast.expr) -> bool:
    return (
        isinstance(arg, ast.Name) or
        (isinstance(arg, ast.Attribute) and _simple_arg(arg.value)) or
        (
            isinstance(arg, ast.Call) and
            _simple_arg(arg.func) and
            not arg.args and not arg.keywords
        )
    )


def _starargs(call: ast.Call) -> bool:
    return (
        any(k.arg is None for k in call.keywords) or
        any(isinstance(a, ast.Starred) for a in call.args)
    )


def _format_params(call: ast.Call) -> Dict[str, str]:
    params = {}
    for i, arg in enumerate(call.args):
        params[str(i)] = _unparse(arg)
    for kwd in call.keywords:
        # kwd.arg can't be None here because we exclude starargs
        assert kwd.arg is not None
        params[kwd.arg] = _unparse(kwd.value)
    return params


class FindPy36Plus(ast.NodeVisitor):
    def __init__(self) -> None:
        self.fstrings: Dict[Offset, ast.Call] = {}
        self.named_tuples: Dict[Offset, ast.Call] = {}
        self.dict_typed_dicts: Dict[Offset, ast.Call] = {}
        self.kw_typed_dicts: Dict[Offset, ast.Call] = {}
        self._from_imports: Dict[str, Set[str]] = collections.defaultdict(set)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        if node.level == 0 and node.module in {'typing', 'typing_extensions'}:
            for name in node.names:
                if not name.asname:
                    self._from_imports[node.module].add(name.name)
        self.generic_visit(node)

    def _is_attr(self, node: ast.AST, mods: Set[str], name: str) -> bool:
        return (
            (
                isinstance(node, ast.Name) and
                node.id == name and
                any(name in self._from_imports[mod] for mod in mods)
            ) or
            (
                isinstance(node, ast.Attribute) and
                node.attr == name and
                isinstance(node.value, ast.Name) and
                node.value.id in mods
            )
        )

    def _parse(self, node: ast.Call) -> Optional[Tuple[DotFormatPart, ...]]:
        if not (
                isinstance(node.func, ast.Attribute) and
                isinstance(node.func.value, ast.Str) and
                node.func.attr == 'format' and
                all(_simple_arg(arg) for arg in node.args) and
                all(_simple_arg(k.value) for k in node.keywords) and
                not _starargs(node)
        ):
            return None

        try:
            return parse_format(node.func.value.s)
        except ValueError:
            return None

    def visit_Call(self, node: ast.Call) -> None:
        parsed = self._parse(node)
        if parsed is not None:
            params = _format_params(node)
            seen: Set[str] = set()
            i = 0
            for _, name, spec, _ in parsed:
                # timid: difficult to rewrite correctly
                if spec is not None and '{' in spec:
                    break
                if name is not None:
                    candidate, _, _ = name.partition('.')
                    # timid: could make the f-string longer
                    if candidate and candidate in seen:
                        break
                    # timid: bracketed
                    elif '[' in name:
                        break
                    seen.add(candidate)

                    key = candidate or str(i)
                    # their .format() call is broken currently
                    if key not in params:
                        break
                    if not candidate:
                        i += 1
            else:
                self.fstrings[_ast_to_offset(node)] = node

        self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign) -> None:
        if (
                # NT = ...("NT", ...)
                len(node.targets) == 1 and
                isinstance(node.targets[0], ast.Name) and
                isinstance(node.value, ast.Call) and
                len(node.value.args) >= 1 and
                isinstance(node.value.args[0], ast.Str) and
                node.targets[0].id == node.value.args[0].s and
                not _starargs(node.value)
        ):
            if (
                    self._is_attr(
                        node.value.func, {'typing'}, 'NamedTuple',
                    ) and
                    len(node.value.args) == 2 and
                    not node.value.keywords and
                    isinstance(node.value.args[1], (ast.List, ast.Tuple)) and
                    len(node.value.args[1].elts) > 0 and
                    all(
                        isinstance(tup, ast.Tuple) and
                        len(tup.elts) == 2 and
                        isinstance(tup.elts[0], ast.Str) and
                        tup.elts[0].s.isidentifier() and
                        tup.elts[0].s not in _KEYWORDS
                        for tup in node.value.args[1].elts
                    )
            ):
                self.named_tuples[_ast_to_offset(node)] = node.value
            elif (
                    self._is_attr(
                        node.value.func,
                        {'typing', 'typing_extensions'},
                        'TypedDict',
                    ) and
                    len(node.value.args) == 1 and
                    len(node.value.keywords) > 0
            ):
                self.kw_typed_dicts[_ast_to_offset(node)] = node.value
            elif (
                    self._is_attr(
                        node.value.func,
                        {'typing', 'typing_extensions'},
                        'TypedDict',
                    ) and
                    len(node.value.args) == 2 and
                    not node.value.keywords and
                    isinstance(node.value.args[1], ast.Dict) and
                    node.value.args[1].keys and
                    all(
                        isinstance(k, ast.Str) and
                        k.s.isidentifier() and
                        k.s not in _KEYWORDS
                        for k in node.value.args[1].keys
                    )
            ):
                self.dict_typed_dicts[_ast_to_offset(node)] = node.value

        self.generic_visit(node)


def _unparse(node: ast.expr) -> str:
    if isinstance(node, ast.Name):
        return node.id
    elif isinstance(node, ast.Attribute):
        return ''.join((_unparse(node.value), '.', node.attr))
    elif isinstance(node, ast.Call):
        return '{}()'.format(_unparse(node.func))
    elif isinstance(node, ast.Subscript):
        if sys.version_info >= (3, 9):  # pragma: no cover (py39+)
            # https://github.com/python/typeshed/pull/3950
            node_slice: ast.expr = node.slice  # type: ignore
        elif isinstance(node.slice, ast.Index):  # pragma: no cover (<py39)
            node_slice = node.slice.value
        else:
            raise AssertionError(f'expected Slice: {ast.dump(node)}')
        if isinstance(node_slice, ast.Tuple):
            if len(node_slice.elts) == 1:
                slice_s = f'{_unparse(node_slice.elts[0])},'
            else:
                slice_s = ', '.join(_unparse(elt) for elt in node_slice.elts)
        else:
            slice_s = _unparse(node_slice)
        return '{}[{}]'.format(_unparse(node.value), slice_s)
    elif isinstance(node, ast.Str):
        return repr(node.s)
    elif isinstance(node, ast.Ellipsis):
        return '...'
    elif isinstance(node, ast.List):
        return '[{}]'.format(', '.join(_unparse(elt) for elt in node.elts))
    elif isinstance(node, ast.NameConstant):
        return repr(node.value)
    else:
        raise NotImplementedError(ast.dump(node))


def _to_fstring(src: str, call: ast.Call) -> str:
    params = _format_params(call)

    parts = []
    i = 0
    for s, name, spec, conv in parse_format('f' + src):
        if name is not None:
            k, dot, rest = name.partition('.')
            name = ''.join((params[k or str(i)], dot, rest))
            if not k:  # named and auto params can be in different orders
                i += 1
        parts.append((s, name, spec, conv))
    return unparse_parsed_string(parts)


def _replace_typed_class(
        tokens: List[Token],
        i: int,
        call: ast.Call,
        types: Dict[str, ast.expr],
) -> None:
    if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}:
        indent = f'{tokens[i - 1].src}{" " * 4}'
    else:
        indent = ' ' * 4

    # NT = NamedTuple("nt", [("a", int)])
    # ^i                                 ^end
    end = i + 1
    while end < len(tokens) and tokens[end].name != 'NEWLINE':
        end += 1

    attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items())
    src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
    tokens[i:end] = [Token('CODE', src)]


def _fix_py36_plus(contents_text: str) -> str:
    try:
        ast_obj = ast_parse(contents_text)
    except SyntaxError:
        return contents_text

    visitor = FindPy36Plus()
    visitor.visit(ast_obj)

    if not any((
            visitor.fstrings,
            visitor.named_tuples,
            visitor.dict_typed_dicts,
            visitor.kw_typed_dicts,
    )):
        return contents_text

    try:
        tokens = src_to_tokens(contents_text)
    except tokenize.TokenError:  # pragma: no cover (bpo-2180)
        return contents_text
    for i, token in reversed_enumerate(tokens):
        if token.offset in visitor.fstrings:
            node = visitor.fstrings[token.offset]

            # TODO: handle \N escape sequences
            if r'\N' in token.src:
                continue

            paren = i + 3
            if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
                continue

            # we don't actually care about arg position, so we pass `node`
            victims = _victims(tokens, paren, node, gen=False)
            end = victims.ends[-1]
            # if it spans more than one line, bail
            if tokens[end].line != token.line:
                continue

            tokens[i] = token._replace(src=_to_fstring(token.src, node))
            del tokens[i + 1:end + 1]
        elif token.offset in visitor.named_tuples and token.name == 'NAME':
            call = visitor.named_tuples[token.offset]
            types: Dict[str, ast.expr] = {
                tup.elts[0].s: tup.elts[1]  # type: ignore  # (checked above)
                for tup in call.args[1].elts  # type: ignore  # (checked above)
            }
            _replace_typed_class(tokens, i, call, types)
        elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME':
            call = visitor.kw_typed_dicts[token.offset]
            types = {
                arg.arg: arg.value  # type: ignore  # (checked above)
                for arg in call.keywords
            }
            _replace_typed_class(tokens, i, call, types)
        elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME':
            call = visitor.dict_typed_dicts[token.offset]
            types = {
                k.s: v  # type: ignore  # (checked above)
                for k, v in zip(
                    call.args[1].keys,  # type: ignore  # (checked above)
                    call.args[1].values,  # type: ignore  # (checked above)
                )
            }
            _replace_typed_class(tokens, i, call, types)

    return tokens_to_src(tokens)


def _fix_file(filename: str, args: argparse.Namespace) -> int:
    if filename == '-':
        contents_bytes = sys.stdin.buffer.read()
    else:
        with open(filename, 'rb') as fb:
            contents_bytes = fb.read()

    try:
        contents_text_orig = contents_text = contents_bytes.decode()
    except UnicodeDecodeError:
        print(f'{filename} is non-utf-8 (not supported)')
        return 1

    contents_text = _fix_py2_compatible(contents_text)
    contents_text = _fix_tokens(contents_text, min_version=args.min_version)
    if not args.keep_percent_format:
        contents_text = _fix_percent_format(contents_text)
    if args.min_version >= (3,):
        contents_text = _fix_py3_plus(
            contents_text, args.min_version, args.keep_mock,
        )
    if args.min_version >= (3, 6):
        contents_text = _fix_py36_plus(contents_text)

    if filename == '-':
        print(contents_text, end='')
    elif contents_text != contents_text_orig:
        print(f'Rewriting {filename}', file=sys.stderr)
        with open(filename, 'w', encoding='UTF-8', newline='') as f:
            f.write(contents_text)

    if args.exit_zero_even_if_changed:
        return 0
    else:
        return contents_text != contents_text_orig


def main(argv: Optional[Sequence[str]] = None) -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('filenames', nargs='*')
    parser.add_argument('--exit-zero-even-if-changed', action='store_true')
    parser.add_argument('--keep-percent-format', action='store_true')
    parser.add_argument('--keep-mock', action='store_true')
    parser.add_argument(
        '--py3-plus', '--py3-only',
        action='store_const', dest='min_version', default=(2, 7), const=(3,),
    )
    parser.add_argument(
        '--py36-plus',
        action='store_const', dest='min_version', const=(3, 6),
    )
    parser.add_argument(
        '--py37-plus',
        action='store_const', dest='min_version', const=(3, 7),
    )
    parser.add_argument(
        '--py38-plus',
        action='store_const', dest='min_version', const=(3, 8),
    )
    args = parser.parse_args(argv)

    ret = 0
    for filename in args.filenames:
        ret |= _fix_file(filename, args)
    return ret


if __name__ == '__main__':
    exit(main())