1
0
mirror of https://github.com/moparisthebest/SickRage synced 2024-12-12 11:02:21 -05:00

Fixed issues with editing/saving custom scene exceptions.

Fixed charmap issues for anime show names.

Fixed issues with display show page and epCat key errors.

Fixed duplicate log messages for clearing provider caches.

Fixed issues with email notifier ep names not properly being encoded to UTF-8.

TVDB<->TVRAGE Indexer ID mapping is now performed on demand to be used when needed such as newznab providers can be searched with tvrage_id's and some will return tvrage_id's that later can be used to create show objects from for faster and more accurate name parsing, mapping is done via Trakt API calls.

Added stop event signals to schedualed tasks, SR now waits indefinate till task has been fully stopped before completing a restart or shutdown event.

NameParserCache is now persistent and stores 200 parsed results at any given time for quicker lookups and better performance, this helps maintain results between updates or shutdown/startup events.

Black and White lists for anime now only get used for anime shows as intended, performance gain for non-anime shows that dont need to load these lists.

Internal name cache now builds it self on demand when needed per show request plus checks if show is already in cache and if true exits routine to save time.

Schedualer and QueueItems classes are now a sub-class of threading.Thread and a stop threading event signal has been added to each.

If I forgot to list something it doesn't mean its not fixed so please test and report back if anything is wrong or has been corrected by this new release.
This commit is contained in:
echel0n 2014-07-14 19:00:53 -07:00
parent 09f53d3537
commit d02c0bd6eb
304 changed files with 922 additions and 102786 deletions

View File

@ -455,7 +455,7 @@ class SickRage(object):
sickbeard.showList.append(curShow) sickbeard.showList.append(curShow)
except Exception, e: except Exception, e:
logger.log( logger.log(
u"There was an error creating the show in " + sqlShow["location"] + ": " + str(e).decode('utf-8'), u"There was an error creating the show in " + sqlShow["location"] + ": " + str(e).decode('utf-8', 'replace'),
logger.ERROR) logger.ERROR)
def restore(self, srcDir, dstDir): def restore(self, srcDir, dstDir):
@ -477,14 +477,14 @@ class SickRage(object):
# stop all tasks # stop all tasks
sickbeard.halt() sickbeard.halt()
# save all shows to DB
sickbeard.saveAll()
# shutdown web server # shutdown web server
if self.webserver: if self.webserver:
self.webserver.shutDown() self.webserver.shutDown()
self.webserver = None self.webserver = None
# save all shows to DB
sickbeard.saveAll()
# if run as daemon delete the pidfile # if run as daemon delete the pidfile
if self.runAsDaemon and self.CREATEPID: if self.runAsDaemon and self.CREATEPID:
self.remove_pid_file(self.PIDFILE) self.remove_pid_file(self.PIDFILE)

View File

@ -190,13 +190,13 @@
#if $show.rls_ignore_words: #if $show.rls_ignore_words:
<tr><td class="showLegend">Ignored Words: </td><td>#echo $show.rls_ignore_words#</td></tr> <tr><td class="showLegend">Ignored Words: </td><td>#echo $show.rls_ignore_words#</td></tr>
#end if #end if
#if $bwl.get_white_keywords_for("release_group"): #if $bwl and $bwl.get_white_keywords_for("release_group"):
<tr> <tr>
<td class="showLegend">Wanted Group#if len($bwl.get_white_keywords_for("release_group"))>1 then "s" else ""#:</td> <td class="showLegend">Wanted Group#if len($bwl.get_white_keywords_for("release_group"))>1 then "s" else ""#:</td>
<td>#echo ', '.join($bwl.get_white_keywords_for("release_group"))#</td> <td>#echo ', '.join($bwl.get_white_keywords_for("release_group"))#</td>
</tr> </tr>
#end if #end if
#if $bwl.get_black_keywords_for("release_group"): #if $bwl and $bwl.get_black_keywords_for("release_group"):
<tr> <tr>
<td class="showLegend">Unwanted Group#if len($bwl.get_black_keywords_for("release_group"))>1 then "s" else ""#:</td> <td class="showLegend">Unwanted Group#if len($bwl.get_black_keywords_for("release_group"))>1 then "s" else ""#:</td>
<td>#echo ', '.join($bwl.get_black_keywords_for("release_group"))#</td> <td>#echo ', '.join($bwl.get_black_keywords_for("release_group"))#</td>
@ -265,6 +265,11 @@
<table class="sickbeardTable" cellspacing="1" border="0" cellpadding="0"> <table class="sickbeardTable" cellspacing="1" border="0" cellpadding="0">
#for $epResult in $sqlResults: #for $epResult in $sqlResults:
#set $epStr = str($epResult["season"]) + "x" + str($epResult["episode"])
#if not $epStr in $epCats:
#continue
#end if
#if not $sickbeard.DISPLAY_SHOW_SPECIALS and int($epResult["season"]) == 0: #if not $sickbeard.DISPLAY_SHOW_SPECIALS and int($epResult["season"]) == 0:
#continue #continue
#end if #end if
@ -314,7 +319,6 @@
#set $curSeason = int($epResult["season"]) #set $curSeason = int($epResult["season"])
#end if #end if
#set $epStr = str($epResult["season"]) + "x" + str($epResult["episode"])
#set $epLoc = $epResult["location"] #set $epLoc = $epResult["location"]
<tr class="$Overview.overviewStrings[$epCats[$epStr]] season-$curSeason"> <tr class="$Overview.overviewStrings[$epCats[$epStr]] season-$curSeason">
<td width="1%"> <td width="1%">

View File

@ -1,3 +0,0 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View File

@ -1,23 +0,0 @@
# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Execute computations asynchronously using threads or processes."""
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
from concurrent.futures._base import (FIRST_COMPLETED,
FIRST_EXCEPTION,
ALL_COMPLETED,
CancelledError,
TimeoutError,
Future,
Executor,
wait,
as_completed)
from concurrent.futures.thread import ThreadPoolExecutor
# Jython doesn't have multiprocessing
try:
from concurrent.futures.process import ProcessPoolExecutor
except ImportError:
pass

View File

@ -1,577 +0,0 @@
# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
from __future__ import with_statement
import logging
import threading
import time
try:
from collections import namedtuple
except ImportError:
from concurrent.futures._compat import namedtuple
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
FIRST_COMPLETED = 'FIRST_COMPLETED'
FIRST_EXCEPTION = 'FIRST_EXCEPTION'
ALL_COMPLETED = 'ALL_COMPLETED'
_AS_COMPLETED = '_AS_COMPLETED'
# Possible future states (for internal use by the futures package).
PENDING = 'PENDING'
RUNNING = 'RUNNING'
# The future was cancelled by the user...
CANCELLED = 'CANCELLED'
# ...and _Waiter.add_cancelled() was called by a worker.
CANCELLED_AND_NOTIFIED = 'CANCELLED_AND_NOTIFIED'
FINISHED = 'FINISHED'
_FUTURE_STATES = [
PENDING,
RUNNING,
CANCELLED,
CANCELLED_AND_NOTIFIED,
FINISHED
]
_STATE_TO_DESCRIPTION_MAP = {
PENDING: "pending",
RUNNING: "running",
CANCELLED: "cancelled",
CANCELLED_AND_NOTIFIED: "cancelled",
FINISHED: "finished"
}
# Logger for internal use by the futures package.
LOGGER = logging.getLogger("concurrent.futures")
class Error(Exception):
"""Base class for all future-related exceptions."""
pass
class CancelledError(Error):
"""The Future was cancelled."""
pass
class TimeoutError(Error):
"""The operation exceeded the given deadline."""
pass
class _Waiter(object):
"""Provides the event that wait() and as_completed() block on."""
def __init__(self):
self.event = threading.Event()
self.finished_futures = []
def add_result(self, future):
self.finished_futures.append(future)
def add_exception(self, future):
self.finished_futures.append(future)
def add_cancelled(self, future):
self.finished_futures.append(future)
class _AsCompletedWaiter(_Waiter):
"""Used by as_completed()."""
def __init__(self):
super(_AsCompletedWaiter, self).__init__()
self.lock = threading.Lock()
def add_result(self, future):
with self.lock:
super(_AsCompletedWaiter, self).add_result(future)
self.event.set()
def add_exception(self, future):
with self.lock:
super(_AsCompletedWaiter, self).add_exception(future)
self.event.set()
def add_cancelled(self, future):
with self.lock:
super(_AsCompletedWaiter, self).add_cancelled(future)
self.event.set()
class _FirstCompletedWaiter(_Waiter):
"""Used by wait(return_when=FIRST_COMPLETED)."""
def add_result(self, future):
super(_FirstCompletedWaiter, self).add_result(future)
self.event.set()
def add_exception(self, future):
super(_FirstCompletedWaiter, self).add_exception(future)
self.event.set()
def add_cancelled(self, future):
super(_FirstCompletedWaiter, self).add_cancelled(future)
self.event.set()
class _AllCompletedWaiter(_Waiter):
"""Used by wait(return_when=FIRST_EXCEPTION and ALL_COMPLETED)."""
def __init__(self, num_pending_calls, stop_on_exception):
self.num_pending_calls = num_pending_calls
self.stop_on_exception = stop_on_exception
self.lock = threading.Lock()
super(_AllCompletedWaiter, self).__init__()
def _decrement_pending_calls(self):
with self.lock:
self.num_pending_calls -= 1
if not self.num_pending_calls:
self.event.set()
def add_result(self, future):
super(_AllCompletedWaiter, self).add_result(future)
self._decrement_pending_calls()
def add_exception(self, future):
super(_AllCompletedWaiter, self).add_exception(future)
if self.stop_on_exception:
self.event.set()
else:
self._decrement_pending_calls()
def add_cancelled(self, future):
super(_AllCompletedWaiter, self).add_cancelled(future)
self._decrement_pending_calls()
class _AcquireFutures(object):
"""A context manager that does an ordered acquire of Future conditions."""
def __init__(self, futures):
self.futures = sorted(futures, key=id)
def __enter__(self):
for future in self.futures:
future._condition.acquire()
def __exit__(self, *args):
for future in self.futures:
future._condition.release()
def _create_and_install_waiters(fs, return_when):
if return_when == _AS_COMPLETED:
waiter = _AsCompletedWaiter()
elif return_when == FIRST_COMPLETED:
waiter = _FirstCompletedWaiter()
else:
pending_count = sum(
f._state not in [CANCELLED_AND_NOTIFIED, FINISHED] for f in fs)
if return_when == FIRST_EXCEPTION:
waiter = _AllCompletedWaiter(pending_count, stop_on_exception=True)
elif return_when == ALL_COMPLETED:
waiter = _AllCompletedWaiter(pending_count, stop_on_exception=False)
else:
raise ValueError("Invalid return condition: %r" % return_when)
for f in fs:
f._waiters.append(waiter)
return waiter
def as_completed(fs, timeout=None):
"""An iterator over the given futures that yields each as it completes.
Args:
fs: The sequence of Futures (possibly created by different Executors) to
iterate over.
timeout: The maximum number of seconds to wait. If None, then there
is no limit on the wait time.
Returns:
An iterator that yields the given Futures as they complete (finished or
cancelled).
Raises:
TimeoutError: If the entire result iterator could not be generated
before the given timeout.
"""
if timeout is not None:
end_time = timeout + time.time()
with _AcquireFutures(fs):
finished = set(
f for f in fs
if f._state in [CANCELLED_AND_NOTIFIED, FINISHED])
pending = set(fs) - finished
waiter = _create_and_install_waiters(fs, _AS_COMPLETED)
try:
for future in finished:
yield future
while pending:
if timeout is None:
wait_timeout = None
else:
wait_timeout = end_time - time.time()
if wait_timeout < 0:
raise TimeoutError(
'%d (of %d) futures unfinished' % (
len(pending), len(fs)))
waiter.event.wait(wait_timeout)
with waiter.lock:
finished = waiter.finished_futures
waiter.finished_futures = []
waiter.event.clear()
for future in finished:
yield future
pending.remove(future)
finally:
for f in fs:
f._waiters.remove(waiter)
DoneAndNotDoneFutures = namedtuple(
'DoneAndNotDoneFutures', 'done not_done')
def wait(fs, timeout=None, return_when=ALL_COMPLETED):
"""Wait for the futures in the given sequence to complete.
Args:
fs: The sequence of Futures (possibly created by different Executors) to
wait upon.
timeout: The maximum number of seconds to wait. If None, then there
is no limit on the wait time.
return_when: Indicates when this function should return. The options
are:
FIRST_COMPLETED - Return when any future finishes or is
cancelled.
FIRST_EXCEPTION - Return when any future finishes by raising an
exception. If no future raises an exception
then it is equivalent to ALL_COMPLETED.
ALL_COMPLETED - Return when all futures finish or are cancelled.
Returns:
A named 2-tuple of sets. The first set, named 'done', contains the
futures that completed (is finished or cancelled) before the wait
completed. The second set, named 'not_done', contains uncompleted
futures.
"""
with _AcquireFutures(fs):
done = set(f for f in fs
if f._state in [CANCELLED_AND_NOTIFIED, FINISHED])
not_done = set(fs) - done
if (return_when == FIRST_COMPLETED) and done:
return DoneAndNotDoneFutures(done, not_done)
elif (return_when == FIRST_EXCEPTION) and done:
if any(f for f in done
if not f.cancelled() and f.exception() is not None):
return DoneAndNotDoneFutures(done, not_done)
if len(done) == len(fs):
return DoneAndNotDoneFutures(done, not_done)
waiter = _create_and_install_waiters(fs, return_when)
waiter.event.wait(timeout)
for f in fs:
f._waiters.remove(waiter)
done.update(waiter.finished_futures)
return DoneAndNotDoneFutures(done, set(fs) - done)
class Future(object):
"""Represents the result of an asynchronous computation."""
def __init__(self):
"""Initializes the future. Should not be called by clients."""
self._condition = threading.Condition()
self._state = PENDING
self._result = None
self._exception = None
self._waiters = []
self._done_callbacks = []
def _invoke_callbacks(self):
for callback in self._done_callbacks:
try:
callback(self)
except Exception:
LOGGER.exception('exception calling callback for %r', self)
def __repr__(self):
with self._condition:
if self._state == FINISHED:
if self._exception:
return '<Future at %s state=%s raised %s>' % (
hex(id(self)),
_STATE_TO_DESCRIPTION_MAP[self._state],
self._exception.__class__.__name__)
else:
return '<Future at %s state=%s returned %s>' % (
hex(id(self)),
_STATE_TO_DESCRIPTION_MAP[self._state],
self._result.__class__.__name__)
return '<Future at %s state=%s>' % (
hex(id(self)),
_STATE_TO_DESCRIPTION_MAP[self._state])
def cancel(self):
"""Cancel the future if possible.
Returns True if the future was cancelled, False otherwise. A future
cannot be cancelled if it is running or has already completed.
"""
with self._condition:
if self._state in [RUNNING, FINISHED]:
return False
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
return True
self._state = CANCELLED
self._condition.notify_all()
self._invoke_callbacks()
return True
def cancelled(self):
"""Return True if the future has cancelled."""
with self._condition:
return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]
def isAlive(self):
return self.running()
def running(self):
"""Return True if the future is currently executing."""
with self._condition:
return self._state == RUNNING
def done(self):
"""Return True of the future was cancelled or finished executing."""
with self._condition:
return self._state in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]
def __get_result(self):
if self._exception:
raise self._exception
else:
return self._result
def add_done_callback(self, fn):
"""Attaches a callable that will be called when the future finishes.
Args:
fn: A callable that will be called with this future as its only
argument when the future completes or is cancelled. The callable
will always be called by a thread in the same process in which
it was added. If the future has already completed or been
cancelled then the callable will be called immediately. These
callables are called in the order that they were added.
"""
with self._condition:
if self._state not in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]:
self._done_callbacks.append(fn)
return
fn(self)
def result(self, timeout=None):
"""Return the result of the call that the future represents.
Args:
timeout: The number of seconds to wait for the result if the future
isn't done. If None, then there is no limit on the wait time.
Returns:
The result of the call that the future represents.
Raises:
CancelledError: If the future was cancelled.
TimeoutError: If the future didn't finish executing before the given
timeout.
Exception: If the call raised then that exception will be raised.
"""
with self._condition:
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
raise CancelledError()
elif self._state == FINISHED:
return self.__get_result()
self._condition.wait(timeout)
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
raise CancelledError()
elif self._state == FINISHED:
return self.__get_result()
else:
raise TimeoutError()
def exception(self, timeout=None):
"""Return the exception raised by the call that the future represents.
Args:
timeout: The number of seconds to wait for the exception if the
future isn't done. If None, then there is no limit on the wait
time.
Returns:
The exception raised by the call that the future represents or None
if the call completed without raising.
Raises:
CancelledError: If the future was cancelled.
TimeoutError: If the future didn't finish executing before the given
timeout.
"""
with self._condition:
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
raise CancelledError()
elif self._state == FINISHED:
return self._exception
self._condition.wait(timeout)
if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
raise CancelledError()
elif self._state == FINISHED:
return self._exception
else:
raise TimeoutError()
# The following methods should only be used by Executors and in tests.
def set_running_or_notify_cancel(self):
"""Mark the future as running or process any cancel notifications.
Should only be used by Executor implementations and unit tests.
If the future has been cancelled (cancel() was called and returned
True) then any threads waiting on the future completing (though calls
to as_completed() or wait()) are notified and False is returned.
If the future was not cancelled then it is put in the running state
(future calls to running() will return True) and True is returned.
This method should be called by Executor implementations before
executing the work associated with this future. If this method returns
False then the work should not be executed.
Returns:
False if the Future was cancelled, True otherwise.
Raises:
RuntimeError: if this method was already called or if set_result()
or set_exception() was called.
"""
with self._condition:
if self._state == CANCELLED:
self._state = CANCELLED_AND_NOTIFIED
for waiter in self._waiters:
waiter.add_cancelled(self)
# self._condition.notify_all() is not necessary because
# self.cancel() triggers a notification.
return False
elif self._state == PENDING:
self._state = RUNNING
return True
else:
LOGGER.critical('Future %s in unexpected state: %s',
id(self.future),
self.future._state)
raise RuntimeError('Future in unexpected state')
def set_result(self, result):
"""Sets the return value of work associated with the future.
Should only be used by Executor implementations and unit tests.
"""
with self._condition:
self._result = result
self._state = FINISHED
for waiter in self._waiters:
waiter.add_result(self)
self._condition.notify_all()
self._invoke_callbacks()
def set_exception(self, exception):
"""Sets the result of the future as being the given exception.
Should only be used by Executor implementations and unit tests.
"""
with self._condition:
self._exception = exception
self._state = FINISHED
for waiter in self._waiters:
waiter.add_exception(self)
self._condition.notify_all()
self._invoke_callbacks()
class Executor(object):
"""This is an abstract base class for concrete asynchronous executors."""
def submit(self, fn, *args, **kwargs):
"""Submits a callable to be executed with the given arguments.
Schedules the callable to be executed as fn(*args, **kwargs) and returns
a Future instance representing the execution of the callable.
Returns:
A Future representing the given call.
"""
raise NotImplementedError()
def map(self, fn, *iterables, **kwargs):
"""Returns a iterator equivalent to map(fn, iter).
Args:
fn: A callable that will take as many arguments as there are
passed iterables.
timeout: The maximum number of seconds to wait. If None, then there
is no limit on the wait time.
Returns:
An iterator equivalent to: map(func, *iterables) but the calls may
be evaluated out-of-order.
Raises:
TimeoutError: If the entire result iterator could not be generated
before the given timeout.
Exception: If fn(*args) raises for any values.
"""
timeout = kwargs.get('timeout')
if timeout is not None:
end_time = timeout + time.time()
fs = [self.submit(fn, *args) for args in zip(*iterables)]
try:
for future in fs:
if timeout is None:
yield future.result()
else:
yield future.result(end_time - time.time())
finally:
for future in fs:
future.cancel()
def shutdown(self, wait=True):
"""Clean-up the resources associated with the Executor.
It is safe to call this method several times. Otherwise, no other
methods can be called after this one.
Args:
wait: If True then shutdown will not return until all running
futures have finished executing and the resources used by the
executor have been reclaimed.
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown(wait=True)
return False

View File

@ -1,101 +0,0 @@
from keyword import iskeyword as _iskeyword
from operator import itemgetter as _itemgetter
import sys as _sys
def namedtuple(typename, field_names):
"""Returns a new subclass of tuple with named fields.
>>> Point = namedtuple('Point', 'x y')
>>> Point.__doc__ # docstring for the new class
'Point(x, y)'
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p[0] + p[1] # indexable like a plain tuple
33
>>> x, y = p # unpack like a regular tuple
>>> x, y
(11, 22)
>>> p.x + p.y # fields also accessable by name
33
>>> d = p._asdict() # convert to a dictionary
>>> d['x']
11
>>> Point(**d) # convert from a dictionary
Point(x=11, y=22)
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
Point(x=100, y=22)
"""
# Parse and validate the field names. Validation serves two purposes,
# generating informative error messages and preventing template injection attacks.
if isinstance(field_names, basestring):
field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
field_names = tuple(map(str, field_names))
for name in (typename,) + field_names:
if not all(c.isalnum() or c=='_' for c in name):
raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
if _iskeyword(name):
raise ValueError('Type names and field names cannot be a keyword: %r' % name)
if name[0].isdigit():
raise ValueError('Type names and field names cannot start with a number: %r' % name)
seen_names = set()
for name in field_names:
if name.startswith('_'):
raise ValueError('Field names cannot start with an underscore: %r' % name)
if name in seen_names:
raise ValueError('Encountered duplicate field name: %r' % name)
seen_names.add(name)
# Create and fill-in the class template
numfields = len(field_names)
argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
reprtxt = ', '.join('%s=%%r' % name for name in field_names)
dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names))
template = '''class %(typename)s(tuple):
'%(typename)s(%(argtxt)s)' \n
__slots__ = () \n
_fields = %(field_names)r \n
def __new__(_cls, %(argtxt)s):
return _tuple.__new__(_cls, (%(argtxt)s)) \n
@classmethod
def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new %(typename)s object from a sequence or iterable'
result = new(cls, iterable)
if len(result) != %(numfields)d:
raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
return result \n
def __repr__(self):
return '%(typename)s(%(reprtxt)s)' %% self \n
def _asdict(t):
'Return a new dict which maps field names to their values'
return {%(dicttxt)s} \n
def _replace(_self, **kwds):
'Return a new %(typename)s object replacing specified fields with new values'
result = _self._make(map(kwds.pop, %(field_names)r, _self))
if kwds:
raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
return result \n
def __getnewargs__(self):
return tuple(self) \n\n''' % locals()
for i, name in enumerate(field_names):
template += ' %s = _property(_itemgetter(%d))\n' % (name, i)
# Execute the template string in a temporary namespace and
# support tracing utilities by setting a value for frame.f_globals['__name__']
namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
_property=property, _tuple=tuple)
try:
exec(template, namespace)
except SyntaxError:
e = _sys.exc_info()[1]
raise SyntaxError(e.message + ':\n' + template)
result = namespace[typename]
# For pickling to work, the __module__ variable needs to be set to the frame
# where the named tuple is created. Bypass this step in enviroments where
# sys._getframe is not defined (Jython for example).
if hasattr(_sys, '_getframe'):
result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
return result

View File

@ -1,363 +0,0 @@
# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Implements ProcessPoolExecutor.
The follow diagram and text describe the data-flow through the system:
|======================= In-process =====================|== Out-of-process ==|
+----------+ +----------+ +--------+ +-----------+ +---------+
| | => | Work Ids | => | | => | Call Q | => | |
| | +----------+ | | +-----------+ | |
| | | ... | | | | ... | | |
| | | 6 | | | | 5, call() | | |
| | | 7 | | | | ... | | |
| Process | | ... | | Local | +-----------+ | Process |
| Pool | +----------+ | Worker | | #1..n |
| Executor | | Thread | | |
| | +----------- + | | +-----------+ | |
| | <=> | Work Items | <=> | | <= | Result Q | <= | |
| | +------------+ | | +-----------+ | |
| | | 6: call() | | | | ... | | |
| | | future | | | | 4, result | | |
| | | ... | | | | 3, except | | |
+----------+ +------------+ +--------+ +-----------+ +---------+
Executor.submit() called:
- creates a uniquely numbered _WorkItem and adds it to the "Work Items" dict
- adds the id of the _WorkItem to the "Work Ids" queue
Local worker thread:
- reads work ids from the "Work Ids" queue and looks up the corresponding
WorkItem from the "Work Items" dict: if the work item has been cancelled then
it is simply removed from the dict, otherwise it is repackaged as a
_CallItem and put in the "Call Q". New _CallItems are put in the "Call Q"
until "Call Q" is full. NOTE: the size of the "Call Q" is kept small because
calls placed in the "Call Q" can no longer be cancelled with Future.cancel().
- reads _ResultItems from "Result Q", updates the future stored in the
"Work Items" dict and deletes the dict entry
Process #1..n:
- reads _CallItems from "Call Q", executes the calls, and puts the resulting
_ResultItems in "Request Q"
"""
from __future__ import with_statement
import atexit
import multiprocessing
import threading
import weakref
import sys
from concurrent.futures import _base
try:
import queue
except ImportError:
import Queue as queue
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
# Workers are created as daemon threads and processes. This is done to allow the
# interpreter to exit when there are still idle processes in a
# ProcessPoolExecutor's process pool (i.e. shutdown() was not called). However,
# allowing workers to die with the interpreter has two undesirable properties:
# - The workers would still be running during interpretor shutdown,
# meaning that they would fail in unpredictable ways.
# - The workers could be killed while evaluating a work item, which could
# be bad if the callable being evaluated has external side-effects e.g.
# writing to a file.
#
# To work around this problem, an exit handler is installed which tells the
# workers to exit when their work queues are empty and then waits until the
# threads/processes finish.
_threads_queues = weakref.WeakKeyDictionary()
_shutdown = False
def _python_exit():
global _shutdown
_shutdown = True
items = list(_threads_queues.items())
for t, q in items:
q.put(None)
for t, q in items:
t.join()
# Controls how many more calls than processes will be queued in the call queue.
# A smaller number will mean that processes spend more time idle waiting for
# work while a larger number will make Future.cancel() succeed less frequently
# (Futures in the call queue cannot be cancelled).
EXTRA_QUEUED_CALLS = 1
class _WorkItem(object):
def __init__(self, future, fn, args, kwargs):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs
class _ResultItem(object):
def __init__(self, work_id, exception=None, result=None):
self.work_id = work_id
self.exception = exception
self.result = result
class _CallItem(object):
def __init__(self, work_id, fn, args, kwargs):
self.work_id = work_id
self.fn = fn
self.args = args
self.kwargs = kwargs
def _process_worker(call_queue, result_queue):
"""Evaluates calls from call_queue and places the results in result_queue.
This worker is run in a separate process.
Args:
call_queue: A multiprocessing.Queue of _CallItems that will be read and
evaluated by the worker.
result_queue: A multiprocessing.Queue of _ResultItems that will written
to by the worker.
shutdown: A multiprocessing.Event that will be set as a signal to the
worker that it should exit when call_queue is empty.
"""
while True:
call_item = call_queue.get(block=True)
if call_item is None:
# Wake up queue management thread
result_queue.put(None)
return
try:
r = call_item.fn(*call_item.args, **call_item.kwargs)
except BaseException:
e = sys.exc_info()[1]
result_queue.put(_ResultItem(call_item.work_id,
exception=e))
else:
result_queue.put(_ResultItem(call_item.work_id,
result=r))
def _add_call_item_to_queue(pending_work_items,
work_ids,
call_queue):
"""Fills call_queue with _WorkItems from pending_work_items.
This function never blocks.
Args:
pending_work_items: A dict mapping work ids to _WorkItems e.g.
{5: <_WorkItem...>, 6: <_WorkItem...>, ...}
work_ids: A queue.Queue of work ids e.g. Queue([5, 6, ...]). Work ids
are consumed and the corresponding _WorkItems from
pending_work_items are transformed into _CallItems and put in
call_queue.
call_queue: A multiprocessing.Queue that will be filled with _CallItems
derived from _WorkItems.
"""
while True:
if call_queue.full():
return
try:
work_id = work_ids.get(block=False)
except queue.Empty:
return
else:
work_item = pending_work_items[work_id]
if work_item.future.set_running_or_notify_cancel():
call_queue.put(_CallItem(work_id,
work_item.fn,
work_item.args,
work_item.kwargs),
block=True)
else:
del pending_work_items[work_id]
continue
def _queue_management_worker(executor_reference,
processes,
pending_work_items,
work_ids_queue,
call_queue,
result_queue):
"""Manages the communication between this process and the worker processes.
This function is run in a local thread.
Args:
executor_reference: A weakref.ref to the ProcessPoolExecutor that owns
this thread. Used to determine if the ProcessPoolExecutor has been
garbage collected and that this function can exit.
process: A list of the multiprocessing.Process instances used as
workers.
pending_work_items: A dict mapping work ids to _WorkItems e.g.
{5: <_WorkItem...>, 6: <_WorkItem...>, ...}
work_ids_queue: A queue.Queue of work ids e.g. Queue([5, 6, ...]).
call_queue: A multiprocessing.Queue that will be filled with _CallItems
derived from _WorkItems for processing by the process workers.
result_queue: A multiprocessing.Queue of _ResultItems generated by the
process workers.
"""
nb_shutdown_processes = [0]
def shutdown_one_process():
"""Tell a worker to terminate, which will in turn wake us again"""
call_queue.put(None)
nb_shutdown_processes[0] += 1
while True:
_add_call_item_to_queue(pending_work_items,
work_ids_queue,
call_queue)
result_item = result_queue.get(block=True)
if result_item is not None:
work_item = pending_work_items[result_item.work_id]
del pending_work_items[result_item.work_id]
if result_item.exception:
work_item.future.set_exception(result_item.exception)
else:
work_item.future.set_result(result_item.result)
# Check whether we should start shutting down.
executor = executor_reference()
# No more work items can be added if:
# - The interpreter is shutting down OR
# - The executor that owns this worker has been collected OR
# - The executor that owns this worker has been shutdown.
if _shutdown or executor is None or executor._shutdown_thread:
# Since no new work items can be added, it is safe to shutdown
# this thread if there are no pending work items.
if not pending_work_items:
while nb_shutdown_processes[0] < len(processes):
shutdown_one_process()
# If .join() is not called on the created processes then
# some multiprocessing.Queue methods may deadlock on Mac OS
# X.
for p in processes:
p.join()
call_queue.close()
return
del executor
_system_limits_checked = False
_system_limited = None
def _check_system_limits():
global _system_limits_checked, _system_limited
if _system_limits_checked:
if _system_limited:
raise NotImplementedError(_system_limited)
_system_limits_checked = True
try:
import os
nsems_max = os.sysconf("SC_SEM_NSEMS_MAX")
except (AttributeError, ValueError):
# sysconf not available or setting not available
return
if nsems_max == -1:
# indetermine limit, assume that limit is determined
# by available memory only
return
if nsems_max >= 256:
# minimum number of semaphores available
# according to POSIX
return
_system_limited = "system provides too few semaphores (%d available, 256 necessary)" % nsems_max
raise NotImplementedError(_system_limited)
class ProcessPoolExecutor(_base.Executor):
def __init__(self, max_workers=None):
"""Initializes a new ProcessPoolExecutor instance.
Args:
max_workers: The maximum number of processes that can be used to
execute the given calls. If None or not given then as many
worker processes will be created as the machine has processors.
"""
_check_system_limits()
if max_workers is None:
self._max_workers = multiprocessing.cpu_count()
else:
self._max_workers = max_workers
# Make the call queue slightly larger than the number of processes to
# prevent the worker processes from idling. But don't make it too big
# because futures in the call queue cannot be cancelled.
self._call_queue = multiprocessing.Queue(self._max_workers +
EXTRA_QUEUED_CALLS)
self._result_queue = multiprocessing.Queue()
self._work_ids = queue.Queue()
self._queue_management_thread = None
self._processes = set()
# Shutdown is a two-step process.
self._shutdown_thread = False
self._shutdown_lock = threading.Lock()
self._queue_count = 0
self._pending_work_items = {}
def _start_queue_management_thread(self):
# When the executor gets lost, the weakref callback will wake up
# the queue management thread.
def weakref_cb(_, q=self._result_queue):
q.put(None)
if self._queue_management_thread is None:
self._queue_management_thread = threading.Thread(
target=_queue_management_worker,
args=(weakref.ref(self, weakref_cb),
self._processes,
self._pending_work_items,
self._work_ids,
self._call_queue,
self._result_queue))
self._queue_management_thread.daemon = True
self._queue_management_thread.start()
_threads_queues[self._queue_management_thread] = self._result_queue
def _adjust_process_count(self):
for _ in range(len(self._processes), self._max_workers):
p = multiprocessing.Process(
target=_process_worker,
args=(self._call_queue,
self._result_queue))
p.start()
self._processes.add(p)
def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._shutdown_thread:
raise RuntimeError('cannot schedule new futures after shutdown')
f = _base.Future()
w = _WorkItem(f, fn, args, kwargs)
self._pending_work_items[self._queue_count] = w
self._work_ids.put(self._queue_count)
self._queue_count += 1
# Wake up queue management thread
self._result_queue.put(None)
self._start_queue_management_thread()
self._adjust_process_count()
return f
submit.__doc__ = _base.Executor.submit.__doc__
def shutdown(self, wait=True):
with self._shutdown_lock:
self._shutdown_thread = True
if self._queue_management_thread:
# Wake up queue management thread
self._result_queue.put(None)
if wait:
self._queue_management_thread.join()
# To reduce the risk of openning too many files, remove references to
# objects that use file descriptors.
self._queue_management_thread = None
self._call_queue = None
self._result_queue = None
self._processes = None
shutdown.__doc__ = _base.Executor.shutdown.__doc__
atexit.register(_python_exit)

