import types import asyncio import socket import warnings import sys from functools import partial from collections import deque from contextlib import contextmanager from .util import ( encode_command, wait_ok, _NOTSET, _set_result, _set_exception, coerced_keys_dict, decode, parse_url, get_event_loop, ) from .parser import Reader from .stream import open_connection, open_unix_connection from .errors import ( ConnectionClosedError, ConnectionForcedCloseError, RedisError, ProtocolError, ReplyError, WatchVariableError, ReadOnlyError, MaxClientsError ) from .pubsub import Channel from .abc import AbcChannel from .abc import AbcConnection from .log import logger __all__ = ['create_connection', 'RedisConnection'] MAX_CHUNK_SIZE = 65536 _PUBSUB_COMMANDS = ( 'SUBSCRIBE', b'SUBSCRIBE', 'PSUBSCRIBE', b'PSUBSCRIBE', 'UNSUBSCRIBE', b'UNSUBSCRIBE', 'PUNSUBSCRIBE', b'PUNSUBSCRIBE', ) async def create_connection(address, *, db=None, password=None, ssl=None, encoding=None, parser=None, loop=None, timeout=None, connection_cls=None): """Creates redis connection. Opens connection to Redis server specified by address argument. Address argument can be one of the following: * A tuple representing (host, port) pair for TCP connections; * A string representing either Redis URI or unix domain socket path. SSL argument is passed through to asyncio.create_connection. By default SSL/TLS is not used. By default any timeout is applied at the connection stage, however you can set a limitted time used trying to open a connection via the `timeout` Kw. Encoding argument can be used to decode byte-replies to strings. By default no decoding is done. Parser parameter can be used to pass custom Redis protocol parser class. By default hiredis.Reader is used (unless it is missing or platform is not CPython). Return value is RedisConnection instance or a connection_cls if it is given. This function is a coroutine. """ assert isinstance(address, (tuple, list, str)), "tuple or str expected" if isinstance(address, str): address, options = parse_url(address) logger.debug("Parsed Redis URI %r", address) db = options.setdefault('db', db) password = options.setdefault('password', password) encoding = options.setdefault('encoding', encoding) timeout = options.setdefault('timeout', timeout) if 'ssl' in options: assert options['ssl'] or (not options['ssl'] and not ssl), ( "Conflicting ssl options are set", options['ssl'], ssl) ssl = ssl or options['ssl'] if timeout is not None and timeout <= 0: raise ValueError("Timeout has to be None or a number greater than 0") if connection_cls: assert issubclass(connection_cls, AbcConnection),\ "connection_class does not meet the AbcConnection contract" cls = connection_cls else: cls = RedisConnection if loop is not None and sys.version_info >= (3, 8, 0): warnings.warn("The loop argument is deprecated", DeprecationWarning) if isinstance(address, (list, tuple)): host, port = address logger.debug("Creating tcp connection to %r", address) reader, writer = await asyncio.wait_for(open_connection( host, port, limit=MAX_CHUNK_SIZE, ssl=ssl), timeout) sock = writer.transport.get_extra_info('socket') if sock is not None: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) address = sock.getpeername() address = tuple(address[:2]) else: logger.debug("Creating unix connection to %r", address) reader, writer = await asyncio.wait_for(open_unix_connection( address, ssl=ssl, limit=MAX_CHUNK_SIZE), timeout) sock = writer.transport.get_extra_info('socket') if sock is not None: address = sock.getpeername() conn = cls(reader, writer, encoding=encoding, address=address, parser=parser) try: if password is not None: await conn.auth(password) if db is not None: await conn.select(db) except Exception: conn.close() await conn.wait_closed() raise return conn class RedisConnection(AbcConnection): """Redis connection.""" def __init__(self, reader, writer, *, address, encoding=None, parser=None, loop=None): if loop is not None and sys.version_info >= (3, 8): warnings.warn("The loop argument is deprecated", DeprecationWarning) if parser is None: parser = Reader assert callable(parser), ( "Parser argument is not callable", parser) self._reader = reader self._writer = writer self._address = address self._waiters = deque() self._reader.set_parser( parser(protocolError=ProtocolError, replyError=ReplyError) ) self._reader_task = asyncio.ensure_future(self._read_data()) self._close_msg = None self._db = 0 self._closing = False self._closed = False self._close_state = asyncio.Event() self._reader_task.add_done_callback(lambda x: self._close_state.set()) self._in_transaction = None self._transaction_error = None # XXX: never used? self._in_pubsub = 0 self._pubsub_channels = coerced_keys_dict() self._pubsub_patterns = coerced_keys_dict() self._encoding = encoding self._pipeline_buffer = None def __repr__(self): return '<RedisConnection [db:{}]>'.format(self._db) async def _read_data(self): """Response reader task.""" last_error = ConnectionClosedError( "Connection has been closed by server") while not self._reader.at_eof(): try: obj = await self._reader.readobj() except asyncio.CancelledError: # NOTE: reader can get cancelled from `close()` method only. last_error = RuntimeError('this is unexpected') break except ProtocolError as exc: # ProtocolError is fatal # so connection must be closed if self._in_transaction is not None: self._transaction_error = exc last_error = exc break except Exception as exc: # NOTE: for QUIT command connection error can be received # before response last_error = exc break else: if (obj == b'' or obj is None) and self._reader.at_eof(): logger.debug("Connection has been closed by server," " response: %r", obj) last_error = ConnectionClosedError("Reader at end of file") break if isinstance(obj, MaxClientsError): last_error = obj break if self._in_pubsub: self._process_pubsub(obj) else: self._process_data(obj) self._closing = True get_event_loop().call_soon(self._do_close, last_error) def _process_data(self, obj): """Processes command results.""" assert len(self._waiters) > 0, (type(obj), obj) waiter, encoding, cb = self._waiters.popleft() if isinstance(obj, RedisError): if isinstance(obj, ReplyError): if obj.args[0].startswith('READONLY'): obj = ReadOnlyError(obj.args[0]) _set_exception(waiter, obj) if self._in_transaction is not None: self._transaction_error = obj else: if encoding is not None: try: obj = decode(obj, encoding) except Exception as exc: _set_exception(waiter, exc) return if cb is not None: try: obj = cb(obj) except Exception as exc: _set_exception(waiter, exc) return _set_result(waiter, obj) if self._in_transaction is not None: self._in_transaction.append((encoding, cb)) def _process_pubsub(self, obj, *, process_waiters=True): """Processes pubsub messages.""" kind, *args, data = obj if kind in (b'subscribe', b'unsubscribe'): chan, = args if process_waiters and self._in_pubsub and self._waiters: self._process_data(obj) if kind == b'unsubscribe': ch = self._pubsub_channels.pop(chan, None) if ch: ch.close() self._in_pubsub = data elif kind in (b'psubscribe', b'punsubscribe'): chan, = args if process_waiters and self._in_pubsub and self._waiters: self._process_data(obj) if kind == b'punsubscribe': ch = self._pubsub_patterns.pop(chan, None) if ch: ch.close() self._in_pubsub = data elif kind == b'message': chan, = args self._pubsub_channels[chan].put_nowait(data) elif kind == b'pmessage': pattern, chan = args self._pubsub_patterns[pattern].put_nowait((chan, data)) elif kind == b'pong': if process_waiters and self._in_pubsub and self._waiters: self._process_data(data or b'PONG') else: logger.warning("Unknown pubsub message received %r", obj) @contextmanager def _buffered(self): # XXX: we must ensure that no await happens # as long as we buffer commands. # Probably we can set some error-raising callback on enter # and remove it on exit # if some await happens in between -> throw an error. # This is creepy solution, 'cause some one might want to await # on some other source except redis. # So we must only raise error we someone tries to await # pending aioredis future # One of solutions is to return coroutine instead of a future # in `execute` method. # In a coroutine we can check if buffering is enabled and raise error. # TODO: describe in docs difference in pipeline mode for # conn.execute vs pipeline.execute() if self._pipeline_buffer is None: self._pipeline_buffer = bytearray() try: yield self buf = self._pipeline_buffer self._writer.write(buf) finally: self._pipeline_buffer = None else: yield self def execute(self, command, *args, encoding=_NOTSET): """Executes redis command and returns Future waiting for the answer. Raises: * TypeError if any of args can not be encoded as bytes. * ReplyError on redis '-ERR' responses. * ProtocolError when response can not be decoded meaning connection is broken. * ConnectionClosedError when either client or server has closed the connection. """ if self._reader is None or self._reader.at_eof(): msg = self._close_msg or "Connection closed or corrupted" raise ConnectionClosedError(msg) if command is None: raise TypeError("command must not be None") if None in args: raise TypeError("args must not contain None") command = command.upper().strip() is_pubsub = command in _PUBSUB_COMMANDS is_ping = command in ('PING', b'PING') if self._in_pubsub and not (is_pubsub or is_ping): raise RedisError("Connection in SUBSCRIBE mode") elif is_pubsub: logger.warning("Deprecated. Use `execute_pubsub` method directly") return self.execute_pubsub(command, *args) if command in ('SELECT', b'SELECT'): cb = partial(self._set_db, args=args) elif command in ('MULTI', b'MULTI'): cb = self._start_transaction elif command in ('EXEC', b'EXEC'): cb = partial(self._end_transaction, discard=False) encoding = None elif command in ('DISCARD', b'DISCARD'): cb = partial(self._end_transaction, discard=True) else: cb = None if encoding is _NOTSET: encoding = self._encoding fut = get_event_loop().create_future() if self._pipeline_buffer is None: self._writer.write(encode_command(command, *args)) else: encode_command(command, *args, buf=self._pipeline_buffer) self._waiters.append((fut, encoding, cb)) return fut def execute_pubsub(self, command, *channels): """Executes redis (p)subscribe/(p)unsubscribe commands. Returns asyncio.gather coroutine waiting for all channels/patterns to receive answers. """ command = command.upper().strip() assert command in _PUBSUB_COMMANDS, ( "Pub/Sub command expected", command) if self._reader is None or self._reader.at_eof(): raise ConnectionClosedError("Connection closed or corrupted") if None in set(channels): raise TypeError("args must not contain None") if not len(channels): raise TypeError("No channels/patterns supplied") is_pattern = len(command) in (10, 12) mkchannel = partial(Channel, is_pattern=is_pattern) channels = [ch if isinstance(ch, AbcChannel) else mkchannel(ch) for ch in channels] if not all(ch.is_pattern == is_pattern for ch in channels): raise ValueError("Not all channels {} match command {}" .format(channels, command)) cmd = encode_command(command, *(ch.name for ch in channels)) res = [] for ch in channels: fut = get_event_loop().create_future() res.append(fut) cb = partial(self._update_pubsub, ch=ch) self._waiters.append((fut, None, cb)) if self._pipeline_buffer is None: self._writer.write(cmd) else: self._pipeline_buffer.extend(cmd) return asyncio.gather(*res) def close(self): """Close connection.""" self._do_close(ConnectionForcedCloseError()) def _do_close(self, exc): if self._closed: return self._closed = True self._closing = False self._writer.transport.close() self._reader_task.cancel() self._reader_task = None self._writer = None self._reader = None self._pipeline_buffer = None if exc is not None: self._close_msg = str(exc) while self._waiters: waiter, *spam = self._waiters.popleft() logger.debug("Cancelling waiter %r", (waiter, spam)) if exc is None: _set_exception(waiter, ConnectionForcedCloseError()) else: _set_exception(waiter, exc) while self._pubsub_channels: _, ch = self._pubsub_channels.popitem() logger.debug("Closing pubsub channel %r", ch) ch.close(exc) while self._pubsub_patterns: _, ch = self._pubsub_patterns.popitem() logger.debug("Closing pubsub pattern %r", ch) ch.close(exc) @property def closed(self): """True if connection is closed.""" closed = self._closing or self._closed if not closed and self._reader and self._reader.at_eof(): self._closing = closed = True get_event_loop().call_soon(self._do_close, None) return closed async def wait_closed(self): """Coroutine waiting until connection is closed.""" await self._close_state.wait() @property def db(self): """Currently selected db index.""" return self._db @property def encoding(self): """Current set codec or None.""" return self._encoding @property def address(self): """Redis server address, either host-port tuple or str.""" return self._address def select(self, db): """Change the selected database for the current connection.""" if not isinstance(db, int): raise TypeError("DB must be of int type, not {!r}".format(db)) if db < 0: raise ValueError("DB must be greater or equal 0, got {!r}" .format(db)) fut = self.execute('SELECT', db) return wait_ok(fut) def _set_db(self, ok, args): assert ok in {b'OK', 'OK'}, ("Unexpected result of SELECT", ok) self._db = args[0] return ok def _start_transaction(self, ok): assert self._in_transaction is None, ( "Connection is already in transaction", self._in_transaction) self._in_transaction = deque() self._transaction_error = None return ok def _end_transaction(self, obj, discard): assert self._in_transaction is not None, ( "Connection is not in transaction", obj) self._transaction_error = None recall, self._in_transaction = self._in_transaction, None recall.popleft() # ignore first (its _start_transaction) if discard: return obj assert isinstance(obj, list) or (obj is None and not discard), ( "Unexpected MULTI/EXEC result", obj, recall) # TODO: need to be able to re-try transaction if obj is None: err = WatchVariableError("WATCH variable has changed") obj = [err] * len(recall) assert len(obj) == len(recall), ( "Wrong number of result items in mutli-exec", obj, recall) res = [] for o, (encoding, cb) in zip(obj, recall): if not isinstance(o, RedisError): try: if encoding: o = decode(o, encoding) if cb: o = cb(o) except Exception as err: res.append(err) continue res.append(o) return res def _update_pubsub(self, obj, *, ch): kind, *pattern, channel, subscriptions = obj self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub # XXX: the channels/patterns storage should be refactored. # if code which supposed to read from channel/pattern # failed (exception in reader or else) than # the channel object will still reside in memory # and leak memory (messages will be put in queue). if kind == b'subscribe' and channel not in self._pubsub_channels: self._pubsub_channels[channel] = ch elif kind == b'psubscribe' and channel not in self._pubsub_patterns: self._pubsub_patterns[channel] = ch if not was_in_pubsub: self._process_pubsub(obj, process_waiters=False) return obj @property def in_transaction(self): """Set to True when MULTI command was issued.""" return self._in_transaction is not None @property def in_pubsub(self): """Indicates that connection is in PUB/SUB mode. Provides the number of subscribed channels. """ return self._in_pubsub @property def pubsub_channels(self): """Returns read-only channels dict.""" return types.MappingProxyType(self._pubsub_channels) @property def pubsub_patterns(self): """Returns read-only patterns dict.""" return types.MappingProxyType(self._pubsub_patterns) def auth(self, password): """Authenticate to server.""" fut = self.execute('AUTH', password) return wait_ok(fut)