1
0
mirror of https://github.com/moparisthebest/SickRage synced 2024-12-13 03:22:22 -05:00
SickRage/lib/sqlalchemy/sql/util.py

602 lines
19 KiB
Python
Raw Normal View History

2014-05-04 23:25:51 -04:00
# sql/util.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
2014-05-05 00:05:43 -04:00
"""High level utilities which build upon other modules here.
2014-05-05 00:05:43 -04:00
"""
2014-05-05 00:05:43 -04:00
from .. import exc, util
from .base import _from_objects, ColumnSet
from . import operators, visitors
from itertools import chain
from collections import deque
2014-05-05 00:05:43 -04:00
from .elements import BindParameter, ColumnClause, ColumnElement, \
Null, UnaryExpression, literal_column, Label
from .selectable import ScalarSelect, Join, FromClause, FromGrouping
from .schema import Column
2014-05-04 23:25:51 -04:00
2014-05-05 00:05:43 -04:00
join_condition = util.langhelpers.public_factory(
Join._join_condition,
".sql.util.join_condition")
2014-05-04 23:25:51 -04:00
2014-05-05 00:05:43 -04:00
# names that are still being imported from the outside
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
from .elements import _find_columns
from .ddl import sort_tables
2014-05-04 23:25:51 -04:00
def find_join_source(clauses, join_to):
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
None, None if no match is found.
e.g.::
clause1 = table1.join(table2)
clause2 = table4.join(table5)
join_to = table2.join(table3)
find_join_source([clause1, clause2], join_to) == clause1
"""
2014-05-05 00:05:43 -04:00
selectables = list(_from_objects(join_to))
2014-05-04 23:25:51 -04:00
for i, f in enumerate(clauses):
for s in selectables:
if f.is_derived_from(s):
return i, f
else:
return None, None
def visit_binary_product(fn, expr):
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
The function is of the form::
def my_fn(binary, left, right)
For each binary expression located which has a
comparison operator, the product of "left" and
"right" will be delivered to that function,
in terms of that binary.
Hence an expression like::
and_(
(a + b) == q + func.sum(e + f),
j == r
)
would have the traversal::
a <eq> q
a <eq> e
a <eq> f
b <eq> q
b <eq> e
b <eq> f
j <eq> r
That is, every combination of "left" and
"right" that doesn't further contain
a binary comparison is passed as pairs.
"""
stack = []
def visit(element):
2014-05-05 00:05:43 -04:00
if isinstance(element, ScalarSelect):
# we dont want to dig into correlated subqueries,
2014-05-04 23:25:51 -04:00
# those are just column elements by themselves
yield element
elif element.__visit_name__ == 'binary' and \
operators.is_comparison(element.operator):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
visit(elem)
else:
2014-05-05 00:05:43 -04:00
if isinstance(element, ColumnClause):
2014-05-04 23:25:51 -04:00
yield element
for elem in element.get_children():
for e in visit(elem):
yield e
list(visit(expr))
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
_visitors['select'] = _visitors['compound_select'] = tables.append
if include_joins:
_visitors['join'] = tables.append
if include_aliases:
_visitors['alias'] = tables.append
if include_crud:
_visitors['insert'] = _visitors['update'] = \
_visitors['delete'] = lambda ent: tables.append(ent.table)
if check_columns:
def visit_column(column):
tables.append(column.table)
_visitors['column'] = visit_column
_visitors['table'] = tables.append
visitors.traverse(clause, {'column_collections': False}, _visitors)
return tables
def unwrap_order_by(clause):
"""Break up an 'order by' expression into individual column-expressions,
without DESC/ASC/NULLS FIRST/NULLS LAST"""
cols = util.column_set()
stack = deque([clause])
while stack:
t = stack.popleft()
2014-05-05 00:05:43 -04:00
if isinstance(t, ColumnElement) and \
2014-05-04 23:25:51 -04:00
(
2014-05-05 00:05:43 -04:00
not isinstance(t, UnaryExpression) or \
2014-05-04 23:25:51 -04:00
not operators.is_ordering_modifier(t.modifier)
):
cols.add(t)
else:
for c in t.get_children():
stack.append(c)
return cols
def clause_is_present(clause, search):
"""Given a target clause and a second to search within, return True
if the target is plainly present in the search without any
subqueries or aliases involved.
Basically descends through Joins.
"""
2014-05-05 00:05:43 -04:00
for elem in surface_selectables(search):
if clause == elem: # use == here so that Annotated's compare
return True
2014-05-05 00:05:43 -04:00
else:
return False
def surface_selectables(clause):
stack = [clause]
while stack:
elem = stack.pop()
yield elem
if isinstance(elem, Join):
2014-05-04 23:25:51 -04:00
stack.extend((elem.left, elem.right))
2014-05-05 00:05:43 -04:00
elif isinstance(elem, FromGrouping):
stack.append(elem.element)
2014-05-04 23:25:51 -04:00
2014-05-05 00:05:43 -04:00
def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
set(surface_selectables(left)).intersection(
surface_selectables(right)
)
)
2014-05-04 23:25:51 -04:00
def bind_values(clause):
"""Return an ordered list of "bound" values in the given clause.
E.g.::
>>> expr = and_(
... table.c.foo==5, table.c.foo==7
... )
>>> bind_values(expr)
[5, 7]
"""
v = []
def visit_bindparam(bind):
v.append(bind.effective_value)
visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
return v
def _quote_ddl_expr(element):
2014-05-05 00:05:43 -04:00
if isinstance(element, util.string_types):
2014-05-04 23:25:51 -04:00
element = element.replace("'", "''")
return "'%s'" % element
else:
return repr(element)
class _repr_params(object):
"""A string view of bound parameters, truncating
display to the given number of 'multi' parameter sets.
"""
def __init__(self, params, batches):
self.params = params
self.batches = batches
def __repr__(self):
if isinstance(self.params, (list, tuple)) and \
len(self.params) > self.batches and \
isinstance(self.params[0], (list, dict, tuple)):
msg = " ... displaying %i of %i total bound parameter sets ... "
return ' '.join((
repr(self.params[:self.batches - 2])[0:-1],
msg % (self.batches, len(self.params)),
repr(self.params[-2:])[1:]
))
else:
return repr(self.params)
def adapt_criterion_to_null(crit, nulls):
"""given criterion containing bind params, convert selected elements
to IS NULL.
"""
def visit_binary(binary):
2014-05-05 00:05:43 -04:00
if isinstance(binary.left, BindParameter) \
2014-05-04 23:25:51 -04:00
and binary.left._identifying_key in nulls:
# reverse order if the NULL is on the left side
binary.left = binary.right
2014-05-05 00:05:43 -04:00
binary.right = Null()
2014-05-04 23:25:51 -04:00
binary.operator = operators.is_
binary.negate = operators.isnot
2014-05-05 00:05:43 -04:00
elif isinstance(binary.right, BindParameter) \
2014-05-04 23:25:51 -04:00
and binary.right._identifying_key in nulls:
2014-05-05 00:05:43 -04:00
binary.right = Null()
2014-05-04 23:25:51 -04:00
binary.operator = operators.is_
binary.negate = operators.isnot
return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
def splice_joins(left, right, stop_on=None):
if left is None:
return right
stack = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
while stack:
(right, prevright) = stack.pop()
2014-05-05 00:05:43 -04:00
if isinstance(right, Join) and right is not stop_on:
2014-05-04 23:25:51 -04:00
right = right._clone()
right._reset_exported()
right.onclause = adapter.traverse(right.onclause)
stack.append((right.left, right))
else:
right = adapter.traverse(right)
if prevright is not None:
prevright.left = right
if ret is None:
ret = right
return ret
def reduce_columns(columns, *clauses, **kw):
"""given a list of columns, return a 'reduced' set based on natural
equivalents.
the set is reduced to the smallest list of columns which have no natural
equivalent present in the list. A "natural equivalent" means that two
columns will ultimately represent the same value because they are related
by a foreign key.
\*clauses is an optional list of join clauses which will be traversed
to further identify columns that are "equivalent".
\**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
whose tables are not yet configured, or columns that aren't yet present.
This function is primarily used to determine the most minimal "primary key"
from a selectable, by reducing the set of primary key columns present
in the the selectable to just those that are not repeated.
2014-05-04 23:25:51 -04:00
"""
ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
only_synonyms = kw.pop('only_synonyms', False)
columns = util.ordered_column_set(columns)
omit = util.column_set()
for col in columns:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
for c in columns:
if c is col:
continue
try:
fk_col = fk.column
except exc.NoReferencedColumnError:
# TODO: add specific coverage here
# to test/sql/test_selectable ReduceTest
if ignore_nonexistent_tables:
continue
else:
raise
except exc.NoReferencedTableError:
# TODO: add specific coverage here
# to test/sql/test_selectable ReduceTest
if ignore_nonexistent_tables:
continue
else:
raise
if fk_col.shares_lineage(c) and \
(not only_synonyms or \
c.name == col.name):
omit.add(col)
break
if clauses:
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(chain(*[c.proxy_set
for c in columns.difference(omit)]))
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
if c.shares_lineage(binary.right) and \
(not only_synonyms or \
c.name == binary.left.name):
omit.add(c)
break
for clause in clauses:
if clause is not None:
visitors.traverse(clause, {}, {'binary': visit_binary})
2014-05-05 00:05:43 -04:00
return ColumnSet(columns.difference(omit))
2014-05-04 23:25:51 -04:00
def criterion_as_pairs(expression, consider_as_foreign_keys=None,
consider_as_referenced_keys=None, any_operator=False):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
raise exc.ArgumentError("Can only specify one of "
"'consider_as_foreign_keys' or "
"'consider_as_referenced_keys'")
def col_is(a, b):
#return a is b
return a.compare(b)
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
2014-05-05 00:05:43 -04:00
if not isinstance(binary.left, ColumnElement) or \
not isinstance(binary.right, ColumnElement):
2014-05-04 23:25:51 -04:00
return
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and \
(col_is(binary.right, binary.left) or
binary.right not in consider_as_foreign_keys):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and \
(col_is(binary.left, binary.right) or
binary.left not in consider_as_foreign_keys):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and \
(col_is(binary.right, binary.left) or
binary.right not in consider_as_referenced_keys):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and \
(col_is(binary.left, binary.right) or
binary.left not in consider_as_referenced_keys):
pairs.append((binary.right, binary.left))
else:
2014-05-05 00:05:43 -04:00
if isinstance(binary.left, Column) and \
isinstance(binary.right, Column):
2014-05-04 23:25:51 -04:00
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
pairs = []
visitors.traverse(expression, {}, {'binary': visit_binary})
return pairs
2014-05-05 00:05:43 -04:00
2014-05-04 23:25:51 -04:00
class AliasedRow(object):
"""Wrap a RowProxy with a translation map.
This object allows a set of keys to be translated
to those present in a RowProxy.
"""
def __init__(self, row, map):
# AliasedRow objects don't nest, so un-nest
# if another AliasedRow was passed
if isinstance(row, AliasedRow):
self.row = row.row
else:
self.row = row
self.map = map
def __contains__(self, key):
return self.map[key] in self.row
def has_key(self, key):
return key in self
def __getitem__(self, key):
return self.row[self.map[key]]
def keys(self):
return self.row.keys()
class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Clones and modifies clauses based on column correspondence.
E.g.::
table1 = Table('sometable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
table2 = Table('someothertable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
condition = table1.c.col1 == table2.c.col1
make an alias of table1::
s = table1.alias('foo')
calling ``ClauseAdapter(s).traverse(condition)`` converts
condition to read::
s.c.col1 == table2.c.col1
"""
def __init__(self, selectable, equivalents=None,
include=None, exclude=None,
include_fn=None, exclude_fn=None,
adapt_on_names=False):
self.__traverse_options__ = {'stop_on': [selectable]}
self.selectable = selectable
if include:
assert not include_fn
self.include_fn = lambda e: e in include
else:
self.include_fn = include_fn
if exclude:
assert not exclude_fn
self.exclude_fn = lambda e: e in exclude
else:
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
def _corresponding_column(self, col, require_embedded,
_seen=util.EMPTY_SET):
newcol = self.selectable.corresponding_column(
col,
require_embedded=require_embedded)
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(equiv,
require_embedded=require_embedded,
_seen=_seen.union([col]))
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
newcol = self.selectable.c.get(col.name)
return newcol
2014-05-05 00:05:43 -04:00
magic_flag = False
2014-05-04 23:25:51 -04:00
def replace(self, col):
2014-05-05 00:05:43 -04:00
if not self.magic_flag and isinstance(col, FromClause) and \
2014-05-04 23:25:51 -04:00
self.selectable.is_derived_from(col):
2014-05-05 00:05:43 -04:00
return self.selectable
elif not isinstance(col, ColumnElement):
2014-05-04 23:25:51 -04:00
return None
elif self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.
Provides the ability to "wrap" this ClauseAdapter
around another, a columns dictionary which returns
adapted elements given an original, and an
adapted_row() factory.
"""
def __init__(self, selectable, equivalents=None,
chain_to=None, include=None,
exclude=None, adapt_required=False):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to:
self.chain(chain_to)
self.columns = util.populate_column_dict(self._locate_col)
self.adapt_required = adapt_required
def wrap(self, adapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__ = self.__dict__.copy()
ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
ac.columns = util.populate_column_dict(ac._locate_col)
return ac
adapt_clause = ClauseAdapter.traverse
adapt_list = ClauseAdapter.copy_and_process
def _wrap(self, local, wrapped):
def locate(col):
col = local(col)
return wrapped(col)
return locate
def _locate_col(self, col):
c = self._corresponding_column(col, True)
if c is None:
c = self.adapt_clause(col)
# anonymize labels in case they have a hardcoded name
2014-05-05 00:05:43 -04:00
if isinstance(c, Label):
2014-05-04 23:25:51 -04:00
c = c.label(None)
2014-05-05 00:05:43 -04:00
# adapt_required used by eager loading to indicate that
# we don't trust a result row column that is not translated.
# this is to prevent a column from being interpreted as that
# of the child row in a self-referential scenario, see
# inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
2014-05-04 23:25:51 -04:00
if self.adapt_required and c is col:
return None
return c
def adapted_row(self, row):
return AliasedRow(row, self.columns)
def __getstate__(self):
d = self.__dict__.copy()
del d['columns']
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.columns = util.PopulateDict(self._locate_col)
2014-05-05 00:05:43 -04:00