View File

@ -1,145 +0,0 @@
# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Implements ThreadPoolExecutor."""
from __future__ import with_statement
import atexit
import threading
import weakref
import sys
from concurrent.futures import _base
try:
import queue
except ImportError:
import Queue as queue
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
# Workers are created as daemon threads. This is done to allow the interpreter
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
# pool (i.e. shutdown() was not called). However, allowing workers to die with
# the interpreter has two undesirable properties:
# - The workers would still be running during interpretor shutdown,
# meaning that they would fail in unpredictable ways.
# - The workers could be killed while evaluating a work item, which could
# be bad if the callable being evaluated has external side-effects e.g.
# writing to a file.
#
# To work around this problem, an exit handler is installed which tells the
# workers to exit when their work queues are empty and then waits until the
# threads finish.
_threads_queues = weakref.WeakKeyDictionary()
_shutdown = False
def _python_exit():
global _shutdown
_shutdown = True
items = list(_threads_queues.items())
for t, q in items:
q.put(None)
for t, q in items:
t.join()
atexit.register(_python_exit)
class _WorkItem(object):
def __init__(self, future, fn, args, kwargs):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs
def run(self):
if not self.future.set_running_or_notify_cancel():
return
try:
result = self.fn(*self.args, **self.kwargs)
except BaseException:
e = sys.exc_info()[1]
self.future.set_exception(e)
else:
self.future.set_result(result)
def _worker(executor_reference, work_queue):
try:
while True:
work_item = work_queue.get(block=True)
if work_item is not None:
work_item.run()
continue
executor = executor_reference()
# Exit if:
# - The interpreter is shutting down OR
# - The executor that owns the worker has been collected OR
# - The executor that owns the worker has been shutdown.
if _shutdown or executor is None or executor._shutdown:
# Notice other workers
work_queue.put(None)
return
del executor
except BaseException:
_base.LOGGER.critical('Exception in worker', exc_info=True)
class ThreadPoolExecutor(_base.Executor):
def __init__(self, max_workers):
"""Initializes a new ThreadPoolExecutor instance.
Args:
max_workers: The maximum number of threads that can be used to
execute the given calls.
"""
self._max_workers = max_workers
self._work_queue = queue.Queue()
self._threads = set()
self._shutdown = False
self._shutdown_lock = threading.Lock()
def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._shutdown:
raise RuntimeError('cannot schedule new futures after shutdown')
f = _base.Future()
w = _WorkItem(f, fn, args, kwargs)
self._work_queue.put(w)
name = None
if kwargs.has_key('name'):
name = kwargs.pop('name')
self._adjust_thread_count(name)
return f
submit.__doc__ = _base.Executor.submit.__doc__
def _adjust_thread_count(self, name=None):
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)
# TODO(bquinlan): Should avoid creating new threads if there are more
# idle threads than items in the work queue.
if len(self._threads) < self._max_workers:
t = threading.Thread(target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue),)
if name:
t.name = name
t.daemon = True
t.start()
self._threads.add(t)
_threads_queues[t] = self._work_queue
def shutdown(self, wait=True):
with self._shutdown_lock:
self._shutdown = True
self._work_queue.put(None)
if wait:
for t in self._threads:
t.join()
shutdown.__doc__ = _base.Executor.shutdown.__doc__

View File

@ -1,24 +0,0 @@
# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Execute computations asynchronously using threads or processes."""
import warnings
from concurrent.futures import (FIRST_COMPLETED,
FIRST_EXCEPTION,
ALL_COMPLETED,
CancelledError,
TimeoutError,
Future,
Executor,
wait,
as_completed,
ProcessPoolExecutor,
ThreadPoolExecutor)
__author__ = 'Brian Quinlan (brian@sweetapp.com)'
warnings.warn('The futures package has been deprecated. '
'Use the concurrent.futures package instead.',
DeprecationWarning)

View File

@ -1 +0,0 @@
from concurrent.futures import ProcessPoolExecutor

View File

@ -1 +0,0 @@
from concurrent.futures import ThreadPoolExecutor

View File

@ -1,519 +0,0 @@
# -*- coding: utf-8 -*-
'''Common object storage frontend.'''
import os
import zlib
import urllib
try:
import cPickle as pickle
except ImportError:
import pickle
from collections import deque
try:
# Import store and cache entry points if setuptools installed
import pkg_resources
stores = dict((_store.name, _store) for _store in
pkg_resources.iter_entry_points('shove.stores'))
caches = dict((_cache.name, _cache) for _cache in
pkg_resources.iter_entry_points('shove.caches'))
# Pass if nothing loaded
if not stores and not caches:
raise ImportError()
except ImportError:
# Static store backend registry
stores = dict(
bsddb='shove.store.bsdb:BsdStore',
cassandra='shove.store.cassandra:CassandraStore',
dbm='shove.store.dbm:DbmStore',
durus='shove.store.durusdb:DurusStore',
file='shove.store.file:FileStore',
firebird='shove.store.db:DbStore',
ftp='shove.store.ftp:FtpStore',
hdf5='shove.store.hdf5:HDF5Store',
leveldb='shove.store.leveldbstore:LevelDBStore',
memory='shove.store.memory:MemoryStore',
mssql='shove.store.db:DbStore',
mysql='shove.store.db:DbStore',
oracle='shove.store.db:DbStore',
postgres='shove.store.db:DbStore',
redis='shove.store.redisdb:RedisStore',
s3='shove.store.s3:S3Store',
simple='shove.store.simple:SimpleStore',
sqlite='shove.store.db:DbStore',
svn='shove.store.svn:SvnStore',
zodb='shove.store.zodb:ZodbStore',
)
# Static cache backend registry
caches = dict(
bsddb='shove.cache.bsdb:BsdCache',
file='shove.cache.file:FileCache',
filelru='shove.cache.filelru:FileLRUCache',
firebird='shove.cache.db:DbCache',
memcache='shove.cache.memcached:MemCached',
memlru='shove.cache.memlru:MemoryLRUCache',
memory='shove.cache.memory:MemoryCache',
mssql='shove.cache.db:DbCache',
mysql='shove.cache.db:DbCache',
oracle='shove.cache.db:DbCache',
postgres='shove.cache.db:DbCache',
redis='shove.cache.redisdb:RedisCache',
simple='shove.cache.simple:SimpleCache',
simplelru='shove.cache.simplelru:SimpleLRUCache',
sqlite='shove.cache.db:DbCache',
)
def getbackend(uri, engines, **kw):
'''
Loads the right backend based on a URI.
@param uri Instance or name string
@param engines A dictionary of scheme/class pairs
'''
if isinstance(uri, basestring):
mod = engines[uri.split('://', 1)[0]]
# Load module if setuptools not present
if isinstance(mod, basestring):
# Isolate classname from dot path
module, klass = mod.split(':')
# Load module
mod = getattr(__import__(module, '', '', ['']), klass)
# Load appropriate class from setuptools entry point
else:
mod = mod.load()
# Return instance
return mod(uri, **kw)
# No-op for existing instances
return uri
def synchronized(func):
'''
Decorator to lock and unlock a method (Phillip J. Eby).
@param func Method to decorate
'''
def wrapper(self, *__args, **__kw):
self._lock.acquire()
try:
return func(self, *__args, **__kw)
finally:
self._lock.release()
wrapper.__name__ = func.__name__
wrapper.__dict__ = func.__dict__
wrapper.__doc__ = func.__doc__
return wrapper
class Base(object):
'''Base Mapping class.'''
def __init__(self, engine, **kw):
'''
@keyword compress True, False, or an integer compression level (1-9).
'''
self._compress = kw.get('compress', False)
self._protocol = kw.get('protocol', pickle.HIGHEST_PROTOCOL)
def __getitem__(self, key):
raise NotImplementedError()
def __setitem__(self, key, value):
raise NotImplementedError()
def __delitem__(self, key):
raise NotImplementedError()
def __contains__(self, key):
try:
value = self[key]
except KeyError:
return False
return True
def get(self, key, default=None):
'''
Fetch a given key from the mapping. If the key does not exist,
return the default.
@param key Keyword of item in mapping.
@param default Default value (default: None)
'''
try:
return self[key]
except KeyError:
return default
def dumps(self, value):
'''Optionally serializes and compresses an object.'''
# Serialize everything but ASCII strings
value = pickle.dumps(value, protocol=self._protocol)
if self._compress:
level = 9 if self._compress is True else self._compress
value = zlib.compress(value, level)
return value
def loads(self, value):
'''Deserializes and optionally decompresses an object.'''
if self._compress:
try:
value = zlib.decompress(value)
except zlib.error:
pass
value = pickle.loads(value)
return value
class BaseStore(Base):
'''Base Store class (based on UserDict.DictMixin).'''
def __init__(self, engine, **kw):
super(BaseStore, self).__init__(engine, **kw)
self._store = None
def __cmp__(self, other):
if other is None:
return False
if isinstance(other, BaseStore):
return cmp(dict(self.iteritems()), dict(other.iteritems()))
def __del__(self):
# __init__ didn't succeed, so don't bother closing
if not hasattr(self, '_store'):
return
self.close()
def __iter__(self):
for k in self.keys():
yield k
def __len__(self):
return len(self.keys())
def __repr__(self):
return repr(dict(self.iteritems()))
def close(self):
'''Closes internal store and clears object references.'''
try:
self._store.close()
except AttributeError:
pass
self._store = None
def clear(self):
'''Removes all keys and values from a store.'''
for key in self.keys():
del self[key]
def items(self):
'''Returns a list with all key/value pairs in the store.'''
return list(self.iteritems())
def iteritems(self):
'''Lazily returns all key/value pairs in a store.'''
for k in self:
yield (k, self[k])
def iterkeys(self):
'''Lazy returns all keys in a store.'''
return self.__iter__()
def itervalues(self):
'''Lazily returns all values in a store.'''
for _, v in self.iteritems():
yield v
def keys(self):
'''Returns a list with all keys in a store.'''
raise NotImplementedError()
def pop(self, key, *args):
'''
Removes and returns a value from a store.
@param args Default to return if key not present.
'''
if len(args) > 1:
raise TypeError('pop expected at most 2 arguments, got ' + repr(
1 + len(args))
)
try:
value = self[key]
# Return default if key not in store
except KeyError:
if args:
return args[0]
del self[key]
return value
def popitem(self):
'''Removes and returns a key, value pair from a store.'''
try:
k, v = self.iteritems().next()
except StopIteration:
raise KeyError('Store is empty.')
del self[k]
return (k, v)
def setdefault(self, key, default=None):
'''
Returns the value corresponding to an existing key or sets the
to key to the default and returns the default.
@param default Default value (default: None)
'''
try:
return self[key]
except KeyError:
self[key] = default
return default
def update(self, other=None, **kw):
'''
Adds to or overwrites the values in this store with values from
another store.
other Another store
kw Additional keys and values to store
'''
if other is None:
pass
elif hasattr(other, 'iteritems'):
for k, v in other.iteritems():
self[k] = v
elif hasattr(other, 'keys'):
for k in other.keys():
self[k] = other[k]
else:
for k, v in other:
self[k] = v
if kw:
self.update(kw)
def values(self):
'''Returns a list with all values in a store.'''
return list(v for _, v in self.iteritems())
class Shove(BaseStore):
'''Common object frontend class.'''
def __init__(self, store='simple://', cache='simple://', **kw):
super(Shove, self).__init__(store, **kw)
# Load store
self._store = getbackend(store, stores, **kw)
# Load cache
self._cache = getbackend(cache, caches, **kw)
# Buffer for lazy writing and setting for syncing frequency
self._buffer, self._sync = dict(), kw.get('sync', 2)
def __getitem__(self, key):
'''Gets a item from shove.'''
try:
return self._cache[key]
except KeyError:
# Synchronize cache and store
self.sync()
value = self._store[key]
self._cache[key] = value
return value
def __setitem__(self, key, value):
'''Sets an item in shove.'''
self._cache[key] = self._buffer[key] = value
# When the buffer reaches self._limit, writes the buffer to the store
if len(self._buffer) >= self._sync:
self.sync()
def __delitem__(self, key):
'''Deletes an item from shove.'''
try:
del self._cache[key]
except KeyError:
pass
self.sync()
del self._store[key]
def keys(self):
'''Returns a list of keys in shove.'''
self.sync()
return self._store.keys()
def sync(self):
'''Writes buffer to store.'''
for k, v in self._buffer.iteritems():
self._store[k] = v
self._buffer.clear()
def close(self):
'''Finalizes and closes shove.'''
# If close has been called, pass
if self._store is not None:
try:
self.sync()
except AttributeError:
pass
self._store.close()
self._store = self._cache = self._buffer = None
class FileBase(Base):
'''Base class for file based storage.'''
def __init__(self, engine, **kw):
super(FileBase, self).__init__(engine, **kw)
if engine.startswith('file://'):
engine = urllib.url2pathname(engine.split('://')[1])
self._dir = engine
# Create directory
if not os.path.exists(self._dir):
self._createdir()
def __getitem__(self, key):
# (per Larry Meyn)
try:
item = open(self._key_to_file(key), 'rb')
data = item.read()
item.close()
return self.loads(data)
except:
raise KeyError(key)
def __setitem__(self, key, value):
# (per Larry Meyn)
try:
item = open(self._key_to_file(key), 'wb')
item.write(self.dumps(value))
item.close()
except (IOError, OSError):
raise KeyError(key)
def __delitem__(self, key):
try:
os.remove(self._key_to_file(key))
except (IOError, OSError):
raise KeyError(key)
def __contains__(self, key):
return os.path.exists(self._key_to_file(key))
def __len__(self):
return len(os.listdir(self._dir))
def _createdir(self):
'''Creates the store directory.'''
try:
os.makedirs(self._dir)
except OSError:
raise EnvironmentError(
'Cache directory "%s" does not exist and ' \
'could not be created' % self._dir
)
def _key_to_file(self, key):
'''Gives the filesystem path for a key.'''
return os.path.join(self._dir, urllib.quote_plus(key))
def keys(self):
'''Returns a list of keys in the store.'''
return [urllib.unquote_plus(name) for name in os.listdir(self._dir)]
class SimpleBase(Base):
'''Single-process in-memory store base class.'''
def __init__(self, engine, **kw):
super(SimpleBase, self).__init__(engine, **kw)
self._store = dict()
def __getitem__(self, key):
try:
return self._store[key]
except:
raise KeyError(key)
def __setitem__(self, key, value):
self._store[key] = value
def __delitem__(self, key):
try:
del self._store[key]
except:
raise KeyError(key)
def __len__(self):
return len(self._store)
def keys(self):
'''Returns a list of keys in the store.'''
return self._store.keys()
class LRUBase(SimpleBase):
def __init__(self, engine, **kw):
super(LRUBase, self).__init__(engine, **kw)
self._max_entries = kw.get('max_entries', 300)
self._hits = 0
self._misses = 0
self._queue = deque()
self._refcount = dict()
def __getitem__(self, key):
try:
value = super(LRUBase, self).__getitem__(key)
self._hits += 1
except KeyError:
self._misses += 1
raise
self._housekeep(key)
return value
def __setitem__(self, key, value):
super(LRUBase, self).__setitem__(key, value)
self._housekeep(key)
if len(self._store) > self._max_entries:
while len(self._store) > self._max_entries:
k = self._queue.popleft()
self._refcount[k] -= 1
if not self._refcount[k]:
super(LRUBase, self).__delitem__(k)
del self._refcount[k]
def _housekeep(self, key):
self._queue.append(key)
self._refcount[key] = self._refcount.get(key, 0) + 1
if len(self._queue) > self._max_entries * 4:
self._purge_queue()
def _purge_queue(self):
for i in [None] * len(self._queue):
k = self._queue.popleft()
if self._refcount[k] == 1:
self._queue.append(k)
else:
self._refcount[k] -= 1
class DbBase(Base):
'''Database common base class.'''
def __init__(self, engine, **kw):
super(DbBase, self).__init__(engine, **kw)
def __delitem__(self, key):
self._store.delete(self._store.c.key == key).execute()
def __len__(self):
return self._store.count().execute().fetchone()[0]
__all__ = ['Shove']

View File

@ -1 +0,0 @@
# -*- coding: utf-8 -*-

117
lib/shove/cache/db.py vendored
View File

@ -1,117 +0,0 @@
# -*- coding: utf-8 -*-
'''
Database object cache.
The shove psuedo-URL used for database object caches is the format used by
SQLAlchemy:
<driver>://<username>:<password>@<host>:<port>/<database>
<driver> is the database engine. The engines currently supported SQLAlchemy are
sqlite, mysql, postgres, oracle, mssql, and firebird.
<username> is the database account user name
<password> is the database accound password
<host> is the database location
<port> is the database port
<database> is the name of the specific database
For more information on specific databases see:
http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported
'''
import time
import random
from datetime import datetime
try:
from sqlalchemy import (
MetaData, Table, Column, String, Binary, DateTime, select, update,
insert, delete,
)
from shove import DbBase
except ImportError:
raise ImportError('Requires SQLAlchemy >= 0.4')
__all__ = ['DbCache']
class DbCache(DbBase):
'''database cache backend'''
def __init__(self, engine, **kw):
super(DbCache, self).__init__(engine, **kw)
# Get table name
tablename = kw.get('tablename', 'cache')
# Bind metadata
self._metadata = MetaData(engine)
# Make cache table
self._store = Table(tablename, self._metadata,
Column('key', String(60), primary_key=True, nullable=False),
Column('value', Binary, nullable=False),
Column('expires', DateTime, nullable=False),
)
# Create cache table if it does not exist
if not self._store.exists():
self._store.create()
# Set maximum entries
self._max_entries = kw.get('max_entries', 300)
# Maximum number of entries to cull per call if cache is full
self._maxcull = kw.get('maxcull', 10)
# Set timeout
self.timeout = kw.get('timeout', 300)
def __getitem__(self, key):
row = select(
[self._store.c.value, self._store.c.expires],
self._store.c.key == key
).execute().fetchone()
if row is not None:
# Remove if item expired
if row.expires < datetime.now().replace(microsecond=0):
del self[key]
raise KeyError(key)
return self.loads(str(row.value))
raise KeyError(key)
def __setitem__(self, key, value):
timeout, value, cache = self.timeout, self.dumps(value), self._store
# Cull if too many items
if len(self) >= self._max_entries:
self._cull()
# Generate expiration time
expires = datetime.fromtimestamp(
time.time() + timeout
).replace(microsecond=0)
# Update database if key already present
if key in self:
update(
cache,
cache.c.key == key,
dict(value=value, expires=expires),
).execute()
# Insert new key if key not present
else:
insert(
cache, dict(key=key, value=value, expires=expires)
).execute()
def _cull(self):
'''Remove items in cache to make more room.'''
cache, maxcull = self._store, self._maxcull
# Remove items that have timed out
now = datetime.now().replace(microsecond=0)
delete(cache, cache.c.expires < now).execute()
# Remove any items over the maximum allowed number in the cache
if len(self) >= self._max_entries:
# Upper limit for key query
ul = maxcull * 2
# Get list of keys
keys = [
i[0] for i in select(
[cache.c.key], limit=ul
).execute().fetchall()
]
# Get some keys at random
delkeys = list(random.choice(keys) for i in xrange(maxcull))
delete(cache, cache.c.key.in_(delkeys)).execute()

View File

@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
'''
File-based cache
shove's psuedo-URL for file caches follows the form:
file://<path>
Where the path is a URL path to a directory on a local filesystem.
Alternatively, a native pathname to the directory can be passed as the 'engine'
argument.
'''
import time
from shove import FileBase
from shove.cache.simple import SimpleCache
class FileCache(FileBase, SimpleCache):
'''File-based cache backend'''
def __init__(self, engine, **kw):
super(FileCache, self).__init__(engine, **kw)
def __getitem__(self, key):
try:
exp, value = super(FileCache, self).__getitem__(key)
# Remove item if time has expired.
if exp < time.time():
del self[key]
raise KeyError(key)
return value
except:
raise KeyError(key)
def __setitem__(self, key, value):
if len(self) >= self._max_entries:
self._cull()
super(FileCache, self).__setitem__(
key, (time.time() + self.timeout, value)
)
__all__ = ['FileCache']

View File

@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
'''
File-based LRU cache
shove's psuedo-URL for file caches follows the form:
file://<path>
Where the path is a URL path to a directory on a local filesystem.
Alternatively, a native pathname to the directory can be passed as the 'engine'
argument.
'''
from shove import FileBase
from shove.cache.simplelru import SimpleLRUCache
class FileCache(FileBase, SimpleLRUCache):
'''File-based LRU cache backend'''
__all__ = ['FileCache']

