# testing/assertsql.py # Copyright (C) 2005-2018 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php from ..engine.default import DefaultDialect from .. import util import re import collections import contextlib from .. import event from sqlalchemy.schema import _DDLCompiles from sqlalchemy.engine.util import _distill_params from sqlalchemy.engine import url class AssertRule(object): is_consumed = False errormessage = None consume_statement = True def process_statement(self, execute_observed): pass def no_more_statements(self): assert False, 'All statements are complete, but pending '\ 'assertion rules remain' class SQLMatchRule(AssertRule): pass class CursorSQL(SQLMatchRule): consume_statement = False def __init__(self, statement, params=None): self.statement = statement self.params = params def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( self.params is not None and self.params != stmt.parameters): self.errormessage = \ "Testing for exact SQL %s parameters %s received %s %s" % ( self.statement, self.params, stmt.statement, stmt.parameters ) else: execute_observed.statements.pop(0) self.is_consumed = True if not execute_observed.statements: self.consume_statement = True class CompiledSQL(SQLMatchRule): def __init__(self, statement, params=None, dialect='default'): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): if self.dialect == 'default': return DefaultDialect() else: # ugh if self.dialect == 'postgresql': params = {'implicit_returning': True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms of a target dialect, which for CompiledSQL is just DefaultDialect.""" context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): compiled = \ context.compiled.statement.compile( dialect=compare_dialect, schema_translate_map=context. execution_options.get('schema_translate_map')) else: compiled = ( context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, inline=context.compiled.inline, schema_translate_map=context. execution_options.get('schema_translate_map')) ) _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ compiled.construct_params(m) for m in parameters] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context _received_statement, _received_parameters = \ self._received_statement(execute_observed) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) if equivalent: if params is not None: all_params = list(params) all_received = list(_received_parameters) while all_params and all_received: param = dict(all_params.pop(0)) for idx, received in enumerate(list(all_received)): # do a positive compare only for param_key in param: # a key in param did not match current # 'received' if param_key not in received or \ received[param_key] != param[param_key]: break else: # all keys in param matched 'received'; # onto next param del all_received[idx] break else: # param did not match any entry # in all_received equivalent = False break if all_params or all_received: equivalent = False if equivalent: self.is_consumed = True self.errormessage = None else: self.errormessage = self._failure_message(params) % { 'received_statement': _received_statement, 'received_parameters': _received_parameters } def _all_params(self, context): if self.params: if util.callable(self.params): params = self.params(context) else: params = self.params if not isinstance(params, list): params = [params] return params else: return None def _failure_message(self, expected_params): return ( 'Testing for compiled statement %r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( self.statement.replace('%', '%%'), expected_params ) ) class RegexSQL(CompiledSQL): def __init__(self, regex, params=None): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params self.dialect = 'default' def _failure_message(self, expected_params): return ( 'Testing for compiled statement ~%r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( self.orig_regex, expected_params ) ) def _compare_sql(self, execute_observed, received_statement): return bool(self.regex.match(received_statement)) class DialectSQL(CompiledSQL): def _compile_dialect(self, execute_observed): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): stmt = re.sub(r'[\n\t]', '', real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): received_stmt, received_params = super(DialectSQL, self).\ _received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: if self._compare_no_space(real_stmt.statement, received_stmt): break else: raise AssertionError( "Can't locate compiled statement %r in list of " "statements actually invoked" % received_stmt) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle if paramstyle == 'pyformat': stmt = re.sub( r':([\w_]+)', r"%(\1)s", stmt) else: # positional params repl = None if paramstyle == 'qmark': repl = "?" elif paramstyle == 'format': repl = r"%s" elif paramstyle == 'numeric': repl = None stmt = re.sub(r':([\w_]+)', repl, stmt) return received_statement == stmt class CountStatements(AssertRule): def __init__(self, count): self.count = count self._statement_count = 0 def process_statement(self, execute_observed): self._statement_count += 1 def no_more_statements(self): if self.count != self._statement_count: assert False, 'desired statement count %d does not match %d' \ % (self.count, self._statement_count) class AllOf(AssertRule): def __init__(self, *rules): self.rules = set(rules) def process_statement(self, execute_observed): for rule in list(self.rules): rule.errormessage = None rule.process_statement(execute_observed) if rule.is_consumed: self.rules.discard(rule) if not self.rules: self.is_consumed = True break elif not rule.errormessage: # rule is not done yet self.errormessage = None break else: self.errormessage = list(self.rules)[0].errormessage class EachOf(AssertRule): def __init__(self, *rules): self.rules = list(rules) def process_statement(self, execute_observed): while self.rules: rule = self.rules[0] rule.process_statement(execute_observed) if rule.is_consumed: self.rules.pop(0) elif rule.errormessage: self.errormessage = rule.errormessage if rule.consume_statement: break if not self.rules: self.is_consumed = True def no_more_statements(self): if self.rules and not self.rules[0].is_consumed: self.rules[0].no_more_statements() elif self.rules: super(EachOf, self).no_more_statements() class Or(AllOf): def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) if rule.is_consumed: self.is_consumed = True break else: self.errormessage = list(self.rules)[0].errormessage class SQLExecuteObserved(object): def __init__(self, context, clauseelement, multiparams, params): self.context = context self.clauseelement = clauseelement self.parameters = _distill_params(multiparams, params) self.statements = [] class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", ["statement", "parameters", "context", "executemany"]) ): pass class SQLAsserter(object): def __init__(self): self.accumulated = [] def _close(self): self._final = self.accumulated del self.accumulated def assert_(self, *rules): rule = EachOf(*rules) observed = list(self._final) while observed: statement = observed.pop(0) rule.process_statement(statement) if rule.is_consumed: break elif rule.errormessage: assert False, rule.errormessage if observed: assert False, "Additional SQL statements remain" elif not rule.is_consumed: rule.no_more_statements() @contextlib.contextmanager def assert_engine(engine): asserter = SQLAsserter() orig = [] @event.listens_for(engine, "before_execute") def connection_execute(conn, clauseelement, multiparams, params): # grab the original statement + params before any cursor # execution orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") def cursor_execute(conn, cursor, statement, parameters, context, executemany): if not context: return # then grab real cursor statements and associate them all # around a single context if asserter.accumulated and \ asserter.accumulated[-1].context is context: obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( statement, parameters, context, executemany) ) try: yield asserter finally: event.remove(engine, "after_cursor_execute", cursor_execute) event.remove(engine, "before_execute", connection_execute) asserter._close()