# sql/util.py # Copyright (C) 2005-2014 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """High level utilities which build upon other modules here. """ from .. import exc, util from .base import _from_objects, ColumnSet from . import operators, visitors from itertools import chain from collections import deque from .elements import BindParameter, ColumnClause, ColumnElement, \ Null, UnaryExpression, literal_column, Label from .selectable import ScalarSelect, Join, FromClause, FromGrouping from .schema import Column join_condition = util.langhelpers.public_factory( Join._join_condition, ".sql.util.join_condition") # names that are still being imported from the outside from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate from .elements import _find_columns from .ddl import sort_tables def find_join_source(clauses, join_to): """Given a list of FROM clauses and a selectable, return the first index and element from the list of clauses which can be joined against the selectable. returns None, None if no match is found. e.g.:: clause1 = table1.join(table2) clause2 = table4.join(table5) join_to = table2.join(table3) find_join_source([clause1, clause2], join_to) == clause1 """ selectables = list(_from_objects(join_to)) for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): return i, f else: return None, None def visit_binary_product(fn, expr): """Produce a traversal of the given expression, delivering column comparisons to the given function. The function is of the form:: def my_fn(binary, left, right) For each binary expression located which has a comparison operator, the product of "left" and "right" will be delivered to that function, in terms of that binary. Hence an expression like:: and_( (a + b) == q + func.sum(e + f), j == r ) would have the traversal:: a q a e a f b q b e b f j r That is, every combination of "left" and "right" that doesn't further contain a binary comparison is passed as pairs. """ stack = [] def visit(element): if isinstance(element, ScalarSelect): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element elif element.__visit_name__ == 'binary' and \ operators.is_comparison(element.operator): stack.insert(0, element) for l in visit(element.left): for r in visit(element.right): fn(stack[0], l, r) stack.pop(0) for elem in element.get_children(): visit(elem) else: if isinstance(element, ColumnClause): yield element for elem in element.get_children(): for e in visit(elem): yield e list(visit(expr)) def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): """locate Table objects within the given expression.""" tables = [] _visitors = {} if include_selects: _visitors['select'] = _visitors['compound_select'] = tables.append if include_joins: _visitors['join'] = tables.append if include_aliases: _visitors['alias'] = tables.append if include_crud: _visitors['insert'] = _visitors['update'] = \ _visitors['delete'] = lambda ent: tables.append(ent.table) if check_columns: def visit_column(column): tables.append(column.table) _visitors['column'] = visit_column _visitors['table'] = tables.append visitors.traverse(clause, {'column_collections': False}, _visitors) return tables def unwrap_order_by(clause): """Break up an 'order by' expression into individual column-expressions, without DESC/ASC/NULLS FIRST/NULLS LAST""" cols = util.column_set() stack = deque([clause]) while stack: t = stack.popleft() if isinstance(t, ColumnElement) and \ ( not isinstance(t, UnaryExpression) or \ not operators.is_ordering_modifier(t.modifier) ): cols.add(t) else: for c in t.get_children(): stack.append(c) return cols def clause_is_present(clause, search): """Given a target clause and a second to search within, return True if the target is plainly present in the search without any subqueries or aliases involved. Basically descends through Joins. """ for elem in surface_selectables(search): if clause == elem: # use == here so that Annotated's compare return True else: return False def surface_selectables(clause): stack = [clause] while stack: elem = stack.pop() yield elem if isinstance(elem, Join): stack.extend((elem.left, elem.right)) elif isinstance(elem, FromGrouping): stack.append(elem.element) def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" return bool( set(surface_selectables(left)).intersection( surface_selectables(right) ) ) def bind_values(clause): """Return an ordered list of "bound" values in the given clause. E.g.:: >>> expr = and_( ... table.c.foo==5, table.c.foo==7 ... ) >>> bind_values(expr) [5, 7] """ v = [] def visit_bindparam(bind): v.append(bind.effective_value) visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) return v def _quote_ddl_expr(element): if isinstance(element, util.string_types): element = element.replace("'", "''") return "'%s'" % element else: return repr(element) class _repr_params(object): """A string view of bound parameters, truncating display to the given number of 'multi' parameter sets. """ def __init__(self, params, batches): self.params = params self.batches = batches def __repr__(self): if isinstance(self.params, (list, tuple)) and \ len(self.params) > self.batches and \ isinstance(self.params[0], (list, dict, tuple)): msg = " ... displaying %i of %i total bound parameter sets ... " return ' '.join(( repr(self.params[:self.batches - 2])[0:-1], msg % (self.batches, len(self.params)), repr(self.params[-2:])[1:] )) else: return repr(self.params) def adapt_criterion_to_null(crit, nulls): """given criterion containing bind params, convert selected elements to IS NULL. """ def visit_binary(binary): if isinstance(binary.left, BindParameter) \ and binary.left._identifying_key in nulls: # reverse order if the NULL is on the left side binary.left = binary.right binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot elif isinstance(binary.right, BindParameter) \ and binary.right._identifying_key in nulls: binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) def splice_joins(left, right, stop_on=None): if left is None: return right stack = [(right, None)] adapter = ClauseAdapter(left) ret = None while stack: (right, prevright) = stack.pop() if isinstance(right, Join) and right is not stop_on: right = right._clone() right._reset_exported() right.onclause = adapter.traverse(right.onclause) stack.append((right.left, right)) else: right = adapter.traverse(right) if prevright is not None: prevright.left = right if ret is None: ret = right return ret def reduce_columns(columns, *clauses, **kw): """given a list of columns, return a 'reduced' set based on natural equivalents. the set is reduced to the smallest list of columns which have no natural equivalent present in the list. A "natural equivalent" means that two columns will ultimately represent the same value because they are related by a foreign key. \*clauses is an optional list of join clauses which will be traversed to further identify columns that are "equivalent". \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys whose tables are not yet configured, or columns that aren't yet present. This function is primarily used to determine the most minimal "primary key" from a selectable, by reducing the set of primary key columns present in the selectable to just those that are not repeated. """ ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) only_synonyms = kw.pop('only_synonyms', False) columns = util.ordered_column_set(columns) omit = util.column_set() for col in columns: for fk in chain(*[c.foreign_keys for c in col.proxy_set]): for c in columns: if c is col: continue try: fk_col = fk.column except exc.NoReferencedColumnError: # TODO: add specific coverage here # to test/sql/test_selectable ReduceTest if ignore_nonexistent_tables: continue else: raise except exc.NoReferencedTableError: # TODO: add specific coverage here # to test/sql/test_selectable ReduceTest if ignore_nonexistent_tables: continue else: raise if fk_col.shares_lineage(c) and \ (not only_synonyms or \ c.name == col.name): omit.add(col) break if clauses: def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)])) if binary.left in cols and binary.right in cols: for c in reversed(columns): if c.shares_lineage(binary.right) and \ (not only_synonyms or \ c.name == binary.left.name): omit.add(c) break for clause in clauses: if clause is not None: visitors.traverse(clause, {}, {'binary': visit_binary}) return ColumnSet(columns.difference(omit)) def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False): """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: raise exc.ArgumentError("Can only specify one of " "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") def col_is(a, b): #return a is b return a.compare(b) def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return if not isinstance(binary.left, ColumnElement) or \ not isinstance(binary.right, ColumnElement): return if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: if isinstance(binary.left, Column) and \ isinstance(binary.right, Column): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] visitors.traverse(expression, {}, {'binary': visit_binary}) return pairs class AliasedRow(object): """Wrap a RowProxy with a translation map. This object allows a set of keys to be translated to those present in a RowProxy. """ def __init__(self, row, map): # AliasedRow objects don't nest, so un-nest # if another AliasedRow was passed if isinstance(row, AliasedRow): self.row = row.row else: self.row = row self.map = map def __contains__(self, key): return self.map[key] in self.row def has_key(self, key): return key in self def __getitem__(self, key): return self.row[self.map[key]] def keys(self): return self.row.keys() class ClauseAdapter(visitors.ReplacingCloningVisitor): """Clones and modifies clauses based on column correspondence. E.g.:: table1 = Table('sometable', metadata, Column('col1', Integer), Column('col2', Integer) ) table2 = Table('someothertable', metadata, Column('col1', Integer), Column('col2', Integer) ) condition = table1.c.col1 == table2.c.col1 make an alias of table1:: s = table1.alias('foo') calling ``ClauseAdapter(s).traverse(condition)`` converts condition to read:: s.c.col1 == table2.c.col1 """ def __init__(self, selectable, equivalents=None, include=None, exclude=None, include_fn=None, exclude_fn=None, adapt_on_names=False): self.__traverse_options__ = {'stop_on': [selectable]} self.selectable = selectable if include: assert not include_fn self.include_fn = lambda e: e in include else: self.include_fn = include_fn if exclude: assert not exclude_fn self.exclude_fn = lambda e: e in exclude else: self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column( col, require_embedded=require_embedded) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col])) if newcol is not None: return newcol if self.adapt_on_names and newcol is None: newcol = self.selectable.c.get(col.name) return newcol magic_flag = False def replace(self, col): if not self.magic_flag and isinstance(col, FromClause) and \ self.selectable.is_derived_from(col): return self.selectable elif not isinstance(col, ColumnElement): return None elif self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None else: return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. Provides the ability to "wrap" this ClauseAdapter around another, a columns dictionary which returns adapted elements given an original, and an adapted_row() factory. """ def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None, adapt_required=False): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: self.chain(chain_to) self.columns = util.populate_column_dict(self._locate_col) self.adapt_required = adapt_required def wrap(self, adapter): ac = self.__class__.__new__(self.__class__) ac.__dict__ = self.__dict__.copy() ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) ac.columns = util.populate_column_dict(ac._locate_col) return ac adapt_clause = ClauseAdapter.traverse adapt_list = ClauseAdapter.copy_and_process def _wrap(self, local, wrapped): def locate(col): col = local(col) return wrapped(col) return locate def _locate_col(self, col): c = self._corresponding_column(col, True) if c is None: c = self.adapt_clause(col) # anonymize labels in case they have a hardcoded name if isinstance(c, Label): c = c.label(None) # adapt_required used by eager loading to indicate that # we don't trust a result row column that is not translated. # this is to prevent a column from being interpreted as that # of the child row in a self-referential scenario, see # inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency if self.adapt_required and c is col: return None return c def adapted_row(self, row): return AliasedRow(row, self.columns) def __getstate__(self): d = self.__dict__.copy() del d['columns'] return d def __setstate__(self, state): self.__dict__.update(state) self.columns = util.PopulateDict(self._locate_col)