View File

@ -1,43 +0,0 @@
# -*- coding: utf-8 -*-
'''
"memcached" cache.
The shove psuedo-URL for a memcache cache is:
memcache://<memcache_server>
'''
try:
import memcache
except ImportError:
raise ImportError("Memcache cache requires the 'memcache' library")
from shove import Base
class MemCached(Base):
'''Memcached cache backend'''
def __init__(self, engine, **kw):
super(MemCached, self).__init__(engine, **kw)
if engine.startswith('memcache://'):
engine = engine.split('://')[1]
self._store = memcache.Client(engine.split(';'))
# Set timeout
self.timeout = kw.get('timeout', 300)
def __getitem__(self, key):
value = self._store.get(key)
if value is None:
raise KeyError(key)
return self.loads(value)
def __setitem__(self, key, value):
self._store.set(key, self.dumps(value), self.timeout)
def __delitem__(self, key):
self._store.delete(key)
__all__ = ['MemCached']

View File

@ -1,38 +0,0 @@
# -*- coding: utf-8 -*-
'''
Thread-safe in-memory cache using LRU.
The shove psuedo-URL for a memory cache is:
memlru://
'''
import copy
import threading
from shove import synchronized
from shove.cache.simplelru import SimpleLRUCache
class MemoryLRUCache(SimpleLRUCache):
'''Thread-safe in-memory cache backend using LRU.'''
def __init__(self, engine, **kw):
super(MemoryLRUCache, self).__init__(engine, **kw)
self._lock = threading.Condition()
@synchronized
def __setitem__(self, key, value):
super(MemoryLRUCache, self).__setitem__(key, value)
@synchronized
def __getitem__(self, key):
return copy.deepcopy(super(MemoryLRUCache, self).__getitem__(key))
@synchronized
def __delitem__(self, key):
super(MemoryLRUCache, self).__delitem__(key)
__all__ = ['MemoryLRUCache']

View File

@ -1,38 +0,0 @@
# -*- coding: utf-8 -*-
'''
Thread-safe in-memory cache.
The shove psuedo-URL for a memory cache is:
memory://
'''
import copy
import threading
from shove import synchronized
from shove.cache.simple import SimpleCache
class MemoryCache(SimpleCache):
'''Thread-safe in-memory cache backend.'''
def __init__(self, engine, **kw):
super(MemoryCache, self).__init__(engine, **kw)
self._lock = threading.Condition()
@synchronized
def __setitem__(self, key, value):
super(MemoryCache, self).__setitem__(key, value)
@synchronized
def __getitem__(self, key):
return copy.deepcopy(super(MemoryCache, self).__getitem__(key))
@synchronized
def __delitem__(self, key):
super(MemoryCache, self).__delitem__(key)
__all__ = ['MemoryCache']

View File

@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
'''
Redis-based object cache
The shove psuedo-URL for a redis cache is:
redis://<host>:<port>/<db>
'''
import urlparse
try:
import redis
except ImportError:
raise ImportError('This store requires the redis library')
from shove import Base
class RedisCache(Base):
'''Redis cache backend'''
init = 'redis://'
def __init__(self, engine, **kw):
super(RedisCache, self).__init__(engine, **kw)
spliturl = urlparse.urlsplit(engine)
host, port = spliturl[1].split(':')
db = spliturl[2].replace('/', '')
self._store = redis.Redis(host, int(port), db)
# Set timeout
self.timeout = kw.get('timeout', 300)
def __getitem__(self, key):
return self.loads(self._store[key])
def __setitem__(self, key, value):
self._store.setex(key, self.dumps(value), self.timeout)
def __delitem__(self, key):
self._store.delete(key)
__all__ = ['RedisCache']

View File

@ -1,68 +0,0 @@
# -*- coding: utf-8 -*-
'''
Single-process in-memory cache.
The shove psuedo-URL for a simple cache is:
simple://
'''
import time
import random
from shove import SimpleBase
class SimpleCache(SimpleBase):
'''Single-process in-memory cache.'''
def __init__(self, engine, **kw):
super(SimpleCache, self).__init__(engine, **kw)
# Get random seed
random.seed()
# Set maximum number of items to cull if over max
self._maxcull = kw.get('maxcull', 10)
# Set max entries
self._max_entries = kw.get('max_entries', 300)
# Set timeout
self.timeout = kw.get('timeout', 300)
def __getitem__(self, key):
exp, value = super(SimpleCache, self).__getitem__(key)
# Delete if item timed out.
if exp < time.time():
super(SimpleCache, self).__delitem__(key)
raise KeyError(key)
return value
def __setitem__(self, key, value):
# Cull values if over max # of entries
if len(self) >= self._max_entries:
self._cull()
# Set expiration time and value
exp = time.time() + self.timeout
super(SimpleCache, self).__setitem__(key, (exp, value))
def _cull(self):
'''Remove items in cache to make room.'''
num, maxcull = 0, self._maxcull
# Cull number of items allowed (set by self._maxcull)
for key in self.keys():
# Remove only maximum # of items allowed by maxcull
if num <= maxcull:
# Remove items if expired
try:
self[key]
except KeyError:
num += 1
else:
break
# Remove any additional items up to max # of items allowed by maxcull
while len(self) >= self._max_entries and num <= maxcull:
# Cull remainder of allowed quota at random
del self[random.choice(self.keys())]
num += 1
__all__ = ['SimpleCache']

View File

@ -1,18 +0,0 @@
# -*- coding: utf-8 -*-
'''
Single-process in-memory LRU cache.
The shove psuedo-URL for a simple cache is:
simplelru://
'''
from shove import LRUBase
class SimpleLRUCache(LRUBase):
'''In-memory cache that purges based on least recently used item.'''
__all__ = ['SimpleLRUCache']

View File

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
from urllib import url2pathname
from shove.store.simple import SimpleStore
class ClientStore(SimpleStore):
'''Base class for stores where updates have to be committed.'''
def __init__(self, engine, **kw):
super(ClientStore, self).__init__(engine, **kw)
if engine.startswith(self.init):
self._engine = url2pathname(engine.split('://')[1])
def __getitem__(self, key):
return self.loads(super(ClientStore, self).__getitem__(key))
def __setitem__(self, key, value):
super(ClientStore, self).__setitem__(key, self.dumps(value))
class SyncStore(ClientStore):
'''Base class for stores where updates have to be committed.'''
def __getitem__(self, key):
return self.loads(super(SyncStore, self).__getitem__(key))
def __setitem__(self, key, value):
super(SyncStore, self).__setitem__(key, value)
try:
self.sync()
except AttributeError:
pass
def __delitem__(self, key):
super(SyncStore, self).__delitem__(key)
try:
self.sync()
except AttributeError:
pass
__all__ = [
'bsdb', 'db', 'dbm', 'durusdb', 'file', 'ftp', 'memory', 's3', 'simple',
'svn', 'zodb', 'redisdb', 'hdf5db', 'leveldbstore', 'cassandra',
]

View File

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
'''
Berkeley Source Database Store.
shove's psuedo-URL for BSDDB stores follows the form:
bsddb://<path>
Where the path is a URL path to a Berkeley database. Alternatively, the native
pathname to a Berkeley database can be passed as the 'engine' parameter.
'''
try:
import bsddb
except ImportError:
raise ImportError('requires bsddb library')
import threading
from shove import synchronized
from shove.store import SyncStore
class BsdStore(SyncStore):
'''Class for Berkeley Source Database Store.'''
init = 'bsddb://'
def __init__(self, engine, **kw):
super(BsdStore, self).__init__(engine, **kw)
self._store = bsddb.hashopen(self._engine)
self._lock = threading.Condition()
self.sync = self._store.sync
@synchronized
def __getitem__(self, key):
return super(BsdStore, self).__getitem__(key)
@synchronized
def __setitem__(self, key, value):
super(BsdStore, self).__setitem__(key, value)
@synchronized
def __delitem__(self, key):
super(BsdStore, self).__delitem__(key)
__all__ = ['BsdStore']

View File

@ -1,72 +0,0 @@
# -*- coding: utf-8 -*-
'''
Cassandra-based object store
The shove psuedo-URL for a cassandra-based store is:
cassandra://<host>:<port>/<keyspace>/<columnFamily>
'''
import urlparse
try:
import pycassa
except ImportError:
raise ImportError('This store requires the pycassa library')
from shove import BaseStore
class CassandraStore(BaseStore):
'''Cassandra based store'''
init = 'cassandra://'
def __init__(self, engine, **kw):
super(CassandraStore, self).__init__(engine, **kw)
spliturl = urlparse.urlsplit(engine)
_, keyspace, column_family = spliturl[2].split('/')
try:
self._pool = pycassa.connect(keyspace, [spliturl[1]])
self._store = pycassa.ColumnFamily(self._pool, column_family)
except pycassa.InvalidRequestException:
from pycassa.system_manager import SystemManager
system_manager = SystemManager(spliturl[1])
system_manager.create_keyspace(
keyspace,
pycassa.system_manager.SIMPLE_STRATEGY,
{'replication_factor': str(kw.get('replication', 1))}
)
system_manager.create_column_family(keyspace, column_family)
self._pool = pycassa.connect(keyspace, [spliturl[1]])
self._store = pycassa.ColumnFamily(self._pool, column_family)
def __getitem__(self, key):
try:
item = self._store.get(key).get(key)
if item is not None:
return self.loads(item)
raise KeyError(key)
except pycassa.NotFoundException:
raise KeyError(key)
def __setitem__(self, key, value):
self._store.insert(key, dict(key=self.dumps(value)))
def __delitem__(self, key):
# beware eventual consistency
try:
self._store.remove(key)
except pycassa.NotFoundException:
raise KeyError(key)
def clear(self):
# beware eventual consistency
self._store.truncate()
def keys(self):
return list(i[0] for i in self._store.get_range())
__all__ = ['CassandraStore']

View File

@ -1,73 +0,0 @@
# -*- coding: utf-8 -*-
'''
Database object store.
The shove psuedo-URL used for database object stores is the format used by
SQLAlchemy:
<driver>://<username>:<password>@<host>:<port>/<database>
<driver> is the database engine. The engines currently supported SQLAlchemy are
sqlite, mysql, postgres, oracle, mssql, and firebird.
<username> is the database account user name
<password> is the database accound password
<host> is the database location
<port> is the database port
<database> is the name of the specific database
For more information on specific databases see:
http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported
'''
try:
from sqlalchemy import MetaData, Table, Column, String, Binary, select
from shove import BaseStore, DbBase
except ImportError, e:
raise ImportError('Error: ' + e + ' Requires SQLAlchemy >= 0.4')
class DbStore(BaseStore, DbBase):
'''Database cache backend.'''
def __init__(self, engine, **kw):
super(DbStore, self).__init__(engine, **kw)
# Get tablename
tablename = kw.get('tablename', 'store')
# Bind metadata
self._metadata = MetaData(engine)
# Make store table
self._store = Table(tablename, self._metadata,
Column('key', String(255), primary_key=True, nullable=False),
Column('value', Binary, nullable=False),
)
# Create store table if it does not exist
if not self._store.exists():
self._store.create()
def __getitem__(self, key):
row = select(
[self._store.c.value], self._store.c.key == key,
).execute().fetchone()
if row is not None:
return self.loads(str(row.value))
raise KeyError(key)
def __setitem__(self, k, v):
v, store = self.dumps(v), self._store
# Update database if key already present
if k in self:
store.update(store.c.key == k).execute(value=v)
# Insert new key if key not present
else:
store.insert().execute(key=k, value=v)
def keys(self):
'''Returns a list of keys in the store.'''
return list(i[0] for i in select(
[self._store.c.key]
).execute().fetchall())
__all__ = ['DbStore']

View File

@ -1,33 +0,0 @@
# -*- coding: utf-8 -*-
'''
DBM Database Store.
shove's psuedo-URL for DBM stores follows the form:
dbm://<path>
Where <path> is a URL path to a DBM database. Alternatively, the native
pathname to a DBM database can be passed as the 'engine' parameter.
'''
import anydbm
from shove.store import SyncStore
class DbmStore(SyncStore):
'''Class for variants of the DBM database.'''
init = 'dbm://'
def __init__(self, engine, **kw):
super(DbmStore, self).__init__(engine, **kw)
self._store = anydbm.open(self._engine, 'c')
try:
self.sync = self._store.sync
except AttributeError:
pass
__all__ = ['DbmStore']

View File

@ -1,43 +0,0 @@
# -*- coding: utf-8 -*-
'''
Durus object database frontend.
shove's psuedo-URL for Durus stores follows the form:
durus://<path>
Where the path is a URL path to a durus FileStorage database. Alternatively, a
native pathname to a durus database can be passed as the 'engine' parameter.
'''
try:
from durus.connection import Connection
from durus.file_storage import FileStorage
except ImportError:
raise ImportError('Requires Durus library')
from shove.store import SyncStore
class DurusStore(SyncStore):
'''Class for Durus object database frontend.'''
init = 'durus://'
def __init__(self, engine, **kw):
super(DurusStore, self).__init__(engine, **kw)
self._db = FileStorage(self._engine)
self._connection = Connection(self._db)
self.sync = self._connection.commit
self._store = self._connection.get_root()
def close(self):
'''Closes all open storage and connections.'''
self.sync()
self._db.close()
super(DurusStore, self).close()
__all__ = ['DurusStore']

View File

@ -1,25 +0,0 @@
# -*- coding: utf-8 -*-
'''
Filesystem-based object store
shove's psuedo-URL for filesystem-based stores follows the form:
file://<path>
Where the path is a URL path to a directory on a local filesystem.
Alternatively, a native pathname to the directory can be passed as the 'engine'
argument.
'''
from shove import BaseStore, FileBase
class FileStore(FileBase, BaseStore):
'''File-based store.'''
def __init__(self, engine, **kw):
super(FileStore, self).__init__(engine, **kw)
__all__ = ['FileStore']

View File

@ -1,88 +0,0 @@
# -*- coding: utf-8 -*-
'''
FTP-accessed stores
shove's URL for FTP accessed stores follows the standard form for FTP URLs
defined in RFC-1738:
ftp://<user>:<password>@<host>:<port>/<url-path>
'''
import urlparse
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
from ftplib import FTP, error_perm
from shove import BaseStore
class FtpStore(BaseStore):
def __init__(self, engine, **kw):
super(FtpStore, self).__init__(engine, **kw)
user = kw.get('user', 'anonymous')
password = kw.get('password', '')
spliturl = urlparse.urlsplit(engine)
# Set URL, path, and strip 'ftp://' off
base, path = spliturl[1], spliturl[2] + '/'
if '@' in base:
auth, base = base.split('@')
user, password = auth.split(':')
self._store = FTP(base, user, password)
# Change to remote path if it exits
try:
self._store.cwd(path)
except error_perm:
self._makedir(path)
self._base, self._user, self._password = base, user, password
self._updated, self ._keys = True, None
def __getitem__(self, key):
try:
local = StringIO()
# Download item
self._store.retrbinary('RETR %s' % key, local.write)
self._updated = False
return self.loads(local.getvalue())
except:
raise KeyError(key)
def __setitem__(self, key, value):
local = StringIO(self.dumps(value))
self._store.storbinary('STOR %s' % key, local)
self._updated = True
def __delitem__(self, key):
try:
self._store.delete(key)
self._updated = True
except:
raise KeyError(key)
def _makedir(self, path):
'''Makes remote paths on an FTP server.'''
paths = list(reversed([i for i in path.split('/') if i != '']))
while paths:
tpath = paths.pop()
self._store.mkd(tpath)
self._store.cwd(tpath)
def keys(self):
'''Returns a list of keys in a store.'''
if self._updated or self._keys is None:
rlist, nlist = list(), list()
# Remote directory listing
self._store.retrlines('LIST -a', rlist.append)
for rlisting in rlist:
# Split remote file based on whitespace
rfile = rlisting.split()
# Append tuple of remote item type & name
if rfile[-1] not in ('.', '..') and rfile[0].startswith('-'):
nlist.append(rfile[-1])
self._keys = nlist
return self._keys
__all__ = ['FtpStore']

View File

@ -1,34 +0,0 @@
# -*- coding: utf-8 -*-
'''
HDF5 Database Store.
shove's psuedo-URL for HDF5 stores follows the form:
hdf5://<path>/<group>
Where <path> is a URL path to a HDF5 database. Alternatively, the native
pathname to a HDF5 database can be passed as the 'engine' parameter.
<group> is the name of the database.
'''
try:
import h5py
except ImportError:
raise ImportError('This store requires h5py library')
from shove.store import ClientStore
class HDF5Store(ClientStore):
'''LevelDB based store'''
init = 'hdf5://'
def __init__(self, engine, **kw):
super(HDF5Store, self).__init__(engine, **kw)
engine, group = self._engine.rsplit('/')
self._store = h5py.File(engine).require_group(group).attrs
__all__ = ['HDF5Store']

View File

@ -1,47 +0,0 @@
# -*- coding: utf-8 -*-
'''
LevelDB Database Store.
shove's psuedo-URL for LevelDB stores follows the form:
leveldb://<path>
Where <path> is a URL path to a LevelDB database. Alternatively, the native
pathname to a LevelDB database can be passed as the 'engine' parameter.
'''
try:
import leveldb
except ImportError:
raise ImportError('This store requires py-leveldb library')
from shove.store import ClientStore
class LevelDBStore(ClientStore):
'''LevelDB based store'''
init = 'leveldb://'
def __init__(self, engine, **kw):
super(LevelDBStore, self).__init__(engine, **kw)
self._store = leveldb.LevelDB(self._engine)
def __getitem__(self, key):
item = self.loads(self._store.Get(key))
if item is not None:
return item
raise KeyError(key)
def __setitem__(self, key, value):
self._store.Put(key, self.dumps(value))
def __delitem__(self, key):
self._store.Delete(key)
def keys(self):
return list(k for k in self._store.RangeIter(include_value=False))
__all__ = ['LevelDBStore']

View File

@ -1,38 +0,0 @@
# -*- coding: utf-8 -*-
'''
Thread-safe in-memory store.
The shove psuedo-URL for a memory store is:
memory://
'''
import copy
import threading
from shove import synchronized
from shove.store.simple import SimpleStore
class MemoryStore(SimpleStore):
'''Thread-safe in-memory store.'''
def __init__(self, engine, **kw):
super(MemoryStore, self).__init__(engine, **kw)
self._lock = threading.Condition()
@synchronized
def __getitem__(self, key):
return copy.deepcopy(super(MemoryStore, self).__getitem__(key))
@synchronized
def __setitem__(self, key, value):
super(MemoryStore, self).__setitem__(key, value)
@synchronized
def __delitem__(self, key):
super(MemoryStore, self).__delitem__(key)
__all__ = ['MemoryStore']

View File

@ -1,50 +0,0 @@
# -*- coding: utf-8 -*-
'''
Redis-based object store
The shove psuedo-URL for a redis-based store is:
redis://<host>:<port>/<db>
'''
import urlparse
try:
import redis
except ImportError:
raise ImportError('This store requires the redis library')
from shove.store import ClientStore
class RedisStore(ClientStore):
'''Redis based store'''
init = 'redis://'
def __init__(self, engine, **kw):
super(RedisStore, self).__init__(engine, **kw)
spliturl = urlparse.urlsplit(engine)
host, port = spliturl[1].split(':')
db = spliturl[2].replace('/', '')
self._store = redis.Redis(host, int(port), db)
def __contains__(self, key):
return self._store.exists(key)
def clear(self):
self._store.flushdb()
def keys(self):
return self._store.keys()
def setdefault(self, key, default=None):
return self._store.getset(key, default)
def update(self, other=None, **kw):
args = kw if other is not None else other
self._store.mset(args)
__all__ = ['RedisStore']

View File

@ -1,91 +0,0 @@
# -*- coding: utf-8 -*-
'''
S3-accessed stores
shove's psuedo-URL for stores found on Amazon.com's S3 web service follows this
form:
s3://<s3_key>:<s3_secret>@<bucket>
<s3_key> is the Access Key issued by Amazon
<s3_secret> is the Secret Access Key issued by Amazon
<bucket> is the name of the bucket accessed through the S3 service
'''
try:
from boto.s3.connection import S3Connection
from boto.s3.key import Key
except ImportError:
raise ImportError('Requires boto library')
from shove import BaseStore
class S3Store(BaseStore):
def __init__(self, engine=None, **kw):
super(S3Store, self).__init__(engine, **kw)
# key = Access Key, secret=Secret Access Key, bucket=bucket name
key, secret, bucket = kw.get('key'), kw.get('secret'), kw.get('bucket')
if engine is not None:
auth, bucket = engine.split('://')[1].split('@')
key, secret = auth.split(':')
# kw 'secure' = (True or False, use HTTPS)
self._conn = S3Connection(key, secret, kw.get('secure', False))
buckets = self._conn.get_all_buckets()
# Use bucket if it exists
for b in buckets:
if b.name == bucket:
self._store = b
break
# Create bucket if it doesn't exist
else:
self._store = self._conn.create_bucket(bucket)
# Set bucket permission ('private', 'public-read',
# 'public-read-write', 'authenticated-read'
self._store.set_acl(kw.get('acl', 'private'))
# Updated flag used for avoiding network calls
self._updated, self._keys = True, None
def __getitem__(self, key):
rkey = self._store.lookup(key)
if rkey is None:
raise KeyError(key)
# Fetch string
value = self.loads(rkey.get_contents_as_string())
# Flag that the store has not been updated
self._updated = False
return value
def __setitem__(self, key, value):
rkey = Key(self._store)
rkey.key = key
rkey.set_contents_from_string(self.dumps(value))
# Flag that the store has been updated
self._updated = True
def __delitem__(self, key):
try:
self._store.delete_key(key)
# Flag that the store has been updated
self._updated = True
except:
raise KeyError(key)
def keys(self):
'''Returns a list of keys in the store.'''
return list(i[0] for i in self.items())
def items(self):
'''Returns a list of items from the store.'''
if self._updated or self._keys is None:
self._keys = self._store.get_all_keys()
return list((str(k.key), k) for k in self._keys)
def iteritems(self):
'''Lazily returns items from the store.'''
for k in self.items():
yield (k.key, k)
__all__ = ['S3Store']

View File

@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
'''
Single-process in-memory store.
The shove psuedo-URL for a simple store is:
simple://
'''
from shove import BaseStore, SimpleBase
class SimpleStore(SimpleBase, BaseStore):
'''Single-process in-memory store.'''
def __init__(self, engine, **kw):
super(SimpleStore, self).__init__(engine, **kw)
__all__ = ['SimpleStore']

View File

@ -1,110 +0,0 @@
# -*- coding: utf-8 -*-
'''
subversion managed store.
The shove psuedo-URL used for a subversion store that is password protected is:
svn:<username><password>:<path>?url=<url>
or for non-password protected repositories:
svn://<path>?url=<url>
<path> is the local repository copy
<url> is the URL of the subversion repository
'''
import os
import urllib
import threading
try:
import pysvn
except ImportError:
raise ImportError('Requires Python Subversion library')
from shove import BaseStore, synchronized
class SvnStore(BaseStore):
'''Class for subversion store.'''
def __init__(self, engine=None, **kw):
super(SvnStore, self).__init__(engine, **kw)
# Get path, url from keywords if used
path, url = kw.get('path'), kw.get('url')
# Get username. password from keywords if used
user, password = kw.get('user'), kw.get('password')
# Process psuedo URL if used
if engine is not None:
path, query = engine.split('n://')[1].split('?')
url = query.split('=')[1]
# Check for username, password
if '@' in path:
auth, path = path.split('@')
user, password = auth.split(':')
path = urllib.url2pathname(path)
# Create subversion client
self._client = pysvn.Client()
# Assign username, password
if user is not None:
self._client.set_username(user)
if password is not None:
self._client.set_password(password)
# Verify that store exists in repository
try:
self._client.info2(url)
# Create store in repository if it doesn't exist
except pysvn.ClientError:
self._client.mkdir(url, 'Adding directory')
# Verify that local copy exists
try:
if self._client.info(path) is None:
self._client.checkout(url, path)
# Check it out if it doesn't exist
except pysvn.ClientError:
self._client.checkout(url, path)
self._path, self._url = path, url
# Lock
self._lock = threading.Condition()
@synchronized
def __getitem__(self, key):
try:
return self.loads(self._client.cat(self._key_to_file(key)))
except:
raise KeyError(key)
@synchronized
def __setitem__(self, key, value):
fname = self._key_to_file(key)
# Write value to file
open(fname, 'wb').write(self.dumps(value))
# Add to repository
if key not in self:
self._client.add(fname)
self._client.checkin([fname], 'Adding %s' % fname)
@synchronized
def __delitem__(self, key):
try:
fname = self._key_to_file(key)
self._client.remove(fname)
# Remove deleted value from repository
self._client.checkin([fname], 'Removing %s' % fname)
except:
raise KeyError(key)
def _key_to_file(self, key):
'''Gives the filesystem path for a key.'''
return os.path.join(self._path, urllib.quote_plus(key))
@synchronized
def keys(self):
'''Returns a list of keys in the subversion repository.'''
return list(str(i.name.split('/')[-1]) for i
in self._client.ls(self._path))
__all__ = ['SvnStore']

View File

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
'''
Zope Object Database store frontend.
shove's psuedo-URL for ZODB stores follows the form:
zodb:<path>
Where the path is a URL path to a ZODB FileStorage database. Alternatively, a
native pathname to a ZODB database can be passed as the 'engine' argument.
'''
try:
import transaction
from ZODB import FileStorage, DB
except ImportError:
raise ImportError('Requires ZODB library')
from shove.store import SyncStore
class ZodbStore(SyncStore):
'''ZODB store front end.'''
init = 'zodb://'
def __init__(self, engine, **kw):
super(ZodbStore, self).__init__(engine, **kw)
# Handle psuedo-URL
self._storage = FileStorage.FileStorage(self._engine)
self._db = DB(self._storage)
self._connection = self._db.open()
self._store = self._connection.root()
# Keeps DB in synch through commits of transactions
self.sync = transaction.commit
def close(self):
'''Closes all open storage and connections.'''
self.sync()
super(ZodbStore, self).close()
self._connection.close()
self._db.close()
self._storage.close()
__all__ = ['ZodbStore']

View File

@ -1 +0,0 @@
# -*- coding: utf-8 -*-

View File

@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestBsdbStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('bsddb://test.db', compress=True)
def tearDown(self):
import os
self.store.close()
os.remove('test.db')
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,137 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestCassandraStore(unittest.TestCase):
def setUp(self):
from shove import Shove
from pycassa.system_manager import SystemManager
system_manager = SystemManager('localhost:9160')
try:
system_manager.create_column_family('Foo', 'shove')
except:
pass
self.store = Shove('cassandra://localhost:9160/Foo/shove')
def tearDown(self):
self.store.clear()
self.store.close()
from pycassa.system_manager import SystemManager
system_manager = SystemManager('localhost:9160')
system_manager.drop_column_family('Foo', 'shove')
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
# def test_clear(self):
# self.store['max'] = 3
# self.store['min'] = 6
# self.store['pow'] = 7
# self.store.clear()
# self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
# def test_popitem(self):
# self.store['max'] = 3
# self.store['min'] = 6
# self.store['pow'] = 7
# item = self.store.popitem()
# self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
# self.store['pow'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store.setdefault('pow', 8), 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestDbCache(unittest.TestCase):
initstring = 'sqlite:///'
def setUp(self):
from shove.cache.db import DbCache
self.cache = DbCache(self.initstring)
def tearDown(self):
self.cache = None
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.db import DbCache
cache = DbCache(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(2)
def tmp():
cache['test']
self.assertRaises(KeyError, tmp)
def test_cull(self):
from shove.cache.db import DbCache
cache = DbCache(self.initstring, max_entries=1)
cache['test'] = 'test'
cache['test2'] = 'test'
cache['test2'] = 'test'
self.assertEquals(len(cache), 1)
if __name__ == '__main__':
unittest.main()

View File

@ -1,131 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestDbStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('sqlite://', compress=True)
def tearDown(self):
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,136 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestDbmStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('dbm://test.dbm', compress=True)
def tearDown(self):
import os
self.store.close()
try:
os.remove('test.dbm.db')
except OSError:
pass
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.setdefault('how', 8)
self.assertEqual(self.store['how'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestDurusStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('durus://test.durus', compress=True)
def tearDown(self):
import os
self.store.close()
os.remove('test.durus')
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestFileCache(unittest.TestCase):
initstring = 'file://test'
def setUp(self):
from shove.cache.file import FileCache
self.cache = FileCache(self.initstring)
def tearDown(self):
import os
self.cache = None
for x in os.listdir('test'):
os.remove(os.path.join('test', x))
os.rmdir('test')
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.file import FileCache
cache = FileCache(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(2)
def tmp():
cache['test']
self.assertRaises(KeyError, tmp)
def test_cull(self):
from shove.cache.file import FileCache
cache = FileCache(self.initstring, max_entries=1)
cache['test'] = 'test'
cache['test2'] = 'test'
num = len(cache)
self.assertEquals(num, 1)
if __name__ == '__main__':
unittest.main()

View File

@ -1,140 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestFileStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('file://test', compress=True)
def tearDown(self):
import os
self.store.close()
for x in os.listdir('test'):
os.remove(os.path.join('test', x))
os.rmdir('test')
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,149 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestFtpStore(unittest.TestCase):
ftpstring = 'put ftp string here'
def setUp(self):
from shove import Shove
self.store = Shove(self.ftpstring, compress=True)
def tearDown(self):
self.store.clear()
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
self.assertEqual(len(self.store), 2)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
item = self.store.popitem()
self.store.sync()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.store.sync()
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.sync()
self.store.update(tstore)
self.store.sync()
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,135 +0,0 @@
# -*- coding: utf-8 -*-
import unittest2
class TestHDF5Store(unittest2.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('hdf5://test.hdf5/test')
def tearDown(self):
import os
self.store.close()
try:
os.remove('test.hdf5')
except OSError:
pass
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.setdefault('bow', 8)
self.assertEqual(self.store['bow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest2.main()

View File

@ -1,132 +0,0 @@
# -*- coding: utf-8 -*-
import unittest2
class TestLevelDBStore(unittest2.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('leveldb://test', compress=True)
def tearDown(self):
import shutil
shutil.rmtree('test')
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.setdefault('bow', 8)
self.assertEqual(self.store['bow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest2.main()

View File

@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestMemcached(unittest.TestCase):
initstring = 'memcache://localhost:11211'
def setUp(self):
from shove.cache.memcached import MemCached
self.cache = MemCached(self.initstring)
def tearDown(self):
self.cache = None
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.memcached import MemCached
cache = MemCached(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(1)
def tmp():
cache['test']
self.assertRaises(KeyError, tmp)
if __name__ == '__main__':
unittest.main()

View File

@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestMemoryCache(unittest.TestCase):
initstring = 'memory://'
def setUp(self):
from shove.cache.memory import MemoryCache
self.cache = MemoryCache(self.initstring)
def tearDown(self):
self.cache = None
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.memory import MemoryCache
cache = MemoryCache(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(1)
def tmp():
cache['test']
self.assertRaises(KeyError, tmp)
def test_cull(self):
from shove.cache.memory import MemoryCache
cache = MemoryCache(self.initstring, max_entries=1)
cache['test'] = 'test'
cache['test2'] = 'test'
cache['test2'] = 'test'
self.assertEquals(len(cache), 1)
if __name__ == '__main__':
unittest.main()

View File

@ -1,135 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestMemoryStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('memory://', compress=True)
def tearDown(self):
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestRedisCache(unittest.TestCase):
initstring = 'redis://localhost:6379/0'
def setUp(self):
from shove.cache.redisdb import RedisCache
self.cache = RedisCache(self.initstring)
def tearDown(self):
self.cache = None
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.redisdb import RedisCache
cache = RedisCache(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(3)
def tmp(): #@IgnorePep8
return cache['test']
self.assertRaises(KeyError, tmp)
if __name__ == '__main__':
unittest.main()

View File

@ -1,128 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestRedisStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('redis://localhost:6379/0')
def tearDown(self):
self.store.clear()
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store.setdefault('pow', 8), 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,149 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestS3Store(unittest.TestCase):
s3string = 's3 test string here'
def setUp(self):
from shove import Shove
self.store = Shove(self.s3string, compress=True)
def tearDown(self):
self.store.clear()
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
self.assertEqual(len(self.store), 2)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
item = self.store.popitem()
self.store.sync()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.store.sync()
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.sync()
self.store.update(tstore)
self.store.sync()
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestSimpleCache(unittest.TestCase):
initstring = 'simple://'
def setUp(self):
from shove.cache.simple import SimpleCache
self.cache = SimpleCache(self.initstring)
def tearDown(self):
self.cache = None
def test_getitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_setitem(self):
self.cache['test'] = 'test'
self.assertEqual(self.cache['test'], 'test')
def test_delitem(self):
self.cache['test'] = 'test'
del self.cache['test']
self.assertEqual('test' in self.cache, False)
def test_get(self):
self.assertEqual(self.cache.get('min'), None)
def test_timeout(self):
import time
from shove.cache.simple import SimpleCache
cache = SimpleCache(self.initstring, timeout=1)
cache['test'] = 'test'
time.sleep(1)
def tmp():
cache['test']
self.assertRaises(KeyError, tmp)
def test_cull(self):
from shove.cache.simple import SimpleCache
cache = SimpleCache(self.initstring, max_entries=1)
cache['test'] = 'test'
cache['test2'] = 'test'
cache['test2'] = 'test'
self.assertEquals(len(cache), 1)
if __name__ == '__main__':
unittest.main()

View File

@ -1,135 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestSimpleStore(unittest.TestCase):
def setUp(self):
from shove import Shove
self.store = Shove('simple://', compress=True)
def tearDown(self):
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,148 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestSvnStore(unittest.TestCase):
svnstring = 'SVN test string here'
def setUp(self):
from shove import Shove
self.store = Shove(self.svnstring, compress=True)
def tearDown(self):
self.store.clear()
self.store.close()
def test__getitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.store.sync()
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.store.sync()
tstore.sync()
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
self.assertEqual(len(self.store), 2)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
self.store.sync()
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
item = self.store.popitem()
self.store.sync()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.store.sync()
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.sync()
self.store.update(tstore)
self.store.sync()
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.sync()
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,138 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
class TestZodbStore(unittest.TestCase):
init = 'zodb://test.db'
def setUp(self):
from shove import Shove
self.store = Shove(self.init, compress=True)
def tearDown(self):
self.store.close()
import os
os.remove('test.db')
os.remove('test.db.index')
os.remove('test.db.tmp')
os.remove('test.db.lock')
def test__getitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__setitem__(self):
self.store['max'] = 3
self.assertEqual(self.store['max'], 3)
def test__delitem__(self):
self.store['max'] = 3
del self.store['max']
self.assertEqual('max' in self.store, False)
def test_get(self):
self.store['max'] = 3
self.assertEqual(self.store.get('min'), None)
def test__cmp__(self):
from shove import Shove
tstore = Shove()
self.store['max'] = 3
tstore['max'] = 3
self.assertEqual(self.store, tstore)
def test__len__(self):
self.store['max'] = 3
self.store['min'] = 6
self.assertEqual(len(self.store), 2)
def test_close(self):
self.store.close()
self.assertEqual(self.store, None)
def test_clear(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
self.store.clear()
self.assertEqual(len(self.store), 0)
def test_items(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.items())
self.assertEqual(('min', 6) in slist, True)
def test_iteritems(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iteritems())
self.assertEqual(('min', 6) in slist, True)
def test_iterkeys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.iterkeys())
self.assertEqual('min' in slist, True)
def test_itervalues(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = list(self.store.itervalues())
self.assertEqual(6 in slist, True)
def test_pop(self):
self.store['max'] = 3
self.store['min'] = 6
item = self.store.pop('min')
self.assertEqual(item, 6)
def test_popitem(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
item = self.store.popitem()
self.assertEqual(len(item) + len(self.store), 4)
def test_setdefault(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['powl'] = 7
self.store.setdefault('pow', 8)
self.assertEqual(self.store['pow'], 8)
def test_update(self):
from shove import Shove
tstore = Shove()
tstore['max'] = 3
tstore['min'] = 6
tstore['pow'] = 7
self.store['max'] = 2
self.store['min'] = 3
self.store['pow'] = 7
self.store.update(tstore)
self.assertEqual(self.store['min'], 6)
def test_values(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.values()
self.assertEqual(6 in slist, True)
def test_keys(self):
self.store['max'] = 3
self.store['min'] = 6
self.store['pow'] = 7
slist = self.store.keys()
self.assertEqual('min' in slist, True)
if __name__ == '__main__':
unittest.main()

View File

@ -1,133 +0,0 @@
# sqlalchemy/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .sql import (
alias,
and_,
asc,
between,
bindparam,
case,
cast,
collate,
delete,
desc,
distinct,
except_,
except_all,
exists,
extract,
false,
func,
insert,
intersect,
intersect_all,
join,
literal,
literal_column,
modifier,
not_,
null,
or_,
outerjoin,
outparam,
over,
select,
subquery,
text,
true,
tuple_,
type_coerce,
union,
union_all,
update,
)
from .types import (
BIGINT,
BINARY,
BLOB,
BOOLEAN,
BigInteger,
Binary,
Boolean,
CHAR,
CLOB,
DATE,
DATETIME,
DECIMAL,
Date,
DateTime,
Enum,
FLOAT,
Float,
INT,
INTEGER,
Integer,
Interval,
LargeBinary,
NCHAR,
NVARCHAR,
NUMERIC,
Numeric,
PickleType,
REAL,
SMALLINT,
SmallInteger,
String,
TEXT,
TIME,
TIMESTAMP,
Text,
Time,
TypeDecorator,
Unicode,
UnicodeText,
VARBINARY,
VARCHAR,
)
from .schema import (
CheckConstraint,
Column,
ColumnDefault,
Constraint,
DefaultClause,
FetchedValue,
ForeignKey,
ForeignKeyConstraint,
Index,
MetaData,
PassiveDefault,
PrimaryKeyConstraint,
Sequence,
Table,
ThreadLocalMetaData,
UniqueConstraint,
DDL,
)
from .inspection import inspect
from .engine import create_engine, engine_from_config
__version__ = '0.9.4'
def __go(lcls):
global __all__
from . import events
from . import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
_sa_util.dependencies.resolve_all("sqlalchemy")
__go(locals())

View File

@ -1,706 +0,0 @@
/*
processors.c
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#include <datetime.h>
#define MODULE_NAME "cprocessors"
#define MODULE_DOC "Module containing C versions of data processing functions."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
#endif
static PyObject *
int_to_boolean(PyObject *self, PyObject *arg)
{
long l = 0;
PyObject *res;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
l = PyLong_AsLong(arg);
#else
l = PyInt_AsLong(arg);
#endif
if (l == 0) {
res = Py_False;
} else if (l == 1) {
res = Py_True;
} else if ((l == -1) && PyErr_Occurred()) {
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else {
PyErr_SetString(PyExc_ValueError,
"int_to_boolean only accepts None, 0 or 1");
return NULL;
}
Py_INCREF(res);
return res;
}
static PyObject *
to_str(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyObject_Str(arg);
}
static PyObject *
to_float(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyNumber_Float(arg);
}
static PyObject *
str_to_datetime(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int year, month, day, hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
should be coherent.
*/
numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
&hour, &minute, &second, &microsecond);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 6) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyDateTime_FromDateAndTime(year, month, day,
hour, minute, second, microsecond);
}
static PyObject *
str_to_time(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "00:00:00.". I don't know which is better, but they should be
coherent.
*/
numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
&microsecond);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyTime_FromTime(hour, minute, second, microsecond);
}
static PyObject *
str_to_date(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int year, month, day;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed != 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyDate_FromDate(year, month, day);
}
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *encoding;
PyObject *errors;
} UnicodeResultProcessor;
typedef struct {
PyObject_HEAD
PyObject *type;
PyObject *format;
} DecimalResultProcessor;
/**************************
* UnicodeResultProcessor *
**************************/
static int
UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *encoding, *errors = NULL;
static char *kwlist[] = {"encoding", "errors", NULL};
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist,
&encoding, &errors))
return -1;
#else
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
&encoding, &errors))
return -1;
#endif
#if PY_MAJOR_VERSION >= 3
encoding = PyUnicode_AsASCIIString(encoding);
#else
Py_INCREF(encoding);
#endif
self->encoding = encoding;
if (errors) {
#if PY_MAJOR_VERSION >= 3
errors = PyUnicode_AsASCIIString(errors);
#else
Py_INCREF(errors);
#endif
} else {
#if PY_MAJOR_VERSION >= 3
errors = PyBytes_FromString("strict");
#else
errors = PyString_FromString("strict");
#endif
if (errors == NULL)
return -1;
}
self->errors = errors;
return 0;
}
static PyObject *
UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
{
const char *encoding, *errors;
char *str;
Py_ssize_t len;
if (value == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
if (PyBytes_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyBytes_AS_STRING(self->encoding);
errors = PyBytes_AS_STRING(self->errors);
#else
if (PyString_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors);
#endif
return PyUnicode_Decode(str, len, encoding, errors);
}
static PyObject *
UnicodeResultProcessor_conditional_process(UnicodeResultProcessor *self, PyObject *value)
{
const char *encoding, *errors;
char *str;
Py_ssize_t len;
if (value == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(value) == 1) {
Py_INCREF(value);
return value;
}
if (PyBytes_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyBytes_AS_STRING(self->encoding);
errors = PyBytes_AS_STRING(self->errors);
#else
if (PyUnicode_Check(value) == 1) {
Py_INCREF(value);
return value;
}
if (PyString_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors);
#endif
return PyUnicode_Decode(str, len, encoding, errors);
}
static void
UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self)
{
Py_XDECREF(self->encoding);
Py_XDECREF(self->errors);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
}
static PyMethodDef UnicodeResultProcessor_methods[] = {
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
"The value processor itself."},
{"conditional_process", (PyCFunction)UnicodeResultProcessor_conditional_process, METH_O,
"Conditional version of the value processor."},
{NULL} /* Sentinel */
};
static PyTypeObject UnicodeResultProcessorType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
sizeof(UnicodeResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"UnicodeResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
UnicodeResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)UnicodeResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
/**************************
* DecimalResultProcessor *
**************************/
static int
DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *type, *format;
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "OU", &type, &format))
#else
if (!PyArg_ParseTuple(args, "OS", &type, &format))
#endif
return -1;
Py_INCREF(type);
self->type = type;
Py_INCREF(format);
self->format = format;
return 0;
}
static PyObject *
DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
{
PyObject *str, *result, *args;
if (value == Py_None)
Py_RETURN_NONE;
/* Decimal does not accept float values directly */
/* SQLite can also give us an integer here (see [ticket:2432]) */
/* XXX: starting with Python 3.1, we could use Decimal.from_float(f),
but the result wouldn't be the same */
args = PyTuple_Pack(1, value);
if (args == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
str = PyUnicode_Format(self->format, args);
#else
str = PyString_Format(self->format, args);
#endif
Py_DECREF(args);
if (str == NULL)
return NULL;
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
Py_DECREF(str);
return result;
}
static void
DecimalResultProcessor_dealloc(DecimalResultProcessor *self)
{
Py_XDECREF(self->type);
Py_XDECREF(self->format);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
}
static PyMethodDef DecimalResultProcessor_methods[] = {
{"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
"The value processor itself."},
{NULL} /* Sentinel */
};
static PyTypeObject DecimalResultProcessorType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.DecimalResultProcessor", /* tp_name */
sizeof(DecimalResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"DecimalResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
DecimalResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)DecimalResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
static PyMethodDef module_methods[] = {
{"int_to_boolean", int_to_boolean, METH_O,
"Convert an integer to a boolean."},
{"to_str", to_str, METH_O,
"Convert any value to its string representation."},
{"to_float", to_float, METH_O,
"Convert any value to its floating point representation."},
{"str_to_datetime", str_to_datetime, METH_O,
"Convert an ISO string to a datetime.datetime object."},
{"str_to_time", str_to_time, METH_O,
"Convert an ISO string to a datetime.time object."},
{"str_to_date", str_to_date, METH_O,
"Convert an ISO string to a datetime.date object."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyMODINIT_FUNC
PyInit_cprocessors(void)
#else
#define INITERROR return
PyMODINIT_FUNC
initcprocessors(void)
#endif
{
PyObject *m;
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
INITERROR;
DecimalResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DecimalResultProcessorType) < 0)
INITERROR;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL)
INITERROR;
PyDateTime_IMPORT;
Py_INCREF(&UnicodeResultProcessorType);
PyModule_AddObject(m, "UnicodeResultProcessor",
(PyObject *)&UnicodeResultProcessorType);
Py_INCREF(&DecimalResultProcessorType);
PyModule_AddObject(m, "DecimalResultProcessor",
(PyObject *)&DecimalResultProcessorType);
#if PY_MAJOR_VERSION >= 3
return m;
#endif
}

View File

@ -1,718 +0,0 @@
/*
resultproxy.c
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#define MODULE_NAME "cresultproxy"
#define MODULE_DOC "Module containing C versions of core ResultProxy classes."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
typedef Py_ssize_t (*lenfunc)(PyObject *);
#define PyInt_FromSsize_t(x) PyInt_FromLong(x)
typedef intargfunc ssizeargfunc;
#endif
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *parent;
PyObject *row;
PyObject *processors;
PyObject *keymap;
} BaseRowProxy;
/****************
* BaseRowProxy *
****************/
static PyObject *
safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
{
PyObject *cls, *state, *tmp;
BaseRowProxy *obj;
if (!PyArg_ParseTuple(args, "OO", &cls, &state))
return NULL;
obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
if (obj == NULL)
return NULL;
tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
if (tmp == NULL) {
Py_DECREF(obj);
return NULL;
}
Py_DECREF(tmp);
if (obj->parent == NULL || obj->row == NULL ||
obj->processors == NULL || obj->keymap == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"__setstate__ for BaseRowProxy subclasses must set values "
"for parent, row, processors and keymap");
Py_DECREF(obj);
return NULL;
}
return (PyObject *)obj;
}
static int
BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
{
PyObject *parent, *row, *processors, *keymap;
if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
&parent, &row, &processors, &keymap))
return -1;
Py_INCREF(parent);
self->parent = parent;
if (!PySequence_Check(row)) {
PyErr_SetString(PyExc_TypeError, "row must be a sequence");
return -1;
}
Py_INCREF(row);
self->row = row;
if (!PyList_CheckExact(processors)) {
PyErr_SetString(PyExc_TypeError, "processors must be a list");
return -1;
}
Py_INCREF(processors);
self->processors = processors;
if (!PyDict_CheckExact(keymap)) {
PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
return -1;
}
Py_INCREF(keymap);
self->keymap = keymap;
return 0;
}
/* We need the reduce method because otherwise the default implementation
* does very weird stuff for pickle protocol 0 and 1. It calls
* BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
*/
static PyObject *
BaseRowProxy_reduce(PyObject *self)
{
PyObject *method, *state;
PyObject *module, *reconstructor, *cls;
method = PyObject_GetAttrString(self, "__getstate__");
if (method == NULL)
return NULL;
state = PyObject_CallObject(method, NULL);
Py_DECREF(method);
if (state == NULL)
return NULL;
module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL)
return NULL;
reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
Py_DECREF(module);
if (reconstructor == NULL) {
Py_DECREF(state);
return NULL;
}
cls = PyObject_GetAttrString(self, "__class__");
if (cls == NULL) {
Py_DECREF(reconstructor);
Py_DECREF(state);
return NULL;
}
return Py_BuildValue("(N(NN))", reconstructor, cls, state);
}
static void
BaseRowProxy_dealloc(BaseRowProxy *self)
{
Py_XDECREF(self->parent);
Py_XDECREF(self->row);
Py_XDECREF(self->processors);
Py_XDECREF(self->keymap);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject *)self);
#else
self->ob_type->tp_free((PyObject *)self);
#endif
}
static PyObject *
BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
{
Py_ssize_t num_values, num_processors;
PyObject **valueptr, **funcptr, **resultptr;
PyObject *func, *result, *processed_value, *values_fastseq;
num_values = PySequence_Length(values);
num_processors = PyList_Size(processors);
if (num_values != num_processors) {
PyErr_Format(PyExc_RuntimeError,
"number of values in row (%d) differ from number of column "
"processors (%d)",
(int)num_values, (int)num_processors);
return NULL;
}
if (astuple) {
result = PyTuple_New(num_values);
} else {
result = PyList_New(num_values);
}
if (result == NULL)
return NULL;
values_fastseq = PySequence_Fast(values, "row must be a sequence");
if (values_fastseq == NULL)
return NULL;
valueptr = PySequence_Fast_ITEMS(values_fastseq);
funcptr = PySequence_Fast_ITEMS(processors);
resultptr = PySequence_Fast_ITEMS(result);
while (--num_values >= 0) {
func = *funcptr;
if (func != Py_None) {
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
NULL);
if (processed_value == NULL) {
Py_DECREF(values_fastseq);
Py_DECREF(result);
return NULL;
}
*resultptr = processed_value;
} else {
Py_INCREF(*valueptr);
*resultptr = *valueptr;
}
valueptr++;
funcptr++;
resultptr++;
}
Py_DECREF(values_fastseq);
return result;
}
static PyListObject *
BaseRowProxy_values(BaseRowProxy *self)
{
return (PyListObject *)BaseRowProxy_processvalues(self->row,
self->processors, 0);
}
static PyObject *
BaseRowProxy_iter(BaseRowProxy *self)
{
PyObject *values, *result;
values = BaseRowProxy_processvalues(self->row, self->processors, 1);
if (values == NULL)
return NULL;
result = PyObject_GetIter(values);
Py_DECREF(values);
if (result == NULL)
return NULL;
return result;
}
static Py_ssize_t
BaseRowProxy_length(BaseRowProxy *self)
{
return PySequence_Length(self->row);
}
static PyObject *
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
{
PyObject *processors, *values;
PyObject *processor, *value, *processed_value;
PyObject *row, *record, *result, *indexobject;
PyObject *exc_module, *exception, *cstr_obj;
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
#endif
char *cstr_key;
long index;
int key_fallback = 0;
int tuple_check = 0;
#if PY_MAJOR_VERSION < 3
if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key);
}
#endif
if (PyLong_CheckExact(key)) {
index = PyLong_AsLong(key);
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else if (PySlice_Check(key)) {
values = PyObject_GetItem(self->row, key);
if (values == NULL)
return NULL;
processors = PyObject_GetItem(self->processors, key);
if (processors == NULL) {
Py_DECREF(values);
return NULL;
}
result = BaseRowProxy_processvalues(values, processors, 1);
Py_DECREF(values);
Py_DECREF(processors);
return result;
} else {
record = PyDict_GetItem((PyObject *)self->keymap, key);
if (record == NULL) {
record = PyObject_CallMethod(self->parent, "_key_fallback",
"O", key);
if (record == NULL)
return NULL;
key_fallback = 1;
}
indexobject = PyTuple_GetItem(record, 2);
if (indexobject == NULL)
return NULL;
if (key_fallback) {
Py_DECREF(record);
}
if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL)
return NULL;
exception = PyObject_GetAttrString(exc_module,
"InvalidRequestError");
Py_DECREF(exc_module);
if (exception == NULL)
return NULL;
// wow. this seems quite excessive.
cstr_obj = PyObject_Str(key);
if (cstr_obj == NULL)
return NULL;
/*
FIXME: raise encoding error exception (in both versions below)
if the key contains non-ascii chars, instead of an
InvalidRequestError without any message like in the
python version.
*/
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(cstr_obj);
if (bytes == NULL)
return NULL;
cstr_key = PyBytes_AS_STRING(bytes);
#else
cstr_key = PyString_AsString(cstr_obj);
#endif
if (cstr_key == NULL) {
Py_DECREF(cstr_obj);
return NULL;
}
Py_DECREF(cstr_obj);
PyErr_Format(exception,
"Ambiguous column name '%.200s' in result set! "
"try 'use_labels' option on select statement.", cstr_key);
return NULL;
}
#if PY_MAJOR_VERSION >= 3
index = PyLong_AsLong(indexobject);
#else
index = PyInt_AsLong(indexobject);
#endif
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
}
processor = PyList_GetItem(self->processors, index);
if (processor == NULL)
return NULL;
row = self->row;
if (PyTuple_CheckExact(row)) {
value = PyTuple_GetItem(row, index);
tuple_check = 1;
}
else {
value = PySequence_GetItem(row, index);
tuple_check = 0;
}
if (value == NULL)
return NULL;
if (processor != Py_None) {
processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL);
if (!tuple_check) {
Py_DECREF(value);
}
return processed_value;
} else {
if (tuple_check) {
Py_INCREF(value);
}
return value;
}
}
static PyObject *
BaseRowProxy_getitem(PyObject *self, Py_ssize_t i)
{
PyObject *index;
#if PY_MAJOR_VERSION >= 3
index = PyLong_FromSsize_t(i);
#else
index = PyInt_FromSsize_t(i);
#endif
return BaseRowProxy_subscript((BaseRowProxy*)self, index);
}
static PyObject *
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
{
PyObject *tmp;
#if PY_MAJOR_VERSION >= 3
PyObject *err_bytes;
#endif
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
else
return tmp;
tmp = BaseRowProxy_subscript(self, name);
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(name);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyBytes_AS_STRING(err_bytes)
);
#else
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyString_AsString(name)
);
#endif
return NULL;
}
return tmp;
}
/***********************
* getters and setters *
***********************/
static PyObject *
BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->parent);
return self->parent;
}
static int
BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
{
PyObject *module, *cls;
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'parent' attribute");
return -1;
}
module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL)
return -1;
cls = PyObject_GetAttrString(module, "ResultMetaData");
Py_DECREF(module);
if (cls == NULL)
return -1;
if (PyObject_IsInstance(value, cls) != 1) {
PyErr_SetString(PyExc_TypeError,
"The 'parent' attribute value must be an instance of "
"ResultMetaData");
return -1;
}
Py_DECREF(cls);
Py_XDECREF(self->parent);
Py_INCREF(value);
self->parent = value;
return 0;
}
static PyObject *
BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->row);
return self->row;
}
static int
BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'row' attribute");
return -1;
}
if (!PySequence_Check(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'row' attribute value must be a sequence");
return -1;
}
Py_XDECREF(self->row);
Py_INCREF(value);
self->row = value;
return 0;
}
static PyObject *
BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->processors);
return self->processors;
}
static int
BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'processors' attribute");
return -1;
}
if (!PyList_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'processors' attribute value must be a list");
return -1;
}
Py_XDECREF(self->processors);
Py_INCREF(value);
self->processors = value;
return 0;
}
static PyObject *
BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->keymap);
return self->keymap;
}
static int
BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'keymap' attribute");
return -1;
}
if (!PyDict_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'keymap' attribute value must be a dict");
return -1;
}
Py_XDECREF(self->keymap);
Py_INCREF(value);
self->keymap = value;
return 0;
}
static PyGetSetDef BaseRowProxy_getseters[] = {
{"_parent",
(getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
"ResultMetaData",
NULL},
{"_row",
(getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
"Original row tuple",
NULL},
{"_processors",
(getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
"list of type processors",
NULL},
{"_keymap",
(getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
"Key to (processor, index) dict",
NULL},
{NULL}
};
static PyMethodDef BaseRowProxy_methods[] = {
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
"Return the values represented by this BaseRowProxy as a list."},
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
"Pickle support method."},
{NULL} /* Sentinel */
};
static PySequenceMethods BaseRowProxy_as_sequence = {
(lenfunc)BaseRowProxy_length, /* sq_length */
0, /* sq_concat */
0, /* sq_repeat */
(ssizeargfunc)BaseRowProxy_getitem, /* sq_item */
0, /* sq_slice */
0, /* sq_ass_item */
0, /* sq_ass_slice */
0, /* sq_contains */
0, /* sq_inplace_concat */
0, /* sq_inplace_repeat */
};
static PyMappingMethods BaseRowProxy_as_mapping = {
(lenfunc)BaseRowProxy_length, /* mp_length */
(binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
0 /* mp_ass_subscript */
};
static PyTypeObject BaseRowProxyType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
sizeof(BaseRowProxy), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)BaseRowProxy_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
&BaseRowProxy_as_sequence, /* tp_as_sequence */
&BaseRowProxy_as_mapping, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
(getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)BaseRowProxy_iter, /* tp_iter */
0, /* tp_iternext */
BaseRowProxy_methods, /* tp_methods */
0, /* tp_members */
BaseRowProxy_getseters, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)BaseRowProxy_init, /* tp_init */
0, /* tp_alloc */
0 /* tp_new */
};
static PyMethodDef module_methods[] = {
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
"reconstruct a RowProxy instance from its pickled form."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyMODINIT_FUNC
PyInit_cresultproxy(void)
#else
#define INITERROR return
PyMODINIT_FUNC
initcresultproxy(void)
#endif
{
PyObject *m;
BaseRowProxyType.tp_new = PyType_GenericNew;
if (PyType_Ready(&BaseRowProxyType) < 0)
INITERROR;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL)
INITERROR;
Py_INCREF(&BaseRowProxyType);
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
#if PY_MAJOR_VERSION >= 3
return m;
#endif
}

View File

@ -1,225 +0,0 @@
/*
utils.c
Copyright (C) 2012-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#define MODULE_NAME "cutils"
#define MODULE_DOC "Module containing C versions of utility functions."
/*
Given arguments from the calling form *multiparams, **params,
return a list of bind parameter structures, usually a list of
dictionaries.
In the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists.
*/
static PyObject *
distill_params(PyObject *self, PyObject *args)
{
PyObject *multiparams, *params;
PyObject *enclosing_list, *double_enclosing_list;
PyObject *zero_element, *zero_element_item;
Py_ssize_t multiparam_size, zero_element_length;
if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, &params)) {
return NULL;
}
if (multiparams != Py_None) {
multiparam_size = PyTuple_Size(multiparams);
if (multiparam_size < 0) {
return NULL;
}
}
else {
multiparam_size = 0;
}
if (multiparam_size == 0) {
if (params != Py_None && PyDict_Size(params) != 0) {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(params);
if (PyList_SetItem(enclosing_list, 0, params) == -1) {
Py_DECREF(params);
Py_DECREF(enclosing_list);
return NULL;
}
}
else {
enclosing_list = PyList_New(0);
if (enclosing_list == NULL) {
return NULL;
}
}
return enclosing_list;
}
else if (multiparam_size == 1) {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
zero_element_length = PySequence_Length(zero_element);
if (zero_element_length != 0) {
zero_element_item = PySequence_GetItem(zero_element, 0);
if (zero_element_item == NULL) {
return NULL;
}
}
else {
zero_element_item = NULL;
}
if (zero_element_length == 0 ||
(
PyObject_HasAttrString(zero_element_item, "__iter__") &&
!PyObject_HasAttrString(zero_element_item, "strip")
)
) {
/*
* execute(stmt, [{}, {}, {}, ...])
* execute(stmt, [(), (), (), ...])
*/
Py_XDECREF(zero_element_item);
Py_INCREF(zero_element);
return zero_element;
}
else {
/*
* execute(stmt, ("value", "value"))
*/
Py_XDECREF(zero_element_item);
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
else if (PyObject_HasAttrString(zero_element, "keys")) {
/*
* execute(stmt, {"key":"value"})
*/
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
} else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
double_enclosing_list = PyList_New(1);
if (double_enclosing_list == NULL) {
Py_DECREF(enclosing_list);
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
return double_enclosing_list;
}
}
else {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyObject_HasAttrString(zero_element, "__iter__") &&
!PyObject_HasAttrString(zero_element, "strip")
) {
Py_INCREF(multiparams);
return multiparams;
}
else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(multiparams);
if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
Py_DECREF(multiparams);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
}
static PyMethodDef module_methods[] = {
{"_distill_params", distill_params, METH_VARARGS,
"Distill an execute() parameter structure."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#endif
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit_cutils(void)
#else
PyMODINIT_FUNC
initcutils(void)
#endif
{
PyObject *m;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
#if PY_MAJOR_VERSION >= 3
if (m == NULL)
return NULL;
return m;
#else
if (m == NULL)
return;
#endif
}

View File

@ -1,9 +0,0 @@
# connectors/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
class Connector(object):
pass

View File

@ -1,149 +0,0 @@
# connectors/mxodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Provide an SQLALchemy connector for the eGenix mxODBC commercial
Python adapter for ODBC. This is not a free product, but eGenix
provides SQLAlchemy with a license for use in continuous integration
testing.
This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
and 2008, using the SQL Server Native driver. However, it is
possible for this to be used on other database platforms.
For more info on mxODBC, see http://www.egenix.com/
"""
import sys
import re
import warnings
from . import Connector
class MxODBCConnector(Connector):
driver = 'mxodbc'
supports_sane_multi_rowcount = False
supports_unicode_statements = True
supports_unicode_binds = True
supports_native_decimal = True
@classmethod
def dbapi(cls):
# this classmethod will normally be replaced by an instance
# attribute of the same name, so this is normally only called once.
cls._load_mx_exceptions()
platform = sys.platform
if platform == 'win32':
from mx.ODBC import Windows as module
# this can be the string "linux2", and possibly others
elif 'linux' in platform:
from mx.ODBC import unixODBC as module
elif platform == 'darwin':
from mx.ODBC import iODBC as module
else:
raise ImportError("Unrecognized platform for mxODBC import")
return module
@classmethod
def _load_mx_exceptions(cls):
""" Import mxODBC exception classes into the module namespace,
as if they had been imported normally. This is done here
to avoid requiring all SQLAlchemy users to install mxODBC.
"""
global InterfaceError, ProgrammingError
from mx.ODBC import InterfaceError
from mx.ODBC import ProgrammingError
def on_connect(self):
def connect(conn):
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler()
return connect
def _error_handler(self):
""" Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings.
"""
from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,)
warnings.warn(message=str(errorvalue),
category=errorclass,
stacklevel=2)
else:
raise errorclass(errorvalue)
return error_handler
def create_connect_args(self, url):
""" Return a tuple of *args,**kwargs for creating a connection.
The mxODBC 3.x connection constructor looks like this:
connect(dsn, user='', password='',
clear_auto_commit=1, errorhandler=None)
This method translates the values in the provided uri
into args and kwargs needed to instantiate an mxODBC Connection.
The arg 'errorhandler' is not used by SQLAlchemy and will
not be populated.
"""
opts = url.translate_connect_args(username='user')
opts.update(url.query)
args = opts.pop('host')
opts.pop('port', None)
opts.pop('database', None)
return (args,), opts
def is_disconnect(self, e, connection, cursor):
# TODO: eGenix recommends checking connection.closed here
# Does that detect dropped connections ?
if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def _get_server_version_info(self, connection):
# eGenix suggests using conn.dbms_version instead
# of what we're doing here
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
# 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _get_direct(self, context):
if context:
native_odbc_execute = context.execution_options.\
get('native_odbc_execute', 'auto')
# default to direct=True in all cases, is more generally
# compatible especially with SQL Server
return False if native_odbc_execute is True else True
else:
return True
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(
statement, parameters, direct=self._get_direct(context))
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters, direct=self._get_direct(context))

View File

@ -1,144 +0,0 @@
# connectors/mysqldb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Define behaviors common to MySQLdb dialects.
Currently includes MySQL and Drizzle.
"""
from . import Connector
from ..engine import base as engine_base, default
from ..sql import operators as sql_operators
from .. import exc, log, schema, sql, types as sqltypes, util, processors
import re
# the subclassing of Connector by all classes
# here is not strictly necessary
class MySQLDBExecutionContext(Connector):
@property
def rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLDBCompiler(Connector):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLDBIdentifierPreparer(Connector):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class MySQLDBConnector(Connector):
driver = 'mysqldb'
supports_unicode_statements = False
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
@classmethod
def dbapi(cls):
# is overridden when pymysql is used
return __import__('MySQLdb')
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'read_timeout', int)
util.coerce_kw_type(opts, 'client_flag', int)
util.coerce_kw_type(opts, 'local_infile', int)
# Note: using either of the below will cause all strings to be returned
# as Unicode, both in raw SQL operations and with column types like
# String and MSString.
util.coerce_kw_type(opts, 'use_unicode', bool)
util.coerce_kw_type(opts, 'charset', str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
for key in keys:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get('client_flag', 0)
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + '.constants.CLIENT'
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts['client_flag'] = client_flag
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.get_server_info()):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name = connection.connection.character_set_name
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
"this MySQL-Python version; "
"please upgrade to a recent version of MySQL-Python. "
"Assuming latin1.")
return 'latin1'
else:
return cset_name()

View File

@ -1,170 +0,0 @@
# connectors/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import Connector
from .. import util
import sys
import re
class PyODBCConnector(Connector):
driver = 'pyodbc'
supports_sane_multi_rowcount = False
if util.py2k:
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
supports_native_decimal = True
default_paramstyle = 'named'
# for non-DSN connections, this should
# hold the desired driver name
pyodbc_driver_name = None
# will be set to True after initialize()
# if the freetds.so is detected
freetds = False
# will be set to the string version of
# the FreeTDS driver if freetds is detected
freetds_driver_version = None
# will be set to True after initialize()
# if the libessqlsrv.so is detected
easysoft = False
def __init__(self, supports_unicode_binds=None, **kw):
super(PyODBCConnector, self).__init__(**kw)
self._user_supports_unicode_binds = supports_unicode_binds
@classmethod
def dbapi(cls):
return __import__('pyodbc')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
keys = opts
query = url.query
connect_args = {}
for param in ('ansi', 'unicode_results', 'autocommit'):
if param in keys:
connect_args[param] = util.asbool(keys.pop(param))
if 'odbc_connect' in keys:
connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
else:
dsn_connection = 'dsn' in keys or \
('host' in keys and 'database' not in keys)
if dsn_connection:
connectors = ['dsn=%s' % (keys.pop('host', '') or \
keys.pop('dsn', ''))]
else:
port = ''
if 'port' in keys and not 'port' in query:
port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" %
keys.pop('driver', self.pyodbc_driver_name),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '')]
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.pop('password', ''))
else:
connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically
# convert textual data from your database encoding to your
# client encoding. This should obviously be set to 'No' if
# you query a cp1253 encoded database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" %
keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
return [[";".join(connectors)], connect_args]
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or \
'Attempt to use a closed connection.' in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def initialize(self, connection):
# determine FreeTDS first. can't issue SQL easily
# without getting unicode_statements/binds set up.
pyodbc = self.dbapi
dbapi_con = connection.connection
_sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
))
self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
))
if self.freetds:
self.freetds_driver_version = dbapi_con.getinfo(
pyodbc.SQL_DRIVER_VER)
self.supports_unicode_statements = (
not util.py2k or
(not self.freetds and not self.easysoft)
)
if self._user_supports_unicode_binds is not None:
self.supports_unicode_binds = self._user_supports_unicode_binds
elif util.py2k:
self.supports_unicode_binds = (
not self.freetds or self.freetds_driver_version >= '0.91'
) and not self.easysoft
else:
self.supports_unicode_binds = True
# run other initialization which asks for user name, etc.
super(PyODBCConnector, self).initialize(connection)
def _dbapi_version(self):
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers):
m = re.match(
r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
vers
)
if not m:
return ()
vers = tuple([int(x) for x in m.group(1).split(".")])
if m.group(2):
vers += (m.group(2),)
return vers
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)

View File

@ -1,59 +0,0 @@
# connectors/zxJDBC.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sys
from . import Connector
class ZxJDBCConnector(Connector):
driver = 'zxjdbc'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_unicode_binds = True
supports_unicode_statements = sys.version > '2.5.0+'
description_encoding = None
default_paramstyle = 'qmark'
jdbc_db_name = None
jdbc_driver_name = None
@classmethod
def dbapi(cls):
from com.ziclix.python.sql import zxJDBC
return zxJDBC
def _driver_kwargs(self):
"""Return kw arg dict to be sent to connect()."""
return {}
def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
url.port is not None
and ':%s' % url.port or '',
url.database)
def create_connect_args(self, url):
opts = self._driver_kwargs()
opts.update(url.query)
return [
[self._create_jdbc_url(url),
url.username, url.password,
self.jdbc_driver_name],
opts]
def is_disconnect(self, e, connection, cursor):
if not isinstance(e, self.dbapi.ProgrammingError):
return False
e = str(e)
return 'connection is closed' in e or 'cursor is closed' in e
def _get_server_version_info(self, connection):
# use connection.connection.dbversion, and parse appropriately
# to get a tuple
raise NotImplementedError()

View File

@ -1,31 +0,0 @@
# databases/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Include imports from the sqlalchemy.dialects package for backwards
compatibility with pre 0.6 versions.
"""
from ..dialects.sqlite import base as sqlite
from ..dialects.postgresql import base as postgresql
postgres = postgresql
from ..dialects.mysql import base as mysql
from ..dialects.drizzle import base as drizzle
from ..dialects.oracle import base as oracle
from ..dialects.firebird import base as firebird
from ..dialects.mssql import base as mssql
from ..dialects.sybase import base as sybase
__all__ = (
'drizzle',
'firebird',
'mssql',
'mysql',
'postgresql',
'sqlite',
'oracle',
'sybase',
)

View File

@ -1,44 +0,0 @@
# dialects/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = (
'drizzle',
'firebird',
'mssql',
'mysql',
'oracle',
'postgresql',
'sqlite',
'sybase',
)
from .. import util
def _auto_fn(name):
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
try:
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
except ImportError:
return None
module = getattr(module, dialect)
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)

View File

@ -1,22 +0,0 @@
from sqlalchemy.dialects.drizzle import base, mysqldb
base.dialect = mysqldb.dialect
from sqlalchemy.dialects.drizzle.base import \
BIGINT, BINARY, BLOB, \
BOOLEAN, CHAR, DATE, \
DATETIME, DECIMAL, DOUBLE, \
ENUM, FLOAT, INTEGER, \
NUMERIC, REAL, TEXT, \
TIME, TIMESTAMP, VARBINARY, \
VARCHAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BLOB',
'BOOLEAN', 'CHAR', 'DATE',
'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'FLOAT', 'INTEGER',
'NUMERIC', 'REAL', 'TEXT',
'TIME', 'TIMESTAMP', 'VARBINARY',
'VARCHAR', 'dialect'
)

View File

@ -1,498 +0,0 @@
# drizzle/base.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2010-2011 Monty Taylor <mordred@inaugust.com>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: drizzle
:name: Drizzle
Drizzle is a variant of MySQL. Unlike MySQL, Drizzle's default storage engine
is InnoDB (transactions, foreign-keys) rather than MyISAM. For more
`Notable Differences <http://docs.drizzle.org/mysql_differences.html>`_, visit
the `Drizzle Documentation <http://docs.drizzle.org/index.html>`_.
The SQLAlchemy Drizzle dialect leans heavily on the MySQL dialect, so much of
the :doc:`SQLAlchemy MySQL <mysql>` documentation is also relevant.
"""
from sqlalchemy import exc
from sqlalchemy import log
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import reflection
from sqlalchemy.dialects.mysql import base as mysql_dialect
from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \
BLOB, BINARY, VARBINARY
class _NumericType(object):
"""Base for Drizzle numeric types."""
def __init__(self, **kw):
super(_NumericType, self).__init__(**kw)
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and \
(
(precision is None and scale is not None) or
(precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
"both altogether.")
super(_FloatType, self).__init__(precision=precision,
asdecimal=asdecimal, **kw)
self.scale = scale
class _StringType(mysql_dialect._StringType):
"""Base for Drizzle string types."""
def __init__(self, collation=None, binary=False, **kw):
kw['national'] = False
super(_StringType, self).__init__(collation=collation, binary=binary,
**kw)
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""Drizzle NUMERIC type."""
__visit_name__ = 'NUMERIC'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(NUMERIC, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""Drizzle DECIMAL type."""
__visit_name__ = 'DECIMAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(DECIMAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class DOUBLE(_FloatType):
"""Drizzle DOUBLE type."""
__visit_name__ = 'DOUBLE'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(DOUBLE, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class REAL(_FloatType, sqltypes.REAL):
"""Drizzle REAL type."""
__visit_name__ = 'REAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(REAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class FLOAT(_FloatType, sqltypes.FLOAT):
"""Drizzle FLOAT type."""
__visit_name__ = 'FLOAT'
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(FLOAT, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
def bind_processor(self, dialect):
return None
class INTEGER(sqltypes.INTEGER):
"""Drizzle INTEGER type."""
__visit_name__ = 'INTEGER'
def __init__(self, **kw):
"""Construct an INTEGER."""
super(INTEGER, self).__init__(**kw)
class BIGINT(sqltypes.BIGINT):
"""Drizzle BIGINTEGER type."""
__visit_name__ = 'BIGINT'
def __init__(self, **kw):
"""Construct a BIGINTEGER."""
super(BIGINT, self).__init__(**kw)
class TIME(mysql_dialect.TIME):
"""Drizzle TIME type."""
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Drizzle TIMESTAMP type."""
__visit_name__ = 'TIMESTAMP'
class TEXT(_StringType, sqltypes.TEXT):
"""Drizzle TEXT type, for text up to 2^16 characters."""
__visit_name__ = 'TEXT'
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` characters.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TEXT, self).__init__(length=length, **kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Drizzle VARCHAR type, for variable-length character data."""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Drizzle CHAR type, for fixed-length character data."""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
super(CHAR, self).__init__(length=length, **kwargs)
class ENUM(mysql_dialect.ENUM):
"""Drizzle ENUM type."""
def __init__(self, *enums, **kw):
"""Construct an ENUM.
Example:
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values will be
quoted when generating the schema according to the quoting flag (see
below).
:param strict: Defaults to False: ensure that a given value is in this
ENUM's range of permissible values when inserting or updating rows.
Note that Drizzle will not raise a fatal error if you attempt to
store an out of range value- an alternate value will be stored
instead.
(See Drizzle ENUM documentation.)
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
:param quoting: Defaults to 'auto': automatically determine enum value
quoting. If all enum values are surrounded by the same quoting
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
'quoted': values in enums are already quoted, they will be used
directly when generating the schema - this usage is deprecated.
'unquoted': values in enums are not quoted, they will be escaped and
surrounded by single quotes when generating the schema.
Previous versions of this type always required manually quoted
values to be supplied; future versions will always quote the string
literals for you. This is a transitional option.
"""
super(ENUM, self).__init__(*enums, **kw)
class _DrizzleBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.NUMERIC
colspecs = {
sqltypes.Numeric: NUMERIC,
sqltypes.Float: FLOAT,
sqltypes.Time: TIME,
sqltypes.Enum: ENUM,
sqltypes.Boolean: _DrizzleBoolean,
}
# All the types we have in Drizzle
ischema_names = {
'BIGINT': BIGINT,
'BINARY': BINARY,
'BLOB': BLOB,
'BOOLEAN': BOOLEAN,
'CHAR': CHAR,
'DATE': DATE,
'DATETIME': DATETIME,
'DECIMAL': DECIMAL,
'DOUBLE': DOUBLE,
'ENUM': ENUM,
'FLOAT': FLOAT,
'INT': INTEGER,
'INTEGER': INTEGER,
'NUMERIC': NUMERIC,
'TEXT': TEXT,
'TIME': TIME,
'TIMESTAMP': TIMESTAMP,
'VARBINARY': VARBINARY,
'VARCHAR': VARCHAR,
}
class DrizzleCompiler(mysql_dialect.MySQLCompiler):
def visit_typeclause(self, typeclause):
type_ = typeclause.type.dialect_impl(self.dialect)
if isinstance(type_, sqltypes.Integer):
return 'INTEGER'
else:
return super(DrizzleCompiler, self).visit_typeclause(typeclause)
def visit_cast(self, cast, **kwargs):
type_ = self.process(cast.typeclause)
if type_ is None:
return self.process(cast.clause)
return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
class DrizzleDDLCompiler(mysql_dialect.MySQLDDLCompiler):
pass
class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler):
def _extend_numeric(self, type_, spec):
return spec
def _extend_string(self, type_, defaults, spec):
"""Extend a string-type declaration with standard SQL
COLLATE annotations and Drizzle specific extensions.
"""
def attr(name):
return getattr(type_, name, defaults.get(name))
if attr('collation'):
collation = 'COLLATE %s' % type_.collation
elif attr('binary'):
collation = 'BINARY'
else:
collation = None
return ' '.join([c for c in (spec, collation)
if c is not None])
def visit_NCHAR(self, type):
raise NotImplementedError("Drizzle does not support NCHAR")
def visit_NVARCHAR(self, type):
raise NotImplementedError("Drizzle does not support NVARCHAR")
def visit_FLOAT(self, type_):
if type_.scale is not None and type_.precision is not None:
return "FLOAT(%s, %s)" % (type_.precision, type_.scale)
else:
return "FLOAT"
def visit_BOOLEAN(self, type_):
return "BOOLEAN"
def visit_BLOB(self, type_):
return "BLOB"
class DrizzleExecutionContext(mysql_dialect.MySQLExecutionContext):
pass
class DrizzleIdentifierPreparer(mysql_dialect.MySQLIdentifierPreparer):
pass
@log.class_logger
class DrizzleDialect(mysql_dialect.MySQLDialect):
"""Details of the Drizzle dialect.
Not used directly in application code.
"""
name = 'drizzle'
_supports_cast = True
supports_sequences = False
supports_native_boolean = True
supports_views = False
default_paramstyle = 'format'
colspecs = colspecs
statement_compiler = DrizzleCompiler
ddl_compiler = DrizzleDDLCompiler
type_compiler = DrizzleTypeCompiler
ischema_names = ischema_names
preparer = DrizzleIdentifierPreparer
def on_connect(self):
"""Force autocommit - Drizzle Bug#707842 doesn't set this properly"""
def connect(conn):
conn.autocommit(False)
return connect
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
"""Return a Unicode SHOW TABLES from a given schema."""
if schema is not None:
current_schema = schema
else:
current_schema = self.default_schema_name
charset = 'utf8'
rp = connection.execute("SHOW TABLES FROM %s" %
self.identifier_preparer.quote_identifier(current_schema))
return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
raise NotImplementedError
def _detect_casing(self, connection):
"""Sniff out identifier case sensitivity.
Cached per-connection. This value can not change without a server
restart.
"""
return 0
def _detect_collations(self, connection):
"""Pull the active COLLATIONS list from the server.
Cached per-connection.
"""
collations = {}
charset = self._connection_charset
rs = connection.execute(
'SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM'
' data_dictionary.COLLATIONS')
for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
def _detect_ansiquotes(self, connection):
"""Detect and adjust for the ANSI_QUOTES sql mode."""
self._server_ansiquotes = False
self._backslash_escapes = False

View File

@ -1,48 +0,0 @@
"""
.. dialect:: drizzle+mysqldb
:name: MySQL-Python
:dbapi: mysqldb
:connectstring: drizzle+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://sourceforge.net/projects/mysql-python
"""
from sqlalchemy.dialects.drizzle.base import (
DrizzleDialect,
DrizzleExecutionContext,
DrizzleCompiler,
DrizzleIdentifierPreparer)
from sqlalchemy.connectors.mysqldb import (
MySQLDBExecutionContext,
MySQLDBCompiler,
MySQLDBIdentifierPreparer,
MySQLDBConnector)
class DrizzleExecutionContext_mysqldb(MySQLDBExecutionContext,
DrizzleExecutionContext):
pass
class DrizzleCompiler_mysqldb(MySQLDBCompiler, DrizzleCompiler):
pass
class DrizzleIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer,
DrizzleIdentifierPreparer):
pass
class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect):
execution_ctx_cls = DrizzleExecutionContext_mysqldb
statement_compiler = DrizzleCompiler_mysqldb
preparer = DrizzleIdentifierPreparer_mysqldb
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return 'utf8'
dialect = DrizzleDialect_mysqldb

View File

@ -1,20 +0,0 @@
# firebird/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb
base.dialect = fdb.dialect
from sqlalchemy.dialects.firebird.base import \
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
dialect
__all__ = (
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
'dialect'
)

View File

@ -1,738 +0,0 @@
# firebird/base.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird
:name: Firebird
Firebird Dialects
-----------------
Firebird offers two distinct dialects_ (not to be confused with a
SQLAlchemy ``Dialect``):
dialect 1
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
dialect 3
This is the newer and supported syntax, introduced in Interbase 6.0.
The SQLAlchemy Firebird dialect detects these versions and
adjusts its representation of SQL accordingly. However,
support for dialect 1 is not well tested and probably has
incompatibilities.
Locking Behavior
----------------
Firebird locks tables aggressively. For this reason, a DROP TABLE may
hang until other transactions are released. SQLAlchemy does its best
to release transactions as quickly as possible. The most common cause
of hanging transactions is a non-fully consumed result set, i.e.::
result = engine.execute("select * from table")
row = result.fetchone()
return
Where above, the ``ResultProxy`` has not been fully consumed. The
connection will be returned to the pool and the transactional state
rolled back once the Python garbage collector reclaims the objects
which hold onto the connection, which often occurs asynchronously.
The above use case can be alleviated by calling ``first()`` on the
``ResultProxy`` which will fetch the first row and immediately close
all remaining cursor/connection resources.
RETURNING support
-----------------
Firebird 2.0 supports returning a result set from inserts, and 2.1
extends that to deletes and updates. This is generically exposed by
the SQLAlchemy ``returning()`` method, such as::
# INSERT..RETURNING
result = table.insert().returning(table.c.col1, table.c.col2).\\
values(name='foo')
print result.fetchall()
# UPDATE..RETURNING
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
where(empl.c.sales>100).\\
values(dict(salary=empl.c.salary * 1.1))
print raises.fetchall()
.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
"""
import datetime
from sqlalchemy import schema as sa_schema
from sqlalchemy import exc, types as sqltypes, sql, util
from sqlalchemy.sql import expression
from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler
from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
RESERVED_WORDS = set([
"active", "add", "admin", "after", "all", "alter", "and", "any", "as",
"asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
"bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
"character", "character_length", "char_length", "check", "close",
"collate", "column", "commit", "committed", "computed", "conditional",
"connect", "constraint", "containing", "count", "create", "cross",
"cstring", "current", "current_connection", "current_date",
"current_role", "current_time", "current_timestamp",
"current_transaction", "current_user", "cursor", "database", "date",
"day", "dec", "decimal", "declare", "default", "delete", "desc",
"descending", "disconnect", "distinct", "do", "domain", "double",
"drop", "else", "end", "entry_point", "escape", "exception",
"execute", "exists", "exit", "external", "extract", "fetch", "file",
"filter", "float", "for", "foreign", "from", "full", "function",
"gdscode", "generator", "gen_id", "global", "grant", "group",
"having", "hour", "if", "in", "inactive", "index", "inner",
"input_type", "insensitive", "insert", "int", "integer", "into", "is",
"isolation", "join", "key", "leading", "left", "length", "level",
"like", "long", "lower", "manual", "max", "maximum_segment", "merge",
"min", "minute", "module_name", "month", "names", "national",
"natural", "nchar", "no", "not", "null", "numeric", "octet_length",
"of", "on", "only", "open", "option", "or", "order", "outer",
"output_type", "overflow", "page", "pages", "page_size", "parameter",
"password", "plan", "position", "post_event", "precision", "primary",
"privileges", "procedure", "protected", "rdb$db_key", "read", "real",
"record_version", "recreate", "recursive", "references", "release",
"reserv", "reserving", "retain", "returning_values", "returns",
"revoke", "right", "rollback", "rows", "row_count", "savepoint",
"schema", "second", "segment", "select", "sensitive", "set", "shadow",
"shared", "singular", "size", "smallint", "snapshot", "some", "sort",
"sqlcode", "stability", "start", "starting", "starts", "statistics",
"sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
"to", "trailing", "transaction", "trigger", "trim", "uncommitted",
"union", "unique", "update", "upper", "user", "using", "value",
"values", "varchar", "variable", "varying", "view", "wait", "when",
"where", "while", "with", "work", "write", "year",
])
class _StringType(sqltypes.String):
"""Base for Firebird string types."""
def __init__(self, charset=None, **kw):
self.charset = charset
super(_StringType, self).__init__(**kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Firebird VARCHAR type"""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Firebird CHAR type"""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
super(CHAR, self).__init__(length=length, **kwargs)
class _FBDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
return process
colspecs = {
sqltypes.DateTime: _FBDateTime
}
ischema_names = {
'SHORT': SMALLINT,
'LONG': INTEGER,
'QUAD': FLOAT,
'FLOAT': FLOAT,
'DATE': DATE,
'TIME': TIME,
'TEXT': TEXT,
'INT64': BIGINT,
'DOUBLE': FLOAT,
'TIMESTAMP': TIMESTAMP,
'VARYING': VARCHAR,
'CSTRING': CHAR,
'BLOB': BLOB,
}
# TODO: date conversion types (should be implemented as _FBDateTime,
# _FBDate, etc. as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_):
return self.visit_SMALLINT(type_)
def visit_datetime(self, type_):
return self.visit_TIMESTAMP(type_)
def visit_TEXT(self, type_):
return "BLOB SUB_TYPE 1"
def visit_BLOB(self, type_):
return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
charset = getattr(type_, 'charset', None)
if charset is None:
return basic
else:
return '%s CHARACTER SET %s' % (basic, charset)
def visit_CHAR(self, type_):
basic = super(FBTypeCompiler, self).visit_CHAR(type_)
return self._extend_string(type_, basic)
def visit_VARCHAR(self, type_):
if not type_.length:
raise exc.CompileError(
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_)
return self._extend_string(type_, basic)
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosyncrasies"""
ansi_bind_rules = True
#def visit_contains_op_binary(self, binary, operator, **kw):
# cant use CONTAINING b.c. it's case insensitive.
#def visit_notcontains_op_binary(self, binary, operator, **kw):
# cant use NOT CONTAINING b.c. it's case insensitive.
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
def visit_startswith_op_binary(self, binary, operator, **kw):
return '%s STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_notstartswith_op_binary(self, binary, operator, **kw):
return '%s NOT STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw))
def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two:
return super(FBCompiler, self).\
visit_alias(alias, asfrom=asfrom, **kwargs)
else:
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
alias_name = isinstance(alias.name,
expression._truncated_label) and \
self._truncated_identifier("alias",
alias.name) or alias.name
return self.process(
alias.original, asfrom=asfrom, **kwargs) + \
" " + \
self.preparer.format_alias(alias, alias_name)
else:
return self.process(alias.original, **kwargs)
def visit_substring_func(self, func, **kw):
s = self.process(func.clauses.clauses[0])
start = self.process(func.clauses.clauses[1])
if len(func.clauses.clauses) > 2:
length = self.process(func.clauses.clauses[2])
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
else:
return "SUBSTRING(%s FROM %s)" % (s, start)
def visit_length_func(self, function, **kw):
if self.dialect._version_two:
return "char_length" + self.function_argspec(function)
else:
return "strlen" + self.function_argspec(function)
visit_char_length_func = visit_length_func
def function_argspec(self, func, **kw):
# TODO: this probably will need to be
# narrowed to a fixed list, some no-arg functions
# may require parens - see similar example in the oracle
# dialect
if func.clauses is not None and len(func.clauses):
return self.process(func.clause_expr, **kw)
else:
return ""
def default_from(self):
return " FROM rdb$database"
def visit_sequence(self, seq):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
if select._limit:
result += "FIRST %s " % self.process(sql.literal(select._limit))
if select._offset:
result += "SKIP %s " % self.process(sql.literal(select._offset))
if select._distinct:
result += "DISTINCT "
return result
def limit_clause(self, select):
"""Already taken care of in the `get_select_precolumns` method."""
return ""
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_select_column(None, c, True, False, {})
for c in expression._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
"""Firebird syntactic idiosyncrasies"""
def visit_create_sequence(self, create):
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
# no syntax for these
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None:
raise NotImplemented(
"Firebird SEQUENCE doesn't support START WITH")
if create.element.increment is not None:
raise NotImplemented(
"Firebird SEQUENCE doesn't support INCREMENT BY")
if self.dialect._version_two:
return "CREATE SEQUENCE %s" % \
self.preparer.format_sequence(create.element)
else:
return "CREATE GENERATOR %s" % \
self.preparer.format_sequence(create.element)
def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two:
return "DROP SEQUENCE %s" % \
self.preparer.format_sequence(drop.element)
else:
return "DROP GENERATOR %s" % \
self.preparer.format_sequence(drop.element)
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words."""
reserved_words = RESERVED_WORDS
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(['_'])
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
class FBExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
"""Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar(
"SELECT gen_id(%s, 1) FROM rdb$database" %
self.dialect.identifier_preparer.format_sequence(seq),
type_
)
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
name = 'firebird'
max_identifier_length = 31
supports_sequences = True
sequences_optional = False
supports_default_values = True
postfetch_lastrowid = False
supports_native_boolean = False
requires_name_normalize = True
supports_empty_insert = False
statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler
preparer = FBIdentifierPreparer
type_compiler = FBTypeCompiler
execution_ctx_cls = FBExecutionContext
colspecs = colspecs
ischema_names = ischema_names
construct_arguments = []
# defaults to dialect ver. 3,
# will be autodetected off upon
# first connect
_version_two = True
def initialize(self, connection):
super(FBDialect, self).initialize(connection)
self._version_two = ('firebird' in self.server_version_info and \
self.server_version_info >= (2, )
) or \
('interbase' in self.server_version_info and \
self.server_version_info >= (6, )
)
if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy()
self.ischema_names['TIMESTAMP'] = sqltypes.DATE
self.colspecs = {
sqltypes.DateTime: sqltypes.DATE
}
self.implicit_returning = self._version_two and \
self.__dict__.get('implicit_returning', True)
def normalize_name(self, name):
# Remove trailing spaces: FB uses a CHAR() type,
# that is padded with spaces
name = name and name.rstrip()
if name is None:
return None
elif name.upper() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.lower()
else:
return name
def denormalize_name(self, name):
if name is None:
return None
elif name.lower() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.upper()
else:
return name
def has_table(self, connection, table_name, schema=None):
"""Return ``True`` if the given table exists, ignoring
the `schema`."""
tblqry = """
SELECT 1 AS has_table FROM rdb$database
WHERE EXISTS (SELECT rdb$relation_name
FROM rdb$relations
WHERE rdb$relation_name=?)
"""
c = connection.execute(tblqry, [self.denormalize_name(table_name)])
return c.first() is not None
def has_sequence(self, connection, sequence_name, schema=None):
"""Return ``True`` if the given sequence (generator) exists."""
genqry = """
SELECT 1 AS has_sequence FROM rdb$database
WHERE EXISTS (SELECT rdb$generator_name
FROM rdb$generators
WHERE rdb$generator_name=?)
"""
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
return c.first() is not None
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
# there are two queries commonly mentioned for this.
# this one, using view_blr, is at the Firebird FAQ among other places:
# http://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
# the other query is this one. It's not clear if there's really
# any difference between these two. This link:
# http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
# states them as interchangeable. Some discussion at [ticket:2898]
# SELECT DISTINCT rdb$relation_name
# FROM rdb$relation_fields
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
# see http://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is not null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
qry = """
SELECT rdb$view_source AS view_source
FROM rdb$relations
WHERE rdb$relation_name=?
"""
rp = connection.execute(qry, [self.denormalize_name(view_name)])
row = rp.first()
if row:
return row['view_source']
else:
return None
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Query to extract the PK/FK constrained fields of the given table
keyqry = """
SELECT se.rdb$field_name AS fname
FROM rdb$relation_constraints rc
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
"""
tablename = self.denormalize_name(table_name)
# get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
return {'constrained_columns': pkfields, 'name': None}
@reflection.cache
def get_column_sequence(self, connection,
table_name, column_name,
schema=None, **kw):
tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field
genqry = """
SELECT trigdep.rdb$depended_on_name AS fgenerator
FROM rdb$dependencies tabdep
JOIN rdb$dependencies trigdep
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
AND trigdep.rdb$depended_on_type=14
AND trigdep.rdb$dependent_type=2
JOIN rdb$triggers trig ON
trig.rdb$trigger_name=tabdep.rdb$dependent_name
WHERE tabdep.rdb$depended_on_name=?
AND tabdep.rdb$depended_on_type=0
AND trig.rdb$trigger_type=1
AND tabdep.rdb$field_name=?
AND (SELECT count(*)
FROM rdb$dependencies trigdep2
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
"""
genr = connection.execute(genqry, [tablename, colname]).first()
if genr is not None:
return dict(name=self.normalize_name(genr['fgenerator']))
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
# Query to extract the details of all the fields of the given table
tblqry = """
SELECT r.rdb$field_name AS fname,
r.rdb$null_flag AS null_flag,
t.rdb$type_name AS ftype,
f.rdb$field_sub_type AS stype,
f.rdb$field_length/
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
f.rdb$field_precision AS fprec,
f.rdb$field_scale AS fscale,
COALESCE(r.rdb$default_source,
f.rdb$default_source) AS fdefault
FROM rdb$relation_fields r
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
JOIN rdb$types t
ON t.rdb$type=f.rdb$field_type AND
t.rdb$field_name='RDB$FIELD_TYPE'
LEFT JOIN rdb$character_sets cs ON
f.rdb$character_set_id=cs.rdb$character_set_id
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
ORDER BY r.rdb$field_position
"""
# get the PK, used to determine the eventual associated sequence
pk_constraint = self.get_pk_constraint(connection, table_name)
pkey_cols = pk_constraint['constrained_columns']
tablename = self.denormalize_name(table_name)
# get all of the fields for this table
c = connection.execute(tblqry, [tablename])
cols = []
while True:
row = c.fetchone()
if row is None:
break
name = self.normalize_name(row['fname'])
orig_colname = row['fname']
# get the data type
colspec = row['ftype'].rstrip()
coltype = self.ischema_names.get(colspec)
if coltype is None:
util.warn("Did not recognize type '%s' of column '%s'" %
(colspec, name))
coltype = sqltypes.NULLTYPE
elif issubclass(coltype, Integer) and row['fprec'] != 0:
coltype = NUMERIC(
precision=row['fprec'],
scale=row['fscale'] * -1)
elif colspec in ('VARYING', 'CSTRING'):
coltype = coltype(row['flen'])
elif colspec == 'TEXT':
coltype = TEXT(row['flen'])
elif colspec == 'BLOB':
if row['stype'] == 1:
coltype = TEXT()
else:
coltype = BLOB()
else:
coltype = coltype()
# does it have a default value?
defvalue = None
if row['fdefault'] is not None:
# the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword
# and it may also be lower case
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
defexpr = row['fdefault'].lstrip()
assert defexpr[:8].rstrip().upper() == \
'DEFAULT', "Unrecognized default value: %s" % \
defexpr
defvalue = defexpr[8:].strip()
if defvalue == 'NULL':
# Redundant
defvalue = None
col_d = {
'name': name,
'type': coltype,
'nullable': not bool(row['null_flag']),
'default': defvalue,
'autoincrement': defvalue is None
}
if orig_colname.lower() == orig_colname:
col_d['quote'] = True
# if the PK is a single field, try to see if its linked to
# a sequence thru a trigger
if len(pkey_cols) == 1 and name == pkey_cols[0]:
seq_d = self.get_column_sequence(connection, tablename, name)
if seq_d is not None:
col_d['sequence'] = seq_d
cols.append(col_d)
return cols
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Query to extract the details of each UK/FK of the given table
fkqry = """
SELECT rc.rdb$constraint_name AS cname,
cse.rdb$field_name AS fname,
ix2.rdb$relation_name AS targetrname,
se.rdb$field_name AS targetfname
FROM rdb$relation_constraints rc
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
JOIN rdb$index_segments cse ON
cse.rdb$index_name=ix1.rdb$index_name
JOIN rdb$index_segments se
ON se.rdb$index_name=ix2.rdb$index_name
AND se.rdb$field_position=cse.rdb$field_position
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
ORDER BY se.rdb$index_name, se.rdb$field_position
"""
tablename = self.denormalize_name(table_name)
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
fks = util.defaultdict(lambda: {
'name': None,
'constrained_columns': [],
'referred_schema': None,
'referred_table': None,
'referred_columns': []
})
for row in c:
cname = self.normalize_name(row['cname'])
fk = fks[cname]
if not fk['name']:
fk['name'] = cname
fk['referred_table'] = self.normalize_name(row['targetrname'])
fk['constrained_columns'].append(
self.normalize_name(row['fname']))
fk['referred_columns'].append(
self.normalize_name(row['targetfname']))
return list(fks.values())
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
qry = """
SELECT ix.rdb$index_name AS index_name,
ix.rdb$unique_flag AS unique_flag,
ic.rdb$field_name AS field_name
FROM rdb$indices ix
JOIN rdb$index_segments ic
ON ix.rdb$index_name=ic.rdb$index_name
LEFT OUTER JOIN rdb$relation_constraints
ON rdb$relation_constraints.rdb$index_name =
ic.rdb$index_name
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
AND rdb$relation_constraints.rdb$constraint_type IS NULL
ORDER BY index_name, ic.rdb$field_position
"""
c = connection.execute(qry, [self.denormalize_name(table_name)])
indexes = util.defaultdict(dict)
for row in c:
indexrec = indexes[row['index_name']]
if 'name' not in indexrec:
indexrec['name'] = self.normalize_name(row['index_name'])
indexrec['column_names'] = []
indexrec['unique'] = bool(row['unique_flag'])
indexrec['column_names'].append(
self.normalize_name(row['field_name']))
return list(indexes.values())

View File

@ -1,115 +0,0 @@
# firebird/fdb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+fdb
:name: fdb
:dbapi: pyodbc
:connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: http://pypi.python.org/pypi/fdb/
fdb is a kinterbasdb compatible DBAPI for Firebird.
.. versionadded:: 0.8 - Support for the fdb Firebird driver.
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
under the ``firebird://`` URL space, as ``fdb`` is now the official
Python driver for Firebird.
Arguments
----------
The ``fdb`` dialect is based on the :mod:`sqlalchemy.dialects.firebird.kinterbasdb`
dialect, however does not accept every argument that Kinterbasdb does.
* ``enable_rowcount`` - True by default, setting this to False disables
the usage of "cursor.rowcount" with the
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
ResultProxy will return -1 for result.rowcount. The rationale here is
that Kinterbasdb requires a second round trip to the database when
.rowcount is called - since SQLA's resultproxy automatically closes
the cursor after a non-result-returning statement, rowcount must be
called, if at all, before the result object is returned. Additionally,
cursor.rowcount may not return correct results with older versions
of Firebird, and setting this flag to False will also cause the
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
per-execution basis using the ``enable_rowcount`` option with
:meth:`.Connection.execution_options`::
conn = engine.connect().execution_options(enable_rowcount=True)
r = conn.execute(stmt)
print r.rowcount
* ``retaining`` - False by default. Setting this to True will pass the
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
methods of the DBAPI connection, which can improve performance in some
situations, but apparently with significant caveats.
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
understand the implications of this flag.
.. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying
transaction retaining behavior - in 0.8 it defaults to ``True``
for backwards compatibility.
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
In 0.8 it defaulted to ``True``.
.. seealso::
http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions - information
on the "retaining" flag.
"""
from .kinterbasdb import FBDialect_kinterbasdb
from ... import util
class FBDialect_fdb(FBDialect_kinterbasdb):
def __init__(self, enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_fdb, self).__init__(
enable_rowcount=enable_rowcount,
retaining=retaining, **kwargs)
@classmethod
def dbapi(cls):
return __import__('fdb')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
isc_info_firebird_version = 103
fbconn = connection.connection
version = fbconn.db_info(isc_info_firebird_version)
return self._parse_version_info(version)
dialect = FBDialect_fdb

View File

@ -1,179 +0,0 @@
# firebird/kinterbasdb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+kinterbasdb
:name: kinterbasdb
:dbapi: kinterbasdb
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: http://firebirdsql.org/index.php?op=devel&sub=python
Arguments
----------
The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. In addition, it
also accepts the following:
* ``type_conv`` - select the kind of mapping done on the types: by default
SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
the linked documents below for further information.
* ``concurrency_level`` - set the backend policy with regards to threading
issues: by default SQLAlchemy uses policy 1. See the linked documents
below for futher information.
.. seealso::
http://sourceforge.net/projects/kinterbasdb
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
"""
from .base import FBDialect, FBExecutionContext
from ... import util, types as sqltypes
from re import match
import decimal
class _kinterbasdb_numeric(object):
def bind_processor(self, dialect):
def process(value):
if isinstance(value, decimal.Decimal):
return str(value)
else:
return value
return process
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
pass
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
pass
class FBExecutionContext_kinterbasdb(FBExecutionContext):
@property
def rowcount(self):
if self.execution_options.get('enable_rowcount',
self.dialect.enable_rowcount):
return self.cursor.rowcount
else:
return -1
class FBDialect_kinterbasdb(FBDialect):
driver = 'kinterbasdb'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
execution_ctx_cls = FBExecutionContext_kinterbasdb
supports_native_decimal = True
colspecs = util.update_copy(
FBDialect.colspecs,
{
sqltypes.Numeric: _FBNumeric_kinterbasdb,
sqltypes.Float: _FBFloat_kinterbasdb,
}
)
def __init__(self, type_conv=200, concurrency_level=1,
enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.enable_rowcount = enable_rowcount
self.type_conv = type_conv
self.concurrency_level = concurrency_level
self.retaining = retaining
if enable_rowcount:
self.supports_sane_rowcount = True
@classmethod
def dbapi(cls):
return __import__('kinterbasdb')
def do_execute(self, cursor, statement, parameters, context=None):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, dbapi_connection):
dbapi_connection.rollback(self.retaining)
def do_commit(self, dbapi_connection):
dbapi_connection.commit(self.retaining)
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
type_conv = opts.pop('type_conv', self.type_conv)
concurrency_level = opts.pop('concurrency_level',
self.concurrency_level)
if self.dbapi is not None:
initialized = getattr(self.dbapi, 'initialized', None)
if initialized is None:
# CVS rev 1.96 changed the name of the attribute:
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
initialized = getattr(self.dbapi, '_initialized', False)
if not initialized:
self.dbapi.init(type_conv=type_conv,
concurrency_level=concurrency_level)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
fbconn = connection.connection
version = fbconn.server_version
return self._parse_version_info(version)
def _parse_version_info(self, version):
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
if not m:
raise AssertionError(
"Could not determine version from string '%s'" % version)
if m.group(5) != None:
return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
else:
return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
def is_disconnect(self, e, connection, cursor):
if isinstance(e, (self.dbapi.OperationalError,
self.dbapi.ProgrammingError)):
msg = str(e)
return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or
'Invalid cursor state' in msg or
'connection shutdown' in msg)
else:
return False
dialect = FBDialect_kinterbasdb

View File

@ -1,26 +0,0 @@
# mssql/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \
pymssql, zxjdbc, mxodbc
base.dialect = pyodbc.dialect
from sqlalchemy.dialects.mssql.base import \
INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\
MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
)

View File

@ -1,79 +0,0 @@
# mssql/adodbapi.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+adodbapi
:name: adodbapi
:dbapi: adodbapi
:connectstring: mssql+adodbapi://<username>:<password>@<dsnname>
:url: http://adodbapi.sourceforge.net/
.. note::
The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and
above at this time.
"""
import datetime
from sqlalchemy import types as sqltypes, util
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
import sys
class MSDateTime_adodbapi(MSDateTime):
def result_processor(self, dialect, coltype):
def process(value):
# adodbapi will return datetimes with empty time
# values as datetime.date() objects.
# Promote them back to full datetime.datetime()
if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day)
return value
return process
class MSDialect_adodbapi(MSDialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = True
driver = 'adodbapi'
@classmethod
def import_dbapi(cls):
import adodbapi as module
return module
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.DateTime: MSDateTime_adodbapi
}
)
def create_connect_args(self, url):
keys = url.query
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
connectors.append("Data Source=%s, %s" %
(keys.get("host"), keys.get("port")))
else:
connectors.append("Data Source=%s" % keys.get("host"))
connectors.append("Initial Catalog=%s" % keys.get("database"))
user = keys.get("user")
if user:
connectors.append("User Id=%s" % user)
connectors.append("Password=%s" % keys.get("password", ""))
else:
connectors.append("Integrated Security=SSPI")
return [[";".join(connectors)], {}]
def is_disconnect(self, e, connection, cursor):
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
"'connection failure'" in str(e)
dialect = MSDialect_adodbapi

File diff suppressed because it is too large Load Diff

View File

@ -1,114 +0,0 @@
# mssql/information_schema.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
# TODO: should be using the sys. catalog with SQL Server, not information schema
from ... import Table, MetaData, Column
from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator
from ... import cast
from ... import util
from ...sql import expression
from ...ext.compiler import compiles
ischema = MetaData()
class CoerceUnicode(TypeDecorator):
impl = Unicode
def process_bind_param(self, value, dialect):
if util.py2k and isinstance(value, util.binary_type):
value = value.decode(dialect.encoding)
return value
def bind_expression(self, bindvalue):
return _cast_on_2005(bindvalue)
class _cast_on_2005(expression.ColumnElement):
def __init__(self, bindvalue):
self.bindvalue = bindvalue
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
if compiler.dialect.server_version_info < base.MS_2005_VERSION:
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
schemata = Table("SCHEMATA", ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
schema="INFORMATION_SCHEMA")
tables = Table("TABLES", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
schema="INFORMATION_SCHEMA")
columns = Table("COLUMNS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="INFORMATION_SCHEMA")
constraints = Table("TABLE_CONSTRAINTS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
schema="INFORMATION_SCHEMA")
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
schema="INFORMATION_SCHEMA")
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
schema="INFORMATION_SCHEMA")
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
# TODO: is CATLOG misspelled ?
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
key="unique_constraint_catalog"),
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
key="unique_constraint_schema"),
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
key="unique_constraint_name"),
Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"),
schema="INFORMATION_SCHEMA")
views = Table("VIEWS", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA")

View File

@ -1,111 +0,0 @@
# mssql/mxodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+mxodbc
:name: mxODBC
:dbapi: mxodbc
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
:url: http://www.egenix.com/
Execution Modes
---------------
mxODBC features two styles of statement execution, using the
``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
an extension to the DBAPI specification). The former makes use of a particular
API call specific to the SQL Server Native Client ODBC driver known
SQLDescribeParam, while the latter does not.
mxODBC apparently only makes repeated use of a single prepared statement
when SQLDescribeParam is used. The advantage to prepared statement reuse is
one of performance. The disadvantage is that SQLDescribeParam has a limited
set of scenarios in which bind parameters are understood, including that they
cannot be placed within the argument lists of function calls, anywhere outside
the FROM, or even within subqueries within the FROM clause - making the usage
of bind parameters within SELECT statements impossible for all but the most
simplistic statements.
For this reason, the mxODBC dialect uses the "native" mode by default only for
INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
all other statements.
This behavior can be controlled via
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
value of ``True`` will unconditionally use native bind parameters and a value
of ``False`` will unconditionally use string-escaped parameters.
"""
from ... import types as sqltypes
from ...connectors.mxodbc import MxODBCConnector
from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
from .base import (MSDialect,
MSSQLStrictCompiler,
_MSDateTime, _MSDate, _MSTime)
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
"""Include pyodbc's numeric processor.
"""
class _MSDate_mxodbc(_MSDate):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s-%s-%s" % (value.year, value.month, value.day)
else:
return None
return process
class _MSTime_mxodbc(_MSTime):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s:%s:%s" % (value.hour, value.minute, value.second)
else:
return None
return process
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
"""
The pyodbc execution context is useful for enabling
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
does not work (tables with insert triggers).
"""
#todo - investigate whether the pyodbc execution context
# is really only being used in cases where OUTPUT
# won't work.
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# this is only needed if "native ODBC" mode is used,
# which is now disabled by default.
#statement_compiler = MSSQLStrictCompiler
execution_ctx_cls = MSExecutionContext_mxodbc
# flag used by _MSNumeric_mxodbc
_need_decimal_fix = True
colspecs = {
sqltypes.Numeric: _MSNumeric_mxodbc,
sqltypes.DateTime: _MSDateTime,
sqltypes.Date: _MSDate_mxodbc,
sqltypes.Time: _MSTime_mxodbc,
}
def __init__(self, description_encoding=None, **params):
super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding
dialect = MSDialect_mxodbc

View File

@ -1,92 +0,0 @@
# mssql/pymssql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+pymssql
:name: pymssql
:dbapi: pymssql
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>?charset=utf8
:url: http://pymssql.org/
pymssql is a Python module that provides a Python DBAPI interface around
`FreeTDS <http://www.freetds.org/>`_. Compatible builds are available for
Linux, MacOSX and Windows platforms.
"""
from .base import MSDialect
from ... import types as sqltypes, util, processors
import re
class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSDialect_pymssql(MSDialect):
supports_sane_rowcount = False
driver = 'pymssql'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pymssql,
sqltypes.Float: sqltypes.Float,
}
)
@classmethod
def dbapi(cls):
module = __import__('pymssql')
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (1, ):
util.warn("The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI.")
return module
def __init__(self, **params):
super(MSDialect_pymssql, self).__init__(**params)
self.use_scope_identity = True
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version")
m = re.match(
r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
port = opts.pop('port', None)
if port and 'host' in opts:
opts['host'] = "%s:%s" % (opts['host'], port)
return [[], opts]
def is_disconnect(self, e, connection, cursor):
for msg in (
"Adaptive Server connection timed out",
"Net-Lib error during Connection reset by peer",
"message 20003", # connection timeout
"Error 10054",
"Not connected to any MS SQL server",
"Connection is closed"
):
if msg in str(e):
return True
else:
return False
dialect = MSDialect_pymssql

View File

@ -1,260 +0,0 @@
# mssql/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
Additional Connection Examples
-------------------------------
Examples of pyodbc connection string URLs:
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
The connection string that is created will appear like::
dsn=mydsn;Trusted_Connection=Yes
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
like::
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection
that would appear like::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
string which includes the port
information using the comma syntax. This will create the following
connection string::
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
string that includes the port
information as a separate ``port`` keyword. This will create the
following connection string::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
* ``mssql+pyodbc://user:pass@host/db?driver=MyDriver`` - connects using a connection
string that includes a custom
ODBC driver name. This will create the following connection string::
DRIVER={MyDriver};Server=host;Database=db;UID=user;PWD=pass
If you require a connection string that is outside the options
presented above, use the ``odbc_connect`` keyword to pass in a
urlencoded connection string. What gets passed in will be urldecoded
and passed directly.
For example::
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string::
dsn=mydsn;Database=db
Encoding your connection string can be easily accomplished through
the python shell. For example::
>>> import urllib
>>> urllib.quote_plus('dsn=mydsn;Database=db')
'dsn%3Dmydsn%3BDatabase%3Ddb'
Unicode Binds
-------------
The current state of PyODBC on a unix backend with FreeTDS and/or
EasySoft is poor regarding unicode; different OS platforms and versions of UnixODBC
versus IODBC versus FreeTDS/EasySoft versus PyODBC itself dramatically
alter how strings are received. The PyODBC dialect attempts to use all the information
it knows to determine whether or not a Python unicode literal can be
passed directly to the PyODBC driver or not; while SQLAlchemy can encode
these to bytestrings first, some users have reported that PyODBC mis-handles
bytestrings for certain encodings and requires a Python unicode object,
while the author has observed widespread cases where a Python unicode
is completely misinterpreted by PyODBC, particularly when dealing with
the information schema tables used in table reflection, and the value
must first be encoded to a bytestring.
It is for this reason that whether or not unicode literals for bound
parameters be sent to PyODBC can be controlled using the
``supports_unicode_binds`` parameter to ``create_engine()``. When
left at its default of ``None``, the PyODBC dialect will use its
best guess as to whether or not the driver deals with unicode literals
well. When ``False``, unicode literals will be encoded first, and when
``True`` unicode literals will be passed straight through. This is an interim
flag that hopefully should not be needed when the unicode situation stabilizes
for unix + PyODBC.
.. versionadded:: 0.7.7
``supports_unicode_binds`` parameter to ``create_engine()``\ .
"""
from .base import MSExecutionContext, MSDialect
from ...connectors.pyodbc import PyODBCConnector
from ... import types as sqltypes, util
import decimal
class _ms_numeric_pyodbc(object):
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
The routines here are needed for older pyodbc versions
as well as current mxODBC versions.
"""
def bind_processor(self, dialect):
super_process = super(_ms_numeric_pyodbc, self).\
bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value):
if self.asdecimal and \
isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
elif adjusted > 7:
return self._large_dec_to_string(value)
if super_process:
return super_process(value)
else:
return value
return process
# these routines needed for older versions of pyodbc.
# as of 2.1.8 this logic is integrated.
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
(value < 0 and '-' or ''),
'0' * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value.as_tuple()[1]]))
def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
if 'E' in str(value):
result = "%s%s%s" % (
(value < 0 and '-' or ''),
"".join([str(s) for s in _int]),
"0" * (value.adjusted() - (len(_int) - 1)))
else:
if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % (
(value < 0 and '-' or ''),
"".join(
[str(s) for s in _int][0:value.adjusted() + 1]),
"".join(
[str(s) for s in _int][value.adjusted() + 1:]))
else:
result = "%s%s" % (
(value < 0 and '-' or ''),
"".join(
[str(s) for s in _int][0:value.adjusted() + 1]))
return result
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
pass
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
pass
class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same
statement.
Background on why "scope_identity()" is preferable to "@@identity":
http://msdn.microsoft.com/en-us/library/ms190315.aspx
Background on why we attempt to embed "scope_identity()" into the same
statement as the INSERT:
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
"""
super(MSExecutionContext_pyodbc, self).pre_exec()
# don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
if self._select_lastrowid and \
self.dialect.use_scope_identity and \
len(self.parameters[0]):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with
# no data (due to triggers, etc.)
while True:
try:
# fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly)
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error as e:
# no way around this - nextset() consumes the previous set
# so we need to just keep flipping
self.cursor.nextset()
self._lastrowid = int(row[0])
else:
super(MSExecutionContext_pyodbc, self).post_exec()
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
execution_ctx_cls = MSExecutionContext_pyodbc
pyodbc_driver_name = 'SQL Server'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc
}
)
def __init__(self, description_encoding=None, **params):
super(MSDialect_pyodbc, self).__init__(**params)
self.description_encoding = description_encoding
self.use_scope_identity = self.use_scope_identity and \
self.dbapi and \
hasattr(self.dbapi.Cursor, 'nextset')
self._need_decimal_fix = self.dbapi and \
self._dbapi_version() < (2, 1, 8)
dialect = MSDialect_pyodbc

View File

@ -1,65 +0,0 @@
# mssql/zxjdbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+zxjdbc
:name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]
:driverurl: http://jtds.sourceforge.net/
"""
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import MSDialect, MSExecutionContext
from ... import engine
class MSExecutionContext_zxjdbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
super(MSExecutionContext_zxjdbc, self).pre_exec()
# scope_identity after the fact returns null in jTDS so we must
# embed it
if self._select_lastrowid and self.dialect.use_scope_identity:
self._embedded_scope_identity = True
self.statement += "; SELECT scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
while True:
try:
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error:
self.cursor.nextset()
self._lastrowid = int(row[0])
if (self.isinsert or self.isupdate or self.isdelete) and \
self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
table = self.dialect.identifier_preparer.format_table(
self.compiled.statement.table)
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
jdbc_db_name = 'jtds:sqlserver'
jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
execution_ctx_cls = MSExecutionContext_zxjdbc
def _get_server_version_info(self, connection):
return tuple(
int(x)
for x in connection.connection.dbversion.split('.')
)
dialect = MSDialect_zxjdbc

View File

@ -1,28 +0,0 @@
# mysql/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import base, mysqldb, oursql, \
pyodbc, zxjdbc, mysqlconnector, pymysql,\
gaerdbms, cymysql
# default dialect
base.dialect = mysqldb.dialect
from .base import \
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
DECIMAL, DOUBLE, ENUM, DECIMAL,\
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
MEDIUMINT, MEDIUMTEXT, NCHAR, \
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
TINYBLOB, TINYINT, TINYTEXT,\
VARBINARY, VARCHAR, YEAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
)

File diff suppressed because it is too large Load Diff

View File

@ -1,84 +0,0 @@
# mysql/cymysql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+cymysql
:name: CyMySQL
:dbapi: cymysql
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://github.com/nakagami/CyMySQL
"""
import re
from .mysqldb import MySQLDialect_mysqldb
from .base import (BIT, MySQLDialect)
from ... import util
class _cymysqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
"""
def process(value):
if value is not None:
v = 0
for i in util.iterbytes(value):
v = v << 8 | i
return v
return value
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
driver = 'cymysql'
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _cymysqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('cymysql')
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.server_version):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in \
(2006, 2013, 2014, 2045, 2055)
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
else:
return False
dialect = MySQLDialect_cymysql

View File

@ -1,84 +0,0 @@
# mysql/gaerdbms.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+gaerdbms
:name: Google Cloud SQL
:dbapi: rdbms
:connectstring: mysql+gaerdbms:///<dbname>?instance=<instancename>
:url: https://developers.google.com/appengine/docs/python/cloud-sql/developers-guide
This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with minimal
changes.
.. versionadded:: 0.7.8
Pooling
-------
Google App Engine connections appear to be randomly recycled,
so the dialect does not pool connections. The :class:`.NullPool`
implementation is installed within the :class:`.Engine` by
default.
"""
import os
from .mysqldb import MySQLDialect_mysqldb
from ...pool import NullPool
import re
def _is_dev_environment():
return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
@classmethod
def dbapi(cls):
# from django:
# http://code.google.com/p/googleappengine/source/
# browse/trunk/python/google/storage/speckle/
# python/django/backend/base.py#118
# see also [ticket:2649]
# see also http://stackoverflow.com/q/14224679/34549
from google.appengine.api import apiproxy_stub_map
if _is_dev_environment():
from google.appengine.api import rdbms_mysqldb
return rdbms_mysqldb
elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
from google.storage.speckle.python.api import rdbms_apiproxy
return rdbms_apiproxy
else:
from google.storage.speckle.python.api import rdbms_googleapi
return rdbms_googleapi
@classmethod
def get_pool_class(cls, url):
# Cloud SQL connections die at any moment
return NullPool
def create_connect_args(self, url):
opts = url.translate_connect_args()
if not _is_dev_environment():
# 'dsn' and 'instance' are because we are skipping
# the traditional google.api.rdbms wrapper
opts['dsn'] = ''
opts['instance'] = url.query['instance']
return [], opts
def _extract_error_code(self, exception):
match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception))
# The rdbms api will wrap then re-raise some types of errors
# making this regex return no matches.
code = match.group(1) or match.group(2) if match else None
if code:
return int(code)
dialect = MySQLDialect_gaerdbms

View File

@ -1,131 +0,0 @@
# mysql/mysqlconnector.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mysqlconnector
:name: MySQL Connector/Python
:dbapi: myconnpy
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://dev.mysql.com/downloads/connector/python/
"""
from .base import (MySQLDialect,
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
BIT)
from ... import util
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
def get_lastrowid(self):
return self.cursor.lastrowid
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = 'mysqlconnector'
if util.py2k:
supports_unicode_statements = False
supports_unicode_binds = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _myconnpyBIT,
}
)
@classmethod
def dbapi(cls):
from mysql import connector
return connector
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
util.coerce_kw_type(opts, 'buffered', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
opts.setdefault('buffered', True)
opts.setdefault('raise_on_warnings', True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
client_flags = opts.get('client_flags', ClientFlag.get_default())
client_flags |= ClientFlag.FOUND_ROWS
opts['client_flags'] = client_flags
except:
pass
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = dbapi_con.get_server_version()
return tuple(version)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return e.errno in errnos or \
"MySQL Connection not available." in str(e)
else:
return False
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
dialect = MySQLDialect_mysqlconnector

View File

@ -1,94 +0,0 @@
# mysql/mysqldb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mysqldb
:name: MySQL-Python
:dbapi: mysqldb
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://sourceforge.net/projects/mysql-python
Unicode
-------
MySQLdb requires a "charset" parameter to be passed in order for it
to handle non-ASCII characters correctly. When this parameter is passed,
MySQLdb will also implicitly set the "use_unicode" flag to true, which means
that it will return Python unicode objects instead of bytestrings.
However, SQLAlchemy's decode process, when C extensions are enabled,
is orders of magnitude faster than that of MySQLdb as it does not call into
Python functions to do so. Therefore, the **recommended URL to use for
unicode** will include both charset and use_unicode=0::
create_engine("mysql+mysqldb://user:pass@host/dbname?charset=utf8&use_unicode=0")
As of this writing, MySQLdb only runs on Python 2. It is not known how
MySQLdb behaves on Python 3 as far as unicode decoding.
Known Issues
-------------
MySQL-python version 1.2.2 has a serious memory leak related
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
It is strongly advised to use the latest version of MySQL-Python.
"""
from .base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from ...connectors.mysqldb import (
MySQLDBExecutionContext,
MySQLDBCompiler,
MySQLDBIdentifierPreparer,
MySQLDBConnector
)
from .base import TEXT
from ... import sql
class MySQLExecutionContext_mysqldb(MySQLDBExecutionContext, MySQLExecutionContext):
pass
class MySQLCompiler_mysqldb(MySQLDBCompiler, MySQLCompiler):
pass
class MySQLIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, MySQLIdentifierPreparer):
pass
class MySQLDialect_mysqldb(MySQLDBConnector, MySQLDialect):
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer_mysqldb
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8_bin collation and unicode returns
has_utf8_bin = connection.scalar(
"show collation where %s = 'utf8' and %s = 'utf8_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation")
))
if has_utf8_bin:
additional_tests = [
sql.collate(sql.cast(
sql.literal_column(
"'test collated returns'"),
TEXT(charset='utf8')), "utf8_bin")
]
else:
additional_tests = []
return super(MySQLDBConnector, self)._check_unicode_returns(
connection, additional_tests)
dialect = MySQLDialect_mysqldb

View File

@ -1,261 +0,0 @@
# mysql/oursql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+oursql
:name: OurSQL
:dbapi: oursql
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://packages.python.org/oursql/
Unicode
-------
oursql defaults to using ``utf8`` as the connection charset, but other
encodings may be used instead. Like the MySQL-Python driver, unicode support
can be completely disabled::
# oursql sets the connection charset to utf8 automatically; all strings come
# back as utf8 str
create_engine('mysql+oursql:///mydb?use_unicode=0')
To not automatically use ``utf8`` and instead use whatever the connection
defaults to, there is a separate parameter::
# use the default connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?default_charset=1')
# use latin1 as the connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?charset=latin1')
"""
import re
from .base import (BIT, MySQLDialect, MySQLExecutionContext)
from ... import types as sqltypes, util
class _oursqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""oursql already converts mysql bits, so."""
return None
class MySQLExecutionContext_oursql(MySQLExecutionContext):
@property
def plain_query(self):
return self.execution_options.get('_oursql_plain_query', False)
class MySQLDialect_oursql(MySQLDialect):
driver = 'oursql'
if util.py2k:
supports_unicode_binds = True
supports_unicode_statements = True
supports_native_decimal = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
execution_ctx_cls = MySQLExecutionContext_oursql
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _oursqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('oursql')
def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
if context and context.plain_query:
cursor.execute(statement, plain_query=True)
else:
cursor.execute(statement, parameters)
def do_begin(self, connection):
connection.cursor().execute('BEGIN', plain_query=True)
def _xa_query(self, connection, query, xid):
if util.py2k:
arg = connection.connection._escape_string(xid)
else:
charset = self._connection_charset
arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
arg = "'%s'" % arg
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
# Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries
# refuse to return any data if they're run through
# the parameterized query API, or refuse to be parameterized
# in the first place.
def do_begin_twophase(self, connection, xid):
self._xa_query(connection, 'XA BEGIN %s', xid)
def do_prepare_twophase(self, connection, xid):
self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA PREPARE %s', xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA ROLLBACK %s', xid)
def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
self._xa_query(connection, 'XA COMMIT %s', xid)
# Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ?
def has_table(self, connection, table_name, schema=None):
return MySQLDialect.has_table(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema
)
def get_table_options(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_table_options(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_columns(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_columns(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_view_names(self, connection, schema=None, **kw):
return MySQLDialect.get_view_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema=schema,
**kw
)
def get_table_names(self, connection, schema=None, **kw):
return MySQLDialect.get_table_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema
)
def get_schema_names(self, connection, **kw):
return MySQLDialect.get_schema_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
**kw
)
def initialize(self, connection):
return MySQLDialect.initialize(
self,
connection.execution_options(_oursql_plain_query=True)
)
def _show_create_table(self, connection, table, charset=None,
full_name=None):
return MySQLDialect._show_create_table(
self,
connection.contextual_connect(close_with_result=True).
execution_options(_oursql_plain_query=True),
table, charset, full_name
)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
else:
return e.errno in (2006, 2013, 2014, 2045, 2055)
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'port', int)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'autoping', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
util.coerce_kw_type(opts, 'default_charset', bool)
if opts.pop('default_charset', False):
opts['charset'] = None
else:
util.coerce_kw_type(opts, 'charset', str)
opts['use_unicode'] = opts.get('use_unicode', True)
util.coerce_kw_type(opts, 'use_unicode', bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
opts.setdefault('found_rows', True)
ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
'ssl_capath', 'ssl_cipher']:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.server_info):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
return exception.errno
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return connection.connection.charset
def _compat_fetchall(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchone()
def _compat_first(self, rp, charset=None):
return rp.first()
dialect = MySQLDialect_oursql

View File

@ -1,45 +0,0 @@
# mysql/pymysql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+pymysql
:name: PyMySQL
:dbapi: pymysql
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: http://code.google.com/p/pymysql/
MySQL-Python Compatibility
--------------------------
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
and targets 100% compatibility. Most behavioral notes for MySQL-python apply to
the pymysql driver as well.
"""
from .mysqldb import MySQLDialect_mysqldb
from ...util import py3k
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = 'pymysql'
description_encoding = None
if py3k:
supports_unicode_statements = True
@classmethod
def dbapi(cls):
return __import__('pymysql')
if py3k:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
dialect = MySQLDialect_pymysql

View File

@ -1,80 +0,0 @@
# mysql/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
Limitations
-----------
The mysql-pyodbc dialect is subject to unresolved character encoding issues
which exist within the current ODBC drivers available.
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
of OurSQL, MySQLdb, or MySQL-connector/Python.
"""
from .base import MySQLDialect, MySQLExecutionContext
from ...connectors.pyodbc import PyODBCConnector
from ... import util
import re
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_unicode_statements = False
execution_ctx_cls = MySQLExecutionContext_pyodbc
pyodbc_driver_name = "MySQL"
def __init__(self, **kw):
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
kw.setdefault('convert_unicode', True)
super(MySQLDialect_pyodbc, self).__init__(**kw)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
else:
return None
dialect = MySQLDialect_pyodbc

View File

@ -1,111 +0,0 @@
# mysql/zxjdbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+zxjdbc
:name: zxjdbc for Jython
:dbapi: zxjdbc
:connectstring: mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
:driverurl: http://dev.mysql.com/downloads/connector/j/
Character Sets
--------------
SQLAlchemy zxjdbc dialects pass unicode straight through to the
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
``characterEncoding`` connection property to ``UTF-8``. It may be
overriden via a ``create_engine`` URL parameter.
"""
import re
from ... import types as sqltypes, util
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import BIT, MySQLDialect, MySQLExecutionContext
class _ZxJDBCBit(BIT):
def result_processor(self, dialect, coltype):
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
def process(value):
if value is None:
return value
if isinstance(value, bool):
return int(value)
v = 0
for i in value:
v = v << 8 | (i & 0xff)
value = v
return value
return process
class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
jdbc_db_name = 'mysql'
jdbc_driver_name = 'com.mysql.jdbc.Driver'
execution_ctx_cls = MySQLExecutionContext_zxjdbc
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _ZxJDBCBit
}
)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _driver_kwargs(self):
"""return kw arg dict to be sent to connect()."""
return dict(characterEncoding='UTF-8', yearIsDateType='false')
def _extract_error_code(self, exception):
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.dbversion):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
dialect = MySQLDialect_zxjdbc

View File

@ -1,23 +0,0 @@
# oracle/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
base.dialect = cx_oracle.dialect
from sqlalchemy.dialects.oracle.base import \
VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
VARCHAR2, NVARCHAR2, ROWID, dialect
__all__ = (
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER',
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
'VARCHAR2', 'NVARCHAR2', 'ROWID'
)

File diff suppressed because it is too large Load Diff

View File

@ -1,941 +0,0 @@
# oracle/cx_oracle.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: oracle+cx_oracle
:name: cx-Oracle
:dbapi: cx_oracle
:connectstring: oracle+cx_oracle://user:pass@host:port/dbname[?key=value&key=value...]
:url: http://cx-oracle.sourceforge.net/
Additional Connect Arguments
----------------------------
When connecting with ``dbname`` present, the host, port, and dbname tokens are
converted to a TNS name using
the cx_oracle ``makedsn()`` function. Otherwise, the host token is taken
directly as a TNS name.
Additional arguments which may be specified either as query string arguments
on the URL, or as keyword arguments to :func:`.create_engine()` are:
* ``allow_twophase`` - enable two-phase transactions. Defaults to ``True``.
* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
to 50. This setting is significant with cx_Oracle as the contents of LOB
objects are only readable within a "live" row (e.g. within a batch of
50 rows).
* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
* ``auto_setinputsizes`` - the cx_oracle.setinputsizes() call is issued for
all bind parameters. This is required for LOB datatypes but can be
disabled to reduce overhead. Defaults to ``True``. Specific types
can be excluded from this process using the ``exclude_setinputsizes``
parameter.
* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail.
* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail.
* ``exclude_setinputsizes`` - a tuple or list of string DBAPI type names to
be excluded from the "auto setinputsizes" feature. The type names here
must match DBAPI types that are found in the "cx_Oracle" module namespace,
such as cx_Oracle.UNICODE, cx_Oracle.NCLOB, etc. Defaults to
``(STRING, UNICODE)``.
.. versionadded:: 0.8 specific DBAPI types can be excluded from the
auto_setinputsizes feature via the exclude_setinputsizes attribute.
* ``mode`` - This is given the string value of SYSDBA or SYSOPER, or alternatively
an integer value. This value is only available as a URL query string
argument.
* ``threaded`` - enable multithreaded access to cx_oracle connections. Defaults
to ``True``. Note that this is the opposite default of the cx_Oracle DBAPI
itself.
.. _cx_oracle_unicode:
Unicode
-------
The cx_Oracle DBAPI as of version 5 fully supports unicode, and has the ability
to return string results as Python unicode objects natively.
When used in Python 3, cx_Oracle returns all strings as Python unicode objects
(that is, plain ``str`` in Python 3). In Python 2, it will return as Python
unicode those column values that are of type ``NVARCHAR`` or ``NCLOB``. For
column values that are of type ``VARCHAR`` or other non-unicode string types,
it will return values as Python strings (e.g. bytestrings).
The cx_Oracle SQLAlchemy dialect presents two different options for the use case of
returning ``VARCHAR`` column values as Python unicode objects under Python 2:
* the cx_Oracle DBAPI has the ability to coerce all string results to Python
unicode objects unconditionally using output type handlers. This has
the advantage that the unicode conversion is global to all statements
at the cx_Oracle driver level, meaning it works with raw textual SQL
statements that have no typing information associated. However, this system
has been observed to incur signfiicant performance overhead, not only because
it takes effect for all string values unconditionally, but also because cx_Oracle under
Python 2 seems to use a pure-Python function call in order to do the
decode operation, which under cPython can orders of magnitude slower
than doing it using C functions alone.
* SQLAlchemy has unicode-decoding services built in, and when using SQLAlchemy's
C extensions, these functions do not use any Python function calls and
are very fast. The disadvantage to this approach is that the unicode
conversion only takes effect for statements where the :class:`.Unicode` type
or :class:`.String` type with ``convert_unicode=True`` is explicitly
associated with the result column. This is the case for any ORM or Core
query or SQL expression as well as for a :func:`.text` construct that specifies
output column types, so in the vast majority of cases this is not an issue.
However, when sending a completely raw string to :meth:`.Connection.execute`,
this typing information isn't present, unless the string is handled
within a :func:`.text` construct that adds typing information.
As of version 0.9.2 of SQLAlchemy, the default approach is to use SQLAlchemy's
typing system. This keeps cx_Oracle's expensive Python 2 approach
disabled unless the user explicitly wants it. Under Python 3, SQLAlchemy detects
that cx_Oracle is returning unicode objects natively and cx_Oracle's system
is used.
To re-enable cx_Oracle's output type handler under Python 2, the
``coerce_to_unicode=True`` flag (new in 0.9.4) can be passed to
:func:`.create_engine`::
engine = create_engine("oracle+cx_oracle://dsn", coerce_to_unicode=True)
Alternatively, to run a pure string SQL statement and get ``VARCHAR`` results
as Python unicode under Python 2 without using cx_Oracle's native handlers,
the :func:`.text` feature can be used::
from sqlalchemy import text, Unicode
result = conn.execute(text("select username from user").columns(username=Unicode))
.. versionchanged:: 0.9.2 cx_Oracle's outputtypehandlers are no longer used for
unicode results of non-unicode datatypes in Python 2, after they were identified as a major
performance bottleneck. SQLAlchemy's own unicode facilities are used
instead.
.. versionadded:: 0.9.4 Added the ``coerce_to_unicode`` flag, to re-enable
cx_Oracle's outputtypehandler and revert to pre-0.9.2 behavior.
.. _cx_oracle_returning:
RETURNING Support
-----------------
The cx_oracle DBAPI supports a limited subset of Oracle's already limited RETURNING support.
Typically, results can only be guaranteed for at most one column being returned;
this is the typical case when SQLAlchemy uses RETURNING to get just the value of a
primary-key-associated sequence value. Additional column expressions will
cause problems in a non-determinative way, due to cx_oracle's lack of support for
the OCI_DATA_AT_EXEC API which is required for more complex RETURNING scenarios.
For this reason, stability may be enhanced by disabling RETURNING support completely;
SQLAlchemy otherwise will use RETURNING to fetch newly sequence-generated
primary keys. As illustrated in :ref:`oracle_returning`::
engine = create_engine("oracle://scott:tiger@dsn", implicit_returning=False)
.. seealso::
http://docs.oracle.com/cd/B10501_01/appdev.920/a96584/oci05bnd.htm#420693 - OCI documentation for RETURNING
http://sourceforge.net/mailarchive/message.php?msg_id=31338136 - cx_oracle developer commentary
.. _cx_oracle_lob:
LOB Objects
-----------
cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy converts
these to strings so that the interface of the Binary type is consistent with that of
other backends, and so that the linkage to a live cursor is not needed in scenarios
like result.fetchmany() and result.fetchall(). This means that by default, LOB
objects are fully fetched unconditionally by SQLAlchemy, and the linkage to a live
cursor is broken.
To disable this processing, pass ``auto_convert_lobs=False`` to :func:`.create_engine()`.
Two Phase Transaction Support
-----------------------------
Two Phase transactions are implemented using XA transactions, and are known
to work in a rudimental fashion with recent versions of cx_Oracle
as of SQLAlchemy 0.8.0b2, 0.7.10. However, the mechanism is not yet
considered to be robust and should still be regarded as experimental.
In particular, the cx_Oracle DBAPI as recently as 5.1.2 has a bug regarding
two phase which prevents
a particular DBAPI connection from being consistently usable in both
prepared transactions as well as traditional DBAPI usage patterns; therefore
once a particular connection is used via :meth:`.Connection.begin_prepared`,
all subsequent usages of the underlying DBAPI connection must be within
the context of prepared transactions.
The default behavior of :class:`.Engine` is to maintain a pool of DBAPI
connections. Therefore, due to the above glitch, a DBAPI connection that has
been used in a two-phase operation, and is then returned to the pool, will
not be usable in a non-two-phase context. To avoid this situation,
the application can make one of several choices:
* Disable connection pooling using :class:`.NullPool`
* Ensure that the particular :class:`.Engine` in use is only used
for two-phase operations. A :class:`.Engine` bound to an ORM
:class:`.Session` which includes ``twophase=True`` will consistently
use the two-phase transaction style.
* For ad-hoc two-phase operations without disabling pooling, the DBAPI
connection in use can be evicted from the connection pool using the
:meth:`.Connection.detach` method.
.. versionchanged:: 0.8.0b2,0.7.10
Support for cx_oracle prepared transactions has been implemented
and tested.
.. _cx_oracle_numeric:
Precision Numerics
------------------
The SQLAlchemy dialect goes through a lot of steps to ensure
that decimal numbers are sent and received with full accuracy.
An "outputtypehandler" callable is associated with each
cx_oracle connection object which detects numeric types and
receives them as string values, instead of receiving a Python
``float`` directly, which is then passed to the Python
``Decimal`` constructor. The :class:`.Numeric` and
:class:`.Float` types under the cx_oracle dialect are aware of
this behavior, and will coerce the ``Decimal`` to ``float`` if
the ``asdecimal`` flag is ``False`` (default on :class:`.Float`,
optional on :class:`.Numeric`).
Because the handler coerces to ``Decimal`` in all cases first,
the feature can detract significantly from performance.
If precision numerics aren't required, the decimal handling
can be disabled by passing the flag ``coerce_to_decimal=False``
to :func:`.create_engine`::
engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False)
.. versionadded:: 0.7.6
Add the ``coerce_to_decimal`` flag.
Another alternative to performance is to use the
`cdecimal <http://pypi.python.org/pypi/cdecimal/>`_ library;
see :class:`.Numeric` for additional notes.
The handler attempts to use the "precision" and "scale"
attributes of the result set column to best determine if
subsequent incoming values should be received as ``Decimal`` as
opposed to int (in which case no processing is added). There are
several scenarios where OCI_ does not provide unambiguous data
as to the numeric type, including some situations where
individual rows may return a combination of floating point and
integer values. Certain values for "precision" and "scale" have
been observed to determine this scenario. When it occurs, the
outputtypehandler receives as string and then passes off to a
processing function which detects, for each returned value, if a
decimal point is present, and if so converts to ``Decimal``,
otherwise to int. The intention is that simple int-based
statements like "SELECT my_seq.nextval() FROM DUAL" continue to
return ints and not ``Decimal`` objects, and that any kind of
floating point value is received as a string so that there is no
floating point loss of precision.
The "decimal point is present" logic itself is also sensitive to
locale. Under OCI_, this is controlled by the NLS_LANG
environment variable. Upon first connection, the dialect runs a
test to determine the current "decimal" character, which can be
a comma "," for european locales. From that point forward the
outputtypehandler uses that character to represent a decimal
point. Note that cx_oracle 5.0.3 or greater is required
when dealing with numerics with locale settings that don't use
a period "." as the decimal character.
.. versionchanged:: 0.6.6
The outputtypehandler supports the case where the locale uses a
comma "," character to represent a decimal point.
.. _OCI: http://www.oracle.com/technetwork/database/features/oci/index.html
"""
from __future__ import absolute_import
from .base import OracleCompiler, OracleDialect, OracleExecutionContext
from . import base as oracle
from ...engine import result as _result
from sqlalchemy import types as sqltypes, util, exc, processors
import random
import collections
import decimal
import re
class _OracleNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
# cx_oracle accepts Decimal objects and floats
return None
def result_processor(self, dialect, coltype):
# we apply a cx_oracle type handler to all connections
# that converts floating point strings to Decimal().
# However, in some subquery situations, Oracle doesn't
# give us enough information to determine int or Decimal.
# It could even be int/Decimal differently on each row,
# regardless of the scale given for the originating type.
# So we still need an old school isinstance() handler
# here for decimals.
if dialect.supports_native_decimal:
if self.asdecimal:
fstring = "%%.%df" % self._effective_decimal_return_scale
def to_decimal(value):
if value is None:
return None
elif isinstance(value, decimal.Decimal):
return value
else:
return decimal.Decimal(fstring % value)
return to_decimal
else:
if self.precision is None and self.scale is None:
return processors.to_float
elif not getattr(self, '_is_oracle_number', False) \
and self.scale is not None:
return processors.to_float
else:
return None
else:
# cx_oracle 4 behavior, will assume
# floats
return super(_OracleNumeric, self).\
result_processor(dialect, coltype)
class _OracleDate(sqltypes.Date):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return value.date()
else:
return value
return process
class _LOBMixin(object):
def result_processor(self, dialect, coltype):
if not dialect.auto_convert_lobs:
# return the cx_oracle.LOB directly.
return None
def process(value):
if value is not None:
return value.read()
else:
return value
return process
class _NativeUnicodeMixin(object):
if util.py2k:
def bind_processor(self, dialect):
if dialect._cx_oracle_with_unicode:
def process(value):
if value is None:
return value
else:
return unicode(value)
return process
else:
return super(_NativeUnicodeMixin, self).bind_processor(dialect)
# we apply a connection output handler that returns
# unicode in all cases, so the "native_unicode" flag
# will be set for the default String.result_processor.
class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR):
def get_dbapi_type(self, dbapi):
return dbapi.FIXED_CHAR
class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR):
def get_dbapi_type(self, dbapi):
return getattr(dbapi, 'UNICODE', dbapi.STRING)
class _OracleText(_LOBMixin, sqltypes.Text):
def get_dbapi_type(self, dbapi):
return dbapi.CLOB
class _OracleLong(oracle.LONG):
# a raw LONG is a text type, but does *not*
# get the LobMixin with cx_oracle.
def get_dbapi_type(self, dbapi):
return dbapi.LONG_STRING
class _OracleString(_NativeUnicodeMixin, sqltypes.String):
pass
class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText):
def get_dbapi_type(self, dbapi):
return dbapi.NCLOB
def result_processor(self, dialect, coltype):
lob_processor = _LOBMixin.result_processor(self, dialect, coltype)
if lob_processor is None:
return None
string_processor = sqltypes.UnicodeText.result_processor(self, dialect, coltype)
if string_processor is None:
return lob_processor
else:
def process(value):
return string_processor(lob_processor(value))
return process
class _OracleInteger(sqltypes.Integer):
def result_processor(self, dialect, coltype):
def to_int(val):
if val is not None:
val = int(val)
return val
return to_int
class _OracleBinary(_LOBMixin, sqltypes.LargeBinary):
def get_dbapi_type(self, dbapi):
return dbapi.BLOB
def bind_processor(self, dialect):
return None
class _OracleInterval(oracle.INTERVAL):
def get_dbapi_type(self, dbapi):
return dbapi.INTERVAL
class _OracleRaw(oracle.RAW):
pass
class _OracleRowid(oracle.ROWID):
def get_dbapi_type(self, dbapi):
return dbapi.ROWID
class OracleCompiler_cx_oracle(OracleCompiler):
def bindparam_string(self, name, **kw):
quote = getattr(name, 'quote', None)
if quote is True or quote is not False and \
self.preparer._bindparam_requires_quotes(name):
quoted_name = '"%s"' % name
self._quoted_bind_names[name] = quoted_name
return OracleCompiler.bindparam_string(self, quoted_name, **kw)
else:
return OracleCompiler.bindparam_string(self, name, **kw)
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
def pre_exec(self):
quoted_bind_names = \
getattr(self.compiled, '_quoted_bind_names', None)
if quoted_bind_names:
if not self.dialect.supports_unicode_statements:
# if DBAPI doesn't accept unicode statements,
# keys in self.parameters would have been encoded
# here. so convert names in quoted_bind_names
# to encoded as well.
quoted_bind_names = \
dict(
(fromname.encode(self.dialect.encoding),
toname.encode(self.dialect.encoding))
for fromname, toname in
quoted_bind_names.items()
)
for param in self.parameters:
for fromname, toname in quoted_bind_names.items():
param[toname] = param[fromname]
del param[fromname]
if self.dialect.auto_setinputsizes:
# cx_oracle really has issues when you setinputsizes
# on String, including that outparams/RETURNING
# breaks for varchars
self.set_input_sizes(quoted_bind_names,
exclude_types=self.dialect.exclude_setinputsizes
)
# if a single execute, check for outparams
if len(self.compiled_parameters) == 1:
for bindparam in self.compiled.binds.values():
if bindparam.isoutparam:
dbtype = bindparam.type.dialect_impl(self.dialect).\
get_dbapi_type(self.dialect.dbapi)
if not hasattr(self, 'out_parameters'):
self.out_parameters = {}
if dbtype is None:
raise exc.InvalidRequestError(
"Cannot create out parameter for parameter "
"%r - it's type %r is not supported by"
" cx_oracle" %
(bindparam.key, bindparam.type)
)
name = self.compiled.bind_names[bindparam]
self.out_parameters[name] = self.cursor.var(dbtype)
self.parameters[0][quoted_bind_names.get(name, name)] = \
self.out_parameters[name]
def create_cursor(self):
c = self._dbapi_connection.cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
return c
def get_result_proxy(self):
if hasattr(self, 'out_parameters') and self.compiled.returning:
returning_params = dict(
(k, v.getvalue())
for k, v in self.out_parameters.items()
)
return ReturningResultProxy(self, returning_params)
result = None
if self.cursor.description is not None:
for column in self.cursor.description:
type_code = column[1]
if type_code in self.dialect._cx_oracle_binary_types:
result = _result.BufferedColumnResultProxy(self)
if result is None:
result = _result.ResultProxy(self)
if hasattr(self, 'out_parameters'):
if self.compiled_parameters is not None and \
len(self.compiled_parameters) == 1:
result.out_parameters = out_parameters = {}
for bind, name in self.compiled.bind_names.items():
if name in self.out_parameters:
type = bind.type
impl_type = type.dialect_impl(self.dialect)
dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
result_processor = impl_type.\
result_processor(self.dialect,
dbapi_type)
if result_processor is not None:
out_parameters[name] = \
result_processor(self.out_parameters[name].getvalue())
else:
out_parameters[name] = self.out_parameters[name].getvalue()
else:
result.out_parameters = dict(
(k, v.getvalue())
for k, v in self.out_parameters.items()
)
return result
class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle):
"""Support WITH_UNICODE in Python 2.xx.
WITH_UNICODE allows cx_Oracle's Python 3 unicode handling
behavior under Python 2.x. This mode in some cases disallows
and in other cases silently passes corrupted data when
non-Python-unicode strings (a.k.a. plain old Python strings)
are passed as arguments to connect(), the statement sent to execute(),
or any of the bind parameter keys or values sent to execute().
This optional context therefore ensures that all statements are
passed as Python unicode objects.
"""
def __init__(self, *arg, **kw):
OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw)
self.statement = util.text_type(self.statement)
def _execute_scalar(self, stmt):
return super(OracleExecutionContext_cx_oracle_with_unicode, self).\
_execute_scalar(util.text_type(stmt))
class ReturningResultProxy(_result.FullyBufferedResultProxy):
"""Result proxy which stuffs the _returning clause + outparams into the fetch."""
def __init__(self, context, returning_params):
self._returning_params = returning_params
super(ReturningResultProxy, self).__init__(context)
def _cursor_description(self):
returning = self.context.compiled.returning
return [
("ret_%d" % i, None)
for i, col in enumerate(returning)
]
def _buffer_rows(self):
return collections.deque([tuple(self._returning_params["ret_%d" % i]
for i, c in enumerate(self._returning_params))])
class OracleDialect_cx_oracle(OracleDialect):
execution_ctx_cls = OracleExecutionContext_cx_oracle
statement_compiler = OracleCompiler_cx_oracle
driver = "cx_oracle"
colspecs = colspecs = {
sqltypes.Numeric: _OracleNumeric,
sqltypes.Date: _OracleDate, # generic type, assume datetime.date is desired
sqltypes.LargeBinary: _OracleBinary,
sqltypes.Boolean: oracle._OracleBoolean,
sqltypes.Interval: _OracleInterval,
oracle.INTERVAL: _OracleInterval,
sqltypes.Text: _OracleText,
sqltypes.String: _OracleString,
sqltypes.UnicodeText: _OracleUnicodeText,
sqltypes.CHAR: _OracleChar,
# a raw LONG is a text type, but does *not*
# get the LobMixin with cx_oracle.
oracle.LONG: _OracleLong,
# this is only needed for OUT parameters.
# it would be nice if we could not use it otherwise.
sqltypes.Integer: _OracleInteger,
oracle.RAW: _OracleRaw,
sqltypes.Unicode: _OracleNVarChar,
sqltypes.NVARCHAR: _OracleNVarChar,
oracle.ROWID: _OracleRowid,
}
execute_sequence_format = list
def __init__(self,
auto_setinputsizes=True,
exclude_setinputsizes=("STRING", "UNICODE"),
auto_convert_lobs=True,
threaded=True,
allow_twophase=True,
coerce_to_decimal=True,
coerce_to_unicode=False,
arraysize=50, **kwargs):
OracleDialect.__init__(self, **kwargs)
self.threaded = threaded
self.arraysize = arraysize
self.allow_twophase = allow_twophase
self.supports_timestamp = self.dbapi is None or \
hasattr(self.dbapi, 'TIMESTAMP')
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
if hasattr(self.dbapi, 'version'):
self.cx_oracle_ver = tuple([int(x) for x in
self.dbapi.version.split('.')])
else:
self.cx_oracle_ver = (0, 0, 0)
def types(*names):
return set(
getattr(self.dbapi, name, None) for name in names
).difference([None])
self.exclude_setinputsizes = types(*(exclude_setinputsizes or ()))
self._cx_oracle_string_types = types("STRING", "UNICODE",
"NCLOB", "CLOB")
self._cx_oracle_unicode_types = types("UNICODE", "NCLOB")
self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB")
self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0)
self.coerce_to_unicode = (
self.cx_oracle_ver >= (5, 0) and
coerce_to_unicode
)
self.supports_native_decimal = (
self.cx_oracle_ver >= (5, 0) and
coerce_to_decimal
)
self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0)
if self.cx_oracle_ver is None:
# this occurs in tests with mock DBAPIs
self._cx_oracle_string_types = set()
self._cx_oracle_with_unicode = False
elif self.cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'):
# cx_Oracle WITH_UNICODE mode. *only* python
# unicode objects accepted for anything
self.supports_unicode_statements = True
self.supports_unicode_binds = True
self._cx_oracle_with_unicode = True
if util.py2k:
# There's really no reason to run with WITH_UNICODE under Python 2.x.
# Give the user a hint.
util.warn(
"cx_Oracle is compiled under Python 2.xx using the "
"WITH_UNICODE flag. Consider recompiling cx_Oracle "
"without this flag, which is in no way necessary for full "
"support of Unicode. Otherwise, all string-holding bind "
"parameters must be explicitly typed using SQLAlchemy's "
"String type or one of its subtypes,"
"or otherwise be passed as Python unicode. "
"Plain Python strings passed as bind parameters will be "
"silently corrupted by cx_Oracle."
)
self.execution_ctx_cls = \
OracleExecutionContext_cx_oracle_with_unicode
else:
self._cx_oracle_with_unicode = False
if self.cx_oracle_ver is None or \
not self.auto_convert_lobs or \
not hasattr(self.dbapi, 'CLOB'):
self.dbapi_type_map = {}
else:
# only use this for LOB objects. using it for strings, dates
# etc. leads to a little too much magic, reflection doesn't know if it should
# expect encoded strings or unicodes, etc.
self.dbapi_type_map = {
self.dbapi.CLOB: oracle.CLOB(),
self.dbapi.NCLOB: oracle.NCLOB(),
self.dbapi.BLOB: oracle.BLOB(),
self.dbapi.BINARY: oracle.RAW(),
}
@classmethod
def dbapi(cls):
import cx_Oracle
return cx_Oracle
def initialize(self, connection):
super(OracleDialect_cx_oracle, self).initialize(connection)
if self._is_oracle_8:
self.supports_unicode_binds = False
self._detect_decimal_char(connection)
def _detect_decimal_char(self, connection):
"""detect if the decimal separator character is not '.', as
is the case with european locale settings for NLS_LANG.
cx_oracle itself uses similar logic when it formats Python
Decimal objects to strings on the bind side (as of 5.0.3),
as Oracle sends/receives string numerics only in the
current locale.
"""
if self.cx_oracle_ver < (5,):
# no output type handlers before version 5
return
cx_Oracle = self.dbapi
conn = connection.connection
# override the output_type_handler that's
# on the cx_oracle connection with a plain
# one on the cursor
def output_type_handler(cursor, name, defaultType,
size, precision, scale):
return cursor.var(
cx_Oracle.STRING,
255, arraysize=cursor.arraysize)
cursor = conn.cursor()
cursor.outputtypehandler = output_type_handler
cursor.execute("SELECT 0.1 FROM DUAL")
val = cursor.fetchone()[0]
cursor.close()
char = re.match(r"([\.,])", val).group(1)
if char != '.':
_detect_decimal = self._detect_decimal
self._detect_decimal = \
lambda value: _detect_decimal(value.replace(char, '.'))
self._to_decimal = \
lambda value: decimal.Decimal(value.replace(char, '.'))
def _detect_decimal(self, value):
if "." in value:
return decimal.Decimal(value)
else:
return int(value)
_to_decimal = decimal.Decimal
def on_connect(self):
if self.cx_oracle_ver < (5,):
# no output type handlers before version 5
return
cx_Oracle = self.dbapi
def output_type_handler(cursor, name, defaultType,
size, precision, scale):
# convert all NUMBER with precision + positive scale to Decimal
# this almost allows "native decimal" mode.
if self.supports_native_decimal and \
defaultType == cx_Oracle.NUMBER and \
precision and scale > 0:
return cursor.var(
cx_Oracle.STRING,
255,
outconverter=self._to_decimal,
arraysize=cursor.arraysize)
# if NUMBER with zero precision and 0 or neg scale, this appears
# to indicate "ambiguous". Use a slower converter that will
# make a decision based on each value received - the type
# may change from row to row (!). This kills
# off "native decimal" mode, handlers still needed.
elif self.supports_native_decimal and \
defaultType == cx_Oracle.NUMBER \
and not precision and scale <= 0:
return cursor.var(
cx_Oracle.STRING,
255,
outconverter=self._detect_decimal,
arraysize=cursor.arraysize)
# allow all strings to come back natively as Unicode
elif self.coerce_to_unicode and \
defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR):
return cursor.var(util.text_type, size, cursor.arraysize)
def on_connect(conn):
conn.outputtypehandler = output_type_handler
return on_connect
def create_connect_args(self, url):
dialect_opts = dict(url.query)
for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
'threaded', 'allow_twophase'):
if opt in dialect_opts:
util.coerce_kw_type(dialect_opts, opt, bool)
setattr(self, opt, dialect_opts[opt])
if url.database:
# if we have a database, then we have a remote host
port = url.port
if port:
port = int(port)
else:
port = 1521
dsn = self.dbapi.makedsn(url.host, port, url.database)
else:
# we have a local tnsname
dsn = url.host
opts = dict(
user=url.username,
password=url.password,
dsn=dsn,
threaded=self.threaded,
twophase=self.allow_twophase,
)
if util.py2k:
if self._cx_oracle_with_unicode:
for k, v in opts.items():
if isinstance(v, str):
opts[k] = unicode(v)
else:
for k, v in opts.items():
if isinstance(v, unicode):
opts[k] = str(v)
if 'mode' in url.query:
opts['mode'] = url.query['mode']
if isinstance(opts['mode'], util.string_types):
mode = opts['mode'].upper()
if mode == 'SYSDBA':
opts['mode'] = self.dbapi.SYSDBA
elif mode == 'SYSOPER':
opts['mode'] = self.dbapi.SYSOPER
else:
util.coerce_kw_type(opts, 'mode', int)
return ([], opts)
def _get_server_version_info(self, connection):
return tuple(
int(x)
for x in connection.connection.version.split('.')
)
def is_disconnect(self, e, connection, cursor):
error, = e.args
if isinstance(e, self.dbapi.InterfaceError):
return "not connected" in str(e)
elif hasattr(error, 'code'):
# ORA-00028: your session has been killed
# ORA-03114: not connected to ORACLE
# ORA-03113: end-of-file on communication channel
# ORA-03135: connection lost contact
# ORA-01033: ORACLE initialization or shutdown in progress
# ORA-02396: exceeded maximum idle time, please connect again
# TODO: Others ?
return error.code in (28, 3114, 3113, 3135, 1033, 2396)
else:
return False
def create_xid(self):
"""create a two-phase transaction ID.
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). its format is unspecified."""
id = random.randint(0, 2 ** 128)
return (0x1234, "%032x" % id, "%032x" % 9)
def do_executemany(self, cursor, statement, parameters, context=None):
if isinstance(parameters, tuple):
parameters = list(parameters)
cursor.executemany(statement, parameters)
def do_begin_twophase(self, connection, xid):
connection.connection.begin(*xid)
def do_prepare_twophase(self, connection, xid):
result = connection.connection.prepare()
connection.info['cx_oracle_prepared'] = result
def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False):
self.do_rollback(connection.connection)
def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self.do_commit(connection.connection)
else:
oci_prepared = connection.info['cx_oracle_prepared']
if oci_prepared:
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
connection.info.pop('cx_oracle_prepared', None)
dialect = OracleDialect_cx_oracle

View File

@ -1,218 +0,0 @@
# oracle/zxjdbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: oracle+zxjdbc
:name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: oracle+zxjdbc://user:pass@host/dbname
:driverurl: http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
"""
import decimal
import re
from sqlalchemy import sql, types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext
from sqlalchemy.engine import result as _result
from sqlalchemy.sql import expression
import collections
SQLException = zxJDBC = None
class _ZxJDBCDate(sqltypes.Date):
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
else:
return value.date()
return process
class _ZxJDBCNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
#XXX: does the dialect return Decimal or not???
# if it does (in all cases), we could use a None processor as well as
# the to_float generic processor
if self.asdecimal:
def process(value):
if isinstance(value, decimal.Decimal):
return value
else:
return decimal.Decimal(str(value))
else:
def process(value):
if isinstance(value, decimal.Decimal):
return float(value)
else:
return value
return process
class OracleCompiler_zxjdbc(OracleCompiler):
def returning_clause(self, stmt, returning_cols):
self.returning_cols = list(expression._select_iterables(returning_cols))
# within_columns_clause=False so that labels (foo AS bar) don't render
columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
for c in self.returning_cols]
if not hasattr(self, 'returning_parameters'):
self.returning_parameters = []
binds = []
for i, col in enumerate(self.returning_cols):
dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
self.returning_parameters.append((i + 1, dbtype))
bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
self.binds[bindparam.key] = bindparam
binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
def pre_exec(self):
if hasattr(self.compiled, 'returning_parameters'):
# prepare a zxJDBC statement so we can grab its underlying
# OraclePreparedStatement's getReturnResultSet later
self.statement = self.cursor.prepare(self.statement)
def get_result_proxy(self):
if hasattr(self.compiled, 'returning_parameters'):
rrs = None
try:
try:
rrs = self.statement.__statement__.getReturnResultSet()
next(rrs)
except SQLException as sqle:
msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
if sqle.getSQLState() is not None:
msg += ' [SQLState: %s]' % sqle.getSQLState()
raise zxJDBC.Error(msg)
else:
row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
for index, dbtype in self.compiled.returning_parameters)
return ReturningResultProxy(self, row)
finally:
if rrs is not None:
try:
rrs.close()
except SQLException:
pass
self.statement.close()
return _result.ResultProxy(self)
def create_cursor(self):
cursor = self._dbapi_connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
return cursor
class ReturningResultProxy(_result.FullyBufferedResultProxy):
"""ResultProxy backed by the RETURNING ResultSet results."""
def __init__(self, context, returning_row):
self._returning_row = returning_row
super(ReturningResultProxy, self).__init__(context)
def _cursor_description(self):
ret = []
for c in self.context.compiled.returning_cols:
if hasattr(c, 'name'):
ret.append((c.name, c.type))
else:
ret.append((c.anon_label, c.type))
return ret
def _buffer_rows(self):
return collections.deque([self._returning_row])
class ReturningParam(object):
"""A bindparam value representing a RETURNING parameter.
Specially handled by OracleReturningDataHandler.
"""
def __init__(self, type):
self.type = type
def __eq__(self, other):
if isinstance(other, ReturningParam):
return self.type == other.type
return NotImplemented
def __ne__(self, other):
if isinstance(other, ReturningParam):
return self.type != other.type
return NotImplemented
def __repr__(self):
kls = self.__class__
return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
self.type)
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
jdbc_db_name = 'oracle'
jdbc_driver_name = 'oracle.jdbc.OracleDriver'
statement_compiler = OracleCompiler_zxjdbc
execution_ctx_cls = OracleExecutionContext_zxjdbc
colspecs = util.update_copy(
OracleDialect.colspecs,
{
sqltypes.Date: _ZxJDBCDate,
sqltypes.Numeric: _ZxJDBCNumeric
}
)
def __init__(self, *args, **kwargs):
super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
global SQLException, zxJDBC
from java.sql import SQLException
from com.ziclix.python.sql import zxJDBC
from com.ziclix.python.sql.handler import OracleDataHandler
class OracleReturningDataHandler(OracleDataHandler):
"""zxJDBC DataHandler that specially handles ReturningParam."""
def setJDBCObject(self, statement, index, object, dbtype=None):
if type(object) is ReturningParam:
statement.registerReturnParameter(index, object.type)
elif dbtype is None:
OracleDataHandler.setJDBCObject(
self, statement, index, object)
else:
OracleDataHandler.setJDBCObject(
self, statement, index, object, dbtype)
self.DataHandler = OracleReturningDataHandler
def initialize(self, connection):
super(OracleDialect_zxjdbc, self).initialize(connection)
self.implicit_returning = connection.connection.driverversion >= '10.2'
def _create_jdbc_url(self, url):
return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
def _get_server_version_info(self, connection):
version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
return tuple(int(x) for x in version.split('.'))
dialect = OracleDialect_zxjdbc

Some files were not shown because too many files have changed in this diff Show More