From c577ff2887dfa58a1715aa61bee3297e67a7422f Mon Sep 17 00:00:00 2001 From: echel0n Date: Tue, 15 Jul 2014 11:40:40 -0700 Subject: [PATCH] Reverted back to using Shove+SQLAlchemy for storing persistent object data to avoid any more DB corruption errors. --- lib/shove/__init__.py | 519 +++ lib/shove/cache/__init__.py | 1 + lib/shove/cache/db.py | 117 + lib/shove/cache/file.py | 46 + lib/shove/cache/filelru.py | 23 + lib/shove/cache/memcached.py | 43 + lib/shove/cache/memlru.py | 38 + lib/shove/cache/memory.py | 38 + lib/shove/cache/redisdb.py | 45 + lib/shove/cache/simple.py | 68 + lib/shove/cache/simplelru.py | 18 + lib/shove/store/__init__.py | 48 + lib/shove/store/bsdb.py | 48 + lib/shove/store/cassandra.py | 72 + lib/shove/store/db.py | 73 + lib/shove/store/dbm.py | 33 + lib/shove/store/durusdb.py | 43 + lib/shove/store/file.py | 25 + lib/shove/store/ftp.py | 88 + lib/shove/store/hdf5.py | 34 + lib/shove/store/leveldbstore.py | 47 + lib/shove/store/memory.py | 38 + lib/shove/store/redisdb.py | 50 + lib/shove/store/s3.py | 91 + lib/shove/store/simple.py | 21 + lib/shove/store/svn.py | 110 + lib/shove/store/zodb.py | 48 + lib/shove/tests/__init__.py | 1 + lib/shove/tests/test_bsddb_store.py | 133 + lib/shove/tests/test_cassandra_store.py | 137 + lib/shove/tests/test_db_cache.py | 54 + lib/shove/tests/test_db_store.py | 131 + lib/shove/tests/test_dbm_store.py | 136 + lib/shove/tests/test_durus_store.py | 133 + lib/shove/tests/test_file_cache.py | 58 + lib/shove/tests/test_file_store.py | 140 + lib/shove/tests/test_ftp_store.py | 149 + lib/shove/tests/test_hdf5_store.py | 135 + lib/shove/tests/test_leveldb_store.py | 132 + lib/shove/tests/test_memcached_cache.py | 46 + lib/shove/tests/test_memory_cache.py | 54 + lib/shove/tests/test_memory_store.py | 135 + lib/shove/tests/test_redis_cache.py | 45 + lib/shove/tests/test_redis_store.py | 128 + lib/shove/tests/test_s3_store.py | 149 + lib/shove/tests/test_simple_cache.py | 54 + lib/shove/tests/test_simple_store.py | 135 + lib/shove/tests/test_svn_store.py | 148 + lib/shove/tests/test_zodb_store.py | 138 + lib/sqlalchemy/__init__.py | 133 + lib/sqlalchemy/cextension/processors.c | 706 ++++ lib/sqlalchemy/cextension/resultproxy.c | 718 ++++ lib/sqlalchemy/cextension/utils.c | 225 ++ lib/sqlalchemy/connectors/__init__.py | 9 + lib/sqlalchemy/connectors/mxodbc.py | 149 + lib/sqlalchemy/connectors/mysqldb.py | 144 + lib/sqlalchemy/connectors/pyodbc.py | 170 + lib/sqlalchemy/connectors/zxJDBC.py | 59 + lib/sqlalchemy/databases/__init__.py | 31 + lib/sqlalchemy/dialects/__init__.py | 44 + lib/sqlalchemy/dialects/drizzle/__init__.py | 22 + lib/sqlalchemy/dialects/drizzle/base.py | 498 +++ lib/sqlalchemy/dialects/drizzle/mysqldb.py | 48 + lib/sqlalchemy/dialects/firebird/__init__.py | 20 + lib/sqlalchemy/dialects/firebird/base.py | 738 ++++ lib/sqlalchemy/dialects/firebird/fdb.py | 115 + .../dialects/firebird/kinterbasdb.py | 179 + lib/sqlalchemy/dialects/mssql/__init__.py | 26 + lib/sqlalchemy/dialects/mssql/adodbapi.py | 79 + lib/sqlalchemy/dialects/mssql/base.py | 1550 +++++++ .../dialects/mssql/information_schema.py | 114 + lib/sqlalchemy/dialects/mssql/mxodbc.py | 111 + lib/sqlalchemy/dialects/mssql/pymssql.py | 92 + lib/sqlalchemy/dialects/mssql/pyodbc.py | 260 ++ lib/sqlalchemy/dialects/mssql/zxjdbc.py | 65 + lib/sqlalchemy/dialects/mysql/__init__.py | 28 + lib/sqlalchemy/dialects/mysql/base.py | 3078 ++++++++++++++ lib/sqlalchemy/dialects/mysql/cymysql.py | 84 + lib/sqlalchemy/dialects/mysql/gaerdbms.py | 84 + .../dialects/mysql/mysqlconnector.py | 131 + lib/sqlalchemy/dialects/mysql/mysqldb.py | 94 + lib/sqlalchemy/dialects/mysql/oursql.py | 261 ++ lib/sqlalchemy/dialects/mysql/pymysql.py | 45 + lib/sqlalchemy/dialects/mysql/pyodbc.py | 80 + lib/sqlalchemy/dialects/mysql/zxjdbc.py | 111 + lib/sqlalchemy/dialects/oracle/__init__.py | 23 + lib/sqlalchemy/dialects/oracle/base.py | 1291 ++++++ lib/sqlalchemy/dialects/oracle/cx_oracle.py | 941 +++++ lib/sqlalchemy/dialects/oracle/zxjdbc.py | 218 + lib/sqlalchemy/dialects/postgres.py | 16 + .../dialects/postgresql/__init__.py | 29 + lib/sqlalchemy/dialects/postgresql/base.py | 2367 +++++++++++ .../dialects/postgresql/constraints.py | 73 + lib/sqlalchemy/dialects/postgresql/hstore.py | 369 ++ lib/sqlalchemy/dialects/postgresql/json.py | 199 + lib/sqlalchemy/dialects/postgresql/pg8000.py | 126 + .../dialects/postgresql/psycopg2.py | 515 +++ .../dialects/postgresql/pypostgresql.py | 78 + lib/sqlalchemy/dialects/postgresql/ranges.py | 160 + lib/sqlalchemy/dialects/postgresql/zxjdbc.py | 45 + lib/sqlalchemy/dialects/sqlite/__init__.py | 19 + lib/sqlalchemy/dialects/sqlite/base.py | 1049 +++++ lib/sqlalchemy/dialects/sqlite/pysqlite.py | 335 ++ lib/sqlalchemy/dialects/sybase/__init__.py | 27 + lib/sqlalchemy/dialects/sybase/base.py | 816 ++++ lib/sqlalchemy/dialects/sybase/mxodbc.py | 32 + lib/sqlalchemy/dialects/sybase/pyodbc.py | 84 + lib/sqlalchemy/dialects/sybase/pysybase.py | 100 + .../dialects/type_migration_guidelines.txt | 145 + lib/sqlalchemy/engine/__init__.py | 372 ++ lib/sqlalchemy/engine/base.py | 1808 +++++++++ lib/sqlalchemy/engine/default.py | 957 +++++ lib/sqlalchemy/engine/interfaces.py | 849 ++++ lib/sqlalchemy/engine/reflection.py | 590 +++ lib/sqlalchemy/engine/result.py | 1033 +++++ lib/sqlalchemy/engine/strategies.py | 258 ++ lib/sqlalchemy/engine/threadlocal.py | 134 + lib/sqlalchemy/engine/url.py | 227 ++ lib/sqlalchemy/engine/util.py | 72 + lib/sqlalchemy/event/__init__.py | 10 + lib/sqlalchemy/event/api.py | 131 + lib/sqlalchemy/event/attr.py | 386 ++ lib/sqlalchemy/event/base.py | 217 + lib/sqlalchemy/event/legacy.py | 156 + lib/sqlalchemy/event/registry.py | 241 ++ lib/sqlalchemy/events.py | 924 +++++ lib/sqlalchemy/exc.py | 363 ++ lib/sqlalchemy/ext/__init__.py | 5 + lib/sqlalchemy/ext/associationproxy.py | 1053 +++++ lib/sqlalchemy/ext/automap.py | 907 +++++ lib/sqlalchemy/ext/compiler.py | 448 +++ lib/sqlalchemy/ext/declarative/__init__.py | 1316 ++++++ lib/sqlalchemy/ext/declarative/api.py | 513 +++ lib/sqlalchemy/ext/declarative/base.py | 532 +++ lib/sqlalchemy/ext/declarative/clsregistry.py | 309 ++ lib/sqlalchemy/ext/horizontal_shard.py | 128 + lib/sqlalchemy/ext/hybrid.py | 808 ++++ lib/sqlalchemy/ext/instrumentation.py | 407 ++ lib/sqlalchemy/ext/mutable.py | 636 +++ lib/sqlalchemy/ext/orderinglist.py | 376 ++ lib/sqlalchemy/ext/serializer.py | 156 + lib/sqlalchemy/inspection.py | 92 + lib/sqlalchemy/interfaces.py | 310 ++ lib/sqlalchemy/log.py | 214 + lib/sqlalchemy/orm/__init__.py | 267 ++ lib/sqlalchemy/orm/attributes.py | 1541 +++++++ lib/sqlalchemy/orm/base.py | 457 +++ lib/sqlalchemy/orm/collections.py | 1550 +++++++ lib/sqlalchemy/orm/dependency.py | 1165 ++++++ lib/sqlalchemy/orm/deprecated_interfaces.py | 590 +++ lib/sqlalchemy/orm/descriptor_props.py | 678 ++++ lib/sqlalchemy/orm/dynamic.py | 368 ++ lib/sqlalchemy/orm/evaluator.py | 123 + lib/sqlalchemy/orm/events.py | 1711 ++++++++ lib/sqlalchemy/orm/exc.py | 163 + lib/sqlalchemy/orm/identity.py | 240 ++ lib/sqlalchemy/orm/instrumentation.py | 499 +++ lib/sqlalchemy/orm/interfaces.py | 577 +++ lib/sqlalchemy/orm/loading.py | 611 +++ lib/sqlalchemy/orm/mapper.py | 2709 +++++++++++++ lib/sqlalchemy/orm/path_registry.py | 261 ++ lib/sqlalchemy/orm/persistence.py | 1107 +++++ lib/sqlalchemy/orm/properties.py | 259 ++ lib/sqlalchemy/orm/query.py | 3564 +++++++++++++++++ lib/sqlalchemy/orm/relationships.py | 2646 ++++++++++++ lib/sqlalchemy/orm/scoping.py | 176 + lib/sqlalchemy/orm/session.py | 2407 +++++++++++ lib/sqlalchemy/orm/state.py | 617 +++ lib/sqlalchemy/orm/strategies.py | 1459 +++++++ lib/sqlalchemy/orm/strategy_options.py | 948 +++++ lib/sqlalchemy/orm/sync.py | 118 + lib/sqlalchemy/orm/unitofwork.py | 646 +++ lib/sqlalchemy/orm/util.py | 953 +++++ lib/sqlalchemy/pool.py | 1250 ++++++ lib/sqlalchemy/processors.py | 152 + lib/sqlalchemy/schema.py | 61 + lib/sqlalchemy/sql/__init__.py | 90 + lib/sqlalchemy/sql/annotation.py | 191 + lib/sqlalchemy/sql/base.py | 621 +++ lib/sqlalchemy/sql/compiler.py | 2957 ++++++++++++++ lib/sqlalchemy/sql/ddl.py | 864 ++++ lib/sqlalchemy/sql/default_comparator.py | 282 ++ lib/sqlalchemy/sql/dml.py | 770 ++++ lib/sqlalchemy/sql/elements.py | 3451 ++++++++++++++++ lib/sqlalchemy/sql/expression.py | 127 + lib/sqlalchemy/sql/functions.py | 534 +++ lib/sqlalchemy/sql/naming.py | 165 + lib/sqlalchemy/sql/operators.py | 867 ++++ lib/sqlalchemy/sql/schema.py | 3386 ++++++++++++++++ lib/sqlalchemy/sql/selectable.py | 3115 ++++++++++++++ lib/sqlalchemy/sql/sqltypes.py | 1639 ++++++++ lib/sqlalchemy/sql/type_api.py | 1064 +++++ lib/sqlalchemy/sql/util.py | 601 +++ lib/sqlalchemy/sql/visitors.py | 314 ++ lib/sqlalchemy/testing/__init__.py | 32 + lib/sqlalchemy/testing/assertions.py | 439 ++ lib/sqlalchemy/testing/assertsql.py | 333 ++ lib/sqlalchemy/testing/config.py | 77 + lib/sqlalchemy/testing/engines.py | 448 +++ lib/sqlalchemy/testing/entities.py | 99 + lib/sqlalchemy/testing/exclusions.py | 363 ++ lib/sqlalchemy/testing/fixtures.py | 380 ++ lib/sqlalchemy/testing/mock.py | 21 + lib/sqlalchemy/testing/pickleable.py | 142 + lib/sqlalchemy/testing/plugin/__init__.py | 0 lib/sqlalchemy/testing/plugin/noseplugin.py | 89 + lib/sqlalchemy/testing/plugin/plugin_base.py | 455 +++ lib/sqlalchemy/testing/plugin/pytestplugin.py | 125 + lib/sqlalchemy/testing/profiling.py | 309 ++ lib/sqlalchemy/testing/requirements.py | 631 +++ lib/sqlalchemy/testing/runner.py | 48 + lib/sqlalchemy/testing/schema.py | 100 + lib/sqlalchemy/testing/suite/__init__.py | 9 + lib/sqlalchemy/testing/suite/test_ddl.py | 63 + lib/sqlalchemy/testing/suite/test_insert.py | 230 ++ .../testing/suite/test_reflection.py | 540 +++ lib/sqlalchemy/testing/suite/test_results.py | 220 + lib/sqlalchemy/testing/suite/test_select.py | 86 + lib/sqlalchemy/testing/suite/test_sequence.py | 128 + lib/sqlalchemy/testing/suite/test_types.py | 598 +++ .../testing/suite/test_update_delete.py | 63 + lib/sqlalchemy/testing/util.py | 205 + lib/sqlalchemy/testing/warnings.py | 56 + lib/sqlalchemy/types.py | 77 + lib/sqlalchemy/util/__init__.py | 45 + lib/sqlalchemy/util/_collections.py | 957 +++++ lib/sqlalchemy/util/compat.py | 215 + lib/sqlalchemy/util/deprecations.py | 143 + lib/sqlalchemy/util/langhelpers.py | 1231 ++++++ lib/sqlalchemy/util/queue.py | 199 + lib/sqlalchemy/util/topological.py | 96 + sickbeard/name_parser/parser.py | 42 +- sickbeard/rssfeeds.py | 34 +- 233 files changed, 100705 insertions(+), 26 deletions(-) create mode 100644 lib/shove/__init__.py create mode 100644 lib/shove/cache/__init__.py create mode 100644 lib/shove/cache/db.py create mode 100644 lib/shove/cache/file.py create mode 100644 lib/shove/cache/filelru.py create mode 100644 lib/shove/cache/memcached.py create mode 100644 lib/shove/cache/memlru.py create mode 100644 lib/shove/cache/memory.py create mode 100644 lib/shove/cache/redisdb.py create mode 100644 lib/shove/cache/simple.py create mode 100644 lib/shove/cache/simplelru.py create mode 100644 lib/shove/store/__init__.py create mode 100644 lib/shove/store/bsdb.py create mode 100644 lib/shove/store/cassandra.py create mode 100644 lib/shove/store/db.py create mode 100644 lib/shove/store/dbm.py create mode 100644 lib/shove/store/durusdb.py create mode 100644 lib/shove/store/file.py create mode 100644 lib/shove/store/ftp.py create mode 100644 lib/shove/store/hdf5.py create mode 100644 lib/shove/store/leveldbstore.py create mode 100644 lib/shove/store/memory.py create mode 100644 lib/shove/store/redisdb.py create mode 100644 lib/shove/store/s3.py create mode 100644 lib/shove/store/simple.py create mode 100644 lib/shove/store/svn.py create mode 100644 lib/shove/store/zodb.py create mode 100644 lib/shove/tests/__init__.py create mode 100644 lib/shove/tests/test_bsddb_store.py create mode 100644 lib/shove/tests/test_cassandra_store.py create mode 100644 lib/shove/tests/test_db_cache.py create mode 100644 lib/shove/tests/test_db_store.py create mode 100644 lib/shove/tests/test_dbm_store.py create mode 100644 lib/shove/tests/test_durus_store.py create mode 100644 lib/shove/tests/test_file_cache.py create mode 100644 lib/shove/tests/test_file_store.py create mode 100644 lib/shove/tests/test_ftp_store.py create mode 100644 lib/shove/tests/test_hdf5_store.py create mode 100644 lib/shove/tests/test_leveldb_store.py create mode 100644 lib/shove/tests/test_memcached_cache.py create mode 100644 lib/shove/tests/test_memory_cache.py create mode 100644 lib/shove/tests/test_memory_store.py create mode 100644 lib/shove/tests/test_redis_cache.py create mode 100644 lib/shove/tests/test_redis_store.py create mode 100644 lib/shove/tests/test_s3_store.py create mode 100644 lib/shove/tests/test_simple_cache.py create mode 100644 lib/shove/tests/test_simple_store.py create mode 100644 lib/shove/tests/test_svn_store.py create mode 100644 lib/shove/tests/test_zodb_store.py create mode 100644 lib/sqlalchemy/__init__.py create mode 100644 lib/sqlalchemy/cextension/processors.c create mode 100644 lib/sqlalchemy/cextension/resultproxy.c create mode 100644 lib/sqlalchemy/cextension/utils.c create mode 100644 lib/sqlalchemy/connectors/__init__.py create mode 100644 lib/sqlalchemy/connectors/mxodbc.py create mode 100644 lib/sqlalchemy/connectors/mysqldb.py create mode 100644 lib/sqlalchemy/connectors/pyodbc.py create mode 100644 lib/sqlalchemy/connectors/zxJDBC.py create mode 100644 lib/sqlalchemy/databases/__init__.py create mode 100644 lib/sqlalchemy/dialects/__init__.py create mode 100644 lib/sqlalchemy/dialects/drizzle/__init__.py create mode 100644 lib/sqlalchemy/dialects/drizzle/base.py create mode 100644 lib/sqlalchemy/dialects/drizzle/mysqldb.py create mode 100644 lib/sqlalchemy/dialects/firebird/__init__.py create mode 100644 lib/sqlalchemy/dialects/firebird/base.py create mode 100644 lib/sqlalchemy/dialects/firebird/fdb.py create mode 100644 lib/sqlalchemy/dialects/firebird/kinterbasdb.py create mode 100644 lib/sqlalchemy/dialects/mssql/__init__.py create mode 100644 lib/sqlalchemy/dialects/mssql/adodbapi.py create mode 100644 lib/sqlalchemy/dialects/mssql/base.py create mode 100644 lib/sqlalchemy/dialects/mssql/information_schema.py create mode 100644 lib/sqlalchemy/dialects/mssql/mxodbc.py create mode 100644 lib/sqlalchemy/dialects/mssql/pymssql.py create mode 100644 lib/sqlalchemy/dialects/mssql/pyodbc.py create mode 100644 lib/sqlalchemy/dialects/mssql/zxjdbc.py create mode 100644 lib/sqlalchemy/dialects/mysql/__init__.py create mode 100644 lib/sqlalchemy/dialects/mysql/base.py create mode 100644 lib/sqlalchemy/dialects/mysql/cymysql.py create mode 100644 lib/sqlalchemy/dialects/mysql/gaerdbms.py create mode 100644 lib/sqlalchemy/dialects/mysql/mysqlconnector.py create mode 100644 lib/sqlalchemy/dialects/mysql/mysqldb.py create mode 100644 lib/sqlalchemy/dialects/mysql/oursql.py create mode 100644 lib/sqlalchemy/dialects/mysql/pymysql.py create mode 100644 lib/sqlalchemy/dialects/mysql/pyodbc.py create mode 100644 lib/sqlalchemy/dialects/mysql/zxjdbc.py create mode 100644 lib/sqlalchemy/dialects/oracle/__init__.py create mode 100644 lib/sqlalchemy/dialects/oracle/base.py create mode 100644 lib/sqlalchemy/dialects/oracle/cx_oracle.py create mode 100644 lib/sqlalchemy/dialects/oracle/zxjdbc.py create mode 100644 lib/sqlalchemy/dialects/postgres.py create mode 100644 lib/sqlalchemy/dialects/postgresql/__init__.py create mode 100644 lib/sqlalchemy/dialects/postgresql/base.py create mode 100644 lib/sqlalchemy/dialects/postgresql/constraints.py create mode 100644 lib/sqlalchemy/dialects/postgresql/hstore.py create mode 100644 lib/sqlalchemy/dialects/postgresql/json.py create mode 100644 lib/sqlalchemy/dialects/postgresql/pg8000.py create mode 100644 lib/sqlalchemy/dialects/postgresql/psycopg2.py create mode 100644 lib/sqlalchemy/dialects/postgresql/pypostgresql.py create mode 100644 lib/sqlalchemy/dialects/postgresql/ranges.py create mode 100644 lib/sqlalchemy/dialects/postgresql/zxjdbc.py create mode 100644 lib/sqlalchemy/dialects/sqlite/__init__.py create mode 100644 lib/sqlalchemy/dialects/sqlite/base.py create mode 100644 lib/sqlalchemy/dialects/sqlite/pysqlite.py create mode 100644 lib/sqlalchemy/dialects/sybase/__init__.py create mode 100644 lib/sqlalchemy/dialects/sybase/base.py create mode 100644 lib/sqlalchemy/dialects/sybase/mxodbc.py create mode 100644 lib/sqlalchemy/dialects/sybase/pyodbc.py create mode 100644 lib/sqlalchemy/dialects/sybase/pysybase.py create mode 100644 lib/sqlalchemy/dialects/type_migration_guidelines.txt create mode 100644 lib/sqlalchemy/engine/__init__.py create mode 100644 lib/sqlalchemy/engine/base.py create mode 100644 lib/sqlalchemy/engine/default.py create mode 100644 lib/sqlalchemy/engine/interfaces.py create mode 100644 lib/sqlalchemy/engine/reflection.py create mode 100644 lib/sqlalchemy/engine/result.py create mode 100644 lib/sqlalchemy/engine/strategies.py create mode 100644 lib/sqlalchemy/engine/threadlocal.py create mode 100644 lib/sqlalchemy/engine/url.py create mode 100644 lib/sqlalchemy/engine/util.py create mode 100644 lib/sqlalchemy/event/__init__.py create mode 100644 lib/sqlalchemy/event/api.py create mode 100644 lib/sqlalchemy/event/attr.py create mode 100644 lib/sqlalchemy/event/base.py create mode 100644 lib/sqlalchemy/event/legacy.py create mode 100644 lib/sqlalchemy/event/registry.py create mode 100644 lib/sqlalchemy/events.py create mode 100644 lib/sqlalchemy/exc.py create mode 100644 lib/sqlalchemy/ext/__init__.py create mode 100644 lib/sqlalchemy/ext/associationproxy.py create mode 100644 lib/sqlalchemy/ext/automap.py create mode 100644 lib/sqlalchemy/ext/compiler.py create mode 100644 lib/sqlalchemy/ext/declarative/__init__.py create mode 100644 lib/sqlalchemy/ext/declarative/api.py create mode 100644 lib/sqlalchemy/ext/declarative/base.py create mode 100644 lib/sqlalchemy/ext/declarative/clsregistry.py create mode 100644 lib/sqlalchemy/ext/horizontal_shard.py create mode 100644 lib/sqlalchemy/ext/hybrid.py create mode 100644 lib/sqlalchemy/ext/instrumentation.py create mode 100644 lib/sqlalchemy/ext/mutable.py create mode 100644 lib/sqlalchemy/ext/orderinglist.py create mode 100644 lib/sqlalchemy/ext/serializer.py create mode 100644 lib/sqlalchemy/inspection.py create mode 100644 lib/sqlalchemy/interfaces.py create mode 100644 lib/sqlalchemy/log.py create mode 100644 lib/sqlalchemy/orm/__init__.py create mode 100644 lib/sqlalchemy/orm/attributes.py create mode 100644 lib/sqlalchemy/orm/base.py create mode 100644 lib/sqlalchemy/orm/collections.py create mode 100644 lib/sqlalchemy/orm/dependency.py create mode 100644 lib/sqlalchemy/orm/deprecated_interfaces.py create mode 100644 lib/sqlalchemy/orm/descriptor_props.py create mode 100644 lib/sqlalchemy/orm/dynamic.py create mode 100644 lib/sqlalchemy/orm/evaluator.py create mode 100644 lib/sqlalchemy/orm/events.py create mode 100644 lib/sqlalchemy/orm/exc.py create mode 100644 lib/sqlalchemy/orm/identity.py create mode 100644 lib/sqlalchemy/orm/instrumentation.py create mode 100644 lib/sqlalchemy/orm/interfaces.py create mode 100644 lib/sqlalchemy/orm/loading.py create mode 100644 lib/sqlalchemy/orm/mapper.py create mode 100644 lib/sqlalchemy/orm/path_registry.py create mode 100644 lib/sqlalchemy/orm/persistence.py create mode 100644 lib/sqlalchemy/orm/properties.py create mode 100644 lib/sqlalchemy/orm/query.py create mode 100644 lib/sqlalchemy/orm/relationships.py create mode 100644 lib/sqlalchemy/orm/scoping.py create mode 100644 lib/sqlalchemy/orm/session.py create mode 100644 lib/sqlalchemy/orm/state.py create mode 100644 lib/sqlalchemy/orm/strategies.py create mode 100644 lib/sqlalchemy/orm/strategy_options.py create mode 100644 lib/sqlalchemy/orm/sync.py create mode 100644 lib/sqlalchemy/orm/unitofwork.py create mode 100644 lib/sqlalchemy/orm/util.py create mode 100644 lib/sqlalchemy/pool.py create mode 100644 lib/sqlalchemy/processors.py create mode 100644 lib/sqlalchemy/schema.py create mode 100644 lib/sqlalchemy/sql/__init__.py create mode 100644 lib/sqlalchemy/sql/annotation.py create mode 100644 lib/sqlalchemy/sql/base.py create mode 100644 lib/sqlalchemy/sql/compiler.py create mode 100644 lib/sqlalchemy/sql/ddl.py create mode 100644 lib/sqlalchemy/sql/default_comparator.py create mode 100644 lib/sqlalchemy/sql/dml.py create mode 100644 lib/sqlalchemy/sql/elements.py create mode 100644 lib/sqlalchemy/sql/expression.py create mode 100644 lib/sqlalchemy/sql/functions.py create mode 100644 lib/sqlalchemy/sql/naming.py create mode 100644 lib/sqlalchemy/sql/operators.py create mode 100644 lib/sqlalchemy/sql/schema.py create mode 100644 lib/sqlalchemy/sql/selectable.py create mode 100644 lib/sqlalchemy/sql/sqltypes.py create mode 100644 lib/sqlalchemy/sql/type_api.py create mode 100644 lib/sqlalchemy/sql/util.py create mode 100644 lib/sqlalchemy/sql/visitors.py create mode 100644 lib/sqlalchemy/testing/__init__.py create mode 100644 lib/sqlalchemy/testing/assertions.py create mode 100644 lib/sqlalchemy/testing/assertsql.py create mode 100644 lib/sqlalchemy/testing/config.py create mode 100644 lib/sqlalchemy/testing/engines.py create mode 100644 lib/sqlalchemy/testing/entities.py create mode 100644 lib/sqlalchemy/testing/exclusions.py create mode 100644 lib/sqlalchemy/testing/fixtures.py create mode 100644 lib/sqlalchemy/testing/mock.py create mode 100644 lib/sqlalchemy/testing/pickleable.py create mode 100644 lib/sqlalchemy/testing/plugin/__init__.py create mode 100644 lib/sqlalchemy/testing/plugin/noseplugin.py create mode 100644 lib/sqlalchemy/testing/plugin/plugin_base.py create mode 100644 lib/sqlalchemy/testing/plugin/pytestplugin.py create mode 100644 lib/sqlalchemy/testing/profiling.py create mode 100644 lib/sqlalchemy/testing/requirements.py create mode 100644 lib/sqlalchemy/testing/runner.py create mode 100644 lib/sqlalchemy/testing/schema.py create mode 100644 lib/sqlalchemy/testing/suite/__init__.py create mode 100644 lib/sqlalchemy/testing/suite/test_ddl.py create mode 100644 lib/sqlalchemy/testing/suite/test_insert.py create mode 100644 lib/sqlalchemy/testing/suite/test_reflection.py create mode 100644 lib/sqlalchemy/testing/suite/test_results.py create mode 100644 lib/sqlalchemy/testing/suite/test_select.py create mode 100644 lib/sqlalchemy/testing/suite/test_sequence.py create mode 100644 lib/sqlalchemy/testing/suite/test_types.py create mode 100644 lib/sqlalchemy/testing/suite/test_update_delete.py create mode 100644 lib/sqlalchemy/testing/util.py create mode 100644 lib/sqlalchemy/testing/warnings.py create mode 100644 lib/sqlalchemy/types.py create mode 100644 lib/sqlalchemy/util/__init__.py create mode 100644 lib/sqlalchemy/util/_collections.py create mode 100644 lib/sqlalchemy/util/compat.py create mode 100644 lib/sqlalchemy/util/deprecations.py create mode 100644 lib/sqlalchemy/util/langhelpers.py create mode 100644 lib/sqlalchemy/util/queue.py create mode 100644 lib/sqlalchemy/util/topological.py diff --git a/lib/shove/__init__.py b/lib/shove/__init__.py new file mode 100644 index 00000000..3be119b4 --- /dev/null +++ b/lib/shove/__init__.py @@ -0,0 +1,519 @@ +# -*- coding: utf-8 -*- +'''Common object storage frontend.''' + +import os +import zlib +import urllib +try: + import cPickle as pickle +except ImportError: + import pickle +from collections import deque + +try: + # Import store and cache entry points if setuptools installed + import pkg_resources + stores = dict((_store.name, _store) for _store in + pkg_resources.iter_entry_points('shove.stores')) + caches = dict((_cache.name, _cache) for _cache in + pkg_resources.iter_entry_points('shove.caches')) + # Pass if nothing loaded + if not stores and not caches: + raise ImportError() +except ImportError: + # Static store backend registry + stores = dict( + bsddb='shove.store.bsdb:BsdStore', + cassandra='shove.store.cassandra:CassandraStore', + dbm='shove.store.dbm:DbmStore', + durus='shove.store.durusdb:DurusStore', + file='shove.store.file:FileStore', + firebird='shove.store.db:DbStore', + ftp='shove.store.ftp:FtpStore', + hdf5='shove.store.hdf5:HDF5Store', + leveldb='shove.store.leveldbstore:LevelDBStore', + memory='shove.store.memory:MemoryStore', + mssql='shove.store.db:DbStore', + mysql='shove.store.db:DbStore', + oracle='shove.store.db:DbStore', + postgres='shove.store.db:DbStore', + redis='shove.store.redisdb:RedisStore', + s3='shove.store.s3:S3Store', + simple='shove.store.simple:SimpleStore', + sqlite='shove.store.db:DbStore', + svn='shove.store.svn:SvnStore', + zodb='shove.store.zodb:ZodbStore', + ) + # Static cache backend registry + caches = dict( + bsddb='shove.cache.bsdb:BsdCache', + file='shove.cache.file:FileCache', + filelru='shove.cache.filelru:FileLRUCache', + firebird='shove.cache.db:DbCache', + memcache='shove.cache.memcached:MemCached', + memlru='shove.cache.memlru:MemoryLRUCache', + memory='shove.cache.memory:MemoryCache', + mssql='shove.cache.db:DbCache', + mysql='shove.cache.db:DbCache', + oracle='shove.cache.db:DbCache', + postgres='shove.cache.db:DbCache', + redis='shove.cache.redisdb:RedisCache', + simple='shove.cache.simple:SimpleCache', + simplelru='shove.cache.simplelru:SimpleLRUCache', + sqlite='shove.cache.db:DbCache', + ) + + +def getbackend(uri, engines, **kw): + ''' + Loads the right backend based on a URI. + + @param uri Instance or name string + @param engines A dictionary of scheme/class pairs + ''' + if isinstance(uri, basestring): + mod = engines[uri.split('://', 1)[0]] + # Load module if setuptools not present + if isinstance(mod, basestring): + # Isolate classname from dot path + module, klass = mod.split(':') + # Load module + mod = getattr(__import__(module, '', '', ['']), klass) + # Load appropriate class from setuptools entry point + else: + mod = mod.load() + # Return instance + return mod(uri, **kw) + # No-op for existing instances + return uri + + +def synchronized(func): + ''' + Decorator to lock and unlock a method (Phillip J. Eby). + + @param func Method to decorate + ''' + def wrapper(self, *__args, **__kw): + self._lock.acquire() + try: + return func(self, *__args, **__kw) + finally: + self._lock.release() + wrapper.__name__ = func.__name__ + wrapper.__dict__ = func.__dict__ + wrapper.__doc__ = func.__doc__ + return wrapper + + +class Base(object): + + '''Base Mapping class.''' + + def __init__(self, engine, **kw): + ''' + @keyword compress True, False, or an integer compression level (1-9). + ''' + self._compress = kw.get('compress', False) + self._protocol = kw.get('protocol', pickle.HIGHEST_PROTOCOL) + + def __getitem__(self, key): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __contains__(self, key): + try: + value = self[key] + except KeyError: + return False + return True + + def get(self, key, default=None): + ''' + Fetch a given key from the mapping. If the key does not exist, + return the default. + + @param key Keyword of item in mapping. + @param default Default value (default: None) + ''' + try: + return self[key] + except KeyError: + return default + + def dumps(self, value): + '''Optionally serializes and compresses an object.''' + # Serialize everything but ASCII strings + value = pickle.dumps(value, protocol=self._protocol) + if self._compress: + level = 9 if self._compress is True else self._compress + value = zlib.compress(value, level) + return value + + def loads(self, value): + '''Deserializes and optionally decompresses an object.''' + if self._compress: + try: + value = zlib.decompress(value) + except zlib.error: + pass + value = pickle.loads(value) + return value + + +class BaseStore(Base): + + '''Base Store class (based on UserDict.DictMixin).''' + + def __init__(self, engine, **kw): + super(BaseStore, self).__init__(engine, **kw) + self._store = None + + def __cmp__(self, other): + if other is None: + return False + if isinstance(other, BaseStore): + return cmp(dict(self.iteritems()), dict(other.iteritems())) + + def __del__(self): + # __init__ didn't succeed, so don't bother closing + if not hasattr(self, '_store'): + return + self.close() + + def __iter__(self): + for k in self.keys(): + yield k + + def __len__(self): + return len(self.keys()) + + def __repr__(self): + return repr(dict(self.iteritems())) + + def close(self): + '''Closes internal store and clears object references.''' + try: + self._store.close() + except AttributeError: + pass + self._store = None + + def clear(self): + '''Removes all keys and values from a store.''' + for key in self.keys(): + del self[key] + + def items(self): + '''Returns a list with all key/value pairs in the store.''' + return list(self.iteritems()) + + def iteritems(self): + '''Lazily returns all key/value pairs in a store.''' + for k in self: + yield (k, self[k]) + + def iterkeys(self): + '''Lazy returns all keys in a store.''' + return self.__iter__() + + def itervalues(self): + '''Lazily returns all values in a store.''' + for _, v in self.iteritems(): + yield v + + def keys(self): + '''Returns a list with all keys in a store.''' + raise NotImplementedError() + + def pop(self, key, *args): + ''' + Removes and returns a value from a store. + + @param args Default to return if key not present. + ''' + if len(args) > 1: + raise TypeError('pop expected at most 2 arguments, got ' + repr( + 1 + len(args)) + ) + try: + value = self[key] + # Return default if key not in store + except KeyError: + if args: + return args[0] + del self[key] + return value + + def popitem(self): + '''Removes and returns a key, value pair from a store.''' + try: + k, v = self.iteritems().next() + except StopIteration: + raise KeyError('Store is empty.') + del self[k] + return (k, v) + + def setdefault(self, key, default=None): + ''' + Returns the value corresponding to an existing key or sets the + to key to the default and returns the default. + + @param default Default value (default: None) + ''' + try: + return self[key] + except KeyError: + self[key] = default + return default + + def update(self, other=None, **kw): + ''' + Adds to or overwrites the values in this store with values from + another store. + + other Another store + kw Additional keys and values to store + ''' + if other is None: + pass + elif hasattr(other, 'iteritems'): + for k, v in other.iteritems(): + self[k] = v + elif hasattr(other, 'keys'): + for k in other.keys(): + self[k] = other[k] + else: + for k, v in other: + self[k] = v + if kw: + self.update(kw) + + def values(self): + '''Returns a list with all values in a store.''' + return list(v for _, v in self.iteritems()) + + +class Shove(BaseStore): + + '''Common object frontend class.''' + + def __init__(self, store='simple://', cache='simple://', **kw): + super(Shove, self).__init__(store, **kw) + # Load store + self._store = getbackend(store, stores, **kw) + # Load cache + self._cache = getbackend(cache, caches, **kw) + # Buffer for lazy writing and setting for syncing frequency + self._buffer, self._sync = dict(), kw.get('sync', 2) + + def __getitem__(self, key): + '''Gets a item from shove.''' + try: + return self._cache[key] + except KeyError: + # Synchronize cache and store + self.sync() + value = self._store[key] + self._cache[key] = value + return value + + def __setitem__(self, key, value): + '''Sets an item in shove.''' + self._cache[key] = self._buffer[key] = value + # When the buffer reaches self._limit, writes the buffer to the store + if len(self._buffer) >= self._sync: + self.sync() + + def __delitem__(self, key): + '''Deletes an item from shove.''' + try: + del self._cache[key] + except KeyError: + pass + self.sync() + del self._store[key] + + def keys(self): + '''Returns a list of keys in shove.''' + self.sync() + return self._store.keys() + + def sync(self): + '''Writes buffer to store.''' + for k, v in self._buffer.iteritems(): + self._store[k] = v + self._buffer.clear() + + def close(self): + '''Finalizes and closes shove.''' + # If close has been called, pass + if self._store is not None: + try: + self.sync() + except AttributeError: + pass + self._store.close() + self._store = self._cache = self._buffer = None + + +class FileBase(Base): + + '''Base class for file based storage.''' + + def __init__(self, engine, **kw): + super(FileBase, self).__init__(engine, **kw) + if engine.startswith('file://'): + engine = urllib.url2pathname(engine.split('://')[1]) + self._dir = engine + # Create directory + if not os.path.exists(self._dir): + self._createdir() + + def __getitem__(self, key): + # (per Larry Meyn) + try: + item = open(self._key_to_file(key), 'rb') + data = item.read() + item.close() + return self.loads(data) + except: + raise KeyError(key) + + def __setitem__(self, key, value): + # (per Larry Meyn) + try: + item = open(self._key_to_file(key), 'wb') + item.write(self.dumps(value)) + item.close() + except (IOError, OSError): + raise KeyError(key) + + def __delitem__(self, key): + try: + os.remove(self._key_to_file(key)) + except (IOError, OSError): + raise KeyError(key) + + def __contains__(self, key): + return os.path.exists(self._key_to_file(key)) + + def __len__(self): + return len(os.listdir(self._dir)) + + def _createdir(self): + '''Creates the store directory.''' + try: + os.makedirs(self._dir) + except OSError: + raise EnvironmentError( + 'Cache directory "%s" does not exist and ' \ + 'could not be created' % self._dir + ) + + def _key_to_file(self, key): + '''Gives the filesystem path for a key.''' + return os.path.join(self._dir, urllib.quote_plus(key)) + + def keys(self): + '''Returns a list of keys in the store.''' + return [urllib.unquote_plus(name) for name in os.listdir(self._dir)] + + +class SimpleBase(Base): + + '''Single-process in-memory store base class.''' + + def __init__(self, engine, **kw): + super(SimpleBase, self).__init__(engine, **kw) + self._store = dict() + + def __getitem__(self, key): + try: + return self._store[key] + except: + raise KeyError(key) + + def __setitem__(self, key, value): + self._store[key] = value + + def __delitem__(self, key): + try: + del self._store[key] + except: + raise KeyError(key) + + def __len__(self): + return len(self._store) + + def keys(self): + '''Returns a list of keys in the store.''' + return self._store.keys() + + +class LRUBase(SimpleBase): + + def __init__(self, engine, **kw): + super(LRUBase, self).__init__(engine, **kw) + self._max_entries = kw.get('max_entries', 300) + self._hits = 0 + self._misses = 0 + self._queue = deque() + self._refcount = dict() + + def __getitem__(self, key): + try: + value = super(LRUBase, self).__getitem__(key) + self._hits += 1 + except KeyError: + self._misses += 1 + raise + self._housekeep(key) + return value + + def __setitem__(self, key, value): + super(LRUBase, self).__setitem__(key, value) + self._housekeep(key) + if len(self._store) > self._max_entries: + while len(self._store) > self._max_entries: + k = self._queue.popleft() + self._refcount[k] -= 1 + if not self._refcount[k]: + super(LRUBase, self).__delitem__(k) + del self._refcount[k] + + def _housekeep(self, key): + self._queue.append(key) + self._refcount[key] = self._refcount.get(key, 0) + 1 + if len(self._queue) > self._max_entries * 4: + self._purge_queue() + + def _purge_queue(self): + for i in [None] * len(self._queue): + k = self._queue.popleft() + if self._refcount[k] == 1: + self._queue.append(k) + else: + self._refcount[k] -= 1 + + +class DbBase(Base): + + '''Database common base class.''' + + def __init__(self, engine, **kw): + super(DbBase, self).__init__(engine, **kw) + + def __delitem__(self, key): + self._store.delete(self._store.c.key == key).execute() + + def __len__(self): + return self._store.count().execute().fetchone()[0] + + +__all__ = ['Shove'] diff --git a/lib/shove/cache/__init__.py b/lib/shove/cache/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/lib/shove/cache/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/lib/shove/cache/db.py b/lib/shove/cache/db.py new file mode 100644 index 00000000..21fea01f --- /dev/null +++ b/lib/shove/cache/db.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +''' +Database object cache. + +The shove psuedo-URL used for database object caches is the format used by +SQLAlchemy: + +://:@:/ + + is the database engine. The engines currently supported SQLAlchemy are +sqlite, mysql, postgres, oracle, mssql, and firebird. + is the database account user name + is the database accound password + is the database location + is the database port + is the name of the specific database + +For more information on specific databases see: + +http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported +''' + +import time +import random +from datetime import datetime +try: + from sqlalchemy import ( + MetaData, Table, Column, String, Binary, DateTime, select, update, + insert, delete, + ) + from shove import DbBase +except ImportError: + raise ImportError('Requires SQLAlchemy >= 0.4') + +__all__ = ['DbCache'] + + +class DbCache(DbBase): + + '''database cache backend''' + + def __init__(self, engine, **kw): + super(DbCache, self).__init__(engine, **kw) + # Get table name + tablename = kw.get('tablename', 'cache') + # Bind metadata + self._metadata = MetaData(engine) + # Make cache table + self._store = Table(tablename, self._metadata, + Column('key', String(60), primary_key=True, nullable=False), + Column('value', Binary, nullable=False), + Column('expires', DateTime, nullable=False), + ) + # Create cache table if it does not exist + if not self._store.exists(): + self._store.create() + # Set maximum entries + self._max_entries = kw.get('max_entries', 300) + # Maximum number of entries to cull per call if cache is full + self._maxcull = kw.get('maxcull', 10) + # Set timeout + self.timeout = kw.get('timeout', 300) + + def __getitem__(self, key): + row = select( + [self._store.c.value, self._store.c.expires], + self._store.c.key == key + ).execute().fetchone() + if row is not None: + # Remove if item expired + if row.expires < datetime.now().replace(microsecond=0): + del self[key] + raise KeyError(key) + return self.loads(str(row.value)) + raise KeyError(key) + + def __setitem__(self, key, value): + timeout, value, cache = self.timeout, self.dumps(value), self._store + # Cull if too many items + if len(self) >= self._max_entries: + self._cull() + # Generate expiration time + expires = datetime.fromtimestamp( + time.time() + timeout + ).replace(microsecond=0) + # Update database if key already present + if key in self: + update( + cache, + cache.c.key == key, + dict(value=value, expires=expires), + ).execute() + # Insert new key if key not present + else: + insert( + cache, dict(key=key, value=value, expires=expires) + ).execute() + + def _cull(self): + '''Remove items in cache to make more room.''' + cache, maxcull = self._store, self._maxcull + # Remove items that have timed out + now = datetime.now().replace(microsecond=0) + delete(cache, cache.c.expires < now).execute() + # Remove any items over the maximum allowed number in the cache + if len(self) >= self._max_entries: + # Upper limit for key query + ul = maxcull * 2 + # Get list of keys + keys = [ + i[0] for i in select( + [cache.c.key], limit=ul + ).execute().fetchall() + ] + # Get some keys at random + delkeys = list(random.choice(keys) for i in xrange(maxcull)) + delete(cache, cache.c.key.in_(delkeys)).execute() diff --git a/lib/shove/cache/file.py b/lib/shove/cache/file.py new file mode 100644 index 00000000..7b9a4ae7 --- /dev/null +++ b/lib/shove/cache/file.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +''' +File-based cache + +shove's psuedo-URL for file caches follows the form: + +file:// + +Where the path is a URL path to a directory on a local filesystem. +Alternatively, a native pathname to the directory can be passed as the 'engine' +argument. +''' + +import time + +from shove import FileBase +from shove.cache.simple import SimpleCache + + +class FileCache(FileBase, SimpleCache): + + '''File-based cache backend''' + + def __init__(self, engine, **kw): + super(FileCache, self).__init__(engine, **kw) + + def __getitem__(self, key): + try: + exp, value = super(FileCache, self).__getitem__(key) + # Remove item if time has expired. + if exp < time.time(): + del self[key] + raise KeyError(key) + return value + except: + raise KeyError(key) + + def __setitem__(self, key, value): + if len(self) >= self._max_entries: + self._cull() + super(FileCache, self).__setitem__( + key, (time.time() + self.timeout, value) + ) + + +__all__ = ['FileCache'] diff --git a/lib/shove/cache/filelru.py b/lib/shove/cache/filelru.py new file mode 100644 index 00000000..de076613 --- /dev/null +++ b/lib/shove/cache/filelru.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +''' +File-based LRU cache + +shove's psuedo-URL for file caches follows the form: + +file:// + +Where the path is a URL path to a directory on a local filesystem. +Alternatively, a native pathname to the directory can be passed as the 'engine' +argument. +''' + +from shove import FileBase +from shove.cache.simplelru import SimpleLRUCache + + +class FileCache(FileBase, SimpleLRUCache): + + '''File-based LRU cache backend''' + + +__all__ = ['FileCache'] diff --git a/lib/shove/cache/memcached.py b/lib/shove/cache/memcached.py new file mode 100644 index 00000000..aedfe282 --- /dev/null +++ b/lib/shove/cache/memcached.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +''' +"memcached" cache. + +The shove psuedo-URL for a memcache cache is: + +memcache:// +''' + +try: + import memcache +except ImportError: + raise ImportError("Memcache cache requires the 'memcache' library") + +from shove import Base + + +class MemCached(Base): + + '''Memcached cache backend''' + + def __init__(self, engine, **kw): + super(MemCached, self).__init__(engine, **kw) + if engine.startswith('memcache://'): + engine = engine.split('://')[1] + self._store = memcache.Client(engine.split(';')) + # Set timeout + self.timeout = kw.get('timeout', 300) + + def __getitem__(self, key): + value = self._store.get(key) + if value is None: + raise KeyError(key) + return self.loads(value) + + def __setitem__(self, key, value): + self._store.set(key, self.dumps(value), self.timeout) + + def __delitem__(self, key): + self._store.delete(key) + + +__all__ = ['MemCached'] diff --git a/lib/shove/cache/memlru.py b/lib/shove/cache/memlru.py new file mode 100644 index 00000000..7db61ec5 --- /dev/null +++ b/lib/shove/cache/memlru.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +''' +Thread-safe in-memory cache using LRU. + +The shove psuedo-URL for a memory cache is: + +memlru:// +''' + +import copy +import threading + +from shove import synchronized +from shove.cache.simplelru import SimpleLRUCache + + +class MemoryLRUCache(SimpleLRUCache): + + '''Thread-safe in-memory cache backend using LRU.''' + + def __init__(self, engine, **kw): + super(MemoryLRUCache, self).__init__(engine, **kw) + self._lock = threading.Condition() + + @synchronized + def __setitem__(self, key, value): + super(MemoryLRUCache, self).__setitem__(key, value) + + @synchronized + def __getitem__(self, key): + return copy.deepcopy(super(MemoryLRUCache, self).__getitem__(key)) + + @synchronized + def __delitem__(self, key): + super(MemoryLRUCache, self).__delitem__(key) + + +__all__ = ['MemoryLRUCache'] diff --git a/lib/shove/cache/memory.py b/lib/shove/cache/memory.py new file mode 100644 index 00000000..e70f9bbb --- /dev/null +++ b/lib/shove/cache/memory.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +''' +Thread-safe in-memory cache. + +The shove psuedo-URL for a memory cache is: + +memory:// +''' + +import copy +import threading + +from shove import synchronized +from shove.cache.simple import SimpleCache + + +class MemoryCache(SimpleCache): + + '''Thread-safe in-memory cache backend.''' + + def __init__(self, engine, **kw): + super(MemoryCache, self).__init__(engine, **kw) + self._lock = threading.Condition() + + @synchronized + def __setitem__(self, key, value): + super(MemoryCache, self).__setitem__(key, value) + + @synchronized + def __getitem__(self, key): + return copy.deepcopy(super(MemoryCache, self).__getitem__(key)) + + @synchronized + def __delitem__(self, key): + super(MemoryCache, self).__delitem__(key) + + +__all__ = ['MemoryCache'] diff --git a/lib/shove/cache/redisdb.py b/lib/shove/cache/redisdb.py new file mode 100644 index 00000000..c53536c1 --- /dev/null +++ b/lib/shove/cache/redisdb.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +''' +Redis-based object cache + +The shove psuedo-URL for a redis cache is: + +redis://:/ +''' + +import urlparse + +try: + import redis +except ImportError: + raise ImportError('This store requires the redis library') + +from shove import Base + + +class RedisCache(Base): + + '''Redis cache backend''' + + init = 'redis://' + + def __init__(self, engine, **kw): + super(RedisCache, self).__init__(engine, **kw) + spliturl = urlparse.urlsplit(engine) + host, port = spliturl[1].split(':') + db = spliturl[2].replace('/', '') + self._store = redis.Redis(host, int(port), db) + # Set timeout + self.timeout = kw.get('timeout', 300) + + def __getitem__(self, key): + return self.loads(self._store[key]) + + def __setitem__(self, key, value): + self._store.setex(key, self.dumps(value), self.timeout) + + def __delitem__(self, key): + self._store.delete(key) + + +__all__ = ['RedisCache'] diff --git a/lib/shove/cache/simple.py b/lib/shove/cache/simple.py new file mode 100644 index 00000000..6855603e --- /dev/null +++ b/lib/shove/cache/simple.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +''' +Single-process in-memory cache. + +The shove psuedo-URL for a simple cache is: + +simple:// +''' + +import time +import random + +from shove import SimpleBase + + +class SimpleCache(SimpleBase): + + '''Single-process in-memory cache.''' + + def __init__(self, engine, **kw): + super(SimpleCache, self).__init__(engine, **kw) + # Get random seed + random.seed() + # Set maximum number of items to cull if over max + self._maxcull = kw.get('maxcull', 10) + # Set max entries + self._max_entries = kw.get('max_entries', 300) + # Set timeout + self.timeout = kw.get('timeout', 300) + + def __getitem__(self, key): + exp, value = super(SimpleCache, self).__getitem__(key) + # Delete if item timed out. + if exp < time.time(): + super(SimpleCache, self).__delitem__(key) + raise KeyError(key) + return value + + def __setitem__(self, key, value): + # Cull values if over max # of entries + if len(self) >= self._max_entries: + self._cull() + # Set expiration time and value + exp = time.time() + self.timeout + super(SimpleCache, self).__setitem__(key, (exp, value)) + + def _cull(self): + '''Remove items in cache to make room.''' + num, maxcull = 0, self._maxcull + # Cull number of items allowed (set by self._maxcull) + for key in self.keys(): + # Remove only maximum # of items allowed by maxcull + if num <= maxcull: + # Remove items if expired + try: + self[key] + except KeyError: + num += 1 + else: + break + # Remove any additional items up to max # of items allowed by maxcull + while len(self) >= self._max_entries and num <= maxcull: + # Cull remainder of allowed quota at random + del self[random.choice(self.keys())] + num += 1 + + +__all__ = ['SimpleCache'] diff --git a/lib/shove/cache/simplelru.py b/lib/shove/cache/simplelru.py new file mode 100644 index 00000000..fbb6e446 --- /dev/null +++ b/lib/shove/cache/simplelru.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +''' +Single-process in-memory LRU cache. + +The shove psuedo-URL for a simple cache is: + +simplelru:// +''' + +from shove import LRUBase + + +class SimpleLRUCache(LRUBase): + + '''In-memory cache that purges based on least recently used item.''' + + +__all__ = ['SimpleLRUCache'] diff --git a/lib/shove/store/__init__.py b/lib/shove/store/__init__.py new file mode 100644 index 00000000..5d639a07 --- /dev/null +++ b/lib/shove/store/__init__.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +from urllib import url2pathname +from shove.store.simple import SimpleStore + + +class ClientStore(SimpleStore): + + '''Base class for stores where updates have to be committed.''' + + def __init__(self, engine, **kw): + super(ClientStore, self).__init__(engine, **kw) + if engine.startswith(self.init): + self._engine = url2pathname(engine.split('://')[1]) + + def __getitem__(self, key): + return self.loads(super(ClientStore, self).__getitem__(key)) + + def __setitem__(self, key, value): + super(ClientStore, self).__setitem__(key, self.dumps(value)) + + +class SyncStore(ClientStore): + + '''Base class for stores where updates have to be committed.''' + + def __getitem__(self, key): + return self.loads(super(SyncStore, self).__getitem__(key)) + + def __setitem__(self, key, value): + super(SyncStore, self).__setitem__(key, value) + try: + self.sync() + except AttributeError: + pass + + def __delitem__(self, key): + super(SyncStore, self).__delitem__(key) + try: + self.sync() + except AttributeError: + pass + + +__all__ = [ + 'bsdb', 'db', 'dbm', 'durusdb', 'file', 'ftp', 'memory', 's3', 'simple', + 'svn', 'zodb', 'redisdb', 'hdf5db', 'leveldbstore', 'cassandra', +] diff --git a/lib/shove/store/bsdb.py b/lib/shove/store/bsdb.py new file mode 100644 index 00000000..d1f9c6dc --- /dev/null +++ b/lib/shove/store/bsdb.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +''' +Berkeley Source Database Store. + +shove's psuedo-URL for BSDDB stores follows the form: + +bsddb:// + +Where the path is a URL path to a Berkeley database. Alternatively, the native +pathname to a Berkeley database can be passed as the 'engine' parameter. +''' +try: + import bsddb +except ImportError: + raise ImportError('requires bsddb library') + +import threading + +from shove import synchronized +from shove.store import SyncStore + + +class BsdStore(SyncStore): + + '''Class for Berkeley Source Database Store.''' + + init = 'bsddb://' + + def __init__(self, engine, **kw): + super(BsdStore, self).__init__(engine, **kw) + self._store = bsddb.hashopen(self._engine) + self._lock = threading.Condition() + self.sync = self._store.sync + + @synchronized + def __getitem__(self, key): + return super(BsdStore, self).__getitem__(key) + + @synchronized + def __setitem__(self, key, value): + super(BsdStore, self).__setitem__(key, value) + + @synchronized + def __delitem__(self, key): + super(BsdStore, self).__delitem__(key) + + +__all__ = ['BsdStore'] diff --git a/lib/shove/store/cassandra.py b/lib/shove/store/cassandra.py new file mode 100644 index 00000000..1f6532ee --- /dev/null +++ b/lib/shove/store/cassandra.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +''' +Cassandra-based object store + +The shove psuedo-URL for a cassandra-based store is: + +cassandra://:// +''' + +import urlparse + +try: + import pycassa +except ImportError: + raise ImportError('This store requires the pycassa library') + +from shove import BaseStore + + +class CassandraStore(BaseStore): + + '''Cassandra based store''' + + init = 'cassandra://' + + def __init__(self, engine, **kw): + super(CassandraStore, self).__init__(engine, **kw) + spliturl = urlparse.urlsplit(engine) + _, keyspace, column_family = spliturl[2].split('/') + try: + self._pool = pycassa.connect(keyspace, [spliturl[1]]) + self._store = pycassa.ColumnFamily(self._pool, column_family) + except pycassa.InvalidRequestException: + from pycassa.system_manager import SystemManager + system_manager = SystemManager(spliturl[1]) + system_manager.create_keyspace( + keyspace, + pycassa.system_manager.SIMPLE_STRATEGY, + {'replication_factor': str(kw.get('replication', 1))} + ) + system_manager.create_column_family(keyspace, column_family) + self._pool = pycassa.connect(keyspace, [spliturl[1]]) + self._store = pycassa.ColumnFamily(self._pool, column_family) + + def __getitem__(self, key): + try: + item = self._store.get(key).get(key) + if item is not None: + return self.loads(item) + raise KeyError(key) + except pycassa.NotFoundException: + raise KeyError(key) + + def __setitem__(self, key, value): + self._store.insert(key, dict(key=self.dumps(value))) + + def __delitem__(self, key): + # beware eventual consistency + try: + self._store.remove(key) + except pycassa.NotFoundException: + raise KeyError(key) + + def clear(self): + # beware eventual consistency + self._store.truncate() + + def keys(self): + return list(i[0] for i in self._store.get_range()) + + +__all__ = ['CassandraStore'] diff --git a/lib/shove/store/db.py b/lib/shove/store/db.py new file mode 100644 index 00000000..0004e6f8 --- /dev/null +++ b/lib/shove/store/db.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +''' +Database object store. + +The shove psuedo-URL used for database object stores is the format used by +SQLAlchemy: + +://:@:/ + + is the database engine. The engines currently supported SQLAlchemy are +sqlite, mysql, postgres, oracle, mssql, and firebird. + is the database account user name + is the database accound password + is the database location + is the database port + is the name of the specific database + +For more information on specific databases see: + +http://www.sqlalchemy.org/docs/dbengine.myt#dbengine_supported +''' + +try: + from sqlalchemy import MetaData, Table, Column, String, Binary, select + from shove import BaseStore, DbBase +except ImportError, e: + raise ImportError('Error: ' + e + ' Requires SQLAlchemy >= 0.4') + + +class DbStore(BaseStore, DbBase): + + '''Database cache backend.''' + + def __init__(self, engine, **kw): + super(DbStore, self).__init__(engine, **kw) + # Get tablename + tablename = kw.get('tablename', 'store') + # Bind metadata + self._metadata = MetaData(engine) + # Make store table + self._store = Table(tablename, self._metadata, + Column('key', String(255), primary_key=True, nullable=False), + Column('value', Binary, nullable=False), + ) + # Create store table if it does not exist + if not self._store.exists(): + self._store.create() + + def __getitem__(self, key): + row = select( + [self._store.c.value], self._store.c.key == key, + ).execute().fetchone() + if row is not None: + return self.loads(str(row.value)) + raise KeyError(key) + + def __setitem__(self, k, v): + v, store = self.dumps(v), self._store + # Update database if key already present + if k in self: + store.update(store.c.key == k).execute(value=v) + # Insert new key if key not present + else: + store.insert().execute(key=k, value=v) + + def keys(self): + '''Returns a list of keys in the store.''' + return list(i[0] for i in select( + [self._store.c.key] + ).execute().fetchall()) + + +__all__ = ['DbStore'] diff --git a/lib/shove/store/dbm.py b/lib/shove/store/dbm.py new file mode 100644 index 00000000..323d2484 --- /dev/null +++ b/lib/shove/store/dbm.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +''' +DBM Database Store. + +shove's psuedo-URL for DBM stores follows the form: + +dbm:// + +Where is a URL path to a DBM database. Alternatively, the native +pathname to a DBM database can be passed as the 'engine' parameter. +''' + +import anydbm + +from shove.store import SyncStore + + +class DbmStore(SyncStore): + + '''Class for variants of the DBM database.''' + + init = 'dbm://' + + def __init__(self, engine, **kw): + super(DbmStore, self).__init__(engine, **kw) + self._store = anydbm.open(self._engine, 'c') + try: + self.sync = self._store.sync + except AttributeError: + pass + + +__all__ = ['DbmStore'] diff --git a/lib/shove/store/durusdb.py b/lib/shove/store/durusdb.py new file mode 100644 index 00000000..8e27670e --- /dev/null +++ b/lib/shove/store/durusdb.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +''' +Durus object database frontend. + +shove's psuedo-URL for Durus stores follows the form: + +durus:// + + +Where the path is a URL path to a durus FileStorage database. Alternatively, a +native pathname to a durus database can be passed as the 'engine' parameter. +''' + +try: + from durus.connection import Connection + from durus.file_storage import FileStorage +except ImportError: + raise ImportError('Requires Durus library') + +from shove.store import SyncStore + + +class DurusStore(SyncStore): + + '''Class for Durus object database frontend.''' + + init = 'durus://' + + def __init__(self, engine, **kw): + super(DurusStore, self).__init__(engine, **kw) + self._db = FileStorage(self._engine) + self._connection = Connection(self._db) + self.sync = self._connection.commit + self._store = self._connection.get_root() + + def close(self): + '''Closes all open storage and connections.''' + self.sync() + self._db.close() + super(DurusStore, self).close() + + +__all__ = ['DurusStore'] diff --git a/lib/shove/store/file.py b/lib/shove/store/file.py new file mode 100644 index 00000000..e66e9c4f --- /dev/null +++ b/lib/shove/store/file.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +''' +Filesystem-based object store + +shove's psuedo-URL for filesystem-based stores follows the form: + +file:// + +Where the path is a URL path to a directory on a local filesystem. +Alternatively, a native pathname to the directory can be passed as the 'engine' +argument. +''' + +from shove import BaseStore, FileBase + + +class FileStore(FileBase, BaseStore): + + '''File-based store.''' + + def __init__(self, engine, **kw): + super(FileStore, self).__init__(engine, **kw) + + +__all__ = ['FileStore'] diff --git a/lib/shove/store/ftp.py b/lib/shove/store/ftp.py new file mode 100644 index 00000000..c2d4aec6 --- /dev/null +++ b/lib/shove/store/ftp.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +''' +FTP-accessed stores + +shove's URL for FTP accessed stores follows the standard form for FTP URLs +defined in RFC-1738: + +ftp://:@:/ +''' + +import urlparse +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO +from ftplib import FTP, error_perm + +from shove import BaseStore + + +class FtpStore(BaseStore): + + def __init__(self, engine, **kw): + super(FtpStore, self).__init__(engine, **kw) + user = kw.get('user', 'anonymous') + password = kw.get('password', '') + spliturl = urlparse.urlsplit(engine) + # Set URL, path, and strip 'ftp://' off + base, path = spliturl[1], spliturl[2] + '/' + if '@' in base: + auth, base = base.split('@') + user, password = auth.split(':') + self._store = FTP(base, user, password) + # Change to remote path if it exits + try: + self._store.cwd(path) + except error_perm: + self._makedir(path) + self._base, self._user, self._password = base, user, password + self._updated, self ._keys = True, None + + def __getitem__(self, key): + try: + local = StringIO() + # Download item + self._store.retrbinary('RETR %s' % key, local.write) + self._updated = False + return self.loads(local.getvalue()) + except: + raise KeyError(key) + + def __setitem__(self, key, value): + local = StringIO(self.dumps(value)) + self._store.storbinary('STOR %s' % key, local) + self._updated = True + + def __delitem__(self, key): + try: + self._store.delete(key) + self._updated = True + except: + raise KeyError(key) + + def _makedir(self, path): + '''Makes remote paths on an FTP server.''' + paths = list(reversed([i for i in path.split('/') if i != ''])) + while paths: + tpath = paths.pop() + self._store.mkd(tpath) + self._store.cwd(tpath) + + def keys(self): + '''Returns a list of keys in a store.''' + if self._updated or self._keys is None: + rlist, nlist = list(), list() + # Remote directory listing + self._store.retrlines('LIST -a', rlist.append) + for rlisting in rlist: + # Split remote file based on whitespace + rfile = rlisting.split() + # Append tuple of remote item type & name + if rfile[-1] not in ('.', '..') and rfile[0].startswith('-'): + nlist.append(rfile[-1]) + self._keys = nlist + return self._keys + + +__all__ = ['FtpStore'] diff --git a/lib/shove/store/hdf5.py b/lib/shove/store/hdf5.py new file mode 100644 index 00000000..a9b618e5 --- /dev/null +++ b/lib/shove/store/hdf5.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +''' +HDF5 Database Store. + +shove's psuedo-URL for HDF5 stores follows the form: + +hdf5:/// + +Where is a URL path to a HDF5 database. Alternatively, the native +pathname to a HDF5 database can be passed as the 'engine' parameter. + is the name of the database. +''' + +try: + import h5py +except ImportError: + raise ImportError('This store requires h5py library') + +from shove.store import ClientStore + + +class HDF5Store(ClientStore): + + '''LevelDB based store''' + + init = 'hdf5://' + + def __init__(self, engine, **kw): + super(HDF5Store, self).__init__(engine, **kw) + engine, group = self._engine.rsplit('/') + self._store = h5py.File(engine).require_group(group).attrs + + +__all__ = ['HDF5Store'] diff --git a/lib/shove/store/leveldbstore.py b/lib/shove/store/leveldbstore.py new file mode 100644 index 00000000..ca73a494 --- /dev/null +++ b/lib/shove/store/leveldbstore.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +''' +LevelDB Database Store. + +shove's psuedo-URL for LevelDB stores follows the form: + +leveldb:// + +Where is a URL path to a LevelDB database. Alternatively, the native +pathname to a LevelDB database can be passed as the 'engine' parameter. +''' + +try: + import leveldb +except ImportError: + raise ImportError('This store requires py-leveldb library') + +from shove.store import ClientStore + + +class LevelDBStore(ClientStore): + + '''LevelDB based store''' + + init = 'leveldb://' + + def __init__(self, engine, **kw): + super(LevelDBStore, self).__init__(engine, **kw) + self._store = leveldb.LevelDB(self._engine) + + def __getitem__(self, key): + item = self.loads(self._store.Get(key)) + if item is not None: + return item + raise KeyError(key) + + def __setitem__(self, key, value): + self._store.Put(key, self.dumps(value)) + + def __delitem__(self, key): + self._store.Delete(key) + + def keys(self): + return list(k for k in self._store.RangeIter(include_value=False)) + + +__all__ = ['LevelDBStore'] diff --git a/lib/shove/store/memory.py b/lib/shove/store/memory.py new file mode 100644 index 00000000..525ae69e --- /dev/null +++ b/lib/shove/store/memory.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +''' +Thread-safe in-memory store. + +The shove psuedo-URL for a memory store is: + +memory:// +''' + +import copy +import threading + +from shove import synchronized +from shove.store.simple import SimpleStore + + +class MemoryStore(SimpleStore): + + '''Thread-safe in-memory store.''' + + def __init__(self, engine, **kw): + super(MemoryStore, self).__init__(engine, **kw) + self._lock = threading.Condition() + + @synchronized + def __getitem__(self, key): + return copy.deepcopy(super(MemoryStore, self).__getitem__(key)) + + @synchronized + def __setitem__(self, key, value): + super(MemoryStore, self).__setitem__(key, value) + + @synchronized + def __delitem__(self, key): + super(MemoryStore, self).__delitem__(key) + + +__all__ = ['MemoryStore'] diff --git a/lib/shove/store/redisdb.py b/lib/shove/store/redisdb.py new file mode 100644 index 00000000..67fa2ebd --- /dev/null +++ b/lib/shove/store/redisdb.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +''' +Redis-based object store + +The shove psuedo-URL for a redis-based store is: + +redis://:/ +''' + +import urlparse + +try: + import redis +except ImportError: + raise ImportError('This store requires the redis library') + +from shove.store import ClientStore + + +class RedisStore(ClientStore): + + '''Redis based store''' + + init = 'redis://' + + def __init__(self, engine, **kw): + super(RedisStore, self).__init__(engine, **kw) + spliturl = urlparse.urlsplit(engine) + host, port = spliturl[1].split(':') + db = spliturl[2].replace('/', '') + self._store = redis.Redis(host, int(port), db) + + def __contains__(self, key): + return self._store.exists(key) + + def clear(self): + self._store.flushdb() + + def keys(self): + return self._store.keys() + + def setdefault(self, key, default=None): + return self._store.getset(key, default) + + def update(self, other=None, **kw): + args = kw if other is not None else other + self._store.mset(args) + + +__all__ = ['RedisStore'] diff --git a/lib/shove/store/s3.py b/lib/shove/store/s3.py new file mode 100644 index 00000000..dbf12f21 --- /dev/null +++ b/lib/shove/store/s3.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +''' +S3-accessed stores + +shove's psuedo-URL for stores found on Amazon.com's S3 web service follows this +form: + +s3://:@ + + is the Access Key issued by Amazon + is the Secret Access Key issued by Amazon + is the name of the bucket accessed through the S3 service +''' + +try: + from boto.s3.connection import S3Connection + from boto.s3.key import Key +except ImportError: + raise ImportError('Requires boto library') + +from shove import BaseStore + + +class S3Store(BaseStore): + + def __init__(self, engine=None, **kw): + super(S3Store, self).__init__(engine, **kw) + # key = Access Key, secret=Secret Access Key, bucket=bucket name + key, secret, bucket = kw.get('key'), kw.get('secret'), kw.get('bucket') + if engine is not None: + auth, bucket = engine.split('://')[1].split('@') + key, secret = auth.split(':') + # kw 'secure' = (True or False, use HTTPS) + self._conn = S3Connection(key, secret, kw.get('secure', False)) + buckets = self._conn.get_all_buckets() + # Use bucket if it exists + for b in buckets: + if b.name == bucket: + self._store = b + break + # Create bucket if it doesn't exist + else: + self._store = self._conn.create_bucket(bucket) + # Set bucket permission ('private', 'public-read', + # 'public-read-write', 'authenticated-read' + self._store.set_acl(kw.get('acl', 'private')) + # Updated flag used for avoiding network calls + self._updated, self._keys = True, None + + def __getitem__(self, key): + rkey = self._store.lookup(key) + if rkey is None: + raise KeyError(key) + # Fetch string + value = self.loads(rkey.get_contents_as_string()) + # Flag that the store has not been updated + self._updated = False + return value + + def __setitem__(self, key, value): + rkey = Key(self._store) + rkey.key = key + rkey.set_contents_from_string(self.dumps(value)) + # Flag that the store has been updated + self._updated = True + + def __delitem__(self, key): + try: + self._store.delete_key(key) + # Flag that the store has been updated + self._updated = True + except: + raise KeyError(key) + + def keys(self): + '''Returns a list of keys in the store.''' + return list(i[0] for i in self.items()) + + def items(self): + '''Returns a list of items from the store.''' + if self._updated or self._keys is None: + self._keys = self._store.get_all_keys() + return list((str(k.key), k) for k in self._keys) + + def iteritems(self): + '''Lazily returns items from the store.''' + for k in self.items(): + yield (k.key, k) + + +__all__ = ['S3Store'] diff --git a/lib/shove/store/simple.py b/lib/shove/store/simple.py new file mode 100644 index 00000000..8f7ebb33 --- /dev/null +++ b/lib/shove/store/simple.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +''' +Single-process in-memory store. + +The shove psuedo-URL for a simple store is: + +simple:// +''' + +from shove import BaseStore, SimpleBase + + +class SimpleStore(SimpleBase, BaseStore): + + '''Single-process in-memory store.''' + + def __init__(self, engine, **kw): + super(SimpleStore, self).__init__(engine, **kw) + + +__all__ = ['SimpleStore'] diff --git a/lib/shove/store/svn.py b/lib/shove/store/svn.py new file mode 100644 index 00000000..5bb8c33e --- /dev/null +++ b/lib/shove/store/svn.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +''' +subversion managed store. + +The shove psuedo-URL used for a subversion store that is password protected is: + +svn::?url= + +or for non-password protected repositories: + +svn://?url= + + is the local repository copy + is the URL of the subversion repository +''' + +import os +import urllib +import threading + +try: + import pysvn +except ImportError: + raise ImportError('Requires Python Subversion library') + +from shove import BaseStore, synchronized + + +class SvnStore(BaseStore): + + '''Class for subversion store.''' + + def __init__(self, engine=None, **kw): + super(SvnStore, self).__init__(engine, **kw) + # Get path, url from keywords if used + path, url = kw.get('path'), kw.get('url') + # Get username. password from keywords if used + user, password = kw.get('user'), kw.get('password') + # Process psuedo URL if used + if engine is not None: + path, query = engine.split('n://')[1].split('?') + url = query.split('=')[1] + # Check for username, password + if '@' in path: + auth, path = path.split('@') + user, password = auth.split(':') + path = urllib.url2pathname(path) + # Create subversion client + self._client = pysvn.Client() + # Assign username, password + if user is not None: + self._client.set_username(user) + if password is not None: + self._client.set_password(password) + # Verify that store exists in repository + try: + self._client.info2(url) + # Create store in repository if it doesn't exist + except pysvn.ClientError: + self._client.mkdir(url, 'Adding directory') + # Verify that local copy exists + try: + if self._client.info(path) is None: + self._client.checkout(url, path) + # Check it out if it doesn't exist + except pysvn.ClientError: + self._client.checkout(url, path) + self._path, self._url = path, url + # Lock + self._lock = threading.Condition() + + @synchronized + def __getitem__(self, key): + try: + return self.loads(self._client.cat(self._key_to_file(key))) + except: + raise KeyError(key) + + @synchronized + def __setitem__(self, key, value): + fname = self._key_to_file(key) + # Write value to file + open(fname, 'wb').write(self.dumps(value)) + # Add to repository + if key not in self: + self._client.add(fname) + self._client.checkin([fname], 'Adding %s' % fname) + + @synchronized + def __delitem__(self, key): + try: + fname = self._key_to_file(key) + self._client.remove(fname) + # Remove deleted value from repository + self._client.checkin([fname], 'Removing %s' % fname) + except: + raise KeyError(key) + + def _key_to_file(self, key): + '''Gives the filesystem path for a key.''' + return os.path.join(self._path, urllib.quote_plus(key)) + + @synchronized + def keys(self): + '''Returns a list of keys in the subversion repository.''' + return list(str(i.name.split('/')[-1]) for i + in self._client.ls(self._path)) + + +__all__ = ['SvnStore'] diff --git a/lib/shove/store/zodb.py b/lib/shove/store/zodb.py new file mode 100644 index 00000000..43768dde --- /dev/null +++ b/lib/shove/store/zodb.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +''' +Zope Object Database store frontend. + +shove's psuedo-URL for ZODB stores follows the form: + +zodb: + + +Where the path is a URL path to a ZODB FileStorage database. Alternatively, a +native pathname to a ZODB database can be passed as the 'engine' argument. +''' + +try: + import transaction + from ZODB import FileStorage, DB +except ImportError: + raise ImportError('Requires ZODB library') + +from shove.store import SyncStore + + +class ZodbStore(SyncStore): + + '''ZODB store front end.''' + + init = 'zodb://' + + def __init__(self, engine, **kw): + super(ZodbStore, self).__init__(engine, **kw) + # Handle psuedo-URL + self._storage = FileStorage.FileStorage(self._engine) + self._db = DB(self._storage) + self._connection = self._db.open() + self._store = self._connection.root() + # Keeps DB in synch through commits of transactions + self.sync = transaction.commit + + def close(self): + '''Closes all open storage and connections.''' + self.sync() + super(ZodbStore, self).close() + self._connection.close() + self._db.close() + self._storage.close() + + +__all__ = ['ZodbStore'] diff --git a/lib/shove/tests/__init__.py b/lib/shove/tests/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/lib/shove/tests/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/lib/shove/tests/test_bsddb_store.py b/lib/shove/tests/test_bsddb_store.py new file mode 100644 index 00000000..3de7896e --- /dev/null +++ b/lib/shove/tests/test_bsddb_store.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestBsdbStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('bsddb://test.db', compress=True) + + def tearDown(self): + import os + self.store.close() + os.remove('test.db') + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_cassandra_store.py b/lib/shove/tests/test_cassandra_store.py new file mode 100644 index 00000000..a5c60f6a --- /dev/null +++ b/lib/shove/tests/test_cassandra_store.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestCassandraStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + from pycassa.system_manager import SystemManager + system_manager = SystemManager('localhost:9160') + try: + system_manager.create_column_family('Foo', 'shove') + except: + pass + self.store = Shove('cassandra://localhost:9160/Foo/shove') + + def tearDown(self): + self.store.clear() + self.store.close() + from pycassa.system_manager import SystemManager + system_manager = SystemManager('localhost:9160') + system_manager.drop_column_family('Foo', 'shove') + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + +# def test_clear(self): +# self.store['max'] = 3 +# self.store['min'] = 6 +# self.store['pow'] = 7 +# self.store.clear() +# self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + +# def test_popitem(self): +# self.store['max'] = 3 +# self.store['min'] = 6 +# self.store['pow'] = 7 +# item = self.store.popitem() +# self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 +# self.store['pow'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store.setdefault('pow', 8), 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_db_cache.py b/lib/shove/tests/test_db_cache.py new file mode 100644 index 00000000..9dd27a06 --- /dev/null +++ b/lib/shove/tests/test_db_cache.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestDbCache(unittest.TestCase): + + initstring = 'sqlite:///' + + def setUp(self): + from shove.cache.db import DbCache + self.cache = DbCache(self.initstring) + + def tearDown(self): + self.cache = None + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.db import DbCache + cache = DbCache(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(2) + + def tmp(): + cache['test'] + self.assertRaises(KeyError, tmp) + + def test_cull(self): + from shove.cache.db import DbCache + cache = DbCache(self.initstring, max_entries=1) + cache['test'] = 'test' + cache['test2'] = 'test' + cache['test2'] = 'test' + self.assertEquals(len(cache), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_db_store.py b/lib/shove/tests/test_db_store.py new file mode 100644 index 00000000..1d9ad616 --- /dev/null +++ b/lib/shove/tests/test_db_store.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestDbStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('sqlite://', compress=True) + + def tearDown(self): + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_dbm_store.py b/lib/shove/tests/test_dbm_store.py new file mode 100644 index 00000000..e64ac9e7 --- /dev/null +++ b/lib/shove/tests/test_dbm_store.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestDbmStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('dbm://test.dbm', compress=True) + + def tearDown(self): + import os + self.store.close() + try: + os.remove('test.dbm.db') + except OSError: + pass + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.setdefault('how', 8) + self.assertEqual(self.store['how'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_durus_store.py b/lib/shove/tests/test_durus_store.py new file mode 100644 index 00000000..006fcc41 --- /dev/null +++ b/lib/shove/tests/test_durus_store.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestDurusStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('durus://test.durus', compress=True) + + def tearDown(self): + import os + self.store.close() + os.remove('test.durus') + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_file_cache.py b/lib/shove/tests/test_file_cache.py new file mode 100644 index 00000000..b288ce82 --- /dev/null +++ b/lib/shove/tests/test_file_cache.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestFileCache(unittest.TestCase): + + initstring = 'file://test' + + def setUp(self): + from shove.cache.file import FileCache + self.cache = FileCache(self.initstring) + + def tearDown(self): + import os + self.cache = None + for x in os.listdir('test'): + os.remove(os.path.join('test', x)) + os.rmdir('test') + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.file import FileCache + cache = FileCache(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(2) + + def tmp(): + cache['test'] + self.assertRaises(KeyError, tmp) + + def test_cull(self): + from shove.cache.file import FileCache + cache = FileCache(self.initstring, max_entries=1) + cache['test'] = 'test' + cache['test2'] = 'test' + num = len(cache) + self.assertEquals(num, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_file_store.py b/lib/shove/tests/test_file_store.py new file mode 100644 index 00000000..35643ced --- /dev/null +++ b/lib/shove/tests/test_file_store.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestFileStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('file://test', compress=True) + + def tearDown(self): + import os + self.store.close() + for x in os.listdir('test'): + os.remove(os.path.join('test', x)) + os.rmdir('test') + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_ftp_store.py b/lib/shove/tests/test_ftp_store.py new file mode 100644 index 00000000..17679a2c --- /dev/null +++ b/lib/shove/tests/test_ftp_store.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestFtpStore(unittest.TestCase): + + ftpstring = 'put ftp string here' + + def setUp(self): + from shove import Shove + self.store = Shove(self.ftpstring, compress=True) + + def tearDown(self): + self.store.clear() + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + self.assertEqual(len(self.store), 2) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + item = self.store.popitem() + self.store.sync() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.store.sync() + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.sync() + self.store.update(tstore) + self.store.sync() + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_hdf5_store.py b/lib/shove/tests/test_hdf5_store.py new file mode 100644 index 00000000..b1342ecf --- /dev/null +++ b/lib/shove/tests/test_hdf5_store.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +import unittest2 + + +class TestHDF5Store(unittest2.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('hdf5://test.hdf5/test') + + def tearDown(self): + import os + self.store.close() + try: + os.remove('test.hdf5') + except OSError: + pass + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.setdefault('bow', 8) + self.assertEqual(self.store['bow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest2.main() diff --git a/lib/shove/tests/test_leveldb_store.py b/lib/shove/tests/test_leveldb_store.py new file mode 100644 index 00000000..b3a3d177 --- /dev/null +++ b/lib/shove/tests/test_leveldb_store.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- + +import unittest2 + + +class TestLevelDBStore(unittest2.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('leveldb://test', compress=True) + + def tearDown(self): + import shutil + shutil.rmtree('test') + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.setdefault('bow', 8) + self.assertEqual(self.store['bow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest2.main() diff --git a/lib/shove/tests/test_memcached_cache.py b/lib/shove/tests/test_memcached_cache.py new file mode 100644 index 00000000..98f0b96d --- /dev/null +++ b/lib/shove/tests/test_memcached_cache.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestMemcached(unittest.TestCase): + + initstring = 'memcache://localhost:11211' + + def setUp(self): + from shove.cache.memcached import MemCached + self.cache = MemCached(self.initstring) + + def tearDown(self): + self.cache = None + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.memcached import MemCached + cache = MemCached(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(1) + + def tmp(): + cache['test'] + self.assertRaises(KeyError, tmp) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_memory_cache.py b/lib/shove/tests/test_memory_cache.py new file mode 100644 index 00000000..87749cdb --- /dev/null +++ b/lib/shove/tests/test_memory_cache.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestMemoryCache(unittest.TestCase): + + initstring = 'memory://' + + def setUp(self): + from shove.cache.memory import MemoryCache + self.cache = MemoryCache(self.initstring) + + def tearDown(self): + self.cache = None + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.memory import MemoryCache + cache = MemoryCache(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(1) + + def tmp(): + cache['test'] + self.assertRaises(KeyError, tmp) + + def test_cull(self): + from shove.cache.memory import MemoryCache + cache = MemoryCache(self.initstring, max_entries=1) + cache['test'] = 'test' + cache['test2'] = 'test' + cache['test2'] = 'test' + self.assertEquals(len(cache), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_memory_store.py b/lib/shove/tests/test_memory_store.py new file mode 100644 index 00000000..12e505dd --- /dev/null +++ b/lib/shove/tests/test_memory_store.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestMemoryStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('memory://', compress=True) + + def tearDown(self): + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_redis_cache.py b/lib/shove/tests/test_redis_cache.py new file mode 100644 index 00000000..c8e9b8db --- /dev/null +++ b/lib/shove/tests/test_redis_cache.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestRedisCache(unittest.TestCase): + + initstring = 'redis://localhost:6379/0' + + def setUp(self): + from shove.cache.redisdb import RedisCache + self.cache = RedisCache(self.initstring) + + def tearDown(self): + self.cache = None + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.redisdb import RedisCache + cache = RedisCache(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(3) + def tmp(): #@IgnorePep8 + return cache['test'] + self.assertRaises(KeyError, tmp) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_redis_store.py b/lib/shove/tests/test_redis_store.py new file mode 100644 index 00000000..06b1e0e9 --- /dev/null +++ b/lib/shove/tests/test_redis_store.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestRedisStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('redis://localhost:6379/0') + + def tearDown(self): + self.store.clear() + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store.setdefault('pow', 8), 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_s3_store.py b/lib/shove/tests/test_s3_store.py new file mode 100644 index 00000000..8a0f08d7 --- /dev/null +++ b/lib/shove/tests/test_s3_store.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestS3Store(unittest.TestCase): + + s3string = 's3 test string here' + + def setUp(self): + from shove import Shove + self.store = Shove(self.s3string, compress=True) + + def tearDown(self): + self.store.clear() + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + self.assertEqual(len(self.store), 2) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + item = self.store.popitem() + self.store.sync() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.store.sync() + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.sync() + self.store.update(tstore) + self.store.sync() + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_simple_cache.py b/lib/shove/tests/test_simple_cache.py new file mode 100644 index 00000000..8cd1830c --- /dev/null +++ b/lib/shove/tests/test_simple_cache.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestSimpleCache(unittest.TestCase): + + initstring = 'simple://' + + def setUp(self): + from shove.cache.simple import SimpleCache + self.cache = SimpleCache(self.initstring) + + def tearDown(self): + self.cache = None + + def test_getitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_setitem(self): + self.cache['test'] = 'test' + self.assertEqual(self.cache['test'], 'test') + + def test_delitem(self): + self.cache['test'] = 'test' + del self.cache['test'] + self.assertEqual('test' in self.cache, False) + + def test_get(self): + self.assertEqual(self.cache.get('min'), None) + + def test_timeout(self): + import time + from shove.cache.simple import SimpleCache + cache = SimpleCache(self.initstring, timeout=1) + cache['test'] = 'test' + time.sleep(1) + + def tmp(): + cache['test'] + self.assertRaises(KeyError, tmp) + + def test_cull(self): + from shove.cache.simple import SimpleCache + cache = SimpleCache(self.initstring, max_entries=1) + cache['test'] = 'test' + cache['test2'] = 'test' + cache['test2'] = 'test' + self.assertEquals(len(cache), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_simple_store.py b/lib/shove/tests/test_simple_store.py new file mode 100644 index 00000000..d2431ec5 --- /dev/null +++ b/lib/shove/tests/test_simple_store.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestSimpleStore(unittest.TestCase): + + def setUp(self): + from shove import Shove + self.store = Shove('simple://', compress=True) + + def tearDown(self): + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_svn_store.py b/lib/shove/tests/test_svn_store.py new file mode 100644 index 00000000..b3103816 --- /dev/null +++ b/lib/shove/tests/test_svn_store.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestSvnStore(unittest.TestCase): + + svnstring = 'SVN test string here' + + def setUp(self): + from shove import Shove + self.store = Shove(self.svnstring, compress=True) + + def tearDown(self): + self.store.clear() + self.store.close() + + def test__getitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.store.sync() + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.store.sync() + tstore.sync() + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + self.assertEqual(len(self.store), 2) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store.sync() + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + item = self.store.popitem() + self.store.sync() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.store.sync() + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.sync() + self.store.update(tstore) + self.store.sync() + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.sync() + slist = self.store.keys() + self.assertEqual('min' in slist, True) + +if __name__ == '__main__': + unittest.main() diff --git a/lib/shove/tests/test_zodb_store.py b/lib/shove/tests/test_zodb_store.py new file mode 100644 index 00000000..9d979fea --- /dev/null +++ b/lib/shove/tests/test_zodb_store.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +import unittest + + +class TestZodbStore(unittest.TestCase): + + init = 'zodb://test.db' + + def setUp(self): + from shove import Shove + self.store = Shove(self.init, compress=True) + + def tearDown(self): + self.store.close() + import os + os.remove('test.db') + os.remove('test.db.index') + os.remove('test.db.tmp') + os.remove('test.db.lock') + + def test__getitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__setitem__(self): + self.store['max'] = 3 + self.assertEqual(self.store['max'], 3) + + def test__delitem__(self): + self.store['max'] = 3 + del self.store['max'] + self.assertEqual('max' in self.store, False) + + def test_get(self): + self.store['max'] = 3 + self.assertEqual(self.store.get('min'), None) + + def test__cmp__(self): + from shove import Shove + tstore = Shove() + self.store['max'] = 3 + tstore['max'] = 3 + self.assertEqual(self.store, tstore) + + def test__len__(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.assertEqual(len(self.store), 2) + + def test_close(self): + self.store.close() + self.assertEqual(self.store, None) + + def test_clear(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + self.store.clear() + self.assertEqual(len(self.store), 0) + + def test_items(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.items()) + self.assertEqual(('min', 6) in slist, True) + + def test_iteritems(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iteritems()) + self.assertEqual(('min', 6) in slist, True) + + def test_iterkeys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.iterkeys()) + self.assertEqual('min' in slist, True) + + def test_itervalues(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = list(self.store.itervalues()) + self.assertEqual(6 in slist, True) + + def test_pop(self): + self.store['max'] = 3 + self.store['min'] = 6 + item = self.store.pop('min') + self.assertEqual(item, 6) + + def test_popitem(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + item = self.store.popitem() + self.assertEqual(len(item) + len(self.store), 4) + + def test_setdefault(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['powl'] = 7 + self.store.setdefault('pow', 8) + self.assertEqual(self.store['pow'], 8) + + def test_update(self): + from shove import Shove + tstore = Shove() + tstore['max'] = 3 + tstore['min'] = 6 + tstore['pow'] = 7 + self.store['max'] = 2 + self.store['min'] = 3 + self.store['pow'] = 7 + self.store.update(tstore) + self.assertEqual(self.store['min'], 6) + + def test_values(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.values() + self.assertEqual(6 in slist, True) + + def test_keys(self): + self.store['max'] = 3 + self.store['min'] = 6 + self.store['pow'] = 7 + slist = self.store.keys() + self.assertEqual('min' in slist, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py new file mode 100644 index 00000000..67155e0f --- /dev/null +++ b/lib/sqlalchemy/__init__.py @@ -0,0 +1,133 @@ +# sqlalchemy/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +from .sql import ( + alias, + and_, + asc, + between, + bindparam, + case, + cast, + collate, + delete, + desc, + distinct, + except_, + except_all, + exists, + extract, + false, + func, + insert, + intersect, + intersect_all, + join, + literal, + literal_column, + modifier, + not_, + null, + or_, + outerjoin, + outparam, + over, + select, + subquery, + text, + true, + tuple_, + type_coerce, + union, + union_all, + update, + ) + +from .types import ( + BIGINT, + BINARY, + BLOB, + BOOLEAN, + BigInteger, + Binary, + Boolean, + CHAR, + CLOB, + DATE, + DATETIME, + DECIMAL, + Date, + DateTime, + Enum, + FLOAT, + Float, + INT, + INTEGER, + Integer, + Interval, + LargeBinary, + NCHAR, + NVARCHAR, + NUMERIC, + Numeric, + PickleType, + REAL, + SMALLINT, + SmallInteger, + String, + TEXT, + TIME, + TIMESTAMP, + Text, + Time, + TypeDecorator, + Unicode, + UnicodeText, + VARBINARY, + VARCHAR, + ) + + +from .schema import ( + CheckConstraint, + Column, + ColumnDefault, + Constraint, + DefaultClause, + FetchedValue, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + PassiveDefault, + PrimaryKeyConstraint, + Sequence, + Table, + ThreadLocalMetaData, + UniqueConstraint, + DDL, +) + + +from .inspection import inspect +from .engine import create_engine, engine_from_config + +__version__ = '0.9.4' + +def __go(lcls): + global __all__ + + from . import events + from . import util as _sa_util + + import inspect as _inspect + + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) + + _sa_util.dependencies.resolve_all("sqlalchemy") +__go(locals()) \ No newline at end of file diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c new file mode 100644 index 00000000..d5681776 --- /dev/null +++ b/lib/sqlalchemy/cextension/processors.c @@ -0,0 +1,706 @@ +/* +processors.c +Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include +#include + +#define MODULE_NAME "cprocessors" +#define MODULE_DOC "Module containing C versions of data processing functions." + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + +static PyObject * +int_to_boolean(PyObject *self, PyObject *arg) +{ + long l = 0; + PyObject *res; + + if (arg == Py_None) + Py_RETURN_NONE; + + +#if PY_MAJOR_VERSION >= 3 + l = PyLong_AsLong(arg); +#else + l = PyInt_AsLong(arg); +#endif + if (l == 0) { + res = Py_False; + } else if (l == 1) { + res = Py_True; + } else if ((l == -1) && PyErr_Occurred()) { + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else { + PyErr_SetString(PyExc_ValueError, + "int_to_boolean only accepts None, 0 or 1"); + return NULL; + } + + Py_INCREF(res); + return res; +} + +static PyObject * +to_str(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyObject_Str(arg); +} + +static PyObject * +to_float(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyNumber_Float(arg); +} + +static PyObject * +str_to_datetime(PyObject *self, PyObject *arg) +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *bytes; + PyObject *err_bytes; +#endif + const char *str; + int numparsed; + unsigned int year, month, day, hour, minute, second, microsecond = 0; + PyObject *err_repr; + + if (arg == Py_None) + Py_RETURN_NONE; + +#if PY_MAJOR_VERSION >= 3 + bytes = PyUnicode_AsASCIIString(arg); + if (bytes == NULL) + str = NULL; + else + str = PyBytes_AS_STRING(bytes); +#else + str = PyString_AsString(arg); +#endif + if (str == NULL) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse datetime string '%.200s' " + "- value is not a string.", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse datetime string '%.200s' " + "- value is not a string.", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "2000-01-01 00:00:00.". I don't know which is better, but they + should be coherent. + */ + numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day, + &hour, &minute, &second, µsecond); +#if PY_MAJOR_VERSION >= 3 + Py_DECREF(bytes); +#endif + if (numparsed < 6) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse datetime string: %.200s", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse datetime string: %.200s", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + return PyDateTime_FromDateAndTime(year, month, day, + hour, minute, second, microsecond); +} + +static PyObject * +str_to_time(PyObject *self, PyObject *arg) +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *bytes; + PyObject *err_bytes; +#endif + const char *str; + int numparsed; + unsigned int hour, minute, second, microsecond = 0; + PyObject *err_repr; + + if (arg == Py_None) + Py_RETURN_NONE; + +#if PY_MAJOR_VERSION >= 3 + bytes = PyUnicode_AsASCIIString(arg); + if (bytes == NULL) + str = NULL; + else + str = PyBytes_AS_STRING(bytes); +#else + str = PyString_AsString(arg); +#endif + if (str == NULL) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; + +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse time string '%.200s' - value is not a string.", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse time string '%.200s' - value is not a string.", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "00:00:00.". I don't know which is better, but they should be + coherent. + */ + numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second, + µsecond); +#if PY_MAJOR_VERSION >= 3 + Py_DECREF(bytes); +#endif + if (numparsed < 3) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse time string: %.200s", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse time string: %.200s", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + return PyTime_FromTime(hour, minute, second, microsecond); +} + +static PyObject * +str_to_date(PyObject *self, PyObject *arg) +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *bytes; + PyObject *err_bytes; +#endif + const char *str; + int numparsed; + unsigned int year, month, day; + PyObject *err_repr; + + if (arg == Py_None) + Py_RETURN_NONE; + +#if PY_MAJOR_VERSION >= 3 + bytes = PyUnicode_AsASCIIString(arg); + if (bytes == NULL) + str = NULL; + else + str = PyBytes_AS_STRING(bytes); +#else + str = PyString_AsString(arg); +#endif + if (str == NULL) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse date string '%.200s' - value is not a string.", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse date string '%.200s' - value is not a string.", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + + numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day); +#if PY_MAJOR_VERSION >= 3 + Py_DECREF(bytes); +#endif + if (numparsed != 3) { + err_repr = PyObject_Repr(arg); + if (err_repr == NULL) + return NULL; +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(err_repr); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_ValueError, + "Couldn't parse date string: %.200s", + PyBytes_AS_STRING(err_bytes)); + Py_DECREF(err_bytes); +#else + PyErr_Format( + PyExc_ValueError, + "Couldn't parse date string: %.200s", + PyString_AsString(err_repr)); +#endif + Py_DECREF(err_repr); + return NULL; + } + return PyDate_FromDate(year, month, day); +} + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *encoding; + PyObject *errors; +} UnicodeResultProcessor; + +typedef struct { + PyObject_HEAD + PyObject *type; + PyObject *format; +} DecimalResultProcessor; + + + +/************************** + * UnicodeResultProcessor * + **************************/ + +static int +UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *encoding, *errors = NULL; + static char *kwlist[] = {"encoding", "errors", NULL}; + +#if PY_MAJOR_VERSION >= 3 + if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist, + &encoding, &errors)) + return -1; +#else + if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist, + &encoding, &errors)) + return -1; +#endif + +#if PY_MAJOR_VERSION >= 3 + encoding = PyUnicode_AsASCIIString(encoding); +#else + Py_INCREF(encoding); +#endif + self->encoding = encoding; + + if (errors) { +#if PY_MAJOR_VERSION >= 3 + errors = PyUnicode_AsASCIIString(errors); +#else + Py_INCREF(errors); +#endif + } else { +#if PY_MAJOR_VERSION >= 3 + errors = PyBytes_FromString("strict"); +#else + errors = PyString_FromString("strict"); +#endif + if (errors == NULL) + return -1; + } + self->errors = errors; + + return 0; +} + +static PyObject * +UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value) +{ + const char *encoding, *errors; + char *str; + Py_ssize_t len; + + if (value == Py_None) + Py_RETURN_NONE; + +#if PY_MAJOR_VERSION >= 3 + if (PyBytes_AsStringAndSize(value, &str, &len)) + return NULL; + + encoding = PyBytes_AS_STRING(self->encoding); + errors = PyBytes_AS_STRING(self->errors); +#else + if (PyString_AsStringAndSize(value, &str, &len)) + return NULL; + + encoding = PyString_AS_STRING(self->encoding); + errors = PyString_AS_STRING(self->errors); +#endif + + return PyUnicode_Decode(str, len, encoding, errors); +} + +static PyObject * +UnicodeResultProcessor_conditional_process(UnicodeResultProcessor *self, PyObject *value) +{ + const char *encoding, *errors; + char *str; + Py_ssize_t len; + + if (value == Py_None) + Py_RETURN_NONE; + +#if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(value) == 1) { + Py_INCREF(value); + return value; + } + + if (PyBytes_AsStringAndSize(value, &str, &len)) + return NULL; + + encoding = PyBytes_AS_STRING(self->encoding); + errors = PyBytes_AS_STRING(self->errors); +#else + + if (PyUnicode_Check(value) == 1) { + Py_INCREF(value); + return value; + } + + if (PyString_AsStringAndSize(value, &str, &len)) + return NULL; + + + encoding = PyString_AS_STRING(self->encoding); + errors = PyString_AS_STRING(self->errors); +#endif + + return PyUnicode_Decode(str, len, encoding, errors); +} + +static void +UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self) +{ + Py_XDECREF(self->encoding); + Py_XDECREF(self->errors); +#if PY_MAJOR_VERSION >= 3 + Py_TYPE(self)->tp_free((PyObject*)self); +#else + self->ob_type->tp_free((PyObject*)self); +#endif +} + +static PyMethodDef UnicodeResultProcessor_methods[] = { + {"process", (PyCFunction)UnicodeResultProcessor_process, METH_O, + "The value processor itself."}, + {"conditional_process", (PyCFunction)UnicodeResultProcessor_conditional_process, METH_O, + "Conditional version of the value processor."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject UnicodeResultProcessorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */ + sizeof(UnicodeResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "UnicodeResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + UnicodeResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)UnicodeResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +/************************** + * DecimalResultProcessor * + **************************/ + +static int +DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *type, *format; + +#if PY_MAJOR_VERSION >= 3 + if (!PyArg_ParseTuple(args, "OU", &type, &format)) +#else + if (!PyArg_ParseTuple(args, "OS", &type, &format)) +#endif + return -1; + + Py_INCREF(type); + self->type = type; + + Py_INCREF(format); + self->format = format; + + return 0; +} + +static PyObject * +DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) +{ + PyObject *str, *result, *args; + + if (value == Py_None) + Py_RETURN_NONE; + + /* Decimal does not accept float values directly */ + /* SQLite can also give us an integer here (see [ticket:2432]) */ + /* XXX: starting with Python 3.1, we could use Decimal.from_float(f), + but the result wouldn't be the same */ + + args = PyTuple_Pack(1, value); + if (args == NULL) + return NULL; + +#if PY_MAJOR_VERSION >= 3 + str = PyUnicode_Format(self->format, args); +#else + str = PyString_Format(self->format, args); +#endif + + Py_DECREF(args); + if (str == NULL) + return NULL; + + result = PyObject_CallFunctionObjArgs(self->type, str, NULL); + Py_DECREF(str); + return result; +} + +static void +DecimalResultProcessor_dealloc(DecimalResultProcessor *self) +{ + Py_XDECREF(self->type); + Py_XDECREF(self->format); +#if PY_MAJOR_VERSION >= 3 + Py_TYPE(self)->tp_free((PyObject*)self); +#else + self->ob_type->tp_free((PyObject*)self); +#endif +} + +static PyMethodDef DecimalResultProcessor_methods[] = { + {"process", (PyCFunction)DecimalResultProcessor_process, METH_O, + "The value processor itself."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject DecimalResultProcessorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "sqlalchemy.DecimalResultProcessor", /* tp_name */ + sizeof(DecimalResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "DecimalResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DecimalResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DecimalResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +static PyMethodDef module_methods[] = { + {"int_to_boolean", int_to_boolean, METH_O, + "Convert an integer to a boolean."}, + {"to_str", to_str, METH_O, + "Convert any value to its string representation."}, + {"to_float", to_float, METH_O, + "Convert any value to its floating point representation."}, + {"str_to_datetime", str_to_datetime, METH_O, + "Convert an ISO string to a datetime.datetime object."}, + {"str_to_time", str_to_time, METH_O, + "Convert an ISO string to a datetime.time object."}, + {"str_to_date", str_to_date, METH_O, + "Convert an ISO string to a datetime.date object."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +#if PY_MAJOR_VERSION >= 3 + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, + MODULE_DOC, + -1, + module_methods +}; + +#define INITERROR return NULL + +PyMODINIT_FUNC +PyInit_cprocessors(void) + +#else + +#define INITERROR return + +PyMODINIT_FUNC +initcprocessors(void) + +#endif + +{ + PyObject *m; + + UnicodeResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&UnicodeResultProcessorType) < 0) + INITERROR; + + DecimalResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&DecimalResultProcessorType) < 0) + INITERROR; + +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&module_def); +#else + m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); +#endif + if (m == NULL) + INITERROR; + + PyDateTime_IMPORT; + + Py_INCREF(&UnicodeResultProcessorType); + PyModule_AddObject(m, "UnicodeResultProcessor", + (PyObject *)&UnicodeResultProcessorType); + + Py_INCREF(&DecimalResultProcessorType); + PyModule_AddObject(m, "DecimalResultProcessor", + (PyObject *)&DecimalResultProcessorType); + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif +} diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c new file mode 100644 index 00000000..218c7b80 --- /dev/null +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -0,0 +1,718 @@ +/* +resultproxy.c +Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include + +#define MODULE_NAME "cresultproxy" +#define MODULE_DOC "Module containing C versions of core ResultProxy classes." + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +typedef Py_ssize_t (*lenfunc)(PyObject *); +#define PyInt_FromSsize_t(x) PyInt_FromLong(x) +typedef intargfunc ssizeargfunc; +#endif + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *parent; + PyObject *row; + PyObject *processors; + PyObject *keymap; +} BaseRowProxy; + +/**************** + * BaseRowProxy * + ****************/ + +static PyObject * +safe_rowproxy_reconstructor(PyObject *self, PyObject *args) +{ + PyObject *cls, *state, *tmp; + BaseRowProxy *obj; + + if (!PyArg_ParseTuple(args, "OO", &cls, &state)) + return NULL; + + obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls); + if (obj == NULL) + return NULL; + + tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state); + if (tmp == NULL) { + Py_DECREF(obj); + return NULL; + } + Py_DECREF(tmp); + + if (obj->parent == NULL || obj->row == NULL || + obj->processors == NULL || obj->keymap == NULL) { + PyErr_SetString(PyExc_RuntimeError, + "__setstate__ for BaseRowProxy subclasses must set values " + "for parent, row, processors and keymap"); + Py_DECREF(obj); + return NULL; + } + + return (PyObject *)obj; +} + +static int +BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds) +{ + PyObject *parent, *row, *processors, *keymap; + + if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4, + &parent, &row, &processors, &keymap)) + return -1; + + Py_INCREF(parent); + self->parent = parent; + + if (!PySequence_Check(row)) { + PyErr_SetString(PyExc_TypeError, "row must be a sequence"); + return -1; + } + Py_INCREF(row); + self->row = row; + + if (!PyList_CheckExact(processors)) { + PyErr_SetString(PyExc_TypeError, "processors must be a list"); + return -1; + } + Py_INCREF(processors); + self->processors = processors; + + if (!PyDict_CheckExact(keymap)) { + PyErr_SetString(PyExc_TypeError, "keymap must be a dict"); + return -1; + } + Py_INCREF(keymap); + self->keymap = keymap; + + return 0; +} + +/* We need the reduce method because otherwise the default implementation + * does very weird stuff for pickle protocol 0 and 1. It calls + * BaseRowProxy.__new__(RowProxy_instance) upon *pickling*. + */ +static PyObject * +BaseRowProxy_reduce(PyObject *self) +{ + PyObject *method, *state; + PyObject *module, *reconstructor, *cls; + + method = PyObject_GetAttrString(self, "__getstate__"); + if (method == NULL) + return NULL; + + state = PyObject_CallObject(method, NULL); + Py_DECREF(method); + if (state == NULL) + return NULL; + + module = PyImport_ImportModule("sqlalchemy.engine.result"); + if (module == NULL) + return NULL; + + reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor"); + Py_DECREF(module); + if (reconstructor == NULL) { + Py_DECREF(state); + return NULL; + } + + cls = PyObject_GetAttrString(self, "__class__"); + if (cls == NULL) { + Py_DECREF(reconstructor); + Py_DECREF(state); + return NULL; + } + + return Py_BuildValue("(N(NN))", reconstructor, cls, state); +} + +static void +BaseRowProxy_dealloc(BaseRowProxy *self) +{ + Py_XDECREF(self->parent); + Py_XDECREF(self->row); + Py_XDECREF(self->processors); + Py_XDECREF(self->keymap); +#if PY_MAJOR_VERSION >= 3 + Py_TYPE(self)->tp_free((PyObject *)self); +#else + self->ob_type->tp_free((PyObject *)self); +#endif +} + +static PyObject * +BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple) +{ + Py_ssize_t num_values, num_processors; + PyObject **valueptr, **funcptr, **resultptr; + PyObject *func, *result, *processed_value, *values_fastseq; + + num_values = PySequence_Length(values); + num_processors = PyList_Size(processors); + if (num_values != num_processors) { + PyErr_Format(PyExc_RuntimeError, + "number of values in row (%d) differ from number of column " + "processors (%d)", + (int)num_values, (int)num_processors); + return NULL; + } + + if (astuple) { + result = PyTuple_New(num_values); + } else { + result = PyList_New(num_values); + } + if (result == NULL) + return NULL; + + values_fastseq = PySequence_Fast(values, "row must be a sequence"); + if (values_fastseq == NULL) + return NULL; + + valueptr = PySequence_Fast_ITEMS(values_fastseq); + funcptr = PySequence_Fast_ITEMS(processors); + resultptr = PySequence_Fast_ITEMS(result); + while (--num_values >= 0) { + func = *funcptr; + if (func != Py_None) { + processed_value = PyObject_CallFunctionObjArgs(func, *valueptr, + NULL); + if (processed_value == NULL) { + Py_DECREF(values_fastseq); + Py_DECREF(result); + return NULL; + } + *resultptr = processed_value; + } else { + Py_INCREF(*valueptr); + *resultptr = *valueptr; + } + valueptr++; + funcptr++; + resultptr++; + } + Py_DECREF(values_fastseq); + return result; +} + +static PyListObject * +BaseRowProxy_values(BaseRowProxy *self) +{ + return (PyListObject *)BaseRowProxy_processvalues(self->row, + self->processors, 0); +} + +static PyObject * +BaseRowProxy_iter(BaseRowProxy *self) +{ + PyObject *values, *result; + + values = BaseRowProxy_processvalues(self->row, self->processors, 1); + if (values == NULL) + return NULL; + + result = PyObject_GetIter(values); + Py_DECREF(values); + if (result == NULL) + return NULL; + + return result; +} + +static Py_ssize_t +BaseRowProxy_length(BaseRowProxy *self) +{ + return PySequence_Length(self->row); +} + +static PyObject * +BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) +{ + PyObject *processors, *values; + PyObject *processor, *value, *processed_value; + PyObject *row, *record, *result, *indexobject; + PyObject *exc_module, *exception, *cstr_obj; +#if PY_MAJOR_VERSION >= 3 + PyObject *bytes; +#endif + char *cstr_key; + long index; + int key_fallback = 0; + int tuple_check = 0; + +#if PY_MAJOR_VERSION < 3 + if (PyInt_CheckExact(key)) { + index = PyInt_AS_LONG(key); + } +#endif + + if (PyLong_CheckExact(key)) { + index = PyLong_AsLong(key); + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else if (PySlice_Check(key)) { + values = PyObject_GetItem(self->row, key); + if (values == NULL) + return NULL; + + processors = PyObject_GetItem(self->processors, key); + if (processors == NULL) { + Py_DECREF(values); + return NULL; + } + + result = BaseRowProxy_processvalues(values, processors, 1); + Py_DECREF(values); + Py_DECREF(processors); + return result; + } else { + record = PyDict_GetItem((PyObject *)self->keymap, key); + if (record == NULL) { + record = PyObject_CallMethod(self->parent, "_key_fallback", + "O", key); + if (record == NULL) + return NULL; + key_fallback = 1; + } + + indexobject = PyTuple_GetItem(record, 2); + if (indexobject == NULL) + return NULL; + + if (key_fallback) { + Py_DECREF(record); + } + + if (indexobject == Py_None) { + exc_module = PyImport_ImportModule("sqlalchemy.exc"); + if (exc_module == NULL) + return NULL; + + exception = PyObject_GetAttrString(exc_module, + "InvalidRequestError"); + Py_DECREF(exc_module); + if (exception == NULL) + return NULL; + + // wow. this seems quite excessive. + cstr_obj = PyObject_Str(key); + if (cstr_obj == NULL) + return NULL; + +/* + FIXME: raise encoding error exception (in both versions below) + if the key contains non-ascii chars, instead of an + InvalidRequestError without any message like in the + python version. +*/ +#if PY_MAJOR_VERSION >= 3 + bytes = PyUnicode_AsASCIIString(cstr_obj); + if (bytes == NULL) + return NULL; + cstr_key = PyBytes_AS_STRING(bytes); +#else + cstr_key = PyString_AsString(cstr_obj); +#endif + if (cstr_key == NULL) { + Py_DECREF(cstr_obj); + return NULL; + } + Py_DECREF(cstr_obj); + + PyErr_Format(exception, + "Ambiguous column name '%.200s' in result set! " + "try 'use_labels' option on select statement.", cstr_key); + return NULL; + } + +#if PY_MAJOR_VERSION >= 3 + index = PyLong_AsLong(indexobject); +#else + index = PyInt_AsLong(indexobject); +#endif + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } + processor = PyList_GetItem(self->processors, index); + if (processor == NULL) + return NULL; + + row = self->row; + if (PyTuple_CheckExact(row)) { + value = PyTuple_GetItem(row, index); + tuple_check = 1; + } + else { + value = PySequence_GetItem(row, index); + tuple_check = 0; + } + + if (value == NULL) + return NULL; + + if (processor != Py_None) { + processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL); + if (!tuple_check) { + Py_DECREF(value); + } + return processed_value; + } else { + if (tuple_check) { + Py_INCREF(value); + } + return value; + } +} + +static PyObject * +BaseRowProxy_getitem(PyObject *self, Py_ssize_t i) +{ + PyObject *index; + +#if PY_MAJOR_VERSION >= 3 + index = PyLong_FromSsize_t(i); +#else + index = PyInt_FromSsize_t(i); +#endif + return BaseRowProxy_subscript((BaseRowProxy*)self, index); +} + +static PyObject * +BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) +{ + PyObject *tmp; +#if PY_MAJOR_VERSION >= 3 + PyObject *err_bytes; +#endif + + if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + return NULL; + PyErr_Clear(); + } + else + return tmp; + + tmp = BaseRowProxy_subscript(self, name); + if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) { + +#if PY_MAJOR_VERSION >= 3 + err_bytes = PyUnicode_AsASCIIString(name); + if (err_bytes == NULL) + return NULL; + PyErr_Format( + PyExc_AttributeError, + "Could not locate column in row for column '%.200s'", + PyBytes_AS_STRING(err_bytes) + ); +#else + PyErr_Format( + PyExc_AttributeError, + "Could not locate column in row for column '%.200s'", + PyString_AsString(name) + ); +#endif + return NULL; + } + return tmp; +} + +/*********************** + * getters and setters * + ***********************/ + +static PyObject * +BaseRowProxy_getparent(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->parent); + return self->parent; +} + +static int +BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure) +{ + PyObject *module, *cls; + + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'parent' attribute"); + return -1; + } + + module = PyImport_ImportModule("sqlalchemy.engine.result"); + if (module == NULL) + return -1; + + cls = PyObject_GetAttrString(module, "ResultMetaData"); + Py_DECREF(module); + if (cls == NULL) + return -1; + + if (PyObject_IsInstance(value, cls) != 1) { + PyErr_SetString(PyExc_TypeError, + "The 'parent' attribute value must be an instance of " + "ResultMetaData"); + return -1; + } + Py_DECREF(cls); + Py_XDECREF(self->parent); + Py_INCREF(value); + self->parent = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getrow(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->row); + return self->row; +} + +static int +BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'row' attribute"); + return -1; + } + + if (!PySequence_Check(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'row' attribute value must be a sequence"); + return -1; + } + + Py_XDECREF(self->row); + Py_INCREF(value); + self->row = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->processors); + return self->processors; +} + +static int +BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'processors' attribute"); + return -1; + } + + if (!PyList_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'processors' attribute value must be a list"); + return -1; + } + + Py_XDECREF(self->processors); + Py_INCREF(value); + self->processors = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->keymap); + return self->keymap; +} + +static int +BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'keymap' attribute"); + return -1; + } + + if (!PyDict_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'keymap' attribute value must be a dict"); + return -1; + } + + Py_XDECREF(self->keymap); + Py_INCREF(value); + self->keymap = value; + + return 0; +} + +static PyGetSetDef BaseRowProxy_getseters[] = { + {"_parent", + (getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent, + "ResultMetaData", + NULL}, + {"_row", + (getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow, + "Original row tuple", + NULL}, + {"_processors", + (getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors, + "list of type processors", + NULL}, + {"_keymap", + (getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap, + "Key to (processor, index) dict", + NULL}, + {NULL} +}; + +static PyMethodDef BaseRowProxy_methods[] = { + {"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS, + "Return the values represented by this BaseRowProxy as a list."}, + {"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS, + "Pickle support method."}, + {NULL} /* Sentinel */ +}; + +static PySequenceMethods BaseRowProxy_as_sequence = { + (lenfunc)BaseRowProxy_length, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)BaseRowProxy_getitem, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods BaseRowProxy_as_mapping = { + (lenfunc)BaseRowProxy_length, /* mp_length */ + (binaryfunc)BaseRowProxy_subscript, /* mp_subscript */ + 0 /* mp_ass_subscript */ +}; + +static PyTypeObject BaseRowProxyType = { + PyVarObject_HEAD_INIT(NULL, 0) + "sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */ + sizeof(BaseRowProxy), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)BaseRowProxy_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &BaseRowProxy_as_sequence, /* tp_as_sequence */ + &BaseRowProxy_as_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)BaseRowProxy_getattro,/* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)BaseRowProxy_iter, /* tp_iter */ + 0, /* tp_iternext */ + BaseRowProxy_methods, /* tp_methods */ + 0, /* tp_members */ + BaseRowProxy_getseters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)BaseRowProxy_init, /* tp_init */ + 0, /* tp_alloc */ + 0 /* tp_new */ +}; + +static PyMethodDef module_methods[] = { + {"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS, + "reconstruct a RowProxy instance from its pickled form."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +#if PY_MAJOR_VERSION >= 3 + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, + MODULE_DOC, + -1, + module_methods +}; + +#define INITERROR return NULL + +PyMODINIT_FUNC +PyInit_cresultproxy(void) + +#else + +#define INITERROR return + +PyMODINIT_FUNC +initcresultproxy(void) + +#endif + +{ + PyObject *m; + + BaseRowProxyType.tp_new = PyType_GenericNew; + if (PyType_Ready(&BaseRowProxyType) < 0) + INITERROR; + +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&module_def); +#else + m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); +#endif + if (m == NULL) + INITERROR; + + Py_INCREF(&BaseRowProxyType); + PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType); + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif +} diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c new file mode 100644 index 00000000..377ba8a8 --- /dev/null +++ b/lib/sqlalchemy/cextension/utils.c @@ -0,0 +1,225 @@ +/* +utils.c +Copyright (C) 2012-2014 the SQLAlchemy authors and contributors + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include + +#define MODULE_NAME "cutils" +#define MODULE_DOC "Module containing C versions of utility functions." + +/* + Given arguments from the calling form *multiparams, **params, + return a list of bind parameter structures, usually a list of + dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + */ +static PyObject * +distill_params(PyObject *self, PyObject *args) +{ + PyObject *multiparams, *params; + PyObject *enclosing_list, *double_enclosing_list; + PyObject *zero_element, *zero_element_item; + Py_ssize_t multiparam_size, zero_element_length; + + if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, ¶ms)) { + return NULL; + } + + if (multiparams != Py_None) { + multiparam_size = PyTuple_Size(multiparams); + if (multiparam_size < 0) { + return NULL; + } + } + else { + multiparam_size = 0; + } + + if (multiparam_size == 0) { + if (params != Py_None && PyDict_Size(params) != 0) { + enclosing_list = PyList_New(1); + if (enclosing_list == NULL) { + return NULL; + } + Py_INCREF(params); + if (PyList_SetItem(enclosing_list, 0, params) == -1) { + Py_DECREF(params); + Py_DECREF(enclosing_list); + return NULL; + } + } + else { + enclosing_list = PyList_New(0); + if (enclosing_list == NULL) { + return NULL; + } + } + return enclosing_list; + } + else if (multiparam_size == 1) { + zero_element = PyTuple_GetItem(multiparams, 0); + if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) { + zero_element_length = PySequence_Length(zero_element); + + if (zero_element_length != 0) { + zero_element_item = PySequence_GetItem(zero_element, 0); + if (zero_element_item == NULL) { + return NULL; + } + } + else { + zero_element_item = NULL; + } + + if (zero_element_length == 0 || + ( + PyObject_HasAttrString(zero_element_item, "__iter__") && + !PyObject_HasAttrString(zero_element_item, "strip") + ) + ) { + /* + * execute(stmt, [{}, {}, {}, ...]) + * execute(stmt, [(), (), (), ...]) + */ + Py_XDECREF(zero_element_item); + Py_INCREF(zero_element); + return zero_element; + } + else { + /* + * execute(stmt, ("value", "value")) + */ + Py_XDECREF(zero_element_item); + enclosing_list = PyList_New(1); + if (enclosing_list == NULL) { + return NULL; + } + Py_INCREF(zero_element); + if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { + Py_DECREF(zero_element); + Py_DECREF(enclosing_list); + return NULL; + } + return enclosing_list; + } + } + else if (PyObject_HasAttrString(zero_element, "keys")) { + /* + * execute(stmt, {"key":"value"}) + */ + enclosing_list = PyList_New(1); + if (enclosing_list == NULL) { + return NULL; + } + Py_INCREF(zero_element); + if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { + Py_DECREF(zero_element); + Py_DECREF(enclosing_list); + return NULL; + } + return enclosing_list; + } else { + enclosing_list = PyList_New(1); + if (enclosing_list == NULL) { + return NULL; + } + double_enclosing_list = PyList_New(1); + if (double_enclosing_list == NULL) { + Py_DECREF(enclosing_list); + return NULL; + } + Py_INCREF(zero_element); + if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { + Py_DECREF(zero_element); + Py_DECREF(enclosing_list); + Py_DECREF(double_enclosing_list); + return NULL; + } + if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) { + Py_DECREF(zero_element); + Py_DECREF(enclosing_list); + Py_DECREF(double_enclosing_list); + return NULL; + } + return double_enclosing_list; + } + } + else { + zero_element = PyTuple_GetItem(multiparams, 0); + if (PyObject_HasAttrString(zero_element, "__iter__") && + !PyObject_HasAttrString(zero_element, "strip") + ) { + Py_INCREF(multiparams); + return multiparams; + } + else { + enclosing_list = PyList_New(1); + if (enclosing_list == NULL) { + return NULL; + } + Py_INCREF(multiparams); + if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) { + Py_DECREF(multiparams); + Py_DECREF(enclosing_list); + return NULL; + } + return enclosing_list; + } + } +} + +static PyMethodDef module_methods[] = { + {"_distill_params", distill_params, METH_VARARGS, + "Distill an execute() parameter structure."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + +#if PY_MAJOR_VERSION >= 3 + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, + MODULE_DOC, + -1, + module_methods + }; +#endif + + +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC +PyInit_cutils(void) +#else +PyMODINIT_FUNC +initcutils(void) +#endif +{ + PyObject *m; + +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&module_def); +#else + m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); +#endif + +#if PY_MAJOR_VERSION >= 3 + if (m == NULL) + return NULL; + return m; +#else + if (m == NULL) + return; +#endif +} + diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py new file mode 100644 index 00000000..761024fe --- /dev/null +++ b/lib/sqlalchemy/connectors/__init__.py @@ -0,0 +1,9 @@ +# connectors/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +class Connector(object): + pass diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py new file mode 100644 index 00000000..e5562a25 --- /dev/null +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -0,0 +1,149 @@ +# connectors/mxodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Provide an SQLALchemy connector for the eGenix mxODBC commercial +Python adapter for ODBC. This is not a free product, but eGenix +provides SQLAlchemy with a license for use in continuous integration +testing. + +This has been tested for use with mxODBC 3.1.2 on SQL Server 2005 +and 2008, using the SQL Server Native driver. However, it is +possible for this to be used on other database platforms. + +For more info on mxODBC, see http://www.egenix.com/ + +""" + +import sys +import re +import warnings + +from . import Connector + + +class MxODBCConnector(Connector): + driver = 'mxodbc' + + supports_sane_multi_rowcount = False + supports_unicode_statements = True + supports_unicode_binds = True + + supports_native_decimal = True + + @classmethod + def dbapi(cls): + # this classmethod will normally be replaced by an instance + # attribute of the same name, so this is normally only called once. + cls._load_mx_exceptions() + platform = sys.platform + if platform == 'win32': + from mx.ODBC import Windows as module + # this can be the string "linux2", and possibly others + elif 'linux' in platform: + from mx.ODBC import unixODBC as module + elif platform == 'darwin': + from mx.ODBC import iODBC as module + else: + raise ImportError("Unrecognized platform for mxODBC import") + return module + + @classmethod + def _load_mx_exceptions(cls): + """ Import mxODBC exception classes into the module namespace, + as if they had been imported normally. This is done here + to avoid requiring all SQLAlchemy users to install mxODBC. + """ + global InterfaceError, ProgrammingError + from mx.ODBC import InterfaceError + from mx.ODBC import ProgrammingError + + def on_connect(self): + def connect(conn): + conn.stringformat = self.dbapi.MIXED_STRINGFORMAT + conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT + conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT + conn.errorhandler = self._error_handler() + return connect + + def _error_handler(self): + """ Return a handler that adjusts mxODBC's raised Warnings to + emit Python standard warnings. + """ + from mx.ODBC.Error import Warning as MxOdbcWarning + + def error_handler(connection, cursor, errorclass, errorvalue): + if issubclass(errorclass, MxOdbcWarning): + errorclass.__bases__ = (Warning,) + warnings.warn(message=str(errorvalue), + category=errorclass, + stacklevel=2) + else: + raise errorclass(errorvalue) + return error_handler + + def create_connect_args(self, url): + """ Return a tuple of *args,**kwargs for creating a connection. + + The mxODBC 3.x connection constructor looks like this: + + connect(dsn, user='', password='', + clear_auto_commit=1, errorhandler=None) + + This method translates the values in the provided uri + into args and kwargs needed to instantiate an mxODBC Connection. + + The arg 'errorhandler' is not used by SQLAlchemy and will + not be populated. + + """ + opts = url.translate_connect_args(username='user') + opts.update(url.query) + args = opts.pop('host') + opts.pop('port', None) + opts.pop('database', None) + return (args,), opts + + def is_disconnect(self, e, connection, cursor): + # TODO: eGenix recommends checking connection.closed here + # Does that detect dropped connections ? + if isinstance(e, self.dbapi.ProgrammingError): + return "connection already closed" in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def _get_server_version_info(self, connection): + # eGenix suggests using conn.dbms_version instead + # of what we're doing here + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + # 18 == pyodbc.SQL_DBMS_VER + for n in r.split(dbapi_con.getinfo(18)[1]): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _get_direct(self, context): + if context: + native_odbc_execute = context.execution_options.\ + get('native_odbc_execute', 'auto') + # default to direct=True in all cases, is more generally + # compatible especially with SQL Server + return False if native_odbc_execute is True else True + else: + return True + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany( + statement, parameters, direct=self._get_direct(context)) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters, direct=self._get_direct(context)) diff --git a/lib/sqlalchemy/connectors/mysqldb.py b/lib/sqlalchemy/connectors/mysqldb.py new file mode 100644 index 00000000..e4efb220 --- /dev/null +++ b/lib/sqlalchemy/connectors/mysqldb.py @@ -0,0 +1,144 @@ +# connectors/mysqldb.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Define behaviors common to MySQLdb dialects. + +Currently includes MySQL and Drizzle. + +""" + +from . import Connector +from ..engine import base as engine_base, default +from ..sql import operators as sql_operators +from .. import exc, log, schema, sql, types as sqltypes, util, processors +import re + + +# the subclassing of Connector by all classes +# here is not strictly necessary + + +class MySQLDBExecutionContext(Connector): + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQLDBCompiler(Connector): + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class MySQLDBIdentifierPreparer(Connector): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace("%", "%%") + + +class MySQLDBConnector(Connector): + driver = 'mysqldb' + supports_unicode_statements = False + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = 'format' + + @classmethod + def dbapi(cls): + # is overridden when pymysql is used + return __import__('MySQLdb') + + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'connect_timeout', int) + util.coerce_kw_type(opts, 'read_timeout', int) + util.coerce_kw_type(opts, 'client_flag', int) + util.coerce_kw_type(opts, 'local_infile', int) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, 'charset', str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher'] + for key in keys: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get('client_flag', 0) + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + '.constants.CLIENT' + ).constants.CLIENT + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except (AttributeError, ImportError): + self.supports_sane_rowcount = False + opts['client_flag'] = client_flag + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + return exception.args[0] + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + try: + # note: the SQL here would be + # "SHOW VARIABLES LIKE 'character_set%%'" + cset_name = connection.connection.character_set_name + except AttributeError: + util.warn( + "No 'character_set_name' can be detected with " + "this MySQL-Python version; " + "please upgrade to a recent version of MySQL-Python. " + "Assuming latin1.") + return 'latin1' + else: + return cset_name() + diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 00000000..284de288 --- /dev/null +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,170 @@ +# connectors/pyodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import Connector +from .. import util + + +import sys +import re + + +class PyODBCConnector(Connector): + driver = 'pyodbc' + + supports_sane_multi_rowcount = False + + if util.py2k: + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + + supports_native_decimal = True + default_paramstyle = 'named' + + # for non-DSN connections, this should + # hold the desired driver name + pyodbc_driver_name = None + + # will be set to True after initialize() + # if the freetds.so is detected + freetds = False + + # will be set to the string version of + # the FreeTDS driver if freetds is detected + freetds_driver_version = None + + # will be set to True after initialize() + # if the libessqlsrv.so is detected + easysoft = False + + def __init__(self, supports_unicode_binds=None, **kw): + super(PyODBCConnector, self).__init__(**kw) + self._user_supports_unicode_binds = supports_unicode_binds + + @classmethod + def dbapi(cls): + return __import__('pyodbc') + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + keys = opts + query = url.query + + connect_args = {} + for param in ('ansi', 'unicode_results', 'autocommit'): + if param in keys: + connect_args[param] = util.asbool(keys.pop(param)) + + if 'odbc_connect' in keys: + connectors = [util.unquote_plus(keys.pop('odbc_connect'))] + else: + dsn_connection = 'dsn' in keys or \ + ('host' in keys and 'database' not in keys) + if dsn_connection: + connectors = ['dsn=%s' % (keys.pop('host', '') or \ + keys.pop('dsn', ''))] + else: + port = '' + if 'port' in keys and not 'port' in query: + port = ',%d' % int(keys.pop('port')) + + connectors = ["DRIVER={%s}" % + keys.pop('driver', self.pyodbc_driver_name), + 'Server=%s%s' % (keys.pop('host', ''), port), + 'Database=%s' % keys.pop('database', '')] + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.pop('password', '')) + else: + connectors.append("Trusted_Connection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically + # convert textual data from your database encoding to your + # client encoding. This should obviously be set to 'No' if + # you query a cp1253 encoded database from a latin1 client... + if 'odbc_autotranslate' in keys: + connectors.append("AutoTranslate=%s" % + keys.pop("odbc_autotranslate")) + + connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()]) + return [[";".join(connectors)], connect_args] + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or \ + 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def initialize(self, connection): + # determine FreeTDS first. can't issue SQL easily + # without getting unicode_statements/binds set up. + + pyodbc = self.dbapi + + dbapi_con = connection.connection + + _sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME) + self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name + )) + self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name + )) + + if self.freetds: + self.freetds_driver_version = dbapi_con.getinfo( + pyodbc.SQL_DRIVER_VER) + + self.supports_unicode_statements = ( + not util.py2k or + (not self.freetds and not self.easysoft) + ) + + if self._user_supports_unicode_binds is not None: + self.supports_unicode_binds = self._user_supports_unicode_binds + elif util.py2k: + self.supports_unicode_binds = ( + not self.freetds or self.freetds_driver_version >= '0.91' + ) and not self.easysoft + else: + self.supports_unicode_binds = True + + # run other initialization which asks for user name, etc. + super(PyODBCConnector, self).initialize(connection) + + def _dbapi_version(self): + if not self.dbapi: + return () + return self._parse_dbapi_version(self.dbapi.version) + + def _parse_dbapi_version(self, vers): + m = re.match( + r'(?:py.*-)?([\d\.]+)(?:-(\w+))?', + vers + ) + if not m: + return () + vers = tuple([int(x) for x in m.group(1).split(".")]) + if m.group(2): + vers += (m.group(2),) + return vers + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py new file mode 100644 index 00000000..e0bbc573 --- /dev/null +++ b/lib/sqlalchemy/connectors/zxJDBC.py @@ -0,0 +1,59 @@ +# connectors/zxJDBC.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sys +from . import Connector + + +class ZxJDBCConnector(Connector): + driver = 'zxjdbc' + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_unicode_binds = True + supports_unicode_statements = sys.version > '2.5.0+' + description_encoding = None + default_paramstyle = 'qmark' + + jdbc_db_name = None + jdbc_driver_name = None + + @classmethod + def dbapi(cls): + from com.ziclix.python.sql import zxJDBC + return zxJDBC + + def _driver_kwargs(self): + """Return kw arg dict to be sent to connect().""" + return {} + + def _create_jdbc_url(self, url): + """Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`""" + return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host, + url.port is not None + and ':%s' % url.port or '', + url.database) + + def create_connect_args(self, url): + opts = self._driver_kwargs() + opts.update(url.query) + return [ + [self._create_jdbc_url(url), + url.username, url.password, + self.jdbc_driver_name], + opts] + + def is_disconnect(self, e, connection, cursor): + if not isinstance(e, self.dbapi.ProgrammingError): + return False + e = str(e) + return 'connection is closed' in e or 'cursor is closed' in e + + def _get_server_version_info(self, connection): + # use connection.connection.dbversion, and parse appropriately + # to get a tuple + raise NotImplementedError() diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py new file mode 100644 index 00000000..915eefa4 --- /dev/null +++ b/lib/sqlalchemy/databases/__init__.py @@ -0,0 +1,31 @@ +# databases/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Include imports from the sqlalchemy.dialects package for backwards +compatibility with pre 0.6 versions. + +""" +from ..dialects.sqlite import base as sqlite +from ..dialects.postgresql import base as postgresql +postgres = postgresql +from ..dialects.mysql import base as mysql +from ..dialects.drizzle import base as drizzle +from ..dialects.oracle import base as oracle +from ..dialects.firebird import base as firebird +from ..dialects.mssql import base as mssql +from ..dialects.sybase import base as sybase + + +__all__ = ( + 'drizzle', + 'firebird', + 'mssql', + 'mysql', + 'postgresql', + 'sqlite', + 'oracle', + 'sybase', + ) diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py new file mode 100644 index 00000000..974d4f78 --- /dev/null +++ b/lib/sqlalchemy/dialects/__init__.py @@ -0,0 +1,44 @@ +# dialects/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +__all__ = ( + 'drizzle', + 'firebird', + 'mssql', + 'mysql', + 'oracle', + 'postgresql', + 'sqlite', + 'sybase', + ) + +from .. import util + +def _auto_fn(name): + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + try: + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + except ImportError: + return None + + module = getattr(module, dialect) + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) diff --git a/lib/sqlalchemy/dialects/drizzle/__init__.py b/lib/sqlalchemy/dialects/drizzle/__init__.py new file mode 100644 index 00000000..1392b8e2 --- /dev/null +++ b/lib/sqlalchemy/dialects/drizzle/__init__.py @@ -0,0 +1,22 @@ +from sqlalchemy.dialects.drizzle import base, mysqldb + +base.dialect = mysqldb.dialect + +from sqlalchemy.dialects.drizzle.base import \ + BIGINT, BINARY, BLOB, \ + BOOLEAN, CHAR, DATE, \ + DATETIME, DECIMAL, DOUBLE, \ + ENUM, FLOAT, INTEGER, \ + NUMERIC, REAL, TEXT, \ + TIME, TIMESTAMP, VARBINARY, \ + VARCHAR, dialect + +__all__ = ( + 'BIGINT', 'BINARY', 'BLOB', + 'BOOLEAN', 'CHAR', 'DATE', + 'DATETIME', 'DECIMAL', 'DOUBLE', + 'ENUM', 'FLOAT', 'INTEGER', + 'NUMERIC', 'REAL', 'TEXT', + 'TIME', 'TIMESTAMP', 'VARBINARY', + 'VARCHAR', 'dialect' +) diff --git a/lib/sqlalchemy/dialects/drizzle/base.py b/lib/sqlalchemy/dialects/drizzle/base.py new file mode 100644 index 00000000..b5addb42 --- /dev/null +++ b/lib/sqlalchemy/dialects/drizzle/base.py @@ -0,0 +1,498 @@ +# drizzle/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2010-2011 Monty Taylor +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +""" + +.. dialect:: drizzle + :name: Drizzle + +Drizzle is a variant of MySQL. Unlike MySQL, Drizzle's default storage engine +is InnoDB (transactions, foreign-keys) rather than MyISAM. For more +`Notable Differences `_, visit +the `Drizzle Documentation `_. + +The SQLAlchemy Drizzle dialect leans heavily on the MySQL dialect, so much of +the :doc:`SQLAlchemy MySQL ` documentation is also relevant. + + +""" + +from sqlalchemy import exc +from sqlalchemy import log +from sqlalchemy import types as sqltypes +from sqlalchemy.engine import reflection +from sqlalchemy.dialects.mysql import base as mysql_dialect +from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \ + BLOB, BINARY, VARBINARY + + +class _NumericType(object): + """Base for Drizzle numeric types.""" + + def __init__(self, **kw): + super(_NumericType, self).__init__(**kw) + + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and \ + ( + (precision is None and scale is not None) or + (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") + + super(_FloatType, self).__init__(precision=precision, + asdecimal=asdecimal, **kw) + self.scale = scale + + +class _StringType(mysql_dialect._StringType): + """Base for Drizzle string types.""" + + def __init__(self, collation=None, binary=False, **kw): + kw['national'] = False + super(_StringType, self).__init__(collation=collation, binary=binary, + **kw) + + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """Drizzle NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + """ + + super(NUMERIC, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """Drizzle DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + """ + super(DECIMAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class DOUBLE(_FloatType): + """Drizzle DOUBLE type.""" + + __visit_name__ = 'DOUBLE' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + """ + + super(DOUBLE, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class REAL(_FloatType, sqltypes.REAL): + """Drizzle REAL type.""" + + __visit_name__ = 'REAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + """ + + super(REAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class FLOAT(_FloatType, sqltypes.FLOAT): + """Drizzle FLOAT type.""" + + __visit_name__ = 'FLOAT' + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + """ + + super(FLOAT, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + def bind_processor(self, dialect): + return None + + +class INTEGER(sqltypes.INTEGER): + """Drizzle INTEGER type.""" + + __visit_name__ = 'INTEGER' + + def __init__(self, **kw): + """Construct an INTEGER.""" + + super(INTEGER, self).__init__(**kw) + + +class BIGINT(sqltypes.BIGINT): + """Drizzle BIGINTEGER type.""" + + __visit_name__ = 'BIGINT' + + def __init__(self, **kw): + """Construct a BIGINTEGER.""" + + super(BIGINT, self).__init__(**kw) + + +class TIME(mysql_dialect.TIME): + """Drizzle TIME type.""" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """Drizzle TIMESTAMP type.""" + + __visit_name__ = 'TIMESTAMP' + + +class TEXT(_StringType, sqltypes.TEXT): + """Drizzle TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = 'TEXT' + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + + super(TEXT, self).__init__(length=length, **kw) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """Drizzle VARCHAR type, for variable-length character data.""" + + __visit_name__ = 'VARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """Drizzle CHAR type, for fixed-length character data.""" + + __visit_name__ = 'CHAR' + + def __init__(self, length=None, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + + super(CHAR, self).__init__(length=length, **kwargs) + + +class ENUM(mysql_dialect.ENUM): + """Drizzle ENUM type.""" + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + Example: + + Column('myenum', ENUM("foo", "bar", "baz")) + + :param enums: The range of valid values for this ENUM. Values will be + quoted when generating the schema according to the quoting flag (see + below). + + :param strict: Defaults to False: ensure that a given value is in this + ENUM's range of permissible values when inserting or updating rows. + Note that Drizzle will not raise a fatal error if you attempt to + store an out of range value- an alternate value will be stored + instead. + (See Drizzle ENUM documentation.) + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. + + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema - this usage is deprecated. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. + + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + """ + + super(ENUM, self).__init__(*enums, **kw) + + +class _DrizzleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMERIC + + +colspecs = { + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Time: TIME, + sqltypes.Enum: ENUM, + sqltypes.Boolean: _DrizzleBoolean, +} + + +# All the types we have in Drizzle +ischema_names = { + 'BIGINT': BIGINT, + 'BINARY': BINARY, + 'BLOB': BLOB, + 'BOOLEAN': BOOLEAN, + 'CHAR': CHAR, + 'DATE': DATE, + 'DATETIME': DATETIME, + 'DECIMAL': DECIMAL, + 'DOUBLE': DOUBLE, + 'ENUM': ENUM, + 'FLOAT': FLOAT, + 'INT': INTEGER, + 'INTEGER': INTEGER, + 'NUMERIC': NUMERIC, + 'TEXT': TEXT, + 'TIME': TIME, + 'TIMESTAMP': TIMESTAMP, + 'VARBINARY': VARBINARY, + 'VARCHAR': VARCHAR, +} + + +class DrizzleCompiler(mysql_dialect.MySQLCompiler): + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.Integer): + return 'INTEGER' + else: + return super(DrizzleCompiler, self).visit_typeclause(typeclause) + + def visit_cast(self, cast, **kwargs): + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause) + + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) + + +class DrizzleDDLCompiler(mysql_dialect.MySQLDDLCompiler): + pass + + +class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler): + def _extend_numeric(self, type_, spec): + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL + COLLATE annotations and Drizzle specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr('collation'): + collation = 'COLLATE %s' % type_.collation + elif attr('binary'): + collation = 'BINARY' + else: + collation = None + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_NCHAR(self, type): + raise NotImplementedError("Drizzle does not support NCHAR") + + def visit_NVARCHAR(self, type): + raise NotImplementedError("Drizzle does not support NVARCHAR") + + def visit_FLOAT(self, type_): + if type_.scale is not None and type_.precision is not None: + return "FLOAT(%s, %s)" % (type_.precision, type_.scale) + else: + return "FLOAT" + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_BLOB(self, type_): + return "BLOB" + + +class DrizzleExecutionContext(mysql_dialect.MySQLExecutionContext): + pass + + +class DrizzleIdentifierPreparer(mysql_dialect.MySQLIdentifierPreparer): + pass + + +@log.class_logger +class DrizzleDialect(mysql_dialect.MySQLDialect): + """Details of the Drizzle dialect. + + Not used directly in application code. + """ + + name = 'drizzle' + + _supports_cast = True + supports_sequences = False + supports_native_boolean = True + supports_views = False + + default_paramstyle = 'format' + colspecs = colspecs + + statement_compiler = DrizzleCompiler + ddl_compiler = DrizzleDDLCompiler + type_compiler = DrizzleTypeCompiler + ischema_names = ischema_names + preparer = DrizzleIdentifierPreparer + + def on_connect(self): + """Force autocommit - Drizzle Bug#707842 doesn't set this properly""" + + def connect(conn): + conn.autocommit(False) + return connect + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = 'utf8' + rp = connection.execute("SHOW TABLES FROM %s" % + self.identifier_preparer.quote_identifier(current_schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + raise NotImplementedError + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + """ + + return 0 + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + charset = self._connection_charset + rs = connection.execute( + 'SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM' + ' data_dictionary.COLLATIONS') + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + self._server_ansiquotes = False + self._backslash_escapes = False + + diff --git a/lib/sqlalchemy/dialects/drizzle/mysqldb.py b/lib/sqlalchemy/dialects/drizzle/mysqldb.py new file mode 100644 index 00000000..7d91cc36 --- /dev/null +++ b/lib/sqlalchemy/dialects/drizzle/mysqldb.py @@ -0,0 +1,48 @@ +""" +.. dialect:: drizzle+mysqldb + :name: MySQL-Python + :dbapi: mysqldb + :connectstring: drizzle+mysqldb://:@[:]/ + :url: http://sourceforge.net/projects/mysql-python + + +""" + +from sqlalchemy.dialects.drizzle.base import ( + DrizzleDialect, + DrizzleExecutionContext, + DrizzleCompiler, + DrizzleIdentifierPreparer) +from sqlalchemy.connectors.mysqldb import ( + MySQLDBExecutionContext, + MySQLDBCompiler, + MySQLDBIdentifierPreparer, + MySQLDBConnector) + + +class DrizzleExecutionContext_mysqldb(MySQLDBExecutionContext, + DrizzleExecutionContext): + pass + + +class DrizzleCompiler_mysqldb(MySQLDBCompiler, DrizzleCompiler): + pass + + +class DrizzleIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, + DrizzleIdentifierPreparer): + pass + + +class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect): + execution_ctx_cls = DrizzleExecutionContext_mysqldb + statement_compiler = DrizzleCompiler_mysqldb + preparer = DrizzleIdentifierPreparer_mysqldb + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + return 'utf8' + + +dialect = DrizzleDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 00000000..094ac3e8 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,20 @@ +# firebird/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb + +base.dialect = fdb.dialect + +from sqlalchemy.dialects.firebird.base import \ + SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \ + TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\ + dialect + +__all__ = ( + 'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME', + 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB', + 'dialect' +) diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 00000000..21db57b6 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,738 @@ +# firebird/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: firebird + :name: Firebird + +Firebird Dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Locking Behavior +---------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute("select * from table") + row = result.fetchone() + return + +Where above, the ``ResultProxy`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``ResultProxy`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 +extends that to deletes and updates. This is generically exposed by +the SQLAlchemy ``returning()`` method, such as:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + raises = empl.update().returning(empl.c.id, empl.c.salary).\\ + where(empl.c.sales>100).\\ + values(dict(salary=empl.c.salary * 1.1)) + print raises.fetchall() + + +.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html + +""" + +import datetime + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, types as sqltypes, sql, util +from sqlalchemy.sql import expression +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler + + +from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC, + SMALLINT, TEXT, TIME, TIMESTAMP, Integer) + + +RESERVED_WORDS = set([ + "active", "add", "admin", "after", "all", "alter", "and", "any", "as", + "asc", "ascending", "at", "auto", "avg", "before", "begin", "between", + "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char", + "character", "character_length", "char_length", "check", "close", + "collate", "column", "commit", "committed", "computed", "conditional", + "connect", "constraint", "containing", "count", "create", "cross", + "cstring", "current", "current_connection", "current_date", + "current_role", "current_time", "current_timestamp", + "current_transaction", "current_user", "cursor", "database", "date", + "day", "dec", "decimal", "declare", "default", "delete", "desc", + "descending", "disconnect", "distinct", "do", "domain", "double", + "drop", "else", "end", "entry_point", "escape", "exception", + "execute", "exists", "exit", "external", "extract", "fetch", "file", + "filter", "float", "for", "foreign", "from", "full", "function", + "gdscode", "generator", "gen_id", "global", "grant", "group", + "having", "hour", "if", "in", "inactive", "index", "inner", + "input_type", "insensitive", "insert", "int", "integer", "into", "is", + "isolation", "join", "key", "leading", "left", "length", "level", + "like", "long", "lower", "manual", "max", "maximum_segment", "merge", + "min", "minute", "module_name", "month", "names", "national", + "natural", "nchar", "no", "not", "null", "numeric", "octet_length", + "of", "on", "only", "open", "option", "or", "order", "outer", + "output_type", "overflow", "page", "pages", "page_size", "parameter", + "password", "plan", "position", "post_event", "precision", "primary", + "privileges", "procedure", "protected", "rdb$db_key", "read", "real", + "record_version", "recreate", "recursive", "references", "release", + "reserv", "reserving", "retain", "returning_values", "returns", + "revoke", "right", "rollback", "rows", "row_count", "savepoint", + "schema", "second", "segment", "select", "sensitive", "set", "shadow", + "shared", "singular", "size", "smallint", "snapshot", "some", "sort", + "sqlcode", "stability", "start", "starting", "starts", "statistics", + "sub_type", "sum", "suspend", "table", "then", "time", "timestamp", + "to", "trailing", "transaction", "trigger", "trim", "uncommitted", + "union", "unique", "update", "upper", "user", "using", "value", + "values", "varchar", "variable", "varying", "view", "wait", "when", + "where", "while", "with", "work", "write", "year", + ]) + + +class _StringType(sqltypes.String): + """Base for Firebird string types.""" + + def __init__(self, charset=None, **kw): + self.charset = charset + super(_StringType, self).__init__(**kw) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """Firebird VARCHAR type""" + __visit_name__ = 'VARCHAR' + + def __init__(self, length=None, **kwargs): + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """Firebird CHAR type""" + __visit_name__ = 'CHAR' + + def __init__(self, length=None, **kwargs): + super(CHAR, self).__init__(length=length, **kwargs) + + +class _FBDateTime(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +colspecs = { + sqltypes.DateTime: _FBDateTime +} + +ischema_names = { + 'SHORT': SMALLINT, + 'LONG': INTEGER, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': BIGINT, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, + 'VARYING': VARCHAR, + 'CSTRING': CHAR, + 'BLOB': BLOB, + } + + +# TODO: date conversion types (should be implemented as _FBDateTime, +# _FBDate, etc. as bind/result functionality is required) + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TEXT(self, type_): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_): + return "BLOB SUB_TYPE 0" + + def _extend_string(self, type_, basic): + charset = getattr(type_, 'charset', None) + if charset is None: + return basic + else: + return '%s CHARACTER SET %s' % (basic, charset) + + def visit_CHAR(self, type_): + basic = super(FBTypeCompiler, self).visit_CHAR(type_) + return self._extend_string(type_, basic) + + def visit_VARCHAR(self, type_): + if not type_.length: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % + self.dialect.name) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_) + return self._extend_string(type_, basic) + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosyncrasies""" + + ansi_bind_rules = True + + #def visit_contains_op_binary(self, binary, operator, **kw): + # cant use CONTAINING b.c. it's case insensitive. + + #def visit_notcontains_op_binary(self, binary, operator, **kw): + # cant use NOT CONTAINING b.c. it's case insensitive. + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_startswith_op_binary(self, binary, operator, **kw): + return '%s STARTING WITH %s' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) + + def visit_notstartswith_op_binary(self, binary, operator, **kw): + return '%s NOT STARTING WITH %s' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).\ + visit_alias(alias, asfrom=asfrom, **kwargs) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = isinstance(alias.name, + expression._truncated_label) and \ + self._truncated_identifier("alias", + alias.name) or alias.name + + return self.process( + alias.original, asfrom=asfrom, **kwargs) + \ + " " + \ + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + # TODO: this probably will need to be + # narrowed to a fixed list, some no-arg functions + # may require parens - see similar example in the oracle + # dialect + if func.clauses is not None and len(func.clauses): + return self.process(func.clause_expr, **kw) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit: + result += "FIRST %s " % self.process(sql.literal(select._limit)) + if select._offset: + result += "SKIP %s " % self.process(sql.literal(select._offset)) + if select._distinct: + result += "DISTINCT " + return result + + def limit_clause(self, select): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in expression._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosyncrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + # no syntax for these + # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html + if create.element.start is not None: + raise NotImplemented( + "Firebird SEQUENCE doesn't support START WITH") + if create.element.increment is not None: + raise NotImplemented( + "Firebird SEQUENCE doesn't support INCREMENT BY") + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % \ + self.preparer.format_sequence(create.element) + else: + return "CREATE GENERATOR %s" % \ + self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % \ + self.preparer.format_sequence(drop.element) + else: + return "DROP GENERATOR %s" % \ + self.preparer.format_sequence(drop.element) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(['_']) + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + """Get the next value from the sequence using ``gen_id()``.""" + + return self._execute_scalar( + "SELECT gen_id(%s, 1) FROM rdb$database" % + self.dialect.identifier_preparer.format_sequence(seq), + type_ + ) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = 'firebird' + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + supports_native_boolean = False + + requires_name_normalize = True + supports_empty_insert = False + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + execution_ctx_cls = FBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + construct_arguments = [] + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = ('firebird' in self.server_version_info and \ + self.server_version_info >= (2, ) + ) or \ + ('interbase' in self.server_version_info and \ + self.server_version_info >= (6, ) + ) + + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names['TIMESTAMP'] = sqltypes.DATE + self.colspecs = { + sqltypes.DateTime: sqltypes.DATE + } + + self.implicit_returning = self._version_two and \ + self.__dict__.get('implicit_returning', True) + + def normalize_name(self, name): + # Remove trailing spaces: FB uses a CHAR() type, + # that is padded with spaces + name = name and name.rstrip() + if name is None: + return None + elif name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper() + else: + return name + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring + the `schema`.""" + + tblqry = """ + SELECT 1 AS has_table FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.execute(tblqry, [self.denormalize_name(table_name)]) + return c.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 AS has_sequence FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.execute(genqry, [self.denormalize_name(sequence_name)]) + return c.first() is not None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + # there are two queries commonly mentioned for this. + # this one, using view_blr, is at the Firebird FAQ among other places: + # http://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + + # the other query is this one. It's not clear if there's really + # any difference between these two. This link: + # http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8 + # states them as interchangeable. Some discussion at [ticket:2898] + # SELECT DISTINCT rdb$relation_name + # FROM rdb$relation_fields + # WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + # see http://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is not null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.execute(qry, [self.denormalize_name(view_name)]) + row = rp.first() + if row: + return row['view_source'] + else: + return None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] + return {'constrained_columns': pkfields, 'name': None} + + @reflection.cache + def get_column_sequence(self, connection, + table_name, column_name, + schema=None, **kw): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON + trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genr = connection.execute(genqry, [tablename, colname]).first() + if genr is not None: + return dict(name=self.normalize_name(genr['fgenerator'])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length/ + COALESCE(cs.rdb$bytes_per_character,1) AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, + f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND + t.rdb$field_name='RDB$FIELD_TYPE' + LEFT JOIN rdb$character_sets cs ON + f.rdb$character_set_id=cs.rdb$character_set_id + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pk_constraint = self.get_pk_constraint(connection, table_name) + pkey_cols = pk_constraint['constrained_columns'] + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.execute(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row['fname']) + orig_colname = row['fname'] + + # get the data type + colspec = row['ftype'].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (colspec, name)) + coltype = sqltypes.NULLTYPE + elif issubclass(coltype, Integer) and row['fprec'] != 0: + coltype = NUMERIC( + precision=row['fprec'], + scale=row['fscale'] * -1) + elif colspec in ('VARYING', 'CSTRING'): + coltype = coltype(row['flen']) + elif colspec == 'TEXT': + coltype = TEXT(row['flen']) + elif colspec == 'BLOB': + if row['stype'] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype() + + # does it have a default value? + defvalue = None + if row['fdefault'] is not None: + # the value comes down as "DEFAULT 'value'": there may be + # more than one whitespace around the "DEFAULT" keyword + # and it may also be lower case + # (see also http://tracker.firebirdsql.org/browse/CORE-356) + defexpr = row['fdefault'].lstrip() + assert defexpr[:8].rstrip().upper() == \ + 'DEFAULT', "Unrecognized default value: %s" % \ + defexpr + defvalue = defexpr[8:].strip() + if defvalue == 'NULL': + # Redundant + defvalue = None + col_d = { + 'name': name, + 'type': coltype, + 'nullable': not bool(row['null_flag']), + 'default': defvalue, + 'autoincrement': defvalue is None + } + + if orig_colname.lower() == orig_colname: + col_d['quote'] = True + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols) == 1 and name == pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d['sequence'] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON + cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict(lambda: { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': None, + 'referred_columns': [] + }) + + for row in c: + cname = self.normalize_name(row['cname']) + fk = fks[cname] + if not fk['name']: + fk['name'] = cname + fk['referred_table'] = self.normalize_name(row['targetrname']) + fk['constrained_columns'].append( + self.normalize_name(row['fname'])) + fk['referred_columns'].append( + self.normalize_name(row['targetfname'])) + return list(fks.values()) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = + ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, ic.rdb$field_position + """ + c = connection.execute(qry, [self.denormalize_name(table_name)]) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row['index_name']] + if 'name' not in indexrec: + indexrec['name'] = self.normalize_name(row['index_name']) + indexrec['column_names'] = [] + indexrec['unique'] = bool(row['unique_flag']) + + indexrec['column_names'].append( + self.normalize_name(row['field_name'])) + + return list(indexes.values()) + diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py new file mode 100644 index 00000000..4d94ef0d --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -0,0 +1,115 @@ +# firebird/fdb.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+fdb + :name: fdb + :dbapi: pyodbc + :connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: http://pypi.python.org/pypi/fdb/ + + fdb is a kinterbasdb compatible DBAPI for Firebird. + + .. versionadded:: 0.8 - Support for the fdb Firebird driver. + + .. versionchanged:: 0.9 - The fdb dialect is now the default dialect + under the ``firebird://`` URL space, as ``fdb`` is now the official + Python driver for Firebird. + +Arguments +---------- + +The ``fdb`` dialect is based on the :mod:`sqlalchemy.dialects.firebird.kinterbasdb` +dialect, however does not accept every argument that Kinterbasdb does. + +* ``enable_rowcount`` - True by default, setting this to False disables + the usage of "cursor.rowcount" with the + Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically + after any UPDATE or DELETE statement. When disabled, SQLAlchemy's + ResultProxy will return -1 for result.rowcount. The rationale here is + that Kinterbasdb requires a second round trip to the database when + .rowcount is called - since SQLA's resultproxy automatically closes + the cursor after a non-result-returning statement, rowcount must be + called, if at all, before the result object is returned. Additionally, + cursor.rowcount may not return correct results with older versions + of Firebird, and setting this flag to False will also cause the + SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a + per-execution basis using the ``enable_rowcount`` option with + :meth:`.Connection.execution_options`:: + + conn = engine.connect().execution_options(enable_rowcount=True) + r = conn.execute(stmt) + print r.rowcount + +* ``retaining`` - False by default. Setting this to True will pass the + ``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()`` + methods of the DBAPI connection, which can improve performance in some + situations, but apparently with significant caveats. + Please read the fdb and/or kinterbasdb DBAPI documentation in order to + understand the implications of this flag. + + .. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying + transaction retaining behavior - in 0.8 it defaults to ``True`` + for backwards compatibility. + + .. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``. + In 0.8 it defaulted to ``True``. + + .. seealso:: + + http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions - information + on the "retaining" flag. + +""" + +from .kinterbasdb import FBDialect_kinterbasdb +from ... import util + + +class FBDialect_fdb(FBDialect_kinterbasdb): + + def __init__(self, enable_rowcount=True, + retaining=False, **kwargs): + super(FBDialect_fdb, self).__init__( + enable_rowcount=enable_rowcount, + retaining=retaining, **kwargs) + + @classmethod + def dbapi(cls): + return __import__('fdb') + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + util.coerce_kw_type(opts, 'type_conv', int) + + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + isc_info_firebird_version = 103 + fbconn = connection.connection + + version = fbconn.db_info(isc_info_firebird_version) + + return self._parse_version_info(version) + +dialect = FBDialect_fdb diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 00000000..b8a83a07 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,179 @@ +# firebird/kinterbasdb.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+kinterbasdb + :name: kinterbasdb + :dbapi: kinterbasdb + :connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: http://firebirdsql.org/index.php?op=devel&sub=python + +Arguments +---------- + +The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining`` +arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. In addition, it +also accepts the following: + +* ``type_conv`` - select the kind of mapping done on the types: by default + SQLAlchemy uses 200 with Unicode, datetime and decimal support. See + the linked documents below for further information. + +* ``concurrency_level`` - set the backend policy with regards to threading + issues: by default SQLAlchemy uses policy 1. See the linked documents + below for futher information. + +.. seealso:: + + http://sourceforge.net/projects/kinterbasdb + + http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation + + http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency + +""" + +from .base import FBDialect, FBExecutionContext +from ... import util, types as sqltypes +from re import match +import decimal + + +class _kinterbasdb_numeric(object): + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + return str(value) + else: + return value + return process + +class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric): + pass + +class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): + pass + + +class FBExecutionContext_kinterbasdb(FBExecutionContext): + @property + def rowcount(self): + if self.execution_options.get('enable_rowcount', + self.dialect.enable_rowcount): + return self.cursor.rowcount + else: + return -1 + + +class FBDialect_kinterbasdb(FBDialect): + driver = 'kinterbasdb' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + execution_ctx_cls = FBExecutionContext_kinterbasdb + + supports_native_decimal = True + + colspecs = util.update_copy( + FBDialect.colspecs, + { + sqltypes.Numeric: _FBNumeric_kinterbasdb, + sqltypes.Float: _FBFloat_kinterbasdb, + } + + ) + + def __init__(self, type_conv=200, concurrency_level=1, + enable_rowcount=True, + retaining=False, **kwargs): + super(FBDialect_kinterbasdb, self).__init__(**kwargs) + self.enable_rowcount = enable_rowcount + self.type_conv = type_conv + self.concurrency_level = concurrency_level + self.retaining = retaining + if enable_rowcount: + self.supports_sane_rowcount = True + + @classmethod + def dbapi(cls): + return __import__('kinterbasdb') + + def do_execute(self, cursor, statement, parameters, context=None): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback(self.retaining) + + def do_commit(self, dbapi_connection): + dbapi_connection.commit(self.retaining) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + util.coerce_kw_type(opts, 'type_conv', int) + + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', + self.concurrency_level) + + if self.dbapi is not None: + initialized = getattr(self.dbapi, 'initialized', None) + if initialized is None: + # CVS rev 1.96 changed the name of the attribute: + # http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/ + # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 + initialized = getattr(self.dbapi, '_initialized', False) + if not initialized: + self.dbapi.init(type_conv=type_conv, + concurrency_level=concurrency_level) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + fbconn = connection.connection + version = fbconn.server_version + + return self._parse_version_info(version) + + def _parse_version_info(self, version): + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % version) + + if m.group(5) != None: + return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird']) + else: + return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase']) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, (self.dbapi.OperationalError, + self.dbapi.ProgrammingError)): + msg = str(e) + return ('Unable to complete network request to host' in msg or + 'Invalid connection state' in msg or + 'Invalid cursor state' in msg or + 'connection shutdown' in msg) + else: + return False + +dialect = FBDialect_kinterbasdb diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 00000000..7a2dfa60 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,26 @@ +# mssql/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \ + pymssql, zxjdbc, mxodbc + +base.dialect = pyodbc.dialect + +from sqlalchemy.dialects.mssql.base import \ + INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \ + NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\ + DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \ + BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\ + MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect + + +__all__ = ( + 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR', + 'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME', + 'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME', + 'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP', + 'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect' +) diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 00000000..95cf4242 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,79 @@ +# mssql/adodbapi.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+adodbapi + :name: adodbapi + :dbapi: adodbapi + :connectstring: mssql+adodbapi://:@ + :url: http://adodbapi.sourceforge.net/ + +.. note:: + + The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and + above at this time. + +""" +import datetime +from sqlalchemy import types as sqltypes, util +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect, coltype): + def process(value): + # adodbapi will return datetimes with empty time + # values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.DateTime: MSDateTime_adodbapi + } + ) + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append("Data Source=%s, %s" % + (keys.get("host"), keys.get("port"))) + else: + connectors.append("Data Source=%s" % keys.get("host")) + connectors.append("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join(connectors)], {}] + + def is_disconnect(self, e, connection, cursor): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \ + "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 00000000..522cb5ce --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1550 @@ +# mssql/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql + :name: Microsoft SQL Server + + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + from sqlalchemy import Table, Integer, Sequence, Column + + Table('test', metadata, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create(some_engine) + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +Implicit ``autoincrement`` behavior works the same in MSSQL as it +does in other dialects and results in an ``IDENTITY`` column. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +Character collations are supported by the base string types, +specified by the string argument "collation":: + + from sqlalchemy import VARCHAR + Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) + +When such a column is associated with a :class:`.Table`, the +CREATE TABLE statement for this column will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +.. versionadded:: 0.8 Character collations are now part of the base string + types. + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +.. _mssql_indexes: + +Clustered Index Support +----------------------- + +The MSSQL dialect supports clustered indexes (and primary keys) via the +``mssql_clustered`` option. This option is available to :class:`.Index`, +:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`. + +To generate a clustered index:: + + Index("my_index", table.c.x, mssql_clustered=True) + +which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``. + +.. versionadded:: 0.8 + +To generate a clustered primary key use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, PRIMARY KEY CLUSTERED (x, y)) + +Similarly, we can generate a clustered unique constraint using:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) + + .. versionadded:: 0.9.2 + +MSSQL-Specific Index Options +----------------------------- + +In addition to clustering, the MSSQL dialect supports other special options +for :class:`.Index`. + +INCLUDE +^^^^^^^ + +The ``mssql_include`` option renders INCLUDE(colname) for the given string names:: + + Index("my_index", table.c.x, mssql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +.. versionadded:: 0.8 + +Index ordering +^^^^^^^^^^^^^^ + +Index ordering is available via functional expressions, such as:: + + Index("my_index", table.c.x.desc()) + +would render the index as ``CREATE INDEX my_index ON table (x DESC)`` + +.. versionadded:: 0.8 + +.. seealso:: + + :ref:`schema_indexes_functional` + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatible with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always return the database +server version information (in this case SQL2005) and not the +compatibility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Triggers +-------- + +SQLAlchemy by default uses OUTPUT INSERTED to get at newly +generated primary key values via IDENTITY columns or other +server side defaults. MS-SQL does not +allow the usage of OUTPUT INSERTED on tables that have triggers. +To disable the usage of OUTPUT INSERTED on a per-table basis, +specify ``implicit_returning=False`` for each :class:`.Table` +which has triggers:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + # ..., + implicit_returning=False + ) + +Declarative form:: + + class MyClass(Base): + # ... + __table_args__ = {'implicit_returning':False} + + +This option can also be specified engine-wide using the +``implicit_returning=False`` argument on :func:`.create_engine`. + +Enabling Snapshot Isolation +--------------------------- + +Not necessarily specific to SQLAlchemy, SQL Server has a default transaction +isolation mode that locks entire tables, and causes even mildly concurrent +applications to have long held locks and frequent deadlocks. +Enabling snapshot isolation for the database as a whole is recommended +for modern levels of concurrency support. This is accomplished via the +following ALTER DATABASE commands executed at the SQL prompt:: + + ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON + + ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON + +Background on SQL Server snapshot isolation is available at +http://msdn.microsoft.com/en-us/library/ms175095.aspx. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table +* reflection of indexes does not work with versions older than + SQL Server 2005 + +""" +import datetime +import operator +import re + +from ... import sql, schema as sa_schema, exc, util +from ...sql import compiler, expression, \ + util as sql_util, cast +from ... import engine +from ...engine import reflection, default +from ... import types as sqltypes +from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ + VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR + + +from ...util import update_wrapper +from . import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class REAL(sqltypes.REAL): + __visit_name__ = 'REAL' + + def __init__(self, **kw): + # REAL is a synonym for FLOAT(24) on SQL server + kw['precision'] = 24 + super(REAL, self).__init__(**kw) + + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, util.string_types): + return datetime.date(*[ + int(x or 0) + for x in self._reg.match(value).groups() + ]) + else: + return value + return process + + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine( + self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, util.string_types): + return datetime.time(*[ + int(x or 0) + for x in self._reg.match(value).groups()]) + else: + return value + return process +_MSTime = TIME + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kw): + super(DATETIME2, self).__init__(**kw) + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + super(_StringType, self).__init__(collation=collation) + + + + +class NTEXT(sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = 'IMAGE' + + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSDateTime = _MSDateTime +MSDate = _MSDate +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +ischema_names = { + 'int': INTEGER, + 'bigint': BIGINT, + 'smallint': SMALLINT, + 'tinyint': TINYINT, + 'varchar': VARCHAR, + 'nvarchar': NVARCHAR, + 'char': CHAR, + 'nchar': NCHAR, + 'text': TEXT, + 'ntext': NTEXT, + 'decimal': DECIMAL, + 'numeric': NUMERIC, + 'float': FLOAT, + 'datetime': DATETIME, + 'datetime2': DATETIME2, + 'datetimeoffset': DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime': SMALLDATETIME, + 'binary': BINARY, + 'varbinary': VARBINARY, + 'bit': BIT, + 'real': REAL, + 'image': IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_, length=None): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if not length: + length = type_.length + + if length: + spec = spec + "(%s)" % length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_, length=type_.length or 'max') + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_, length=type_.length or 'max') + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_large_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + return self._extend( + "VARBINARY", + type_, + length=type_.length or 'max') + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = \ + seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.root_connection._cursor_execute(self.cursor, + "SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl), + (), self) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + conn = self.root_connection + if self._select_lastrowid: + if self.dialect.use_scope_identity: + conn._cursor_execute(self.cursor, + "SELECT scope_identity() AS lastrowid", (), self) + else: + conn._cursor_execute(self.cursor, + "SELECT @@identity AS lastrowid", (), self) + # fetchall() ensures the cursor is consumed without closing it + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and \ + self.compiled.returning: + self._result_proxy = engine.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + conn._cursor_execute(self.cursor, + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table), + (), self) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute( + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer.\ + format_table(self.compiled.statement.table) + ) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return engine.ResultProxy(self) + + +class MSSQLCompiler(compiler.SQLCompiler): + returning_precedes_values = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super(MSSQLCompiler, self).__init__(*args, **kwargs) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "%s + %s" % \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def visit_true(self, expr, **kw): + return '1' + + def visit_false(self, expr, **kw): + return '0' + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit is not None: + s = select._distinct and "DISTINCT " or "" + + # ODBC drivers and possibly others + # don't support bind params in the SELECT clause on SQL Server. + # so have to use literal here. + if select._limit is not None: + if not select._offset: + s += "TOP %d " % select._limit + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def get_from_hint_text(self, table, text): + return text + + def get_crud_hint_text(self, table, text): + return text + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if select._offset and not getattr(select, '_mssql_visit', None): + # to use ROW_NUMBER(), an ORDER BY is required. + if not select._order_by_clause.clauses: + raise exc.CompileError('MSSQL requires an order_by when ' + 'using an offset.') + + _offset = select._offset + _limit = select._limit + _order_by_clauses = select._order_by_clause.clauses + select = select._generate() + select._mssql_visit = True + select = select.column( + sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) + .label("mssql_rn") + ).order_by(None).alias() + + mssql_rn = sql.column('mssql_rn') + limitselect = sql.select([c for c in select.c if + c.key != 'mssql_rn']) + limitselect.append_whereclause(mssql_rn > _offset) + if _limit is not None: + limitselect.append_whereclause(mssql_rn <= (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): + if mssql_aliased is table or iscrud: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=table, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + kwargs['mssql_aliased'] = alias.original + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % \ + (field, self.process(extract.expr, **kw)) + + def visit_savepoint(self, savepoint_stmt): + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return ("ROLLBACK TRANSACTION %s" + % self.preparer.format_savepoint(savepoint_stmt)) + + def visit_column(self, column, add_to_result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error( + t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type + ) + + return super(MSSQLCompiler, self).\ + visit_column(converted, **kwargs) + + return super(MSSQLCompiler, self).visit_column( + column, add_to_result_map=add_to_result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if ( + isinstance(binary.left, expression.BindParameter) + and binary.operator == operator.eq + and not isinstance(binary.right, expression.BindParameter) + ): + return self.process( + expression.BinaryExpression(binary.right, + binary.left, + binary.operator), + **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + + columns = [ + self._label_select_column(None, adapter.traverse(c), + True, False, {}) + for c in expression._select_iterables(returning_cols) + ] + + return 'OUTPUT ' + ', '.join(columns) + + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).\ + label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which + # SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select, **kw): + order_by = self.process(select._order_by_clause, **kw) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ', '.join( + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in [from_table] + extra_froms) + + +class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + ansi_bind_rules = True + + def visit_in_op_binary(self, binary, operator, **kw): + kw['literal_binds'] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_notin_op_binary(self, binary, operator, **kw): + kw['literal_binds'] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def render_literal_value(self, value, type_): + """ + For date and datetime values, convert to a string + format acceptable to MSSQL. That seems to be the + so-called ODBC canonical date format which looks + like this: + + yyyy-mm-dd hh:mi:ss.mmm(24h) + + For other data types, call the base class implementation. + """ + # datetime and date are both subclasses of datetime.date + if issubclass(type(value), datetime.date): + # SQL Server wants single quotes around the date string. + return "'" + str(value) + "'" + else: + return super(MSSQLStrictCompiler, self).\ + render_literal_value(value, type_) + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = (self.preparer.format_column(column) + " " + + self.dialect.type_compiler.process(column.type)) + + if column.nullable is not None: + if not column.nullable or column.primary_key or \ + isinstance(column.default, sa_schema.Sequence): + colspec += " NOT NULL" + else: + colspec += " NULL" + + if column.table is None: + raise exc.CompileError( + "mssql requires Table-bound columns " + "in order to generate DDL") + + # install an IDENTITY Sequence if we either a sequence or an implicit IDENTITY column + if isinstance(column.default, sa_schema.Sequence): + if column.default.start == 0: + start = 0 + else: + start = column.default.start or 1 + + colspec += " IDENTITY(%s,%s)" % (start, column.default.increment or 1) + elif column is column.table._autoincrement_column: + colspec += " IDENTITY(1,1)" + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_create_index(self, create, include_schema=False): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + # handle clustering option + if index.dialect_options['mssql']['clustered']: + text += "CLUSTERED " + + text += "INDEX %s ON %s (%s)" \ + % ( + self._prepared_index_name(index, + include_schema=include_schema), + preparer.format_table(index.table), + ', '.join( + self.sql_compiler.process(expr, + include_table=False, literal_binds=True) for + expr in index.expressions) + ) + + # handle other included columns + if index.dialect_options['mssql']['include']: + inclusions = [index.table.c[col] + if isinstance(col, util.string_types) else col + for col in index.dialect_options['mssql']['include']] + + text += " INCLUDE (%s)" \ + % ', '.join([preparer.quote(c.name) + for c in inclusions]) + + return text + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s ON %s" % ( + self._prepared_index_name(drop.element, include_schema=False), + self.preparer.format_table(drop.element.table) + ) + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return '' + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + + if constraint.dialect_options['mssql']['clustered']: + text += "CLUSTERED " + + text += "(%s)" % ', '.join(self.preparer.quote(c.name) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_unique_constraint(self, constraint): + if len(constraint) == 0: + return '' + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "UNIQUE " + + if constraint.dialect_options['mssql']['clustered']: + text += "CLUSTERED " + + text += "(%s)" % ', '.join(self.preparer.quote(c.name) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', + final_quote=']') + + def _escape_identifier(self, value): + return value + + def quote_schema(self, schema, force=None): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + + +def _db_plus_owner_listing(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db(dbname, connection, fn, dialect, connection, + dbname, owner, schema, **kw) + return update_wrapper(wrap, fn) + + +def _db_plus_owner(fn): + def wrap(dialect, connection, tablename, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db(dbname, connection, fn, dialect, connection, + tablename, dbname, owner, schema, **kw) + return update_wrapper(wrap, fn) + + +def _switch_db(dbname, connection, fn, *arg, **kw): + if dbname: + current_db = connection.scalar("select db_name()") + connection.execute("use %s" % dbname) + try: + return fn(*arg, **kw) + finally: + if dbname: + connection.execute("use %s" % current_db) + + +def _owner_plus_db(dialect, schema): + if not schema: + return None, dialect.default_schema_name + elif "." in schema: + return schema.split(".", 1) + else: + return None, schema + + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + + colspecs = { + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate, + sqltypes.Time: TIME, + } + + ischema_names = ischema_names + + supports_native_boolean = False + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + construct_arguments = [ + (sa_schema.PrimaryKeyConstraint, { + "clustered": False + }), + (sa_schema.UniqueConstraint, { + "clustered": False + }), + (sa_schema.Index, { + "clustered": False, + "include": None + }) + ] + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + # give the DBAPI a push + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + super(MSDialect, self).do_savepoint(connection, name) + + def do_release_savepoint(self, connection, name): + # SQL Server does not support RELEASE SAVEPOINT + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info[0] not in list(range(8, 17)): + # FreeTDS with version 4.2 seems to report here + # a number like "95.10.255". Don't know what + # that is. So emit warning. + util.warn( + "Unrecognized server version info '%s'. Version specific " + "behaviors may not function properly. If using ODBC " + "with FreeTDS, ensure server version 7.0 or 8.0, not 4.2, " + "is configured in the FreeTDS configuration." % + ".".join(str(x) for x in self.server_version_info)) + if self.server_version_info >= MS_2005_VERSION and \ + 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name()") + if user_name is not None: + # now, get the default schema + query = sql.text(""" + SELECT default_schema_name FROM + sys.database_principals + WHERE name = :name + AND type = 'S' + """) + try: + default_schema_name = connection.scalar(query, name=user_name) + if default_schema_name is not None: + return util.text_type(default_schema_name) + except: + pass + return self.schema_name + + @_db_plus_owner + def has_table(self, connection, tablename, dbname, owner, schema): + columns = ischema.columns + + whereclause = columns.c.table_name == tablename + + if owner: + whereclause = sql.and_(whereclause, + columns.c.table_schema == owner) + s = sql.select([columns], whereclause) + c = connection.execute(s) + return c.first() is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + @_db_plus_owner_listing + def get_table_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == 'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + @_db_plus_owner_listing + def get_view_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == 'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + @reflection.cache + @_db_plus_owner + def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): + # using system catalogs, don't support index reflection + # below MS 2005 + if self.server_version_info < MS_2005_VERSION: + return [] + + rp = connection.execute( + sql.text("select ind.index_id, ind.is_unique, ind.name " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0", + bindparams=[ + sql.bindparam('tabname', tablename, + sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', owner, + sqltypes.String(convert_unicode=True)) + ], + typemap={ + 'name': sqltypes.Unicode() + } + ) + ) + indexes = {} + for row in rp: + indexes[row['index_id']] = { + 'name': row['name'], + 'unique': row['is_unique'] == 1, + 'column_names': [] + } + rp = connection.execute( + sql.text( + "select ind_col.index_id, ind_col.object_id, col.name " + "from sys.columns as col " + "join sys.tables as tab on tab.object_id=col.object_id " + "join sys.index_columns as ind_col on " + "(ind_col.column_id=col.column_id and " + "ind_col.object_id=tab.object_id) " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name=:tabname " + "and sch.name=:schname", + bindparams=[ + sql.bindparam('tabname', tablename, + sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', owner, + sqltypes.String(convert_unicode=True)) + ], + typemap={'name': sqltypes.Unicode()} + ), + ) + for row in rp: + if row['index_id'] in indexes: + indexes[row['index_id']]['column_names'].append(row['name']) + + return list(indexes.values()) + + @reflection.cache + @_db_plus_owner + def get_view_definition(self, connection, viewname, dbname, owner, schema, **kw): + rp = connection.execute( + sql.text( + "select definition from sys.sql_modules as mod, " + "sys.views as views, " + "sys.schemas as sch" + " where " + "mod.object_id=views.object_id and " + "views.schema_id=sch.schema_id and " + "views.name=:viewname and sch.name=:schname", + bindparams=[ + sql.bindparam('viewname', viewname, + sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', owner, + sqltypes.String(convert_unicode=True)) + ] + ) + ) + + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + @_db_plus_owner + def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + # Get base columns + columns = ischema.columns + if owner: + whereclause = sql.and_(columns.c.table_name == tablename, + columns.c.table_schema == owner) + else: + whereclause = columns.c.table_name == tablename + s = sql.select([columns], whereclause, + order_by=[columns.c.ordinal_position]) + + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, + numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, + MSNText, MSBinary, MSVarBinary, + sqltypes.LargeBinary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or \ + (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" % + (type, name)) + coltype = sqltypes.NULLTYPE + else: + if issubclass(coltype, sqltypes.Numeric) and \ + coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name': name, + 'type': coltype, + 'nullable': nullable, + 'default': default, + 'autoincrement': False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', " + "@table_owner = '%s'" + % (tablename, owner)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + + if ic is not None and self.server_version_info >= MS_2005_VERSION: + table_fullname = "%s.%s" % (owner, tablename) + cursor = connection.execute( + "select ident_seed('%s'), ident_incr('%s')" + % (table_fullname, table_fullname) + ) + + row = cursor.first() + if row is not None and row[0] is not None: + colmap[ic]['sequence'].update({ + 'start': int(row[0]), + 'increment': int(row[1]) + }) + return cols + + @reflection.cache + @_db_plus_owner + def get_pk_constraint(self, connection, tablename, dbname, owner, schema, **kw): + pkeys = [] + TC = ischema.constraints + C = ischema.key_constraints.alias('C') + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type, C.c.constraint_name], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner) + ) + c = connection.execute(s) + constraint_name = None + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + if constraint_name is None: + constraint_name = row[C.c.constraint_name.name] + return {'constrained_columns': pkeys, 'name': constraint_name} + + @reflection.cache + @_db_plus_owner + def get_foreign_keys(self, connection, tablename, dbname, owner, schema, **kw): + RR = ischema.ref_constraints + C = ischema.key_constraints.alias('C') + R = ischema.key_constraints.alias('R') + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, + RR.c.update_rule, + RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == owner, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == + RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by=[RR.c.constraint_name, R.c.ordinal_position] + ) + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': None, + 'referred_columns': [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + if schema is not None or owner != rschema: + if dbname: + rschema = dbname + "." + rschema + rec['referred_schema'] = rschema + + local_cols, remote_cols = \ + rec['constrained_columns'],\ + rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return list(fkeys.values()) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 00000000..26e70f7f --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,114 @@ +# mssql/information_schema.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +# TODO: should be using the sys. catalog with SQL Server, not information schema + +from ... import Table, MetaData, Column +from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator +from ... import cast +from ... import util +from ...sql import expression +from ...ext.compiler import compiles + +ischema = MetaData() + +class CoerceUnicode(TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + if util.py2k and isinstance(value, util.binary_type): + value = value.decode(dialect.encoding) + return value + + def bind_expression(self, bindvalue): + return _cast_on_2005(bindvalue) + +class _cast_on_2005(expression.ColumnElement): + def __init__(self, bindvalue): + self.bindvalue = bindvalue + +@compiles(_cast_on_2005) +def _compile(element, compiler, **kw): + from . import base + if compiler.dialect.server_version_info < base.MS_2005_VERSION: + return compiler.process(element.bindvalue, **kw) + else: + return compiler.process(cast(element.bindvalue, Unicode), **kw) + +schemata = Table("SCHEMATA", ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") + +tables = Table("TABLES", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), + schema="INFORMATION_SCHEMA") + +columns = Table("COLUMNS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") + +constraints = Table("TABLE_CONSTRAINTS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"), + schema="INFORMATION_SCHEMA") + +column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") + +key_constraints = Table("KEY_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") + +ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, + key="unique_constraint_catalog"), + + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, + key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, + key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") + +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py new file mode 100644 index 00000000..5b686c47 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -0,0 +1,111 @@ +# mssql/mxodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+mxodbc + :name: mxODBC + :dbapi: mxodbc + :connectstring: mssql+mxodbc://:@ + :url: http://www.egenix.com/ + +Execution Modes +--------------- + +mxODBC features two styles of statement execution, using the +``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being +an extension to the DBAPI specification). The former makes use of a particular +API call specific to the SQL Server Native Client ODBC driver known +SQLDescribeParam, while the latter does not. + +mxODBC apparently only makes repeated use of a single prepared statement +when SQLDescribeParam is used. The advantage to prepared statement reuse is +one of performance. The disadvantage is that SQLDescribeParam has a limited +set of scenarios in which bind parameters are understood, including that they +cannot be placed within the argument lists of function calls, anywhere outside +the FROM, or even within subqueries within the FROM clause - making the usage +of bind parameters within SELECT statements impossible for all but the most +simplistic statements. + +For this reason, the mxODBC dialect uses the "native" mode by default only for +INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for +all other statements. + +This behavior can be controlled via +:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the +``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a +value of ``True`` will unconditionally use native bind parameters and a value +of ``False`` will unconditionally use string-escaped parameters. + +""" + + +from ... import types as sqltypes +from ...connectors.mxodbc import MxODBCConnector +from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc +from .base import (MSDialect, + MSSQLStrictCompiler, + _MSDateTime, _MSDate, _MSTime) + + +class _MSNumeric_mxodbc(_MSNumeric_pyodbc): + """Include pyodbc's numeric processor. + """ + + +class _MSDate_mxodbc(_MSDate): + def bind_processor(self, dialect): + def process(value): + if value is not None: + return "%s-%s-%s" % (value.year, value.month, value.day) + else: + return None + return process + + +class _MSTime_mxodbc(_MSTime): + def bind_processor(self, dialect): + def process(value): + if value is not None: + return "%s:%s:%s" % (value.hour, value.minute, value.second) + else: + return None + return process + + +class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): + """ + The pyodbc execution context is useful for enabling + SELECT SCOPE_IDENTITY in cases where OUTPUT clause + does not work (tables with insert triggers). + """ + #todo - investigate whether the pyodbc execution context + # is really only being used in cases where OUTPUT + # won't work. + + +class MSDialect_mxodbc(MxODBCConnector, MSDialect): + + # this is only needed if "native ODBC" mode is used, + # which is now disabled by default. + #statement_compiler = MSSQLStrictCompiler + + execution_ctx_cls = MSExecutionContext_mxodbc + + # flag used by _MSNumeric_mxodbc + _need_decimal_fix = True + + colspecs = { + sqltypes.Numeric: _MSNumeric_mxodbc, + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate_mxodbc, + sqltypes.Time: _MSTime_mxodbc, + } + + def __init__(self, description_encoding=None, **params): + super(MSDialect_mxodbc, self).__init__(**params) + self.description_encoding = description_encoding + +dialect = MSDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 00000000..0182fee1 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,92 @@ +# mssql/pymssql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+pymssql + :name: pymssql + :dbapi: pymssql + :connectstring: mssql+pymssql://:@?charset=utf8 + :url: http://pymssql.org/ + +pymssql is a Python module that provides a Python DBAPI interface around +`FreeTDS `_. Compatible builds are available for +Linux, MacOSX and Windows platforms. + +""" +from .base import MSDialect +from ... import types as sqltypes, util, processors +import re + + +class _MSNumeric_pymssql(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + driver = 'pymssql' + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric: _MSNumeric_pymssql, + sqltypes.Float: sqltypes.Float, + } + ) + + @classmethod + def dbapi(cls): + module = __import__('pymssql') + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = lambda x: x if hasattr(x, 'decode') else str(x) + + client_ver = tuple(int(x) for x in module.__version__.split(".")) + if client_ver < (1, ): + util.warn("The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI.") + return module + + def __init__(self, **params): + super(MSDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + def _get_server_version_info(self, connection): + vers = connection.scalar("select @@version") + m = re.match( + r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers) + if m: + return tuple(int(x) for x in m.group(1, 2, 3, 4)) + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + port = opts.pop('port', None) + if port and 'host' in opts: + opts['host'] = "%s:%s" % (opts['host'], port) + return [[], opts] + + def is_disconnect(self, e, connection, cursor): + for msg in ( + "Adaptive Server connection timed out", + "Net-Lib error during Connection reset by peer", + "message 20003", # connection timeout + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed" + ): + if msg in str(e): + return True + else: + return False + +dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 00000000..8c43eb8a --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,260 @@ +# mssql/pyodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mssql+pyodbc://:@ + :url: http://pypi.python.org/pypi/pyodbc/ + +Additional Connection Examples +------------------------------- + +Examples of pyodbc connection string URLs: + +* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;Trusted_Connection=Yes + +* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection + that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection + string which includes the port + information using the comma syntax. This will create the following + connection string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection + string that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +* ``mssql+pyodbc://user:pass@host/db?driver=MyDriver`` - connects using a connection + string that includes a custom + ODBC driver name. This will create the following connection string:: + + DRIVER={MyDriver};Server=host;Database=db;UID=user;PWD=pass + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + +Unicode Binds +------------- + +The current state of PyODBC on a unix backend with FreeTDS and/or +EasySoft is poor regarding unicode; different OS platforms and versions of UnixODBC +versus IODBC versus FreeTDS/EasySoft versus PyODBC itself dramatically +alter how strings are received. The PyODBC dialect attempts to use all the information +it knows to determine whether or not a Python unicode literal can be +passed directly to the PyODBC driver or not; while SQLAlchemy can encode +these to bytestrings first, some users have reported that PyODBC mis-handles +bytestrings for certain encodings and requires a Python unicode object, +while the author has observed widespread cases where a Python unicode +is completely misinterpreted by PyODBC, particularly when dealing with +the information schema tables used in table reflection, and the value +must first be encoded to a bytestring. + +It is for this reason that whether or not unicode literals for bound +parameters be sent to PyODBC can be controlled using the +``supports_unicode_binds`` parameter to ``create_engine()``. When +left at its default of ``None``, the PyODBC dialect will use its +best guess as to whether or not the driver deals with unicode literals +well. When ``False``, unicode literals will be encoded first, and when +``True`` unicode literals will be passed straight through. This is an interim +flag that hopefully should not be needed when the unicode situation stabilizes +for unix + PyODBC. + +.. versionadded:: 0.7.7 + ``supports_unicode_binds`` parameter to ``create_engine()``\ . + +""" + +from .base import MSExecutionContext, MSDialect +from ...connectors.pyodbc import PyODBCConnector +from ... import types as sqltypes, util +import decimal + +class _ms_numeric_pyodbc(object): + + """Turns Decimals with adjusted() < 0 or > 7 into strings. + + The routines here are needed for older pyodbc versions + as well as current mxODBC versions. + + """ + + def bind_processor(self, dialect): + + super_process = super(_ms_numeric_pyodbc, self).\ + bind_processor(dialect) + + if not dialect._need_decimal_fix: + return super_process + + def process(value): + if self.asdecimal and \ + isinstance(value, decimal.Decimal): + + adjusted = value.adjusted() + if adjusted < 0: + return self._small_dec_to_string(value) + elif adjusted > 7: + return self._large_dec_to_string(value) + + if super_process: + return super_process(value) + else: + return value + return process + + # these routines needed for older versions of pyodbc. + # as of 2.1.8 this logic is integrated. + + def _small_dec_to_string(self, value): + return "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]])) + + def _large_dec_to_string(self, value): + _int = value.as_tuple()[1] + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int) - 1))) + else: + if (len(_int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join( + [str(s) for s in _int][0:value.adjusted() + 1]), + "".join( + [str(s) for s in _int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join( + [str(s) for s in _int][0:value.adjusted() + 1])) + return result + +class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric): + pass + +class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float): + pass + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same + statement. + + Background on why "scope_identity()" is preferable to "@@identity": + http://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an + # "INSERT .. DEFAULT VALUES" + if self._select_lastrowid and \ + self.dialect.use_scope_identity and \ + len(self.parameters[0]): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with + # no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed + # without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error as e: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric: _MSNumeric_pyodbc, + sqltypes.Float: _MSFloat_pyodbc + } + ) + + def __init__(self, description_encoding=None, **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.use_scope_identity and \ + self.dbapi and \ + hasattr(self.dbapi.Cursor, 'nextset') + self._need_decimal_fix = self.dbapi and \ + self._dbapi_version() < (2, 1, 8) + +dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py new file mode 100644 index 00000000..706eef3a --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py @@ -0,0 +1,65 @@ +# mssql/zxjdbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+zxjdbc + :name: zxJDBC for Jython + :dbapi: zxjdbc + :connectstring: mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...] + :driverurl: http://jtds.sourceforge.net/ + + +""" +from ...connectors.zxJDBC import ZxJDBCConnector +from .base import MSDialect, MSExecutionContext +from ... import engine + + +class MSExecutionContext_zxjdbc(MSExecutionContext): + + _embedded_scope_identity = False + + def pre_exec(self): + super(MSExecutionContext_zxjdbc, self).pre_exec() + # scope_identity after the fact returns null in jTDS so we must + # embed it + if self._select_lastrowid and self.dialect.use_scope_identity: + self._embedded_scope_identity = True + self.statement += "; SELECT scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + while True: + try: + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error: + self.cursor.nextset() + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and \ + self.compiled.returning: + self._result_proxy = engine.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + table = self.dialect.identifier_preparer.format_table( + self.compiled.statement.table) + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) + + +class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect): + jdbc_db_name = 'jtds:sqlserver' + jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver' + + execution_ctx_cls = MSExecutionContext_zxjdbc + + def _get_server_version_info(self, connection): + return tuple( + int(x) + for x in connection.connection.dbversion.split('.') + ) + +dialect = MSDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 00000000..4eb8cc6d --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,28 @@ +# mysql/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import base, mysqldb, oursql, \ + pyodbc, zxjdbc, mysqlconnector, pymysql,\ + gaerdbms, cymysql + +# default dialect +base.dialect = mysqldb.dialect + +from .base import \ + BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \ + DECIMAL, DOUBLE, ENUM, DECIMAL,\ + FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, \ + MEDIUMINT, MEDIUMTEXT, NCHAR, \ + NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \ + TINYBLOB, TINYINT, TINYTEXT,\ + VARBINARY, VARCHAR, YEAR, dialect + +__all__ = ( +'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', +'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', +'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', +'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect' +) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 00000000..ba6e7b62 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,3078 @@ +# mysql/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql + :name: MySQL + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports MySQL starting with version 4.1 through modern releases. +However, no heroic measures are taken to work around major missing +SQL features - if your server version does not support sub-selects, for +example, they won't work in SQLAlchemy either. + +See the official MySQL documentation for detailed information about features +supported in any given server release. + +.. _mysql_connection_timeouts: + +Connection Timeouts +------------------- + +MySQL features an automatic connection close behavior, for connections that have +been idle for eight hours or more. To circumvent having this issue, use the +``pool_recycle`` option which controls the maximum age of any connection:: + + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + +.. _mysql_storage_engines: + +CREATE TABLE arguments including Storage Engines +------------------------------------------------ + +MySQL's CREATE TABLE syntax includes a wide array of special options, +including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``, ``INSERT_METHOD``, and many more. +To accommodate the rendering of these arguments, specify the form +``mysql_argument_name="value"``. For example, to specify a table with +``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8``, and ``KEY_BLOCK_SIZE`` of ``1024``:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8', + mysql_key_block_size="1024" + ) + +The MySQL dialect will normally transfer any keyword specified as ``mysql_keyword_name`` +to be rendered as ``KEYWORD_NAME`` in the ``CREATE TABLE`` statement. A handful +of these names will render with a space instead of an underscore; to support this, +the MySQL dialect has awareness of these particular names, which include +``DATA DIRECTORY`` (e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g. +``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g. ``mysql_index_directory``). + +The most common argument is ``mysql_engine``, which refers to the storage engine +for the table. Historically, MySQL server installations would default +to ``MyISAM`` for this value, although newer versions may be defaulting +to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support +of transactions and foreign keys. + +A :class:`.Table` that is created in a MySQL database with a storage engine +of ``MyISAM`` will be essentially non-transactional, meaning any INSERT/UPDATE/DELETE +statement referring to this table will be invoked as autocommit. It also will have no +support for foreign key constraints; while the ``CREATE TABLE`` statement +accepts foreign key options, when using the ``MyISAM`` storage engine these +arguments are discarded. Reflecting such a table will also produce no +foreign key constraint information. + +For fully atomic transactions as well as support for foreign key +constraints, all participating ``CREATE TABLE`` statements must specify a +transactional engine, which in the vast majority of cases is ``InnoDB``. + +.. seealso:: + + `The InnoDB Storage Engine + `_ - + on the MySQL website. + +Case Sensitivity and Table Reflection +------------------------------------- + +MySQL has inconsistent support for case-sensitive identifier +names, basing support on specific details of the underlying +operating system. However, it has been observed that no matter +what case sensitivity behavior is present, the names of tables in +foreign key declarations are *always* received from the database +as all-lower case, making it impossible to accurately reflect a +schema where inter-related tables use mixed-case identifier names. + +Therefore it is strongly advised that table names be declared as +all lower case both within SQLAlchemy as well as on the MySQL +database itself, especially if database reflection features are +to be used. + +Transaction Isolation Level +--------------------------- + +:func:`.create_engine` accepts an ``isolation_level`` +parameter which results in the command ``SET SESSION +TRANSACTION ISOLATION LEVEL `` being invoked for +every new connection. Valid values for this parameter are +``READ COMMITTED``, ``READ UNCOMMITTED``, +``REPEATABLE READ``, and ``SERIALIZABLE``:: + + engine = create_engine( + "mysql://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +.. versionadded:: 0.7.6 + +AUTO_INCREMENT Behavior +----------------------- + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on +the first :class:`.Integer` primary key column which is not marked as a foreign key:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by passing ``False`` to the :paramref:`~.Column.autoincrement` +argument of :class:`.Column`. This flag can also be used to enable +auto-increment on a secondary column in a multi-column key for some storage +engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +Ansi Quoting Style +------------------ + +MySQL features two varieties of identifier "quoting style", one using +backticks and the other using quotes, e.g. ```some_identifier``` vs. +``"some_identifier"``. All MySQL dialects detect which version +is in use by checking the value of ``sql_mode`` when a connection is first +established with a particular :class:`.Engine`. This quoting style comes +into play when rendering table and column names as well as when reflecting +existing database structures. The detection is entirely automatic and +no special configuration is needed to use either quoting style. + +.. versionchanged:: 0.6 detection of ANSI quoting style is entirely automatic, + there's no longer any end-user ``create_engine()`` options in this regard. + +MySQL SQL Extensions +-------------------- + +Many of the MySQL SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid MySQL statement can be executed as a string as well. + +Some limited direct support for MySQL extensions to SQL is currently +available. + +* SELECT pragma:: + + select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + +* UPDATE with LIMIT:: + + update(..., mysql_limit=10) + +rowcount Support +---------------- + +SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the +usual definition of "number of rows matched by an UPDATE or DELETE" statement. +This is in contradiction to the default setting on most MySQL DBAPI drivers, +which is "number of rows actually modified/deleted". For this reason, the +SQLAlchemy MySQL dialects always set the ``constants.CLIENT.FOUND_ROWS`` flag, +or whatever is equivalent for the DBAPI in use, on connect, unless the flag value +is overridden using DBAPI-specific options +(such as ``client_flag`` for the MySQL-Python driver, ``found_rows`` for the +OurSQL driver). + +See also: + +:attr:`.ResultProxy.rowcount` + + +CAST Support +------------ + +MySQL documents the CAST operator as available in version 4.0.2. When using the +SQLAlchemy :func:`.cast` function, SQLAlchemy +will not render the CAST token on MySQL before this version, based on server version +detection, instead rendering the internal expression directly. + +CAST may still not be desirable on an early MySQL version post-4.0.2, as it didn't +add all datatype support until 4.1.1. If your application falls into this +narrow area, the behavior of CAST can be controlled using the +:ref:`sqlalchemy.ext.compiler_toplevel` system, as per the recipe below:: + + from sqlalchemy.sql.expression import Cast + from sqlalchemy.ext.compiler import compiles + + @compiles(Cast, 'mysql') + def _check_mysql_version(element, compiler, **kw): + if compiler.dialect.server_version_info < (4, 1, 0): + return compiler.process(element.clause, **kw) + else: + return compiler.visit_cast(element, **kw) + +The above function, which only needs to be declared once +within an application, overrides the compilation of the +:func:`.cast` construct to check for version 4.1.0 before +fully rendering CAST; else the internal element of the +construct is rendered directly. + + +.. _mysql_indexes: + +MySQL Specific Index Options +---------------------------- + +MySQL-specific extensions to the :class:`.Index` construct are available. + +Index Length +~~~~~~~~~~~~~ + +MySQL provides an option to create index entries with a certain length, where +"length" refers to the number of characters or bytes in each value which will +become part of the index. SQLAlchemy provides this feature via the +``mysql_length`` parameter:: + + Index('my_index', my_table.c.data, mysql_length=10) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, 'b': 9}) + +Prefix lengths are given in characters for nonbinary string types and in bytes +for binary string types. The value passed to the keyword argument *must* be +either an integer (and, thus, specify the same prefix length value for all +columns of the index) or a dict in which keys are column names and values are +prefix length values for corresponding columns. MySQL only allows a length for +a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY, VARBINARY and +BLOB. + +.. versionadded:: 0.8.2 ``mysql_length`` may now be specified as a dictionary + for use with composite indexes. + +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + +More information can be found at: + +http://dev.mysql.com/doc/refman/5.0/en/create-index.html + +http://dev.mysql.com/doc/refman/5.0/en/create-table.html + +.. _mysql_foreign_keys: + +MySQL Foreign Keys +------------------ + +MySQL's behavior regarding foreign keys has some important caveats. + +Foreign Key Arguments to Avoid +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL does not support the foreign key arguments "DEFERRABLE", "INITIALLY", +or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with +:class:`.ForeignKeyConstraint` or :class:`.ForeignKey` will have the effect of these keywords being +rendered in a DDL expression, which will then raise an error on MySQL. +In order to use these keywords on a foreign key while having them ignored +on a MySQL backend, use a custom compile rule:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.schema import ForeignKeyConstraint + + @compiles(ForeignKeyConstraint, "mysql") + def process(element, compiler, **kw): + element.deferrable = element.initially = None + return compiler.visit_foreign_key_constraint(element, **kw) + +.. versionchanged:: 0.9.0 - the MySQL backend no longer silently ignores + the ``deferrable`` or ``initially`` keyword arguments of :class:`.ForeignKeyConstraint` + and :class:`.ForeignKey`. + +The "MATCH" keyword is in fact more insidious, and is explicitly disallowed +by SQLAlchemy in conjunction with the MySQL backend. This argument is silently +ignored by MySQL, but in addition has the effect of ON UPDATE and ON DELETE options +also being ignored by the backend. Therefore MATCH should never be used with the +MySQL backend; as is the case with DEFERRABLE and INITIALLY, custom compilation +rules can be used to correct a MySQL ForeignKeyConstraint at DDL definition time. + +.. versionadded:: 0.9.0 - the MySQL backend will raise a :class:`.CompileError` + when the ``match`` keyword is used with :class:`.ForeignKeyConstraint` + or :class:`.ForeignKey`. + +Reflection of Foreign Key Constraints +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Not all MySQL storage engines support foreign keys. When using the +very common ``MyISAM`` MySQL storage engine, the information loaded by table +reflection will not include foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload=True + ) + +.. seealso:: + + :ref:`mysql_storage_engines` + +""" + +import datetime +import re +import sys + +from ... import schema as sa_schema +from ... import exc, log, sql, util +from ...sql import compiler +from array import array as _array + +from ...engine import reflection +from ...engine import default +from ... import types as sqltypes +from ...util import topological +from ...types import DATE, BOOLEAN, \ + BLOB, BINARY, VARBINARY + +RESERVED_WORDS = set( + ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', + 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', + 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', + 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', + 'current_user', 'cursor', 'database', 'databases', 'day_hour', + 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', + 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', + 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', + 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', + 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', + 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having', + 'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if', + 'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive', + 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', 'integer', + 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', 'kill', + 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', 'load', + 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', 'longtext', + 'loop', 'low_priority', 'master_ssl_verify_server_cert', 'match', + 'mediumblob', 'mediumint', 'mediumtext', 'middleint', + 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', + 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', + 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', + 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', + 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', + 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', + 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', + 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', + 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', + 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', + 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', + 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', + 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', + 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', + 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', + + 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 + + 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 + + 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', + 'read_only', 'read_write', # 5.1 + + 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', + 'resignal', 'signal', 'slow', # 5.5 + + 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', + 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 + + ]) + +AUTOCOMMIT_RE = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', + re.I | re.UNICODE) +SET_RE = re.compile( + r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', + re.I | re.UNICODE) + + +class _NumericType(object): + """Base for MySQL numeric types. + + This is the base both for NUMERIC as well as INTEGER, hence + it's a mixin. + + """ + + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill + super(_NumericType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[_NumericType, sqltypes.Numeric]) + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and \ + ( + (precision is None and scale is not None) or + (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") + super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[_FloatType, _NumericType, sqltypes.Float]) + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super(_IntegerType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]) + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__(self, charset=None, collation=None, + ascii=False, binary=False, unicode=False, + national=False, **kw): + self.charset = charset + + # allow collate= or collation= + kw.setdefault('collation', kw.pop('collate', collation)) + + self.ascii = ascii + self.unicode = unicode + self.binary = binary + self.national = national + super(_StringType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[_StringType, sqltypes.String]) + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(NUMERIC, self).__init__(precision=precision, + scale=scale, asdecimal=asdecimal, **kw) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DECIMAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class DOUBLE(_FloatType): + """MySQL DOUBLE type.""" + + __visit_name__ = 'DOUBLE' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + .. note:: + + The :class:`.DOUBLE` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. Specify + either ``scale=n`` or ``decimal_return_scale=n`` in order to change + this scale, or ``asdecimal=False`` to return values directly as + Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DOUBLE, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class REAL(_FloatType, sqltypes.REAL): + """MySQL REAL type.""" + + __visit_name__ = 'REAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + .. note:: + + The :class:`.REAL` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. Specify + either ``scale=n`` or ``decimal_return_scale=n`` in order to change + this scale, or ``asdecimal=False`` to return values directly as + Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(REAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = 'FLOAT' + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(FLOAT, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) + + def bind_processor(self, dialect): + return None + + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = 'INTEGER' + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(INTEGER, self).__init__(display_width=display_width, **kw) + + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = 'BIGINT' + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(BIGINT, self).__init__(display_width=display_width, **kw) + + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = 'MEDIUMINT' + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = 'TINYINT' + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(TINYINT, self).__init__(display_width=display_width, **kw) + + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = 'SMALLINT' + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(SMALLINT, self).__init__(display_width=display_width, **kw) + + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for + MyISAM, MEMORY, InnoDB and BDB. For older versions, use a MSTinyInteger() + type. + + """ + + __visit_name__ = 'BIT' + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. + + """ + + def process(value): + if value is not None: + v = 0 + for i in map(ord, value): + v = v << 8 | i + return v + return value + return process + + +class TIME(sqltypes.TIME): + """MySQL TIME type. """ + + __visit_name__ = 'TIME' + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + .. versionadded:: 0.8 The MySQL-specific TIME + type as well as fractional seconds support. + + """ + super(TIME, self).__init__(timezone=timezone) + self.fsp = fsp + + def result_processor(self, dialect, coltype): + time = datetime.time + + def process(value): + # convert from a timedelta value + if value is not None: + microseconds = value.microseconds + seconds = value.seconds + minutes = seconds // 60 + return time(minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds) + else: + return None + return process + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type. + + """ + + __visit_name__ = 'TIMESTAMP' + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIMESTAMP type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIMESTAMP type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + .. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.TIMESTAMP` + with fractional seconds support. + + """ + super(TIMESTAMP, self).__init__(timezone=timezone) + self.fsp = fsp + + +class DATETIME(sqltypes.DATETIME): + """MySQL DATETIME type. + + """ + + __visit_name__ = 'DATETIME' + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL DATETIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the DATETIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + .. versionadded:: 0.8.5 Added MySQL-specific :class:`.mysql.DATETIME` + with fractional seconds support. + + """ + super(DATETIME, self).__init__(timezone=timezone) + self.fsp = fsp + + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = 'YEAR' + + def __init__(self, display_width=None): + self.display_width = display_width + + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = 'TEXT' + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TEXT, self).__init__(length=length, **kw) + + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for text up to 2^8 characters.""" + + __visit_name__ = 'TINYTEXT' + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TINYTEXT, self).__init__(**kwargs) + + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + + __visit_name__ = 'MEDIUMTEXT' + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(MEDIUMTEXT, self).__init__(**kwargs) + + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for text up to 2^32 characters.""" + + __visit_name__ = 'LONGTEXT' + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(LONGTEXT, self).__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = 'VARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = 'CHAR' + + def __init__(self, length=None, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super(CHAR, self).__init__(length=length, **kwargs) + + @classmethod + def _adapt_string_for_cast(self, type_): + # copy the given string type into a CHAR + # for the purposes of rendering a CAST expression + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.CHAR): + return type_ + elif isinstance(type_, _StringType): + return CHAR( + length=type_.length, + charset=type_.charset, + collation=type_.collation, + ascii=type_.ascii, + binary=type_.binary, + unicode=type_.unicode, + national=False # not supported in CAST + ) + else: + return CHAR(length=type_.length) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NVARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NVARCHAR, self).__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NCHAR, self).__init__(length=length, **kwargs) + + +class TINYBLOB(sqltypes._Binary): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = 'TINYBLOB' + + +class MEDIUMBLOB(sqltypes._Binary): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = 'MEDIUMBLOB' + + +class LONGBLOB(sqltypes._Binary): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = 'LONGBLOB' + +class _EnumeratedValues(_StringType): + def _init_values(self, values, kw): + self.quoting = kw.pop('quoting', 'auto') + + if self.quoting == 'auto' and len(values): + # What quoting character are we using? + q = None + for e in values: + if len(e) == 0: + self.quoting = 'unquoted' + break + elif q is None: + q = e[0] + + if len(e) == 1 or e[0] != q or e[-1] != q: + self.quoting = 'unquoted' + break + else: + self.quoting = 'quoted' + + if self.quoting == 'quoted': + util.warn_deprecated( + 'Manually quoting %s value literals is deprecated. Supply ' + 'unquoted values and use the quoting= option in cases of ' + 'ambiguity.' % self.__class__.__name__) + + values = self._strip_values(values) + + self._enumerated_values = values + length = max([len(v) for v in values] + [0]) + return values, length + + @classmethod + def _strip_values(cls, values): + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + return strip_values + +class ENUM(sqltypes.Enum, _EnumeratedValues): + """MySQL ENUM type.""" + + __visit_name__ = 'ENUM' + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + E.g.:: + + Column('myenum', ENUM("foo", "bar", "baz")) + + :param enums: The range of valid values for this ENUM. Values will be + quoted when generating the schema according to the quoting flag (see + below). + + :param strict: Defaults to False: ensure that a given value is in this + ENUM's range of permissible values when inserting or updating rows. + Note that MySQL will not raise a fatal error if you attempt to store + an out of range value- an alternate value will be stored instead. + (See MySQL ENUM documentation.) + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. + + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema - this usage is deprecated. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. + + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + """ + values, length = self._init_values(enums, kw) + self.strict = kw.pop('strict', False) + kw.pop('metadata', None) + kw.pop('schema', None) + kw.pop('name', None) + kw.pop('quote', None) + kw.pop('native_enum', None) + kw.pop('inherit_schema', None) + _StringType.__init__(self, length=length, **kw) + sqltypes.Enum.__init__(self, *values) + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[ENUM, _StringType, sqltypes.Enum]) + + def bind_processor(self, dialect): + super_convert = super(ENUM, self).bind_processor(dialect) + + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exc.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + + def adapt(self, cls, **kw): + if issubclass(cls, ENUM): + kw['strict'] = self.strict + return sqltypes.Enum.adapt(self, cls, **kw) + + +class SET(_EnumeratedValues): + """MySQL SET type.""" + + __visit_name__ = 'SET' + + def __init__(self, *values, **kw): + """Construct a SET. + + E.g.:: + + Column('myset', SET("foo", "bar", "baz")) + + :param values: The range of valid values for this SET. Values will be + quoted when generating the schema according to the quoting flag (see + below). + + .. versionchanged:: 0.9.0 quoting is applied automatically to + :class:`.mysql.SET` in the same way as for :class:`.mysql.ENUM`. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. + + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema - this usage is deprecated. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. + + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + .. versionadded:: 0.9.0 + + """ + values, length = self._init_values(values, kw) + self.values = tuple(values) + + kw.setdefault('length', length) + super(SET, self).__init__(**kw) + + def result_processor(self, dialect, coltype): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, set): + # ..some versions convert '' to an empty set + if not value: + value.add('') + return value + # ...and some versions return strings + if value is not None: + return set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(SET, self).bind_processor(dialect) + + def process(value): + if value is None or isinstance(value, util.int_types + util.string_types): + pass + else: + if None in value: + value = set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + +# old names +MSTime = TIME +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + _IntegerType: _IntegerType, + _NumericType: _NumericType, + _FloatType: _FloatType, + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Time: TIME, + sqltypes.Enum: ENUM, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + 'bigint': BIGINT, + 'binary': BINARY, + 'bit': BIT, + 'blob': BLOB, + 'boolean': BOOLEAN, + 'char': CHAR, + 'date': DATE, + 'datetime': DATETIME, + 'decimal': DECIMAL, + 'double': DOUBLE, + 'enum': ENUM, + 'fixed': DECIMAL, + 'float': FLOAT, + 'int': INTEGER, + 'integer': INTEGER, + 'longblob': LONGBLOB, + 'longtext': LONGTEXT, + 'mediumblob': MEDIUMBLOB, + 'mediumint': MEDIUMINT, + 'mediumtext': MEDIUMTEXT, + 'nchar': NCHAR, + 'nvarchar': NVARCHAR, + 'numeric': NUMERIC, + 'set': SET, + 'smallint': SMALLINT, + 'text': TEXT, + 'time': TIME, + 'timestamp': TIMESTAMP, + 'tinyblob': TINYBLOB, + 'tinyint': TINYINT, + 'tinytext': TINYTEXT, + 'varbinary': VARBINARY, + 'varchar': VARCHAR, + 'year': YEAR, +} + + +class MySQLExecutionContext(default.DefaultExecutionContext): + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) + + +class MySQLCompiler(compiler.SQLCompiler): + + render_table_with_column_in_update_from = True + """Overridden from base SQLCompiler value""" + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({'milliseconds': 'millisecond'}) + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_utc_timestamp_func(self, fn, **kw): + return "UTC_TIMESTAMP" + + def visit_sysdate_func(self, fn, **kw): + return "SYSDATE()" + + def visit_concat_op_binary(self, binary, operator, **kw): + return "concat(%s, %s)" % (self.process(binary.left), + self.process(binary.right)) + + def visit_match_op_binary(self, binary, operator, **kw): + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ + (self.process(binary.left), self.process(binary.right)) + + def get_from_hint_text(self, table, text): + return text + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.Integer): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, sqltypes.TIMESTAMP): + return 'DATETIME' + elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, + sqltypes.Date, sqltypes.Time)): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, sqltypes.String) and not isinstance(type_, (ENUM, SET)): + adapted = CHAR._adapt_string_for_cast(type_) + return self.dialect.type_compiler.process(adapted) + elif isinstance(type_, sqltypes._Binary): + return 'BINARY' + elif isinstance(type_, sqltypes.NUMERIC): + return self.dialect.type_compiler.process( + type_).replace('NUMERIC', 'DECIMAL') + else: + return None + + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + if not self.dialect._supports_cast: + return self.process(cast.clause.self_group()) + + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause.self_group()) + + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) + + def render_literal_value(self, value, type_): + value = super(MySQLCompiler, self).render_literal_value(value, type_) + if self.dialect._backslash_escapes: + value = value.replace('\\', '\\\\') + return value + + def get_select_precolumns(self, select): + """Add special MySQL keywords in place of DISTINCT. + + .. note:: + + this usage is deprecated. :meth:`.Select.prefix_with` + should be used for special keywords at the start + of a SELECT. + + """ + if isinstance(select._distinct, util.string_types): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return ''.join( + (self.process(join.left, asfrom=True, **kwargs), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs))) + + def for_update_clause(self, select): + if select._for_update_arg.read: + return " LOCK IN SHARE MODE" + else: + return " FOR UPDATE" + + def limit_clause(self, select): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit, offset = select._limit, select._offset + + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + # http://dev.mysql.com/doc/refman/5.0/en/select.html + if limit is None: + # hardwire the upper limit. Currently + # needed by OurSQL with Python 3 + # (https://bugs.launchpad.net/oursql/+bug/686232), + # but also is consistent with the usage of the upper + # bound as part of MySQL's "syntax" for OFFSET with + # no LIMIT + return ' \n LIMIT %s, %s' % ( + self.process(sql.literal(offset)), + "18446744073709551615") + else: + return ' \n LIMIT %s, %s' % ( + self.process(sql.literal(offset)), + self.process(sql.literal(limit))) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (self.process(sql.literal(limit)),) + + def update_limit_clause(self, update_stmt): + limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) + if limit: + return "LIMIT %s" % limit + else: + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms)) + + def update_from_clause(self, update_stmt, from_table, + extra_froms, from_hints, **kw): + return None + + +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLDDLCompiler(compiler.DDLCompiler): + def create_table_constraints(self, table): + """Get table constraints.""" + constraint_string = super( + MySQLDDLCompiler, self).create_table_constraints(table) + + # why self.dialect.name and not 'mysql'? because of drizzle + is_innodb = 'engine' in table.dialect_options[self.dialect.name] and \ + table.dialect_options[self.dialect.name]['engine'].lower() == 'innodb' + + auto_inc_column = table._autoincrement_column + + if is_innodb and \ + auto_inc_column is not None and \ + auto_inc_column is not list(table.primary_key)[0]: + if constraint_string: + constraint_string += ", \n\t" + constraint_string += "KEY %s (%s)" % ( + self.preparer.quote( + "idx_autoinc_%s" % auto_inc_column.name + ), + self.preparer.format_column(auto_inc_column) + ) + + return constraint_string + + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type) + ] + + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) + if not column.nullable and not is_timestamp: + colspec.append('NOT NULL') + + elif column.nullable and is_timestamp and default is None: + colspec.append('NULL') + + if column is column.table._autoincrement_column and \ + column.server_default is None: + colspec.append('AUTO_INCREMENT') + + return ' '.join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + + opts = dict( + ( + k[len(self.dialect.name) + 1:].upper(), + v + ) + for k, v in table.kwargs.items() + if k.startswith('%s_' % self.dialect.name) + ) + + for opt in topological.sort([ + ('DEFAULT_CHARSET', 'COLLATE'), + ('DEFAULT_CHARACTER_SET', 'COLLATE'), + ('PARTITION_BY', 'PARTITIONS'), # only for test consistency + ], opts): + arg = opts[opt] + if opt in _options_of_type_string: + arg = "'%s'" % arg.replace("\\", "\\\\").replace("'", "''") + + if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', + 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', + 'DEFAULT_CHARSET', + 'DEFAULT_COLLATE', 'PARTITION_BY'): + opt = opt.replace('_', ' ') + + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE', 'PARTITION BY', 'PARTITIONS'): + joiner = ' ' + + table_opts.append(joiner.join((opt, arg))) + return ' '.join(table_opts) + + def visit_create_index(self, create): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + table = preparer.format_table(index.table) + columns = [self.sql_compiler.process(expr, include_table=False, + literal_binds=True) + for expr in index.expressions] + + name = self._prepared_index_name(index) + + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s " % (name, table) + + length = index.dialect_options['mysql']['length'] + if length is not None: + + if isinstance(length, dict): + # length value can be a (column_name --> integer value) mapping + # specifying the prefix length for each column of the index + columns = ', '.join( + ('%s(%d)' % (col, length[col]) + if col in length else '%s' % col) + for col in columns + ) + else: + # or can be an integer value specifying the same + # prefix length for all columns of the index + columns = ', '.join( + '%s(%d)' % (col, length) + for col in columns + ) + else: + columns = ', '.join(columns) + text += '(%s)' % columns + + using = index.dialect_options['mysql']['using'] + if using is not None: + text += " USING %s" % (preparer.quote(using)) + + return text + + def visit_primary_key_constraint(self, constraint): + text = super(MySQLDDLCompiler, self).\ + visit_primary_key_constraint(constraint) + using = constraint.dialect_options['mysql']['using'] + if using: + text += " USING %s" % (self.preparer.quote(using)) + return text + + def visit_drop_index(self, drop): + index = drop.element + + return "\nDROP INDEX %s ON %s" % ( + self._prepared_index_name(index, + include_schema=False), + self.preparer.format_table(index.table)) + + def visit_drop_constraint(self, drop): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % \ + (self.preparer.format_table(constraint.table), + qual, const) + + def define_constraint_match(self, constraint): + if constraint.match is not None: + raise exc.CompileError( + "MySQL ignores the 'MATCH' keyword while at the same time " + "causes ON UPDATE/ON DELETE clauses to be ignored.") + return "" + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += ' UNSIGNED' + if type_.zerofill: + spec += ' ZEROFILL' + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr('charset'): + charset = 'CHARACTER SET %s' % attr('charset') + elif attr('ascii'): + charset = 'ASCII' + elif attr('unicode'): + charset = 'UNICODE' + else: + charset = None + + if attr('collation'): + collation = 'COLLATE %s' % type_.collation + elif attr('binary'): + collation = 'BINARY' + else: + collation = None + + if attr('national'): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return ' '.join([c for c in ('NATIONAL', spec, collation) + if c is not None]) + return ' '.join([c for c in (spec, charset, collation) + if c is not None]) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType)) + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric(type_, + "NUMERIC(%(precision)s)" % + {'precision': type_.precision}) + else: + return self._extend_numeric(type_, + "NUMERIC(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric(type_, + "DECIMAL(%(precision)s)" % + {'precision': type_.precision}) + else: + return self._extend_numeric(type_, + "DECIMAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) + + def visit_DOUBLE(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, + "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) + else: + return self._extend_numeric(type_, 'DOUBLE') + + def visit_REAL(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, + "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale': type_.scale}) + else: + return self._extend_numeric(type_, 'REAL') + + def visit_FLOAT(self, type_): + if self._mysql_type(type_) and \ + type_.scale is not None and \ + type_.precision is not None: + return self._extend_numeric(type_, + "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + elif type_.precision is not None: + return self._extend_numeric(type_, + "FLOAT(%s)" % (type_.precision,)) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, + "INTEGER(%(display_width)s)" % + {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, + "BIGINT(%(display_width)s)" % + {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, + "MEDIUMINT(%(display_width)s)" % + {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, + "TINYINT(%s)" % type_.display_width) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, + "SMALLINT(%(display_width)s)" % + {'display_width': type_.display_width} + ) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_): + if getattr(type_, 'fsp', None): + return "DATETIME(%d)" % type_.fsp + else: + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + if getattr(type_, 'fsp', None): + return "TIME(%d)" % type_.fsp + else: + return "TIME" + + def visit_TIMESTAMP(self, type_): + if getattr(type_, 'fsp', None): + return "TIMESTAMP(%d)" % type_.fsp + else: + return "TIMESTAMP" + + def visit_YEAR(self, type_): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % + self.dialect.name) + + def visit_CHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "CHAR(%(length)s)" % + {'length': type_.length}) + else: + return self._extend_string(type_, {}, "CHAR") + + def visit_NVARCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + if type_.length: + return self._extend_string(type_, {'national': True}, + "VARCHAR(%(length)s)" % {'length': type_.length}) + else: + raise exc.CompileError( + "NVARCHAR requires a length on dialect %s" % + self.dialect.name) + + def visit_NCHAR(self, type_): + # We'll actually generate the equiv. + # "NATIONAL CHAR" instead of "NCHAR". + if type_.length: + return self._extend_string(type_, {'national': True}, + "CHAR(%(length)s)" % {'length': type_.length}) + else: + return self._extend_string(type_, {'national': True}, "CHAR") + + def visit_VARBINARY(self, type_): + return "VARBINARY(%d)" % type_.length + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_enum(self, type_): + if not type_.native_enum: + return super(MySQLTypeCompiler, self).visit_enum(type_) + else: + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_BLOB(self, type_): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_): + return "LONGBLOB" + + def _visit_enumerated_values(self, name, type_, enumerated_values): + quoted_enums = [] + for e in enumerated_values: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string(type_, {}, "%s(%s)" % ( + name, ",".join(quoted_enums)) + ) + + def visit_ENUM(self, type_): + return self._visit_enumerated_values("ENUM", type_, + type_._enumerated_values) + + def visit_SET(self, type_): + return self._visit_enumerated_values("SET", type_, + type_._enumerated_values) + + def visit_BOOLEAN(self, type): + return "BOOL" + + +class MySQLIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect, server_ansiquotes=False, **kw): + if not server_ansiquotes: + quote = "`" + else: + quote = '"' + + super(MySQLIdentifierPreparer, self).__init__( + dialect, + initial_quote=quote, + escape_quote=quote) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +@log.class_logger +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + + name = 'mysql' + supports_alter = True + + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + max_index_name_length = 64 + + supports_native_enum = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_multivalues_insert = True + + default_paramstyle = 'format' + colspecs = colspecs + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + preparer = MySQLIdentifierPreparer + + # default SQL compilation settings - + # these are modified upon initialize(), + # i.e. first connect + _backslash_escapes = True + _server_ansiquotes = False + + construct_arguments = [ + (sa_schema.Table, { + "*": None + }), + (sql.Update, { + "limit": None + }), + (sa_schema.PrimaryKeyConstraint, { + "using": None + }), + (sa_schema.Index, { + "using": None, + "length": None, + }) + ] + + def __init__(self, isolation_level=None, **kwargs): + kwargs.pop('use_ansiquotes', None) # legacy + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + _isolation_lookup = set(['SERIALIZABLE', + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('SELECT @@tx_isolation') + val = cursor.fetchone()[0] + cursor.close() + if util.py3k and isinstance(val, bytes): + val = val.decode() + return val.upper().replace("-", " ") + + def do_commit(self, dbapi_connection): + """Execute a COMMIT.""" + + # COMMIT/ROLLBACK were introduced in 3.23.15. + # Yes, we have at least one user who has to talk to these old versions! + # + # Ignore commit/rollback if support isn't present, otherwise even basic + # operations via autocommit fail. + try: + dbapi_connection.commit() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_rollback(self, dbapi_connection): + """Execute a ROLLBACK.""" + + try: + dbapi_connection.rollback() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), xid=xid) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA PREPARE :xid"), xid=xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), xid=xid) + + def do_recover_twophase(self, connection): + resultset = connection.execute("XA RECOVER") + return [row['data'][0:row['gtrid_length']] for row in resultset] + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in \ + (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): + # if underlying connection is closed, + # this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver + inconsistencies.""" + + return [_DecodingRowProxy(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + return _DecodingRowProxy(rp.fetchone(), charset) + + def _compat_first(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + return _DecodingRowProxy(rp.first(), charset) + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + return connection.execute('SELECT DATABASE()').scalar() + + def has_table(self, connection, table_name, schema=None): + # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly + # on macosx (and maybe win?) with multibyte table names. + # + # TODO: if this is not a problem on win, make the strategy swappable + # based on platform. DESCRIBE is slower. + + # [ticket:726] + # full_name = self.identifier_preparer.format_table(table, + # use_schema=True) + + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + + st = "DESCRIBE %s" % full_name + rs = None + try: + try: + rs = connection.execute(st) + have = rs.fetchone() is not None + rs.close() + return have + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + return False + raise + finally: + if rs: + rs.close() + + def initialize(self, connection): + self._connection_charset = self._detect_charset(connection) + self._detect_ansiquotes(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer(self, + server_ansiquotes=self._server_ansiquotes) + + default.DefaultDialect.initialize(self, connection) + + @property + def _supports_cast(self): + return self.server_version_info is None or \ + self.server_version_info >= (4, 0, 2) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.execute("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + if self.server_version_info < (5, 0, 2): + rp = connection.execute("SHOW TABLES FROM %s" % + self.identifier_preparer.quote_identifier(current_schema)) + return [row[0] for + row in self._compat_fetchall(rp, charset=charset)] + else: + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(current_schema)) + + return [row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == 'BASE TABLE'] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if self.server_version_info < (5, 0, 2): + raise NotImplementedError + if schema is None: + schema = self.default_schema_name + if self.server_version_info < (5, 0, 2): + return self.get_table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ('VIEW', 'SYSTEM VIEW')] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + return parsed_state.table_options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + return parsed_state.columns + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + for key in parsed_state.keys: + if key['type'] == 'PRIMARY': + # There can be only one. + cols = [s[0] for s in key['columns']] + return {'constrained_columns': cols, 'name': None} + return {'constrained_columns': [], 'name': None} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + default_schema = None + + fkeys = [] + + for spec in parsed_state.constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = \ + connection.dialect.default_schema_name + if schema == default_schema: + ref_schema = schema + + loc_names = spec['local'] + ref_names = spec['foreign'] + + con_kw = {} + for opt in ('onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] + + fkey_d = { + 'name': spec['name'], + 'constrained_columns': loc_names, + 'referred_schema': ref_schema, + 'referred_table': ref_name, + 'referred_columns': ref_names, + 'options': con_kw + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + + indexes = [] + for spec in parsed_state.keys: + unique = False + flavor = spec['type'] + if flavor == 'PRIMARY': + continue + if flavor == 'UNIQUE': + unique = True + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + pass + index_d = {} + index_d['name'] = spec['name'] + index_d['column_names'] = [s[0] for s in spec['columns']] + index_d['unique'] = unique + index_d['type'] = flavor + indexes.append(index_d) + return indexes + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw) + + return [ + { + 'name': key['name'], + 'column_names': [col[0] for col in key['columns']] + } + for key in parsed_state.keys + if key['type'] == 'UNIQUE' + ] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, view_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + return sql + + def _parsed_state_or_create(self, connection, table_name, + schema=None, **kw): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get('info_cache', None) + ) + + @util.memoized_property + def _tabledef_parser(self): + """return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + + """ + if (self.server_version_info < (4, 1) and self._server_ansiquotes): + # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 + preparer = self.preparer(self, server_ansiquotes=False) + else: + preparer = self.identifier_preparer + return MySQLTableDefinitionParser(self, preparer) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + parser = self._tabledef_parser + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + if sql.startswith('CREATE ALGORITHM'): + # Adapt views to something table-like. + columns = self._describe_table(connection, None, charset, + full_name=full_name) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html + + charset = self._connection_charset + row = self._compat_first(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if row[1] == 'OFF': + cs = 0 + elif row[1] == 'ON': + cs = 1 + else: + cs = int(row[1]) + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + if self.server_version_info < (4, 1, 0): + pass + else: + charset = self._connection_charset + rs = connection.execute('SHOW COLLATION') + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), + charset=self._connection_charset) + + if not row: + mode = '' + else: + mode = row[1] or '' + # 4.0 + if mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + + self._server_ansiquotes = 'ANSI_QUOTES' in mode + + # as of MySQL 5.0.1 + self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode + + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + rp = connection.execute(st) + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + row = self._compat_first(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + + return sql + + def _describe_table(self, connection, table, charset=None, + full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execute(st) + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + + +class ReflectedState(object): + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + + +@log.class_logger +class MySQLTableDefinitionParser(object): + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r'\r?\n', show_create): + if line.startswith(' ' + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(') '): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ')': + pass + elif line.startswith('CREATE '): + self._parse_table_name(line, state) + # Not present in real reflection, but may be if + # loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == 'key': + state.keys.append(spec) + elif type_ == 'constraint': + state.constraints.append(spec) + else: + pass + return state + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + :param line: A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + :param line: The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group('name')) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + :param line: The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ')': + pass + + else: + rest_of_line = line[:] + for regex, cleanup in self._pr_options: + m = regex.search(rest_of_line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + if cleanup: + value = cleanup(value) + options[directive.lower()] = value + rest_of_line = regex.sub('', rest_of_line) + + for nope in ('auto_increment', 'data directory', 'index directory'): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options['%s_%s' % (self.dialect.name, opt)] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + :param line: Any column-bearing line from SHOW CREATE TABLE + """ + + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec['full']: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args, notnull = \ + spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == '': + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + for kw in ('unsigned', 'zerofill'): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ('charset', 'collate'): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + + if issubclass(col_type, _EnumeratedValues): + type_args = _EnumeratedValues._strip_values(type_args) + + type_instance = col_type(*type_args, **type_kw) + + col_args, col_kw = [], {} + + # NOT NULL + col_kw['nullable'] = True + if spec.get('notnull', False): + col_kw['nullable'] = False + + # AUTO_INCREMENT + if spec.get('autoincr', False): + col_kw['autoincrement'] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw['autoincrement'] = False + + # DEFAULT + default = spec.get('default', None) + + if default == 'NULL': + # eliminates the need to deal with this later. + default = None + + col_d = dict(name=name, type=type_instance, default=default) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] + + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(' '.join(line)) + + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table_name)), + ',\n'.join(buffer), + '\n) ']) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return self._re_keyexprs.findall(identifiers) + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + + _final = self.preparer.final_quote + + quotes = dict(zip(('iq', 'fq', 'esc_fq'), + [re.escape(s) for s in + (self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final))])) + + self._pr_name = _pr_compile( + r'^CREATE (?:\w+ +)?TABLE +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, + self.preparer._unescape_identifier) + + # `col`,`col2`(32),`col3`(15) DESC + # + # Note: ASC and DESC aren't reflected, so we'll punt... + self._re_keyexprs = _re_compile( + r'(?:' + r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' + r'(?:\((\d+)\))?(?=\,|$))+' % quotes) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + + # 123 or 123,456 + self._re_csv_int = _re_compile(r'\d+') + + # `colname` [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|' + r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' + r'(?: +(?PUNSIGNED))?' + r'(?: +(?PZEROFILL))?' + r'(?: +CHARACTER SET +(?P[\w_]+))?' + r'(?: +COLLATE +(?P[\w_]+))?' + r'(?: +(?PNOT NULL))?' + r'(?: +DEFAULT +(?P' + r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+' + r'(?: +ON UPDATE \w+)?)' + r'))?' + r'(?: +(?PAUTO_INCREMENT))?' + r'(?: +COMMENT +(P(?:\x27\x27|[^\x27])+))?' + r'(?: +COLUMN_FORMAT +(?P\w+))?' + r'(?: +STORAGE +(?P\w+))?' + r'(?: +(?P.*))?' + r',?$' + % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' + r'.*?(?PNOT NULL)?' + % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name + self._re_key = _re_compile( + r' ' + r'(?:(?P\S+) )?KEY' + r'(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' + r'(?: +USING +(?P\S+))?' + r' +\((?P.+?)\)' + r'(?: +USING +(?P\S+))?' + r'(?: +KEY_BLOCK_SIZE +(?P\S+))?' + r'(?: +WITH PARSER +(?P\S+))?' + r',?$' + % quotes + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION' + self._re_constraint = _re_compile( + r' ' + r'CONSTRAINT +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'FOREIGN KEY +' + r'\((?P[^\)]+?)\) REFERENCES +' + r'(?P%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' + r'\((?P[^\)]+?)\)' + r'(?: +(?PMATCH \w+))?' + r'(?: +ON DELETE (?P%(on)s))?' + r'(?: +ON UPDATE (?P%(on)s))?' + % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)') + + # Table-level options (COLLATE, ENGINE, etc.) + # Do the string options first, since they have quoted + # strings we need to get rid of. + for option in _options_of_type_string: + self._add_option_string(option) + + for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', + 'AVG_ROW_LENGTH', 'CHARACTER SET', + 'DEFAULT CHARSET', 'CHECKSUM', + 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', + 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', + 'KEY_BLOCK_SIZE'): + self._add_option_word(option) + + self._add_option_regex('UNION', r'\([^\)]+\)') + self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex('RAID_TYPE', + r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + + _optional_equals = r'(?:\s*(?:=\s*)|\s+)' + + def _add_option_string(self, directive): + regex = (r'(?P%s)%s' + r"'(?P(?:[^']|'')*?)'(?!')" % + (re.escape(directive), self._optional_equals)) + self._pr_options.append(_pr_compile(regex, lambda v: + v.replace("\\\\", "\\").replace("''", "'"))) + + def _add_option_word(self, directive): + regex = (r'(?P%s)%s' + r'(?P\w+)' % + (re.escape(directive), self._optional_equals)) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = (r'(?P%s)%s' + r'(?P%s)' % + (re.escape(directive), self._optional_equals, regex)) + self._pr_options.append(_pr_compile(regex)) + +_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY', + 'PASSWORD', 'CONNECTION') + + + +class _DecodingRowProxy(object): + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + + if self.charset and isinstance(item, util.binary_type): + return item.decode(self.charset) + else: + return item + + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, util.binary_type): + return item.decode(self.charset) + else: + return item + + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py new file mode 100644 index 00000000..49728045 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -0,0 +1,84 @@ +# mysql/cymysql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+cymysql + :name: CyMySQL + :dbapi: cymysql + :connectstring: mysql+cymysql://:@/[?] + :url: https://github.com/nakagami/CyMySQL + +""" +import re + +from .mysqldb import MySQLDialect_mysqldb +from .base import (BIT, MySQLDialect) +from ... import util + +class _cymysqlBIT(BIT): + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + """ + + def process(value): + if value is not None: + v = 0 + for i in util.iterbytes(value): + v = v << 8 | i + return v + return value + return process + + +class MySQLDialect_cymysql(MySQLDialect_mysqldb): + driver = 'cymysql' + + description_encoding = None + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_unicode_statements = True + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + BIT: _cymysqlBIT, + } + ) + + @classmethod + def dbapi(cls): + return __import__('cymysql') + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.server_version): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in \ + (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): + # if underlying connection is closed, + # this is the error you get + return True + else: + return False + +dialect = MySQLDialect_cymysql diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py new file mode 100644 index 00000000..13203fce --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -0,0 +1,84 @@ +# mysql/gaerdbms.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: mysql+gaerdbms + :name: Google Cloud SQL + :dbapi: rdbms + :connectstring: mysql+gaerdbms:///?instance= + :url: https://developers.google.com/appengine/docs/python/cloud-sql/developers-guide + + This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with minimal + changes. + + .. versionadded:: 0.7.8 + + +Pooling +------- + +Google App Engine connections appear to be randomly recycled, +so the dialect does not pool connections. The :class:`.NullPool` +implementation is installed within the :class:`.Engine` by +default. + +""" + +import os + +from .mysqldb import MySQLDialect_mysqldb +from ...pool import NullPool +import re + + +def _is_dev_environment(): + return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/') + + +class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): + + @classmethod + def dbapi(cls): + # from django: + # http://code.google.com/p/googleappengine/source/ + # browse/trunk/python/google/storage/speckle/ + # python/django/backend/base.py#118 + # see also [ticket:2649] + # see also http://stackoverflow.com/q/14224679/34549 + from google.appengine.api import apiproxy_stub_map + + if _is_dev_environment(): + from google.appengine.api import rdbms_mysqldb + return rdbms_mysqldb + elif apiproxy_stub_map.apiproxy.GetStub('rdbms'): + from google.storage.speckle.python.api import rdbms_apiproxy + return rdbms_apiproxy + else: + from google.storage.speckle.python.api import rdbms_googleapi + return rdbms_googleapi + + @classmethod + def get_pool_class(cls, url): + # Cloud SQL connections die at any moment + return NullPool + + def create_connect_args(self, url): + opts = url.translate_connect_args() + if not _is_dev_environment(): + # 'dsn' and 'instance' are because we are skipping + # the traditional google.api.rdbms wrapper + opts['dsn'] = '' + opts['instance'] = url.query['instance'] + return [], opts + + def _extract_error_code(self, exception): + match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception)) + # The rdbms api will wrap then re-raise some types of errors + # making this regex return no matches. + code = match.group(1) or match.group(2) if match else None + if code: + return int(code) + +dialect = MySQLDialect_gaerdbms diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py new file mode 100644 index 00000000..3536c3ad --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -0,0 +1,131 @@ +# mysql/mysqlconnector.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mysql+mysqlconnector + :name: MySQL Connector/Python + :dbapi: myconnpy + :connectstring: mysql+mysqlconnector://:@[:]/ + :url: http://dev.mysql.com/downloads/connector/python/ + + +""" + +from .base import (MySQLDialect, + MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, + BIT) + +from ... import util + + +class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): + + def get_lastrowid(self): + return self.cursor.lastrowid + + +class MySQLCompiler_mysqlconnector(MySQLCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace("%", "%%") + + +class _myconnpyBIT(BIT): + def result_processor(self, dialect, coltype): + """MySQL-connector already converts mysql bits, so.""" + + return None + + +class MySQLDialect_mysqlconnector(MySQLDialect): + driver = 'mysqlconnector' + + if util.py2k: + supports_unicode_statements = False + supports_unicode_binds = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = 'format' + execution_ctx_cls = MySQLExecutionContext_mysqlconnector + statement_compiler = MySQLCompiler_mysqlconnector + + preparer = MySQLIdentifierPreparer_mysqlconnector + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + BIT: _myconnpyBIT, + } + ) + + @classmethod + def dbapi(cls): + from mysql import connector + return connector + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + + opts.update(url.query) + + util.coerce_kw_type(opts, 'buffered', bool) + util.coerce_kw_type(opts, 'raise_on_warnings', bool) + opts.setdefault('buffered', True) + opts.setdefault('raise_on_warnings', True) + + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + client_flags = opts.get('client_flags', ClientFlag.get_default()) + client_flags |= ClientFlag.FOUND_ROWS + opts['client_flags'] = client_flags + except: + pass + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = dbapi_con.get_server_version() + return tuple(version) + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + errnos = (2006, 2013, 2014, 2045, 2055, 2048) + exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) + if isinstance(e, exceptions): + return e.errno in errnos or \ + "MySQL Connection not available." in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + +dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 00000000..7fb63f13 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,94 @@ +# mysql/mysqldb.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+mysqldb + :name: MySQL-Python + :dbapi: mysqldb + :connectstring: mysql+mysqldb://:@[:]/ + :url: http://sourceforge.net/projects/mysql-python + + +Unicode +------- + +MySQLdb requires a "charset" parameter to be passed in order for it +to handle non-ASCII characters correctly. When this parameter is passed, +MySQLdb will also implicitly set the "use_unicode" flag to true, which means +that it will return Python unicode objects instead of bytestrings. +However, SQLAlchemy's decode process, when C extensions are enabled, +is orders of magnitude faster than that of MySQLdb as it does not call into +Python functions to do so. Therefore, the **recommended URL to use for +unicode** will include both charset and use_unicode=0:: + + create_engine("mysql+mysqldb://user:pass@host/dbname?charset=utf8&use_unicode=0") + +As of this writing, MySQLdb only runs on Python 2. It is not known how +MySQLdb behaves on Python 3 as far as unicode decoding. + + +Known Issues +------------- + +MySQL-python version 1.2.2 has a serious memory leak related +to unicode conversion, a feature which is disabled via ``use_unicode=0``. +It is strongly advised to use the latest version of MySQL-Python. + +""" + +from .base import (MySQLDialect, MySQLExecutionContext, + MySQLCompiler, MySQLIdentifierPreparer) +from ...connectors.mysqldb import ( + MySQLDBExecutionContext, + MySQLDBCompiler, + MySQLDBIdentifierPreparer, + MySQLDBConnector + ) +from .base import TEXT +from ... import sql + +class MySQLExecutionContext_mysqldb(MySQLDBExecutionContext, MySQLExecutionContext): + pass + + +class MySQLCompiler_mysqldb(MySQLDBCompiler, MySQLCompiler): + pass + + +class MySQLIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, MySQLIdentifierPreparer): + pass + + +class MySQLDialect_mysqldb(MySQLDBConnector, MySQLDialect): + execution_ctx_cls = MySQLExecutionContext_mysqldb + statement_compiler = MySQLCompiler_mysqldb + preparer = MySQLIdentifierPreparer_mysqldb + + def _check_unicode_returns(self, connection): + # work around issue fixed in + # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 + # specific issue w/ the utf8_bin collation and unicode returns + + has_utf8_bin = connection.scalar( + "show collation where %s = 'utf8' and %s = 'utf8_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation") + )) + if has_utf8_bin: + additional_tests = [ + sql.collate(sql.cast( + sql.literal_column( + "'test collated returns'"), + TEXT(charset='utf8')), "utf8_bin") + ] + else: + additional_tests = [] + return super(MySQLDBConnector, self)._check_unicode_returns( + connection, additional_tests) + +dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py new file mode 100644 index 00000000..e6b50f33 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -0,0 +1,261 @@ +# mysql/oursql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+oursql + :name: OurSQL + :dbapi: oursql + :connectstring: mysql+oursql://:@[:]/ + :url: http://packages.python.org/oursql/ + +Unicode +------- + +oursql defaults to using ``utf8`` as the connection charset, but other +encodings may be used instead. Like the MySQL-Python driver, unicode support +can be completely disabled:: + + # oursql sets the connection charset to utf8 automatically; all strings come + # back as utf8 str + create_engine('mysql+oursql:///mydb?use_unicode=0') + +To not automatically use ``utf8`` and instead use whatever the connection +defaults to, there is a separate parameter:: + + # use the default connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?default_charset=1') + + # use latin1 as the connection charset; all strings come back as unicode + create_engine('mysql+oursql:///mydb?charset=latin1') +""" + +import re + +from .base import (BIT, MySQLDialect, MySQLExecutionContext) +from ... import types as sqltypes, util + + +class _oursqlBIT(BIT): + def result_processor(self, dialect, coltype): + """oursql already converts mysql bits, so.""" + + return None + + +class MySQLExecutionContext_oursql(MySQLExecutionContext): + + @property + def plain_query(self): + return self.execution_options.get('_oursql_plain_query', False) + + +class MySQLDialect_oursql(MySQLDialect): + driver = 'oursql' + + if util.py2k: + supports_unicode_binds = True + supports_unicode_statements = True + + supports_native_decimal = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + execution_ctx_cls = MySQLExecutionContext_oursql + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _oursqlBIT, + } + ) + + @classmethod + def dbapi(cls): + return __import__('oursql') + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of *cursor.execute(statement, parameters)*.""" + + if context and context.plain_query: + cursor.execute(statement, plain_query=True) + else: + cursor.execute(statement, parameters) + + def do_begin(self, connection): + connection.cursor().execute('BEGIN', plain_query=True) + + def _xa_query(self, connection, query, xid): + if util.py2k: + arg = connection.connection._escape_string(xid) + else: + charset = self._connection_charset + arg = connection.connection._escape_string(xid.encode(charset)).decode(charset) + arg = "'%s'" % arg + connection.execution_options(_oursql_plain_query=True).execute(query % arg) + + # Because mysql is bad, these methods have to be + # reimplemented to use _PlainQuery. Basically, some queries + # refuse to return any data if they're run through + # the parameterized query API, or refuse to be parameterized + # in the first place. + def do_begin_twophase(self, connection, xid): + self._xa_query(connection, 'XA BEGIN %s', xid) + + def do_prepare_twophase(self, connection, xid): + self._xa_query(connection, 'XA END %s', xid) + self._xa_query(connection, 'XA PREPARE %s', xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self._xa_query(connection, 'XA END %s', xid) + self._xa_query(connection, 'XA ROLLBACK %s', xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + self._xa_query(connection, 'XA COMMIT %s', xid) + + # Q: why didn't we need all these "plain_query" overrides earlier ? + # am i on a newer/older version of OurSQL ? + def has_table(self, connection, table_name, schema=None): + return MySQLDialect.has_table( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema + ) + + def get_table_options(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_table_options( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema=schema, + **kw + ) + + def get_columns(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_columns( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema=schema, + **kw + ) + + def get_view_names(self, connection, schema=None, **kw): + return MySQLDialect.get_view_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + schema=schema, + **kw + ) + + def get_table_names(self, connection, schema=None, **kw): + return MySQLDialect.get_table_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + schema + ) + + def get_schema_names(self, connection, **kw): + return MySQLDialect.get_schema_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + **kw + ) + + def initialize(self, connection): + return MySQLDialect.initialize( + self, + connection.execution_options(_oursql_plain_query=True) + ) + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + return MySQLDialect._show_create_table( + self, + connection.contextual_connect(close_with_result=True). + execution_options(_oursql_plain_query=True), + table, charset, full_name + ) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.ProgrammingError): + return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed') + else: + return e.errno in (2006, 2013, 2014, 2045, 2055) + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'port', int) + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'autoping', bool) + util.coerce_kw_type(opts, 'raise_on_warnings', bool) + + util.coerce_kw_type(opts, 'default_charset', bool) + if opts.pop('default_charset', False): + opts['charset'] = None + else: + util.coerce_kw_type(opts, 'charset', str) + opts['use_unicode'] = opts.get('use_unicode', True) + util.coerce_kw_type(opts, 'use_unicode', bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + opts.setdefault('found_rows', True) + + ssl = {} + for key in ['ssl_ca', 'ssl_key', 'ssl_cert', + 'ssl_capath', 'ssl_cipher']: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + return [[], opts] + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.server_info): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + return exception.errno + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + return connection.connection.charset + + def _compat_fetchall(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchone() + + def _compat_first(self, rp, charset=None): + return rp.first() + + +dialect = MySQLDialect_oursql diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py new file mode 100644 index 00000000..7989203c --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -0,0 +1,45 @@ +# mysql/pymysql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+pymysql + :name: PyMySQL + :dbapi: pymysql + :connectstring: mysql+pymysql://:@/[?] + :url: http://code.google.com/p/pymysql/ + +MySQL-Python Compatibility +-------------------------- + +The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver, +and targets 100% compatibility. Most behavioral notes for MySQL-python apply to +the pymysql driver as well. + +""" + +from .mysqldb import MySQLDialect_mysqldb +from ...util import py3k + +class MySQLDialect_pymysql(MySQLDialect_mysqldb): + driver = 'pymysql' + + description_encoding = None + if py3k: + supports_unicode_statements = True + + + @classmethod + def dbapi(cls): + return __import__('pymysql') + + if py3k: + def _extract_error_code(self, exception): + if isinstance(exception.args[0], Exception): + exception = exception.args[0] + return exception.args[0] + +dialect = MySQLDialect_pymysql diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 00000000..e60e39ce --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,80 @@ +# mysql/pyodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + + +.. dialect:: mysql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mysql+pyodbc://:@ + :url: http://pypi.python.org/pypi/pyodbc/ + + +Limitations +----------- + +The mysql-pyodbc dialect is subject to unresolved character encoding issues +which exist within the current ODBC drivers available. +(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage +of OurSQL, MySQLdb, or MySQL-connector/Python. + +""" + +from .base import MySQLDialect, MySQLExecutionContext +from ...connectors.pyodbc import PyODBCConnector +from ... import util +import re + + +class MySQLExecutionContext_pyodbc(MySQLExecutionContext): + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): + supports_unicode_statements = False + execution_ctx_cls = MySQLExecutionContext_pyodbc + + pyodbc_driver_name = "MySQL" + + def __init__(self, **kw): + # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 + kw.setdefault('convert_unicode', True) + super(MySQLDialect_pyodbc, self).__init__(**kw) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + +dialect = MySQLDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py new file mode 100644 index 00000000..b5fcfbda --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -0,0 +1,111 @@ +# mysql/zxjdbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+zxjdbc + :name: zxjdbc for Jython + :dbapi: zxjdbc + :connectstring: mysql+zxjdbc://:@[:]/ + :driverurl: http://dev.mysql.com/downloads/connector/j/ + +Character Sets +-------------- + +SQLAlchemy zxjdbc dialects pass unicode straight through to the +zxjdbc/JDBC layer. To allow multiple character sets to be sent from the +MySQL Connector/J JDBC driver, by default SQLAlchemy sets its +``characterEncoding`` connection property to ``UTF-8``. It may be +overriden via a ``create_engine`` URL parameter. + +""" +import re + +from ... import types as sqltypes, util +from ...connectors.zxJDBC import ZxJDBCConnector +from .base import BIT, MySQLDialect, MySQLExecutionContext + + +class _ZxJDBCBit(BIT): + def result_processor(self, dialect, coltype): + """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): + if value is None: + return value + if isinstance(value, bool): + return int(value) + v = 0 + for i in value: + v = v << 8 | (i & 0xff) + value = v + return value + return process + + +class MySQLExecutionContext_zxjdbc(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): + jdbc_db_name = 'mysql' + jdbc_driver_name = 'com.mysql.jdbc.Driver' + + execution_ctx_cls = MySQLExecutionContext_zxjdbc + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _ZxJDBCBit + } + ) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs)) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return dict(characterEncoding='UTF-8', yearIsDateType='false') + + def _extract_error_code(self, exception): + # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist + # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () + m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args)) + c = m.group(1) + if c: + return int(c) + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.dbversion): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + +dialect = MySQLDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 00000000..b75762ab --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,23 @@ +# oracle/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc + +base.dialect = cx_oracle.dialect + +from sqlalchemy.dialects.oracle.base import \ + VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\ + BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\ + FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\ + VARCHAR2, NVARCHAR2, ROWID, dialect + + +__all__ = ( +'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER', +'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', +'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL', +'VARCHAR2', 'NVARCHAR2', 'ROWID' +) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 00000000..8bacb885 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,1291 @@ +# oracle/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: oracle + :name: Oracle + + Oracle version 8 through current (11g at the time of this writing) are supported. + +Connect Arguments +----------------- + +The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which +affect the behavior of the dialect regardless of driver in use. + +* ``use_ansi`` - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults + to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. + +* ``optimize_limits`` - defaults to ``False``. see the section on LIMIT/OFFSET. + +* ``use_binds_for_limits`` - defaults to ``True``. see the section on LIMIT/OFFSET. + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually assumed to have +"autoincrementing" behavior, meaning they can generate their own primary key values upon +INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences +to produce these values. With the Oracle dialect, *a sequence must always be explicitly +specified to enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To specify sequences, +use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload=True:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + autoload=True + ) + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier names +using UPPERCASE text. SQLAlchemy on the other hand considers an all-lower case identifier +name to be case insensitive. The Oracle dialect converts all case insensitive identifiers +to and from those two formats during schema level communication, such as reflection of +tables and indexes. Using an UPPERCASE name on the SQLAlchemy side indicates a +case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names have been +truly created as case sensitive (i.e. using quoted names), all lowercase names should be +used on the SQLAlchemy side. + + +LIMIT/OFFSET Support +-------------------- + +Oracle has no support for the LIMIT or OFFSET keywords. SQLAlchemy uses +a wrapped subquery approach in conjunction with ROWNUM. The exact methodology +is taken from +http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . + +There are two options which affect its behavior: + +* the "FIRST ROWS()" optimization keyword is not used by default. To enable the usage of this + optimization directive, specify ``optimize_limits=True`` to :func:`.create_engine`. +* the values passed for the limit/offset are sent as bound parameters. Some users have observed + that Oracle produces a poor query plan when the values are sent as binds and not + rendered literally. To render the limit/offset values literally within the SQL + statement, specify ``use_binds_for_limits=False`` to :func:`.create_engine`. + +Some users have reported better performance when the entirely different approach of a +window query is used, i.e. ROW_NUMBER() OVER (ORDER BY), to provide LIMIT/OFFSET (note +that the majority of users don't observe this). To suit this case the +method used for LIMIT/OFFSET can be replaced entirely. See the recipe at +http://www.sqlalchemy.org/trac/wiki/UsageRecipes/WindowFunctionsByDefault +which installs a select compiler that overrides the generation of limit/offset with +a window function. + +.. _oracle_returning: + +RETURNING Support +----------------- + +The Oracle database supports a limited form of RETURNING, in order to retrieve result +sets of matched rows from INSERT, UPDATE and DELETE statements. Oracle's +RETURNING..INTO syntax only supports one row being returned, as it relies upon +OUT parameters in order to function. In addition, supported DBAPIs have further +limitations (see :ref:`cx_oracle_returning`). + +SQLAlchemy's "implicit returning" feature, which employs RETURNING within an INSERT +and sometimes an UPDATE statement in order to fetch newly generated primary key values +and other SQL defaults and expressions, is normally enabled on the Oracle +backend. By default, "implicit returning" typically only fetches the value of a +single ``nextval(some_seq)`` expression embedded into an INSERT in order to increment +a sequence within an INSERT statement and get the value back at the same time. +To disable this feature across the board, specify ``implicit_returning=False`` to +:func:`.create_engine`:: + + engine = create_engine("oracle://scott:tiger@dsn", implicit_returning=False) + +Implicit returning can also be disabled on a table-by-table basis as a table option:: + + # Core Table + my_table = Table("my_table", metadata, ..., implicit_returning=False) + + + # declarative + class MyClass(Base): + __tablename__ = 'my_table' + __table_args__ = {"implicit_returning": False} + +.. seealso:: + + :ref:`cx_oracle_returning` - additional cx_oracle-specific restrictions on implicit returning. + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based solution +is available at http://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relationship(). + +Oracle 8 Compatibility +---------------------- + +When Oracle 8 is detected, the dialect internally configures itself to the following +behaviors: + +* the use_ansi flag is set to False. This has the effect of converting all + JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN + makes use of Oracle's (+) operator. + +* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when + the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are issued + instead. This because these types don't seem to work correctly on Oracle 8 + even though they are available. The :class:`~sqlalchemy.types.NVARCHAR` + and :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate NVARCHAR2 and NCLOB. + +* the "native unicode" mode is disabled when using cx_oracle, i.e. SQLAlchemy + encodes all Python unicode objects to "string" before passing in as bind parameters. + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search for tables +indicated by synonyms, either in local or remote schemas or accessed over DBLINK, +by passing the flag ``oracle_resolve_synonyms=True`` as a +keyword argument to the :class:`.Table` construct:: + + some_table = Table('some_table', autoload=True, + autoload_with=some_engine, + oracle_resolve_synonyms=True) + +When this flag is set, the given name (such as ``some_table`` above) will +be searched not just in the ``ALL_TABLES`` view, but also within the +``ALL_SYNONYMS`` view to see if this name is actually a synonym to another name. +If the synonym is located and refers to a DBLINK, the oracle dialect knows +how to locate the table's information using DBLINK syntax (e.g. ``@dblink``). + +``oracle_resolve_synonyms`` is accepted wherever reflection arguments are +accepted, including methods such as :meth:`.MetaData.reflect` and +:meth:`.Inspector.get_columns`. + +If synonyms are not in use, this flag should be left disabled. + +DateTime Compatibility +---------------------- + +Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, +which can actually store a date and time value. For this reason, the Oracle +dialect provides a type :class:`.oracle.DATE` which is a subclass of +:class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column +is reflected and the type is reported as ``DATE``, the time-supporting +:class:`.oracle.DATE` type is used. + +.. versionchanged:: 0.9.4 Added :class:`.oracle.DATE` to subclass + :class:`.DateTime`. This is a change as previous versions + would reflect a ``DATE`` column as :class:`.types.DATE`, which subclasses + :class:`.Date`. The only significance here is for schemes that are + examining the type of column for use in special Python translations or + for migrating schemas to other database backends. + +""" + +import re + +from sqlalchemy import util, sql +from sqlalchemy.engine import default, base, reflection +from sqlalchemy.sql import compiler, visitors, expression +from sqlalchemy.sql import operators as sql_operators, functions as sql_functions +from sqlalchemy import types as sqltypes, schema as sa_schema +from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ + BLOB, CLOB, TIMESTAMP, FLOAT + +RESERVED_WORDS = \ + set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN '\ + 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED '\ + 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE '\ + 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE '\ + 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES '\ + 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '\ + 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '\ + 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '\ + 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) + +NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' + 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) + + +class RAW(sqltypes._Binary): + __visit_name__ = 'RAW' +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = 'NCLOB' + + +class VARCHAR2(VARCHAR): + __visit_name__ = 'VARCHAR2' + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = 'NUMBER' + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + def adapt(self, impltype): + ret = super(NUMBER, self).adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class DOUBLE_PRECISION(sqltypes.Numeric): + __visit_name__ = 'DOUBLE_PRECISION' + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = False + + super(DOUBLE_PRECISION, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = 'BFILE' + + +class LONG(sqltypes.Text): + __visit_name__ = 'LONG' + +class DATE(sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`.types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + __visit_name__ = 'DATE' + + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class INTERVAL(sqltypes.TypeEngine): + __visit_name__ = 'INTERVAL' + + def __init__(self, + day_precision=None, + second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs (cx_oracle and zxjdbc). + + :param day_precision: the day precision value. this is the number of digits + to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the number of digits + to store for the fractional seconds field. Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL(day_precision=interval.day_precision, + second_precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + __visit_name__ = 'ROWID' + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + +colspecs = { + sqltypes.Boolean: _OracleBoolean, + sqltypes.Interval: INTERVAL, + sqltypes.DateTime: DATE +} + +ischema_names = { + 'VARCHAR2': VARCHAR, + 'NVARCHAR2': NVARCHAR, + 'CHAR': CHAR, + 'DATE': DATE, + 'NUMBER': NUMBER, + 'BLOB': BLOB, + 'BFILE': BFILE, + 'CLOB': CLOB, + 'NCLOB': NCLOB, + 'TIMESTAMP': TIMESTAMP, + 'TIMESTAMP WITH TIME ZONE': TIMESTAMP, + 'INTERVAL DAY TO SECOND': INTERVAL, + 'RAW': RAW, + 'FLOAT': FLOAT, + 'DOUBLE PRECISION': DOUBLE_PRECISION, + 'LONG': LONG, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_): + return self.visit_DATE(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_unicode(self, type_): + if self.dialect._supports_nchar: + return self.visit_NVARCHAR2(type_) + else: + return self.visit_VARCHAR2(type_) + + def visit_INTERVAL(self, type_): + return "INTERVAL DAY%s TO SECOND%s" % ( + type_.day_precision is not None and + "(%d)" % type_.day_precision or + "", + type_.second_precision is not None and + "(%d)" % type_.second_precision or + "", + ) + + def visit_LONG(self, type_): + return "LONG" + + def visit_TIMESTAMP(self, type_): + if type_.timezone: + return "TIMESTAMP WITH TIME ZONE" + else: + return "TIMESTAMP" + + def visit_DOUBLE_PRECISION(self, type_): + return self._generate_numeric(type_, "DOUBLE PRECISION") + + def visit_NUMBER(self, type_, **kw): + return self._generate_numeric(type_, "NUMBER", **kw) + + def _generate_numeric(self, type_, name, precision=None, scale=None): + if precision is None: + precision = type_.precision + + if scale is None: + scale = getattr(type_, 'scale', None) + + if precision is None: + return name + elif scale is None: + n = "%(name)s(%(precision)s)" + return n % {'name': name, 'precision': precision} + else: + n = "%(name)s(%(precision)s, %(scale)s)" + return n % {'name': name, 'precision': precision, 'scale': scale} + + def visit_string(self, type_): + return self.visit_VARCHAR2(type_) + + def visit_VARCHAR2(self, type_): + return self._visit_varchar(type_, '', '2') + + def visit_NVARCHAR2(self, type_): + return self._visit_varchar(type_, 'N', '2') + visit_NVARCHAR = visit_NVARCHAR2 + + def visit_VARCHAR(self, type_): + return self._visit_varchar(type_, '', '') + + def _visit_varchar(self, type_, n, num): + if not type_.length: + return "%(n)sVARCHAR%(two)s" % {'two': num, 'n': n} + elif not n and self.dialect._supports_char_length: + varchar = "VARCHAR%(two)s(%(length)s CHAR)" + return varchar % {'length': type_.length, 'two': num} + else: + varchar = "%(n)sVARCHAR%(two)s(%(length)s)" + return varchar % {'length': type_.length, 'two': num, 'n': n} + + def visit_text(self, type_): + return self.visit_CLOB(type_) + + def visit_unicode_text(self, type_): + if self.dialect._supports_nchar: + return self.visit_NCLOB(type_) + else: + return self.visit_CLOB(type_) + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_big_integer(self, type_): + return self.visit_NUMBER(type_, precision=19) + + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_RAW(self, type_): + if type_.length: + return "RAW(%(length)s)" % {'length': type_.length} + else: + return "RAW" + + def visit_ROWID(self, type_): + return "ROWID" + + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + compound_keywords = util.update_copy( + compiler.SQLCompiler.compound_keywords, + { + expression.CompoundSelect.EXCEPT: 'MINUS' + } + ) + + def __init__(self, *args, **kwargs): + self.__wheres = {} + self._quoted_bind_names = {} + super(OracleCompiler, self).__init__(*args, **kwargs) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % (self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % (self.process(binary.left), + self.process(binary.right)) + + def visit_true(self, expr, **kw): + return '1' + + def visit_false(self, expr, **kw): + return '0' + + def get_select_hint_text(self, byfroms): + return " ".join( + "/*+ %s */" % text for table, text in byfroms.items() + ) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join(self, join, **kwargs) + else: + kwargs['asfrom'] = True + if isinstance(join.right, expression.FromGrouping): + right = join.right.element + else: + right = join.right + return self.process(join.left, **kwargs) + \ + ", " + self.process(right, **kwargs) + + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if join.right.is_derived_from(binary.left.table): + binary.left = _OuterJoinColumn(binary.left) + elif join.right.is_derived_from(binary.right.table): + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, + {'binary': visit_binary})) + else: + clauses.append(join.onclause) + + for j in join.left, join.right: + if isinstance(j, expression.Join): + visit_join(j) + elif isinstance(j, expression.FromGrouping): + visit_join(j.element) + + for f in froms: + if isinstance(f, expression.Join): + visit_join(f) + + if not clauses: + return None + else: + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + + def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + + if asfrom or ashint: + alias_name = isinstance(alias.name, expression._truncated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + if ashint: + return alias_name + elif asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + \ + " " + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def returning_clause(self, stmt, returning_cols): + columns = [] + binds = [] + for i, column in enumerate(expression._select_iterables(returning_cols)): + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + else: + col_expr = column + outparam = sql.outparam("ret_%d" % i, type_=column.type) + self.binds[outparam.key] = outparam + binds.append(self.bindparam_string(self._truncate_bindparam(outparam))) + columns.append(self.process(col_expr, within_columns_clause=False)) + self.result_map[outparam.key] = ( + outparam.key, + (column, getattr(column, 'name', None), + getattr(column, 'key', None)), + column.type + ) + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + def _TODO_visit_compound_select(self, select): + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``rownum`` criterion. + """ + + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + froms = self._display_froms_for_select( + select, kwargs.get('asfrom', False)) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) + select._oracle_visit = True + + if select._limit is not None or select._offset is not None: + # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html + # + # Generalized form of an Oracle pagination query: + # select ... from ( + # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( + # select distinct ... where ... order by ... + # ) where ROWNUM <= :limit+:offset + # ) where ora_rn > :offset + # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 + + # TODO: use annotations instead of clone + attr set ? + select = select._generate() + select._oracle_visit = True + + # Wrap the middle select and add the hint + limitselect = sql.select([c for c in select.c]) + if select._limit and self.dialect.optimize_limits: + limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # If needed, add the limiting clause + if select._limit is not None: + max_row = select._limit + if select._offset is not None: + max_row += select._offset + if not self.dialect.use_binds_for_limits: + max_row = sql.literal_column("%d" % max_row) + limitselect.append_whereclause( + sql.literal_column("ROWNUM") <= max_row) + + # If needed, add the ora_rn, and wrap again with offset. + if select._offset is None: + limitselect._for_update_arg = select._for_update_arg + select = limitselect + else: + limitselect = limitselect.column( + sql.literal_column("ROWNUM").label("ora_rn")) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + offsetselect = sql.select( + [c for c in limitselect.c if c.key != 'ora_rn']) + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + offset_value = select._offset + if not self.dialect.use_binds_for_limits: + offset_value = sql.literal_column("%d" % offset_value) + offsetselect.append_whereclause( + sql.literal_column("ora_rn") > offset_value) + + offsetselect._for_update_arg = select._for_update_arg + select = offsetselect + + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def limit_clause(self, select): + return "" + + def for_update_clause(self, select): + if self.is_subquery(): + return "" + + tmp = ' FOR UPDATE' + + if select._for_update_arg.of: + tmp += ' OF ' + ', '.join( + self.process(elem) for elem in + select._for_update_arg.of + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + return tmp + + +class OracleDDLCompiler(compiler.DDLCompiler): + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign keys. " + "Consider using deferrable=True, initially='deferred' or triggers.") + + return text + + def visit_create_index(self, create, **kw): + return super(OracleDDLCompiler, self).\ + visit_create_index(create, include_schema=True) + + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = set([x.lower() for x in RESERVED_WORDS]) + illegal_initial_characters = set(range(0, 10)).union(["_", "$"]) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + ) + + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + + +class OracleExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar("SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", type_) + + +class OracleDialect(default.DefaultDialect): + name = 'oracle' + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = 'named' + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_default_values = False + supports_empty_insert = False + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + execution_ctx_cls = OracleExecutionContext + + reflection_options = ('oracle_resolve_synonyms', ) + + construct_arguments = [ + (sa_schema.Table, {"resolve_synonyms": False}) + ] + + def __init__(self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=True, + **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + self.use_binds_for_limits = use_binds_for_limits + + def initialize(self, connection): + super(OracleDialect, self).initialize(connection) + self.implicit_returning = self.__dict__.get( + 'implicit_returning', + self.server_version_info > (10, ) + ) + + if self._is_oracle_8: + self.colspecs = self.colspecs.copy() + self.colspecs.pop(sqltypes.Interval) + self.use_ansi = False + + @property + def _is_oracle_8(self): + return self.server_version_info and \ + self.server_version_info < (9, ) + + @property + def _supports_char_length(self): + return not self._is_oracle_8 + + @property + def _supports_nchar(self): + return not self._is_oracle_8 + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def has_table(self, connection, table_name, schema=None): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text("SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name"), + name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema)) + return cursor.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text("SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND sequence_owner = :schema_name"), + name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema)) + return cursor.first() is not None + + def normalize_name(self, name): + if name is None: + return None + if util.py2k: + if isinstance(name, str): + name = name.decode(self.encoding) + if name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + name = name.upper() + if util.py2k: + if not self.supports_unicode_binds: + name = name.encode(self.encoding) + else: + name = unicode(name) + return name + + def _get_default_schema_name(self, connection): + return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) + + def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if found. + """ + + q = "SELECT owner, table_owner, table_name, db_link, "\ + "synonym_name FROM all_synonyms WHERE " + clauses = [] + params = {} + if desired_synonym: + clauses.append("synonym_name = :synonym_name") + params['synonym_name'] = desired_synonym + if desired_owner: + clauses.append("owner = :desired_owner") + params['desired_owner'] = desired_owner + if desired_table: + clauses.append("table_name = :tname") + params['tname'] = desired_table + + q += " AND ".join(clauses) + + result = connection.execute(sql.text(q), **params) + if desired_owner: + row = result.first() + if row: + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + else: + rows = result.fetchall() + if len(rows) > 1: + raise AssertionError("There are multiple tables visible to the schema, you must specify owner") + elif len(rows) == 1: + row = rows[0] + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + + @reflection.cache + def _prepare_reflection_args(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(schema), + desired_synonym=self.denormalize_name(table_name) + ) + else: + actual_name, owner, dblink, synonym = None, None, None, None + if not actual_name: + actual_name = self.denormalize_name(table_name) + + if dblink: + # using user_db_links here since all_db_links appears + # to have more restricted permissions. + # http://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm + # will need to hear from more users if we are doing + # the right thing here. See [ticket:2619] + owner = connection.scalar( + sql.text("SELECT username FROM user_db_links " + "WHERE db_link=:link"), link=dblink) + dblink = "@" + dblink + elif not owner: + owner = self.denormalize_name(schema or self.default_schema_name) + + return (actual_name, owner, dblink or '', synonym) + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.execute(s,) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + + # note that table_names() isnt loading DBLINKed or synonym'ed tables + if schema is None: + schema = self.default_schema_name + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner " + "AND IOT_NAME IS NULL") + cursor = connection.execute(s, owner=schema) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + columns = [] + if self._supports_char_length: + char_length_col = 'char_length' + else: + char_length_col = 'data_length' + + params = {"table_name": table_name} + text = "SELECT column_name, data_type, %(char_length_col)s, "\ + "data_precision, data_scale, "\ + "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "\ + "WHERE table_name = :table_name" + if schema is not None: + params['owner'] = schema + text += " AND owner = :owner " + text += " ORDER BY column_id" + text = text % {'dblink': dblink, 'char_length_col': char_length_col} + + c = connection.execute(sql.text(text), **params) + + for row in c: + (colname, orig_colname, coltype, length, precision, scale, nullable, default) = \ + (self.normalize_name(row[0]), row[0], row[1], row[2], row[3], row[4], row[5] == 'Y', row[6]) + + if coltype == 'NUMBER': + coltype = NUMBER(precision, scale) + elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'): + coltype = self.ischema_names.get(coltype)(length) + elif 'WITH TIME ZONE' in coltype: + coltype = TIMESTAMP(timezone=True) + else: + coltype = re.sub(r'\(\d+\)', '', coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, colname)) + coltype = sqltypes.NULLTYPE + + cdict = { + 'name': colname, + 'type': coltype, + 'nullable': nullable, + 'default': default, + 'autoincrement': default is None + } + if orig_colname.lower() == orig_colname: + cdict['quote'] = True + + columns.append(cdict) + return columns + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + info_cache = kw.get('info_cache') + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + indexes = [] + + params = {'table_name': table_name} + text = \ + "SELECT a.index_name, a.column_name, b.uniqueness "\ + "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ + "\nALL_INDEXES%(dblink)s b "\ + "\nWHERE "\ + "\na.index_name = b.index_name "\ + "\nAND a.table_owner = b.table_owner "\ + "\nAND a.table_name = b.table_name "\ + "\nAND a.table_name = :table_name " + + if schema is not None: + params['schema'] = schema + text += "AND a.table_owner = :schema " + + text += "ORDER BY a.index_name, a.column_position" + + text = text % {'dblink': dblink} + + q = sql.text(text) + rp = connection.execute(q, **params) + indexes = [] + last_index_name = None + pk_constraint = self.get_pk_constraint( + connection, table_name, schema, resolve_synonyms=resolve_synonyms, + dblink=dblink, info_cache=kw.get('info_cache')) + pkeys = pk_constraint['constrained_columns'] + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + + oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) + + def upper_name_set(names): + return set([i.upper() for i in names]) + + pk_names = upper_name_set(pkeys) + + def remove_if_primary_key(index): + # don't include the primary key index + if index is not None and \ + upper_name_set(index['column_names']) == pk_names: + indexes.pop() + + index = None + for rset in rp: + if rset.index_name != last_index_name: + remove_if_primary_key(index) + index = dict(name=self.normalize_name(rset.index_name), column_names=[]) + indexes.append(index) + index['unique'] = uniqueness.get(rset.uniqueness, False) + + # filter out Oracle SYS_NC names. could also do an outer join + # to the all_tab_columns table and check for real col names there. + if not oracle_sys_col.match(rset.column_name): + index['column_names'].append(self.normalize_name(rset.column_name)) + last_index_name = rset.index_name + remove_if_primary_key(index) + return indexes + + @reflection.cache + def _get_constraint_data(self, connection, table_name, schema=None, + dblink='', **kw): + + params = {'table_name': table_name} + + text = \ + "SELECT"\ + "\nac.constraint_name,"\ + "\nac.constraint_type,"\ + "\nloc.column_name AS local_column,"\ + "\nrem.table_name AS remote_table,"\ + "\nrem.column_name AS remote_column,"\ + "\nrem.owner AS remote_owner,"\ + "\nloc.position as loc_pos,"\ + "\nrem.position as rem_pos"\ + "\nFROM all_constraints%(dblink)s ac,"\ + "\nall_cons_columns%(dblink)s loc,"\ + "\nall_cons_columns%(dblink)s rem"\ + "\nWHERE ac.table_name = :table_name"\ + "\nAND ac.constraint_type IN ('R','P')" + + if schema is not None: + params['owner'] = schema + text += "\nAND ac.owner = :owner" + + text += \ + "\nAND ac.owner = loc.owner"\ + "\nAND ac.constraint_name = loc.constraint_name"\ + "\nAND ac.r_owner = rem.owner(+)"\ + "\nAND ac.r_constraint_name = rem.constraint_name(+)"\ + "\nAND (rem.position IS NULL or loc.position=rem.position)"\ + "\nORDER BY ac.constraint_name, loc.position" + + text = text % {'dblink': dblink} + rp = connection.execute(sql.text(text), **params) + constraint_data = rp.fetchall() + return constraint_data + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + pkeys = [] + constraint_name = None + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == 'P': + if constraint_name is None: + constraint_name = self.normalize_name(cons_name) + pkeys.append(local_column) + return {'constrained_columns': pkeys, 'name': constraint_name} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + requested_schema = schema # to check later on + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + def fkey_rec(): + return { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': None, + 'referred_columns': [] + } + + fkeys = util.defaultdict(fkey_rec) + + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + + if cons_type == 'R': + if remote_table is None: + # ticket 363 + util.warn( + ("Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?") % {'dblink': dblink}) + continue + + rec = fkeys[cons_name] + rec['name'] = cons_name + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + if not rec['referred_table']: + if resolve_synonyms: + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ + self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table) + ) + if ref_synonym: + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name(ref_remote_owner) + + rec['referred_table'] = remote_table + + if requested_schema is not None or self.denormalize_name(remote_owner) != schema: + rec['referred_schema'] = remote_owner + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return list(fkeys.values()) + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + info_cache = kw.get('info_cache') + (view_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, view_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + params = {'view_name': view_name} + text = "SELECT text FROM all_views WHERE view_name=:view_name" + + if schema is not None: + text += " AND owner = :schema" + params['schema'] = schema + + rp = connection.execute(sql.text(text), **params).scalar() + if rp: + if util.py2k: + rp = rp.decode(self.encoding) + return rp + else: + return None + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + + def __init__(self, column): + self.column = column diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 00000000..b8ee90b5 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,941 @@ +# oracle/cx_oracle.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: oracle+cx_oracle + :name: cx-Oracle + :dbapi: cx_oracle + :connectstring: oracle+cx_oracle://user:pass@host:port/dbname[?key=value&key=value...] + :url: http://cx-oracle.sourceforge.net/ + +Additional Connect Arguments +---------------------------- + +When connecting with ``dbname`` present, the host, port, and dbname tokens are +converted to a TNS name using +the cx_oracle ``makedsn()`` function. Otherwise, the host token is taken +directly as a TNS name. + +Additional arguments which may be specified either as query string arguments +on the URL, or as keyword arguments to :func:`.create_engine()` are: + +* ``allow_twophase`` - enable two-phase transactions. Defaults to ``True``. + +* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted + to 50. This setting is significant with cx_Oracle as the contents of LOB + objects are only readable within a "live" row (e.g. within a batch of + 50 rows). + +* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. + +* ``auto_setinputsizes`` - the cx_oracle.setinputsizes() call is issued for + all bind parameters. This is required for LOB datatypes but can be + disabled to reduce overhead. Defaults to ``True``. Specific types + can be excluded from this process using the ``exclude_setinputsizes`` + parameter. + +* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail. + +* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail. + +* ``exclude_setinputsizes`` - a tuple or list of string DBAPI type names to + be excluded from the "auto setinputsizes" feature. The type names here + must match DBAPI types that are found in the "cx_Oracle" module namespace, + such as cx_Oracle.UNICODE, cx_Oracle.NCLOB, etc. Defaults to + ``(STRING, UNICODE)``. + + .. versionadded:: 0.8 specific DBAPI types can be excluded from the + auto_setinputsizes feature via the exclude_setinputsizes attribute. + +* ``mode`` - This is given the string value of SYSDBA or SYSOPER, or alternatively + an integer value. This value is only available as a URL query string + argument. + +* ``threaded`` - enable multithreaded access to cx_oracle connections. Defaults + to ``True``. Note that this is the opposite default of the cx_Oracle DBAPI + itself. + +.. _cx_oracle_unicode: + +Unicode +------- + +The cx_Oracle DBAPI as of version 5 fully supports unicode, and has the ability +to return string results as Python unicode objects natively. + +When used in Python 3, cx_Oracle returns all strings as Python unicode objects +(that is, plain ``str`` in Python 3). In Python 2, it will return as Python +unicode those column values that are of type ``NVARCHAR`` or ``NCLOB``. For +column values that are of type ``VARCHAR`` or other non-unicode string types, +it will return values as Python strings (e.g. bytestrings). + +The cx_Oracle SQLAlchemy dialect presents two different options for the use case of +returning ``VARCHAR`` column values as Python unicode objects under Python 2: + +* the cx_Oracle DBAPI has the ability to coerce all string results to Python + unicode objects unconditionally using output type handlers. This has + the advantage that the unicode conversion is global to all statements + at the cx_Oracle driver level, meaning it works with raw textual SQL + statements that have no typing information associated. However, this system + has been observed to incur signfiicant performance overhead, not only because + it takes effect for all string values unconditionally, but also because cx_Oracle under + Python 2 seems to use a pure-Python function call in order to do the + decode operation, which under cPython can orders of magnitude slower + than doing it using C functions alone. + +* SQLAlchemy has unicode-decoding services built in, and when using SQLAlchemy's + C extensions, these functions do not use any Python function calls and + are very fast. The disadvantage to this approach is that the unicode + conversion only takes effect for statements where the :class:`.Unicode` type + or :class:`.String` type with ``convert_unicode=True`` is explicitly + associated with the result column. This is the case for any ORM or Core + query or SQL expression as well as for a :func:`.text` construct that specifies + output column types, so in the vast majority of cases this is not an issue. + However, when sending a completely raw string to :meth:`.Connection.execute`, + this typing information isn't present, unless the string is handled + within a :func:`.text` construct that adds typing information. + +As of version 0.9.2 of SQLAlchemy, the default approach is to use SQLAlchemy's +typing system. This keeps cx_Oracle's expensive Python 2 approach +disabled unless the user explicitly wants it. Under Python 3, SQLAlchemy detects +that cx_Oracle is returning unicode objects natively and cx_Oracle's system +is used. + +To re-enable cx_Oracle's output type handler under Python 2, the +``coerce_to_unicode=True`` flag (new in 0.9.4) can be passed to +:func:`.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", coerce_to_unicode=True) + +Alternatively, to run a pure string SQL statement and get ``VARCHAR`` results +as Python unicode under Python 2 without using cx_Oracle's native handlers, +the :func:`.text` feature can be used:: + + from sqlalchemy import text, Unicode + result = conn.execute(text("select username from user").columns(username=Unicode)) + +.. versionchanged:: 0.9.2 cx_Oracle's outputtypehandlers are no longer used for + unicode results of non-unicode datatypes in Python 2, after they were identified as a major + performance bottleneck. SQLAlchemy's own unicode facilities are used + instead. + +.. versionadded:: 0.9.4 Added the ``coerce_to_unicode`` flag, to re-enable + cx_Oracle's outputtypehandler and revert to pre-0.9.2 behavior. + +.. _cx_oracle_returning: + +RETURNING Support +----------------- + +The cx_oracle DBAPI supports a limited subset of Oracle's already limited RETURNING support. +Typically, results can only be guaranteed for at most one column being returned; +this is the typical case when SQLAlchemy uses RETURNING to get just the value of a +primary-key-associated sequence value. Additional column expressions will +cause problems in a non-determinative way, due to cx_oracle's lack of support for +the OCI_DATA_AT_EXEC API which is required for more complex RETURNING scenarios. + +For this reason, stability may be enhanced by disabling RETURNING support completely; +SQLAlchemy otherwise will use RETURNING to fetch newly sequence-generated +primary keys. As illustrated in :ref:`oracle_returning`:: + + engine = create_engine("oracle://scott:tiger@dsn", implicit_returning=False) + +.. seealso:: + + http://docs.oracle.com/cd/B10501_01/appdev.920/a96584/oci05bnd.htm#420693 - OCI documentation for RETURNING + + http://sourceforge.net/mailarchive/message.php?msg_id=31338136 - cx_oracle developer commentary + +.. _cx_oracle_lob: + +LOB Objects +----------- + +cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy converts +these to strings so that the interface of the Binary type is consistent with that of +other backends, and so that the linkage to a live cursor is not needed in scenarios +like result.fetchmany() and result.fetchall(). This means that by default, LOB +objects are fully fetched unconditionally by SQLAlchemy, and the linkage to a live +cursor is broken. + +To disable this processing, pass ``auto_convert_lobs=False`` to :func:`.create_engine()`. + +Two Phase Transaction Support +----------------------------- + +Two Phase transactions are implemented using XA transactions, and are known +to work in a rudimental fashion with recent versions of cx_Oracle +as of SQLAlchemy 0.8.0b2, 0.7.10. However, the mechanism is not yet +considered to be robust and should still be regarded as experimental. + +In particular, the cx_Oracle DBAPI as recently as 5.1.2 has a bug regarding +two phase which prevents +a particular DBAPI connection from being consistently usable in both +prepared transactions as well as traditional DBAPI usage patterns; therefore +once a particular connection is used via :meth:`.Connection.begin_prepared`, +all subsequent usages of the underlying DBAPI connection must be within +the context of prepared transactions. + +The default behavior of :class:`.Engine` is to maintain a pool of DBAPI +connections. Therefore, due to the above glitch, a DBAPI connection that has +been used in a two-phase operation, and is then returned to the pool, will +not be usable in a non-two-phase context. To avoid this situation, +the application can make one of several choices: + +* Disable connection pooling using :class:`.NullPool` + +* Ensure that the particular :class:`.Engine` in use is only used + for two-phase operations. A :class:`.Engine` bound to an ORM + :class:`.Session` which includes ``twophase=True`` will consistently + use the two-phase transaction style. + +* For ad-hoc two-phase operations without disabling pooling, the DBAPI + connection in use can be evicted from the connection pool using the + :meth:`.Connection.detach` method. + +.. versionchanged:: 0.8.0b2,0.7.10 + Support for cx_oracle prepared transactions has been implemented + and tested. + +.. _cx_oracle_numeric: + +Precision Numerics +------------------ + +The SQLAlchemy dialect goes through a lot of steps to ensure +that decimal numbers are sent and received with full accuracy. +An "outputtypehandler" callable is associated with each +cx_oracle connection object which detects numeric types and +receives them as string values, instead of receiving a Python +``float`` directly, which is then passed to the Python +``Decimal`` constructor. The :class:`.Numeric` and +:class:`.Float` types under the cx_oracle dialect are aware of +this behavior, and will coerce the ``Decimal`` to ``float`` if +the ``asdecimal`` flag is ``False`` (default on :class:`.Float`, +optional on :class:`.Numeric`). + +Because the handler coerces to ``Decimal`` in all cases first, +the feature can detract significantly from performance. +If precision numerics aren't required, the decimal handling +can be disabled by passing the flag ``coerce_to_decimal=False`` +to :func:`.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False) + +.. versionadded:: 0.7.6 + Add the ``coerce_to_decimal`` flag. + +Another alternative to performance is to use the +`cdecimal `_ library; +see :class:`.Numeric` for additional notes. + +The handler attempts to use the "precision" and "scale" +attributes of the result set column to best determine if +subsequent incoming values should be received as ``Decimal`` as +opposed to int (in which case no processing is added). There are +several scenarios where OCI_ does not provide unambiguous data +as to the numeric type, including some situations where +individual rows may return a combination of floating point and +integer values. Certain values for "precision" and "scale" have +been observed to determine this scenario. When it occurs, the +outputtypehandler receives as string and then passes off to a +processing function which detects, for each returned value, if a +decimal point is present, and if so converts to ``Decimal``, +otherwise to int. The intention is that simple int-based +statements like "SELECT my_seq.nextval() FROM DUAL" continue to +return ints and not ``Decimal`` objects, and that any kind of +floating point value is received as a string so that there is no +floating point loss of precision. + +The "decimal point is present" logic itself is also sensitive to +locale. Under OCI_, this is controlled by the NLS_LANG +environment variable. Upon first connection, the dialect runs a +test to determine the current "decimal" character, which can be +a comma "," for european locales. From that point forward the +outputtypehandler uses that character to represent a decimal +point. Note that cx_oracle 5.0.3 or greater is required +when dealing with numerics with locale settings that don't use +a period "." as the decimal character. + +.. versionchanged:: 0.6.6 + The outputtypehandler supports the case where the locale uses a + comma "," character to represent a decimal point. + +.. _OCI: http://www.oracle.com/technetwork/database/features/oci/index.html + +""" + +from __future__ import absolute_import + +from .base import OracleCompiler, OracleDialect, OracleExecutionContext +from . import base as oracle +from ...engine import result as _result +from sqlalchemy import types as sqltypes, util, exc, processors +import random +import collections +import decimal +import re + + +class _OracleNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + # cx_oracle accepts Decimal objects and floats + return None + + def result_processor(self, dialect, coltype): + # we apply a cx_oracle type handler to all connections + # that converts floating point strings to Decimal(). + # However, in some subquery situations, Oracle doesn't + # give us enough information to determine int or Decimal. + # It could even be int/Decimal differently on each row, + # regardless of the scale given for the originating type. + # So we still need an old school isinstance() handler + # here for decimals. + + if dialect.supports_native_decimal: + if self.asdecimal: + fstring = "%%.%df" % self._effective_decimal_return_scale + + def to_decimal(value): + if value is None: + return None + elif isinstance(value, decimal.Decimal): + return value + else: + return decimal.Decimal(fstring % value) + + return to_decimal + else: + if self.precision is None and self.scale is None: + return processors.to_float + elif not getattr(self, '_is_oracle_number', False) \ + and self.scale is not None: + return processors.to_float + else: + return None + else: + # cx_oracle 4 behavior, will assume + # floats + return super(_OracleNumeric, self).\ + result_processor(dialect, coltype) + + +class _OracleDate(sqltypes.Date): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return value.date() + else: + return value + return process + + +class _LOBMixin(object): + def result_processor(self, dialect, coltype): + if not dialect.auto_convert_lobs: + # return the cx_oracle.LOB directly. + return None + + def process(value): + if value is not None: + return value.read() + else: + return value + return process + + +class _NativeUnicodeMixin(object): + if util.py2k: + def bind_processor(self, dialect): + if dialect._cx_oracle_with_unicode: + def process(value): + if value is None: + return value + else: + return unicode(value) + return process + else: + return super(_NativeUnicodeMixin, self).bind_processor(dialect) + + # we apply a connection output handler that returns + # unicode in all cases, so the "native_unicode" flag + # will be set for the default String.result_processor. + + +class _OracleChar(_NativeUnicodeMixin, sqltypes.CHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_CHAR + + +class _OracleNVarChar(_NativeUnicodeMixin, sqltypes.NVARCHAR): + def get_dbapi_type(self, dbapi): + return getattr(dbapi, 'UNICODE', dbapi.STRING) + + +class _OracleText(_LOBMixin, sqltypes.Text): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + + +class _OracleLong(oracle.LONG): + # a raw LONG is a text type, but does *not* + # get the LobMixin with cx_oracle. + + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + +class _OracleString(_NativeUnicodeMixin, sqltypes.String): + pass + + +class _OracleUnicodeText(_LOBMixin, _NativeUnicodeMixin, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + return dbapi.NCLOB + + def result_processor(self, dialect, coltype): + lob_processor = _LOBMixin.result_processor(self, dialect, coltype) + if lob_processor is None: + return None + + string_processor = sqltypes.UnicodeText.result_processor(self, dialect, coltype) + + if string_processor is None: + return lob_processor + else: + def process(value): + return string_processor(lob_processor(value)) + return process + + +class _OracleInteger(sqltypes.Integer): + def result_processor(self, dialect, coltype): + def to_int(val): + if val is not None: + val = int(val) + return val + return to_int + + +class _OracleBinary(_LOBMixin, sqltypes.LargeBinary): + def get_dbapi_type(self, dbapi): + return dbapi.BLOB + + def bind_processor(self, dialect): + return None + + +class _OracleInterval(oracle.INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + +class _OracleRaw(oracle.RAW): + pass + + +class _OracleRowid(oracle.ROWID): + def get_dbapi_type(self, dbapi): + return dbapi.ROWID + + +class OracleCompiler_cx_oracle(OracleCompiler): + def bindparam_string(self, name, **kw): + quote = getattr(name, 'quote', None) + if quote is True or quote is not False and \ + self.preparer._bindparam_requires_quotes(name): + quoted_name = '"%s"' % name + self._quoted_bind_names[name] = quoted_name + return OracleCompiler.bindparam_string(self, quoted_name, **kw) + else: + return OracleCompiler.bindparam_string(self, name, **kw) + + +class OracleExecutionContext_cx_oracle(OracleExecutionContext): + + def pre_exec(self): + quoted_bind_names = \ + getattr(self.compiled, '_quoted_bind_names', None) + if quoted_bind_names: + if not self.dialect.supports_unicode_statements: + # if DBAPI doesn't accept unicode statements, + # keys in self.parameters would have been encoded + # here. so convert names in quoted_bind_names + # to encoded as well. + quoted_bind_names = \ + dict( + (fromname.encode(self.dialect.encoding), + toname.encode(self.dialect.encoding)) + for fromname, toname in + quoted_bind_names.items() + ) + for param in self.parameters: + for fromname, toname in quoted_bind_names.items(): + param[toname] = param[fromname] + del param[fromname] + + if self.dialect.auto_setinputsizes: + # cx_oracle really has issues when you setinputsizes + # on String, including that outparams/RETURNING + # breaks for varchars + self.set_input_sizes(quoted_bind_names, + exclude_types=self.dialect.exclude_setinputsizes + ) + + # if a single execute, check for outparams + if len(self.compiled_parameters) == 1: + for bindparam in self.compiled.binds.values(): + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + if dbtype is None: + raise exc.InvalidRequestError( + "Cannot create out parameter for parameter " + "%r - it's type %r is not supported by" + " cx_oracle" % + (bindparam.key, bindparam.type) + ) + name = self.compiled.bind_names[bindparam] + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][quoted_bind_names.get(name, name)] = \ + self.out_parameters[name] + + def create_cursor(self): + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def get_result_proxy(self): + if hasattr(self, 'out_parameters') and self.compiled.returning: + returning_params = dict( + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) + return ReturningResultProxy(self, returning_params) + + result = None + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect._cx_oracle_binary_types: + result = _result.BufferedColumnResultProxy(self) + + if result is None: + result = _result.ResultProxy(self) + + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None and \ + len(self.compiled_parameters) == 1: + result.out_parameters = out_parameters = {} + + for bind, name in self.compiled.bind_names.items(): + if name in self.out_parameters: + type = bind.type + impl_type = type.dialect_impl(self.dialect) + dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + result_processor = impl_type.\ + result_processor(self.dialect, + dbapi_type) + if result_processor is not None: + out_parameters[name] = \ + result_processor(self.out_parameters[name].getvalue()) + else: + out_parameters[name] = self.out_parameters[name].getvalue() + else: + result.out_parameters = dict( + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) + + return result + + +class OracleExecutionContext_cx_oracle_with_unicode(OracleExecutionContext_cx_oracle): + """Support WITH_UNICODE in Python 2.xx. + + WITH_UNICODE allows cx_Oracle's Python 3 unicode handling + behavior under Python 2.x. This mode in some cases disallows + and in other cases silently passes corrupted data when + non-Python-unicode strings (a.k.a. plain old Python strings) + are passed as arguments to connect(), the statement sent to execute(), + or any of the bind parameter keys or values sent to execute(). + This optional context therefore ensures that all statements are + passed as Python unicode objects. + + """ + def __init__(self, *arg, **kw): + OracleExecutionContext_cx_oracle.__init__(self, *arg, **kw) + self.statement = util.text_type(self.statement) + + def _execute_scalar(self, stmt): + return super(OracleExecutionContext_cx_oracle_with_unicode, self).\ + _execute_scalar(util.text_type(stmt)) + + +class ReturningResultProxy(_result.FullyBufferedResultProxy): + """Result proxy which stuffs the _returning clause + outparams into the fetch.""" + + def __init__(self, context, returning_params): + self._returning_params = returning_params + super(ReturningResultProxy, self).__init__(context) + + def _cursor_description(self): + returning = self.context.compiled.returning + return [ + ("ret_%d" % i, None) + for i, col in enumerate(returning) + ] + + def _buffer_rows(self): + return collections.deque([tuple(self._returning_params["ret_%d" % i] + for i, c in enumerate(self._returning_params))]) + + +class OracleDialect_cx_oracle(OracleDialect): + execution_ctx_cls = OracleExecutionContext_cx_oracle + statement_compiler = OracleCompiler_cx_oracle + + driver = "cx_oracle" + + colspecs = colspecs = { + sqltypes.Numeric: _OracleNumeric, + sqltypes.Date: _OracleDate, # generic type, assume datetime.date is desired + sqltypes.LargeBinary: _OracleBinary, + sqltypes.Boolean: oracle._OracleBoolean, + sqltypes.Interval: _OracleInterval, + oracle.INTERVAL: _OracleInterval, + sqltypes.Text: _OracleText, + sqltypes.String: _OracleString, + sqltypes.UnicodeText: _OracleUnicodeText, + sqltypes.CHAR: _OracleChar, + + # a raw LONG is a text type, but does *not* + # get the LobMixin with cx_oracle. + oracle.LONG: _OracleLong, + + # this is only needed for OUT parameters. + # it would be nice if we could not use it otherwise. + sqltypes.Integer: _OracleInteger, + + oracle.RAW: _OracleRaw, + sqltypes.Unicode: _OracleNVarChar, + sqltypes.NVARCHAR: _OracleNVarChar, + oracle.ROWID: _OracleRowid, + } + + execute_sequence_format = list + + def __init__(self, + auto_setinputsizes=True, + exclude_setinputsizes=("STRING", "UNICODE"), + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + coerce_to_decimal=True, + coerce_to_unicode=False, + arraysize=50, **kwargs): + OracleDialect.__init__(self, **kwargs) + self.threaded = threaded + self.arraysize = arraysize + self.allow_twophase = allow_twophase + self.supports_timestamp = self.dbapi is None or \ + hasattr(self.dbapi, 'TIMESTAMP') + self.auto_setinputsizes = auto_setinputsizes + self.auto_convert_lobs = auto_convert_lobs + + if hasattr(self.dbapi, 'version'): + self.cx_oracle_ver = tuple([int(x) for x in + self.dbapi.version.split('.')]) + else: + self.cx_oracle_ver = (0, 0, 0) + + def types(*names): + return set( + getattr(self.dbapi, name, None) for name in names + ).difference([None]) + + self.exclude_setinputsizes = types(*(exclude_setinputsizes or ())) + self._cx_oracle_string_types = types("STRING", "UNICODE", + "NCLOB", "CLOB") + self._cx_oracle_unicode_types = types("UNICODE", "NCLOB") + self._cx_oracle_binary_types = types("BFILE", "CLOB", "NCLOB", "BLOB") + self.supports_unicode_binds = self.cx_oracle_ver >= (5, 0) + + self.coerce_to_unicode = ( + self.cx_oracle_ver >= (5, 0) and + coerce_to_unicode + ) + + self.supports_native_decimal = ( + self.cx_oracle_ver >= (5, 0) and + coerce_to_decimal + ) + + self._cx_oracle_native_nvarchar = self.cx_oracle_ver >= (5, 0) + + if self.cx_oracle_ver is None: + # this occurs in tests with mock DBAPIs + self._cx_oracle_string_types = set() + self._cx_oracle_with_unicode = False + elif self.cx_oracle_ver >= (5,) and not hasattr(self.dbapi, 'UNICODE'): + # cx_Oracle WITH_UNICODE mode. *only* python + # unicode objects accepted for anything + self.supports_unicode_statements = True + self.supports_unicode_binds = True + self._cx_oracle_with_unicode = True + + if util.py2k: + # There's really no reason to run with WITH_UNICODE under Python 2.x. + # Give the user a hint. + util.warn( + "cx_Oracle is compiled under Python 2.xx using the " + "WITH_UNICODE flag. Consider recompiling cx_Oracle " + "without this flag, which is in no way necessary for full " + "support of Unicode. Otherwise, all string-holding bind " + "parameters must be explicitly typed using SQLAlchemy's " + "String type or one of its subtypes," + "or otherwise be passed as Python unicode. " + "Plain Python strings passed as bind parameters will be " + "silently corrupted by cx_Oracle." + ) + self.execution_ctx_cls = \ + OracleExecutionContext_cx_oracle_with_unicode + else: + self._cx_oracle_with_unicode = False + + if self.cx_oracle_ver is None or \ + not self.auto_convert_lobs or \ + not hasattr(self.dbapi, 'CLOB'): + self.dbapi_type_map = {} + else: + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: oracle.CLOB(), + self.dbapi.NCLOB: oracle.NCLOB(), + self.dbapi.BLOB: oracle.BLOB(), + self.dbapi.BINARY: oracle.RAW(), + } + + @classmethod + def dbapi(cls): + import cx_Oracle + return cx_Oracle + + def initialize(self, connection): + super(OracleDialect_cx_oracle, self).initialize(connection) + if self._is_oracle_8: + self.supports_unicode_binds = False + self._detect_decimal_char(connection) + + def _detect_decimal_char(self, connection): + """detect if the decimal separator character is not '.', as + is the case with european locale settings for NLS_LANG. + + cx_oracle itself uses similar logic when it formats Python + Decimal objects to strings on the bind side (as of 5.0.3), + as Oracle sends/receives string numerics only in the + current locale. + + """ + if self.cx_oracle_ver < (5,): + # no output type handlers before version 5 + return + + cx_Oracle = self.dbapi + conn = connection.connection + + # override the output_type_handler that's + # on the cx_oracle connection with a plain + # one on the cursor + + def output_type_handler(cursor, name, defaultType, + size, precision, scale): + return cursor.var( + cx_Oracle.STRING, + 255, arraysize=cursor.arraysize) + + cursor = conn.cursor() + cursor.outputtypehandler = output_type_handler + cursor.execute("SELECT 0.1 FROM DUAL") + val = cursor.fetchone()[0] + cursor.close() + char = re.match(r"([\.,])", val).group(1) + if char != '.': + _detect_decimal = self._detect_decimal + self._detect_decimal = \ + lambda value: _detect_decimal(value.replace(char, '.')) + self._to_decimal = \ + lambda value: decimal.Decimal(value.replace(char, '.')) + + def _detect_decimal(self, value): + if "." in value: + return decimal.Decimal(value) + else: + return int(value) + + _to_decimal = decimal.Decimal + + def on_connect(self): + if self.cx_oracle_ver < (5,): + # no output type handlers before version 5 + return + + cx_Oracle = self.dbapi + + def output_type_handler(cursor, name, defaultType, + size, precision, scale): + # convert all NUMBER with precision + positive scale to Decimal + # this almost allows "native decimal" mode. + if self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER and \ + precision and scale > 0: + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=self._to_decimal, + arraysize=cursor.arraysize) + # if NUMBER with zero precision and 0 or neg scale, this appears + # to indicate "ambiguous". Use a slower converter that will + # make a decision based on each value received - the type + # may change from row to row (!). This kills + # off "native decimal" mode, handlers still needed. + elif self.supports_native_decimal and \ + defaultType == cx_Oracle.NUMBER \ + and not precision and scale <= 0: + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=self._detect_decimal, + arraysize=cursor.arraysize) + # allow all strings to come back natively as Unicode + elif self.coerce_to_unicode and \ + defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): + return cursor.var(util.text_type, size, cursor.arraysize) + + def on_connect(conn): + conn.outputtypehandler = output_type_handler + + return on_connect + + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + + if url.database: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + dsn = self.dbapi.makedsn(url.host, port, url.database) + else: + # we have a local tnsname + dsn = url.host + + opts = dict( + user=url.username, + password=url.password, + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, + ) + + if util.py2k: + if self._cx_oracle_with_unicode: + for k, v in opts.items(): + if isinstance(v, str): + opts[k] = unicode(v) + else: + for k, v in opts.items(): + if isinstance(v, unicode): + opts[k] = str(v) + + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], util.string_types): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple( + int(x) + for x in connection.connection.version.split('.') + ) + + def is_disconnect(self, e, connection, cursor): + error, = e.args + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + elif hasattr(error, 'code'): + # ORA-00028: your session has been killed + # ORA-03114: not connected to ORACLE + # ORA-03113: end-of-file on communication channel + # ORA-03135: connection lost contact + # ORA-01033: ORACLE initialization or shutdown in progress + # ORA-02396: exceeded maximum idle time, please connect again + # TODO: Others ? + return error.code in (28, 3114, 3113, 3135, 1033, 2396) + else: + return False + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id, "%032x" % 9) + + def do_executemany(self, cursor, statement, parameters, context=None): + if isinstance(parameters, tuple): + parameters = list(parameters) + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + + def do_prepare_twophase(self, connection, xid): + result = connection.connection.prepare() + connection.info['cx_oracle_prepared'] = result + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_commit(connection.connection) + else: + oci_prepared = connection.info['cx_oracle_prepared'] + if oci_prepared: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + connection.info.pop('cx_oracle_prepared', None) + +dialect = OracleDialect_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py new file mode 100644 index 00000000..710645b2 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -0,0 +1,218 @@ +# oracle/zxjdbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: oracle+zxjdbc + :name: zxJDBC for Jython + :dbapi: zxjdbc + :connectstring: oracle+zxjdbc://user:pass@host/dbname + :driverurl: http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html. + +""" +import decimal +import re + +from sqlalchemy import sql, types as sqltypes, util +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext +from sqlalchemy.engine import result as _result +from sqlalchemy.sql import expression +import collections + +SQLException = zxJDBC = None + + +class _ZxJDBCDate(sqltypes.Date): + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + else: + return value.date() + return process + + +class _ZxJDBCNumeric(sqltypes.Numeric): + + def result_processor(self, dialect, coltype): + #XXX: does the dialect return Decimal or not??? + # if it does (in all cases), we could use a None processor as well as + # the to_float generic processor + if self.asdecimal: + def process(value): + if isinstance(value, decimal.Decimal): + return value + else: + return decimal.Decimal(str(value)) + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +class OracleCompiler_zxjdbc(OracleCompiler): + + def returning_clause(self, stmt, returning_cols): + self.returning_cols = list(expression._select_iterables(returning_cols)) + + # within_columns_clause=False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) + for c in self.returning_cols] + + if not hasattr(self, 'returning_parameters'): + self.returning_parameters = [] + + binds = [] + for i, col in enumerate(self.returning_cols): + dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + self.returning_parameters.append((i + 1, dbtype)) + + bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) + self.binds[bindparam.key] = bindparam + binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + +class OracleExecutionContext_zxjdbc(OracleExecutionContext): + + def pre_exec(self): + if hasattr(self.compiled, 'returning_parameters'): + # prepare a zxJDBC statement so we can grab its underlying + # OraclePreparedStatement's getReturnResultSet later + self.statement = self.cursor.prepare(self.statement) + + def get_result_proxy(self): + if hasattr(self.compiled, 'returning_parameters'): + rrs = None + try: + try: + rrs = self.statement.__statement__.getReturnResultSet() + next(rrs) + except SQLException as sqle: + msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) + if sqle.getSQLState() is not None: + msg += ' [SQLState: %s]' % sqle.getSQLState() + raise zxJDBC.Error(msg) + else: + row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype) + for index, dbtype in self.compiled.returning_parameters) + return ReturningResultProxy(self, row) + finally: + if rrs is not None: + try: + rrs.close() + except SQLException: + pass + self.statement.close() + + return _result.ResultProxy(self) + + def create_cursor(self): + cursor = self._dbapi_connection.cursor() + cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) + return cursor + + +class ReturningResultProxy(_result.FullyBufferedResultProxy): + + """ResultProxy backed by the RETURNING ResultSet results.""" + + def __init__(self, context, returning_row): + self._returning_row = returning_row + super(ReturningResultProxy, self).__init__(context) + + def _cursor_description(self): + ret = [] + for c in self.context.compiled.returning_cols: + if hasattr(c, 'name'): + ret.append((c.name, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + return collections.deque([self._returning_row]) + + +class ReturningParam(object): + + """A bindparam value representing a RETURNING parameter. + + Specially handled by OracleReturningDataHandler. + """ + + def __init__(self, type): + self.type = type + + def __eq__(self, other): + if isinstance(other, ReturningParam): + return self.type == other.type + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ReturningParam): + return self.type != other.type + return NotImplemented + + def __repr__(self): + kls = self.__class__ + return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self), + self.type) + + +class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): + jdbc_db_name = 'oracle' + jdbc_driver_name = 'oracle.jdbc.OracleDriver' + + statement_compiler = OracleCompiler_zxjdbc + execution_ctx_cls = OracleExecutionContext_zxjdbc + + colspecs = util.update_copy( + OracleDialect.colspecs, + { + sqltypes.Date: _ZxJDBCDate, + sqltypes.Numeric: _ZxJDBCNumeric + } + ) + + def __init__(self, *args, **kwargs): + super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs) + global SQLException, zxJDBC + from java.sql import SQLException + from com.ziclix.python.sql import zxJDBC + from com.ziclix.python.sql.handler import OracleDataHandler + + class OracleReturningDataHandler(OracleDataHandler): + """zxJDBC DataHandler that specially handles ReturningParam.""" + + def setJDBCObject(self, statement, index, object, dbtype=None): + if type(object) is ReturningParam: + statement.registerReturnParameter(index, object.type) + elif dbtype is None: + OracleDataHandler.setJDBCObject( + self, statement, index, object) + else: + OracleDataHandler.setJDBCObject( + self, statement, index, object, dbtype) + self.DataHandler = OracleReturningDataHandler + + def initialize(self, connection): + super(OracleDialect_zxjdbc, self).initialize(connection) + self.implicit_returning = connection.connection.driverversion >= '10.2' + + def _create_jdbc_url(self, url): + return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) + + def _get_server_version_info(self, connection): + version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) + return tuple(int(x) for x in version.split('.')) + +dialect = OracleDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py new file mode 100644 index 00000000..6ed7e18b --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres.py @@ -0,0 +1,16 @@ +# dialects/postgres.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +# backwards compat with the old name +from sqlalchemy.util import warn_deprecated + +warn_deprecated( + "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. " + "The new URL format is postgresql[+driver]://:@/" + ) + +from sqlalchemy.dialects.postgresql import * +from sqlalchemy.dialects.postgresql import base diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 00000000..180e9fc7 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,29 @@ +# postgresql/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import base, psycopg2, pg8000, pypostgresql, zxjdbc + +base.dialect = psycopg2.dialect + +from .base import \ + INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ + INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \ + DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All, \ + TSVECTOR +from .constraints import ExcludeConstraint +from .hstore import HSTORE, hstore +from .json import JSON, JSONElement +from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ + TSTZRANGE + +__all__ = ( + 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', + 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', + 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', + 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE', + 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', + 'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONElement' +) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 00000000..f69a6e01 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,2367 @@ +# postgresql/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql + :name: PostgreSQL + + +Sequences/SERIAL +---------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means +of creating new primary key values for integer-based primary key columns. When +creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for +integer-based primary key columns, which generates a sequence and server side +default corresponding to the column. + +To specify a specific named sequence to be used for primary key generation, +use the :func:`~sqlalchemy.schema.Sequence` construct:: + + Table('sometable', metadata, + Column('id', Integer, Sequence('some_id_seq'), primary_key=True) + ) + +When SQLAlchemy issues a single INSERT statement, to fulfill the contract of +having the "last insert identifier" available, a RETURNING clause is added to +the INSERT statement which specifies the primary key columns should be +returned after the statement completes. The RETURNING functionality only takes +place if Postgresql 8.2 or later is in use. As a fallback approach, the +sequence, whether specified explicitly or implicitly via ``SERIAL``, is +executed independently beforehand, the returned value to be used in the +subsequent insert. Note that when an +:func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the "last inserted identifier" functionality does not +apply; no RETURNING clause is emitted nor is the sequence pre-executed in this +case. + +To force the usage of RETURNING by default off, specify the flag +``implicit_returning=False`` to :func:`.create_engine`. + +.. _postgresql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All Postgresql dialects support setting of transaction isolation level +both via a dialect-specific parameter ``isolation_level`` +accepted by :func:`.create_engine`, +as well as the ``isolation_level`` argument as passed to :meth:`.Connection.execution_options`. +When using a non-psycopg2 dialect, this feature works by issuing the +command ``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL +`` for each new connection. + +To set isolation level using :func:`.create_engine`:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options(isolation_level="READ COMMITTED") + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` + +The :mod:`~sqlalchemy.dialects.postgresql.psycopg2` dialect also offers the special level ``AUTOCOMMIT``. See +:ref:`psycopg2_isolation_level` for details. + +.. _postgresql_schema_reflection: + +Remote-Schema Table Introspection and Postgresql search_path +------------------------------------------------------------ + +The Postgresql dialect can reflect tables from any schema. The +:paramref:`.Table.schema` argument, or alternatively the +:paramref:`.MetaData.reflect.schema` argument determines which schema will +be searched for the table or tables. The reflected :class:`.Table` objects +will in all cases retain this ``.schema`` attribute as was specified. However, +with regards to tables which these :class:`.Table` objects refer to via +foreign key constraint, a decision must be made as to how the ``.schema`` +is represented in those remote tables, in the case where that remote +schema name is also a member of the current +`Postgresql search path `_. + +By default, the Postgresql dialect mimics the behavior encouraged by +Postgresql's own ``pg_get_constraintdef()`` builtin procedure. This function +returns a sample definition for a particular foreign key constraint, +omitting the referenced schema name from that definition when the name is +also in the Postgresql schema search path. The interaction below +illustrates this behavior:: + + test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); + CREATE TABLE + test=> CREATE TABLE referring( + test(> id INTEGER PRIMARY KEY, + test(> referred_id INTEGER REFERENCES test_schema.referred(id)); + CREATE TABLE + test=> SET search_path TO public, test_schema; + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f' + test-> ; + pg_get_constraintdef + --------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES referred(id) + (1 row) + +Above, we created a table ``referred`` as a member of the remote schema ``test_schema``, however +when we added ``test_schema`` to the PG ``search_path`` and then asked ``pg_get_constraintdef()`` +for the ``FOREIGN KEY`` syntax, ``test_schema`` was not included in the +output of the function. + +On the other hand, if we set the search path back to the typical default +of ``public``:: + + test=> SET search_path TO public; + SET + +The same query against ``pg_get_constraintdef()`` now returns the fully +schema-qualified name for us:: + + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f'; + pg_get_constraintdef + --------------------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id) + (1 row) + +SQLAlchemy will by default use the return value of ``pg_get_constraintdef()`` +in order to determine the remote schema name. That is, if our ``search_path`` +were set to include ``test_schema``, and we invoked a table +reflection process as follows:: + + >>> from sqlalchemy import Table, MetaData, create_engine + >>> engine = create_engine("postgresql://scott:tiger@localhost/test") + >>> with engine.connect() as conn: + ... conn.execute("SET search_path TO test_schema, public") + ... meta = MetaData() + ... referring = Table('referring', meta, autoload=True, autoload_with=conn) + ... + + +The above process would deliver to the :attr:`.MetaData.tables` collection +``referred`` table named **without** the schema:: + + >>> meta.tables['referred'].schema is None + True + +To alter the behavior of reflection such that the referred schema is maintained +regardless of the ``search_path`` setting, use the ``postgresql_ignore_search_path`` +option, which can be specified as a dialect-specific argument to both +:class:`.Table` as well as :meth:`.MetaData.reflect`:: + + >>> with engine.connect() as conn: + ... conn.execute("SET search_path TO test_schema, public") + ... meta = MetaData() + ... referring = Table('referring', meta, autoload=True, autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... + + +We will now have ``test_schema.referred`` stored as schema-qualified:: + + >>> meta.tables['test_schema.referred'].schema + 'test_schema' + +.. sidebar:: Best Practices for Postgresql Schema reflection + + The description of Postgresql schema reflection behavior is complex, and is + the product of many years of dealing with widely varied use cases and user preferences. + But in fact, there's no need to understand any of it if you just stick to the simplest + use pattern: leave the ``search_path`` set to its default of ``public`` only, never refer + to the name ``public`` as an explicit schema name otherwise, and + refer to all other schema names explicitly when building + up a :class:`.Table` object. The options described here are only for those users + who can't, or prefer not to, stay within these guidelines. + +Note that **in all cases**, the "default" schema is always reflected as ``None``. +The "default" schema on Postgresql is that which is returned by the +Postgresql ``current_schema()`` function. On a typical Postgresql installation, +this is the name ``public``. So a table that refers to another which is +in the ``public`` (i.e. default) schema will always have the ``.schema`` attribute +set to ``None``. + +.. versionadded:: 0.9.2 Added the ``postgresql_ignore_search_path`` + dialect-level option accepted by :class:`.Table` and :meth:`.MetaData.reflect`. + + +.. seealso:: + + `The Schema Search Path `_ - on the Postgresql website. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and +``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default +for single-row INSERT statements in order to fetch newly generated +primary key identifiers. To specify an explicit ``RETURNING`` clause, +use the :meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + result = table.update().returning(table.c.col1, table.c.col2).\\ + where(table.c.name=='foo').values(name='bar') + print result.fetchall() + + # DELETE..RETURNING + result = table.delete().returning(table.c.col1, table.c.col2).\\ + where(table.c.name=='foo') + print result.fetchall() + +.. _postgresql_match: + +Full Text Search +---------------- + +SQLAlchemy makes available the Postgresql ``@@`` operator via the +:meth:`.ColumnElement.match` method on any textual column expression. +On a Postgresql dialect, an expression like the following:: + + select([sometable.c.text.match("search string")]) + +will emit to the database:: + + SELECT text @@ to_tsquery('search string') FROM table + +The Postgresql text search functions such as ``to_tsquery()`` +and ``to_tsvector()`` are available +explicitly using the standard :attr:`.func` construct. For example:: + + select([ + func.to_tsvector('fat cats ate rats').match('cat & rat') + ]) + +Emits the equivalent of:: + + SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat') + +The :class:`.postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select([cast("some text", TSVECTOR)]) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + + +FROM ONLY ... +------------------------ + +The dialect supports PostgreSQL's ONLY keyword for targeting only a particular +table in an inheritance hierarchy. This can be used to produce the +``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...`` +syntaxes. It uses SQLAlchemy's hints mechanism:: + + # SELECT ... FROM ONLY ... + result = table.select().with_hint(table, 'ONLY', 'postgresql') + print result.fetchall() + + # UPDATE ONLY ... + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') + + # DELETE FROM ONLY ... + table.delete().with_hint('ONLY', dialect_name='postgresql') + +.. _postgresql_indexes: + +Postgresql-Specific Index Options +--------------------------------- + +Several extensions to the :class:`.Index` construct are available, specific +to the PostgreSQL dialect. + +Partial Indexes +^^^^^^^^^^^^^^^^ + +Partial indexes add criterion to the index definition so that the index is +applied to a subset of rows. These can be specified on :class:`.Index` +using the ``postgresql_where`` keyword argument:: + + Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10) + +Operator Classes +^^^^^^^^^^^^^^^^^ + +PostgreSQL allows the specification of an *operator class* for each column of +an index (see +http://www.postgresql.org/docs/8.3/interactive/indexes-opclass.html). +The :class:`.Index` construct allows these to be specified via the +``postgresql_ops`` keyword argument:: + + Index('my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +.. versionadded:: 0.7.2 + ``postgresql_ops`` keyword argument to :class:`.Index` construct. + +Note that the keys in the ``postgresql_ops`` dictionary are the "key" name of +the :class:`.Column`, i.e. the name used to access it from the ``.c`` +collection of :class:`.Table`, which can be configured to be different than +the actual name of the column as expressed in the database. + +Index Types +^^^^^^^^^^^^ + +PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well +as the ability for users to create their own (see +http://www.postgresql.org/docs/8.3/static/indexes-types.html). These can be +specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_using='gin') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX command, so it *must* be a valid index type for your +version of PostgreSQL. + +""" +from collections import defaultdict +import re + +from ... import sql, schema, exc, util +from ...engine import default, reflection +from ...sql import compiler, expression, operators +from ... import types as sqltypes + +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None + +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ + CHAR, TEXT, FLOAT, NUMERIC, \ + DATE, BOOLEAN, REAL + +RESERVED_WORDS = set( + ["all", "analyse", "analyze", "and", "any", "array", "as", "asc", + "asymmetric", "both", "case", "cast", "check", "collate", "column", + "constraint", "create", "current_catalog", "current_date", + "current_role", "current_time", "current_timestamp", "current_user", + "default", "deferrable", "desc", "distinct", "do", "else", "end", + "except", "false", "fetch", "for", "foreign", "from", "grant", "group", + "having", "in", "initially", "intersect", "into", "leading", "limit", + "localtime", "localtimestamp", "new", "not", "null", "of", "off", "offset", + "old", "on", "only", "or", "order", "placing", "primary", "references", + "returning", "select", "session_user", "some", "symmetric", "table", + "then", "to", "trailing", "true", "union", "unique", "user", "using", + "variadic", "when", "where", "window", "with", "authorization", + "between", "binary", "cross", "current_schema", "freeze", "full", + "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", + "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" + ]) + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class BYTEA(sqltypes.LargeBinary): + __visit_name__ = 'BYTEA' + + +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = 'DOUBLE_PRECISION' + + +class INET(sqltypes.TypeEngine): + __visit_name__ = "INET" +PGInet = INET + + +class CIDR(sqltypes.TypeEngine): + __visit_name__ = "CIDR" +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine): + __visit_name__ = "MACADDR" +PGMacAddr = MACADDR + + +class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.TypeEngine): + """Postgresql INTERVAL type. + + The INTERVAL type may not be supported on all DBAPIs. + It is known to work on psycopg2 and not pg8000 or zxjdbc. + + """ + __visit_name__ = 'INTERVAL' + + def __init__(self, precision=None): + self.precision = precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + +PGBit = BIT + + +class UUID(sqltypes.TypeEngine): + """Postgresql UUID type. + + Represents the UUID column type, interpreting + data either as natively returned by the DBAPI + or as Python uuid objects. + + The UUID type may not be supported on all DBAPIs. + It is known to work on psycopg2 and not pg8000. + + """ + __visit_name__ = 'UUID' + + def __init__(self, as_uuid=False): + """Construct a UUID type. + + + :param as_uuid=False: if True, values will be interpreted + as Python uuid objects, converting to/from string via the + DBAPI. + + """ + if as_uuid and _python_UUID is None: + raise NotImplementedError( + "This version of Python does not support the native UUID type." + ) + self.as_uuid = as_uuid + + def bind_processor(self, dialect): + if self.as_uuid: + def process(value): + if value is not None: + value = util.text_type(value) + return value + return process + else: + return None + + def result_processor(self, dialect, coltype): + if self.as_uuid: + def process(value): + if value is not None: + value = _python_UUID(value) + return value + return process + else: + return None + +PGUuid = UUID + +class TSVECTOR(sqltypes.TypeEngine): + """The :class:`.postgresql.TSVECTOR` type implements the Postgresql + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + __visit_name__ = 'TSVECTOR' + + + +class _Slice(expression.ColumnElement): + __visit_name__ = 'slice' + type = sqltypes.NULLTYPE + + def __init__(self, slice_, source_comparator): + self.start = source_comparator._check_literal( + source_comparator.expr, + operators.getitem, slice_.start) + self.stop = source_comparator._check_literal( + source_comparator.expr, + operators.getitem, slice_.stop) + + +class Any(expression.ColumnElement): + """Represent the clause ``left operator ANY (right)``. ``right`` must be + an array expression. + + .. seealso:: + + :class:`.postgresql.ARRAY` + + :meth:`.postgresql.ARRAY.Comparator.any` - ARRAY-bound method + + """ + __visit_name__ = 'any' + + def __init__(self, left, right, operator=operators.eq): + self.type = sqltypes.Boolean() + self.left = expression._literal_as_binds(left) + self.right = right + self.operator = operator + + +class All(expression.ColumnElement): + """Represent the clause ``left operator ALL (right)``. ``right`` must be + an array expression. + + .. seealso:: + + :class:`.postgresql.ARRAY` + + :meth:`.postgresql.ARRAY.Comparator.all` - ARRAY-bound method + + """ + __visit_name__ = 'all' + + def __init__(self, left, right, operator=operators.eq): + self.type = sqltypes.Boolean() + self.left = expression._literal_as_binds(left) + self.right = right + self.operator = operator + + +class array(expression.Tuple): + """A Postgresql ARRAY literal. + + This is used to produce ARRAY literals in SQL expressions, e.g.:: + + from sqlalchemy.dialects.postgresql import array + from sqlalchemy.dialects import postgresql + from sqlalchemy import select, func + + stmt = select([ + array([1,2]) + array([3,4,5]) + ]) + + print stmt.compile(dialect=postgresql.dialect()) + + Produces the SQL:: + + SELECT ARRAY[%(param_1)s, %(param_2)s] || + ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 + + An instance of :class:`.array` will always have the datatype + :class:`.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: + + array(['foo', 'bar'], type_=CHAR) + + .. versionadded:: 0.8 Added the :class:`~.postgresql.array` literal type. + + See also: + + :class:`.postgresql.ARRAY` + + """ + __visit_name__ = 'array' + + def __init__(self, clauses, **kw): + super(array, self).__init__(*clauses, **kw) + self.type = ARRAY(self.type) + + def _bind_param(self, operator, obj): + return array(*[ + expression.BindParameter(None, o, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + for o in obj + ]) + + def self_group(self, against=None): + return self + + +class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): + """Postgresql ARRAY type. + + Represents values as Python lists. + + An :class:`.ARRAY` type is constructed given the "type" + of element:: + + mytable = Table("mytable", metadata, + Column("data", ARRAY(Integer)) + ) + + The above type represents an N-dimensional array, + meaning Postgresql will interpret values with any number + of dimensions automatically. To produce an INSERT + construct that passes in a 1-dimensional array of integers:: + + connection.execute( + mytable.insert(), + data=[1,2,3] + ) + + The :class:`.ARRAY` type can be constructed given a fixed number + of dimensions:: + + mytable = Table("mytable", metadata, + Column("data", ARRAY(Integer, dimensions=2)) + ) + + This has the effect of the :class:`.ARRAY` type + specifying that number of bracketed blocks when a :class:`.Table` + is used in a CREATE TABLE statement, or when the type is used + within a :func:`.expression.cast` construct; it also causes + the bind parameter and result set processing of the type + to optimize itself to expect exactly that number of dimensions. + Note that Postgresql itself still allows N dimensions with such a type. + + SQL expressions of type :class:`.ARRAY` have support for "index" and + "slice" behavior. The Python ``[]`` operator works normally here, given + integer indexes or slices. Note that Postgresql arrays default + to 1-based indexing. The operator produces binary expression + constructs which will produce the appropriate SQL, both for + SELECT statements:: + + select([mytable.c.data[5], mytable.c.data[2:7]]) + + as well as UPDATE statements when the :meth:`.Update.values` method + is used:: + + mytable.update().values({ + mytable.c.data[5]: 7, + mytable.c.data[2:7]: [1, 2, 3] + }) + + :class:`.ARRAY` provides special methods for containment operations, + e.g.:: + + mytable.c.data.contains([1, 2]) + + For a full list of special methods see :class:`.ARRAY.Comparator`. + + .. versionadded:: 0.8 Added support for index and slice operations + to the :class:`.ARRAY` type, including support for UPDATE + statements, and special array containment operations. + + The :class:`.ARRAY` type may not be supported on all DBAPIs. + It is known to work on psycopg2 and not pg8000. + + See also: + + :class:`.postgresql.array` - produce a literal array value. + + """ + __visit_name__ = 'ARRAY' + + class Comparator(sqltypes.Concatenable.Comparator): + """Define comparison operations for :class:`.ARRAY`.""" + + def __getitem__(self, index): + if isinstance(index, slice): + index = _Slice(index, self) + return_type = self.type + else: + return_type = self.type.item_type + return self._binary_operate(self.expr, operators.getitem, index, + result_type=return_type) + + def any(self, other, operator=operators.eq): + """Return ``other operator ANY (array)`` clause. + + Argument places are switched, because ANY requires array + expression to be on the right hand-side. + + E.g.:: + + from sqlalchemy.sql import operators + + conn.execute( + select([table.c.data]).where( + table.c.data.any(7, operator=operators.lt) + ) + ) + + :param other: expression to be compared + :param operator: an operator object from the + :mod:`sqlalchemy.sql.operators` + package, defaults to :func:`.operators.eq`. + + .. seealso:: + + :class:`.postgresql.Any` + + :meth:`.postgresql.ARRAY.Comparator.all` + + """ + return Any(other, self.expr, operator=operator) + + def all(self, other, operator=operators.eq): + """Return ``other operator ALL (array)`` clause. + + Argument places are switched, because ALL requires array + expression to be on the right hand-side. + + E.g.:: + + from sqlalchemy.sql import operators + + conn.execute( + select([table.c.data]).where( + table.c.data.all(7, operator=operators.lt) + ) + ) + + :param other: expression to be compared + :param operator: an operator object from the + :mod:`sqlalchemy.sql.operators` + package, defaults to :func:`.operators.eq`. + + .. seealso:: + + :class:`.postgresql.All` + + :meth:`.postgresql.ARRAY.Comparator.any` + + """ + return All(other, self.expr, operator=operator) + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.expr.op('<@')(other) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.expr.op('&&')(other) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, operators.custom_op): + if op.opstring in ['@>', '<@', '&&']: + return op, sqltypes.Boolean + return sqltypes.Concatenable.Comparator.\ + _adapt_expression(self, op, other_comparator) + + comparator_factory = Comparator + + def __init__(self, item_type, as_tuple=False, dimensions=None): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. DBAPIs such + as psycopg2 return lists by default. When tuples are + returned, the results are hashable. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This will cause the DDL emitted for this + ARRAY to include the exact number of bracket clauses ``[]``, + and will also optimize the performance of the type overall. + Note that PG arrays are always implicitly "non-dimensioned", + meaning they can store any number of dimensions no matter how + they were declared. + + """ + if isinstance(item_type, ARRAY): + raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype") + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + def _proc_array(self, arr, itemproc, dim, collection): + if dim is None: + arr = list(arr) + if dim == 1 or dim is None and ( + # this has to be (list, tuple), or at least + # not hasattr('__iter__'), since Py3K strings + # etc. have __iter__ + not arr or not isinstance(arr[0], (list, tuple))): + if itemproc: + return collection(itemproc(x) for x in arr) + else: + return collection(arr) + else: + return collection( + self._proc_array( + x, itemproc, + dim - 1 if dim is not None else None, + collection) + for x in arr + ) + + def bind_processor(self, dialect): + item_proc = self.item_type.\ + dialect_impl(dialect).\ + bind_processor(dialect) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, + item_proc, + self.dimensions, + list) + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.\ + dialect_impl(dialect).\ + result_processor(dialect, coltype) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list) + return process + +PGArray = ARRAY + + +class ENUM(sqltypes.Enum): + """Postgresql ENUM type. + + This is a subclass of :class:`.types.Enum` which includes + support for PG's ``CREATE TYPE``. + + :class:`~.postgresql.ENUM` is used automatically when + using the :class:`.types.Enum` type on PG assuming + the ``native_enum`` is left as ``True``. However, the + :class:`~.postgresql.ENUM` class can also be instantiated + directly in order to access some additional Postgresql-specific + options, namely finer control over whether or not + ``CREATE TYPE`` should be emitted. + + Note that both :class:`.types.Enum` as well as + :class:`~.postgresql.ENUM` feature create/drop + methods; the base :class:`.types.Enum` type ultimately + delegates to the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods present here. + + """ + + def __init__(self, *enums, **kw): + """Construct an :class:`~.postgresql.ENUM`. + + Arguments are the same as that of + :class:`.types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + .. versionadded:: 0.7.4 + + """ + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`~.postgresql.ENUM`. + + If the underlying dialect does not support + Postgresql CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`.Engine`, + :class:`.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + if not checkfirst or \ + not bind.dialect.has_type(bind, self.name, schema=self.schema): + bind.execute(CreateEnumType(self)) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`~.postgresql.ENUM`. + + If the underlying dialect does not support + Postgresql DROP TYPE, no action is taken. + + :param bind: a connectable :class:`.Engine`, + :class:`.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + if not checkfirst or \ + bind.dialect.has_type(bind, self.name, schema=self.schema): + bind.execute(DropEnumType(self)) + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if '_ddl_runner' in kw: + ddl_runner = kw['_ddl_runner'] + if '_pg_enums' in ddl_runner.memo: + pg_enums = ddl_runner.memo['_pg_enums'] + else: + pg_enums = ddl_runner.memo['_pg_enums'] = set() + present = self.name in pg_enums + pg_enums.add(self.name) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst, **kw): + if self.metadata is not None and \ + not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + +colspecs = { + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, +} + +ischema_names = { + 'integer': INTEGER, + 'bigint': BIGINT, + 'smallint': SMALLINT, + 'character varying': VARCHAR, + 'character': CHAR, + '"char"': sqltypes.String, + 'name': sqltypes.String, + 'text': TEXT, + 'numeric': NUMERIC, + 'float': FLOAT, + 'real': REAL, + 'inet': INET, + 'cidr': CIDR, + 'uuid': UUID, + 'bit': BIT, + 'bit varying': BIT, + 'macaddr': MACADDR, + 'double precision': DOUBLE_PRECISION, + 'timestamp': TIMESTAMP, + 'timestamp with time zone': TIMESTAMP, + 'timestamp without time zone': TIMESTAMP, + 'time with time zone': TIME, + 'time without time zone': TIME, + 'date': DATE, + 'time': TIME, + 'bytea': BYTEA, + 'boolean': BOOLEAN, + 'interval': INTERVAL, + 'interval year to month': INTERVAL, + 'interval day to second': INTERVAL, + 'tsvector' : TSVECTOR +} + + +class PGCompiler(compiler.SQLCompiler): + + def visit_array(self, element, **kw): + return "ARRAY[%s]" % self.visit_clauselist(element, **kw) + + def visit_slice(self, element, **kw): + return "%s:%s" % ( + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) + + def visit_any(self, element, **kw): + return "%s%sANY (%s)" % ( + self.process(element.left, **kw), + compiler.OPERATORS[element.operator], + self.process(element.right, **kw) + ) + + def visit_all(self, element, **kw): + return "%s%sALL (%s)" % ( + self.process(element.left, **kw), + compiler.OPERATORS[element.operator], + self.process(element.right, **kw) + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_match_op_binary(self, binary, operator, **kw): + return "%s @@ to_tsquery(%s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return '%s ILIKE %s' % \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def render_literal_value(self, value, type_): + value = super(PGCompiler, self).render_literal_value(value, type_) + + if self.dialect._backslash_escapes: + value = value.replace('\\', '\\\\') + return value + + def visit_sequence(self, seq): + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + self.process(sql.literal(select._limit)) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + self.process(sql.literal(select._offset)) + return text + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + if hint.upper() != 'ONLY': + raise exc.CompileError("Unrecognized hint: %r" % hint) + return "ONLY " + sqltext + + def get_select_precolumns(self, select): + if select._distinct is not False: + if select._distinct is True: + return "DISTINCT " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [self.process(col) for col in select._distinct] + ) + ") " + else: + return "DISTINCT ON (" + self.process(select._distinct) + ") " + else: + return "" + + def for_update_clause(self, select): + + if select._for_update_arg.read: + tmp = " FOR SHARE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tables = util.OrderedSet( + c.table if isinstance(c, expression.ColumnClause) + else c for c in select._for_update_arg.of) + tmp += " OF " + ", ".join( + self.process(table, ashint=True) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + return tmp + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in expression._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0], **kw) + start = self.process(func.clauses.clauses[1], **kw) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2], **kw) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + impl_type = column.type.dialect_impl(self.dialect) + if column.primary_key and \ + column is column.table._autoincrement_column and \ + ( + self.dialect.supports_smallserial or + not isinstance(impl_type, sqltypes.SmallInteger) + ) and ( + column.default is None or + ( + isinstance(column.default, schema.Sequence) and + column.default.optional + )): + if isinstance(impl_type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_enum_type(self, create): + type_ = create.element + + return "CREATE TYPE %s AS ENUM (%s)" % ( + self.preparer.format_type(type_), + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums) + ) + + def visit_drop_enum_type(self, drop): + type_ = drop.element + + return "DROP TYPE %s" % ( + self.preparer.format_type(type_) + ) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + self._verify_index_table(index) + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s " % ( + self._prepared_index_name(index, + include_schema=False), + preparer.format_table(index.table) + ) + + using = index.dialect_options['postgresql']['using'] + if using: + text += "USING %s " % preparer.quote(using) + + ops = index.dialect_options["postgresql"]["ops"] + text += "(%s)" \ + % ( + ', '.join([ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, literal_binds=True) + + (c.key in ops and (' ' + ops[c.key]) or '') + for expr, c in zip(index.expressions, index.columns)]) + ) + + whereclause = index.dialect_options["postgresql"]["where"] + + if whereclause is not None: + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, + literal_binds=True) + text += " WHERE " + where_compiled + return text + + def visit_exclude_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + elements = [] + for c in constraint.columns: + op = constraint.operators[c.name] + elements.append(self.preparer.quote(c.name) + ' WITH '+op) + text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements)) + if constraint.where is not None: + text += ' WHERE (%s)' % self.sql_compiler.process( + constraint.where, + literal_binds=True) + text += self.define_constraint_deferrability(constraint) + return text + + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_TSVECTOR(self, type): + return "TSVECTOR" + + def visit_INET(self, type_): + return "INET" + + def visit_CIDR(self, type_): + return "CIDR" + + def visit_MACADDR(self, type_): + return "MACADDR" + + def visit_FLOAT(self, type_): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': type_.precision} + + def visit_DOUBLE_PRECISION(self, type_): + return "DOUBLE PRECISION" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_HSTORE(self, type_): + return "HSTORE" + + def visit_JSON(self, type_): + return "JSON" + + def visit_INT4RANGE(self, type_): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_): + return "NUMRANGE" + + def visit_DATERANGE(self, type_): + return "DATERANGE" + + def visit_TSRANGE(self, type_): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_): + return "TSTZRANGE" + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_enum(self, type_): + if not type_.native_enum or not self.dialect.supports_native_enum: + return super(PGTypeCompiler, self).visit_enum(type_) + else: + return self.visit_ENUM(type_) + + def visit_ENUM(self, type_): + return self.dialect.identifier_preparer.format_type(type_) + + def visit_TIMESTAMP(self, type_): + return "TIMESTAMP%s %s" % ( + getattr(type_, 'precision', None) and "(%d)" % + type_.precision or "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + ) + + def visit_TIME(self, type_): + return "TIME%s %s" % ( + getattr(type_, 'precision', None) and "(%d)" % + type_.precision or "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + ) + + def visit_INTERVAL(self, type_): + if type_.precision is not None: + return "INTERVAL(%d)" % type_.precision + else: + return "INTERVAL" + + def visit_BIT(self, type_): + if type_.varying: + compiled = "BIT VARYING" + if type_.length is not None: + compiled += "(%d)" % type_.length + else: + compiled = "BIT(%d)" % type_.length + return compiled + + def visit_UUID(self, type_): + return "UUID" + + def visit_large_binary(self, type_): + return self.visit_BYTEA(type_) + + def visit_BYTEA(self, type_): + return "BYTEA" + + def visit_ARRAY(self, type_): + return self.process(type_.item_type) + ('[]' * (type_.dimensions + if type_.dimensions + is not None else 1)) + + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].\ + replace(self.escape_to_quote, self.escape_quote) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise exc.CompileError("Postgresql ENUM type requires a name.") + + name = self.quote(type_.name) + if not self.omit_schema and use_schema and type_.schema is not None: + name = self.quote_schema(type_.schema) + "." + name + return name + + +class PGInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_oid(self, table_name, schema=None): + """Return the oid from `table_name` and `schema`.""" + + return self.dialect.get_table_oid(self.bind, table_name, schema, + info_cache=self.info_cache) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" + + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq)), type_) + + def get_insert_default(self, column): + if column.primary_key and column is column.table._autoincrement_column: + if column.server_default and column.server_default.has_argument: + + # pre-execute passive defaults on primary key columns + return self._execute_scalar("select %s" % + column.server_default.arg, column.type) + + elif (column.default is None or + (column.default.is_sequence and + column.default.optional)): + + # execute the sequence associated with a SERIAL primary + # key column. for non-primary-key SERIAL, the ID just + # generates server side. + + try: + seq_name = column._postgresql_seq_name + except AttributeError: + tab = column.table.name + col = column.name + tab = tab[0:29 + max(0, (29 - len(col)))] + col = col[0:29 + max(0, (29 - len(tab)))] + name = "%s_%s_seq" % (tab, col) + column._postgresql_seq_name = seq_name = name + + sch = column.table.schema + if sch is not None: + exc = "select nextval('\"%s\".\"%s\"')" % \ + (sch, seq_name) + else: + exc = "select nextval('\"%s\"')" % \ + (seq_name, ) + + return self._execute_scalar(exc, column.type) + + return super(PGExecutionContext, self).get_insert_default(column) + + +class PGDialect(default.DefaultDialect): + name = 'postgresql' + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + supports_native_enum = True + supports_native_boolean = True + supports_smallserial = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = True + supports_empty_insert = False + supports_multivalues_insert = True + default_paramstyle = 'pyformat' + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + execution_ctx_cls = PGExecutionContext + inspector = PGInspector + isolation_level = None + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": {} + }), + (schema.Table, { + "ignore_search_path": False + }) + ] + + reflection_options = ('postgresql_ignore_search_path', ) + + _backslash_escapes = True + + def __init__(self, isolation_level=None, json_serializer=None, + json_deserializer=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer + + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (8, 2) and \ + self.__dict__.get('implicit_returning', True) + self.supports_native_enum = self.server_version_info >= (8, 3) + if not self.supports_native_enum: + self.colspecs = self.colspecs.copy() + # pop base Enum type + self.colspecs.pop(sqltypes.Enum, None) + # psycopg2, others may have placed ENUM here as well + self.colspecs.pop(ENUM, None) + + # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + self._backslash_escapes = self.server_version_info < (8, 2) or \ + connection.scalar( + "show standard_conforming_strings" + ) == 'off' + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + _isolation_lookup = set(['SERIALIZABLE', + 'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ']) + + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('show transaction isolation level') + val = cursor.fetchone()[0] + cursor.close() + return val.upper() + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase(self, connection, xid, + is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction + # context when committing recoverable transactions + # Must find out a way how to make the dbapi not + # open a transaction. + connection.execute("ROLLBACK") + connection.execute("ROLLBACK PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, + is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute("ROLLBACK") + connection.execute("COMMIT PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute( + sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def _get_default_schema_name(self, connection): + return connection.scalar("select current_schema()") + + def has_schema(self, connection, schema): + query = "select nspname from pg_namespace where lower(nspname)=:schema" + cursor = connection.execute( + sql.text( + query, + bindparams=[ + sql.bindparam( + 'schema', util.text_type(schema.lower()), + type_=sqltypes.Unicode)] + ) + ) + + return bool(cursor.first()) + + def has_table(self, connection, table_name, schema=None): + # seems like case gets folded in pg_class... + if schema is None: + cursor = connection.execute( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=current_schema() and " + "relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(table_name), + type_=sqltypes.Unicode)] + ) + ) + else: + cursor = connection.execute( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and " + "relname=:name", + bindparams=[ + sql.bindparam('name', + util.text_type(table_name), type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), type_=sqltypes.Unicode)] + ) + ) + return bool(cursor.first()) + + def has_sequence(self, connection, sequence_name, schema=None): + if schema is None: + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=current_schema() " + "and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode) + ] + ) + ) + else: + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), type_=sqltypes.Unicode) + ] + ) + ) + + return bool(cursor.first()) + + def has_type(self, connection, type_name, schema=None): + if schema is not None: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n + WHERE t.typnamespace = n.oid + AND t.typname = :typname + AND n.nspname = :nspname + ) + """ + query = sql.text(query) + else: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t + WHERE t.typname = :typname + AND pg_type_is_visible(t.oid) + ) + """ + query = sql.text(query) + query = query.bindparams( + sql.bindparam('typname', + util.text_type(type_name), type_=sqltypes.Unicode), + ) + if schema is not None: + query = query.bindparams( + sql.bindparam('nspname', + util.text_type(schema), type_=sqltypes.Unicode), + ) + cursor = connection.execute(query) + return bool(cursor.scalar()) + + def _get_server_version_info(self, connection): + v = connection.execute("select version()").scalar() + m = re.match( + '.*(?:PostgreSQL|EnterpriseDB) ' + '(\d+)\.(\d+)(?:\.(\d+))?(?:\.\d+)?(?:devel)?', + v) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % v) + return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + table_oid = None + if schema is not None: + schema_where_clause = "n.nspname = :schema" + else: + schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" + query = """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + """ % schema_where_clause + # Since we're binding to unicode, table_name and schema_name must be + # unicode. + table_name = util.text_type(table_name) + if schema is not None: + schema = util.text_type(schema) + s = sql.text(query).bindparams(table_name=sqltypes.Unicode) + s = s.columns(oid=sqltypes.Integer) + if schema: + s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode)) + c = connection.execute(s, table_name=table_name, schema=schema) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = """ + SELECT nspname + FROM pg_namespace + ORDER BY nspname + """ + rp = connection.execute(s) + # what about system tables? + + if util.py2k: + schema_names = [row[0].decode(self.encoding) for row in rp \ + if not row[0].startswith('pg_')] + else: + schema_names = [row[0] for row in rp \ + if not row[0].startswith('pg_')] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + result = connection.execute( + sql.text("SELECT relname FROM pg_class c " + "WHERE relkind = 'r' " + "AND '%s' = (select nspname from pg_namespace n " + "where n.oid = c.relnamespace) " % + current_schema, + typemap={'relname': sqltypes.Unicode} + ) + ) + return [row[0] for row in result] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + s = """ + SELECT relname + FROM pg_class c + WHERE relkind = 'v' + AND '%(schema)s' = (select nspname from pg_namespace n + where n.oid = c.relnamespace) + """ % dict(schema=current_schema) + + if util.py2k: + view_names = [row[0].decode(self.encoding) + for row in connection.execute(s)] + else: + view_names = [row[0] for row in connection.execute(s)] + return view_names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + """ + rp = connection.execute(sql.text(s), + view_name=view_name, schema=current_schema) + if rp: + if util.py2k: + view_def = rp.scalar().decode(self.encoding) + else: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) + FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum + AND a.atthasdef) + AS DEFAULT, + a.attnotnull, a.attnum, a.attrelid as table_oid + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ + s = sql.text(SQL_COLS, + bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={'attname': sqltypes.Unicode, 'default': sqltypes.Unicode} + ) + c = connection.execute(s, table_oid=table_oid) + rows = c.fetchall() + domains = self._load_domains(connection) + enums = self._load_enums(connection) + + # format columns + columns = [] + for name, format_type, default, notnull, attnum, table_oid in rows: + column_info = self._get_column_info( + name, format_type, default, notnull, domains, enums, schema) + columns.append(column_info) + return columns + + def _get_column_info(self, name, format_type, default, + notnull, domains, enums, schema): + ## strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = re.sub(r'\(.*\)', '', format_type) + + # strip '[]' from integer[], etc. + attype = re.sub(r'\[\]', '', attype) + + nullable = not notnull + is_array = format_type.endswith('[]') + charlen = re.search('\(([\d,]+)\)', format_type) + if charlen: + charlen = charlen.group(1) + args = re.search('\((.*)\)', format_type) + if args and args.group(1): + args = tuple(re.split('\s*,\s*', args.group(1))) + else: + args = () + kwargs = {} + + if attype == 'numeric': + if charlen: + prec, scale = charlen.split(',') + args = (int(prec), int(scale)) + else: + args = () + elif attype == 'double precision': + args = (53, ) + elif attype == 'integer': + args = () + elif attype in ('timestamp with time zone', + 'time with time zone'): + kwargs['timezone'] = True + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype in ('timestamp without time zone', + 'time without time zone', 'time'): + kwargs['timezone'] = False + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype == 'bit varying': + kwargs['varying'] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype in ('interval', 'interval year to month', + 'interval day to second'): + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif charlen: + args = (int(charlen),) + + while True: + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif attype in enums: + enum = enums[attype] + coltype = ENUM + if "." in attype: + kwargs['schema'], kwargs['name'] = attype.split('.') + else: + kwargs['name'] = attype + args = tuple(enum['labels']) + break + elif attype in domains: + domain = domains[attype] + attype = domain['attype'] + # A table can't override whether the domain is nullable. + nullable = domain['nullable'] + if domain['default'] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain['default'] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = ARRAY(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (attype, name)) + coltype = sqltypes.NULLTYPE + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + autoincrement = True + # the default is related to a Sequence + sch = schema + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = match.group(1) + \ + ('"%s"' % sch) + '.' + \ + match.group(2) + match.group(3) + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + return column_info + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + if self.server_version_info < (8, 4): + PK_SQL = """ + SELECT a.attname + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_attribute a + on t.oid=a.attrelid AND %s + WHERE + t.oid = :table_oid and ix.indisprimary = 't' + ORDER BY a.attnum + """ % self._pg_index_any("a.attnum", "ix.indkey") + + else: + # unnest() and generate_subscripts() both introduced in + # version 8.4 + PK_SQL = """ + SELECT a.attname + FROM pg_attribute a JOIN ( + SELECT unnest(ix.indkey) attnum, + generate_subscripts(ix.indkey, 1) ord + FROM pg_index ix + WHERE ix.indrelid = :table_oid AND ix.indisprimary + ) k ON a.attnum=k.attnum + WHERE a.attrelid = :table_oid + ORDER BY k.ord + """ + t = sql.text(PK_SQL, typemap={'attname': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + cols = [r[0] for r in c.fetchall()] + + PK_CONS_SQL = """ + SELECT conname + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table_oid AND r.contype = 'p' + ORDER BY 1 + """ + t = sql.text(PK_CONS_SQL, typemap={'conname': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + name = c.scalar() + + return {'constrained_columns': cols, 'name': name} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, + postgresql_ignore_search_path=False, **kw): + preparer = self.identifier_preparer + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + FK_SQL = """ + SELECT r.conname, + pg_catalog.pg_get_constraintdef(r.oid, true) as condef, + n.nspname as conschema + FROM pg_catalog.pg_constraint r, + pg_namespace n, + pg_class c + + WHERE r.conrelid = :table AND + r.contype = 'f' AND + c.oid = confrelid AND + n.oid = c.relnamespace + ORDER BY 1 + """ + # http://www.postgresql.org/docs/9.0/static/sql-createtable.html + FK_REGEX = re.compile( + r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)' + r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?' + r'[\s]?(ON UPDATE (CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(ON DELETE (CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?' + r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?' + ) + + t = sql.text(FK_SQL, typemap={ + 'conname': sqltypes.Unicode, + 'condef': sqltypes.Unicode}) + c = connection.execute(t, table=table_oid) + fkeys = [] + for conname, condef, conschema in c.fetchall(): + m = re.search(FK_REGEX, condef).groups() + + constrained_columns, referred_schema, \ + referred_table, referred_columns, \ + _, match, _, onupdate, _, ondelete, \ + deferrable, _, initially = m + + if deferrable is not None: + deferrable = True if deferrable == 'DEFERRABLE' else False + constrained_columns = [preparer._unquote_identifier(x) + for x in re.split(r'\s*,\s*', constrained_columns)] + + if postgresql_ignore_search_path: + # when ignoring search path, we use the actual schema + # provided it isn't the "default" schema + if conschema != self.default_schema_name: + referred_schema = conschema + else: + referred_schema = schema + elif referred_schema: + # referred_schema is the schema that we regexp'ed from + # pg_get_constraintdef(). If the schema is in the search + # path, pg_get_constraintdef() will give us None. + referred_schema = \ + preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == conschema: + # If the actual schema matches the schema of the table + # we're reflecting, then we will use that. + referred_schema = schema + + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) + for x in re.split(r'\s*,\s', referred_columns)] + fkey_d = { + 'name': conname, + 'constrained_columns': constrained_columns, + 'referred_schema': referred_schema, + 'referred_table': referred_table, + 'referred_columns': referred_columns, + 'options': { + 'onupdate': onupdate, + 'ondelete': ondelete, + 'deferrable': deferrable, + 'initially': initially, + 'match': match + } + } + fkeys.append(fkey_d) + return fkeys + + def _pg_index_any(self, col, compare_to): + if self.server_version_info < (8, 1): + # http://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us + # "In CVS tip you could replace this with "attnum = ANY (indkey)". + # Unfortunately, most array support doesn't work on int2vector in + # pre-8.1 releases, so I think you're kinda stuck with the above + # for now. + # regards, tom lane" + return "(%s)" % " OR ".join( + "%s[%d] = %s" % (compare_to, ind, col) + for ind in range(0, 10) + ) + else: + return "%s = ANY(%s)" % (col, compare_to) + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + # cast indkey as varchar since it's an int2vector, + # returned as a list by some drivers such as pypostgresql + + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, ix.indkey%s + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid=ix.indexrelid + left outer join + pg_attribute a + on t.oid=a.attrelid and %s + WHERE + t.relkind = 'r' + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ % ( + # version 8.3 here was based on observing the + # cast does not work in PG 8.2.4, does work in 8.3.0. + # nothing in PG changelogs regarding this. + "::varchar" if self.server_version_info >= (8, 3) else "", + self._pg_index_any("a.attnum", "ix.indkey") + ) + + t = sql.text(IDX_SQL, typemap={'attname': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + + indexes = defaultdict(lambda: defaultdict(dict)) + + sv_idx_name = None + for row in c.fetchall(): + idx_name, unique, expr, prd, col, col_num, idx_key = row + + if expr: + if idx_name != sv_idx_name: + util.warn( + "Skipped unsupported reflection of " + "expression-based index %s" + % idx_name) + sv_idx_name = idx_name + continue + + if prd and not idx_name == sv_idx_name: + util.warn( + "Predicate of partial index %s ignored during reflection" + % idx_name) + sv_idx_name = idx_name + + index = indexes[idx_name] + if col is not None: + index['cols'][col_num] = col + index['key'] = [int(k.strip()) for k in idx_key.split()] + index['unique'] = unique + + return [ + {'name': name, + 'unique': idx['unique'], + 'column_names': [idx['cols'][i] for i in idx['key']]} + for name, idx in indexes.items() + ] + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + UNIQUE_SQL = """ + SELECT + cons.conname as name, + cons.conkey as key, + a.attnum as col_num, + a.attname as col_name + FROM + pg_catalog.pg_constraint cons + join pg_attribute a + on cons.conrelid = a.attrelid AND a.attnum = ANY(cons.conkey) + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'u' + """ + + t = sql.text(UNIQUE_SQL, typemap={'col_name': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + + uniques = defaultdict(lambda: defaultdict(dict)) + for row in c.fetchall(): + uc = uniques[row.name] + uc["key"] = row.key + uc["cols"][row.col_num] = row.col_name + + return [ + {'name': name, + 'column_names': [uc["cols"][i] for i in uc["key"]]} + for name, uc in uniques.items() + ] + + def _load_enums(self, connection): + if not self.supports_native_enum: + return {} + + ## Load data types for enums: + SQL_ENUMS = """ + SELECT t.typname as "name", + -- no enum defaults in 8.4 at least + -- t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema", + e.enumlabel as "label" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid + WHERE t.typtype = 'e' + ORDER BY "name", e.oid -- e.oid gives us label order + """ + + s = sql.text(SQL_ENUMS, typemap={ + 'attname': sqltypes.Unicode, + 'label': sqltypes.Unicode}) + c = connection.execute(s) + + enums = {} + for enum in c.fetchall(): + if enum['visible']: + # 'visible' just means whether or not the enum is in a + # schema that's on the search path -- or not overridden by + # a schema with higher precedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + name = enum['name'] + else: + name = "%s.%s" % (enum['schema'], enum['name']) + + if name in enums: + enums[name]['labels'].append(enum['label']) + else: + enums[name] = { + 'labels': [enum['label']], + } + + return enums + + def _load_domains(self, connection): + ## Load data types for domains: + SQL_DOMAINS = """ + SELECT t.typname as "name", + pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", + not t.typnotnull as "nullable", + t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE t.typtype = 'd' + """ + + s = sql.text(SQL_DOMAINS, typemap={'attname': sqltypes.Unicode}) + c = connection.execute(s) + + domains = {} + for domain in c.fetchall(): + ## strip (30) from character varying(30) + attype = re.search('([^\(]+)', domain['attype']).group(1) + if domain['visible']: + # 'visible' just means whether or not the domain is in a + # schema that's on the search path -- or not overridden by + # a schema with higher precedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + name = domain['name'] + else: + name = "%s.%s" % (domain['schema'], domain['name']) + + domains[name] = { + 'attype': attype, + 'nullable': domain['nullable'], + 'default': domain['default'] + } + + return domains diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py new file mode 100644 index 00000000..f45cef1a --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/constraints.py @@ -0,0 +1,73 @@ +# Copyright (C) 2013-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from sqlalchemy.schema import ColumnCollectionConstraint +from sqlalchemy.sql import expression + +class ExcludeConstraint(ColumnCollectionConstraint): + """A table-level EXCLUDE constraint. + + Defines an EXCLUDE constraint as described in the `postgres + documentation`__. + + __ http://www.postgresql.org/docs/9.0/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE + """ + + __visit_name__ = 'exclude_constraint' + + where = None + + def __init__(self, *elements, **kw): + """ + :param \*elements: + A sequence of two tuples of the form ``(column, operator)`` where + column must be a column name or Column object and operator must + be a string containing the operator to use. + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param using: + Optional string. If set, emit USING when issuing DDL + for this constraint. Defaults to 'gist'. + + :param where: + Optional string. If set, emit WHERE when issuing DDL + for this constraint. + + """ + ColumnCollectionConstraint.__init__( + self, + *[col for col, op in elements], + name=kw.get('name'), + deferrable=kw.get('deferrable'), + initially=kw.get('initially') + ) + self.operators = {} + for col_or_string, op in elements: + name = getattr(col_or_string, 'name', col_or_string) + self.operators[name] = op + self.using = kw.get('using', 'gist') + where = kw.get('where') + if where: + self.where = expression._literal_as_text(where) + + def copy(self, **kw): + elements = [(col, self.operators[col]) + for col in self.columns.keys()] + c = self.__class__(*elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially) + c.dispatch._update(self.dispatch) + return c + diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py new file mode 100644 index 00000000..76562088 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -0,0 +1,369 @@ +# postgresql/hstore.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import re + +from .base import ARRAY, ischema_names +from ... import types as sqltypes +from ...sql import functions as sqlfunc +from ...sql.operators import custom_op +from ... import util + +__all__ = ('HSTORE', 'hstore') + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile(r""" +( + "(?P (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P NULL ) # NULL value + | "(?P (\\ . | [^"])* )" # Quoted value +) +""", re.VERBOSE) + +HSTORE_DELIMITER_RE = re.compile(r""" +[ ]* , [ ]* +""", re.VERBOSE) + + +def _parse_error(hstore_str, pos): + """format an unmarshalling error.""" + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] + residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = '[...]' + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + '[...]' + + return "After %r, could not parse residual at position %d: %r" % ( + parsed_tail, pos, residual) + + +def _parse_hstore(hstore_str): + """Parse an hstore from it's literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group('key').replace(r'\"', '"').replace("\\\\", "\\") + if pair_match.group('value_null'): + value = None + else: + value = pair_match.group('value').replace(r'\"', '"').replace("\\\\", "\\") + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise ValueError(_parse_error(hstore_str, pos)) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + def esc(s, position): + if position == 'value' and s is None: + return 'NULL' + elif isinstance(s, util.string_types): + return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"') + else: + raise ValueError("%r in %s position is not a string." % + (s, position)) + + return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) + for k, v in val.items()) + + +class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine): + """Represent the Postgresql HSTORE type. + + The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', HSTORE) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.HSTORE` provides for a wide range of operations, including: + + * Index operations:: + + data_table.c.data['some key'] == 'some value' + + * Containment operations:: + + data_table.c.data.has_key('some key') + + data_table.c.data.has_all(['one', 'two', 'three']) + + * Concatenation:: + + data_table.c.data + {"k1": "v1"} + + For a full list of special methods see :class:`.HSTORE.comparator_factory`. + + For usage with the SQLAlchemy ORM, it may be desirable to combine + the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary + now part of the :mod:`sqlalchemy.ext.mutable` + extension. This extension will allow "in-place" changes to the + dictionary, e.g. addition of new keys or replacement/removal of existing + keys to/from the current dictionary, to produce events which will be detected + by the unit of work:: + + from sqlalchemy.ext.mutable import MutableDict + + class MyClass(Base): + __tablename__ = 'data_table' + + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(HSTORE)) + + my_object = session.query(MyClass).one() + + # in-place mutation, requires Mutable extension + # in order for the ORM to detect + my_object.data['some_key'] = 'some value' + + session.commit() + + When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM + will not be alerted to any changes to the contents of an existing dictionary, + unless that dictionary value is re-assigned to the HSTORE-attribute itself, + thus generating a change event. + + .. versionadded:: 0.8 + + .. seealso:: + + :class:`.hstore` - render the Postgresql ``hstore()`` function. + + + """ + + __visit_name__ = 'HSTORE' + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for :class:`.HSTORE`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.expr.op('?')(other) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in the PG + array. + """ + return self.expr.op('?&')(other) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in the PG + array. + """ + return self.expr.op('?|')(other) + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. + """ + return _HStoreDefinedFunction(self.expr, key) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys are a superset of the keys of + the argument hstore expression. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument hstore expression. + """ + return self.expr.op('<@')(other) + + def __getitem__(self, other): + """Text expression. Get the value at a given key. Note that the + key may be a SQLA expression. + """ + return self.expr.op('->', precedence=5)(other) + + def delete(self, key): + """HStore expression. Returns the contents of this hstore with the + given key deleted. Note that the key may be a SQLA expression. + """ + if isinstance(key, dict): + key = _serialize_hstore(key) + return _HStoreDeleteFunction(self.expr, key) + + def slice(self, array): + """HStore expression. Returns a subset of an hstore defined by + array of keys. + """ + return _HStoreSliceFunction(self.expr, array) + + def keys(self): + """Text array expression. Returns array of keys.""" + return _HStoreKeysFunction(self.expr) + + def vals(self): + """Text array expression. Returns array of values.""" + return _HStoreValsFunction(self.expr) + + def array(self): + """Text array expression. Returns array of alternating keys and + values. + """ + return _HStoreArrayFunction(self.expr) + + def matrix(self): + """Text array expression. Returns array of [key, value] pairs.""" + return _HStoreMatrixFunction(self.expr) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, custom_op): + if op.opstring in ['?', '?&', '?|', '@>', '<@']: + return op, sqltypes.Boolean + elif op.opstring == '->': + return op, sqltypes.Text + return sqltypes.Concatenable.Comparator.\ + _adapt_expression(self, op, other_comparator) + + def bind_processor(self, dialect): + if util.py2k: + encoding = dialect.encoding + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value).encode(encoding) + else: + return value + else: + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value) + else: + return value + return process + + def result_processor(self, dialect, coltype): + if util.py2k: + encoding = dialect.encoding + def process(value): + if value is not None: + return _parse_hstore(value.decode(encoding)) + else: + return value + else: + def process(value): + if value is not None: + return _parse_hstore(value) + else: + return value + return process + + +ischema_names['hstore'] = HSTORE + + +class hstore(sqlfunc.GenericFunction): + """Construct an hstore value within a SQL expression using the + Postgresql ``hstore()`` function. + + The :class:`.hstore` function accepts one or two arguments as described + in the Postgresql documentation. + + E.g.:: + + from sqlalchemy.dialects.postgresql import array, hstore + + select([hstore('key1', 'value1')]) + + select([ + hstore( + array(['key1', 'key2', 'key3']), + array(['value1', 'value2', 'value3']) + ) + ]) + + .. versionadded:: 0.8 + + .. seealso:: + + :class:`.HSTORE` - the Postgresql ``HSTORE`` datatype. + + """ + type = HSTORE + name = 'hstore' + + +class _HStoreDefinedFunction(sqlfunc.GenericFunction): + type = sqltypes.Boolean + name = 'defined' + + +class _HStoreDeleteFunction(sqlfunc.GenericFunction): + type = HSTORE + name = 'delete' + + +class _HStoreSliceFunction(sqlfunc.GenericFunction): + type = HSTORE + name = 'slice' + + +class _HStoreKeysFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'akeys' + + +class _HStoreValsFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'avals' + + +class _HStoreArrayFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'hstore_to_array' + + +class _HStoreMatrixFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = 'hstore_to_matrix' diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py new file mode 100644 index 00000000..2e29185e --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -0,0 +1,199 @@ +# postgresql/json.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import absolute_import + +import json + +from .base import ischema_names +from ... import types as sqltypes +from ...sql.operators import custom_op +from ... import sql +from ...sql import elements +from ... import util + +__all__ = ('JSON', 'JSONElement') + + +class JSONElement(elements.BinaryExpression): + """Represents accessing an element of a :class:`.JSON` value. + + The :class:`.JSONElement` is produced whenever using the Python index + operator on an expression that has the type :class:`.JSON`:: + + expr = mytable.c.json_data['some_key'] + + The expression typically compiles to a JSON access such as ``col -> key``. + Modifiers are then available for typing behavior, including :meth:`.JSONElement.cast` + and :attr:`.JSONElement.astext`. + + """ + def __init__(self, left, right, astext=False, opstring=None, result_type=None): + self._astext = astext + if opstring is None: + if hasattr(right, '__iter__') and \ + not isinstance(right, util.string_types): + opstring = "#>" + right = "{%s}" % (", ".join(util.text_type(elem) for elem in right)) + else: + opstring = "->" + + self._json_opstring = opstring + operator = custom_op(opstring, precedence=5) + right = left._check_literal(left, operator, right) + super(JSONElement, self).__init__(left, right, operator, type_=result_type) + + @property + def astext(self): + """Convert this :class:`.JSONElement` to use the 'astext' operator + when evaluated. + + E.g.:: + + select([data_table.c.data['some key'].astext]) + + .. seealso:: + + :meth:`.JSONElement.cast` + + """ + if self._astext: + return self + else: + return JSONElement( + self.left, + self.right, + astext=True, + opstring=self._json_opstring + ">", + result_type=sqltypes.String(convert_unicode=True) + ) + + def cast(self, type_): + """Convert this :class:`.JSONElement` to apply both the 'astext' operator + as well as an explicit type cast when evaulated. + + E.g.:: + + select([data_table.c.data['some key'].cast(Integer)]) + + .. seealso:: + + :attr:`.JSONElement.astext` + + """ + if not self._astext: + return self.astext.cast(type_) + else: + return sql.cast(self, type_) + + +class JSON(sqltypes.TypeEngine): + """Represent the Postgresql JSON type. + + The :class:`.JSON` type stores arbitrary JSON format data, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSON) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.JSON` provides several operations: + + * Index operations:: + + data_table.c.data['some key'] + + * Index operations returning text (required for text comparison):: + + data_table.c.data['some key'].astext == 'some value' + + * Index operations with a built-in CAST call:: + + data_table.c.data['some key'].cast(Integer) == 5 + + * Path index operations:: + + data_table.c.data[('key_1', 'key_2', ..., 'key_n')] + + * Path index operations returning text (required for text comparison):: + + data_table.c.data[('key_1', 'key_2', ..., 'key_n')].astext == 'some value' + + Index operations return an instance of :class:`.JSONElement`, which represents + an expression such as ``column -> index``. This element then defines + methods such as :attr:`.JSONElement.astext` and :meth:`.JSONElement.cast` + for setting up type behavior. + + The :class:`.JSON` type, when used with the SQLAlchemy ORM, does not detect + in-place mutations to the structure. In order to detect these, the + :mod:`sqlalchemy.ext.mutable` extension must be used. This extension will + allow "in-place" changes to the datastructure to produce events which + will be detected by the unit of work. See the example at :class:`.HSTORE` + for a simple example involving a dictionary. + + Custom serializers and deserializers are specified at the dialect level, + that is using :func:`.create_engine`. The reason for this is that when + using psycopg2, the DBAPI only allows serializers at the per-cursor + or per-connection level. E.g.:: + + engine = create_engine("postgresql://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) + + When using the psycopg2 dialect, the json_deserializer is registered + against the database using ``psycopg2.extras.register_default_json``. + + .. versionadded:: 0.9 + + """ + + __visit_name__ = 'JSON' + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for :class:`.JSON`.""" + + def __getitem__(self, other): + """Get the value at a given key.""" + + return JSONElement(self.expr, other) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, custom_op): + if op.opstring == '->': + return op, sqltypes.Text + return sqltypes.Concatenable.Comparator.\ + _adapt_expression(self, op, other_comparator) + + def bind_processor(self, dialect): + json_serializer = dialect._json_serializer or json.dumps + if util.py2k: + encoding = dialect.encoding + def process(value): + return json_serializer(value).encode(encoding) + else: + def process(value): + return json_serializer(value) + return process + + def result_processor(self, dialect, coltype): + json_deserializer = dialect._json_deserializer or json.loads + if util.py2k: + encoding = dialect.encoding + def process(value): + return json_deserializer(value.decode(encoding)) + else: + def process(value): + return json_deserializer(value) + return process + + +ischema_names['json'] = JSON diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 00000000..bc73f975 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,126 @@ +# postgresql/pg8000.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql+pg8000 + :name: pg8000 + :dbapi: pg8000 + :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] + :url: http://pybrary.net/pg8000/ + +Unicode +------- + +pg8000 requires that the postgresql client encoding be +configured in the postgresql.conf file in order to use encodings +other than ascii. Set this value to the same value as the +"encoding" parameter on create_engine(), usually "utf-8". + +Interval +-------- + +Passing data from/to the Interval type is not supported as of +yet. + +""" +from ... import util, exc +import decimal +from ... import processors +from ... import types as sqltypes +from .base import PGDialect, \ + PGCompiler, PGIdentifierPreparer, PGExecutionContext,\ + _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES + + +class _PGNumeric(sqltypes.Numeric): + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, + self._effective_decimal_return_scale) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + + +class _PGNumericNoBind(_PGNumeric): + def bind_processor(self, dialect): + return None + + +class PGExecutionContext_pg8000(PGExecutionContext): + pass + + +class PGCompiler_pg8000(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgresql dialect " + "now automatically escapes '%' in text() " + "expressions to '%%'.") + return text.replace('%', '%%') + + +class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') + + +class PGDialect_pg8000(PGDialect): + driver = 'pg8000' + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = 'format' + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_pg8000 + statement_compiler = PGCompiler_pg8000 + preparer = PGIdentifierPreparer_pg8000 + description_encoding = 'use_encoding' + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PGNumericNoBind, + sqltypes.Float: _PGNumeric + } + ) + + @classmethod + def dbapi(cls): + return __import__('pg8000').dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + return "connection is closed" in str(e) + +dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 00000000..ac177062 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,515 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql+psycopg2 + :name: psycopg2 + :dbapi: psycopg2 + :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] + :url: http://pypi.python.org/pypi/psycopg2/ + +psycopg2 Connect Arguments +----------------------------------- + +psycopg2-specific keyword arguments which are accepted by +:func:`.create_engine()` are: + +* ``server_side_cursors``: Enable the usage of "server side cursors" for SQL + statements which support this feature. What this essentially means from a + psycopg2 point of view is that the cursor is created using a name, e.g. + ``connection.cursor('some name')``, which has the effect that result rows are + not immediately pre-fetched and buffered after statement execution, but are + instead left on the server and only retrieved as needed. SQLAlchemy's + :class:`~sqlalchemy.engine.ResultProxy` uses special row-buffering + behavior when this feature is enabled, such that groups of 100 rows at a + time are fetched over the wire to reduce conversational overhead. + Note that the ``stream_results=True`` execution option is a more targeted + way of enabling this mode on a per-execution basis. +* ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode + per connection. True by default. +* ``isolation_level``: This option, available for all Posgtresql dialects, + includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 + dialect. See :ref:`psycopg2_isolation_level`. + + +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +See also: + +`PQconnectdbParams `_ + +Per-Statement/Connection Execution Options +------------------------------------------- + +The following DBAPI-specific options are respected when used with +:meth:`.Connection.execution_options`, :meth:`.Executable.execution_options`, +:meth:`.Query.execution_options`, in addition to those not specific to DBAPIs: + +* isolation_level - Set the transaction isolation level for the lifespan of a + :class:`.Connection` (can only be set on a connection, not a statement + or query). See :ref:`psycopg2_isolation_level`. + +* stream_results - Enable or disable usage of psycopg2 server side cursors - + this feature makes use of "named" cursors in combination with special + result handling methods so that result rows are not fully buffered. + If ``None`` or not set, the ``server_side_cursors`` option of the + :class:`.Engine` is used. + +Unicode +------- + +By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE`` +extension, such that the DBAPI receives and returns all strings as Python +Unicode objects directly - SQLAlchemy passes these values through without +change. Psycopg2 here will encode/decode string values based on the +current "client encoding" setting; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +A second way to affect the client encoding is to set it within Psycopg2 +locally. SQLAlchemy will call psycopg2's ``set_client_encoding()`` +method (see: http://initd.org/psycopg/docs/connection.html#connection.set_client_encoding) +on all new connections based on the value passed to +:func:`.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine("postgresql://user:pass@host/dbname", client_encoding='utf8') + +This overrides the encoding specified in the Postgresql client configuration. + +.. versionadded:: 0.7.3 + The psycopg2-specific ``client_encoding`` parameter to + :func:`.create_engine`. + +SQLAlchemy can also be instructed to skip the usage of the psycopg2 +``UNICODE`` extension and to instead utilize it's own unicode encode/decode +services, which are normally reserved only for those DBAPIs that don't +fully support unicode directly. Passing ``use_native_unicode=False`` to +:func:`.create_engine` will disable usage of ``psycopg2.extensions.UNICODE``. +SQLAlchemy will instead encode data itself into Python bytestrings on the way +in and coerce from bytes on the way back, +using the value of the :func:`.create_engine` ``encoding`` parameter, which +defaults to ``utf-8``. +SQLAlchemy's own unicode encode/decode functionality is steadily becoming +obsolete as more DBAPIs support unicode fully along with the approach of +Python 3; in modern usage psycopg2 should be relied upon to handle unicode. + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + +.. _psycopg2_isolation_level: + +Psycopg2 Transaction Isolation Level +------------------------------------- + +As discussed in :ref:`postgresql_isolation_level`, +all Postgresql dialects support setting of transaction isolation level +both via the ``isolation_level`` parameter passed to :func:`.create_engine`, +as well as the ``isolation_level`` argument used by :meth:`.Connection.execution_options`. +When using the psycopg2 dialect, these options make use of +psycopg2's ``set_isolation_level()`` connection method, rather than +emitting a Postgresql directive; this is because psycopg2's API-level +setting is always emitted at the start of each transaction in any case. + +The psycopg2 dialect supports these constants for isolation level: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using + psycopg2. + + +NOTICE logging +--------------- + +The psycopg2 dialect will log Postgresql NOTICE messages via the +``sqlalchemy.dialects.postgresql`` logger:: + + import logging + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +.. _psycopg2_hstore:: + +HSTORE type +------------ + +The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of the +HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension +by default when it is detected that the target database has the HSTORE +type set up for use. In other words, when the dialect makes the first +connection, a sequence like the following is performed: + +1. Request the available HSTORE oids using ``psycopg2.extras.HstoreAdapter.get_oids()``. + If this function returns a list of HSTORE identifiers, we then determine that + the ``HSTORE`` extension is present. + +2. If the ``use_native_hstore`` flag is at it's default of ``True``, and + we've detected that ``HSTORE`` oids are available, the + ``psycopg2.extensions.register_hstore()`` extension is invoked for all + connections. + +The ``register_hstore()`` extension has the effect of **all Python dictionaries +being accepted as parameters regardless of the type of target column in SQL**. +The dictionaries are converted by this extension into a textual HSTORE expression. +If this behavior is not desired, disable the +use of the hstore extension by setting ``use_native_hstore`` to ``False`` as follows:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) + +The ``HSTORE`` type is **still supported** when the ``psycopg2.extensions.register_hstore()`` +extension is not used. It merely means that the coercion between Python dictionaries and the HSTORE +string format, on both the parameter side and the result side, will take +place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2`` which +may be more performant. + +""" +from __future__ import absolute_import + +import re +import logging + +from ... import util, exc +import decimal +from ... import processors +from ...engine import result as _result +from ...sql import expression +from ... import types as sqltypes +from .base import PGDialect, PGCompiler, \ + PGIdentifierPreparer, PGExecutionContext, \ + ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\ + _INT_TYPES +from .hstore import HSTORE +from .json import JSON + + +logger = logging.getLogger('sqlalchemy.dialects.postgresql') + + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, + self._effective_decimal_return_scale) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + + +class _PGEnum(ENUM): + def result_processor(self, dialect, coltype): + if util.py2k and self.convert_unicode is True: + # we can't easily use PG's extensions here because + # the OID is on the fly, and we need to give it a python + # function anyway - not really worth it. + self.convert_unicode = "force_nocheck" + return super(_PGEnum, self).result_processor(dialect, coltype) + +class _PGHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).result_processor(dialect, coltype) + + +class _PGJSON(JSON): + + def result_processor(self, dialect, coltype): + if dialect._has_native_json: + return None + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + +# When we're handed literal SQL, ensure it's a SELECT-query. Since +# 8.3, combining cursors and "FOR UPDATE" has been fine. +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +_server_side_id = util.counter() + + +class PGExecutionContext_psycopg2(PGExecutionContext): + def create_cursor(self): + # TODO: coverage for server side cursors + select.for_update() + + if self.dialect.server_side_cursors: + is_server_side = \ + self.execution_options.get('stream_results', True) and ( + (self.compiled and isinstance(self.compiled.statement, expression.Selectable) \ + or \ + ( + (not self.compiled or + isinstance(self.compiled.statement, expression.TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) + ) + ) + else: + is_server_side = \ + self.execution_options.get('stream_results', False) + + self.__is_server_side = is_server_side + if is_server_side: + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + else: + return self._dbapi_connection.cursor() + + def get_result_proxy(self): + # TODO: ouch + if logger.isEnabledFor(logging.INFO): + self._log_notices(self.cursor) + + if self.__is_server_side: + return _result.BufferedRowResultProxy(self) + else: + return _result.ResultProxy(self) + + def _log_notices(self, cursor): + for notice in cursor.connection.notices: + # NOTICE messages have a + # newline character at the end + logger.info(notice.rstrip()) + + cursor.connection.notices[:] = [] + + +class PGCompiler_psycopg2(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace('%', '%%') + + +class PGDialect_psycopg2(PGDialect): + driver = 'psycopg2' + if util.py2k: + supports_unicode_statements = False + + default_paramstyle = 'pyformat' + supports_sane_multi_rowcount = False # set to true based on psycopg2 version + execution_ctx_cls = PGExecutionContext_psycopg2 + statement_compiler = PGCompiler_psycopg2 + preparer = PGIdentifierPreparer_psycopg2 + psycopg2_version = (0, 0) + + _has_native_hstore = False + _has_native_json = False + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PGNumeric, + ENUM: _PGEnum, # needs force_unicode + sqltypes.Enum: _PGEnum, # needs force_unicode + HSTORE: _PGHStore, + JSON: _PGJSON + } + ) + + def __init__(self, server_side_cursors=False, use_native_unicode=True, + client_encoding=None, + use_native_hstore=True, + **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + self.use_native_unicode = use_native_unicode + self.use_native_hstore = use_native_hstore + self.supports_unicode_binds = use_native_unicode + self.client_encoding = client_encoding + if self.dbapi and hasattr(self.dbapi, '__version__'): + m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', + self.dbapi.__version__) + if m: + self.psycopg2_version = tuple( + int(x) + for x in m.group(1, 2, 3) + if x is not None) + + def initialize(self, connection): + super(PGDialect_psycopg2, self).initialize(connection) + self._has_native_hstore = self.use_native_hstore and \ + self._hstore_oids(connection.connection) \ + is not None + self._has_native_json = self.psycopg2_version >= (2, 5) + + # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 + self.supports_sane_multi_rowcount = self.psycopg2_version >= (2, 0, 9) + + @classmethod + def dbapi(cls): + import psycopg2 + return psycopg2 + + @util.memoized_property + def _isolation_lookup(self): + from psycopg2 import extensions + return { + 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, + 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, + 'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + 'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ, + 'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE + } + + def set_isolation_level(self, connection, level): + try: + level = self._isolation_lookup[level.replace('_', ' ')] + except KeyError: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + + connection.set_isolation_level(level) + + def on_connect(self): + from psycopg2 import extras, extensions + + fns = [] + if self.client_encoding is not None: + def on_connect(conn): + conn.set_client_encoding(self.client_encoding) + fns.append(on_connect) + + if self.isolation_level is not None: + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) + + if self.dbapi and self.use_native_unicode: + def on_connect(conn): + extensions.register_type(extensions.UNICODE, conn) + extensions.register_type(extensions.UNICODEARRAY, conn) + fns.append(on_connect) + + if self.dbapi and self.use_native_hstore: + def on_connect(conn): + hstore_oids = self._hstore_oids(conn) + if hstore_oids is not None: + oid, array_oid = hstore_oids + if util.py2k: + extras.register_hstore(conn, oid=oid, + array_oid=array_oid, + unicode=True) + else: + extras.register_hstore(conn, oid=oid, + array_oid=array_oid) + fns.append(on_connect) + + if self.dbapi and self._json_deserializer: + def on_connect(conn): + extras.register_default_json(conn, loads=self._json_deserializer) + fns.append(on_connect) + + if fns: + def on_connect(conn): + for fn in fns: + fn(conn) + return on_connect + else: + return None + + @util.memoized_instancemethod + def _hstore_oids(self, conn): + if self.psycopg2_version >= (2, 4): + from psycopg2 import extras + oids = extras.HstoreAdapter.get_oids(conn) + if oids is not None and oids[0]: + return oids[0:2] + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + str_e = str(e).partition("\n")[0] + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + # TODO: these are sent through gettext in libpq and we can't + # check within other locales - consider using connection.closed + 'terminating connection', + 'closed the connection', + 'connection not open', + 'could not receive data from server', + 'could not send data to server', + # psycopg2 client errors, psycopg2/conenction.h, psycopg2/cursor.h + 'connection already closed', + 'cursor already closed', + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + 'losed the connection unexpectedly' + ]: + idx = str_e.find(msg) + if idx >= 0 and '"' not in str_e[:idx]: + return True + return False + +dialect = PGDialect_psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py new file mode 100644 index 00000000..f030d2c1 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -0,0 +1,78 @@ +# postgresql/pypostgresql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql+pypostgresql + :name: py-postgresql + :dbapi: pypostgresql + :connectstring: postgresql+pypostgresql://user:password@host:port/dbname[?key=value&key=value...] + :url: http://python.projects.pgfoundry.org/ + + +""" +from ... import util +from ... import types as sqltypes +from .base import PGDialect, PGExecutionContext +from ... import processors + + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return processors.to_str + + def result_processor(self, dialect, coltype): + if self.asdecimal: + return None + else: + return processors.to_float + + +class PGExecutionContext_pypostgresql(PGExecutionContext): + pass + + +class PGDialect_pypostgresql(PGDialect): + driver = 'pypostgresql' + + supports_unicode_statements = True + supports_unicode_binds = True + description_encoding = None + default_paramstyle = 'pyformat' + + # requires trunk version to support sane rowcounts + # TODO: use dbapi version information to set this flag appropriately + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = PGExecutionContext_pypostgresql + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: PGNumeric, + + # prevents PGNumeric from being used + sqltypes.Float: sqltypes.Float, + } + ) + + @classmethod + def dbapi(cls): + from postgresql.driver import dbapi20 + return dbapi20 + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + else: + opts['port'] = 5432 + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + return "connection is closed" in str(e) + +dialect = PGDialect_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py new file mode 100644 index 00000000..57b0c4c3 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -0,0 +1,160 @@ +# Copyright (C) 2013-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .base import ischema_names +from ... import types as sqltypes + +__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') + +class RangeOperators(object): + """ + This mixin provides functionality for the Range Operators + listed in Table 9-44 of the `postgres documentation`__ for Range + Functions and Operators. It is used by all the range types + provided in the ``postgres`` dialect and can likely be used for + any range types you create yourself. + + __ http://www.postgresql.org/docs/devel/static/functions-range.html + + No extra support is provided for the Range Functions listed in + Table 9-45 of the postgres documentation. For these, the normal + :func:`~sqlalchemy.sql.expression.func` object should be used. + + .. versionadded:: 0.8.2 Support for Postgresql RANGE operations. + + """ + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for range types.""" + + def __ne__(self, other): + "Boolean expression. Returns true if two ranges are not equal" + return self.expr.op('<>')(other) + + def contains(self, other, **kw): + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.op('<@')(other) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.op('&&')(other) + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.op('<<')(other) + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.op('>>')(other) + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.op('&<')(other) + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.op('&>')(other) + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.op('-|-')(other) + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contigous. + """ + return self.expr.op('+')(other) + +class INT4RANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql INT4RANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'INT4RANGE' + +ischema_names['int4range'] = INT4RANGE + +class INT8RANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql INT8RANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'INT8RANGE' + +ischema_names['int8range'] = INT8RANGE + +class NUMRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql NUMRANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'NUMRANGE' + +ischema_names['numrange'] = NUMRANGE + +class DATERANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql DATERANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'DATERANGE' + +ischema_names['daterange'] = DATERANGE + +class TSRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql TSRANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'TSRANGE' + +ischema_names['tsrange'] = TSRANGE + +class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the Postgresql TSTZRANGE type. + + .. versionadded:: 0.8.2 + + """ + + __visit_name__ = 'TSTZRANGE' + +ischema_names['tstzrange'] = TSTZRANGE diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py new file mode 100644 index 00000000..67e7d53e --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -0,0 +1,45 @@ +# postgresql/zxjdbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: postgresql+zxjdbc + :name: zxJDBC for Jython + :dbapi: zxjdbc + :connectstring: postgresql+zxjdbc://scott:tiger@localhost/db + :driverurl: http://jdbc.postgresql.org/ + + +""" +from ...connectors.zxJDBC import ZxJDBCConnector +from .base import PGDialect, PGExecutionContext + + +class PGExecutionContext_zxjdbc(PGExecutionContext): + + def create_cursor(self): + cursor = self._dbapi_connection.cursor() + cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) + return cursor + + +class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): + jdbc_db_name = 'postgresql' + jdbc_driver_name = 'org.postgresql.Driver' + + execution_ctx_cls = PGExecutionContext_zxjdbc + + supports_native_decimal = True + + def __init__(self, *args, **kwargs): + super(PGDialect_zxjdbc, self).__init__(*args, **kwargs) + from com.ziclix.python.sql.handler import PostgresqlDataHandler + self.DataHandler = PostgresqlDataHandler + + def _get_server_version_info(self, connection): + parts = connection.connection.dbversion.split('.') + return tuple(int(x) for x in parts) + +dialect = PGDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 00000000..80846c9e --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,19 @@ +# sqlite/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.sqlite import base, pysqlite + +# default dialect +base.dialect = pysqlite.dialect + +from sqlalchemy.dialects.sqlite.base import ( + BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, REAL, + NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, dialect, +) + +__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', + 'FLOAT', 'INTEGER', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', + 'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 00000000..90df9c19 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,1049 @@ +# sqlite/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sqlite + :name: SQLite + + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does +not provide out of the box functionality for translating values between Python +`datetime` objects and a SQLite-supported format. SQLAlchemy's own +:class:`~sqlalchemy.types.DateTime` and related types provide date formatting +and parsing functionality when SQlite is used. The implementation classes are +:class:`~.sqlite.DATETIME`, :class:`~.sqlite.DATE` and :class:`~.sqlite.TIME`. +These types represent dates and times as ISO formatted strings, which also +nicely support ordering. There's no reliance on typical "libc" internals for +these functions so historical dates are fully supported. + +Auto Incrementing Behavior +-------------------------- + +Background on SQLite's autoincrement is at: http://sqlite.org/autoinc.html + +Two things to note: + +* The AUTOINCREMENT keyword is **not** required for SQLite tables to + generate primary key values automatically. AUTOINCREMENT only means that the + algorithm used to generate ROWID values should be slightly different. +* SQLite does **not** generate primary key (i.e. ROWID) values, even for + one column, if the table has a composite (i.e. multi-column) primary key. + This is regardless of the AUTOINCREMENT keyword being present or not. + +To specifically render the AUTOINCREMENT keyword on the primary key column when +rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table +construct:: + + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) + +Transaction Isolation Level +--------------------------- + +:func:`.create_engine` accepts an ``isolation_level`` parameter which results +in the command ``PRAGMA read_uncommitted `` being invoked for every new +connection. Valid values for this parameter are ``SERIALIZABLE`` and ``READ +UNCOMMITTED`` corresponding to a value of 0 and 1, respectively. See the +section :ref:`pysqlite_serializable` for an important workaround when using +serializable isolation with Pysqlite. + +Database Locking Behavior / Concurrency +--------------------------------------- + +Note that SQLite is not designed for a high level of concurrency. The database +itself, being a file, is locked completely during write operations and within +transactions, meaning exactly one connection has exclusive access to the +database during this period - all other connections will be blocked during this +time. + +The Python DBAPI specification also calls for a connection model that is always +in a transaction; there is no BEGIN method, only commit and rollback. This +implies that a SQLite DBAPI driver would technically allow only serialized +access to a particular database file at all times. The pysqlite driver attempts +to ameliorate this by deferring the actual BEGIN statement until the first DML +(INSERT, UPDATE, or DELETE) is received within a transaction. While this breaks +serializable isolation, it at least delays the exclusive locking inherent in +SQLite's design. + +SQLAlchemy's default mode of usage with the ORM is known as "autocommit=False", +which means the moment the :class:`.Session` begins to be used, a transaction +is begun. As the :class:`.Session` is used, the autoflush feature, also on by +default, will flush out pending changes to the database before each query. The +effect of this is that a :class:`.Session` used in its default mode will often +emit DML early on, long before the transaction is actually committed. This +again will have the effect of serializing access to the SQLite database. If +highly concurrent reads are desired against the SQLite database, it is advised +that the autoflush feature be disabled, and potentially even that autocommit be +re-enabled, which has the effect of each SQL statement and flush committing +changes immediately. + +For more information on SQLite's lack of concurrency by design, please see +`Situations Where Another RDBMS May Work Better - High Concurrency +`_ near the bottom of the page. + +.. _sqlite_foreign_keys: + +Foreign Key Support +------------------- + +SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables, +however by default these constraints have no effect on the operation of the +table. + +Constraint checking on SQLite has three prerequisites: + +* At least version 3.6.19 of SQLite must be in use +* The SQLite libary must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY + or SQLITE_OMIT_TRIGGER symbols enabled. +* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all connections + before use. + +SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for +new connections through the usage of events:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + +.. seealso:: + + `SQLite Foreign Key Support `_ - on + the SQLite web site. + + :ref:`event_toplevel` - SQLAlchemy event API. + +.. _sqlite_type_reflection: + +Type Reflection +--------------- + +SQLite types are unlike those of most other database backends, in that +the string name of the type usually does not correspond to a "type" in a +one-to-one fashion. Instead, SQLite links per-column typing behavior +to one of five so-called "type affinities" based on a string matching +pattern for the type. + +SQLAlchemy's reflection process, when inspecting types, uses a simple +lookup table to link the keywords returned to provided SQLAlchemy types. +This lookup table is present within the SQLite dialect as it is for all +other dialects. However, the SQLite dialect has a different "fallback" +routine for when a particular type name is not located in the lookup map; +it instead implements the SQLite "type affinity" scheme located at +http://www.sqlite.org/datatype3.html section 2.1. + +The provided typemap will make direct associations from an exact string +name match for the following types: + +:class:`~.types.BIGINT`, :class:`~.types.BLOB`, +:class:`~.types.BOOLEAN`, :class:`~.types.BOOLEAN`, +:class:`~.types.CHAR`, :class:`~.types.DATE`, +:class:`~.types.DATETIME`, :class:`~.types.FLOAT`, +:class:`~.types.DECIMAL`, :class:`~.types.FLOAT`, +:class:`~.types.INTEGER`, :class:`~.types.INTEGER`, +:class:`~.types.NUMERIC`, :class:`~.types.REAL`, +:class:`~.types.SMALLINT`, :class:`~.types.TEXT`, +:class:`~.types.TIME`, :class:`~.types.TIMESTAMP`, +:class:`~.types.VARCHAR`, :class:`~.types.NVARCHAR`, +:class:`~.types.NCHAR` + +When a type name does not match one of the above types, the "type affinity" +lookup is used instead: + +* :class:`~.types.INTEGER` is returned if the type name includes the + string ``INT`` +* :class:`~.types.TEXT` is returned if the type name includes the + string ``CHAR``, ``CLOB`` or ``TEXT`` +* :class:`~.types.NullType` is returned if the type name includes the + string ``BLOB`` +* :class:`~.types.REAL` is returned if the type name includes the string + ``REAL``, ``FLOA`` or ``DOUB``. +* Otherwise, the :class:`~.types.NUMERIC` type is used. + +.. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting + columns. + +""" + +import datetime +import re + +from ... import processors +from ... import sql, exc +from ... import types as sqltypes, schema as sa_schema +from ... import util +from ...engine import default, reflection +from ...sql import compiler + +from ...types import (BLOB, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, INTEGER, REAL, + NUMERIC, SMALLINT, TEXT, TIMESTAMP, VARCHAR) + + +class _DateTimeMixin(object): + _reg = None + _storage_format = None + + def __init__(self, storage_format=None, regexp=None, **kw): + super(_DateTimeMixin, self).__init__(**kw) + if regexp is not None: + self._reg = re.compile(regexp) + if storage_format is not None: + self._storage_format = storage_format + + def adapt(self, cls, **kw): + if issubclass(cls, _DateTimeMixin): + if self._storage_format: + kw["storage_format"] = self._storage_format + if self._reg: + kw["regexp"] = self._reg + return super(_DateTimeMixin, self).adapt(cls, **kw) + + def literal_processor(self, dialect): + bp = self.bind_processor(dialect) + def process(value): + return "'%s'" % bp(value) + return process + + +class DATETIME(_DateTimeMixin, sqltypes.DateTime): + """Represent a Python datetime object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(min)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 2011-03-15 12:05:57.10558 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATETIME + + dt = DATETIME( + storage_format="%(year)04d/%(month)02d/%(day)02d %(hour)02d:%(min)02d:%(second)02d", + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" + ) + + :param storage_format: format string which will be applied to the dict with + keys year, month, day, hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows. If the regexp contains named groups, the resulting match dict is + applied to the Python datetime() constructor as keyword arguments. + Otherwise, if positional groups are used, the the datetime() constructor + is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + """ + + _storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + ) + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop('truncate_microseconds', False) + super(DATETIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert 'storage_format' not in kwargs, "You can specify only "\ + "one of truncate_microseconds or storage_format." + assert 'regexp' not in kwargs, "You can specify only one of "\ + "truncate_microseconds or regexp." + self._storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d" + ) + + def bind_processor(self, dialect): + datetime_datetime = datetime.datetime + datetime_date = datetime.date + format = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_datetime): + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + 'hour': value.hour, + 'minute': value.minute, + 'second': value.second, + 'microsecond': value.microsecond, + } + elif isinstance(value, datetime_date): + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + 'hour': 0, + 'minute': 0, + 'second': 0, + 'microsecond': 0, + } + else: + raise TypeError("SQLite DateTime type only accepts Python " + "datetime and date objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime) + else: + return processors.str_to_datetime + + +class DATE(_DateTimeMixin, sqltypes.Date): + """Represent a Python date object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d" + + e.g.:: + + 2011-03-15 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATE + + d = DATE( + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)") + ) + + :param storage_format: format string which will be applied to the + dict with keys year, month, and day. + + :param regexp: regular expression which will be applied to + incoming result rows. If the regexp contains named groups, the + resulting match dict is applied to the Python date() constructor + as keyword arguments. Otherwise, if positional groups are used, the + the date() constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + """ + + _storage_format = "%(year)04d-%(month)02d-%(day)02d" + + def bind_processor(self, dialect): + datetime_date = datetime.date + format = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_date): + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + } + else: + raise TypeError("SQLite Date type only accepts Python " + "date objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date) + else: + return processors.str_to_date + + +class TIME(_DateTimeMixin, sqltypes.Time): + """Represent a Python time object in SQLite using a string. + + The default string storage format is:: + + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 12:05:57.10558 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import TIME + + t = TIME( + storage_format="%(hour)02d-%(minute)02d-%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") + ) + + :param storage_format: format string which will be applied to the dict with + keys hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows. If the regexp contains named groups, the resulting match dict is + applied to the Python time() constructor as keyword arguments. Otherwise, + if positional groups are used, the the time() constructor is called with + positional arguments via ``*map(int, match_obj.groups(0))``. + """ + + _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop('truncate_microseconds', False) + super(TIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert 'storage_format' not in kwargs, "You can specify only "\ + "one of truncate_microseconds or storage_format." + assert 'regexp' not in kwargs, "You can specify only one of "\ + "truncate_microseconds or regexp." + self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" + + def bind_processor(self, dialect): + datetime_time = datetime.time + format = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_time): + return format % { + 'hour': value.hour, + 'minute': value.minute, + 'second': value.second, + 'microsecond': value.microsecond, + } + else: + raise TypeError("SQLite Time type only accepts Python " + "time objects as input.") + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time) + else: + return processors.str_to_time + +colspecs = { + sqltypes.Date: DATE, + sqltypes.DateTime: DATETIME, + sqltypes.Time: TIME, +} + +ischema_names = { + 'BIGINT': sqltypes.BIGINT, + 'BLOB': sqltypes.BLOB, + 'BOOL': sqltypes.BOOLEAN, + 'BOOLEAN': sqltypes.BOOLEAN, + 'CHAR': sqltypes.CHAR, + 'DATE': sqltypes.DATE, + 'DATETIME': sqltypes.DATETIME, + 'DOUBLE': sqltypes.FLOAT, + 'DECIMAL': sqltypes.DECIMAL, + 'FLOAT': sqltypes.FLOAT, + 'INT': sqltypes.INTEGER, + 'INTEGER': sqltypes.INTEGER, + 'NUMERIC': sqltypes.NUMERIC, + 'REAL': sqltypes.REAL, + 'SMALLINT': sqltypes.SMALLINT, + 'TEXT': sqltypes.TEXT, + 'TIME': sqltypes.TIME, + 'TIMESTAMP': sqltypes.TIMESTAMP, + 'VARCHAR': sqltypes.VARCHAR, + 'NVARCHAR': sqltypes.NVARCHAR, + 'NCHAR': sqltypes.NCHAR, +} + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W', + }) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_localtimestamp_func(self, func, **kw): + return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + + def visit_true(self, expr, **kw): + return '1' + + def visit_false(self, expr, **kw): + return '0' + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast, **kwargs) + else: + return self.process(cast.clause, **kwargs) + + def visit_extract(self, extract, **kw): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], + self.process(extract.expr, **kw) + ) + except KeyError: + raise exc.CompileError( + "%s is not a valid extract argument." % extract.field) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += "\n LIMIT " + self.process(sql.literal(select._limit)) + if select._offset is not None: + if select._limit is None: + text += "\n LIMIT " + self.process(sql.literal(-1)) + text += " OFFSET " + self.process(sql.literal(select._offset)) + else: + text += " OFFSET " + self.process(sql.literal(0)) + return text + + def for_update_clause(self, select): + # sqlite has no "FOR UPDATE" AFAICT + return '' + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + + def get_column_specification(self, column, **kwargs): + coltype = self.dialect.type_compiler.process(column.type) + colspec = self.preparer.format_column(column) + " " + coltype + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + if (column.primary_key and + column.table.dialect_options['sqlite']['autoincrement'] and + len(column.table.primary_key.columns) == 1 and + issubclass(column.type._type_affinity, sqltypes.Integer) and + not column.foreign_keys): + colspec += " PRIMARY KEY AUTOINCREMENT" + + return colspec + + def visit_primary_key_constraint(self, constraint): + # for columns with sqlite_autoincrement=True, + # the PRIMARY KEY constraint can only be inline + # with the column itself. + if len(constraint.columns) == 1: + c = list(constraint)[0] + if (c.primary_key and + c.table.dialect_options['sqlite']['autoincrement'] and + issubclass(c.type._type_affinity, sqltypes.Integer) and + not c.foreign_keys): + return None + + return super(SQLiteDDLCompiler, self).visit_primary_key_constraint( + constraint) + + def visit_foreign_key_constraint(self, constraint): + + local_table = list(constraint._elements.values())[0].parent.table + remote_table = list(constraint._elements.values())[0].column.table + + if local_table.schema != remote_table.schema: + return None + else: + return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( + constraint) + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table, use_schema=False) + + def visit_create_index(self, create): + return super(SQLiteDDLCompiler, self).visit_create_index( + create, include_table_schema=False) + + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', + 'into', 'is', 'isnull', 'join', 'key', 'left', 'like', 'limit', + 'match', 'natural', 'not', 'notnull', 'null', 'of', 'offset', 'on', + 'or', 'order', 'outer', 'plan', 'pragma', 'primary', 'query', + 'raise', 'references', 'reindex', 'rename', 'replace', 'restrict', + 'right', 'rollback', 'row', 'select', 'set', 'table', 'temp', + 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union', + 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', + 'when', 'where', + ]) + + def format_index(self, index, use_schema=True, name=None): + """Prepare a quoted index and schema name.""" + + if name is None: + name = index.name + result = self.quote(name, index.quote) + if (not self.omit_schema and + use_schema and + getattr(index.table, "schema", None)): + result = self.quote_schema(index.table.schema, + index.table.quote_schema) + "." + result + return result + + +class SQLiteExecutionContext(default.DefaultExecutionContext): + @util.memoized_property + def _preserve_raw_colnames(self): + return self.execution_options.get("sqlite_raw_colnames", False) + + def _translate_colname(self, colname): + # adjust for dotted column names. SQLite in the case of UNION may store + # col names as "tablename.colname" in cursor.description + if not self._preserve_raw_colnames and "." in colname: + return colname.split(".")[1], colname + else: + return colname, None + + +class SQLiteDialect(default.DefaultDialect): + name = 'sqlite' + supports_alter = False + supports_unicode_statements = True + supports_unicode_binds = True + supports_default_values = True + supports_empty_insert = False + supports_cast = True + supports_multivalues_insert = True + supports_right_nested_joins = False + + default_paramstyle = 'qmark' + execution_ctx_cls = SQLiteExecutionContext + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + isolation_level = None + + supports_cast = True + supports_default_values = True + + construct_arguments = [ + (sa_schema.Table, { + "autoincrement": False + }) + ] + + _broken_fk_pragma_quotes = False + + def __init__(self, isolation_level=None, native_datetime=False, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + # this flag used by pysqlite dialect, and perhaps others in the future, + # to indicate the driver is handling date/timestamp conversions (and + # perhaps datetime/time as well on some hypothetical driver ?) + self.native_datetime = native_datetime + + if self.dbapi is not None: + self.supports_default_values = ( + self.dbapi.sqlite_version_info >= (3, 3, 8)) + self.supports_cast = ( + self.dbapi.sqlite_version_info >= (3, 2, 3)) + self.supports_multivalues_insert = ( + # http://www.sqlite.org/releaselog/3_7_11.html + self.dbapi.sqlite_version_info >= (3, 7, 11)) + # see http://www.sqlalchemy.org/trac/ticket/2568 + # as well as http://www.sqlite.org/src/info/600482d161 + self._broken_fk_pragma_quotes = ( + self.dbapi.sqlite_version_info < (3, 6, 14)) + + _isolation_lookup = { + 'READ UNCOMMITTED': 1, + 'SERIALIZABLE': 0, + } + + def set_isolation_level(self, connection, level): + try: + isolation_level = self._isolation_lookup[level.replace('_', ' ')] + except KeyError: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('PRAGMA read_uncommitted') + res = cursor.fetchone() + if res: + value = res[0] + else: + # http://www.sqlite.org/changes.html#version_3_3_3 + # "Optional READ UNCOMMITTED isolation (instead of the + # default isolation level of SERIALIZABLE) and + # table level locking when database connections + # share a common cache."" + # pre-SQLite 3.3.0 default to 0 + value = 0 + cursor.close() + if value == 0: + return "SERIALIZABLE" + elif value == 1: + return "READ UNCOMMITTED" + else: + assert False, "Unknown isolation level %s" % value + + def on_connect(self): + if self.isolation_level is not None: + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + statement = "%stable_info(%s)" % (pragma, qtable) + cursor = _pragma_cursor(connection.execute(statement)) + row = cursor.fetchone() + + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while not cursor.closed and cursor.fetchone() is not None: + pass + + return row is not None + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + s = ("SELECT name FROM sqlite_master " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" + "AND type='view'") % (master, view_name) + rs = connection.execute(s) + else: + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + except exc.DBAPIError: + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + + result = rs.fetchall() + if result: + return result[0].sql + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + statement = "%stable_info(%s)" % (pragma, qtable) + c = _pragma_cursor(connection.execute(statement)) + + rows = c.fetchall() + columns = [] + for row in rows: + (name, type_, nullable, default, primary_key) = ( + row[1], row[2].upper(), not row[3], row[4], row[5]) + + columns.append(self._get_column_info(name, type_, nullable, + default, primary_key)) + return columns + + def _get_column_info(self, name, type_, nullable, default, primary_key): + coltype = self._resolve_type_affinity(type_) + + if default is not None: + default = util.text_type(default) + + return { + 'name': name, + 'type': coltype, + 'nullable': nullable, + 'default': default, + 'autoincrement': default is None, + 'primary_key': primary_key, + } + + def _resolve_type_affinity(self, type_): + """Return a data type from a reflected column, using affinity tules. + + SQLite's goal for universal compatability introduces some complexity + during reflection, as a column's defined type might not actually be a + type that SQLite understands - or indeed, my not be defined *at all*. + Internally, SQLite handles this with a 'data type affinity' for each + column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER', + 'REAL', or 'NONE' (raw bits). The algorithm that determines this is + listed in http://www.sqlite.org/datatype3.html section 2.1. + + This method allows SQLAlchemy to support that algorithm, while still + providing access to smarter reflection utilities by regcognizing + column definitions that SQLite only supports through affinity (like + DATE and DOUBLE). + + """ + match = re.match(r'([\w ]+)(\(.*?\))?', type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = '' + args = '' + + if coltype in self.ischema_names: + coltype = self.ischema_names[coltype] + elif 'INT' in coltype: + coltype = sqltypes.INTEGER + elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype: + coltype = sqltypes.TEXT + elif 'BLOB' in coltype or not coltype: + coltype = sqltypes.NullType + elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype: + coltype = sqltypes.REAL + else: + coltype = sqltypes.NUMERIC + + if args is not None: + args = re.findall(r'(\d+)', args) + try: + coltype = coltype(*[int(a) for a in args]) + except TypeError: + util.warn( + "Could not instantiate type %s with " + "reflected arguments %s; using no arguments." % + (coltype, args)) + coltype = coltype() + else: + coltype = coltype() + + return coltype + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + cols = self.get_columns(connection, table_name, schema, **kw) + pkeys = [] + for col in cols: + if col['primary_key']: + pkeys.append(col['name']) + return {'constrained_columns': pkeys, 'name': None} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + statement = "%sforeign_key_list(%s)" % (pragma, qtable) + c = _pragma_cursor(connection.execute(statement)) + fkeys = [] + fks = {} + while True: + row = c.fetchone() + if row is None: + break + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + + self._parse_fk(fks, fkeys, numerical_id, rtbl, lcol, rcol) + return fkeys + + def _parse_fk(self, fks, fkeys, numerical_id, rtbl, lcol, rcol): + # sqlite won't return rcol if the table was created with REFERENCES + # , no col + if rcol is None: + rcol = lcol + + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + + try: + fk = fks[numerical_id] + except KeyError: + fk = { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': rtbl, + 'referred_columns': [], + } + fkeys.append(fk) + fks[numerical_id] = fk + + if lcol not in fk['constrained_columns']: + fk['constrained_columns'].append(lcol) + if rcol not in fk['referred_columns']: + fk['referred_columns'].append(rcol) + return fk + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + include_auto_indexes = kw.pop('include_auto_indexes', False) + qtable = quote(table_name) + statement = "%sindex_list(%s)" % (pragma, qtable) + c = _pragma_cursor(connection.execute(statement)) + indexes = [] + while True: + row = c.fetchone() + if row is None: + break + # ignore implicit primary key index. + # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html + elif (not include_auto_indexes and + row[1].startswith('sqlite_autoindex')): + continue + + indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. + for idx in indexes: + statement = "%sindex_info(%s)" % (pragma, quote(idx['name'])) + c = connection.execute(statement) + cols = idx['column_names'] + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + return indexes + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + UNIQUE_SQL = """ + SELECT sql + FROM + sqlite_master + WHERE + type='table' AND + name=:table_name + """ + c = connection.execute(UNIQUE_SQL, table_name=table_name) + table_data = c.fetchone()[0] + + UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' + return [ + {'name': name, + 'column_names': [col.strip(' "') for col in cols.split(',')]} + for name, cols in re.findall(UNIQUE_PATTERN, table_data) + ] + + +def _pragma_cursor(cursor): + """work around SQLite issue whereby cursor.description + is blank when PRAGMA returns no rows.""" + + if cursor.closed: + cursor.fetchone = lambda: None + cursor.fetchall = lambda: [] + return cursor diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 00000000..b53f4d4a --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,335 @@ +# sqlite/pysqlite.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sqlite+pysqlite + :name: pysqlite + :dbapi: sqlite3 + :connectstring: sqlite+pysqlite:///file_path + :url: http://docs.python.org/library/sqlite3.html + + Note that ``pysqlite`` is the same driver as the ``sqlite3`` + module included with the Python distribution. + +Driver +------ + +When using Python 2.5 and above, the built in ``sqlite3`` driver is +already installed and no additional installation is needed. Otherwise, +the ``pysqlite2`` driver needs to be present. This is the same driver as +``sqlite3``, just with a different name. + +The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` +is loaded. This allows an explicitly installed pysqlite driver to take +precedence over the built in one. As with all dialects, a specific +DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control +this explicitly:: + + from sqlite3 import dbapi2 as sqlite + e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) + + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" +portion of the URL. Note that the format of a SQLAlchemy url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to +the **right** of the third slash. So connecting to a relative filepath +looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you +need **four** slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be +used. Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is +present. Specify ``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +Compatibility with sqlite3 "native" date and datetime types +----------------------------------------------------------- + +The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and +sqlite3.PARSE_COLNAMES options, which have the effect of any column +or expression explicitly cast as "date" or "timestamp" will be converted +to a Python date or datetime object. The date and datetime types provided +with the pysqlite dialect are not currently compatible with these options, +since they render the ISO date/datetime including microseconds, which +pysqlite's driver does not. Additionally, SQLAlchemy does not at +this time automatically render the "cast" syntax required for the +freestanding functions "current_timestamp" and "current_date" to return +datetime/date types natively. Unfortunately, pysqlite +does not provide the standard DBAPI types in ``cursor.description``, +leaving SQLAlchemy with no way to detect these types on the fly +without expensive per-row type checks. + +Keeping in mind that pysqlite's parsing option is not recommended, +nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES +can be forced if one configures "native_datetime=True" on create_engine():: + + engine = create_engine('sqlite://', + connect_args={'detect_types': sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True + ) + +With this flag enabled, the DATE and TIMESTAMP types (but note - not the +DATETIME or TIME types...confused yet ?) will not perform any bind parameter +or result processing. Execution of "func.current_date()" will return a string. +"func.current_timestamp()" is registered as returning a DATETIME type in +SQLAlchemy, so this function still receives SQLAlchemy-level result processing. + +.. _pysqlite_threading_pooling: + +Threading/Pooling Behavior +--------------------------- + +Pysqlite's default behavior is to prohibit the usage of a single connection +in more than one thread. This is originally intended to work with older +versions of SQLite that did not support multithreaded operation under +various circumstances. In particular, older SQLite versions +did not allow a ``:memory:`` database to be used in multiple threads +under any circumstances. + +Pysqlite does include a now-undocumented flag known as +``check_same_thread`` which will disable this check, however note that pysqlite +connections are still not safe to use in concurrently in multiple threads. +In particular, any statement execution calls would need to be externally +mutexed, as Pysqlite does not provide for thread-safe propagation of error +messages among other things. So while even ``:memory:`` databases can be +shared among threads in modern SQLite, Pysqlite doesn't provide enough +thread-safety to make this usage worth it. + +SQLAlchemy sets up pooling to work with Pysqlite's default behavior: + +* When a ``:memory:`` SQLite database is specified, the dialect by default + will use :class:`.SingletonThreadPool`. This pool maintains a single + connection per thread, so that all access to the engine within the current + thread use the same ``:memory:`` database - other threads would access a + different ``:memory:`` database. +* When a file-based database is specified, the dialect will use + :class:`.NullPool` as the source of connections. This pool closes and + discards connections which are returned to the pool immediately. SQLite + file-based connections have extremely low overhead, so pooling is not + necessary. The scheme also prevents a connection from being used again in + a different thread and works best with SQLite's coarse-grained file locking. + + .. versionchanged:: 0.7 + Default selection of :class:`.NullPool` for SQLite file-based databases. + Previous versions select :class:`.SingletonThreadPool` by + default for all SQLite databases. + + +Using a Memory Database in Multiple Threads +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use a ``:memory:`` database in a multithreaded scenario, the same connection +object must be shared among threads, since the database exists +only within the scope of that connection. The +:class:`.StaticPool` implementation will maintain a single connection +globally, and the ``check_same_thread`` flag can be passed to Pysqlite +as ``False``:: + + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite://', + connect_args={'check_same_thread':False}, + poolclass=StaticPool) + +Note that using a ``:memory:`` database in multiple threads requires a recent +version of SQLite. + +Using Temporary Tables with SQLite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to the way SQLite deals with temporary tables, if you wish to use a +temporary table in a file-based SQLite database across multiple checkouts +from the connection pool, such as when using an ORM :class:`.Session` where +the temporary table should continue to remain after :meth:`.Session.commit` or +:meth:`.Session.rollback` is called, a pool which maintains a single connection must +be used. Use :class:`.SingletonThreadPool` if the scope is only needed +within the current thread, or :class:`.StaticPool` is scope is needed within +multiple threads for this case:: + + # maintain the same connection per thread + from sqlalchemy.pool import SingletonThreadPool + engine = create_engine('sqlite:///mydb.db', + poolclass=SingletonThreadPool) + + + # maintain the same connection across all threads + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite:///mydb.db', + poolclass=StaticPool) + +Note that :class:`.SingletonThreadPool` should be configured for the number +of threads that are to be used; beyond that number, connections will be +closed out in a non deterministic way. + +Unicode +------- + +The pysqlite driver only returns Python ``unicode`` objects in result sets, +never plain strings, and accommodates ``unicode`` objects within bound +parameter values in all cases. Regardless of the SQLAlchemy string type in +use, string-based result values will by Python ``unicode`` in Python 2. +The :class:`.Unicode` type should still be used to indicate those columns that +require unicode, however, so that non-``unicode`` values passed inadvertently +will emit a warning. Pysqlite will emit an error if a non-``unicode`` string +is passed containing non-ASCII characters. + +.. _pysqlite_serializable: + +Serializable Transaction Isolation +---------------------------------- + +The pysqlite DBAPI driver has a long-standing bug in which transactional +state is not begun until the first DML statement, that is INSERT, UPDATE +or DELETE, is emitted. A SELECT statement will not cause transactional +state to begin. While this mode of usage is fine for typical situations +and has the advantage that the SQLite database file is not prematurely +locked, it breaks serializable transaction isolation, which requires +that the database file be locked upon any SQL being emitted. + +To work around this issue, the ``BEGIN`` keyword can be emitted +at the start of each transaction. The following recipe establishes +a :meth:`.ConnectionEvents.begin` handler to achieve this:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db", isolation_level='SERIALIZABLE') + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.execute("BEGIN") + +""" + +from sqlalchemy.dialects.sqlite.base import SQLiteDialect, DATETIME, DATE +from sqlalchemy import exc, pool +from sqlalchemy import types as sqltypes +from sqlalchemy import util + +import os + + +class _SQLite_pysqliteTimeStamp(DATETIME): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATETIME.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATETIME.result_processor(self, dialect, coltype) + + +class _SQLite_pysqliteDate(DATE): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATE.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATE.result_processor(self, dialect, coltype) + + +class SQLiteDialect_pysqlite(SQLiteDialect): + default_paramstyle = 'qmark' + + colspecs = util.update_copy( + SQLiteDialect.colspecs, + { + sqltypes.Date: _SQLite_pysqliteDate, + sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, + } + ) + + if not util.py2k: + description_encoding = None + + driver = 'pysqlite' + + def __init__(self, **kwargs): + SQLiteDialect.__init__(self, **kwargs) + + if self.dbapi is not None: + sqlite_ver = self.dbapi.version_info + if sqlite_ver < (2, 1, 3): + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) + + @classmethod + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError as e: + try: + from sqlite3 import dbapi2 as sqlite # try 2.5+ stdlib name. + except ImportError: + raise e + return sqlite + + @classmethod + def get_pool_class(cls, url): + if url.database and url.database != ':memory:': + return pool.NullPool + else: + return pool.SingletonThreadPool + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) + filename = url.database or ':memory:' + if filename != ':memory:': + filename = os.path.abspath(filename) + + opts = url.query.copy() + util.coerce_kw_type(opts, 'timeout', float) + util.coerce_kw_type(opts, 'isolation_level', str) + util.coerce_kw_type(opts, 'detect_types', int) + util.coerce_kw_type(opts, 'check_same_thread', bool) + util.coerce_kw_type(opts, 'cached_statements', int) + + return ([filename], opts) + + def is_disconnect(self, e, connection, cursor): + return isinstance(e, self.dbapi.ProgrammingError) and \ + "Cannot operate on a closed database." in str(e) + +dialect = SQLiteDialect_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 00000000..85f9dd9c --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,27 @@ +# sybase/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.sybase import base, pysybase, pyodbc + +# default dialect +base.dialect = pyodbc.dialect + +from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ + TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ + BIGINT, INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\ + IMAGE, BIT, MONEY, SMALLMONEY, TINYINT,\ + dialect + + +__all__ = ( + 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR', + 'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC', + 'BIGINT', 'INT', 'INTEGER', 'SMALLINT', 'BINARY', + 'VARBINARY', 'UNITEXT', 'UNICHAR', 'UNIVARCHAR', + 'IMAGE', 'BIT', 'MONEY', 'SMALLMONEY', 'TINYINT', + 'dialect' +) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 00000000..50127077 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,816 @@ +# sybase/base.py +# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# get_select_precolumns(), limit_clause() implementation +# copyright (C) 2007 Fisch Asset Management +# AG http://www.fam.ch, with coding by Alexander Houben +# alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: sybase + :name: Sybase + +.. note:: + + The Sybase dialect functions on current SQLAlchemy versions + but is not regularly tested, and may have many issues and + caveats not currently handled. + +""" +import operator +import re + +from sqlalchemy.sql import compiler, expression, text, bindparam +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, exc + +from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ + TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ + BIGINT, INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ + UnicodeText, REAL + +RESERVED_WORDS = set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + + +class _SybaseUnitypeMixin(object): + """these types appear to return a buffer object.""" + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return str(value) # decode("ucs-2") + else: + return None + return process + + +class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = 'UNICHAR' + + +class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = 'UNIVARCHAR' + + +class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): + __visit_name__ = 'UNITEXT' + + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = 'IMAGE' + + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_UNICHAR(self, type_): + return "UNICHAR(%d)" % type_.length + + def visit_UNIVARCHAR(self, type_): + return "UNIVARCHAR(%d)" % type_.length + + def visit_UNITEXT(self, type_): + return "UNITEXT" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + +ischema_names = { + 'bigint': BIGINT, + 'int': INTEGER, + 'integer': INTEGER, + 'smallint': SMALLINT, + 'tinyint': TINYINT, + 'unsigned bigint': BIGINT, # TODO: unsigned flags + 'unsigned int': INTEGER, # TODO: unsigned flags + 'unsigned smallint': SMALLINT, # TODO: unsigned flags + 'numeric': NUMERIC, + 'decimal': DECIMAL, + 'dec': DECIMAL, + 'float': FLOAT, + 'double': NUMERIC, # TODO + 'double precision': NUMERIC, # TODO + 'real': REAL, + 'smallmoney': SMALLMONEY, + 'money': MONEY, + 'smalldatetime': DATETIME, + 'datetime': DATETIME, + 'date': DATE, + 'time': TIME, + 'char': CHAR, + 'character': CHAR, + 'varchar': VARCHAR, + 'character varying': VARCHAR, + 'char varying': VARCHAR, + 'unichar': UNICHAR, + 'unicode character': UNIVARCHAR, + 'nchar': NCHAR, + 'national char': NCHAR, + 'national character': NCHAR, + 'nvarchar': NVARCHAR, + 'nchar varying': NVARCHAR, + 'national char varying': NVARCHAR, + 'national character varying': NVARCHAR, + 'text': TEXT, + 'unitext': UNITEXT, + 'binary': BINARY, + 'varbinary': VARBINARY, + 'image': IMAGE, + 'bit': BIT, + +# not in documentation for ASE 15.7 + 'long varchar': TEXT, # TODO + 'timestamp': TIMESTAMP, + 'uniqueidentifier': UNIQUEIDENTIFIER, + +} + + +class SybaseInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_id(self, table_name, schema=None): + """Return the table id from `table_name` and `schema`.""" + + return self.dialect.get_table_id(self.bind, table_name, schema, + info_cache=self.info_cache) + + +class SybaseExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + + def set_ddl_autocommit(self, connection, value): + """Must be implemented by subclasses to accommodate DDL executions. + + "connection" is the raw unwrapped DBAPI connection. "value" + is True or False. when True, the connection should be configured + such that a DDL can take place subsequently. when False, + a DDL has taken place and the connection should be resumed + into non-autocommit mode. + + """ + raise NotImplementedError() + + def pre_exec(self): + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = \ + seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + if self.isddl: + # TODO: to enhance this, we can detect "ddl in tran" on the + # database settings. this error message should be improved to + # include a note about that. + if not self.should_autocommit: + raise exc.InvalidRequestError( + "The Sybase dialect only supports " + "DDL in 'autocommit' mode at this time.") + + self.root_connection.engine.logger.info( + "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + + self.set_ddl_autocommit( + self.root_connection.connection.connection, + True) + + def post_exec(self): + if self.isddl: + self.set_ddl_autocommit(self.root_connection, False) + + if self._enable_identity_insert: + self.cursor.execute( + "SET IDENTITY_INSERT %s OFF" % + self.dialect.identifier_preparer. + format_table(self.compiled.statement.table) + ) + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT @@identity AS lastrowid") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class SybaseSQLCompiler(compiler.SQLCompiler): + ansi_bind_rules = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + # TODO: don't think Sybase supports + # bind params for FIRST / TOP + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset + 1,) + return s + + def get_from_hint_text(self, table, text): + return text + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % ( + field, self.process(extract.expr, **kw)) + + def visit_now_func(self, fn, **kw): + return "GETDATE()" + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" + # which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select, **kw): + kw['literal_binds'] = True + order_by = self.process(select._order_by_clause, **kw) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + + if column.table is None: + raise exc.CompileError( + "The Sybase dialect requires Table-bound " + "columns in order to generate DDL") + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = isinstance(column.default, sa_schema.Sequence) \ + and column.default + if sequence: + start, increment = sequence.start or 1, \ + sequence.increment or 1 + else: + start, increment = 1, 1 + if (start, increment) == (1, 1): + colspec += " IDENTITY" + else: + # TODO: need correct syntax for this + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self._prepared_index_name(drop.element, + include_schema=False) + ) + + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + +class SybaseDialect(default.DefaultDialect): + name = 'sybase' + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_native_boolean = False + supports_unicode_binds = False + postfetch_lastrowid = True + + colspecs = {} + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + inspector = SybaseInspector + + construct_arguments = [] + + def _get_default_schema_name(self, connection): + return connection.scalar( + text("SELECT user_name() as user_name", + typemap={'user_name': Unicode}) + ) + + def initialize(self, connection): + super(SybaseDialect, self).initialize(connection) + if self.server_version_info is not None and\ + self.server_version_info < (15, ): + self.max_identifier_length = 30 + else: + self.max_identifier_length = 255 + + def get_table_id(self, connection, table_name, schema=None, **kw): + """Fetch the id for schema.table_name. + + Several reflection methods require the table id. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + + table_id = None + if schema is None: + schema = self.default_schema_name + + TABLEID_SQL = text(""" + SELECT o.id AS id + FROM sysobjects o JOIN sysusers u ON o.uid=u.uid + WHERE u.name = :schema_name + AND o.name = :table_name + AND o.type in ('U', 'V') + """) + + if util.py2k: + if isinstance(schema, unicode): + schema = schema.encode("ascii") + if isinstance(table_name, unicode): + table_name = table_name.encode("ascii") + result = connection.execute(TABLEID_SQL, + schema_name=schema, + table_name=table_name) + table_id = result.scalar() + if table_id is None: + raise exc.NoSuchTableError(table_name) + return table_id + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + COLUMN_SQL = text(""" + SELECT col.name AS name, + t.name AS type, + (col.status & 8) AS nullable, + (col.status & 128) AS autoincrement, + com.text AS 'default', + col.prec AS precision, + col.scale AS scale, + col.length AS length + FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON + col.cdefault = com.id + WHERE col.usertype = t.usertype + AND col.id = :table_id + ORDER BY col.colid + """) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + + columns = [] + for (name, type_, nullable, autoincrement, default, precision, scale, + length) in results: + col_info = self._get_column_info(name, type_, bool(nullable), + bool(autoincrement), default, precision, scale, + length) + columns.append(col_info) + + return columns + + def _get_column_info(self, name, type_, nullable, autoincrement, default, + precision, scale, length): + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + + if coltype in (NUMERIC, DECIMAL): + args = (precision, scale) + elif coltype == FLOAT: + args = (precision,) + elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR): + args = (length,) + else: + args = () + + if coltype: + coltype = coltype(*args, **kwargs) + #is this necessary + #if is_array: + # coltype = ARRAY(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + coltype = sqltypes.NULLTYPE + + if default: + default = re.sub("DEFAULT", "", default).strip() + default = re.sub("^'(.*)'$", lambda m: m.group(1), default) + else: + default = None + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + return column_info + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + table_cache = {} + column_cache = {} + foreign_keys = [] + + table_cache[table_id] = {"name": table_name, "schema": schema} + + COLUMN_SQL = text(""" + SELECT c.colid AS id, c.name AS name + FROM syscolumns c + WHERE c.id = :table_id + """) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + columns = {} + for col in results: + columns[col["id"]] = col["name"] + column_cache[table_id] = columns + + REFCONSTRAINT_SQL = text(""" + SELECT o.name AS name, r.reftabid AS reftable_id, + r.keycnt AS 'count', + r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3, + r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6, + r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9, + r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12, + r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15, + r.fokey16 AS fokey16, + r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3, + r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6, + r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9, + r.refkey10 AS refkey10, r.refkey11 AS refkey11, + r.refkey12 AS refkey12, r.refkey13 AS refkey13, + r.refkey14 AS refkey14, r.refkey15 AS refkey15, + r.refkey16 AS refkey16 + FROM sysreferences r JOIN sysobjects o on r.tableid = o.id + WHERE r.tableid = :table_id + """) + referential_constraints = connection.execute(REFCONSTRAINT_SQL, + table_id=table_id) + + REFTABLE_SQL = text(""" + SELECT o.name AS name, u.name AS 'schema' + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE o.id = :table_id + """) + + for r in referential_constraints: + reftable_id = r["reftable_id"] + + if reftable_id not in table_cache: + c = connection.execute(REFTABLE_SQL, table_id=reftable_id) + reftable = c.fetchone() + c.close() + table_info = {"name": reftable["name"], "schema": None} + if (schema is not None or + reftable["schema"] != self.default_schema_name): + table_info["schema"] = reftable["schema"] + + table_cache[reftable_id] = table_info + results = connection.execute(COLUMN_SQL, table_id=reftable_id) + reftable_columns = {} + for col in results: + reftable_columns[col["id"]] = col["name"] + column_cache[reftable_id] = reftable_columns + + reftable = table_cache[reftable_id] + reftable_columns = column_cache[reftable_id] + + constrained_columns = [] + referred_columns = [] + for i in range(1, r["count"] + 1): + constrained_columns.append(columns[r["fokey%i" % i]]) + referred_columns.append(reftable_columns[r["refkey%i" % i]]) + + fk_info = { + "constrained_columns": constrained_columns, + "referred_schema": reftable["schema"], + "referred_table": reftable["name"], + "referred_columns": referred_columns, + "name": r["name"] + } + + foreign_keys.append(fk_info) + + return foreign_keys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + INDEX_SQL = text(""" + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + (i.status & 0x2) AS 'unique', + index_col(object_name(i.id), i.indid, 1) AS col_1, + index_col(object_name(i.id), i.indid, 2) AS col_2, + index_col(object_name(i.id), i.indid, 3) AS col_3, + index_col(object_name(i.id), i.indid, 4) AS col_4, + index_col(object_name(i.id), i.indid, 5) AS col_5, + index_col(object_name(i.id), i.indid, 6) AS col_6, + index_col(object_name(i.id), i.indid, 7) AS col_7, + index_col(object_name(i.id), i.indid, 8) AS col_8, + index_col(object_name(i.id), i.indid, 9) AS col_9, + index_col(object_name(i.id), i.indid, 10) AS col_10, + index_col(object_name(i.id), i.indid, 11) AS col_11, + index_col(object_name(i.id), i.indid, 12) AS col_12, + index_col(object_name(i.id), i.indid, 13) AS col_13, + index_col(object_name(i.id), i.indid, 14) AS col_14, + index_col(object_name(i.id), i.indid, 15) AS col_15, + index_col(object_name(i.id), i.indid, 16) AS col_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 0 + AND i.indid BETWEEN 1 AND 254 + """) + + results = connection.execute(INDEX_SQL, table_id=table_id) + indexes = [] + for r in results: + column_names = [] + for i in range(1, r["count"]): + column_names.append(r["col_%i" % (i,)]) + index_info = {"name": r["name"], + "unique": bool(r["unique"]), + "column_names": column_names} + indexes.append(index_info) + + return indexes + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id(connection, table_name, schema, + info_cache=kw.get("info_cache")) + + PK_SQL = text(""" + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + index_col(object_name(i.id), i.indid, 1) AS pk_1, + index_col(object_name(i.id), i.indid, 2) AS pk_2, + index_col(object_name(i.id), i.indid, 3) AS pk_3, + index_col(object_name(i.id), i.indid, 4) AS pk_4, + index_col(object_name(i.id), i.indid, 5) AS pk_5, + index_col(object_name(i.id), i.indid, 6) AS pk_6, + index_col(object_name(i.id), i.indid, 7) AS pk_7, + index_col(object_name(i.id), i.indid, 8) AS pk_8, + index_col(object_name(i.id), i.indid, 9) AS pk_9, + index_col(object_name(i.id), i.indid, 10) AS pk_10, + index_col(object_name(i.id), i.indid, 11) AS pk_11, + index_col(object_name(i.id), i.indid, 12) AS pk_12, + index_col(object_name(i.id), i.indid, 13) AS pk_13, + index_col(object_name(i.id), i.indid, 14) AS pk_14, + index_col(object_name(i.id), i.indid, 15) AS pk_15, + index_col(object_name(i.id), i.indid, 16) AS pk_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 2048 + AND i.indid BETWEEN 1 AND 254 + """) + + results = connection.execute(PK_SQL, table_id=table_id) + pks = results.fetchone() + results.close() + + constrained_columns = [] + for i in range(1, pks["count"] + 1): + constrained_columns.append(pks["pk_%i" % (i,)]) + return {"constrained_columns": constrained_columns, + "name": pks["name"]} + + @reflection.cache + def get_schema_names(self, connection, **kw): + + SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u") + + schemas = connection.execute(SCHEMA_SQL) + + return [s["name"] for s in schemas] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + TABLE_SQL = text(""" + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'U' + """) + + if util.py2k: + if isinstance(schema, unicode): + schema = schema.encode("ascii") + + tables = connection.execute(TABLE_SQL, schema_name=schema) + + return [t["name"] for t in tables] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_DEF_SQL = text(""" + SELECT c.text + FROM syscomments c JOIN sysobjects o ON c.id = o.id + WHERE o.name = :view_name + AND o.type = 'V' + """) + + if util.py2k: + if isinstance(view_name, unicode): + view_name = view_name.encode("ascii") + + view = connection.execute(VIEW_DEF_SQL, view_name=view_name) + + return view.scalar() + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_SQL = text(""" + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'V' + """) + + if util.py2k: + if isinstance(schema, unicode): + schema = schema.encode("ascii") + views = connection.execute(VIEW_SQL, schema_name=schema) + + return [v["name"] for v in views] + + def has_table(self, connection, table_name, schema=None): + try: + self.get_table_id(connection, table_name, schema) + except exc.NoSuchTableError: + return False + else: + return True diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 00000000..f14d1c42 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,32 @@ +# sybase/mxodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" + +.. dialect:: sybase+mxodbc + :name: mxODBC + :dbapi: mxodbc + :connectstring: sybase+mxodbc://:@ + :url: http://www.egenix.com/ + +.. note:: + + This dialect is a stub only and is likely non functional at this time. + + +""" +from sqlalchemy.dialects.sybase.base import SybaseDialect +from sqlalchemy.dialects.sybase.base import SybaseExecutionContext +from sqlalchemy.connectors.mxodbc import MxODBCConnector + + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + + +class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + +dialect = SybaseDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 00000000..f773e5a6 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,84 @@ +# sybase/pyodbc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sybase+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: sybase+pyodbc://:@[/] + :url: http://pypi.python.org/pypi/pyodbc/ + + +Unicode Support +--------------- + +The pyodbc driver currently supports usage of these Sybase types with +Unicode or multibyte strings:: + + CHAR + NCHAR + NVARCHAR + TEXT + VARCHAR + +Currently *not* supported are:: + + UNICHAR + UNITEXT + UNIVARCHAR + +""" + +from sqlalchemy.dialects.sybase.base import SybaseDialect,\ + SybaseExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes, processors +import decimal + + +class _SybNumeric_pyodbc(sqltypes.Numeric): + """Turns Decimals with adjusted() < -6 into floats. + + It's not yet known how to get decimals with many + significant digits or very large adjusted() into Sybase + via pyodbc. + + """ + + def bind_processor(self, dialect): + super_process = super(_SybNumeric_pyodbc, self).\ + bind_processor(dialect) + + def process(value): + if self.asdecimal and \ + isinstance(value, decimal.Decimal): + + if value.adjusted() < -6: + return processors.to_float(value) + + if super_process: + return super_process(value) + else: + return value + return process + + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + def set_ddl_autocommit(self, connection, value): + if value: + connection.autocommit = True + else: + connection.autocommit = False + + +class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + + colspecs = { + sqltypes.Numeric: _SybNumeric_pyodbc, + } + +dialect = SybaseDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py new file mode 100644 index 00000000..664bd9ac --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -0,0 +1,100 @@ +# sybase/pysybase.py +# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sybase+pysybase + :name: Python-Sybase + :dbapi: Sybase + :connectstring: sybase+pysybase://:@/[database name] + :url: http://python-sybase.sourceforge.net/ + +Unicode Support +--------------- + +The python-sybase driver does not appear to support non-ASCII strings of any +kind at this time. + +""" + +from sqlalchemy import types as sqltypes, processors +from sqlalchemy.dialects.sybase.base import SybaseDialect, \ + SybaseExecutionContext, SybaseSQLCompiler + + +class _SybNumeric(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class SybaseExecutionContext_pysybase(SybaseExecutionContext): + + def set_ddl_autocommit(self, dbapi_connection, value): + if value: + # call commit() on the Sybase connection directly, + # to avoid any side effects of calling a Connection + # transactional method inside of pre_exec() + dbapi_connection.commit() + + def pre_exec(self): + SybaseExecutionContext.pre_exec(self) + + for param in self.parameters: + for key in list(param): + param["@" + key] = param[key] + del param[key] + + +class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): + def bindparam_string(self, name, **kw): + return "@" + name + + +class SybaseDialect_pysybase(SybaseDialect): + driver = 'pysybase' + execution_ctx_cls = SybaseExecutionContext_pysybase + statement_compiler = SybaseSQLCompiler_pysybase + + colspecs = { + sqltypes.Numeric: _SybNumeric, + sqltypes.Float: sqltypes.Float + } + + @classmethod + def dbapi(cls): + import Sybase + return Sybase + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user', password='passwd') + + return ([opts.pop('host')], opts) + + def do_executemany(self, cursor, statement, parameters, context=None): + # calling python-sybase executemany yields: + # TypeError: string too long for buffer + for param in parameters: + cursor.execute(statement, param) + + def _get_server_version_info(self, connection): + vers = connection.scalar("select @@version_number") + # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), + # (12, 5, 0, 0) + return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, (self.dbapi.OperationalError, + self.dbapi.ProgrammingError)): + msg = str(e) + return ('Unable to complete network request to host' in msg or + 'Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = SybaseDialect_pysybase diff --git a/lib/sqlalchemy/dialects/type_migration_guidelines.txt b/lib/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 00000000..1ca15f7f --- /dev/null +++ b/lib/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +Postgresql dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has it's + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py new file mode 100644 index 00000000..99251f63 --- /dev/null +++ b/lib/sqlalchemy/engine/__init__.py @@ -0,0 +1,372 @@ +# engine/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL connections, SQL execution and high-level DB-API interface. + +The engine package defines the basic components used to interface +DB-API modules with higher-level statement construction, +connection-management, execution and result contexts. The primary +"entry point" class into this package is the Engine and it's public +constructor ``create_engine()``. + +This package includes: + +base.py + Defines interface classes and some implementation classes which + comprise the basic components used to interface between a DB-API, + constructed and plain-text statements, connections, transactions, + and results. + +default.py + Contains default implementations of some of the components defined + in base.py. All current database dialects use the classes in + default.py as base classes for their own database-specific + implementations. + +strategies.py + The mechanics of constructing ``Engine`` objects are represented + here. Defines the ``EngineStrategy`` class which represents how + to go from arguments specified to the ``create_engine()`` + function, to a fully constructed ``Engine``, including + initialization of connection pooling, dialects, and specific + subclasses of ``Engine``. + +threadlocal.py + The ``TLEngine`` class is defined here, which is a subclass of + the generic ``Engine`` and tracks ``Connection`` and + ``Transaction`` objects against the identity of the current + thread. This allows certain programming patterns based around + the concept of a "thread-local connection" to be possible. + The ``TLEngine`` is created by using the "threadlocal" engine + strategy in conjunction with the ``create_engine()`` function. + +url.py + Defines the ``URL`` class which represents the individual + components of a string URL passed to ``create_engine()``. Also + defines a basic module-loading strategy for the dialect specifier + within a URL. +""" + +from .interfaces import ( + Connectable, + Dialect, + ExecutionContext, + + # backwards compat + Compiled, + TypeCompiler +) + +from .base import ( + Connection, + Engine, + NestedTransaction, + RootTransaction, + Transaction, + TwoPhaseTransaction, + ) + +from .result import ( + BufferedColumnResultProxy, + BufferedColumnRow, + BufferedRowResultProxy, + FullyBufferedResultProxy, + ResultProxy, + RowProxy, + ) + +from .util import ( + connection_memoize + ) + + +from . import util, strategies + +# backwards compat +from ..sql import ddl + +default_strategy = 'plain' + + +def create_engine(*args, **kwargs): + """Create a new :class:`.Engine` instance. + + The standard calling form is to send the URL as the + first positional argument, usually a string + that indicates database dialect and connection arguments:: + + + engine = create_engine("postgresql://scott:tiger@localhost/test") + + Additional keyword arguments may then follow it which + establish various options on the resulting :class:`.Engine` + and its underlying :class:`.Dialect` and :class:`.Pool` + constructs:: + + engine = create_engine("mysql://scott:tiger@hostname/dbname", + encoding='latin1', echo=True) + + The string form of the URL is + ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where + ``dialect`` is a database name such as ``mysql``, ``oracle``, + ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as + ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, + the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. + + ``**kwargs`` takes a wide variety of options which are routed + towards their appropriate components. Arguments may be specific to + the :class:`.Engine`, the underlying :class:`.Dialect`, as well as the + :class:`.Pool`. Specific dialects also accept keyword arguments that + are unique to that dialect. Here, we describe the parameters + that are common to most :func:`.create_engine()` usage. + + Once established, the newly resulting :class:`.Engine` will + request a connection from the underlying :class:`.Pool` once + :meth:`.Engine.connect` is called, or a method which depends on it + such as :meth:`.Engine.execute` is invoked. The :class:`.Pool` in turn + will establish the first actual DBAPI connection when this request + is received. The :func:`.create_engine` call itself does **not** + establish any actual DBAPI connections directly. + + .. seealso:: + + :doc:`/core/engines` + + :doc:`/dialects/index` + + :ref:`connections_toplevel` + + :param case_sensitive=True: if False, result column names + will match in a case-insensitive fashion, that is, + ``row['SomeColumn']``. + + .. versionchanged:: 0.8 + By default, result row names match case-sensitively. + In version 0.7 and prior, all matches were case-insensitive. + + :param connect_args: a dictionary of options which will be + passed directly to the DBAPI's ``connect()`` method as + additional keyword arguments. See the example + at :ref:`custom_dbapi_args`. + + :param convert_unicode=False: if set to True, sets + the default behavior of ``convert_unicode`` on the + :class:`.String` type to ``True``, regardless + of a setting of ``False`` on an individual + :class:`.String` type, thus causing all :class:`.String` + -based columns + to accommodate Python ``unicode`` objects. This flag + is useful as an engine-wide setting when using a + DBAPI that does not natively support Python + ``unicode`` objects and raises an error when + one is received (such as pyodbc with FreeTDS). + + See :class:`.String` for further details on + what this flag indicates. + + :param creator: a callable which returns a DBAPI connection. + This creation function will be passed to the underlying + connection pool and will be used to create all new database + connections. Usage of this function causes connection + parameters specified in the URL argument to be bypassed. + + :param echo=False: if True, the Engine will log all statements + as well as a repr() of their parameter lists to the engines + logger, which defaults to sys.stdout. The ``echo`` attribute of + ``Engine`` can be modified at any time to turn logging on and + off. If set to the string ``"debug"``, result rows will be + printed to the standard output as well. This flag ultimately + controls a Python logger; see :ref:`dbengine_logging` for + information on how to configure logging directly. + + :param echo_pool=False: if True, the connection pool will log + all checkouts/checkins to the logging stream, which defaults to + sys.stdout. This flag ultimately controls a Python logger; see + :ref:`dbengine_logging` for information on how to configure logging + directly. + + :param encoding: Defaults to ``utf-8``. This is the string + encoding used by SQLAlchemy for string encode/decode + operations which occur within SQLAlchemy, **outside of + the DBAPI.** Most modern DBAPIs feature some degree of + direct support for Python ``unicode`` objects, + what you see in Python 2 as a string of the form + ``u'some string'``. For those scenarios where the + DBAPI is detected as not supporting a Python ``unicode`` + object, this encoding is used to determine the + source/destination encoding. It is **not used** + for those cases where the DBAPI handles unicode + directly. + + To properly configure a system to accommodate Python + ``unicode`` objects, the DBAPI should be + configured to handle unicode to the greatest + degree as is appropriate - see + the notes on unicode pertaining to the specific + target database in use at :ref:`dialect_toplevel`. + + Areas where string encoding may need to be accommodated + outside of the DBAPI include zero or more of: + + * the values passed to bound parameters, corresponding to + the :class:`.Unicode` type or the :class:`.String` type + when ``convert_unicode`` is ``True``; + * the values returned in result set columns corresponding + to the :class:`.Unicode` type or the :class:`.String` + type when ``convert_unicode`` is ``True``; + * the string SQL statement passed to the DBAPI's + ``cursor.execute()`` method; + * the string names of the keys in the bound parameter + dictionary passed to the DBAPI's ``cursor.execute()`` + as well as ``cursor.setinputsizes()`` methods; + * the string column names retrieved from the DBAPI's + ``cursor.description`` attribute. + + When using Python 3, the DBAPI is required to support + *all* of the above values as Python ``unicode`` objects, + which in Python 3 are just known as ``str``. In Python 2, + the DBAPI does not specify unicode behavior at all, + so SQLAlchemy must make decisions for each of the above + values on a per-DBAPI basis - implementations are + completely inconsistent in their behavior. + + :param execution_options: Dictionary execution options which will + be applied to all connections. See + :meth:`~sqlalchemy.engine.Connection.execution_options` + + :param implicit_returning=True: When ``True``, a RETURNING- + compatible construct, if available, will be used to + fetch newly generated primary key values when a single row + INSERT statement is emitted with no existing returning() + clause. This applies to those backends which support RETURNING + or a compatible construct, including Postgresql, Firebird, Oracle, + Microsoft SQL Server. Set this to ``False`` to disable + the automatic usage of RETURNING. + + :param label_length=None: optional integer value which limits + the size of dynamically generated column labels to that many + characters. If less than 6, labels are generated as + "_(counter)". If ``None``, the value of + ``dialect.max_identifier_length`` is used instead. + + :param listeners: A list of one or more + :class:`~sqlalchemy.interfaces.PoolListener` objects which will + receive connection pool events. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.engine" logger. Defaults to a hexstring of the + object's id. + + :param max_overflow=10: the number of connections to allow in + connection pool "overflow", that is connections that can be + opened above and beyond the pool_size setting, which defaults + to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. + + :param module=None: reference to a Python module object (the module + itself, not its string name). Specifies an alternate DBAPI module to + be used by the engine's dialect. Each sub-dialect references a + specific DBAPI which will be imported before first connect. This + parameter causes the import to be bypassed, and the given module to + be used instead. Can be used for testing of DBAPIs as well as to + inject "mock" DBAPI implementations into the :class:`.Engine`. + + :param pool=None: an already-constructed instance of + :class:`~sqlalchemy.pool.Pool`, such as a + :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this + pool will be used directly as the underlying connection pool + for the engine, bypassing whatever connection parameters are + present in the URL argument. For information on constructing + connection pools manually, see :ref:`pooling_toplevel`. + + :param poolclass=None: a :class:`~sqlalchemy.pool.Pool` + subclass, which will be used to create a connection pool + instance using the connection parameters given in the URL. Note + this differs from ``pool`` in that you don't actually + instantiate the pool in this case, you just indicate what type + of pool to be used. + + :param pool_logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param pool_size=5: the number of connections to keep open + inside the connection pool. This used with + :class:`~sqlalchemy.pool.QueuePool` as + well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With + :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting + of 0 indicates no limit; to disable pooling, set ``poolclass`` to + :class:`~sqlalchemy.pool.NullPool` instead. + + :param pool_recycle=-1: this setting causes the pool to recycle + connections after the given number of seconds has passed. It + defaults to -1, or no timeout. For example, setting to 3600 + means connections will be recycled after one hour. Note that + MySQL in particular will disconnect automatically if no + activity is detected on a connection for eight hours (although + this is configurable with the MySQLDB connection itself and the + server configuration as well). + + :param pool_reset_on_return='rollback': set the "reset on return" + behavior of the pool, which is whether ``rollback()``, + ``commit()``, or nothing is called upon connections + being returned to the pool. See the docstring for + ``reset_on_return`` at :class:`.Pool`. + + .. versionadded:: 0.7.6 + + :param pool_timeout=30: number of seconds to wait before giving + up on getting a connection from the pool. This is only used + with :class:`~sqlalchemy.pool.QueuePool`. + + :param strategy='plain': selects alternate engine implementations. + Currently available are: + + * the ``threadlocal`` strategy, which is described in + :ref:`threadlocal_strategy`; + * the ``mock`` strategy, which dispatches all statement + execution to a function passed as the argument ``executor``. + See `example in the FAQ + `_. + + :param executor=None: a function taking arguments + ``(sql, *multiparams, **params)``, to which the ``mock`` strategy will + dispatch all statement execution. Used only by ``strategy='mock'``. + + """ + + strategy = kwargs.pop('strategy', default_strategy) + strategy = strategies.strategies[strategy] + return strategy.create(*args, **kwargs) + + +def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): + """Create a new Engine instance using a configuration dictionary. + + The dictionary is typically produced from a config file where keys + are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The + 'prefix' argument indicates the prefix to be searched for. + + A select set of keyword arguments will be "coerced" to their + expected type based on string values. In a future release, this + functionality will be expanded and include dialect-specific + arguments. + """ + + options = dict((key[len(prefix):], configuration[key]) + for key in configuration + if key.startswith(prefix)) + options['_coerce_config'] = True + options.update(kwargs) + url = options.pop('url') + return create_engine(url, **options) + + +__all__ = ( + 'create_engine', + 'engine_from_config', + ) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py new file mode 100644 index 00000000..9f656cac --- /dev/null +++ b/lib/sqlalchemy/engine/base.py @@ -0,0 +1,1808 @@ +# engine/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import with_statement + +"""Defines :class:`.Connection` and :class:`.Engine`. + +""" + + +import sys +from .. import exc, util, log, interfaces +from ..sql import expression, util as sql_util, schema, ddl +from .interfaces import Connectable, Compiled +from .util import _distill_params +import contextlib + + +class Connection(Connectable): + """Provides high-level functionality for a wrapped DB-API connection. + + Provides execution support for string-based SQL statements as well as + :class:`.ClauseElement`, :class:`.Compiled` and :class:`.DefaultGenerator` + objects. Provides a :meth:`begin` method to return :class:`.Transaction` + objects. + + The Connection object is **not** thread-safe. While a Connection can be + shared among threads using properly synchronized access, it is still + possible that the underlying DBAPI connection may not support shared + access between threads. Check the DBAPI documentation for details. + + The Connection object represents a single dbapi connection checked out + from the connection pool. In this state, the connection pool has no affect + upon the connection, including its expiration or timeout state. For the + connection pool to properly manage connections, connections should be + returned to the connection pool (i.e. ``connection.close()``) whenever the + connection is not in use. + + .. index:: + single: thread safety; Connection + + """ + + def __init__(self, engine, connection=None, close_with_result=False, + _branch=False, _execution_options=None, + _dispatch=None, + _has_events=None): + """Construct a new Connection. + + The constructor here is not public and is only called only by an + :class:`.Engine`. See :meth:`.Engine.connect` and + :meth:`.Engine.contextual_connect` methods. + + """ + self.engine = engine + self.dialect = engine.dialect + self.__connection = connection or engine.raw_connection() + self.__transaction = None + self.should_close_with_result = close_with_result + self.__savepoint_seq = 0 + self.__branch = _branch + self.__invalid = False + self.__can_reconnect = True + if _dispatch: + self.dispatch = _dispatch + elif _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events) + + self._echo = self.engine._should_log_info() + if _execution_options: + self._execution_options =\ + engine._execution_options.union(_execution_options) + else: + self._execution_options = engine._execution_options + + if self._has_events or self.engine._has_events: + self.dispatch.engine_connect(self, _branch) + + def _branch(self): + """Return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled, + and also whose close() method does nothing. + + This is used to execute "sub" statements within a single execution, + usually an INSERT statement. + """ + + return self.engine._connection_cls( + self.engine, + self.__connection, + _branch=True, + _has_events=self._has_events, + _dispatch=self.dispatch) + + def _clone(self): + """Create a shallow copy of this Connection. + + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def execution_options(self, **opt): + """ Set non-SQL options for the connection which take effect + during execution. + + The method returns a copy of this :class:`.Connection` which references + the same underlying DBAPI connection, but also defines the given + execution options which will take effect for a call to + :meth:`execute`. As the new :class:`.Connection` references the same + underlying resource, it's usually a good idea to ensure that the copies + would be discarded immediately, which is implicit if used as in:: + + result = connection.execution_options(stream_results=True).\\ + execute(stmt) + + Note that any key/value can be passed to + :meth:`.Connection.execution_options`, and it will be stored in the + ``_execution_options`` dictionary of the :class:`.Connection`. It + is suitable for usage by end-user schemes to communicate with + event listeners, for example. + + The keywords that are currently recognized by SQLAlchemy itself + include all those listed under :meth:`.Executable.execution_options`, + as well as others that are specific to :class:`.Connection`. + + :param autocommit: Available on: Connection, statement. + When True, a COMMIT will be invoked after execution + when executed in 'autocommit' mode, i.e. when an explicit + transaction is not begun on the connection. Note that DBAPI + connections by default are always in a transaction - SQLAlchemy uses + rules applied to different kinds of statements to determine if + COMMIT will be invoked in order to provide its "autocommit" feature. + Typically, all INSERT/UPDATE/DELETE statements as well as + CREATE/DROP statements have autocommit behavior enabled; SELECT + constructs do not. Use this option when invoking a SELECT or other + specific SQL construct where COMMIT is desired (typically when + calling stored procedures and such), and an explicit + transaction is not in progress. + + :param compiled_cache: Available on: Connection. + A dictionary where :class:`.Compiled` objects + will be cached when the :class:`.Connection` compiles a clause + expression into a :class:`.Compiled` object. + It is the user's responsibility to + manage the size of this dictionary, which will have keys + corresponding to the dialect, clause element, the column + names within the VALUES or SET clause of an INSERT or UPDATE, + as well as the "batch" mode for an INSERT or UPDATE statement. + The format of this dictionary is not guaranteed to stay the + same in future releases. + + Note that the ORM makes use of its own "compiled" caches for + some operations, including flush operations. The caching + used by the ORM internally supersedes a cache dictionary + specified here. + + :param isolation_level: Available on: Connection. + Set the transaction isolation level for + the lifespan of this connection. Valid values include + those string values accepted by the ``isolation_level`` + parameter passed to :func:`.create_engine`, and are + database specific, including those for :ref:`sqlite_toplevel`, + :ref:`postgresql_toplevel` - see those dialect's documentation + for further info. + + Note that this option necessarily affects the underlying + DBAPI connection for the lifespan of the originating + :class:`.Connection`, and is not per-execution. This + setting is not removed until the underlying DBAPI connection + is returned to the connection pool, i.e. + the :meth:`.Connection.close` method is called. + + :param no_parameters: When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. + + .. versionadded:: 0.7.6 + + :param stream_results: Available on: Connection, statement. + Indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. This is a limitation + of many DBAPIs. The flag is currently understood only by the + psycopg2 dialect. + + """ + c = self._clone() + c._execution_options = c._execution_options.union(opt) + if self._has_events or self.engine._has_events: + self.dispatch.set_connection_execution_options(c, opt) + self.dialect.set_connection_execution_options(c, opt) + return c + + @property + def closed(self): + """Return True if this connection is closed.""" + + return '_Connection__connection' not in self.__dict__ \ + and not self.__can_reconnect + + @property + def invalidated(self): + """Return True if this connection was invalidated.""" + + return self.__invalid + + @property + def connection(self): + "The underlying DB-API connection managed by this Connection." + + try: + return self.__connection + except AttributeError: + return self._revalidate_connection() + + def _revalidate_connection(self): + if self.__can_reconnect and self.__invalid: + if self.__transaction is not None: + raise exc.InvalidRequestError( + "Can't reconnect until invalid " + "transaction is rolled back") + self.__connection = self.engine.raw_connection() + self.__invalid = False + return self.__connection + raise exc.ResourceClosedError("This Connection is closed") + + @property + def _connection_is_valid(self): + # use getattr() for is_valid to support exceptions raised in + # dialect initializer, where the connection is not wrapped in + # _ConnectionFairy + + return getattr(self.__connection, 'is_valid', False) + + @property + def _still_open_and_connection_is_valid(self): + return \ + not self.closed and \ + not self.invalidated and \ + getattr(self.__connection, 'is_valid', False) + + @property + def info(self): + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.Connection`, allowing user-defined + data to be associated with the connection. + + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`.Connection`. + + """ + + return self.connection.info + + def connect(self): + """Returns a branched version of this :class:`.Connection`. + + The :meth:`.Connection.close` method on the returned + :class:`.Connection` can be called and this + :class:`.Connection` will remain open. + + This method provides usage symmetry with + :meth:`.Engine.connect`, including for usage + with context managers. + + """ + + return self._branch() + + def contextual_connect(self, **kwargs): + """Returns a branched version of this :class:`.Connection`. + + The :meth:`.Connection.close` method on the returned + :class:`.Connection` can be called and this + :class:`.Connection` will remain open. + + This method provides usage symmetry with + :meth:`.Engine.contextual_connect`, including for usage + with context managers. + + """ + + return self._branch() + + def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with + this :class:`.Connection`. + + The underlying DBAPI connection is literally closed (if + possible), and is discarded. Its source connection pool will + typically lazily create a new connection to replace it. + + Upon the next use (where "use" typically means using the + :meth:`.Connection.execute` method or similar), + this :class:`.Connection` will attempt to + procure a new DBAPI connection using the services of the + :class:`.Pool` as a source of connectivty (e.g. a "reconnection"). + + If a transaction was in progress (e.g. the + :meth:`.Connection.begin` method has been called) when + :meth:`.Connection.invalidate` method is called, at the DBAPI + level all state associated with this transaction is lost, as + the DBAPI connection is closed. The :class:`.Connection` + will not allow a reconnection to proceed until the :class:`.Transaction` + object is ended, by calling the :meth:`.Transaction.rollback` + method; until that point, any attempt at continuing to use the + :class:`.Connection` will raise an + :class:`~sqlalchemy.exc.InvalidRequestError`. + This is to prevent applications from accidentally + continuing an ongoing transactional operations despite the + fact that the transaction has been lost due to an + invalidation. + + The :meth:`.Connection.invalidate` method, just like auto-invalidation, + will at the connection pool level invoke the :meth:`.PoolEvents.invalidate` + event. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + if self.invalidated: + return + + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + if self._connection_is_valid: + self.__connection.invalidate(exception) + del self.__connection + self.__invalid = True + + def detach(self): + """Detach the underlying DB-API connection from its connection pool. + + E.g.:: + + with engine.connect() as conn: + conn.detach() + conn.execute("SET search_path TO schema1, schema2") + + # work with connection + + # connection is fully closed (since we used "with:", can + # also call .close()) + + This :class:`.Connection` instance will remain usable. When closed + (or exited from a context manager context as above), + the DB-API connection will be literally closed and not + returned to its originating pool. + + This method can be used to insulate the rest of an application + from a modified state on a connection (such as a transaction + isolation level or similar). + + """ + + self.__connection.detach() + + def begin(self): + """Begin a transaction and return a transaction handle. + + The returned object is an instance of :class:`.Transaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`.Transaction.rollback` + or :meth:`.Transaction.commit` method is called. + + Nested calls to :meth:`.begin` on the same :class:`.Connection` + will return new :class:`.Transaction` objects that represent + an emulated transaction within the scope of the enclosing + transaction, that is:: + + trans = conn.begin() # outermost transaction + trans2 = conn.begin() # "nested" + trans2.commit() # does nothing + trans.commit() # actually commits + + Calls to :meth:`.Transaction.commit` only have an effect + when invoked via the outermost :class:`.Transaction` object, though the + :meth:`.Transaction.rollback` method of any of the + :class:`.Transaction` objects will roll back the + transaction. + + See also: + + :meth:`.Connection.begin_nested` - use a SAVEPOINT + + :meth:`.Connection.begin_twophase` - use a two phase /XID transaction + + :meth:`.Engine.begin` - context manager available from + :class:`.Engine`. + + """ + + if self.__transaction is None: + self.__transaction = RootTransaction(self) + return self.__transaction + else: + return Transaction(self, self.__transaction) + + def begin_nested(self): + """Begin a nested transaction and return a transaction handle. + + The returned object is an instance of :class:`.NestedTransaction`. + + Nested transactions require SAVEPOINT support in the + underlying database. Any transaction in the hierarchy may + ``commit`` and ``rollback``, however the outermost transaction + still controls the overall ``commit`` or ``rollback`` of the + transaction of a whole. + + See also :meth:`.Connection.begin`, + :meth:`.Connection.begin_twophase`. + """ + if self.__transaction is None: + self.__transaction = RootTransaction(self) + else: + self.__transaction = NestedTransaction(self, self.__transaction) + return self.__transaction + + def begin_twophase(self, xid=None): + """Begin a two-phase or XA transaction and return a transaction + handle. + + The returned object is an instance of :class:`.TwoPhaseTransaction`, + which in addition to the methods provided by + :class:`.Transaction`, also provides a + :meth:`~.TwoPhaseTransaction.prepare` method. + + :param xid: the two phase transaction id. If not supplied, a + random id will be generated. + + See also :meth:`.Connection.begin`, + :meth:`.Connection.begin_twophase`. + + """ + + if self.__transaction is not None: + raise exc.InvalidRequestError( + "Cannot start a two phase transaction when a transaction " + "is already in progress.") + if xid is None: + xid = self.engine.dialect.create_xid() + self.__transaction = TwoPhaseTransaction(self, xid) + return self.__transaction + + def recover_twophase(self): + return self.engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid, recover=False): + self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid, recover=False): + self.engine.dialect.do_commit_twophase(self, xid, recover=recover) + + def in_transaction(self): + """Return True if a transaction is in progress.""" + + return self.__transaction is not None + + def _begin_impl(self, transaction): + if self._echo: + self.engine.logger.info("BEGIN (implicit)") + + if self._has_events or self.engine._has_events: + self.dispatch.begin(self) + + try: + self.engine.dialect.do_begin(self.connection) + if self.connection._reset_agent is None: + self.connection._reset_agent = transaction + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _rollback_impl(self): + if self._has_events or self.engine._has_events: + self.dispatch.rollback(self) + + if self._still_open_and_connection_is_valid: + if self._echo: + self.engine.logger.info("ROLLBACK") + try: + self.engine.dialect.do_rollback(self.connection) + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + else: + self.__transaction = None + + def _commit_impl(self, autocommit=False): + if self._has_events or self.engine._has_events: + self.dispatch.commit(self) + + if self._echo: + self.engine.logger.info("COMMIT") + try: + self.engine.dialect.do_commit(self.connection) + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + + def _savepoint_impl(self, name=None): + if self._has_events or self.engine._has_events: + self.dispatch.savepoint(self, name) + + if name is None: + self.__savepoint_seq += 1 + name = 'sa_savepoint_%s' % self.__savepoint_seq + if self._still_open_and_connection_is_valid: + self.engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name, context): + if self._has_events or self.engine._has_events: + self.dispatch.rollback_savepoint(self, name, context) + + if self._still_open_and_connection_is_valid: + self.engine.dialect.do_rollback_to_savepoint(self, name) + self.__transaction = context + + def _release_savepoint_impl(self, name, context): + if self._has_events or self.engine._has_events: + self.dispatch.release_savepoint(self, name, context) + + if self._still_open_and_connection_is_valid: + self.engine.dialect.do_release_savepoint(self, name) + self.__transaction = context + + def _begin_twophase_impl(self, transaction): + if self._echo: + self.engine.logger.info("BEGIN TWOPHASE (implicit)") + if self._has_events or self.engine._has_events: + self.dispatch.begin_twophase(self, transaction.xid) + + if self._still_open_and_connection_is_valid: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + + if self.connection._reset_agent is None: + self.connection._reset_agent = transaction + + def _prepare_twophase_impl(self, xid): + if self._has_events or self.engine._has_events: + self.dispatch.prepare_twophase(self, xid) + + if self._still_open_and_connection_is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.engine.dialect.do_prepare_twophase(self, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + if self._has_events or self.engine._has_events: + self.dispatch.rollback_twophase(self, xid, is_prepared) + + if self._still_open_and_connection_is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_rollback_twophase(self, xid, is_prepared) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + else: + self.__transaction = None + + def _commit_twophase_impl(self, xid, is_prepared): + if self._has_events or self.engine._has_events: + self.dispatch.commit_twophase(self, xid, is_prepared) + + if self._still_open_and_connection_is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + else: + self.__transaction = None + + def _autorollback(self): + if not self.in_transaction(): + self._rollback_impl() + + def close(self): + """Close this :class:`.Connection`. + + This results in a release of the underlying database + resources, that is, the DBAPI connection referenced + internally. The DBAPI connection is typically restored + back to the connection-holding :class:`.Pool` referenced + by the :class:`.Engine` that produced this + :class:`.Connection`. Any transactional state present on + the DBAPI connection is also unconditionally released via + the DBAPI connection's ``rollback()`` method, regardless + of any :class:`.Transaction` object that may be + outstanding with regards to this :class:`.Connection`. + + After :meth:`~.Connection.close` is called, the + :class:`.Connection` is permanently in a closed state, + and will allow no further operations. + + """ + try: + conn = self.__connection + except AttributeError: + pass + else: + if not self.__branch: + conn.close() + if conn._reset_agent is self.__transaction: + conn._reset_agent = None + del self.__connection + self.__can_reconnect = False + self.__transaction = None + + def scalar(self, object, *multiparams, **params): + """Executes and returns the first column of the first row. + + The underlying result/cursor is closed after execution. + """ + + return self.execute(object, *multiparams, **params).scalar() + + def execute(self, object, *multiparams, **params): + """Executes the a SQL statement construct and returns a + :class:`.ResultProxy`. + + :param object: The statement to be executed. May be + one of: + + * a plain string + * any :class:`.ClauseElement` construct that is also + a subclass of :class:`.Executable`, such as a + :func:`~.expression.select` construct + * a :class:`.FunctionElement`, such as that generated + by :attr:`.func`, will be automatically wrapped in + a SELECT statement, which is then executed. + * a :class:`.DDLElement` object + * a :class:`.DefaultGenerator` object + * a :class:`.Compiled` object + + :param \*multiparams/\**params: represent bound parameter + values to be used in the execution. Typically, + the format is either a collection of one or more + dictionaries passed to \*multiparams:: + + conn.execute( + table.insert(), + {"id":1, "value":"v1"}, + {"id":2, "value":"v2"} + ) + + ...or individual key/values interpreted by \**params:: + + conn.execute( + table.insert(), id=1, value="v1" + ) + + In the case that a plain SQL string is passed, and the underlying + DBAPI accepts positional bind parameters, a collection of tuples + or individual values in \*multiparams may be passed:: + + conn.execute( + "INSERT INTO table (id, value) VALUES (?, ?)", + (1, "v1"), (2, "v2") + ) + + conn.execute( + "INSERT INTO table (id, value) VALUES (?, ?)", + 1, "v1" + ) + + Note above, the usage of a question mark "?" or other + symbol is contingent upon the "paramstyle" accepted by the DBAPI + in use, which may be any of "qmark", "named", "pyformat", "format", + "numeric". See `pep-249 `_ + for details on paramstyle. + + To execute a textual SQL statement which uses bound parameters in a + DBAPI-agnostic way, use the :func:`~.expression.text` construct. + + """ + if isinstance(object, util.string_types[0]): + return self._execute_text(object, multiparams, params) + try: + meth = object._execute_on_connection + except AttributeError: + raise exc.InvalidRequestError( + "Unexecutable object type: %s" % + type(object)) + else: + return meth(self, multiparams, params) + + def _execute_function(self, func, multiparams, params): + """Execute a sql.FunctionElement object.""" + + return self._execute_clauseelement(func.select(), + multiparams, params) + + def _execute_default(self, default, multiparams, params): + """Execute a schema.ColumnDefault object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + default, multiparams, params = \ + fn(self, default, multiparams, params) + + try: + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + + dialect = self.dialect + ctx = dialect.execution_ctx_cls._init_default( + dialect, self, conn) + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) + + ret = ctx._exec_default(default, None) + if self.should_close_with_result: + self.close() + + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + default, multiparams, params, ret) + + return ret + + def _execute_ddl(self, ddl, multiparams, params): + """Execute a schema.DDL object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + ddl, multiparams, params = \ + fn(self, ddl, multiparams, params) + + dialect = self.dialect + + compiled = ddl.compile(dialect=dialect) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_ddl, + compiled, + None, + compiled + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + ddl, multiparams, params, ret) + return ret + + def _execute_clauseelement(self, elem, multiparams, params): + """Execute a sql.ClauseElement object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + elem, multiparams, params = \ + fn(self, elem, multiparams, params) + + distilled_params = _distill_params(multiparams, params) + if distilled_params: + # note this is usually dict but we support RowProxy + # as well; but dict.keys() as an iterator is OK + keys = distilled_params[0].keys() + else: + keys = [] + + dialect = self.dialect + if 'compiled_cache' in self._execution_options: + key = dialect, elem, tuple(keys), len(distilled_params) > 1 + if key in self._execution_options['compiled_cache']: + compiled_sql = self._execution_options['compiled_cache'][key] + else: + compiled_sql = elem.compile( + dialect=dialect, column_keys=keys, + inline=len(distilled_params) > 1) + self._execution_options['compiled_cache'][key] = compiled_sql + else: + compiled_sql = elem.compile( + dialect=dialect, column_keys=keys, + inline=len(distilled_params) > 1) + + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled_sql, + distilled_params, + compiled_sql, distilled_params + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + elem, multiparams, params, ret) + return ret + + def _execute_compiled(self, compiled, multiparams, params): + """Execute a sql.Compiled object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + compiled, multiparams, params = \ + fn(self, compiled, multiparams, params) + + dialect = self.dialect + parameters = _distill_params(multiparams, params) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled, + parameters, + compiled, parameters + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + compiled, multiparams, params, ret) + return ret + + def _execute_text(self, statement, multiparams, params): + """Execute a string SQL statement.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + statement, multiparams, params = \ + fn(self, statement, multiparams, params) + + dialect = self.dialect + parameters = _distill_params(multiparams, params) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_statement, + statement, + parameters, + statement, parameters + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + statement, multiparams, params, ret) + return ret + + def _execute_context(self, dialect, constructor, + statement, parameters, + *args): + """Create an :class:`.ExecutionContext` and execute, returning + a :class:`.ResultProxy`.""" + + try: + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + + context = constructor(dialect, self, conn, *args) + except Exception as e: + self._handle_dbapi_exception(e, + util.text_type(statement), parameters, + None, None) + + if context.compiled: + context.pre_exec() + + cursor, statement, parameters = context.cursor, \ + context.statement, \ + context.parameters + + if not context.executemany: + parameters = parameters[0] + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = \ + fn(self, cursor, statement, parameters, + context, context.executemany) + + if self._echo: + self.engine.logger.info(statement) + self.engine.logger.info("%r", + sql_util._repr_params(parameters, batches=10)) + try: + if context.executemany: + for fn in () if not self.dialect._has_events \ + else self.dialect.dispatch.do_executemany: + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_executemany( + cursor, + statement, + parameters, + context) + + elif not parameters and context.no_parameters: + for fn in () if not self.dialect._has_events \ + else self.dialect.dispatch.do_execute_no_params: + if fn(cursor, statement, context): + break + else: + self.dialect.do_execute_no_params( + cursor, + statement, + context) + + else: + for fn in () if not self.dialect._has_events \ + else self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_execute( + cursor, + statement, + parameters, + context) + except Exception as e: + self._handle_dbapi_exception( + e, + statement, + parameters, + cursor, + context) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute(self, cursor, + statement, + parameters, + context, + context.executemany) + + if context.compiled: + context.post_exec() + + if context.isinsert and not context.executemany: + context.post_insert() + + # create a resultproxy, get rowcount/implicit RETURNING + # rows, close cursor if no further results pending + result = context.get_result_proxy() + if context.isinsert: + if context._is_implicit_returning: + context._fetch_implicit_returning(result) + result.close(_autoclose_connection=False) + result._metadata = None + elif not context._is_explicit_returning: + result.close(_autoclose_connection=False) + result._metadata = None + elif context.isupdate and context._is_implicit_returning: + context._fetch_implicit_update_returning(result) + result.close(_autoclose_connection=False) + result._metadata = None + + elif result._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc), + result.rowcount + result.close(_autoclose_connection=False) + + if self.__transaction is None and context.should_autocommit: + self._commit_impl(autocommit=True) + + if result.closed and self.should_close_with_result: + self.close() + + return result + + def _cursor_execute(self, cursor, statement, parameters, context=None): + """Execute a statement + params on the given cursor. + + Adds appropriate logging and exception handling. + + This method is used by DefaultDialect for special-case + executions, such as for sequences and column defaults. + The path of statement execution in the majority of cases + terminates at _execute_context(). + + """ + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = \ + fn(self, cursor, statement, parameters, + context, + False) + + if self._echo: + self.engine.logger.info(statement) + self.engine.logger.info("%r", parameters) + try: + for fn in () if not self.dialect._has_events \ + else self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_execute( + cursor, + statement, + parameters, + context) + except Exception as e: + self._handle_dbapi_exception( + e, + statement, + parameters, + cursor, + context) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute(self, cursor, + statement, + parameters, + context, + False) + + def _safe_close_cursor(self, cursor): + """Close the given cursor, catching exceptions + and turning into log warnings. + + """ + try: + cursor.close() + except (SystemExit, KeyboardInterrupt): + raise + except Exception: + self.connection._logger.error( + "Error closing cursor", exc_info=True) + + _reentrant_error = False + _is_disconnect = False + + def _handle_dbapi_exception(self, + e, + statement, + parameters, + cursor, + context): + + exc_info = sys.exc_info() + + if not self._is_disconnect: + self._is_disconnect = isinstance(e, self.dialect.dbapi.Error) and \ + not self.closed and \ + self.dialect.is_disconnect(e, self.__connection, cursor) + + if self._reentrant_error: + util.raise_from_cause( + exc.DBAPIError.instance(statement, + parameters, + e, + self.dialect.dbapi.Error), + exc_info + ) + self._reentrant_error = True + try: + # non-DBAPI error - if we already got a context, + # or theres no string statement, don't wrap it + should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ + (statement is not None and context is None) + + if should_wrap and context: + if self._has_events or self.engine._has_events: + self.dispatch.dbapi_error(self, + cursor, + statement, + parameters, + context, + e) + context.handle_dbapi_exception(e) + + if not self._is_disconnect: + if cursor: + self._safe_close_cursor(cursor) + self._autorollback() + + if should_wrap: + util.raise_from_cause( + exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + connection_invalidated=self._is_disconnect), + exc_info + ) + + util.reraise(*exc_info) + + finally: + del self._reentrant_error + if self._is_disconnect: + del self._is_disconnect + dbapi_conn_wrapper = self.connection + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) + if self.should_close_with_result: + self.close() + + def default_schema_name(self): + return self.engine.dialect.get_default_schema_name(self) + + def transaction(self, callable_, *args, **kwargs): + """Execute the given function within a transaction boundary. + + The function is passed this :class:`.Connection` + as the first argument, followed by the given \*args and \**kwargs, + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + conn.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Connection.begin`:: + + with conn.begin(): + conn.execute("some statement", {'x':5, 'y':10}) + + As well as with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) + + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Engine.transaction` - engine-level version of + :meth:`.Connection.transaction` + + """ + + trans = self.begin() + try: + ret = self.run_callable(callable_, *args, **kwargs) + trans.commit() + return ret + except: + with util.safe_reraise(): + trans.rollback() + + def run_callable(self, callable_, *args, **kwargs): + """Given a callable object or function, execute it, passing + a :class:`.Connection` as the first argument. + + The given \*args and \**kwargs are passed subsequent + to the :class:`.Connection` argument. + + This function, along with :meth:`.Engine.run_callable`, + allows a function to be run with a :class:`.Connection` + or :class:`.Engine` object without the need to know + which one is being dealt with. + + """ + return callable_(self, *args, **kwargs) + + def _run_visitor(self, visitorcallable, element, **kwargs): + visitorcallable(self.dialect, self, + **kwargs).traverse_single(element) + + +class Transaction(object): + """Represent a database transaction in progress. + + The :class:`.Transaction` object is procured by + calling the :meth:`~.Connection.begin` method of + :class:`.Connection`:: + + from sqlalchemy import create_engine + engine = create_engine("postgresql://scott:tiger@localhost/test") + connection = engine.connect() + trans = connection.begin() + connection.execute("insert into x (a, b) values (1, 2)") + trans.commit() + + The object provides :meth:`.rollback` and :meth:`.commit` + methods in order to control transaction boundaries. It + also implements a context manager interface so that + the Python ``with`` statement can be used with the + :meth:`.Connection.begin` method:: + + with connection.begin(): + connection.execute("insert into x (a, b) values (1, 2)") + + The Transaction object is **not** threadsafe. + + See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`, + :meth:`.Connection.begin_nested`. + + .. index:: + single: thread safety; Transaction + """ + + def __init__(self, connection, parent): + self.connection = connection + self._parent = parent or self + self.is_active = True + + def close(self): + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + if not self._parent.is_active: + return + if self._parent is self: + self.rollback() + + def rollback(self): + """Roll back this :class:`.Transaction`. + + """ + if not self._parent.is_active: + return + self._do_rollback() + self.is_active = False + + def _do_rollback(self): + self._parent.rollback() + + def commit(self): + """Commit this :class:`.Transaction`.""" + + if not self._parent.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self._do_commit() + self.is_active = False + + def _do_commit(self): + pass + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None and self.is_active: + try: + self.commit() + except: + with util.safe_reraise(): + self.rollback() + else: + self.rollback() + + +class RootTransaction(Transaction): + def __init__(self, connection): + super(RootTransaction, self).__init__(connection, None) + self.connection._begin_impl(self) + + def _do_rollback(self): + if self.is_active: + self.connection._rollback_impl() + + def _do_commit(self): + if self.is_active: + self.connection._commit_impl() + + +class NestedTransaction(Transaction): + """Represent a 'nested', or SAVEPOINT transaction. + + A new :class:`.NestedTransaction` object may be procured + using the :meth:`.Connection.begin_nested` method. + + The interface is the same as that of :class:`.Transaction`. + + """ + def __init__(self, connection, parent): + super(NestedTransaction, self).__init__(connection, parent) + self._savepoint = self.connection._savepoint_impl() + + def _do_rollback(self): + if self.is_active: + self.connection._rollback_to_savepoint_impl( + self._savepoint, self._parent) + + def _do_commit(self): + if self.is_active: + self.connection._release_savepoint_impl( + self._savepoint, self._parent) + + +class TwoPhaseTransaction(Transaction): + """Represent a two-phase transaction. + + A new :class:`.TwoPhaseTransaction` object may be procured + using the :meth:`.Connection.begin_twophase` method. + + The interface is the same as that of :class:`.Transaction` + with the addition of the :meth:`prepare` method. + + """ + def __init__(self, connection, xid): + super(TwoPhaseTransaction, self).__init__(connection, None) + self._is_prepared = False + self.xid = xid + self.connection._begin_twophase_impl(self) + + def prepare(self): + """Prepare this :class:`.TwoPhaseTransaction`. + + After a PREPARE, the transaction can be committed. + + """ + if not self._parent.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self.connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _do_rollback(self): + self.connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def _do_commit(self): + self.connection._commit_twophase_impl(self.xid, self._is_prepared) + + +class Engine(Connectable, log.Identified): + """ + Connects a :class:`~sqlalchemy.pool.Pool` and + :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a + source of database connectivity and behavior. + + An :class:`.Engine` object is instantiated publicly using the + :func:`~sqlalchemy.create_engine` function. + + See also: + + :doc:`/core/engines` + + :ref:`connections_toplevel` + + """ + + _execution_options = util.immutabledict() + _has_events = False + _connection_cls = Connection + + def __init__(self, pool, dialect, url, + logging_name=None, echo=None, proxy=None, + execution_options=None + ): + self.pool = pool + self.url = url + self.dialect = dialect + self.pool._dialect = dialect + if logging_name: + self.logging_name = logging_name + self.echo = echo + self.engine = self + log.instance_logger(self, echoflag=echo) + if proxy: + interfaces.ConnectionProxy._adapt_listener(self, proxy) + if execution_options: + self.update_execution_options(**execution_options) + + def update_execution_options(self, **opt): + """Update the default execution_options dictionary + of this :class:`.Engine`. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`.create_engine`. + + .. seealso:: + + :meth:`.Connection.execution_options` + + :meth:`.Engine.execution_options` + + """ + self._execution_options = \ + self._execution_options.union(opt) + self.dispatch.set_engine_execution_options(self, opt) + self.dialect.set_engine_execution_options(self, opt) + + def execution_options(self, **opt): + """Return a new :class:`.Engine` that will provide + :class:`.Connection` objects with the given execution options. + + The returned :class:`.Engine` remains related to the original + :class:`.Engine` in that it shares the same connection pool and + other state: + + * The :class:`.Pool` used by the new :class:`.Engine` is the + same instance. The :meth:`.Engine.dispose` method will replace + the connection pool instance for the parent engine as well + as this one. + * Event listeners are "cascaded" - meaning, the new :class:`.Engine` + inherits the events of the parent, and new events can be associated + with the new :class:`.Engine` individually. + * The logging configuration and logging_name is copied from the parent + :class:`.Engine`. + + The intent of the :meth:`.Engine.execution_options` method is + to implement "sharding" schemes where multiple :class:`.Engine` + objects refer to the same connection pool, but are differentiated + by options that would be consumed by a custom event:: + + primary_engine = create_engine("mysql://") + shard1 = primary_engine.execution_options(shard_id="shard1") + shard2 = primary_engine.execution_options(shard_id="shard2") + + Above, the ``shard1`` engine serves as a factory for + :class:`.Connection` objects that will contain the execution option + ``shard_id=shard1``, and ``shard2`` will produce :class:`.Connection` + objects that contain the execution option ``shard_id=shard2``. + + An event handler can consume the above execution option to perform + a schema switch or other operation, given a connection. Below + we emit a MySQL ``use`` statement to switch databases, at the same + time keeping track of which database we've established using the + :attr:`.Connection.info` dictionary, which gives us a persistent + storage space that follows the DBAPI connection:: + + from sqlalchemy import event + from sqlalchemy.engine import Engine + + shards = {"default": "base", shard_1: "db1", "shard_2": "db2"} + + @event.listens_for(Engine, "before_cursor_execute") + def _switch_shard(conn, cursor, stmt, params, context, executemany): + shard_id = conn._execution_options.get('shard_id', "default") + current_shard = conn.info.get("current_shard", None) + + if current_shard != shard_id: + cursor.execute("use %s" % shards[shard_id]) + conn.info["current_shard"] = shard_id + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.Connection.execution_options` - update execution options + on a :class:`.Connection` object. + + :meth:`.Engine.update_execution_options` - update the execution + options for a given :class:`.Engine` in place. + + """ + return OptionEngine(self, opt) + + @property + def name(self): + """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`.""" + + return self.dialect.name + + @property + def driver(self): + """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`.""" + + return self.dialect.driver + + echo = log.echo_property() + + def __repr__(self): + return 'Engine(%r)' % self.url + + def dispose(self): + """Dispose of the connection pool used by this :class:`.Engine`. + + A new connection pool is created immediately after the old one has + been disposed. This new pool, like all SQLAlchemy connection pools, + does not make any actual connections to the database until one is + first requested. + + This method has two general use cases: + + * When a dropped connection is detected, it is assumed that all + connections held by the pool are potentially dropped, and + the entire pool is replaced. + + * An application may want to use :meth:`dispose` within a test + suite that is creating multiple engines. + + It is critical to note that :meth:`dispose` does **not** guarantee + that the application will release all open database connections - only + those connections that are checked into the pool are closed. + Connections which remain checked out or have been detached from + the engine are not affected. + + """ + self.pool.dispose() + self.pool = self.pool.recreate() + + def _execute_default(self, default): + with self.contextual_connect() as conn: + return conn._execute_default(default, (), {}) + + @contextlib.contextmanager + def _optional_conn_ctx_manager(self, connection=None): + if connection is None: + with self.contextual_connect() as conn: + yield conn + else: + yield connection + + def _run_visitor(self, visitorcallable, element, + connection=None, **kwargs): + with self._optional_conn_ctx_manager(connection) as conn: + conn._run_visitor(visitorcallable, element, **kwargs) + + class _trans_ctx(object): + def __init__(self, conn, transaction, close_with_result): + self.conn = conn + self.transaction = transaction + self.close_with_result = close_with_result + + def __enter__(self): + return self.conn + + def __exit__(self, type, value, traceback): + if type is not None: + self.transaction.rollback() + else: + self.transaction.commit() + if not self.close_with_result: + self.conn.close() + + def begin(self, close_with_result=False): + """Return a context manager delivering a :class:`.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute("insert into table (x, y, z) values (1, 2, 3)") + conn.execute("my_special_procedure(5)") + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + The ``close_with_result`` flag is normally ``False``, and indicates + that the :class:`.Connection` will be closed when the operation + is complete. When set to ``True``, it indicates the + :class:`.Connection` is in "single use" mode, where the + :class:`.ResultProxy` returned by the first call to + :meth:`.Connection.execute` will close the :class:`.Connection` when + that :class:`.ResultProxy` has exhausted all result rows. + + .. versionadded:: 0.7.6 + + See also: + + :meth:`.Engine.connect` - procure a :class:`.Connection` from + an :class:`.Engine`. + + :meth:`.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`.Connection`. + + """ + conn = self.contextual_connect(close_with_result=close_with_result) + try: + trans = conn.begin() + except: + with util.safe_reraise(): + conn.close() + return Engine._trans_ctx(conn, trans, close_with_result) + + def transaction(self, callable_, *args, **kwargs): + """Execute the given function within a transaction boundary. + + The function is passed a :class:`.Connection` newly procured + from :meth:`.Engine.contextual_connect` as the first argument, + followed by the given \*args and \**kwargs. + + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + engine.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) + + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Connection.transaction` - connection-level version of + :meth:`.Engine.transaction` + + """ + + with self.contextual_connect() as conn: + return conn.transaction(callable_, *args, **kwargs) + + def run_callable(self, callable_, *args, **kwargs): + """Given a callable object or function, execute it, passing + a :class:`.Connection` as the first argument. + + The given \*args and \**kwargs are passed subsequent + to the :class:`.Connection` argument. + + This function, along with :meth:`.Connection.run_callable`, + allows a function to be run with a :class:`.Connection` + or :class:`.Engine` object without the need to know + which one is being dealt with. + + """ + with self.contextual_connect() as conn: + return conn.run_callable(callable_, *args, **kwargs) + + def execute(self, statement, *multiparams, **params): + """Executes the given construct and returns a :class:`.ResultProxy`. + + The arguments are the same as those used by + :meth:`.Connection.execute`. + + Here, a :class:`.Connection` is acquired using the + :meth:`~.Engine.contextual_connect` method, and the statement executed + with that connection. The returned :class:`.ResultProxy` is flagged + such that when the :class:`.ResultProxy` is exhausted and its + underlying cursor is closed, the :class:`.Connection` created here + will also be closed, which allows its associated DBAPI connection + resource to be returned to the connection pool. + + """ + + connection = self.contextual_connect(close_with_result=True) + return connection.execute(statement, *multiparams, **params) + + def scalar(self, statement, *multiparams, **params): + return self.execute(statement, *multiparams, **params).scalar() + + def _execute_clauseelement(self, elem, multiparams=None, params=None): + connection = self.contextual_connect(close_with_result=True) + return connection._execute_clauseelement(elem, multiparams, params) + + def _execute_compiled(self, compiled, multiparams, params): + connection = self.contextual_connect(close_with_result=True) + return connection._execute_compiled(compiled, multiparams, params) + + def connect(self, **kwargs): + """Return a new :class:`.Connection` object. + + The :class:`.Connection` object is a facade that uses a DBAPI + connection internally in order to communicate with the database. This + connection is procured from the connection-holding :class:`.Pool` + referenced by this :class:`.Engine`. When the + :meth:`~.Connection.close` method of the :class:`.Connection` object + is called, the underlying DBAPI connection is then returned to the + connection pool, where it may be used again in a subsequent call to + :meth:`~.Engine.connect`. + + """ + + return self._connection_cls(self, **kwargs) + + def contextual_connect(self, close_with_result=False, **kwargs): + """Return a :class:`.Connection` object which may be part of some + ongoing context. + + By default, this method does the same thing as :meth:`.Engine.connect`. + Subclasses of :class:`.Engine` may override this method + to provide contextual behavior. + + :param close_with_result: When True, the first :class:`.ResultProxy` + created by the :class:`.Connection` will call the + :meth:`.Connection.close` method of that connection as soon as any + pending result rows are exhausted. This is used to supply the + "connectionless execution" behavior provided by the + :meth:`.Engine.execute` method. + + """ + + return self._connection_cls(self, + self.pool.connect(), + close_with_result=close_with_result, + **kwargs) + + def table_names(self, schema=None, connection=None): + """Return a list of all table names available in the database. + + :param schema: Optional, retrieve names from a non-default schema. + + :param connection: Optional, use a specified connection. Default is + the ``contextual_connect`` for this ``Engine``. + """ + + with self._optional_conn_ctx_manager(connection) as conn: + if not schema: + schema = self.dialect.default_schema_name + return self.dialect.get_table_names(conn, schema) + + def has_table(self, table_name, schema=None): + """Return True if the given backend has a table of the given name. + + .. seealso:: + + :ref:`metadata_reflection_inspector` - detailed schema inspection using + the :class:`.Inspector` interface. + + :class:`.quoted_name` - used to pass quoting information along + with a schema identifier. + + """ + return self.run_callable(self.dialect.has_table, table_name, schema) + + def raw_connection(self): + """Return a "raw" DBAPI connection from the connection pool. + + The returned object is a proxied version of the DBAPI + connection object used by the underlying driver in use. + The object will have all the same behavior as the real DBAPI + connection, except that its ``close()`` method will result in the + connection being returned to the pool, rather than being closed + for real. + + This method provides direct DBAPI connection access for + special situations. In most situations, the :class:`.Connection` + object should be used, which is procured using the + :meth:`.Engine.connect` method. + + """ + + return self.pool.unique_connection() + + +class OptionEngine(Engine): + def __init__(self, proxied, execution_options): + self._proxied = proxied + self.url = proxied.url + self.dialect = proxied.dialect + self.logging_name = proxied.logging_name + self.echo = proxied.echo + log.instance_logger(self, echoflag=self.echo) + self.dispatch = self.dispatch._join(proxied.dispatch) + self._execution_options = proxied._execution_options + self.update_execution_options(**execution_options) + + def _get_pool(self): + return self._proxied.pool + + def _set_pool(self, pool): + self._proxied.pool = pool + + pool = property(_get_pool, _set_pool) + + def _get_has_events(self): + return self._proxied._has_events or \ + self.__dict__.get('_has_events', False) + + def _set_has_events(self, value): + self.__dict__['_has_events'] = value + + _has_events = property(_get_has_events, _set_has_events) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py new file mode 100644 index 00000000..0fd41105 --- /dev/null +++ b/lib/sqlalchemy/engine/default.py @@ -0,0 +1,957 @@ +# engine/default.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Default implementations of per-dialect sqlalchemy.engine classes. + +These are semi-private implementation classes which are only of importance +to database dialect authors; dialects will usually use the classes here +as the base class for their own corresponding classes. + +""" + +import re +import random +from . import reflection, interfaces, result +from ..sql import compiler, expression +from .. import types as sqltypes +from .. import exc, util, pool, processors +import codecs +import weakref +from .. import event + +AUTOCOMMIT_REGEXP = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', + re.I | re.UNICODE) + + + +class DefaultDialect(interfaces.Dialect): + """Default implementation of Dialect""" + + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler + preparer = compiler.IdentifierPreparer + supports_alter = True + + # the first value we'd get for an autoincrement + # column. + default_sequence_base = 1 + + # most DBAPIs happy with this for execute(). + # not cx_oracle. + execute_sequence_format = tuple + + supports_views = True + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + + supports_right_nested_joins = True + + supports_native_enum = False + supports_native_boolean = False + + supports_simple_order_by_label = True + + engine_config_types = util.immutabledict([ + ('convert_unicode', util.bool_or_str('force')), + ('pool_timeout', int), + ('echo', util.bool_or_str('debug')), + ('echo_pool', util.bool_or_str('debug')), + ('pool_recycle', int), + ('pool_size', int), + ('max_overflow', int), + ('pool_threadlocal', bool), + ('use_native_unicode', bool), + ]) + + # if the NUMERIC type + # returns decimal.Decimal. + # *not* the FLOAT type however. + supports_native_decimal = False + + if util.py3k: + supports_unicode_statements = True + supports_unicode_binds = True + returns_unicode_strings = True + description_encoding = None + else: + supports_unicode_statements = False + supports_unicode_binds = False + returns_unicode_strings = False + description_encoding = 'use_encoding' + + name = 'default' + + # length at which to truncate + # any identifier. + max_identifier_length = 9999 + + # length at which to truncate + # the name of an index. + # Usually None to indicate + # 'use max_identifier_length'. + # thanks to MySQL, sigh + max_index_name_length = None + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + dbapi_type_map = {} + colspecs = {} + default_paramstyle = 'named' + supports_default_values = False + supports_empty_insert = True + supports_multivalues_insert = False + + server_version_info = None + + construct_arguments = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the Postgresql dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + reflection_options = () + + def __init__(self, convert_unicode=False, + encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, + supports_right_nested_joins=None, + case_sensitive=True, + supports_native_boolean=None, + label_length=None, **kwargs): + + if not getattr(self, 'ported_sqla_06', True): + util.warn( + "The %s dialect is not yet ported to the 0.6 format" % + self.name) + + self.convert_unicode = convert_unicode + self.encoding = encoding + self.positional = False + self._ischema = None + self.dbapi = dbapi + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning + self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + if supports_right_nested_joins is not None: + self.supports_right_nested_joins = supports_right_nested_joins + if supports_native_boolean is not None: + self.supports_native_boolean = supports_native_boolean + self.case_sensitive = case_sensitive + + if label_length and label_length > self.max_identifier_length: + raise exc.ArgumentError( + "Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) + self.label_length = label_length + + if self.description_encoding == 'use_encoding': + self._description_decoder = \ + processors.to_unicode_processor_factory( + encoding + ) + elif self.description_encoding is not None: + self._description_decoder = \ + processors.to_unicode_processor_factory( + self.description_encoding + ) + self._encoder = codecs.getencoder(self.encoding) + self._decoder = processors.to_unicode_processor_factory(self.encoding) + + + + @util.memoized_property + def _type_memos(self): + return weakref.WeakKeyDictionary() + + @property + def dialect_description(self): + return self.name + "+" + self.driver + + @classmethod + def get_pool_class(cls, url): + return getattr(cls, 'poolclass', pool.QueuePool) + + def initialize(self, connection): + try: + self.server_version_info = \ + self._get_server_version_info(connection) + except NotImplementedError: + self.server_version_info = None + try: + self.default_schema_name = \ + self._get_default_schema_name(connection) + except NotImplementedError: + self.default_schema_name = None + + try: + self.default_isolation_level = \ + self.get_isolation_level(connection.connection) + except NotImplementedError: + self.default_isolation_level = None + + self.returns_unicode_strings = self._check_unicode_returns(connection) + + if self.description_encoding is not None and \ + self._check_unicode_description(connection): + self._description_decoder = self.description_encoding = None + + self.do_rollback(connection.connection) + + def on_connect(self): + """return a callable which sets up a newly created DBAPI connection. + + This is used to set dialect-wide per-connection options such as + isolation modes, unicode modes, etc. + + If a callable is returned, it will be assembled into a pool listener + that receives the direct DBAPI connection, with all wrappers removed. + + If None is returned, no listener will be generated. + + """ + return None + + def _check_unicode_returns(self, connection, additional_tests=None): + if util.py2k and not self.supports_unicode_statements: + cast_to = util.binary_type + else: + cast_to = util.text_type + + if self.positional: + parameters = self.execute_sequence_format() + else: + parameters = {} + + def check_unicode(test): + statement = cast_to(expression.select([test]).compile(dialect=self)) + try: + cursor = connection.connection.cursor() + connection._cursor_execute(cursor, statement, parameters) + row = cursor.fetchone() + cursor.close() + except exc.DBAPIError as de: + # note that _cursor_execute() will have closed the cursor + # if an exception is thrown. + util.warn("Exception attempting to " + "detect unicode returns: %r" % de) + return False + else: + return isinstance(row[0], util.text_type) + + tests = [ + # detect plain VARCHAR + expression.cast( + expression.literal_column("'test plain returns'"), + sqltypes.VARCHAR(60) + ), + # detect if there's an NVARCHAR type with different behavior available + expression.cast( + expression.literal_column("'test unicode returns'"), + sqltypes.Unicode(60) + ), + ] + + if additional_tests: + tests += additional_tests + + results = set([check_unicode(test) for test in tests]) + + if results.issuperset([True, False]): + return "conditional" + else: + return results == set([True]) + + def _check_unicode_description(self, connection): + # all DBAPIs on Py2K return cursor.description as encoded, + # until pypy2.1beta2 with sqlite, so let's just check it - + # it's likely others will start doing this too in Py2k. + + if util.py2k and not self.supports_unicode_statements: + cast_to = util.binary_type + else: + cast_to = util.text_type + + cursor = connection.connection.cursor() + try: + cursor.execute( + cast_to( + expression.select([ + expression.literal_column("'x'").label("some_label") + ]).compile(dialect=self) + ) + ) + return isinstance(cursor.description[0][0], util.text_type) + finally: + cursor.close() + + def type_descriptor(self, typeobj): + """Provide a database-specific :class:`.TypeEngine` object, given + the generic object which comes from the types module. + + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to :func:`.types.adapt_type`. + + """ + return sqltypes.adapt_type(typeobj, self.colspecs) + + def reflecttable(self, connection, table, include_columns, exclude_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns, exclude_columns) + + def get_pk_constraint(self, conn, table_name, schema=None, **kw): + """Compatibility method, adapts the result of get_primary_keys() + for those dialects which don't implement get_pk_constraint(). + + """ + return { + 'constrained_columns': + self.get_primary_keys(conn, table_name, + schema=schema, **kw) + } + + def validate_identifier(self, ident): + if len(ident) > self.max_identifier_length: + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + + def create_connect_args(self, url): + opts = url.translate_connect_args() + opts.update(url.query) + return [[], opts] + + def set_engine_execution_options(self, engine, opts): + if 'isolation_level' in opts: + isolation_level = opts['isolation_level'] + @event.listens_for(engine, "engine_connect") + def set_isolation(connection, branch): + if not branch: + self._set_connection_isolation(connection, isolation_level) + + def set_connection_execution_options(self, connection, opts): + if 'isolation_level' in opts: + self._set_connection_isolation(connection, opts['isolation_level']) + + def _set_connection_isolation(self, connection, level): + self.set_isolation_level(connection.connection, level) + connection.connection._connection_record.\ + finalize_callback.append(self.reset_isolation_level) + + + def do_begin(self, dbapi_connection): + pass + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def do_close(self, dbapi_connection): + dbapi_connection.close() + + def create_xid(self): + """Create a random two-phase transaction ID. + + This id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). Its format is unspecified. + """ + + return "_sa_%032x" % random.randint(0, 2 ** 128) + + def do_savepoint(self, connection, name): + connection.execute(expression.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(expression.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(expression.ReleaseSavepointClause(name)) + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany(statement, parameters) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters) + + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + + def is_disconnect(self, e, connection, cursor): + return False + + def reset_isolation_level(self, dbapi_conn): + # default_isolation_level is read from the first connection + # after the initial set of 'isolation_level', if any, so is + # the configured default of this dialect. + self.set_isolation_level(dbapi_conn, self.default_isolation_level) + + +class DefaultExecutionContext(interfaces.ExecutionContext): + isinsert = False + isupdate = False + isdelete = False + isddl = False + executemany = False + result_map = None + compiled = None + statement = None + postfetch_cols = None + prefetch_cols = None + returning_cols = None + _is_implicit_returning = False + _is_explicit_returning = False + + # a hook for SQLite's translation of + # result column names + _translate_colname = None + + @classmethod + def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl): + """Initialize execution context for a DDLElement construct.""" + + self = cls.__new__(cls) + self.dialect = dialect + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.engine = connection.engine + + self.compiled = compiled = compiled_ddl + self.isddl = True + + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = dict(self.execution_options) + self.execution_options.update(connection._execution_options) + + if not dialect.supports_unicode_statements: + self.unicode_statement = util.text_type(compiled) + self.statement = dialect._encoder(self.unicode_statement)[0] + else: + self.statement = self.unicode_statement = util.text_type(compiled) + + self.cursor = self.create_cursor() + self.compiled_parameters = [] + + if dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [{}] + + return self + + @classmethod + def _init_compiled(cls, dialect, connection, dbapi_connection, + compiled, parameters): + """Initialize execution context for a Compiled construct.""" + + self = cls.__new__(cls) + self.dialect = dialect + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.engine = connection.engine + + self.compiled = compiled + + if not compiled.can_execute: + raise exc.ArgumentError("Not an executable clause") + + self.execution_options = compiled.statement._execution_options + if connection._execution_options: + self.execution_options = dict(self.execution_options) + self.execution_options.update(connection._execution_options) + + # compiled clauseelement. process bind params, process table defaults, + # track collections used by ResultProxy to target and process results + + self.result_map = compiled.result_map + + self.unicode_statement = util.text_type(compiled) + if not dialect.supports_unicode_statements: + self.statement = self.unicode_statement.encode( + self.dialect.encoding) + else: + self.statement = self.unicode_statement + + self.isinsert = compiled.isinsert + self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete + + if self.isinsert or self.isupdate or self.isdelete: + self._is_explicit_returning = bool(compiled.statement._returning) + self._is_implicit_returning = bool(compiled.returning and \ + not compiled.statement._returning) + + if not parameters: + self.compiled_parameters = [compiled.construct_params()] + else: + self.compiled_parameters = \ + [compiled.construct_params(m, _group_number=grp) for + grp, m in enumerate(parameters)] + + self.executemany = len(parameters) > 1 + + self.cursor = self.create_cursor() + if self.isinsert or self.isupdate: + self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch + self.returning_cols = self.compiled.returning + self.__process_defaults() + + processors = compiled._bind_processors + + # Convert the dictionary of bind parameter values + # into a dict or list to be sent to the DBAPI's + # execute() or executemany() method. + parameters = [] + if dialect.positional: + for compiled_params in self.compiled_parameters: + param = [] + for key in self.compiled.positiontup: + if key in processors: + param.append(processors[key](compiled_params[key])) + else: + param.append(compiled_params[key]) + parameters.append(dialect.execute_sequence_format(param)) + else: + encode = not dialect.supports_unicode_statements + for compiled_params in self.compiled_parameters: + param = {} + if encode: + for key in compiled_params: + if key in processors: + param[dialect._encoder(key)[0]] = \ + processors[key](compiled_params[key]) + else: + param[dialect._encoder(key)[0]] = \ + compiled_params[key] + else: + for key in compiled_params: + if key in processors: + param[key] = processors[key](compiled_params[key]) + else: + param[key] = compiled_params[key] + parameters.append(param) + self.parameters = dialect.execute_sequence_format(parameters) + + return self + + @classmethod + def _init_statement(cls, dialect, connection, dbapi_connection, + statement, parameters): + """Initialize execution context for a string SQL statement.""" + + self = cls.__new__(cls) + self.dialect = dialect + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.engine = connection.engine + + # plain text statement + self.execution_options = connection._execution_options + + if not parameters: + if self.dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [{}] + elif isinstance(parameters[0], dialect.execute_sequence_format): + self.parameters = parameters + elif isinstance(parameters[0], dict): + if dialect.supports_unicode_statements: + self.parameters = parameters + else: + self.parameters = [ + dict((dialect._encoder(k)[0], d[k]) for k in d) + for d in parameters + ] or [{}] + else: + self.parameters = [dialect.execute_sequence_format(p) + for p in parameters] + + self.executemany = len(parameters) > 1 + + if not dialect.supports_unicode_statements and \ + isinstance(statement, util.text_type): + self.unicode_statement = statement + self.statement = dialect._encoder(statement)[0] + else: + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + return self + + @classmethod + def _init_default(cls, dialect, connection, dbapi_connection): + """Initialize execution context for a ColumnDefault construct.""" + + self = cls.__new__(cls) + self.dialect = dialect + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.engine = connection.engine + self.execution_options = connection._execution_options + self.cursor = self.create_cursor() + return self + + @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + @util.memoized_property + def is_crud(self): + return self.isinsert or self.isupdate or self.isdelete + + @util.memoized_property + def should_autocommit(self): + autocommit = self.execution_options.get('autocommit', + not self.compiled and + self.statement and + expression.PARSE_AUTOCOMMIT + or False) + + if autocommit is expression.PARSE_AUTOCOMMIT: + return self.should_autocommit_text(self.unicode_statement) + else: + return autocommit + + def _execute_scalar(self, stmt, type_): + """Execute a string statement on the current cursor, returning a + scalar result. + + Used to fire off sequences, default phrases, and "select lastrowid" + types of statements individually or in the context of a parent INSERT + or UPDATE statement. + + """ + + conn = self.root_connection + if isinstance(stmt, util.text_type) and \ + not self.dialect.supports_unicode_statements: + stmt = self.dialect._encoder(stmt)[0] + + if self.dialect.positional: + default_params = self.dialect.execute_sequence_format() + else: + default_params = {} + + conn._cursor_execute(self.cursor, stmt, default_params, context=self) + r = self.cursor.fetchone()[0] + if type_ is not None: + # apply type post processors to the result + proc = type_._cached_result_processor( + self.dialect, + self.cursor.description[0][1] + ) + if proc: + return proc(r) + return r + + @property + def connection(self): + return self.root_connection._branch() + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + + def create_cursor(self): + return self._dbapi_connection.cursor() + + def pre_exec(self): + pass + + def post_exec(self): + pass + + def get_result_processor(self, type_, colname, coltype): + """Return a 'result processor' for a given type as present in + cursor.description. + + This has a default implementation that dialects can override + for context-sensitive result type handling. + + """ + return type_._cached_result_processor(self.dialect, coltype) + + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + return self.cursor.lastrowid + + def handle_dbapi_exception(self, e): + pass + + def get_result_proxy(self): + return result.ResultProxy(self) + + @property + def rowcount(self): + return self.cursor.rowcount + + def supports_sane_rowcount(self): + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + return self.dialect.supports_sane_multi_rowcount + + def post_insert(self): + if not self._is_implicit_returning and \ + not self._is_explicit_returning and \ + not self.compiled.inline and \ + self.dialect.postfetch_lastrowid and \ + (not self.inserted_primary_key or \ + None in self.inserted_primary_key): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None) + if proc is not None: + lastrowid = proc(lastrowid) + + self.inserted_primary_key = [ + lastrowid if c is autoinc_col else v + for c, v in zip( + table.primary_key, + self.inserted_primary_key) + ] + + def _fetch_implicit_returning(self, resultproxy): + table = self.compiled.statement.table + row = resultproxy.fetchone() + + ipk = [] + for c, v in zip(table.primary_key, self.inserted_primary_key): + if v is not None: + ipk.append(v) + else: + ipk.append(row[c]) + + self.inserted_primary_key = ipk + self.returned_defaults = row + + def _fetch_implicit_update_returning(self, resultproxy): + row = resultproxy.fetchone() + self.returned_defaults = row + + def lastrow_has_defaults(self): + return (self.isinsert or self.isupdate) and \ + bool(self.postfetch_cols) + + def set_input_sizes(self, translate=None, exclude_types=None): + """Given a cursor and ClauseParameters, call the appropriate + style of ``setinputsizes()`` on the cursor, using DB-API types + from the bind parameter's ``TypeEngine`` objects. + + This method only called by those dialects which require it, + currently cx_oracle. + + """ + + if not hasattr(self.compiled, 'bind_names'): + return + + types = dict( + (self.compiled.bind_names[bindparam], bindparam.type) + for bindparam in self.compiled.bind_names) + + if self.dialect.positional: + inputsizes = [] + for key in self.compiled.positiontup: + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and \ + (not exclude_types or dbtype not in exclude_types): + inputsizes.append(dbtype) + try: + self.cursor.setinputsizes(*inputsizes) + except Exception as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self) + else: + inputsizes = {} + for key in self.compiled.bind_names.values(): + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and \ + (not exclude_types or dbtype not in exclude_types): + if translate: + key = translate.get(key, key) + if not self.dialect.supports_unicode_binds: + key = self.dialect._encoder(key)[0] + inputsizes[key] = dbtype + try: + self.cursor.setinputsizes(**inputsizes) + except Exception as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self) + + def _exec_default(self, default, type_): + if default.is_sequence: + return self.fire_sequence(default, type_) + elif default.is_callable: + return default.arg(self) + elif default.is_clause_element: + # TODO: expensive branching here should be + # pulled into _exec_scalar() + conn = self.connection + c = expression.select([default.arg]).compile(bind=conn) + return conn._execute_compiled(c, (), {}).scalar() + else: + return default.arg + + def get_insert_default(self, column): + if column.default is None: + return None + else: + return self._exec_default(column.default, column.type) + + def get_update_default(self, column): + if column.onupdate is None: + return None + else: + return self._exec_default(column.onupdate, column.type) + + def __process_defaults(self): + """Generate default values for compiled insert/update statements, + and generate inserted_primary_key collection. + """ + + key_getter = self.compiled._key_getters_for_crud_column[2] + + if self.executemany: + if len(self.compiled.prefetch): + scalar_defaults = {} + + # pre-determine scalar Python-side defaults + # to avoid many calls of get_insert_default()/ + # get_update_default() + for c in self.prefetch_cols: + if self.isinsert and c.default and c.default.is_scalar: + scalar_defaults[c] = c.default.arg + elif self.isupdate and c.onupdate and c.onupdate.is_scalar: + scalar_defaults[c] = c.onupdate.arg + + for param in self.compiled_parameters: + self.current_parameters = param + for c in self.prefetch_cols: + if c in scalar_defaults: + val = scalar_defaults[c] + elif self.isinsert: + val = self.get_insert_default(c) + else: + val = self.get_update_default(c) + if val is not None: + param[key_getter(c)] = val + del self.current_parameters + else: + self.current_parameters = compiled_parameters = \ + self.compiled_parameters[0] + + for c in self.compiled.prefetch: + if self.isinsert: + val = self.get_insert_default(c) + else: + val = self.get_update_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + del self.current_parameters + + if self.isinsert: + self.inserted_primary_key = [ + self.compiled_parameters[0].get(key_getter(c), None) + for c in self.compiled.\ + statement.table.primary_key + ] + + +DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py new file mode 100644 index 00000000..73722586 --- /dev/null +++ b/lib/sqlalchemy/engine/interfaces.py @@ -0,0 +1,849 @@ +# engine/interfaces.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Define core interfaces used by the engine system.""" + +from .. import util, event + +# backwards compat +from ..sql.compiler import Compiled, TypeCompiler + +class Dialect(object): + """Define the behavior of a specific database and DB-API combination. + + Any aspect of metadata definition, SQL query generation, + execution, result-set handling, or anything else which varies + between databases is defined under the general category of the + Dialect. The Dialect acts as a factory for other + database-specific object implementations including + ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. + + All Dialects implement the following attributes: + + name + identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + + driver + identifying name for the dialect's DBAPI + + positional + True if the paramstyle for this Dialect is positional. + + paramstyle + the paramstyle to be used (some DB-APIs support multiple + paramstyles). + + convert_unicode + True if Unicode conversion should be applied to all ``str`` + types. + + encoding + type of encoding to use for unicode, usually defaults to + 'utf-8'. + + statement_compiler + a :class:`.Compiled` class used to compile SQL statements + + ddl_compiler + a :class:`.Compiled` class used to compile DDL statements + + server_version_info + a tuple containing a version number for the DB backend in use. + This value is only available for supporting dialects, and is + typically populated during the initial connection to the database. + + default_schema_name + the name of the default schema. This value is only available for + supporting dialects, and is typically populated during the + initial connection to the database. + + execution_ctx_cls + a :class:`.ExecutionContext` class used to handle statement execution + + execute_sequence_format + either the 'tuple' or 'list' type, depending on what cursor.execute() + accepts for the second argument (they vary). + + preparer + a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to + quote identifiers. + + supports_alter + ``True`` if the database supports ``ALTER TABLE``. + + max_identifier_length + The maximum length of identifier names. + + supports_unicode_statements + Indicate whether the DB-API can receive SQL statements as Python + unicode strings + + supports_unicode_binds + Indicate whether the DB-API can receive string bind parameters + as Python unicode strings + + supports_sane_rowcount + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. + + supports_sane_multi_rowcount + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + + preexecute_autoincrement_sequences + True if 'implicit' primary key functions must be executed separately + in order to get their value. This is currently oriented towards + Postgresql. + + implicit_returning + use RETURNING or equivalent during INSERT execution in order to load + newly generated primary keys and other column defaults in one execution, + which are then available via inserted_primary_key. + If an insert statement has returning() specified explicitly, + the "implicit" functionality is not used and inserted_primary_key + will not be available. + + dbapi_type_map + A mapping of DB-API type objects present in this Dialect's + DB-API implementation mapped to TypeEngine implementations used + by the dialect. + + This is used to apply types to result sets based on the DB-API + types present in cursor.description; it only takes effect for + result sets against textual statements where no explicit + typemap was present. + + colspecs + A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. + + supports_default_values + Indicates if the construct ``INSERT INTO tablename DEFAULT + VALUES`` is supported + + supports_sequences + Indicates if the dialect supports CREATE SEQUENCE or similar. + + sequences_optional + If True, indicates if the "optional" flag on the Sequence() construct + should signal to not generate a CREATE SEQUENCE. Applies only to + dialects that support sequences. Currently used only to allow Postgresql + SERIAL to be used on a column that specifies Sequence() for usage on + other backends. + + supports_native_enum + Indicates if the dialect supports a native ENUM construct. + This will prevent types.Enum from generating a CHECK + constraint when that type is used. + + supports_native_boolean + Indicates if the dialect supports a native boolean construct. + This will prevent types.Boolean from generating a CHECK + constraint when that type is used. + + """ + + _has_events = False + + + def create_connect_args(self, url): + """Build DB-API compatible connection arguments. + + Given a :class:`~sqlalchemy.engine.url.URL` object, returns a tuple + consisting of a `*args`/`**kwargs` suitable to send directly + to the dbapi's connect function. + + """ + + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj): + """Transform a generic type to a dialect-specific type. + + Dialect classes will usually use the + :func:`.types.adapt_type` function in the types module to + accomplish this. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. + + """ + + raise NotImplementedError() + + def initialize(self, connection): + """Called during strategized creation of the dialect with a + connection. + + Allows dialects to configure options based on server version info or + other properties. + + The connection passed here is a SQLAlchemy Connection object, + with full capabilities. + + The initalize() method of the base dialect should be called via + super(). + + """ + + pass + + def reflecttable(self, connection, table, include_columns, exclude_columns): + """Load table description from the database. + + Given a :class:`.Connection` and a + :class:`~sqlalchemy.schema.Table` object, reflect its columns and + properties from the database. + + The implementation of this method is provided by + :meth:`.DefaultDialect.reflecttable`, which makes use of + :class:`.Inspector` to retrieve column information. + + Dialects should **not** seek to implement this method, and should + instead implement individual schema inspection operations such as + :meth:`.Dialect.get_columns`, :meth:`.Dialect.get_pk_constraint`, + etc. + + """ + + raise NotImplementedError() + + def get_columns(self, connection, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a :class:`.Connection`, a string + `table_name`, and an optional string `schema`, return column + information as a list of dictionaries with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + autoincrement + boolean + + sequence + a dictionary of the form + {'name' : str, 'start' :int, 'increment': int} + + Additional column attributes may be present. + """ + + raise NotImplementedError() + + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + + Deprecated. This method is only called by the default + implementation of :meth:`.Dialect.get_pk_constraint`. Dialects should + instead implement the :meth:`.Dialect.get_pk_constraint` method directly. + + """ + + raise NotImplementedError() + + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + """Return information about the primary key constraint on + table_name`. + + Given a :class:`.Connection`, a string + `table_name`, and an optional string `schema`, return primary + key information as a dictionary with these keys: + + constrained_columns + a list of column names that make up the primary key + + name + optional name of the primary key constraint. + + """ + raise NotImplementedError() + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a :class:`.Connection`, a string + `table_name`, and an optional string `schema`, return foreign + key information as a list of dicts with these keys: + + name + the constraint's name + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + raise NotImplementedError() + + def get_table_names(self, connection, schema=None, **kw): + """Return a list of table names for `schema`.""" + + raise NotImplementedError + + def get_view_names(self, connection, schema=None, **kw): + """Return a list of all view names available in the database. + + schema: + Optional, retrieve names from a non-default schema. + """ + + raise NotImplementedError() + + def get_view_definition(self, connection, view_name, schema=None, **kw): + """Return view definition. + + Given a :class:`.Connection`, a string + `view_name`, and an optional string `schema`, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes(self, connection, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a :class:`.Connection`, a string + `table_name` and an optional string `schema`, return index + information as a list of dictionaries with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + raise NotImplementedError() + + def get_unique_constraints(self, connection, table_name, schema=None, **kw): + """Return information about unique constraints in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + unique constraint information as a list of dicts with these keys: + + name + the unique constraint's name + + column_names + list of column names in order + + \**kw + other options passed to the dialect's get_unique_constraints() method. + + .. versionadded:: 0.9.0 + + """ + + raise NotImplementedError() + + def normalize_name(self, name): + """convert the given name to lowercase if it is detected as + case insensitive. + + this method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name): + """convert the given name to a case insensitive identifier + for the backend if it is an all-lowercase name. + + this method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def has_table(self, connection, table_name, schema=None): + """Check the existence of a particular table in the database. + + Given a :class:`.Connection` object and a string + `table_name`, return True if the given table (possibly within + the specified `schema`) exists in the database, False + otherwise. + """ + + raise NotImplementedError() + + def has_sequence(self, connection, sequence_name, schema=None): + """Check the existence of a particular sequence in the database. + + Given a :class:`.Connection` object and a string + `sequence_name`, return True if the given sequence exists in + the database, False otherwise. + """ + + raise NotImplementedError() + + def _get_server_version_info(self, connection): + """Retrieve the server version info from the given connection. + + This is used by the default implementation to populate the + "server_version_info" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + """Return the string name of the currently selected schema from + the given connection. + + This is used by the default implementation to populate the + "default_schema_name" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def do_begin(self, dbapi_connection): + """Provide an implementation of ``connection.begin()``, given a + DB-API connection. + + The DBAPI has no dedicated "begin" method and it is expected + that transactions are implicit. This hook is provided for those + DBAPIs that might need additional help in this area. + + Note that :meth:`.Dialect.do_begin` is not called unless a + :class:`.Transaction` object is in use. The + :meth:`.Dialect.do_autocommit` + hook is provided for DBAPIs that need some extra commands emitted + after a commit in order to enter the next transaction, when the + SQLAlchemy :class:`.Connection` is used in it's default "autocommit" + mode. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_rollback(self, dbapi_connection): + """Provide an implementation of ``connection.rollback()``, given + a DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + + def do_commit(self, dbapi_connection): + """Provide an implementation of ``connection.commit()``, given a + DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_close(self, dbapi_connection): + """Provide an implementation of ``connection.close()``, given a DBAPI + connection. + + This hook is called by the :class:`.Pool` when a connection has been + detached from the pool, or is being returned beyond the normal + capacity of the pool. + + .. versionadded:: 0.8 + + """ + + raise NotImplementedError() + + def create_xid(self): + """Create a two-phase transaction ID. + + This id will be passed to do_begin_twophase(), + do_rollback_twophase(), do_commit_twophase(). Its format is + unspecified. + """ + + raise NotImplementedError() + + def do_savepoint(self, connection, name): + """Create a savepoint with the given name. + + :param connection: a :class:`.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_rollback_to_savepoint(self, connection, name): + """Rollback a connection to the named savepoint. + + :param connection: a :class:`.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_release_savepoint(self, connection, name): + """Release the named savepoint on a connection. + + :param connection: a :class:`.Connection`. + :param name: savepoint name. + """ + + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + """Begin a two phase transaction on the given connection. + + :param connection: a :class:`.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_prepare_twophase(self, connection, xid): + """Prepare a two phase transaction on the given connection. + + :param connection: a :class:`.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + """Rollback a two phase transaction on the given connection. + + :param connection: a :class:`.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + """Commit a two phase transaction on the given connection. + + + :param connection: a :class:`.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_recover_twophase(self, connection): + """Recover list of uncommited prepared two phase transaction + identifiers on the given connection. + + :param connection: a :class:`.Connection`. + + """ + + raise NotImplementedError() + + def do_executemany(self, cursor, statement, parameters, context=None): + """Provide an implementation of ``cursor.executemany(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of ``cursor.execute(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute_no_params(self, cursor, statement, parameters, + context=None): + """Provide an implementation of ``cursor.execute(statement)``. + + The parameter collection should not be sent. + + """ + + raise NotImplementedError() + + def is_disconnect(self, e, connection, cursor): + """Return True if the given DB-API error indicates an invalid + connection""" + + raise NotImplementedError() + + def connect(self): + """return a callable which sets up a newly created DBAPI connection. + + The callable accepts a single argument "conn" which is the + DBAPI connection itself. It has no return value. + + This is used to set dialect-wide per-connection options such as + isolation modes, unicode modes, etc. + + If a callable is returned, it will be assembled into a pool listener + that receives the direct DBAPI connection, with all wrappers removed. + + If None is returned, no listener will be generated. + + """ + return None + + def reset_isolation_level(self, dbapi_conn): + """Given a DBAPI connection, revert its isolation to the default.""" + + raise NotImplementedError() + + def set_isolation_level(self, dbapi_conn, level): + """Given a DBAPI connection, set its isolation level.""" + + raise NotImplementedError() + + def get_isolation_level(self, dbapi_conn): + """Given a DBAPI connection, return its isolation level.""" + + raise NotImplementedError() + + +class ExecutionContext(object): + """A messenger object for a Dialect that corresponds to a single + execution. + + ExecutionContext should have these data members: + + connection + Connection object which can be freely used by default value + generators to execute SQL. This Connection should reference the + same underlying connection/transactional resources of + root_connection. + + root_connection + Connection object which is the source of this ExecutionContext. This + Connection may have close_with_result=True set, in which case it can + only be used once. + + dialect + dialect which created this ExecutionContext. + + cursor + DB-API cursor procured from the connection, + + compiled + if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed, + + statement + string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed. + + parameters + bind parameters passed to the execute() method. For compiled + statements, this is a dictionary or list of dictionaries. For + textual statements, it should be in a format suitable for the + dialect's paramstyle (i.e. dict or list of dicts for non + positional, list or list of lists/tuples for positional). + + isinsert + True if the statement is an INSERT. + + isupdate + True if the statement is an UPDATE. + + should_autocommit + True if the statement is a "committable" statement. + + prefetch_cols + a list of Column objects for which a client-side default + was fired off. Applies to inserts and updates. + + postfetch_cols + a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates. + """ + + def create_cursor(self): + """Return a new cursor generated from this ExecutionContext's + connection. + + Some dialects may wish to change the behavior of + connection.cursor(), such as postgresql which may return a PG + "server side" cursor. + """ + + raise NotImplementedError() + + def pre_exec(self): + """Called before an execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `statement` and `parameters` datamembers must be + initialized after this statement is complete. + """ + + raise NotImplementedError() + + def post_exec(self): + """Called after the execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method completes. + """ + + raise NotImplementedError() + + def result(self): + """Return a result object corresponding to this ExecutionContext. + + Returns a ResultProxy. + """ + + raise NotImplementedError() + + def handle_dbapi_exception(self, e): + """Receive a DBAPI exception which occurred upon execute, result + fetch, etc.""" + + raise NotImplementedError() + + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to + a "committable" statement""" + + raise NotImplementedError() + + def lastrow_has_defaults(self): + """Return True if the last INSERT or UPDATE row contained + inlined or database-side defaults. + """ + + raise NotImplementedError() + + def get_rowcount(self): + """Return the DBAPI ``cursor.rowcount`` value, or in some + cases an interpreted value. + + See :attr:`.ResultProxy.rowcount` for details on this. + + """ + + raise NotImplementedError() + + +class Connectable(object): + """Interface for an object which supports execution of SQL constructs. + + The two implementations of :class:`.Connectable` are + :class:`.Connection` and :class:`.Engine`. + + Connectable must also implement the 'dialect' member which references a + :class:`.Dialect` instance. + + """ + + def connect(self, **kwargs): + """Return a :class:`.Connection` object. + + Depending on context, this may be ``self`` if this object + is already an instance of :class:`.Connection`, or a newly + procured :class:`.Connection` if this object is an instance + of :class:`.Engine`. + + """ + + def contextual_connect(self): + """Return a :class:`.Connection` object which may be part of an ongoing + context. + + Depending on context, this may be ``self`` if this object + is already an instance of :class:`.Connection`, or a newly + procured :class:`.Connection` if this object is an instance + of :class:`.Engine`. + + """ + + raise NotImplementedError() + + @util.deprecated("0.7", + "Use the create() method on the given schema " + "object directly, i.e. :meth:`.Table.create`, " + ":meth:`.Index.create`, :meth:`.MetaData.create_all`") + def create(self, entity, **kwargs): + """Emit CREATE statements for the given schema entity. + """ + + raise NotImplementedError() + + @util.deprecated("0.7", + "Use the drop() method on the given schema " + "object directly, i.e. :meth:`.Table.drop`, " + ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`") + def drop(self, entity, **kwargs): + """Emit DROP statements for the given schema entity. + """ + + raise NotImplementedError() + + def execute(self, object, *multiparams, **params): + """Executes the given construct and returns a :class:`.ResultProxy`.""" + raise NotImplementedError() + + def scalar(self, object, *multiparams, **params): + """Executes and returns the first column of the first row. + + The underlying cursor is closed after execution. + """ + raise NotImplementedError() + + def _run_visitor(self, visitorcallable, element, + **kwargs): + raise NotImplementedError() + + def _execute_clauseelement(self, elem, multiparams=None, params=None): + raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py new file mode 100644 index 00000000..45f10051 --- /dev/null +++ b/lib/sqlalchemy/engine/reflection.py @@ -0,0 +1,590 @@ +# engine/reflection.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" + +from .. import exc, sql +from ..sql import schema as sa_schema +from .. import util +from ..sql.type_api import TypeEngine +from ..util import deprecated +from ..util import topological +from .. import inspection +from .base import Connectable + + +@util.decorator +def cache(fn, self, con, *args, **kw): + info_cache = kw.get('info_cache', None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, util.string_types)), + tuple((k, v) for k, v in kw.items() if + isinstance(v, + util.string_types + util.int_types + (float, ) + ) + ) + ) + ret = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +class Inspector(object): + """Performs database schema inspection. + + The Inspector acts as a proxy to the reflection methods of the + :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a + consistent interface as well as caching support for previously + fetched metadata. + + A :class:`.Inspector` object is usually created via the + :func:`.inspect` function:: + + from sqlalchemy import inspect, create_engine + engine = create_engine('...') + insp = inspect(engine) + + The inspection method above is equivalent to using the + :meth:`.Inspector.from_engine` method, i.e.:: + + engine = create_engine('...') + insp = Inspector.from_engine(engine) + + Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` may opt + to return an :class:`.Inspector` subclass that provides additional + methods specific to the dialect's target database. + + """ + + def __init__(self, bind): + """Initialize a new :class:`.Inspector`. + + :param bind: a :class:`~sqlalchemy.engine.Connectable`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + For a dialect-specific instance of :class:`.Inspector`, see + :meth:`.Inspector.from_engine` + + """ + # this might not be a connection, it could be an engine. + self.bind = bind + + # set the engine + if hasattr(bind, 'engine'): + self.engine = bind.engine + else: + self.engine = bind + + if self.engine is bind: + # if engine, ensure initialized + bind.connect().close() + + self.dialect = self.engine.dialect + self.info_cache = {} + + @classmethod + def from_engine(cls, bind): + """Construct a new dialect-specific Inspector object from the given + engine or connection. + + :param bind: a :class:`~sqlalchemy.engine.Connectable`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + This method differs from direct a direct constructor call of + :class:`.Inspector` in that the + :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to + provide a dialect-specific :class:`.Inspector` instance, which may + provide additional methods. + + See the example at :class:`.Inspector`. + + """ + if hasattr(bind.dialect, 'inspector'): + return bind.dialect.inspector(bind) + return Inspector(bind) + + @inspection._inspects(Connectable) + def _insp(bind): + return Inspector.from_engine(bind) + + @property + def default_schema_name(self): + """Return the default schema name presented by the dialect + for the current engine's database user. + + E.g. this is typically ``public`` for Postgresql and ``dbo`` + for SQL Server. + + """ + return self.dialect.default_schema_name + + def get_schema_names(self): + """Return all schema names. + """ + + if hasattr(self.dialect, 'get_schema_names'): + return self.dialect.get_schema_names(self.bind, + info_cache=self.info_cache) + return [] + + def get_table_names(self, schema=None, order_by=None): + """Return all table names in referred to within a particular schema. + + The names are expected to be real tables only, not views. + Views are instead returned using the :meth:`.Inspector.get_view_names` + method. + + + :param schema: Schema name. If ``schema`` is left at ``None``, the + database's default schema is + used, else the named schema is searched. If the database does not + support named schemas, behavior is undefined if ``schema`` is not + passed as ``None``. For special quoting, use :class:`.quoted_name`. + + :param order_by: Optional, may be the string "foreign_key" to sort + the result on foreign key dependencies. + + .. versionchanged:: 0.8 the "foreign_key" sorting sorts tables + in order of dependee to dependent; that is, in creation + order, rather than in drop order. This is to maintain + consistency with similar features such as + :attr:`.MetaData.sorted_tables` and :func:`.util.sort_tables`. + + .. seealso:: + + :attr:`.MetaData.sorted_tables` + + """ + + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names(self.bind, + schema, info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + if order_by == 'foreign_key': + tuples = [] + for tname in tnames: + for fkey in self.get_foreign_keys(tname, schema): + if tname != fkey['referred_table']: + tuples.append((fkey['referred_table'], tname)) + tnames = list(topological.sort(tuples, tnames)) + return tnames + + def get_table_options(self, table_name, schema=None, **kw): + """Return a dictionary of options specified when the table of the + given name was created. + + This currently includes some options that apply to MySQL tables. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + if hasattr(self.dialect, 'get_table_options'): + return self.dialect.get_table_options( + self.bind, table_name, schema, + info_cache=self.info_cache, **kw) + return {} + + def get_view_names(self, schema=None): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + + """ + + return self.dialect.get_view_names(self.bind, schema, + info_cache=self.info_cache) + + def get_view_definition(self, view_name, schema=None): + """Return definition for `view_name`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + + """ + + return self.dialect.get_view_definition( + self.bind, view_name, schema, info_cache=self.info_cache) + + def get_columns(self, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + column information as a list of dicts with these keys: + + name + the column's name + + type + :class:`~sqlalchemy.types.TypeEngine` + + nullable + boolean + + default + the column's default value + + attrs + dict containing optional column attributes + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + + col_defs = self.dialect.get_columns(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs + + @deprecated('0.7', 'Call to deprecated method get_primary_keys.' + ' Use get_pk_constraint instead.') + def get_primary_keys(self, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a list of column names. + """ + + return self.dialect.get_pk_constraint(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw)['constrained_columns'] + + def get_pk_constraint(self, table_name, schema=None, **kw): + """Return information about primary key constraint on `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a dictionary with these keys: + + constrained_columns + a list of column names that make up the primary key + + name + optional name of the primary key constraint. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + return self.dialect.get_pk_constraint(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw) + + def get_foreign_keys(self, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + foreign key information as a list of dicts with these keys: + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + + name + optional name of the foreign key constraint. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + + return self.dialect.get_foreign_keys(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw) + + def get_indexes(self, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + + return self.dialect.get_indexes(self.bind, table_name, + schema, + info_cache=self.info_cache, **kw) + + def get_unique_constraints(self, table_name, schema=None, **kw): + """Return information about unique constraints in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + unique constraint information as a list of dicts with these keys: + + name + the unique constraint's name + + column_names + list of column names in order + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + .. versionadded:: 0.8.4 + + """ + + return self.dialect.get_unique_constraints( + self.bind, table_name, schema, info_cache=self.info_cache, **kw) + + def reflecttable(self, table, include_columns, exclude_columns=()): + """Given a Table object, load its internal constructs based on + introspection. + + This is the underlying method used by most dialects to produce + table reflection. Direct usage is like:: + + from sqlalchemy import create_engine, MetaData, Table + from sqlalchemy.engine import reflection + + engine = create_engine('...') + meta = MetaData() + user_table = Table('user', meta) + insp = Inspector.from_engine(engine) + insp.reflecttable(user_table, None) + + :param table: a :class:`~sqlalchemy.schema.Table` instance. + :param include_columns: a list of string column names to include + in the reflection process. If ``None``, all columns are reflected. + + """ + dialect = self.bind.dialect + + schema = table.schema + table_name = table.name + + # get table-level arguments that are specifically + # intended for reflection, e.g. oracle_resolve_synonyms. + # these are unconditionally passed to related Table + # objects + reflection_options = dict( + (k, table.dialect_kwargs.get(k)) + for k in dialect.reflection_options + if k in table.dialect_kwargs + ) + + # reflect table options, like mysql_engine + tbl_opts = self.get_table_options(table_name, schema, **table.dialect_kwargs) + if tbl_opts: + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) + + if util.py2k: + if isinstance(schema, str): + schema = schema.decode(dialect.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(dialect.encoding) + + found_table = False + cols_by_orig_name = {} + + for col_d in self.get_columns(table_name, schema, **table.dialect_kwargs): + found_table = True + orig_name = col_d['name'] + + table.dispatch.column_reflect(self, table, col_d) + + name = col_d['name'] + if include_columns and name not in include_columns: + continue + if exclude_columns and name in exclude_columns: + continue + + coltype = col_d['type'] + + col_kw = dict( + (k, col_d[k]) + for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] + if k in col_d + ) + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL + # expression, so is wrapped in text() so that no quoting + # occurs on re-issuance. + colargs.append( + sa_schema.DefaultClause( + sql.text(col_d['default']), _reflected=True + ) + ) + + if 'sequence' in col_d: + # TODO: mssql and sybase are using this. + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + cols_by_orig_name[orig_name] = col = \ + sa_schema.Column(name, coltype, *colargs, **col_kw) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col) + + if not found_table: + raise exc.NoSuchTableError(table.name) + + pk_cons = self.get_pk_constraint(table_name, schema, **table.dialect_kwargs) + if pk_cons: + pk_cols = [ + cols_by_orig_name[pk] + for pk in pk_cons['constrained_columns'] + if pk in cols_by_orig_name and pk not in exclude_columns + ] + + # update pk constraint name + table.primary_key.name = pk_cons.get('name') + + # tell the PKConstraint to re-initialize + # it's column collection + table.primary_key._reload(pk_cols) + + fkeys = self.get_foreign_keys(table_name, schema, **table.dialect_kwargs) + for fkey_d in fkeys: + conname = fkey_d['name'] + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_columns = [ + cols_by_orig_name[c].key + if c in cols_by_orig_name else c + for c in fkey_d['constrained_columns'] + ] + if exclude_columns and set(constrained_columns).intersection( + exclude_columns): + continue + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] + refspec = [] + if referred_schema is not None: + sa_schema.Table(referred_table, table.metadata, + autoload=True, schema=referred_schema, + autoload_with=self.bind, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join( + [referred_schema, referred_table, column])) + else: + sa_schema.Table(referred_table, table.metadata, autoload=True, + autoload_with=self.bind, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + if 'options' in fkey_d: + options = fkey_d['options'] + else: + options = {} + table.append_constraint( + sa_schema.ForeignKeyConstraint(constrained_columns, refspec, + conname, link_to_name=True, + **options)) + # Indexes + indexes = self.get_indexes(table_name, schema) + for index_d in indexes: + name = index_d['name'] + columns = index_d['column_names'] + unique = index_d['unique'] + flavor = index_d.get('type', 'unknown type') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting %s KEY for (%s), key covers omitted columns." % + (flavor, ', '.join(columns))) + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + sa_schema.Index(name, *[ + cols_by_orig_name[c] if c in cols_by_orig_name + else table.c[c] + for c in columns + ], + **dict(unique=unique)) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py new file mode 100644 index 00000000..6c98dae1 --- /dev/null +++ b/lib/sqlalchemy/engine/result.py @@ -0,0 +1,1033 @@ +# engine/result.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Define result set constructs including :class:`.ResultProxy` +and :class:`.RowProxy.""" + + + +from .. import exc, util +from ..sql import expression, sqltypes +import collections +import operator + +# This reconstructor is necessary so that pickles with the C extension or +# without use the same Binary format. +try: + # We need a different reconstructor on the C extension so that we can + # add extra checks that fields have correctly been initialized by + # __setstate__. + from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor + + # The extra function embedding is needed so that the + # reconstructor function has the same signature whether or not + # the extension is present. + def rowproxy_reconstructor(cls, state): + return safe_rowproxy_reconstructor(cls, state) +except ImportError: + def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + +try: + from sqlalchemy.cresultproxy import BaseRowProxy +except ImportError: + class BaseRowProxy(object): + __slots__ = ('_parent', '_row', '_processors', '_keymap') + + def __init__(self, parent, row, processors, keymap): + """RowProxy objects are constructed by ResultProxy objects.""" + + self._parent = parent + self._row = row + self._processors = processors + self._keymap = keymap + + def __reduce__(self): + return (rowproxy_reconstructor, + (self.__class__, self.__getstate__())) + + def values(self): + """Return the values represented by this RowProxy as a list.""" + return list(self) + + def __iter__(self): + for processor, value in zip(self._processors, self._row): + if processor is None: + yield value + else: + yield processor(value) + + def __len__(self): + return len(self._row) + + def __getitem__(self, key): + try: + processor, obj, index = self._keymap[key] + except KeyError: + processor, obj, index = self._parent._key_fallback(key) + except TypeError: + if isinstance(key, slice): + l = [] + for processor, value in zip(self._processors[key], + self._row[key]): + if processor is None: + l.append(value) + else: + l.append(processor(value)) + return tuple(l) + else: + raise + if index is None: + raise exc.InvalidRequestError( + "Ambiguous column name '%s' in result set! " + "try 'use_labels' option on select statement." % key) + if processor is not None: + return processor(self._row[index]) + else: + return self._row[index] + + def __getattr__(self, name): + try: + return self[name] + except KeyError as e: + raise AttributeError(e.args[0]) + + +class RowProxy(BaseRowProxy): + """Proxy values from a single cursor row. + + Mostly follows "ordered dictionary" behavior, mapping result + values to the string-based column name, the integer position of + the result in the row, as well as Column instances which can be + mapped to the original Columns that produced this result set (for + results that correspond to constructed SQL expressions). + """ + __slots__ = () + + def __contains__(self, key): + return self._parent._has_key(self._row, key) + + def __getstate__(self): + return { + '_parent': self._parent, + '_row': tuple(self) + } + + def __setstate__(self, state): + self._parent = parent = state['_parent'] + self._row = state['_row'] + self._processors = parent._processors + self._keymap = parent._keymap + + __hash__ = None + + def _op(self, other, op): + return op(tuple(self), tuple(other)) \ + if isinstance(other, RowProxy) \ + else op(tuple(self), other) + + def __lt__(self, other): + return self._op(other, operator.lt) + + def __le__(self, other): + return self._op(other, operator.le) + + def __ge__(self, other): + return self._op(other, operator.ge) + + def __gt__(self, other): + return self._op(other, operator.gt) + + def __eq__(self, other): + return self._op(other, operator.eq) + + def __ne__(self, other): + return self._op(other, operator.ne) + + def __repr__(self): + return repr(tuple(self)) + + def has_key(self, key): + """Return True if this RowProxy contains the given key.""" + + return self._parent._has_key(self._row, key) + + def items(self): + """Return a list of tuples, each tuple containing a key/value pair.""" + # TODO: no coverage here + return [(key, self[key]) for key in self.keys()] + + def keys(self): + """Return the list of keys as strings represented by this RowProxy.""" + + return self._parent.keys + + def iterkeys(self): + return iter(self._parent.keys) + + def itervalues(self): + return iter(self) + +try: + # Register RowProxy with Sequence, + # so sequence protocol is implemented + from collections import Sequence + Sequence.register(RowProxy) +except ImportError: + pass + + +class ResultMetaData(object): + """Handle cursor.description, applying additional info from an execution + context.""" + + def __init__(self, parent, metadata): + self._processors = processors = [] + + # We do not strictly need to store the processor in the key mapping, + # though it is faster in the Python version (probably because of the + # saved attribute lookup self._processors) + self._keymap = keymap = {} + self.keys = [] + context = parent.context + dialect = context.dialect + typemap = dialect.dbapi_type_map + translate_colname = context._translate_colname + self.case_sensitive = dialect.case_sensitive + + # high precedence key values. + primary_keymap = {} + + for i, rec in enumerate(metadata): + colname = rec[0] + coltype = rec[1] + + if dialect.description_encoding: + colname = dialect._description_decoder(colname) + + if translate_colname: + colname, untranslated = translate_colname(colname) + + if dialect.requires_name_normalize: + colname = dialect.normalize_name(colname) + + if context.result_map: + try: + name, obj, type_ = context.result_map[colname + if self.case_sensitive + else colname.lower()] + except KeyError: + name, obj, type_ = \ + colname, None, typemap.get(coltype, sqltypes.NULLTYPE) + else: + name, obj, type_ = \ + colname, None, typemap.get(coltype, sqltypes.NULLTYPE) + + processor = context.get_result_processor(type_, colname, coltype) + + processors.append(processor) + rec = (processor, obj, i) + + # indexes as keys. This is only needed for the Python version of + # RowProxy (the C version uses a faster path for integer indexes). + primary_keymap[i] = rec + + # populate primary keymap, looking for conflicts. + if primary_keymap.setdefault( + name if self.case_sensitive + else name.lower(), + rec) is not rec: + # place a record that doesn't have the "index" - this + # is interpreted later as an AmbiguousColumnError, + # but only when actually accessed. Columns + # colliding by name is not a problem if those names + # aren't used; integer access is always + # unambiguous. + primary_keymap[name + if self.case_sensitive + else name.lower()] = rec = (None, obj, None) + + self.keys.append(colname) + if obj: + for o in obj: + keymap[o] = rec + # technically we should be doing this but we + # are saving on callcounts by not doing so. + # if keymap.setdefault(o, rec) is not rec: + # keymap[o] = (None, obj, None) + + if translate_colname and \ + untranslated: + keymap[untranslated] = rec + + # overwrite keymap values with those of the + # high precedence keymap. + keymap.update(primary_keymap) + + if parent._echo: + context.engine.logger.debug( + "Col %r", tuple(x[0] for x in metadata)) + + @util.pending_deprecation("0.8", "sqlite dialect uses " + "_translate_colname() now") + def _set_keymap_synonym(self, name, origname): + """Set a synonym for the given name. + + Some dialects (SQLite at the moment) may use this to + adjust the column names that are significant within a + row. + + """ + rec = (processor, obj, i) = self._keymap[origname if + self.case_sensitive + else origname.lower()] + if self._keymap.setdefault(name, rec) is not rec: + self._keymap[name] = (processor, obj, None) + + def _key_fallback(self, key, raiseerr=True): + map = self._keymap + result = None + if isinstance(key, util.string_types): + result = map.get(key if self.case_sensitive else key.lower()) + # fallback for targeting a ColumnElement to a textual expression + # this is a rare use case which only occurs when matching text() + # or colummn('name') constructs to ColumnElements, or after a + # pickle/unpickle roundtrip + elif isinstance(key, expression.ColumnElement): + if key._label and ( + key._label + if self.case_sensitive + else key._label.lower()) in map: + result = map[key._label + if self.case_sensitive + else key._label.lower()] + elif hasattr(key, 'name') and ( + key.name + if self.case_sensitive + else key.name.lower()) in map: + # match is only on name. + result = map[key.name + if self.case_sensitive + else key.name.lower()] + # search extra hard to make sure this + # isn't a column/label name overlap. + # this check isn't currently available if the row + # was unpickled. + if result is not None and \ + result[1] is not None: + for obj in result[1]: + if key._compare_name_for_result(obj): + break + else: + result = None + if result is None: + if raiseerr: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" % + expression._string_or_unprintable(key)) + else: + return None + else: + map[key] = result + return result + + def _has_key(self, row, key): + if key in self._keymap: + return True + else: + return self._key_fallback(key, False) is not None + + def __getstate__(self): + return { + '_pickled_keymap': dict( + (key, index) + for key, (processor, obj, index) in self._keymap.items() + if isinstance(key, util.string_types + util.int_types) + ), + 'keys': self.keys, + "case_sensitive": self.case_sensitive, + } + + def __setstate__(self, state): + # the row has been processed at pickling time so we don't need any + # processor anymore + self._processors = [None for _ in range(len(state['keys']))] + self._keymap = keymap = {} + for key, index in state['_pickled_keymap'].items(): + # not preserving "obj" here, unfortunately our + # proxy comparison fails with the unpickle + keymap[key] = (None, None, index) + self.keys = state['keys'] + self.case_sensitive = state['case_sensitive'] + self._echo = False + + +class ResultProxy(object): + """Wraps a DB-API cursor object to provide easier access to row columns. + + Individual columns may be accessed by their integer position, + case-insensitive column name, or by ``schema.Column`` + object. e.g.:: + + row = fetchone() + + col1 = row[0] # access via integer position + + col2 = row['col2'] # access via name + + col3 = row[mytable.c.mycol] # access via Column object. + + ``ResultProxy`` also handles post-processing of result column + data using ``TypeEngine`` objects, which are referenced from + the originating SQL statement that produced this result set. + + """ + + _process_row = RowProxy + out_parameters = None + _can_close_connection = False + _metadata = None + + def __init__(self, context): + self.context = context + self.dialect = context.dialect + self.closed = False + self.cursor = self._saved_cursor = context.cursor + self.connection = context.root_connection + self._echo = self.connection._echo and \ + context.engine._should_log_debug() + self._init_metadata() + + def _init_metadata(self): + metadata = self._cursor_description() + if metadata is not None: + self._metadata = ResultMetaData(self, metadata) + + def keys(self): + """Return the current set of string keys for rows.""" + if self._metadata: + return self._metadata.keys + else: + return [] + + @util.memoized_property + def rowcount(self): + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows *matched* + by the WHERE criterion of an UPDATE or DELETE statement. + + .. note:: + + Notes regarding :attr:`.ResultProxy.rowcount`: + + + * This attribute returns the number of rows *matched*, + which is not necessarily the same as the number of rows + that were actually *modified* - an UPDATE statement, for example, + may have no net change on a given row if the SET values + given are the same as those present in the row already. + Such a row would be matched but not modified. + On backends that feature both styles, such as MySQL, + rowcount is configured by default to return the match + count in all cases. + + * :attr:`.ResultProxy.rowcount` is *only* useful in conjunction + with an UPDATE or DELETE statement. Contrary to what the Python + DBAPI says, it does *not* return the + number of rows available from the results of a SELECT statement + as DBAPIs cannot support this functionality when rows are + unbuffered. + + * :attr:`.ResultProxy.rowcount` may not be fully implemented by + all dialects. In particular, most DBAPIs do not support an + aggregate rowcount result from an executemany call. + The :meth:`.ResultProxy.supports_sane_rowcount` and + :meth:`.ResultProxy.supports_sane_multi_rowcount` methods + will report from the dialect if each usage is known to be + supported. + + * Statements that use RETURNING may not return a correct + rowcount. + + """ + try: + return self.context.rowcount + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, self.cursor, self.context) + + @property + def lastrowid(self): + """return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary when + using insert() expression constructs; the + :attr:`~ResultProxy.inserted_primary_key` attribute provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ + try: + return self._saved_cursor.lastrowid + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, + self._saved_cursor, self.context) + + @property + def returns_rows(self): + """True if this :class:`.ResultProxy` returns rows. + + I.e. if it is legal to call the methods + :meth:`~.ResultProxy.fetchone`, + :meth:`~.ResultProxy.fetchmany` + :meth:`~.ResultProxy.fetchall`. + + """ + return self._metadata is not None + + @property + def is_insert(self): + """True if this :class:`.ResultProxy` is the result + of a executing an expression language compiled + :func:`.expression.insert` construct. + + When True, this implies that the + :attr:`inserted_primary_key` attribute is accessible, + assuming the statement did not include + a user defined "returning" construct. + + """ + return self.context.isinsert + + def _cursor_description(self): + """May be overridden by subclasses.""" + + return self._saved_cursor.description + + def close(self, _autoclose_connection=True): + """Close this ResultProxy. + + Closes the underlying DBAPI cursor corresponding to the execution. + + Note that any data cached within this ResultProxy is still available. + For some types of results, this may include buffered rows. + + If this ResultProxy was generated from an implicit execution, + the underlying Connection will also be closed (returns the + underlying DBAPI connection to the connection pool.) + + This method is called automatically when: + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. + + """ + + if not self.closed: + self.closed = True + self.connection._safe_close_cursor(self.cursor) + if _autoclose_connection and \ + self.connection.should_close_with_result: + self.connection.close() + # allow consistent errors + self.cursor = None + + def __iter__(self): + while True: + row = self.fetchone() + if row is None: + raise StopIteration + else: + yield row + + @util.memoized_property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + The return value is a list of scalar values + corresponding to the list of primary key columns + in the target table. + + This only applies to single row :func:`.insert` + constructs which did not explicitly specify + :meth:`.Insert.returning`. + + Note that primary key columns which specify a + server_default clause, + or otherwise do not qualify as "autoincrement" + columns (see the notes at :class:`.Column`), and were + generated using the database-side default, will + appear in this list as ``None`` unless the backend + supports "returning" and the insert statement executed + with the "implicit returning" enabled. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " + "expression construct.") + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() " + "expression construct.") + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError( + "Can't call inserted_primary_key " + "when returning() " + "is used.") + + return self.context.inserted_primary_key + + def last_updated_params(self): + """Return the collection of updated parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an update() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " + "expression construct.") + elif not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an update() " + "expression construct.") + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + def last_inserted_params(self): + """Return the collection of inserted parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " + "expression construct.") + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() " + "expression construct.") + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + @property + def returned_defaults(self): + """Return the values of default columns that were fetched using + the :meth:`.ValuesBase.return_defaults` feature. + + The value is an instance of :class:`.RowProxy`, or ``None`` + if :meth:`.ValuesBase.return_defaults` was not used or if the + backend does not support RETURNING. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.ValuesBase.return_defaults` + + """ + return self.context.returned_defaults + + def lastrow_has_defaults(self): + """Return ``lastrow_has_defaults()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + """ + + return self.context.lastrow_has_defaults() + + def postfetch_cols(self): + """Return ``postfetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " + "expression construct.") + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct.") + return self.context.postfetch_cols + + def prefetch_cols(self): + """Return ``prefetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled " + "expression construct.") + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct.") + return self.context.prefetch_cols + + def supports_sane_rowcount(self): + """Return ``supports_sane_rowcount`` from the dialect. + + See :attr:`.ResultProxy.rowcount` for background. + + """ + + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + """Return ``supports_sane_multi_rowcount`` from the dialect. + + See :attr:`.ResultProxy.rowcount` for background. + + """ + + return self.dialect.supports_sane_multi_rowcount + + def _fetchone_impl(self): + try: + return self.cursor.fetchone() + except AttributeError: + self._non_result() + + def _fetchmany_impl(self, size=None): + try: + if size is None: + return self.cursor.fetchmany() + else: + return self.cursor.fetchmany(size) + except AttributeError: + self._non_result() + + def _fetchall_impl(self): + try: + return self.cursor.fetchall() + except AttributeError: + self._non_result() + + def _non_result(self): + if self._metadata is None: + raise exc.ResourceClosedError( + "This result object does not return rows. " + "It has been closed automatically.", + ) + else: + raise exc.ResourceClosedError("This result object is closed.") + + def process_rows(self, rows): + process_row = self._process_row + metadata = self._metadata + keymap = metadata._keymap + processors = metadata._processors + if self._echo: + log = self.context.engine.logger.debug + l = [] + for row in rows: + log("Row %r", row) + l.append(process_row(metadata, row, processors, keymap)) + return l + else: + return [process_row(metadata, row, processors, keymap) + for row in rows] + + def fetchall(self): + """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" + + try: + l = self.process_rows(self._fetchall_impl()) + self.close() + return l + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, + self.cursor, self.context) + + def fetchmany(self, size=None): + """Fetch many rows, just like DB-API + ``cursor.fetchmany(size=cursor.arraysize)``. + + If rows are present, the cursor remains open after this is called. + Else the cursor is automatically closed and an empty list is returned. + + """ + + try: + l = self.process_rows(self._fetchmany_impl(size)) + if len(l) == 0: + self.close() + return l + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, + self.cursor, self.context) + + def fetchone(self): + """Fetch one row, just like DB-API ``cursor.fetchone()``. + + If a row is present, the cursor remains open after this is called. + Else the cursor is automatically closed and None is returned. + + """ + try: + row = self._fetchone_impl() + if row is not None: + return self.process_rows([row])[0] + else: + self.close() + return None + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, + self.cursor, self.context) + + def first(self): + """Fetch the first row and then close the result set unconditionally. + + Returns None if no row is present. + + """ + if self._metadata is None: + self._non_result() + + try: + row = self._fetchone_impl() + except Exception as e: + self.connection._handle_dbapi_exception( + e, None, None, + self.cursor, self.context) + + try: + if row is not None: + return self.process_rows([row])[0] + else: + return None + finally: + self.close() + + def scalar(self): + """Fetch the first column of the first row, and close the result set. + + Returns None if no row is present. + + """ + row = self.first() + if row is not None: + return row[0] + else: + return None + + +class BufferedRowResultProxy(ResultProxy): + """A ResultProxy with row buffering behavior. + + ``ResultProxy`` that buffers the contents of a selection of rows + before ``fetchone()`` is called. This is to allow the results of + ``cursor.description`` to be available immediately, when + interfacing with a DB-API that requires rows to be consumed before + this information is available (currently psycopg2, when used with + server-side cursors). + + The pre-fetching behavior fetches only one row initially, and then + grows its buffer size by a fixed amount with each successive need + for additional rows up to a size of 100. + """ + + def _init_metadata(self): + self.__buffer_rows() + super(BufferedRowResultProxy, self)._init_metadata() + + # this is a "growth chart" for the buffering of rows. + # each successive __buffer_rows call will use the next + # value in the list for the buffer size until the max + # is reached + size_growth = { + 1: 5, + 5: 10, + 10: 20, + 20: 50, + 50: 100, + 100: 250, + 250: 500, + 500: 1000 + } + + def __buffer_rows(self): + size = getattr(self, '_bufsize', 1) + self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) + self._bufsize = self.size_growth.get(size, size) + + def _fetchone_impl(self): + if self.closed: + return None + if not self.__rowbuffer: + self.__buffer_rows() + if not self.__rowbuffer: + return None + return self.__rowbuffer.popleft() + + def _fetchmany_impl(self, size=None): + if size is None: + return self._fetchall_impl() + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + self.__rowbuffer.extend(self.cursor.fetchall()) + ret = self.__rowbuffer + self.__rowbuffer = collections.deque() + return ret + + +class FullyBufferedResultProxy(ResultProxy): + """A result proxy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + def _init_metadata(self): + super(FullyBufferedResultProxy, self)._init_metadata() + self.__rowbuffer = self._buffer_rows() + + def _buffer_rows(self): + return collections.deque(self.cursor.fetchall()) + + def _fetchone_impl(self): + if self.__rowbuffer: + return self.__rowbuffer.popleft() + else: + return None + + def _fetchmany_impl(self, size=None): + if size is None: + return self._fetchall_impl() + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + self.__rowbuffer = collections.deque() + return ret + + +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row, processors, keymap): + # preprocess row + row = list(row) + # this is a tad faster than using enumerate + index = 0 + for processor in parent._orig_processors: + if processor is not None: + row[index] = processor(row[index]) + index += 1 + row = tuple(row) + super(BufferedColumnRow, self).__init__(parent, row, + processors, keymap) + + +class BufferedColumnResultProxy(ResultProxy): + """A ResultProxy with column buffering behavior. + + ``ResultProxy`` that loads all columns into memory each time + fetchone() is called. If fetchmany() or fetchall() are called, + the full grid of results is fetched. This is to operate with + databases where result rows contain "live" results that fall out + of scope unless explicitly fetched. Currently this includes + cx_Oracle LOB objects. + + """ + + _process_row = BufferedColumnRow + + def _init_metadata(self): + super(BufferedColumnResultProxy, self)._init_metadata() + metadata = self._metadata + # orig_processors will be used to preprocess each row when they are + # constructed. + metadata._orig_processors = metadata._processors + # replace the all type processors by None processors. + metadata._processors = [None for _ in range(len(metadata.keys))] + keymap = {} + for k, (func, obj, index) in metadata._keymap.items(): + keymap[k] = (None, obj, index) + self._metadata._keymap = keymap + + def fetchall(self): + # can't call cursor.fetchall(), since rows must be + # fully processed before requesting more from the DBAPI. + l = [] + while True: + row = self.fetchone() + if row is None: + break + l.append(row) + return l + + def fetchmany(self, size=None): + # can't call cursor.fetchmany(), since rows must be + # fully processed before requesting more from the DBAPI. + if size is None: + return self.fetchall() + l = [] + for i in range(size): + row = self.fetchone() + if row is None: + break + l.append(row) + return l diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py new file mode 100644 index 00000000..a8a63bb3 --- /dev/null +++ b/lib/sqlalchemy/engine/strategies.py @@ -0,0 +1,258 @@ +# engine/strategies.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Strategies for creating new instances of Engine types. + +These are semi-private implementation classes which provide the +underlying behavior for the "strategy" keyword argument available on +:func:`~sqlalchemy.engine.create_engine`. Current available options are +``plain``, ``threadlocal``, and ``mock``. + +New strategies can be added via new ``EngineStrategy`` classes. +""" + +from operator import attrgetter + +from sqlalchemy.engine import base, threadlocal, url +from sqlalchemy import util, exc, event +from sqlalchemy import pool as poollib + +strategies = {} + + +class EngineStrategy(object): + """An adaptor that processes input arguments and produces an Engine. + + Provides a ``create`` method that receives input arguments and + produces an instance of base.Engine or a subclass. + + """ + + def __init__(self): + strategies[self.name] = self + + def create(self, *args, **kwargs): + """Given arguments, returns a new Engine instance.""" + + raise NotImplementedError() + + +class DefaultEngineStrategy(EngineStrategy): + """Base class for built-in strategies.""" + + def create(self, name_or_url, **kwargs): + # create url.URL object + u = url.make_url(name_or_url) + + dialect_cls = u.get_dialect() + + if kwargs.pop('_coerce_config', False): + def pop_kwarg(key, default=None): + value = kwargs.pop(key, default) + if key in dialect_cls.engine_config_types: + value = dialect_cls.engine_config_types[key](value) + return value + else: + pop_kwarg = kwargs.pop + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = pop_kwarg(k) + + dbapi = kwargs.pop('module', None) + if dbapi is None: + dbapi_args = {} + for k in util.get_func_kwargs(dialect_cls.dbapi): + if k in kwargs: + dbapi_args[k] = pop_kwarg(k) + dbapi = dialect_cls.dbapi(**dbapi_args) + + dialect_args['dbapi'] = dbapi + + # create dialect + dialect = dialect_cls(**dialect_args) + + # assemble connection arguments + (cargs, cparams) = dialect.create_connect_args(u) + cparams.update(pop_kwarg('connect_args', {})) + + # look for existing pool or create + pool = pop_kwarg('pool', None) + if pool is None: + def connect(): + try: + return dialect.connect(*cargs, **cparams) + except dialect.dbapi.Error as e: + invalidated = dialect.is_disconnect(e, None, None) + util.raise_from_cause( + exc.DBAPIError.instance(None, None, + e, dialect.dbapi.Error, + connection_invalidated=invalidated + ) + ) + + creator = pop_kwarg('creator', connect) + + poolclass = pop_kwarg('poolclass', None) + if poolclass is None: + poolclass = dialect_cls.get_pool_class(u) + pool_args = {} + + # consume pool arguments from kwargs, translating a few of + # the arguments + translate = {'logging_name': 'pool_logging_name', + 'echo': 'echo_pool', + 'timeout': 'pool_timeout', + 'recycle': 'pool_recycle', + 'events': 'pool_events', + 'use_threadlocal': 'pool_threadlocal', + 'reset_on_return': 'pool_reset_on_return'} + for k in util.get_cls_kwargs(poolclass): + tk = translate.get(k, k) + if tk in kwargs: + pool_args[k] = pop_kwarg(tk) + pool = poolclass(creator, **pool_args) + else: + if isinstance(pool, poollib._DBProxy): + pool = pool.get_pool(*cargs, **cparams) + else: + pool = pool + + # create engine. + engineclass = self.engine_cls + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = pop_kwarg(k) + + _initialize = kwargs.pop('_initialize', True) + + # all kwargs should be consumed + if kwargs: + raise TypeError( + "Invalid argument(s) %s sent to create_engine(), " + "using configuration %s/%s/%s. Please check that the " + "keyword arguments are appropriate for this combination " + "of components." % (','.join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__)) + + engine = engineclass(pool, dialect, u, **engine_args) + + if _initialize: + do_on_connect = dialect.on_connect() + if do_on_connect: + def on_connect(dbapi_connection, connection_record): + conn = getattr( + dbapi_connection, '_sqla_unwrap', dbapi_connection) + if conn is None: + return + do_on_connect(conn) + + event.listen(pool, 'first_connect', on_connect) + event.listen(pool, 'connect', on_connect) + + def first_connect(dbapi_connection, connection_record): + c = base.Connection(engine, connection=dbapi_connection, + _has_events=False) + + dialect.initialize(c) + event.listen(pool, 'first_connect', first_connect, once=True) + + return engine + + +class PlainEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring a regular Engine.""" + + name = 'plain' + engine_cls = base.Engine + +PlainEngineStrategy() + + +class ThreadLocalEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring an Engine with threadlocal behavior.""" + + name = 'threadlocal' + engine_cls = threadlocal.TLEngine + +ThreadLocalEngineStrategy() + + +class MockEngineStrategy(EngineStrategy): + """Strategy for configuring an Engine-like object with mocked execution. + + Produces a single mock Connectable object which dispatches + statement execution to a passed-in function. + + """ + + name = 'mock' + + def create(self, name_or_url, executor, **kwargs): + # create url.URL object + u = url.make_url(name_or_url) + + dialect_cls = u.get_dialect() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + # create dialect + dialect = dialect_cls(**dialect_args) + + return MockEngineStrategy.MockConnection(dialect, executor) + + class MockConnection(base.Connectable): + def __init__(self, dialect, execute): + self._dialect = dialect + self.execute = execute + + engine = property(lambda s: s) + dialect = property(attrgetter('_dialect')) + name = property(lambda s: s._dialect.name) + + def contextual_connect(self, **kwargs): + return self + + def execution_options(self, **kw): + return self + + def compiler(self, statement, parameters, **kwargs): + return self._dialect.compiler( + statement, parameters, engine=self, **kwargs) + + def create(self, entity, **kwargs): + kwargs['checkfirst'] = False + from sqlalchemy.engine import ddl + + ddl.SchemaGenerator( + self.dialect, self, **kwargs).traverse_single(entity) + + def drop(self, entity, **kwargs): + kwargs['checkfirst'] = False + from sqlalchemy.engine import ddl + ddl.SchemaDropper( + self.dialect, self, **kwargs).traverse_single(entity) + + def _run_visitor(self, visitorcallable, element, + connection=None, + **kwargs): + kwargs['checkfirst'] = False + visitorcallable(self.dialect, self, + **kwargs).traverse_single(element) + + def execute(self, object, *multiparams, **params): + raise NotImplementedError() + +MockEngineStrategy() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py new file mode 100644 index 00000000..ae647a78 --- /dev/null +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -0,0 +1,134 @@ +# engine/threadlocal.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides a thread-local transactional wrapper around the root Engine class. + +The ``threadlocal`` module is invoked when using the +``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`. +This module is semi-private and is invoked automatically when the threadlocal +engine strategy is used. +""" + +from .. import util +from . import base +import weakref + + +class TLConnection(base.Connection): + + def __init__(self, *arg, **kw): + super(TLConnection, self).__init__(*arg, **kw) + self.__opencount = 0 + + def _increment_connect(self): + self.__opencount += 1 + return self + + def close(self): + if self.__opencount == 1: + base.Connection.close(self) + self.__opencount -= 1 + + def _force_close(self): + self.__opencount = 0 + base.Connection.close(self) + + +class TLEngine(base.Engine): + """An Engine that includes support for thread-local managed + transactions. + + """ + _tl_connection_cls = TLConnection + + def __init__(self, *args, **kwargs): + super(TLEngine, self).__init__(*args, **kwargs) + self._connections = util.threading.local() + + def contextual_connect(self, **kw): + if not hasattr(self._connections, 'conn'): + connection = None + else: + connection = self._connections.conn() + + if connection is None or connection.closed: + # guards against pool-level reapers, if desired. + # or not connection.connection.is_valid: + connection = self._tl_connection_cls( + self, self.pool.connect(), **kw) + self._connections.conn = weakref.ref(connection) + + return connection._increment_connect() + + def begin_twophase(self, xid=None): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append( + self.contextual_connect().begin_twophase(xid=xid)) + return self + + def begin_nested(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append( + self.contextual_connect().begin_nested()) + return self + + def begin(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin()) + return self + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None: + self.commit() + else: + self.rollback() + + def prepare(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return + self._connections.trans[-1].prepare() + + def commit(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return + trans = self._connections.trans.pop(-1) + trans.commit() + + def rollback(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return + trans = self._connections.trans.pop(-1) + trans.rollback() + + def dispose(self): + self._connections = util.threading.local() + super(TLEngine, self).dispose() + + @property + def closed(self): + return not hasattr(self._connections, 'conn') or \ + self._connections.conn() is None or \ + self._connections.conn().closed + + def close(self): + if not self.closed: + self.contextual_connect().close() + connection = self._connections.conn() + connection._force_close() + del self._connections.conn + self._connections.trans = [] + + def __repr__(self): + return 'TLEngine(%s)' % str(self.url) diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py new file mode 100644 index 00000000..78ac0618 --- /dev/null +++ b/lib/sqlalchemy/engine/url.py @@ -0,0 +1,227 @@ +# engine/url.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates +information about a database connection specification. + +The URL object is created automatically when +:func:`~sqlalchemy.engine.create_engine` is called with a string +argument; alternatively, the URL is a public-facing construct which can +be used directly and is also accepted directly by ``create_engine()``. +""" + +import re +from .. import exc, util +from . import Dialect +from ..dialects import registry + + +class URL(object): + """ + Represent the components of a URL used to connect to a database. + + This object is suitable to be passed directly to a + :func:`~sqlalchemy.create_engine` call. The fields of the URL are parsed from a + string by the :func:`.make_url` function. the string + format of the URL is an RFC-1738-style string. + + All initialization parameters are available as public attributes. + + :param drivername: the name of the database backend. + This name will correspond to a module in sqlalchemy/databases + or a third party plug-in. + + :param username: The user name. + + :param password: database password. + + :param host: The name of the host. + + :param port: The port number. + + :param database: The database name. + + :param query: A dictionary of options to be passed to the + dialect and/or the DBAPI upon connect. + + """ + + def __init__(self, drivername, username=None, password=None, + host=None, port=None, database=None, query=None): + self.drivername = drivername + self.username = username + self.password = password + self.host = host + if port is not None: + self.port = int(port) + else: + self.port = None + self.database = database + self.query = query or {} + + def __to_string__(self, hide_password=True): + s = self.drivername + "://" + if self.username is not None: + s += _rfc_1738_quote(self.username) + if self.password is not None: + s += ':' + ('***' if hide_password + else _rfc_1738_quote(self.password)) + s += "@" + if self.host is not None: + if ':' in self.host: + s += "[%s]" % self.host + else: + s += self.host + if self.port is not None: + s += ':' + str(self.port) + if self.database is not None: + s += '/' + self.database + if self.query: + keys = list(self.query) + keys.sort() + s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) + return s + + def __str__(self): + return self.__to_string__(hide_password=False) + + def __repr__(self): + return self.__to_string__() + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return \ + isinstance(other, URL) and \ + self.drivername == other.drivername and \ + self.username == other.username and \ + self.password == other.password and \ + self.host == other.host and \ + self.database == other.database and \ + self.query == other.query + + def get_dialect(self): + """Return the SQLAlchemy database dialect class corresponding + to this URL's driver name. + """ + + if '+' not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace('+', '.') + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if hasattr(cls, 'dialect') and \ + isinstance(cls.dialect, type) and \ + issubclass(cls.dialect, Dialect): + return cls.dialect + else: + return cls + + def translate_connect_args(self, names=[], **kw): + """Translate url attributes into a dictionary of connection arguments. + + Returns attributes of this url (`host`, `database`, `username`, + `password`, `port`) as a plain dictionary. The attribute names are + used as the keys by default. Unset or false attributes are omitted + from the final dictionary. + + :param \**kw: Optional, alternate key names for url attributes. + + :param names: Deprecated. Same purpose as the keyword-based alternate + names, but correlates the name to the original positionally. + """ + + translated = {} + attribute_names = ['host', 'database', 'username', 'password', 'port'] + for sname in attribute_names: + if names: + name = names.pop(0) + elif sname in kw: + name = kw[sname] + else: + name = sname + if name is not None and getattr(self, sname, False): + translated[name] = getattr(self, sname) + return translated + + +def make_url(name_or_url): + """Given a string or unicode instance, produce a new URL instance. + + The given string is parsed according to the RFC 1738 spec. If an + existing URL object is passed, just returns the object. + """ + + if isinstance(name_or_url, util.string_types): + return _parse_rfc1738_args(name_or_url) + else: + return name_or_url + + +def _parse_rfc1738_args(name): + pattern = re.compile(r''' + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P.*))? + @)? + (?: + (?: + \[(?P[^/]+)\] | + (?P[^/:]+) + )? + (?::(?P[^/]*))? + )? + (?:/(?P.*))? + ''', re.X) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + if components['database'] is not None: + tokens = components['database'].split('?', 2) + components['database'] = tokens[0] + query = (len(tokens) > 1 and dict(util.parse_qsl(tokens[1]))) or None + if util.py2k and query is not None: + query = dict((k.encode('ascii'), query[k]) for k in query) + else: + query = None + components['query'] = query + + if components['username'] is not None: + components['username'] = _rfc_1738_unquote(components['username']) + + if components['password'] is not None: + components['password'] = _rfc_1738_unquote(components['password']) + + ipv4host = components.pop('ipv4host') + ipv6host = components.pop('ipv6host') + components['host'] = ipv4host or ipv6host + name = components.pop('name') + return URL(name, **components) + else: + raise exc.ArgumentError( + "Could not parse rfc1738 URL from string '%s'" % name) + + +def _rfc_1738_quote(text): + return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + +def _rfc_1738_unquote(text): + return util.unquote(text) + +def _parse_keyvalue_args(name): + m = re.match(r'(\w+)://(.*)', name) + if m is not None: + (name, args) = m.group(1, 2) + opts = dict(util.parse_qsl(args)) + return URL(name, *opts) + else: + return None diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py new file mode 100644 index 00000000..6c0644be --- /dev/null +++ b/lib/sqlalchemy/engine/util.py @@ -0,0 +1,72 @@ +# engine/util.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .. import util + +def connection_memoize(key): + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + """ + + @util.decorator + def decorated(fn, self, connection): + connection = connection.connect() + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return decorated + + +def py_fallback(): + def _distill_params(multiparams, params): + """Given arguments from the calling form *multiparams, **params, + return a list of bind parameter structures, usually a list of + dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + """ + + if not multiparams: + if params: + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if not zero or hasattr(zero[0], '__iter__') and \ + not hasattr(zero[0], 'strip'): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero + else: + # execute(stmt, ("value", "value")) + return [zero] + elif hasattr(zero, 'keys'): + # execute(stmt, {"key":"value"}) + return [zero] + else: + # execute(stmt, "value") + return [[zero]] + else: + if hasattr(multiparams[0], '__iter__') and \ + not hasattr(multiparams[0], 'strip'): + return multiparams + else: + return [multiparams] + + return locals() +try: + from sqlalchemy.cutils import _distill_params +except ImportError: + globals().update(py_fallback()) diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py new file mode 100644 index 00000000..b43bf9bf --- /dev/null +++ b/lib/sqlalchemy/event/__init__.py @@ -0,0 +1,10 @@ +# event/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .api import CANCEL, NO_RETVAL, listen, listens_for, remove, contains +from .base import Events, dispatcher +from .attr import RefCollection +from .legacy import _legacy_signature diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py new file mode 100644 index 00000000..b27ce799 --- /dev/null +++ b/lib/sqlalchemy/event/api.py @@ -0,0 +1,131 @@ +# event/api.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Public API functions for the event system. + +""" +from __future__ import absolute_import + +from .. import util, exc +from .base import _registrars +from .registry import _EventKey + +CANCEL = util.symbol('CANCEL') +NO_RETVAL = util.symbol('NO_RETVAL') + + +def _event_key(target, identifier, fn): + for evt_cls in _registrars[identifier]: + tgt = evt_cls._accept_with(target) + if tgt is not None: + return _EventKey(target, identifier, fn, tgt) + else: + raise exc.InvalidRequestError("No such event '%s' for target '%s'" % + (identifier, target)) + +def listen(target, identifier, fn, *args, **kw): + """Register a listener function for the given target. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + event.listen( + UniqueConstraint, + "after_parent_attach", + unique_constraint_name) + + + A given function can also be invoked for only the first invocation + of the event using the ``once`` argument:: + + def on_config(): + do_config() + + event.listen(Mapper, "before_configure", on_config, once=True) + + .. versionadded:: 0.9.3 Added ``once=True`` to :func:`.event.listen` + and :func:`.event.listens_for`. + + """ + + _event_key(target, identifier, fn).listen(*args, **kw) + + +def listens_for(target, identifier, *args, **kw): + """Decorate a function as a listener for the given target + identifier. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + @event.listens_for(UniqueConstraint, "after_parent_attach") + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + + A given function can also be invoked for only the first invocation + of the event using the ``once`` argument:: + + @event.listens_for(Mapper, "before_configure", once=True) + def on_config(): + do_config() + + + .. versionadded:: 0.9.3 Added ``once=True`` to :func:`.event.listen` + and :func:`.event.listens_for`. + + """ + def decorate(fn): + listen(target, identifier, fn, *args, **kw) + return fn + return decorate + + +def remove(target, identifier, fn): + """Remove an event listener. + + The arguments here should match exactly those which were sent to + :func:`.listen`; all the event registration which proceeded as a result + of this call will be reverted by calling :func:`.remove` with the same + arguments. + + e.g.:: + + # if a function was registered like this... + @event.listens_for(SomeMappedClass, "before_insert", propagate=True) + def my_listener_function(*arg): + pass + + # ... it's removed like this + event.remove(SomeMappedClass, "before_insert", my_listener_function) + + Above, the listener function associated with ``SomeMappedClass`` was also + propagated to subclasses of ``SomeMappedClass``; the :func:`.remove` function + will revert all of these operations. + + .. versionadded:: 0.9.0 + + """ + _event_key(target, identifier, fn).remove() + +def contains(target, identifier, fn): + """Return True if the given target/ident/fn is set up to listen. + + .. versionadded:: 0.9.0 + + """ + + return _event_key(target, identifier, fn).contains() diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py new file mode 100644 index 00000000..b44aeefc --- /dev/null +++ b/lib/sqlalchemy/event/attr.py @@ -0,0 +1,386 @@ +# event/attr.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Attribute implementation for _Dispatch classes. + +The various listener targets for a particular event class are represented +as attributes, which refer to collections of listeners to be fired off. +These collections can exist at the class level as well as at the instance +level. An event is fired off using code like this:: + + some_object.dispatch.first_connect(arg1, arg2) + +Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and +``first_connect`` is typically an instance of ``_ListenerCollection`` +if event listeners are present, or ``_EmptyListener`` if none are present. + +The attribute mechanics here spend effort trying to ensure listener functions +are available with a minimum of function call overhead, that unnecessary +objects aren't created (i.e. many empty per-instance listener collections), +as well as that everything is garbage collectable when owning references are +lost. Other features such as "propagation" of listener functions across +many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances, +as well as support for subclass propagation (e.g. events assigned to +``Pool`` vs. ``QueuePool``) are all implemented here. + +""" + +from __future__ import absolute_import, with_statement + +from .. import util +from ..util import threading +from . import registry +from . import legacy +from itertools import chain +import weakref + + +class RefCollection(object): + @util.memoized_property + def ref(self): + return weakref.ref(self, registry._collection_gced) + +class _DispatchDescriptor(RefCollection): + """Class-level attributes on :class:`._Dispatch` classes.""" + + def __init__(self, parent_dispatch_cls, fn): + self.__name__ = fn.__name__ + argspec = util.inspect_getargspec(fn) + self.arg_names = argspec.args[1:] + self.has_kw = bool(argspec.keywords) + self.legacy_signatures = list(reversed( + sorted( + getattr(fn, '_legacy_signatures', []), + key=lambda s: s[0] + ) + )) + self.__doc__ = fn.__doc__ = legacy._augment_fn_docs( + self, parent_dispatch_cls, fn) + + self._clslevel = weakref.WeakKeyDictionary() + self._empty_listeners = weakref.WeakKeyDictionary() + + def _adjust_fn_spec(self, fn, named): + if named: + fn = self._wrap_fn_for_kw(fn) + if self.legacy_signatures: + try: + argspec = util.get_callable_argspec(fn, no_self=True) + except TypeError: + pass + else: + fn = legacy._wrap_fn_for_legacy(self, fn, argspec) + return fn + + def _wrap_fn_for_kw(self, fn): + def wrap_kw(*args, **kw): + argdict = dict(zip(self.arg_names, args)) + argdict.update(kw) + return fn(**argdict) + return wrap_kw + + + def insert(self, event_key, propagate): + target = event_key.dispatch_target + assert isinstance(target, type), \ + "Class-level Event targets must be classes." + stack = [target] + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + if cls not in self._clslevel: + self._clslevel[cls] = [] + self._clslevel[cls].insert(0, event_key._listen_fn) + registry._stored_in_collection(event_key, self) + + def append(self, event_key, propagate): + target = event_key.dispatch_target + assert isinstance(target, type), \ + "Class-level Event targets must be classes." + + stack = [target] + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + if cls not in self._clslevel: + self._clslevel[cls] = [] + self._clslevel[cls].append(event_key._listen_fn) + registry._stored_in_collection(event_key, self) + + def update_subclass(self, target): + if target not in self._clslevel: + self._clslevel[target] = [] + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend([ + fn for fn + in self._clslevel[cls] + if fn not in clslevel + ]) + + def remove(self, event_key): + target = event_key.dispatch_target + stack = [target] + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + if cls in self._clslevel: + self._clslevel[cls].remove(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self): + """Clear all class level listeners""" + + to_clear = set() + for dispatcher in self._clslevel.values(): + to_clear.update(dispatcher) + dispatcher[:] = [] + registry._clear(self, to_clear) + + def for_modify(self, obj): + """Return an event collection which can be modified. + + For _DispatchDescriptor at the class level of + a dispatcher, this returns self. + + """ + return self + + def __get__(self, obj, cls): + if obj is None: + return self + elif obj._parent_cls in self._empty_listeners: + ret = self._empty_listeners[obj._parent_cls] + else: + self._empty_listeners[obj._parent_cls] = ret = \ + _EmptyListener(self, obj._parent_cls) + # assigning it to __dict__ means + # memoized for fast re-access. but more memory. + obj.__dict__[self.__name__] = ret + return ret + +class _HasParentDispatchDescriptor(object): + def _adjust_fn_spec(self, fn, named): + return self.parent._adjust_fn_spec(fn, named) + +class _EmptyListener(_HasParentDispatchDescriptor): + """Serves as a class-level interface to the events + served by a _DispatchDescriptor, when there are no + instance-level events present. + + Is replaced by _ListenerCollection when instance-level + events are added. + + """ + def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self.parent = parent # _DispatchDescriptor + self.parent_listeners = parent._clslevel[target_cls] + self.name = parent.__name__ + self.propagate = frozenset() + self.listeners = () + + + def for_modify(self, obj): + """Return an event collection which can be modified. + + For _EmptyListener at the instance level of + a dispatcher, this generates a new + _ListenerCollection, applies it to the instance, + and returns it. + + """ + result = _ListenerCollection(self.parent, obj._parent_cls) + if obj.__dict__[self.name] is self: + obj.__dict__[self.name] = result + return result + + def _needs_modify(self, *args, **kw): + raise NotImplementedError("need to call for_modify()") + + exec_once = insert = append = remove = clear = _needs_modify + + def __call__(self, *args, **kw): + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + + def __len__(self): + return len(self.parent_listeners) + + def __iter__(self): + return iter(self.parent_listeners) + + def __bool__(self): + return bool(self.parent_listeners) + + __nonzero__ = __bool__ + + +class _CompoundListener(_HasParentDispatchDescriptor): + _exec_once = False + + @util.memoized_property + def _exec_once_mutex(self): + return threading.Lock() + + def exec_once(self, *args, **kw): + """Execute this event, but only if it has not been + executed already for this collection.""" + + if not self._exec_once: + with self._exec_once_mutex: + if not self._exec_once: + try: + self(*args, **kw) + finally: + self._exec_once = True + + def __call__(self, *args, **kw): + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + for fn in self.listeners: + fn(*args, **kw) + + def __len__(self): + return len(self.parent_listeners) + len(self.listeners) + + def __iter__(self): + return chain(self.parent_listeners, self.listeners) + + def __bool__(self): + return bool(self.listeners or self.parent_listeners) + + __nonzero__ = __bool__ + +class _ListenerCollection(RefCollection, _CompoundListener): + """Instance-level attributes on instances of :class:`._Dispatch`. + + Represents a collection of listeners. + + As of 0.7.9, _ListenerCollection is only first + created via the _EmptyListener.for_modify() method. + + """ + + def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self.parent_listeners = parent._clslevel[target_cls] + self.parent = parent + self.name = parent.__name__ + self.listeners = [] + self.propagate = set() + + def for_modify(self, obj): + """Return an event collection which can be modified. + + For _ListenerCollection at the instance level of + a dispatcher, this returns self. + + """ + return self + + def _update(self, other, only_propagate=True): + """Populate from the listeners in another :class:`_Dispatch` + object.""" + + existing_listeners = self.listeners + existing_listener_set = set(existing_listeners) + self.propagate.update(other.propagate) + other_listeners = [l for l + in other.listeners + if l not in existing_listener_set + and not only_propagate or l in self.propagate + ] + + existing_listeners.extend(other_listeners) + + to_associate = other.propagate.union(other_listeners) + registry._stored_in_collection_multi(self, other, to_associate) + + def insert(self, event_key, propagate): + if event_key._listen_fn not in self.listeners: + event_key.prepend_to_list(self, self.listeners) + if propagate: + self.propagate.add(event_key._listen_fn) + + def append(self, event_key, propagate): + if event_key._listen_fn not in self.listeners: + event_key.append_to_list(self, self.listeners) + if propagate: + self.propagate.add(event_key._listen_fn) + + def remove(self, event_key): + self.listeners.remove(event_key._listen_fn) + self.propagate.discard(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self): + registry._clear(self, self.listeners) + self.propagate.clear() + self.listeners[:] = [] + + +class _JoinedDispatchDescriptor(object): + def __init__(self, name): + self.name = name + + def __get__(self, obj, cls): + if obj is None: + return self + else: + obj.__dict__[self.name] = ret = _JoinedListener( + obj.parent, self.name, + getattr(obj.local, self.name) + ) + return ret + + +class _JoinedListener(_CompoundListener): + _exec_once = False + + def __init__(self, parent, name, local): + self.parent = parent + self.name = name + self.local = local + self.parent_listeners = self.local + + @property + def listeners(self): + return getattr(self.parent, self.name) + + def _adjust_fn_spec(self, fn, named): + return self.local._adjust_fn_spec(fn, named) + + def for_modify(self, obj): + self.local = self.parent_listeners = self.local.for_modify(obj) + return self + + def insert(self, event_key, propagate): + self.local.insert(event_key, propagate) + + def append(self, event_key, propagate): + self.local.append(event_key, propagate) + + def remove(self, event_key): + self.local.remove(event_key) + + def clear(self): + raise NotImplementedError() + + diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py new file mode 100644 index 00000000..5c8d92cb --- /dev/null +++ b/lib/sqlalchemy/event/base.py @@ -0,0 +1,217 @@ +# event/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base implementation classes. + +The public-facing ``Events`` serves as the base class for an event interface; +it's public attributes represent different kinds of events. These attributes +are mirrored onto a ``_Dispatch`` class, which serves as a container for +collections of listener functions. These collections are represented both +at the class level of a particular ``_Dispatch`` class as well as within +instances of ``_Dispatch``. + +""" +from __future__ import absolute_import + +from .. import util +from .attr import _JoinedDispatchDescriptor, _EmptyListener, _DispatchDescriptor + +_registrars = util.defaultdict(list) + + +def _is_event_name(name): + return not name.startswith('_') and name != 'dispatch' + + +class _UnpickleDispatch(object): + """Serializable callable that re-generates an instance of + :class:`_Dispatch` given a particular :class:`.Events` subclass. + + """ + def __call__(self, _parent_cls): + for cls in _parent_cls.__mro__: + if 'dispatch' in cls.__dict__: + return cls.__dict__['dispatch'].dispatch_cls(_parent_cls) + else: + raise AttributeError("No class with a 'dispatch' member present.") + + +class _Dispatch(object): + """Mirror the event listening definitions of an Events class with + listener collections. + + Classes which define a "dispatch" member will return a + non-instantiated :class:`._Dispatch` subclass when the member + is accessed at the class level. When the "dispatch" member is + accessed at the instance level of its owner, an instance + of the :class:`._Dispatch` class is returned. + + A :class:`._Dispatch` class is generated for each :class:`.Events` + class defined, by the :func:`._create_dispatcher_class` function. + The original :class:`.Events` classes remain untouched. + This decouples the construction of :class:`.Events` subclasses from + the implementation used by the event internals, and allows + inspecting tools like Sphinx to work in an unsurprising + way against the public API. + + """ + + _events = None + """reference the :class:`.Events` class which this + :class:`._Dispatch` is created for.""" + + def __init__(self, _parent_cls): + self._parent_cls = _parent_cls + + @util.classproperty + def _listen(cls): + return cls._events._listen + + def _join(self, other): + """Create a 'join' of this :class:`._Dispatch` and another. + + This new dispatcher will dispatch events to both + :class:`._Dispatch` objects. + + """ + if '_joined_dispatch_cls' not in self.__class__.__dict__: + cls = type( + "Joined%s" % self.__class__.__name__, + (_JoinedDispatcher, self.__class__), {} + ) + for ls in _event_descriptors(self): + setattr(cls, ls.name, _JoinedDispatchDescriptor(ls.name)) + + self.__class__._joined_dispatch_cls = cls + return self._joined_dispatch_cls(self, other) + + def __reduce__(self): + return _UnpickleDispatch(), (self._parent_cls, ) + + def _update(self, other, only_propagate=True): + """Populate from the listeners in another :class:`_Dispatch` + object.""" + + for ls in _event_descriptors(other): + if isinstance(ls, _EmptyListener): + continue + getattr(self, ls.name).\ + for_modify(self)._update(ls, only_propagate=only_propagate) + + @util.hybridmethod + def _clear(self): + for attr in dir(self): + if _is_event_name(attr): + getattr(self, attr).for_modify(self).clear() + + +def _event_descriptors(target): + return [getattr(target, k) for k in dir(target) if _is_event_name(k)] + + +class _EventMeta(type): + """Intercept new Event subclasses and create + associated _Dispatch classes.""" + + def __init__(cls, classname, bases, dict_): + _create_dispatcher_class(cls, classname, bases, dict_) + return type.__init__(cls, classname, bases, dict_) + + +def _create_dispatcher_class(cls, classname, bases, dict_): + """Create a :class:`._Dispatch` class corresponding to an + :class:`.Events` class.""" + + # there's all kinds of ways to do this, + # i.e. make a Dispatch class that shares the '_listen' method + # of the Event class, this is the straight monkeypatch. + dispatch_base = getattr(cls, 'dispatch', _Dispatch) + dispatch_cls = type("%sDispatch" % classname, + (dispatch_base, ), {}) + cls._set_dispatch(cls, dispatch_cls) + + for k in dict_: + if _is_event_name(k): + setattr(dispatch_cls, k, _DispatchDescriptor(cls, dict_[k])) + _registrars[k].append(cls) + + if getattr(cls, '_dispatch_target', None): + cls._dispatch_target.dispatch = dispatcher(cls) + + +def _remove_dispatcher(cls): + for k in dir(cls): + if _is_event_name(k): + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] + +class Events(util.with_metaclass(_EventMeta, object)): + """Define event listening functions for a particular target type.""" + + @staticmethod + def _set_dispatch(cls, dispatch_cls): + # this allows an Events subclass to define additional utility + # methods made available to the target via + # "self.dispatch._events." + # @staticemethod to allow easy "super" calls while in a metaclass + # constructor. + cls.dispatch = dispatch_cls + dispatch_cls._events = cls + + + @classmethod + def _accept_with(cls, target): + # Mapper, ClassManager, Session override this to + # also accept classes, scoped_sessions, sessionmakers, etc. + if hasattr(target, 'dispatch') and ( + isinstance(target.dispatch, cls.dispatch) or \ + isinstance(target.dispatch, type) and \ + issubclass(target.dispatch, cls.dispatch) + ): + return target + else: + return None + + @classmethod + def _listen(cls, event_key, propagate=False, insert=False, named=False): + event_key.base_listen(propagate=propagate, insert=insert, named=named) + + @classmethod + def _remove(cls, event_key): + event_key.remove() + + @classmethod + def _clear(cls): + cls.dispatch._clear() + + +class _JoinedDispatcher(object): + """Represent a connection between two _Dispatch objects.""" + + def __init__(self, local, parent): + self.local = local + self.parent = parent + self._parent_cls = local._parent_cls + + +class dispatcher(object): + """Descriptor used by target classes to + deliver the _Dispatch class at the class level + and produce new _Dispatch instances for target + instances. + + """ + def __init__(self, events): + self.dispatch_cls = events.dispatch + self.events = events + + def __get__(self, obj, cls): + if obj is None: + return self.dispatch_cls + obj.__dict__['dispatch'] = disp = self.dispatch_cls(cls) + return disp + diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py new file mode 100644 index 00000000..d8a66674 --- /dev/null +++ b/lib/sqlalchemy/event/legacy.py @@ -0,0 +1,156 @@ +# event/legacy.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Routines to handle adaption of legacy call signatures, +generation of deprecation notes and docstrings. + +""" + +from .. import util + +def _legacy_signature(since, argnames, converter=None): + def leg(fn): + if not hasattr(fn, '_legacy_signatures'): + fn._legacy_signatures = [] + fn._legacy_signatures.append((since, argnames, converter)) + return fn + return leg + +def _wrap_fn_for_legacy(dispatch_descriptor, fn, argspec): + for since, argnames, conv in dispatch_descriptor.legacy_signatures: + if argnames[-1] == "**kw": + has_kw = True + argnames = argnames[0:-1] + else: + has_kw = False + + if len(argnames) == len(argspec.args) \ + and has_kw is bool(argspec.keywords): + + if conv: + assert not has_kw + def wrap_leg(*args): + return fn(*conv(*args)) + else: + def wrap_leg(*args, **kw): + argdict = dict(zip(dispatch_descriptor.arg_names, args)) + args = [argdict[name] for name in argnames] + if has_kw: + return fn(*args, **kw) + else: + return fn(*args) + return wrap_leg + else: + return fn + +def _indent(text, indent): + return "\n".join( + indent + line + for line in text.split("\n") + ) + +def _standard_listen_example(dispatch_descriptor, sample_target, fn): + example_kw_arg = _indent( + "\n".join( + "%(arg)s = kw['%(arg)s']" % {"arg": arg} + for arg in dispatch_descriptor.arg_names[0:2] + ), + " ") + if dispatch_descriptor.legacy_signatures: + current_since = max(since for since, args, conv + in dispatch_descriptor.legacy_signatures) + else: + current_since = None + text = ( + "from sqlalchemy import event\n\n" + "# standard decorator style%(current_since)s\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" + ) + + if len(dispatch_descriptor.arg_names) > 3: + text += ( + + "\n# named argument style (new in 0.9)\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s', named=True)\n" + "def receive_%(event_name)s(**kw):\n" + " \"listen for the '%(event_name)s' event\"\n" + "%(example_kw_arg)s\n" + "\n # ... (event handling logic) ...\n" + ) + + text %= { + "current_since": " (arguments as of %s)" % + current_since if current_since else "", + "event_name": fn.__name__, + "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "", + "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), + "example_kw_arg": example_kw_arg, + "sample_target": sample_target + } + return text + +def _legacy_listen_examples(dispatch_descriptor, sample_target, fn): + text = "" + for since, args, conv in dispatch_descriptor.legacy_signatures: + text += ( + "\n# legacy calling style (pre-%(since)s)\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" % { + "since": since, + "event_name": fn.__name__, + "has_kw_arguments": " **kw" if dispatch_descriptor.has_kw else "", + "named_event_arguments": ", ".join(args), + "sample_target": sample_target + } + ) + return text + +def _version_signature_changes(dispatch_descriptor): + since, args, conv = dispatch_descriptor.legacy_signatures[0] + return ( + "\n.. versionchanged:: %(since)s\n" + " The ``%(event_name)s`` event now accepts the \n" + " arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n" + " Listener functions which accept the previous argument \n" + " signature(s) listed above will be automatically \n" + " adapted to the new signature." % { + "since": since, + "event_name": dispatch_descriptor.__name__, + "named_event_arguments": ", ".join(dispatch_descriptor.arg_names), + "has_kw_arguments": ", **kw" if dispatch_descriptor.has_kw else "" + } + ) + +def _augment_fn_docs(dispatch_descriptor, parent_dispatch_cls, fn): + header = ".. container:: event_signatures\n\n"\ + " Example argument forms::\n"\ + "\n" + + sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj") + text = ( + header + + _indent( + _standard_listen_example( + dispatch_descriptor, sample_target, fn), + " " * 8) + ) + if dispatch_descriptor.legacy_signatures: + text += _indent( + _legacy_listen_examples( + dispatch_descriptor, sample_target, fn), + " " * 8) + + text += _version_signature_changes(dispatch_descriptor) + + return util.inject_docstring_text(fn.__doc__, + text, + 1 + ) diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py new file mode 100644 index 00000000..6f3eb3e8 --- /dev/null +++ b/lib/sqlalchemy/event/registry.py @@ -0,0 +1,241 @@ +# event/registry.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides managed registration services on behalf of :func:`.listen` +arguments. + +By "managed registration", we mean that event listening functions and +other objects can be added to various collections in such a way that their +membership in all those collections can be revoked at once, based on +an equivalent :class:`._EventKey`. + +""" + +from __future__ import absolute_import + +import weakref +import collections +import types +from .. import exc, util + + +_key_to_collection = collections.defaultdict(dict) +""" +Given an original listen() argument, can locate all +listener collections and the listener fn contained + +(target, identifier, fn) -> { + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + } +""" + +_collection_to_key = collections.defaultdict(dict) +""" +Given a _ListenerCollection or _DispatchDescriptor, can locate +all the original listen() arguments and the listener fn contained + +ref(listenercollection) -> { + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + } +""" + +def _collection_gced(ref): + # defaultdict, so can't get a KeyError + if not _collection_to_key or ref not in _collection_to_key: + return + listener_to_key = _collection_to_key.pop(ref) + for key in listener_to_key.values(): + if key in _key_to_collection: + # defaultdict, so can't get a KeyError + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(ref) + if not dispatch_reg: + _key_to_collection.pop(key) + +def _stored_in_collection(event_key, owner): + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + owner_ref = owner.ref + listen_ref = weakref.ref(event_key._listen_fn) + + if owner_ref in dispatch_reg: + assert dispatch_reg[owner_ref] == listen_ref + else: + dispatch_reg[owner_ref] = listen_ref + + listener_to_key = _collection_to_key[owner_ref] + listener_to_key[listen_ref] = key + +def _removed_from_collection(event_key, owner): + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + listen_ref = weakref.ref(event_key._listen_fn) + + owner_ref = owner.ref + dispatch_reg.pop(owner_ref, None) + if not dispatch_reg: + del _key_to_collection[key] + + if owner_ref in _collection_to_key: + listener_to_key = _collection_to_key[owner_ref] + listener_to_key.pop(listen_ref) + +def _stored_in_collection_multi(newowner, oldowner, elements): + if not elements: + return + + oldowner = oldowner.ref + newowner = newowner.ref + + old_listener_to_key = _collection_to_key[oldowner] + new_listener_to_key = _collection_to_key[newowner] + + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + key = old_listener_to_key[listen_ref] + dispatch_reg = _key_to_collection[key] + if newowner in dispatch_reg: + assert dispatch_reg[newowner] == listen_ref + else: + dispatch_reg[newowner] = listen_ref + + new_listener_to_key[listen_ref] = key + +def _clear(owner, elements): + if not elements: + return + + owner = owner.ref + listener_to_key = _collection_to_key[owner] + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + key = listener_to_key[listen_ref] + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(owner, None) + + if not dispatch_reg: + del _key_to_collection[key] + + +class _EventKey(object): + """Represent :func:`.listen` arguments. + """ + + + def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): + self.target = target + self.identifier = identifier + self.fn = fn + if isinstance(fn, types.MethodType): + self.fn_key = id(fn.__func__), id(fn.__self__) + else: + self.fn_key = id(fn) + self.fn_wrap = _fn_wrap + self.dispatch_target = dispatch_target + + @property + def _key(self): + return (id(self.target), self.identifier, self.fn_key) + + def with_wrapper(self, fn_wrap): + if fn_wrap is self._listen_fn: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + self.dispatch_target, + _fn_wrap=fn_wrap + ) + + def with_dispatch_target(self, dispatch_target): + if dispatch_target is self.dispatch_target: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + dispatch_target, + _fn_wrap=self.fn_wrap + ) + + def listen(self, *args, **kw): + once = kw.pop("once", False) + if once: + self.with_wrapper(util.only_once(self._listen_fn)).listen(*args, **kw) + else: + self.dispatch_target.dispatch._listen(self, *args, **kw) + + def remove(self): + key = self._key + + if key not in _key_to_collection: + raise exc.InvalidRequestError( + "No listeners found for event %s / %r / %s " % + (self.target, self.identifier, self.fn) + ) + dispatch_reg = _key_to_collection.pop(key) + + for collection_ref, listener_ref in dispatch_reg.items(): + collection = collection_ref() + listener_fn = listener_ref() + if collection is not None and listener_fn is not None: + collection.remove(self.with_wrapper(listener_fn)) + + def contains(self): + """Return True if this event key is registered to listen. + """ + return self._key in _key_to_collection + + def base_listen(self, propagate=False, insert=False, + named=False): + + target, identifier, fn = \ + self.dispatch_target, self.identifier, self._listen_fn + + dispatch_descriptor = getattr(target.dispatch, identifier) + + fn = dispatch_descriptor._adjust_fn_spec(fn, named) + self = self.with_wrapper(fn) + + if insert: + dispatch_descriptor.\ + for_modify(target.dispatch).insert(self, propagate) + else: + dispatch_descriptor.\ + for_modify(target.dispatch).append(self, propagate) + + @property + def _listen_fn(self): + return self.fn_wrap or self.fn + + def append_value_to_list(self, owner, list_, value): + _stored_in_collection(self, owner) + list_.append(value) + + def append_to_list(self, owner, list_): + _stored_in_collection(self, owner) + list_.append(self._listen_fn) + + def remove_from_list(self, owner, list_): + _removed_from_collection(self, owner) + list_.remove(self._listen_fn) + + def prepend_to_list(self, owner, list_): + _stored_in_collection(self, owner) + list_.insert(0, self._listen_fn) + + diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py new file mode 100644 index 00000000..9ba6de68 --- /dev/null +++ b/lib/sqlalchemy/events.py @@ -0,0 +1,924 @@ +# sqlalchemy/events.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Core event interfaces.""" + +from . import event, exc +from .pool import Pool +from .engine import Connectable, Engine, Dialect +from .sql.base import SchemaEventTarget + +class DDLEvents(event.Events): + """ + Define event listeners for schema objects, + that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget` + subclasses, including :class:`.MetaData`, :class:`.Table`, + :class:`.Column`. + + :class:`.MetaData` and :class:`.Table` support events + specifically regarding when CREATE and DROP + DDL is emitted to the database. + + Attachment events are also provided to customize + behavior whenever a child schema element is associated + with a parent, such as, when a :class:`.Column` is associated + with its :class:`.Table`, when a :class:`.ForeignKeyConstraint` + is associated with a :class:`.Table`, etc. + + Example using the ``after_create`` event:: + + from sqlalchemy import event + from sqlalchemy import Table, Column, Metadata, Integer + + m = MetaData() + some_table = Table('some_table', m, Column('data', Integer)) + + def after_create(target, connection, **kw): + connection.execute("ALTER TABLE %s SET name=foo_%s" % + (target.name, target.name)) + + event.listen(some_table, "after_create", after_create) + + DDL events integrate closely with the + :class:`.DDL` class and the :class:`.DDLElement` hierarchy + of DDL clause constructs, which are themselves appropriate + as listener callables:: + + from sqlalchemy import DDL + event.listen( + some_table, + "after_create", + DDL("ALTER TABLE %(table)s SET name=foo_%(table)s") + ) + + The methods here define the name of an event as well + as the names of members that are passed to listener + functions. + + See also: + + :ref:`event_toplevel` + + :class:`.DDLElement` + + :class:`.DDL` + + :ref:`schema_ddl_sequences` + + """ + + _target_class_doc = "SomeSchemaClassOrObject" + _dispatch_target = SchemaEventTarget + + def before_create(self, target, connection, **kw): + """Called before CREATE statments are emitted. + + :param target: the :class:`.MetaData` or :class:`.Table` + object which is the target of the event. + :param connection: the :class:`.Connection` where the + CREATE statement or statements will be emitted. + :param \**kw: additional keyword arguments relevant + to the event. The contents of this dictionary + may vary across releases, and include the + list of tables being generated for a metadata-level + event, the checkfirst flag, and other + elements used by internal events. + + """ + + def after_create(self, target, connection, **kw): + """Called after CREATE statments are emitted. + + :param target: the :class:`.MetaData` or :class:`.Table` + object which is the target of the event. + :param connection: the :class:`.Connection` where the + CREATE statement or statements have been emitted. + :param \**kw: additional keyword arguments relevant + to the event. The contents of this dictionary + may vary across releases, and include the + list of tables being generated for a metadata-level + event, the checkfirst flag, and other + elements used by internal events. + + """ + + def before_drop(self, target, connection, **kw): + """Called before DROP statments are emitted. + + :param target: the :class:`.MetaData` or :class:`.Table` + object which is the target of the event. + :param connection: the :class:`.Connection` where the + DROP statement or statements will be emitted. + :param \**kw: additional keyword arguments relevant + to the event. The contents of this dictionary + may vary across releases, and include the + list of tables being generated for a metadata-level + event, the checkfirst flag, and other + elements used by internal events. + + """ + + def after_drop(self, target, connection, **kw): + """Called after DROP statments are emitted. + + :param target: the :class:`.MetaData` or :class:`.Table` + object which is the target of the event. + :param connection: the :class:`.Connection` where the + DROP statement or statements have been emitted. + :param \**kw: additional keyword arguments relevant + to the event. The contents of this dictionary + may vary across releases, and include the + list of tables being generated for a metadata-level + event, the checkfirst flag, and other + elements used by internal events. + + """ + + def before_parent_attach(self, target, parent): + """Called before a :class:`.SchemaItem` is associated with + a parent :class:`.SchemaItem`. + + :param target: the target object + :param parent: the parent to which the target is being attached. + + :func:`.event.listen` also accepts a modifier for this event: + + :param propagate=False: When True, the listener function will + be established for any copies made of the target object, + i.e. those copies that are generated when + :meth:`.Table.tometadata` is used. + + """ + + def after_parent_attach(self, target, parent): + """Called after a :class:`.SchemaItem` is associated with + a parent :class:`.SchemaItem`. + + :param target: the target object + :param parent: the parent to which the target is being attached. + + :func:`.event.listen` also accepts a modifier for this event: + + :param propagate=False: When True, the listener function will + be established for any copies made of the target object, + i.e. those copies that are generated when + :meth:`.Table.tometadata` is used. + + """ + + def column_reflect(self, inspector, table, column_info): + """Called for each unit of 'column info' retrieved when + a :class:`.Table` is being reflected. + + The dictionary of column information as returned by the + dialect is passed, and can be modified. The dictionary + is that returned in each element of the list returned + by :meth:`.reflection.Inspector.get_columns`. + + The event is called before any action is taken against + this dictionary, and the contents can be modified. + The :class:`.Column` specific arguments ``info``, ``key``, + and ``quote`` can also be added to the dictionary and + will be passed to the constructor of :class:`.Column`. + + Note that this event is only meaningful if either + associated with the :class:`.Table` class across the + board, e.g.:: + + from sqlalchemy.schema import Table + from sqlalchemy import event + + def listen_for_reflect(inspector, table, column_info): + "receive a column_reflect event" + # ... + + event.listen( + Table, + 'column_reflect', + listen_for_reflect) + + ...or with a specific :class:`.Table` instance using + the ``listeners`` argument:: + + def listen_for_reflect(inspector, table, column_info): + "receive a column_reflect event" + # ... + + t = Table( + 'sometable', + autoload=True, + listeners=[ + ('column_reflect', listen_for_reflect) + ]) + + This because the reflection process initiated by ``autoload=True`` + completes within the scope of the constructor for :class:`.Table`. + + """ + + + +class PoolEvents(event.Events): + """Available events for :class:`.Pool`. + + The methods here define the name of an event as well + as the names of members that are passed to listener + functions. + + e.g.:: + + from sqlalchemy import event + + def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): + "handle an on checkout event" + + event.listen(Pool, 'checkout', my_on_checkout) + + In addition to accepting the :class:`.Pool` class and + :class:`.Pool` instances, :class:`.PoolEvents` also accepts + :class:`.Engine` objects and the :class:`.Engine` class as + targets, which will be resolved to the ``.pool`` attribute of the + given engine or the :class:`.Pool` class:: + + engine = create_engine("postgresql://scott:tiger@localhost/test") + + # will associate with engine.pool + event.listen(engine, 'checkout', my_on_checkout) + + """ + + _target_class_doc = "SomeEngineOrPool" + _dispatch_target = Pool + + @classmethod + def _accept_with(cls, target): + if isinstance(target, type): + if issubclass(target, Engine): + return Pool + elif issubclass(target, Pool): + return target + elif isinstance(target, Engine): + return target.pool + else: + return target + + def connect(self, dbapi_connection, connection_record): + """Called at the moment a particular DBAPI connection is first + created for a given :class:`.Pool`. + + This event allows one to capture the point directly after which + the DBAPI module-level ``.connect()`` method has been used in order + to produce a new DBAPI connection. + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + """ + + def first_connect(self, dbapi_connection, connection_record): + """Called exactly once for the first time a DBAPI connection is + checked out from a particular :class:`.Pool`. + + The rationale for :meth:`.PoolEvents.first_connect` is to determine + information about a particular series of database connections based + on the settings used for all connections. Since a particular + :class:`.Pool` refers to a single "creator" function (which in terms + of a :class:`.Engine` refers to the URL and connection options used), + it is typically valid to make observations about a single connection + that can be safely assumed to be valid about all subsequent connections, + such as the database version, the server and client encoding settings, + collation settings, and many others. + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + """ + + def checkout(self, dbapi_connection, connection_record, connection_proxy): + """Called when a connection is retrieved from the Pool. + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + :param connection_proxy: the :class:`._ConnectionFairy` object which + will proxy the public interface of the DBAPI connection for the lifespan + of the checkout. + + If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + + .. seealso:: :meth:`.ConnectionEvents.engine_connect` - a similar event + which occurs upon creation of a new :class:`.Connection`. + + """ + + def checkin(self, dbapi_connection, connection_record): + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + """ + + def reset(self, dbapi_connnection, connection_record): + """Called before the "reset" action occurs for a pooled connection. + + This event represents + when the ``rollback()`` method is called on the DBAPI connection + before it is returned to the pool. The behavior of "reset" can + be controlled, including disabled, using the ``reset_on_return`` + pool argument. + + + The :meth:`.PoolEvents.reset` event is usually followed by the + the :meth:`.PoolEvents.checkin` event is called, except in those + cases where the connection is discarded immediately after reset. + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ConnectionEvents.rollback` + + :meth:`.ConnectionEvents.commit` + + """ + + def invalidate(self, dbapi_connection, connection_record, exception): + """Called when a DBAPI connection is to be "invalidated". + + This event is called any time the :meth:`._ConnectionRecord.invalidate` + method is invoked, either from API usage or via "auto-invalidation". + The event occurs before a final attempt to call ``.close()`` on the connection + occurs. + + :param dbapi_connection: a DBAPI connection. + + :param connection_record: the :class:`._ConnectionRecord` managing the + DBAPI connection. + + :param exception: the exception object corresponding to the reason + for this invalidation, if any. May be ``None``. + + .. versionadded:: 0.9.2 Added support for connection invalidation + listening. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + +class ConnectionEvents(event.Events): + """Available events for :class:`.Connectable`, which includes + :class:`.Connection` and :class:`.Engine`. + + The methods here define the name of an event as well as the names of + members that are passed to listener functions. + + An event listener can be associated with any :class:`.Connectable` + class or instance, such as an :class:`.Engine`, e.g.:: + + from sqlalchemy import event, create_engine + + def before_cursor_execute(conn, cursor, statement, parameters, context, + executemany): + log.info("Received statement: %s" % statement) + + engine = create_engine('postgresql://scott:tiger@localhost/test') + event.listen(engine, "before_cursor_execute", before_cursor_execute) + + or with a specific :class:`.Connection`:: + + with engine.begin() as conn: + @event.listens_for(conn, 'before_cursor_execute') + def before_cursor_execute(conn, cursor, statement, parameters, + context, executemany): + log.info("Received statement: %s" % statement) + + The :meth:`.before_execute` and :meth:`.before_cursor_execute` + events can also be established with the ``retval=True`` flag, which + allows modification of the statement and parameters to be sent + to the database. The :meth:`.before_cursor_execute` event is + particularly useful here to add ad-hoc string transformations, such + as comments, to all executions:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def comment_sql_calls(conn, cursor, statement, parameters, + context, executemany): + statement = statement + " -- some comment" + return statement, parameters + + .. note:: :class:`.ConnectionEvents` can be established on any + combination of :class:`.Engine`, :class:`.Connection`, as well + as instances of each of those classes. Events across all + four scopes will fire off for a given instance of + :class:`.Connection`. However, for performance reasons, the + :class:`.Connection` object determines at instantiation time + whether or not its parent :class:`.Engine` has event listeners + established. Event listeners added to the :class:`.Engine` + class or to an instance of :class:`.Engine` *after* the instantiation + of a dependent :class:`.Connection` instance will usually + *not* be available on that :class:`.Connection` instance. The newly + added listeners will instead take effect for :class:`.Connection` + instances created subsequent to those event listeners being + established on the parent :class:`.Engine` class or instance. + + :param retval=False: Applies to the :meth:`.before_execute` and + :meth:`.before_cursor_execute` events only. When True, the + user-defined event function must have a return value, which + is a tuple of parameters that replace the given statement + and parameters. See those methods for a description of + specific return arguments. + + .. versionchanged:: 0.8 :class:`.ConnectionEvents` can now be associated + with any :class:`.Connectable` including :class:`.Connection`, + in addition to the existing support for :class:`.Engine`. + + """ + + _target_class_doc = "SomeEngine" + _dispatch_target = Connectable + + + @classmethod + def _listen(cls, event_key, retval=False): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + target._has_events = True + + if not retval: + if identifier == 'before_execute': + orig_fn = fn + + def wrap_before_execute(conn, clauseelement, + multiparams, params): + orig_fn(conn, clauseelement, multiparams, params) + return clauseelement, multiparams, params + fn = wrap_before_execute + elif identifier == 'before_cursor_execute': + orig_fn = fn + + def wrap_before_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + orig_fn(conn, cursor, statement, + parameters, context, executemany) + return statement, parameters + fn = wrap_before_cursor_execute + + elif retval and \ + identifier not in ('before_execute', 'before_cursor_execute'): + raise exc.ArgumentError( + "Only the 'before_execute' and " + "'before_cursor_execute' engine " + "event listeners accept the 'retval=True' " + "argument.") + event_key.with_wrapper(fn).base_listen() + + def before_execute(self, conn, clauseelement, multiparams, params): + """Intercept high level execute() events, receiving uncompiled + SQL constructs and other objects prior to rendering into SQL. + + This event is good for debugging SQL compilation issues as well + as early manipulation of the parameters being sent to the database, + as the parameter lists will be in a consistent format here. + + This event can be optionally established with the ``retval=True`` + flag. The ``clauseelement``, ``multiparams``, and ``params`` + arguments should be returned as a three-tuple in this case:: + + @event.listens_for(Engine, "before_execute", retval=True) + def before_execute(conn, conn, clauseelement, multiparams, params): + # do something with clauseelement, multiparams, params + return clauseelement, multiparams, params + + :param conn: :class:`.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to :meth:`.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + + See also: + + :meth:`.before_cursor_execute` + + """ + + def after_execute(self, conn, clauseelement, multiparams, params, result): + """Intercept high level execute() events after execute. + + + :param conn: :class:`.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to :meth:`.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + :param result: :class:`.ResultProxy` generated by the execution. + + """ + + def before_cursor_execute(self, conn, cursor, statement, + parameters, context, executemany): + """Intercept low-level cursor execute() events before execution, + receiving the string + SQL statement and DBAPI-specific parameter list to be invoked + against a cursor. + + This event is a good choice for logging as well as late modifications + to the SQL string. It's less ideal for parameter modifications except + for those which are specific to a target backend. + + This event can be optionally established with the ``retval=True`` + flag. The ``statement`` and ``parameters`` arguments should be + returned as a two-tuple in this case:: + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def before_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + # do something with statement, parameters + return statement, parameters + + See the example at :class:`.ConnectionEvents`. + + :param conn: :class:`.Connection` object + :param cursor: DBAPI cursor object + :param statement: string SQL statement + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + See also: + + :meth:`.before_execute` + + :meth:`.after_cursor_execute` + + """ + + def after_cursor_execute(self, conn, cursor, statement, + parameters, context, executemany): + """Intercept low-level cursor execute() events after execution. + + :param conn: :class:`.Connection` object + :param cursor: DBAPI cursor object. Will have results pending + if the statement was a SELECT, but these should not be consumed + as they will be needed by the :class:`.ResultProxy`. + :param statement: string SQL statement + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + """ + + def dbapi_error(self, conn, cursor, statement, parameters, + context, exception): + """Intercept a raw DBAPI error. + + This event is called with the DBAPI exception instance + received from the DBAPI itself, *before* SQLAlchemy wraps the + exception with it's own exception wrappers, and before any + other operations are performed on the DBAPI cursor; the + existing transaction remains in effect as well as any state + on the cursor. + + The use case here is to inject low-level exception handling + into an :class:`.Engine`, typically for logging and + debugging purposes. In general, user code should **not** modify + any state or throw any exceptions here as this will + interfere with SQLAlchemy's cleanup and error handling + routines. + + Subsequent to this hook, SQLAlchemy may attempt any + number of operations on the connection/cursor, including + closing the cursor, rolling back of the transaction in the + case of connectionless execution, and disposing of the entire + connection pool if a "disconnect" was detected. The + exception is then wrapped in a SQLAlchemy DBAPI exception + wrapper and re-thrown. + + :param conn: :class:`.Connection` object + :param cursor: DBAPI cursor object + :param statement: string SQL statement + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param exception: The **unwrapped** exception emitted directly from the + DBAPI. The class here is specific to the DBAPI module in use. + + .. versionadded:: 0.7.7 + + """ + + def engine_connect(self, conn, branch): + """Intercept the creation of a new :class:`.Connection`. + + This event is called typically as the direct result of calling + the :meth:`.Engine.connect` method. + + It differs from the :meth:`.PoolEvents.connect` method, which + refers to the actual connection to a database at the DBAPI level; + a DBAPI connection may be pooled and reused for many operations. + In contrast, this event refers only to the production of a higher level + :class:`.Connection` wrapper around such a DBAPI connection. + + It also differs from the :meth:`.PoolEvents.checkout` event + in that it is specific to the :class:`.Connection` object, not the + DBAPI connection that :meth:`.PoolEvents.checkout` deals with, although + this DBAPI connection is available here via the :attr:`.Connection.connection` + attribute. But note there can in fact + be multiple :meth:`.PoolEvents.checkout` events within the lifespan + of a single :class:`.Connection` object, if that :class:`.Connection` + is invalidated and re-established. There can also be multiple + :class:`.Connection` objects generated for the same already-checked-out + DBAPI connection, in the case that a "branch" of a :class:`.Connection` + is produced. + + :param conn: :class:`.Connection` object. + :param branch: if True, this is a "branch" of an existing + :class:`.Connection`. A branch is generated within the course + of a statement execution to invoke supplemental statements, most + typically to pre-execute a SELECT of a default value for the purposes + of an INSERT statement. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.PoolEvents.checkout` the lower-level pool checkout event + for an individual DBAPI connection + + :meth:`.ConnectionEvents.set_connection_execution_options` - a copy of a + :class:`.Connection` is also made when the + :meth:`.Connection.execution_options` method is called. + + """ + + def set_connection_execution_options(self, conn, opts): + """Intercept when the :meth:`.Connection.execution_options` + method is called. + + This method is called after the new :class:`.Connection` has been + produced, with the newly updated execution options collection, but + before the :class:`.Dialect` has acted upon any of those new options. + + Note that this method is not called when a new :class:`.Connection` + is produced which is inheriting execution options from its parent + :class:`.Engine`; to intercept this condition, use the + :meth:`.ConnectionEvents.engine_connect` event. + + :param conn: The newly copied :class:`.Connection` object + + :param opts: dictionary of options that were passed to the + :meth:`.Connection.execution_options` method. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.ConnectionEvents.set_engine_execution_options` - event + which is called when :meth:`.Engine.execution_options` is called. + + + """ + + def set_engine_execution_options(self, engine, opts): + """Intercept when the :meth:`.Engine.execution_options` + method is called. + + The :meth:`.Engine.execution_options` method produces a shallow + copy of the :class:`.Engine` which stores the new options. That new + :class:`.Engine` is passed here. A particular application of this + method is to add a :meth:`.ConnectionEvents.engine_connect` event + handler to the given :class:`.Engine` which will perform some per- + :class:`.Connection` task specific to these execution options. + + :param conn: The newly copied :class:`.Engine` object + + :param opts: dictionary of options that were passed to the + :meth:`.Connection.execution_options` method. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.ConnectionEvents.set_connection_execution_options` - event + which is called when :meth:`.Connection.execution_options` is called. + + """ + + def begin(self, conn): + """Intercept begin() events. + + :param conn: :class:`.Connection` object + + """ + + def rollback(self, conn): + """Intercept rollback() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`.Pool` also "auto-rolls back" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to its default value of ``'rollback'``. + To intercept this + rollback, use the :meth:`.PoolEvents.reset` hook. + + :param conn: :class:`.Connection` object + + .. seealso:: + + :meth:`.PoolEvents.reset` + + """ + + def commit(self, conn): + """Intercept commit() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`.Pool` may also "auto-commit" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to the value ``'commit'``. To intercept this + commit, use the :meth:`.PoolEvents.reset` hook. + + :param conn: :class:`.Connection` object + """ + + def savepoint(self, conn, name): + """Intercept savepoint() events. + + :param conn: :class:`.Connection` object + :param name: specified name used for the savepoint. + + """ + + def rollback_savepoint(self, conn, name, context): + """Intercept rollback_savepoint() events. + + :param conn: :class:`.Connection` object + :param name: specified name used for the savepoint. + :param context: :class:`.ExecutionContext` in use. May be ``None``. + + """ + + def release_savepoint(self, conn, name, context): + """Intercept release_savepoint() events. + + :param conn: :class:`.Connection` object + :param name: specified name used for the savepoint. + :param context: :class:`.ExecutionContext` in use. May be ``None``. + + """ + + def begin_twophase(self, conn, xid): + """Intercept begin_twophase() events. + + :param conn: :class:`.Connection` object + :param xid: two-phase XID identifier + + """ + + def prepare_twophase(self, conn, xid): + """Intercept prepare_twophase() events. + + :param conn: :class:`.Connection` object + :param xid: two-phase XID identifier + """ + + def rollback_twophase(self, conn, xid, is_prepared): + """Intercept rollback_twophase() events. + + :param conn: :class:`.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + def commit_twophase(self, conn, xid, is_prepared): + """Intercept commit_twophase() events. + + :param conn: :class:`.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + +class DialectEvents(event.Events): + """event interface for execution-replacement functions. + + These events allow direct instrumentation and replacement + of key dialect functions which interact with the DBAPI. + + .. note:: + + :class:`.DialectEvents` hooks should be considered **semi-public** + and experimental. + These hooks are not for general use and are only for those situations where + intricate re-statement of DBAPI mechanics must be injected onto an existing + dialect. For general-use statement-interception events, please + use the :class:`.ConnectionEvents` interface. + + .. seealso:: + + :meth:`.ConnectionEvents.before_cursor_execute` + + :meth:`.ConnectionEvents.before_execute` + + :meth:`.ConnectionEvents.after_cursor_execute` + + :meth:`.ConnectionEvents.after_execute` + + + .. versionadded:: 0.9.4 + + """ + + _target_class_doc = "SomeEngine" + _dispatch_target = Dialect + + @classmethod + def _listen(cls, event_key, retval=False): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + target._has_events = True + event_key.base_listen() + + @classmethod + def _accept_with(cls, target): + if isinstance(target, type): + if issubclass(target, Engine): + return Dialect + elif issubclass(target, Dialect): + return target + elif isinstance(target, Engine): + return target.dialect + else: + return target + + def do_executemany(self, cursor, statement, parameters, context): + """Receive a cursor to have executemany() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute_no_params(self, cursor, statement, context): + """Receive a cursor to have execute() with no parameters called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute(self, cursor, statement, parameters, context): + """Receive a cursor to have execute() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py new file mode 100644 index 00000000..68e517e2 --- /dev/null +++ b/lib/sqlalchemy/exc.py @@ -0,0 +1,363 @@ +# sqlalchemy/exc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Exceptions used with SQLAlchemy. + +The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are +raised as a result of DBAPI exceptions are all subclasses of +:exc:`.DBAPIError`. + +""" + +import traceback + + +class SQLAlchemyError(Exception): + """Generic error class.""" + + +class ArgumentError(SQLAlchemyError): + """Raised when an invalid or conflicting function argument is supplied. + + This error generally corresponds to construction time state errors. + + """ + +class NoSuchModuleError(ArgumentError): + """Raised when a dynamically-loaded module (usually a database dialect) + of a particular name cannot be located.""" + +class NoForeignKeysError(ArgumentError): + """Raised when no foreign keys can be located between two selectables + during a join.""" + + +class AmbiguousForeignKeysError(ArgumentError): + """Raised when more than one foreign key matching can be located + between two selectables during a join.""" + + +class CircularDependencyError(SQLAlchemyError): + """Raised by topological sorts when a circular dependency is detected. + + There are two scenarios where this error occurs: + + * In a Session flush operation, if two objects are mutually dependent + on each other, they can not be inserted or deleted via INSERT or + DELETE statements alone; an UPDATE will be needed to post-associate + or pre-deassociate one of the foreign key constrained values. + The ``post_update`` flag described at :ref:`post_update` can resolve + this cycle. + * In a :meth:`.MetaData.create_all`, :meth:`.MetaData.drop_all`, + :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey` + or :class:`.ForeignKeyConstraint` objects mutually refer to each + other. Apply the ``use_alter=True`` flag to one or both, + see :ref:`use_alter`. + + """ + def __init__(self, message, cycles, edges, msg=None): + if msg is None: + message += " Cycles: %r all edges: %r" % (cycles, edges) + else: + message = msg + SQLAlchemyError.__init__(self, message) + self.cycles = cycles + self.edges = edges + + def __reduce__(self): + return self.__class__, (None, self.cycles, + self.edges, self.args[0]) + + +class CompileError(SQLAlchemyError): + """Raised when an error occurs during SQL compilation""" + +class UnsupportedCompilationError(CompileError): + """Raised when an operation is not supported by the given compiler. + + + .. versionadded:: 0.8.3 + + """ + + def __init__(self, compiler, element_type): + super(UnsupportedCompilationError, self).__init__( + "Compiler %r can't render element of type %s" % + (compiler, element_type)) + +class IdentifierError(SQLAlchemyError): + """Raised when a schema name is beyond the max character limit""" + + +class DisconnectionError(SQLAlchemyError): + """A disconnect is detected on a raw DB-API connection. + + This error is raised and consumed internally by a connection pool. It can + be raised by the :meth:`.PoolEvents.checkout` event so that the host pool + forces a retry; the exception will be caught three times in a row before + the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError` + regarding the connection attempt. + + """ + + +class TimeoutError(SQLAlchemyError): + """Raised when a connection pool times out on getting a connection.""" + + +class InvalidRequestError(SQLAlchemyError): + """SQLAlchemy was asked to do something it can't do. + + This error generally corresponds to runtime state errors. + + """ + + +class NoInspectionAvailable(InvalidRequestError): + """A subject passed to :func:`sqlalchemy.inspection.inspect` produced + no context for inspection.""" + + +class ResourceClosedError(InvalidRequestError): + """An operation was requested from a connection, cursor, or other + object that's in a closed state.""" + + +class NoSuchColumnError(KeyError, InvalidRequestError): + """A nonexistent column is requested from a ``RowProxy``.""" + + +class NoReferenceError(InvalidRequestError): + """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + + +class NoReferencedTableError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be + located. + + """ + def __init__(self, message, tname): + NoReferenceError.__init__(self, message) + self.table_name = tname + + def __reduce__(self): + return self.__class__, (self.args[0], self.table_name) + + +class NoReferencedColumnError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Column`` cannot be + located. + + """ + def __init__(self, message, tname, cname): + NoReferenceError.__init__(self, message) + self.table_name = tname + self.column_name = cname + + def __reduce__(self): + return self.__class__, (self.args[0], self.table_name, + self.column_name) + + +class NoSuchTableError(InvalidRequestError): + """Table does not exist or is not visible to a connection.""" + + +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" + + +class DontWrapMixin(object): + """A mixin class which, when applied to a user-defined Exception class, + will not be wrapped inside of :exc:`.StatementError` if the error is + emitted within the process of executing a statement. + + E.g.:: + + from sqlalchemy.exc import DontWrapMixin + + class MyCustomException(Exception, DontWrapMixin): + pass + + class MySpecialType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value == 'invalid': + raise MyCustomException("invalid!") + + """ + +# Moved to orm.exc; compatibility definition installed by orm import until 0.6 +UnmappedColumnError = None + + +class StatementError(SQLAlchemyError): + """An error occurred during execution of a SQL statement. + + :class:`StatementError` wraps the exception raised + during execution, and features :attr:`.statement` + and :attr:`.params` attributes which supply context regarding + the specifics of the statement which had an issue. + + The wrapped exception object is available in + the :attr:`.orig` attribute. + + """ + + statement = None + """The string SQL statement being invoked when this exception occurred.""" + + params = None + """The parameter list being used when this exception occurred.""" + + orig = None + """The DBAPI exception object.""" + + def __init__(self, message, statement, params, orig): + SQLAlchemyError.__init__(self, message) + self.statement = statement + self.params = params + self.orig = orig + self.detail = [] + + def add_detail(self, msg): + self.detail.append(msg) + + def __reduce__(self): + return self.__class__, (self.args[0], self.statement, + self.params, self.orig) + + def __str__(self): + from sqlalchemy.sql import util + params_repr = util._repr_params(self.params, 10) + + return ' '.join([ + "(%s)" % det for det in self.detail + ] + [ + SQLAlchemyError.__str__(self), + repr(self.statement), repr(params_repr) + ]) + + def __unicode__(self): + return self.__str__() + + +class DBAPIError(StatementError): + """Raised when the execution of a database operation fails. + + Wraps exceptions raised by the DB-API underlying the + database operation. Driver-specific implementations of the standard + DB-API exception types are wrapped by matching sub-types of SQLAlchemy's + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note + that there is no guarantee that different DB-API implementations will + raise the same exception type for any given error condition. + + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context + regarding the specifics of the statement which had an issue, for the + typical case when the error was raised within the context of + emitting a SQL statement. + + The wrapped exception object is available in the + :attr:`~.StatementError.orig` attribute. Its type and properties are + DB-API implementation specific. + + """ + + @classmethod + def instance(cls, statement, params, + orig, + dbapi_base_err, + connection_invalidated=False): + # Don't ever wrap these, just return them directly as if + # DBAPIError didn't exist. + if isinstance(orig, (KeyboardInterrupt, SystemExit, DontWrapMixin)): + return orig + + if orig is not None: + # not a DBAPI error, statement is present. + # raise a StatementError + if not isinstance(orig, dbapi_base_err) and statement: + msg = traceback.format_exception_only( + orig.__class__, orig)[-1].strip() + return StatementError( + "%s (original cause: %s)" % (str(orig), msg), + statement, params, orig + ) + + name, glob = orig.__class__.__name__, globals() + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + + return cls(statement, params, orig, connection_invalidated) + + def __reduce__(self): + return self.__class__, (self.statement, self.params, + self.orig, self.connection_invalidated) + + def __init__(self, statement, params, orig, connection_invalidated=False): + try: + text = str(orig) + except (KeyboardInterrupt, SystemExit): + raise + except Exception as e: + text = 'Error in str() of DB-API-generated exception: ' + str(e) + StatementError.__init__( + self, + '(%s) %s' % (orig.__class__.__name__, text), + statement, + params, + orig + ) + self.connection_invalidated = connection_invalidated + + +class InterfaceError(DBAPIError): + """Wraps a DB-API InterfaceError.""" + + +class DatabaseError(DBAPIError): + """Wraps a DB-API DatabaseError.""" + + +class DataError(DatabaseError): + """Wraps a DB-API DataError.""" + + +class OperationalError(DatabaseError): + """Wraps a DB-API OperationalError.""" + + +class IntegrityError(DatabaseError): + """Wraps a DB-API IntegrityError.""" + + +class InternalError(DatabaseError): + """Wraps a DB-API InternalError.""" + + +class ProgrammingError(DatabaseError): + """Wraps a DB-API ProgrammingError.""" + + +class NotSupportedError(DatabaseError): + """Wraps a DB-API NotSupportedError.""" + + +# Warnings + +class SADeprecationWarning(DeprecationWarning): + """Issued once per usage of a deprecated API.""" + + +class SAPendingDeprecationWarning(PendingDeprecationWarning): + """Issued once per usage of a deprecated API.""" + + +class SAWarning(RuntimeWarning): + """Issued at runtime.""" diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py new file mode 100644 index 00000000..1d77acaa --- /dev/null +++ b/lib/sqlalchemy/ext/__init__.py @@ -0,0 +1,5 @@ +# ext/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py new file mode 100644 index 00000000..045645f8 --- /dev/null +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,1053 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +import itertools +import operator +import weakref +from .. import exc, orm, util +from ..orm import collections, interfaces +from ..sql import not_, or_ + + +def association_proxy(target_collection, attr, **kw): + """Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. + + The returned value is an instance of :class:`.AssociationProxy`. + + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. + + :param target_collection: Name of the attribute we'll proxy to. + This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. + + :param attr: Attribute on the associated instance or instances we'll + proxy for. + + For example, given a target collection of [obj1, obj2], a list created + by this proxy property would look like [getattr(obj1, *attr*), + getattr(obj2, *attr*)] + + If the relationship is one-to-one or otherwise uselist=False, then + simply: getattr(obj, *attr*) + + :param creator: optional. + + When new items are added to this proxied collection, new instances of + the class collected by the target collection will be created. For list + and set collections, the target class constructor will be called with + the 'value' for the new instance. For dict types, two arguments are + passed: key and value. + + If you want to construct instances differently, supply a *creator* + function that takes arguments as above and returns instances. + + For scalar relationships, creator() will be called if the target is None. + If the target is present, set operations are proxied to setattr() on the + associated object. + + If you have an associated object with multiple attributes, you may set + up multiple association proxies mapping to different attributes. See + the unit tests for examples, and for examples of how creator() functions + can be used to construct the scalar relationship on-demand in this + situation. + + :param \*\*kw: Passes along any other keyword arguments to + :class:`.AssociationProxy`. + + """ + return AssociationProxy(target_collection, attr, **kw) + + +ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +"""Symbol indicating an :class:`_InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`._InspectionAttr.extension_type` + attibute. + +""" + +class AssociationProxy(interfaces._InspectionAttr): + """A descriptor that presents a read/write view of an object attribute.""" + + is_attribute = False + extension_type = ASSOCIATION_PROXY + + + def __init__(self, target_collection, attr, creator=None, + getset_factory=None, proxy_factory=None, + proxy_bulk_set=None): + """Construct a new :class:`.AssociationProxy`. + + The :func:`.association_proxy` function is provided as the usual + entrypoint here, though :class:`.AssociationProxy` can be instantiated + and/or subclassed directly. + + :param target_collection: Name of the collection we'll proxy to, + usually created with :func:`.relationship`. + + :param attr: Attribute on the collected instances we'll proxy + for. For example, given a target collection of [obj1, obj2], a + list created by this proxy property would look like + [getattr(obj1, attr), getattr(obj2, attr)] + + :param creator: Optional. When new items are added to this proxied + collection, new instances of the class collected by the target + collection will be created. For list and set collections, the + target class constructor will be called with the 'value' for the + new instance. For dict types, two arguments are passed: + key and value. + + If you want to construct instances differently, supply a 'creator' + function that takes arguments as above and returns instances. + + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. + + :param proxy_bulk_set: Optional, use with proxy_factory. See + the _set() method for details. + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + + self.owning_class = None + self.key = '_%s_%s_%s' % ( + type(self).__name__, target_collection, id(self)) + self.collection_class = None + + @property + def remote_attr(self): + """The 'remote' :class:`.MapperProperty` referenced by this + :class:`.AssociationProxy`. + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.attr` + + :attr:`.AssociationProxy.local_attr` + + """ + return getattr(self.target_class, self.value_attr) + + @property + def local_attr(self): + """The 'local' :class:`.MapperProperty` referenced by this + :class:`.AssociationProxy`. + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.attr` + + :attr:`.AssociationProxy.remote_attr` + + """ + return getattr(self.owning_class, self.target_collection) + + @property + def attr(self): + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute is convenient when specifying a join + using :meth:`.Query.join` across two relationships:: + + sess.query(Parent).join(*Parent.proxied.attr) + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.local_attr` + + :attr:`.AssociationProxy.remote_attr` + + """ + return (self.local_attr, self.remote_attr) + + def _get_property(self): + return (orm.class_mapper(self.owning_class). + get_property(self.target_collection)) + + @util.memoized_property + def target_class(self): + """The intermediary class handled by this :class:`.AssociationProxy`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + return self._get_property().mapper.class_ + + @util.memoized_property + def scalar(self): + """Return ``True`` if this :class:`.AssociationProxy` proxies a scalar + relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self): + return not self._get_property().\ + mapper.get_property(self.value_attr).uselist + + @util.memoized_property + def _target_is_object(self): + return getattr(self.target_class, self.value_attr).impl.uses_objects + + def __get__(self, obj, class_): + if self.owning_class is None: + self.owning_class = class_ and class_ or type(obj) + if obj is None: + return self + + if self.scalar: + target = getattr(obj, self.target_collection) + return self._scalar_get(target) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, proxy = getattr(obj, self.key) + if id(obj) == creator_id: + return proxy + except AttributeError: + pass + proxy = self._new(_lazy_collection(obj, self.target_collection)) + setattr(obj, self.key, (id(obj), proxy)) + return proxy + + def __set__(self, obj, values): + if self.owning_class is None: + self.owning_class = type(obj) + + if self.scalar: + creator = self.creator and self.creator or self.target_class + target = getattr(obj, self.target_collection) + if target is None: + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + else: + proxy = self.__get__(obj, None) + if proxy is not values: + proxy.clear() + self._set(proxy, values) + + def __delete__(self, obj): + if self.owning_class is None: + self.owning_class = type(obj) + delattr(obj, self.key) + + def _initialize_scalar_accessors(self): + if self.getset_factory: + get, set = self.getset_factory(None, self) + else: + get, set = self._default_getset(None) + self._scalar_get, self._scalar_set = get, set + + def _default_getset(self, collection_class): + attr = self.value_attr + _getter = operator.attrgetter(attr) + getter = lambda target: _getter(target) if target is not None else None + if collection_class is dict: + setter = lambda o, k, v: setattr(o, attr, v) + else: + setter = lambda o, v: setattr(o, attr, v) + return getter, setter + + def _new(self, lazy_collection): + creator = self.creator and self.creator or self.target_class + self.collection_class = util.duck_type_collection(lazy_collection()) + + if self.proxy_factory: + return self.proxy_factory( + lazy_collection, creator, self.value_attr, self) + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + if self.collection_class is list: + return _AssociationList( + lazy_collection, creator, getter, setter, self) + elif self.collection_class is dict: + return _AssociationDict( + lazy_collection, creator, getter, setter, self) + elif self.collection_class is set: + return _AssociationSet( + lazy_collection, creator, getter, setter, self) + else: + raise exc.ArgumentError( + 'could not guess which interface to use for ' + 'collection_class "%s" backing "%s"; specify a ' + 'proxy_factory and proxy_bulk_set manually' % + (self.collection_class.__name__, self.target_collection)) + + def _inflate(self, proxy): + creator = self.creator and self.creator or self.target_class + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _set(self, proxy, values): + if self.proxy_bulk_set: + self.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + proxy.extend(values) + elif self.collection_class is dict: + proxy.update(values) + elif self.collection_class is set: + proxy.update(values) + else: + raise exc.ArgumentError( + 'no proxy_bulk_set supplied for custom ' + 'collection_class implementation') + + @property + def _comparator(self): + return self._get_property().comparator + + def any(self, criterion=None, **kwargs): + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + + if self._value_is_scalar: + value_expr = getattr( + self.target_class, self.value_attr).has(criterion, **kwargs) + else: + value_expr = getattr( + self.target_class, self.value_attr).any(criterion, **kwargs) + + # check _value_is_scalar here, otherwise + # we're scalar->scalar - call .any() so that + # the "can't call any() on a scalar" msg is raised. + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + value_expr + ) + else: + return self._comparator.any( + value_expr + ) + + def has(self, criterion=None, **kwargs): + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + + if self._target_is_object: + return self._comparator.has( + getattr(self.target_class, self.value_attr).\ + has(criterion, **kwargs) + ) + else: + if criterion is not None or kwargs: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==") + return self._comparator.has() + + def contains(self, obj): + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + , :meth:`.RelationshipProperty.Comparator.has`, + and/or :meth:`.RelationshipProperty.Comparator.contains` + operators of the underlying proxied attributes. + """ + + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(obj) + ) + else: + return self._comparator.any(**{self.value_attr: obj}) + + def __eq__(self, obj): + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None + ) + else: + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj): + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj) + + +class _lazy_collection(object): + def __init__(self, obj, target): + self.ref = weakref.ref(obj) + self.target = target + + def __call__(self): + obj = self.ref() + if obj is None: + raise exc.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") + return getattr(obj, self.target) + + def __getstate__(self): + return {'obj': self.ref(), 'target': self.target} + + def __setstate__(self, state): + self.ref = weakref.ref(state['obj']) + self.target = state['target'] + + +class _AssociationCollection(object): + def __init__(self, lazy_collection, creator, getter, setter, parent): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + lazy_collection + A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship()) + + creator + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + + getter + A function. Given an associated object, return the 'value'. + + setter + A function. Given an associated object and a value, store that + value on the object. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + col = property(lambda self: self.lazy_collection()) + + def __len__(self): + return len(self.col) + + def __bool__(self): + return bool(self.col) + + __nonzero__ = __bool__ + + def __getstate__(self): + return {'parent': self.parent, 'lazy_collection': self.lazy_collection} + + def __setstate__(self, state): + self.parent = state['parent'] + self.lazy_collection = state['lazy_collection'] + self.parent._inflate(self) + + +class _AssociationList(_AssociationCollection): + """Generic, converting, list-to-list proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) + + def __getitem__(self, index): + return self._get(self.col[index]) + + def __setitem__(self, index, value): + if not isinstance(index, slice): + self._set(self.col[index], value) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) + if step == 1: + for i in rng: + del self[start] + i = start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + def __delitem__(self, index): + del self.col[index] + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __getslice__(self, start, end): + return [self._get(member) for member in self.col[start:end]] + + def __setslice__(self, start, end, values): + members = [self._create(v) for v in values] + self.col[start:end] = members + + def __delslice__(self, start, end): + del self.col[start:end] + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + raise StopIteration + + def append(self, value): + item = self._create(value) + self.col.append(item) + + def count(self, value): + return sum([1 for _ in + util.itertools_filter(lambda v: v == value, iter(self))]) + + def extend(self, values): + for v in values: + self.append(v) + + def insert(self, index, value): + self.col[index:index] = [self._create(value)] + + def pop(self, index=-1): + return self.getter(self.col.pop(index)) + + def remove(self, value): + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self): + """Not supported, use reversed(mylist)""" + + raise NotImplementedError + + def sort(self): + """Not supported, use sorted(mylist)""" + + raise NotImplementedError + + def clear(self): + del self.col[0:len(self.col)] + + def __eq__(self, other): + return list(self) == other + + def __ne__(self, other): + return list(self) != other + + def __lt__(self, other): + return list(self) < other + + def __le__(self, other): + return list(self) <= other + + def __gt__(self, other): + return list(self) > other + + def __ge__(self, other): + return list(self) >= other + + def __cmp__(self, other): + return cmp(list(self), other) + + def __add__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n): + if not isinstance(n, int): + return NotImplemented + return list(self) * n + __rmul__ = __mul__ + + def __iadd__(self, iterable): + self.extend(iterable) + return self + + def __imul__(self, n): + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + return NotImplemented + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + def copy(self): + return list(self) + + def __repr__(self): + return repr(list(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(list, func_name)): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +_NotProvided = util.symbol('_NotProvided') + + +class _AssociationDict(_AssociationCollection): + """Generic, converting, dict-to-dict proxy.""" + + def _create(self, key, value): + return self.creator(key, value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, key, value): + return self.setter(object, key, value) + + def __getitem__(self, key): + return self._get(self.col[key]) + + def __setitem__(self, key, value): + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key): + del self.col[key] + + def __contains__(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def has_key(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def __iter__(self): + return iter(self.col.keys()) + + def clear(self): + self.col.clear() + + def __eq__(self, other): + return dict(self) == other + + def __ne__(self, other): + return dict(self) != other + + def __lt__(self, other): + return dict(self) < other + + def __le__(self, other): + return dict(self) <= other + + def __gt__(self, other): + return dict(self) > other + + def __ge__(self, other): + return dict(self) >= other + + def __cmp__(self, other): + return cmp(dict(self), other) + + def __repr__(self): + return repr(dict(self.items())) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + if key not in self.col: + self.col[key] = self._create(key, default) + return default + else: + return self[key] + + def keys(self): + return self.col.keys() + + if util.py2k: + def iteritems(self): + return ((key, self._get(self.col[key])) for key in self.col) + + def itervalues(self): + return (self._get(self.col[key]) for key in self.col) + + def iterkeys(self): + return self.col.iterkeys() + + def values(self): + return [self._get(member) for member in self.col.values()] + + def items(self): + return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): + return ((key, self._get(self.col[key])) for key in self.col) + + def values(self): + return (self._get(self.col[key]) for key in self.col) + + def pop(self, key, default=_NotProvided): + if default is _NotProvided: + member = self.col.pop(key) + else: + member = self.col.pop(key, default) + return self._get(member) + + def popitem(self): + item = self.col.popitem() + return (item[0], self._get(item[1])) + + def update(self, *a, **kw): + if len(a) > 1: + raise TypeError('update expected at most 1 arguments, got %i' % + len(a)) + elif len(a) == 1: + seq_or_map = a[0] + # discern dict from sequence - took the advice from + # http://www.voidspace.org.uk/python/articles/duck_typing.shtml + # still not perfect :( + if hasattr(seq_or_map, 'keys'): + for item in seq_or_map: + self[item] = seq_or_map[item] + else: + try: + for k, v in seq_or_map: + self[k] = v + except ValueError: + raise ValueError( + "dictionary update sequence " + "requires 2-element tuples") + + for key, value in kw: + self[key] = value + + def copy(self): + return dict(self.items()) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(dict, func_name)): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationCollection): + """Generic, converting, set-to-set proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object): + return self.getter(object) + + def _set(self, object, value): + return self.setter(object, value) + + def __len__(self): + return len(self.col) + + def __bool__(self): + if self.col: + return True + else: + return False + + __nonzero__ = __bool__ + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + raise StopIteration + + def add(self, value): + if value not in self: + self.col.add(self._create(value)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + break + + def remove(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + return + raise KeyError(value) + + def pop(self): + if not self.col: + raise KeyError('pop from an empty set') + member = self.col.pop() + return self._get(member) + + def update(self, other): + for value in other: + self.add(value) + + def __ior__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.add(value) + return self + + def _set(self): + return set(iter(self)) + + def union(self, other): + return set(self).union(other) + + __or__ = union + + def difference(self, other): + return set(self).difference(other) + + __sub__ = difference + + def difference_update(self, other): + for value in other: + self.discard(value) + + def __isub__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.discard(value) + return self + + def intersection(self, other): + return set(self).intersection(other) + + __and__ = intersection + + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, other): + return set(self).symmetric_difference(other) + + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def issubset(self, other): + return set(self).issubset(other) + + def issuperset(self, other): + return set(self).issuperset(other) + + def clear(self): + self.col.clear() + + def copy(self): + return set(self) + + def __eq__(self, other): + return set(self) == other + + def __ne__(self, other): + return set(self) != other + + def __lt__(self, other): + return set(self) < other + + def __le__(self, other): + return set(self) <= other + + def __gt__(self, other): + return set(self) > other + + def __ge__(self, other): + return set(self) >= other + + def __repr__(self): + return repr(set(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(set, func_name)): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py new file mode 100644 index 00000000..dfc838da --- /dev/null +++ b/lib/sqlalchemy/ext/automap.py @@ -0,0 +1,907 @@ +# ext/automap.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system +which automatically generates mapped classes and relationships from a database +schema, typically though not necessarily one which is reflected. + +.. versionadded:: 0.9.1 Added :mod:`sqlalchemy.ext.automap`. + +.. note:: + + The :mod:`sqlalchemy.ext.automap` extension should be considered + **experimental** as of 0.9.1. Featureset and API stability is + not guaranteed at this time. + +It is hoped that the :class:`.AutomapBase` system provides a quick +and modernized solution to the problem that the very famous +`SQLSoup `_ +also tries to solve, that of generating a quick and rudimentary object +model from an existing database on the fly. By addressing the issue strictly +at the mapper configuration level, and integrating fully with existing +Declarative class techniques, :class:`.AutomapBase` seeks to provide +a well-integrated approach to the issue of expediently auto-generating ad-hoc +mappings. + + +Basic Use +========= + +The simplest usage is to reflect an existing database into a new model. +We create a new :class:`.AutomapBase` class in a similar manner as to how +we create a declarative base class, using :func:`.automap_base`. +We then call :meth:`.AutomapBase.prepare` on the resulting base class, +asking it to reflect the schema and produce mappings:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy.orm import Session + from sqlalchemy import create_engine + + Base = automap_base() + + # engine, suppose it has two tables 'user' and 'address' set up + engine = create_engine("sqlite:///mydatabase.db") + + # reflect the tables + Base.prepare(engine, reflect=True) + + # mapped classes are now created with names by default + # matching that of the table name. + User = Base.classes.user + Address = Base.classes.address + + session = Session(engine) + + # rudimentary relationships are produced + session.add(Address(email_address="foo@bar.com", user=User(name="foo"))) + session.commit() + + # collection-based relationships are by default named "_collection" + print (u1.address_collection) + +Above, calling :meth:`.AutomapBase.prepare` while passing along the +:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the +:meth:`.MetaData.reflect` method will be called on this declarative base +classes' :class:`.MetaData` collection; then, each viable +:class:`.Table` within the :class:`.MetaData` will get a new mapped class +generated automatically. The :class:`.ForeignKeyConstraint` objects which +link the various tables together will be used to produce new, bidirectional +:func:`.relationship` objects between classes. The classes and relationships +follow along a default naming scheme that we can customize. At this point, +our basic mapping consisting of related ``User`` and ``Address`` classes is ready +to use in the traditional way. + +Generating Mappings from an Existing MetaData +============================================= + +We can pass a pre-declared :class:`.MetaData` object to :func:`.automap_base`. +This object can be constructed in any way, including programmatically, from +a serialized file, or from itself being reflected using :meth:`.MetaData.reflect`. +Below we illustrate a combination of reflection and explicit table declaration:: + + from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey + engine = create_engine("sqlite:///mydatabase.db") + + # produce our own MetaData object + metadata = MetaData() + + # we can reflect it ourselves from a database, using options + # such as 'only' to limit what tables we look at... + metadata.reflect(engine, only=['user', 'address']) + + # ... or just define our own Table objects with it (or combine both) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) + + # we can then produce a set of mappings from this MetaData. + Base = automap_base(metadata=metadata) + + # calling prepare() just sets up mapped classes and relationships. + Base.prepare() + + # mapped classes are ready + User, Address, Order = Base.classes.user, Base.classes.address, Base.classes.user_order + +Specifying Classes Explcitly +============================ + +The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined +explicitly, in a way similar to that of the :class:`.DeferredReflection` class. +Classes that extend from :class:`.AutomapBase` act like regular declarative +classes, but are not immediately mapped after their construction, and are instead +mapped when we call :meth:`.AutomapBase.prepare`. The :meth:`.AutomapBase.prepare` +method will make use of the classes we've established based on the table name +we use. If our schema contains tables ``user`` and ``address``, we can define +one or both of the classes to be used:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + # pre-declare User for the 'user' table + class User(Base): + __tablename__ = 'user' + + # override schema elements like Columns + user_name = Column('name', String) + + # override relationships too, if desired. + # we must use the same name that automap would use for the relationship, + # and also must refer to the class name that automap will generate + # for "address" + address_collection = relationship("address", collection_class=set) + + # reflect + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(engine, reflect=True) + + # we still have Address generated from the tablename "address", + # but User is the same as Base.classes.User now + + Address = Base.classes.address + + u1 = session.query(User).first() + print (u1.address_collection) + + # the backref is still there: + a1 = session.query(Address).first() + print (a1.user) + +Above, one of the more intricate details is that we illustrated overriding +one of the :func:`.relationship` objects that automap would have created. +To do this, we needed to make sure the names match up with what automap +would normally generate, in that the relationship name would be ``User.address_collection`` +and the name of the class referred to, from automap's perspective, is called +``address``, even though we are referring to it as ``Address`` within our usage +of this class. + +Overriding Naming Schemes +========================= + +:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and +relationship names based on a schema, which means it has decision points in how +these names are determined. These three decision points are provided using +functions which can be passed to the :meth:`.AutomapBase.prepare` method, and +are known as :func:`.classname_for_table`, +:func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship`. Any or all of these +functions are provided as in the example below, where we use a "camel case" +scheme for class names and a "pluralizer" for collection names using the +`Inflect `_ package:: + + import re + import inflect + + def camelize_classname(base, tablename, table): + "Produce a 'camelized' class name, e.g. " + "'words_and_underscores' -> 'WordsAndUnderscores'" + + return str(tablename[0].upper() + \\ + re.sub(r'_(\w)', lambda m: m.group(1).upper(), tablename[1:])) + + _pluralizer = inflect.engine() + def pluralize_collection(base, local_cls, referred_cls, constraint): + "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "'SomeTerm' -> 'some_terms'" + + referred_name = referred_cls.__name__ + uncamelized = referred_name[0].lower() + \\ + re.sub(r'\W', + lambda m: "_%s" % m.group(0).lower(), + referred_name[1:]) + pluralized = _pluralizer.plural(uncamelized) + return pluralized + + from sqlalchemy.ext.automap import automap_base + + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + + Base.prepare(engine, reflect=True, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) + +From the above mapping, we would now have classes ``User`` and ``Address``, +where the collection from ``User`` to ``Address`` is called ``User.addresses``:: + + User, Address = Base.classes.User, Base.classes.Address + + u1 = User(addresses=[Address(email="foo@bar.com")]) + +Relationship Detection +====================== + +The vast majority of what automap accomplishes is the generation of +:func:`.relationship` structures based on foreign keys. The mechanism +by which this works for many-to-one and one-to-many relationships is as follows: + +1. A given :class:`.Table`, known to be mapped to a particular class, + is examined for :class:`.ForeignKeyConstraint` objects. + +2. From each :class:`.ForeignKeyConstraint`, the remote :class:`.Table` + object present is matched up to the class to which it is to be mapped, + if any, else it is skipped. + +3. As the :class:`.ForeignKeyConstraint` we are examining correponds to a reference + from the immediate mapped class, + the relationship will be set up as a many-to-one referring to the referred class; + a corresponding one-to-many backref will be created on the referred class referring + to this class. + +4. The names of the relationships are determined using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + callable functions. It is important to note that the default relationship + naming derives the name from the **the actual class name**. If you've + given a particular class an explicit name by declaring it, or specified an + alternate class naming scheme, that's the name from which the relationship + name will be derived. + +5. The classes are inspected for an existing mapped property matching these + names. If one is detected on one side, but none on the other side, :class:`.AutomapBase` + attempts to create a relationship on the missing side, then uses the + :paramref:`.relationship.back_populates` parameter in order to point + the new relationship to the other side. + +6. In the usual case where no relationship is on either side, + :meth:`.AutomapBase.prepare` produces a :func:`.relationship` on the "many-to-one" + side and matches it to the other using the :paramref:`.relationship.backref` + parameter. + +7. Production of the :func:`.relationship` and optionally the :func:`.backref` + is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` + function, which can be supplied by the end-user in order to augment + the arguments passed to :func:`.relationship` or :func:`.backref` or to + make use of custom implementations of these functions. + +Custom Relationship Arguments +----------------------------- + +The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used +to add parameters to relationships. For most cases, we can make use of the +existing :func:`.automap.generate_relationship` function to return +the object, after augmenting the given keyword dictionary with our own +arguments. + +Below is an illustration of how to send +:paramref:`.relationship.cascade` and +:paramref:`.relationship.passive_deletes` +options along to all one-to-many relationships:: + + from sqlalchemy.ext.automap import generate_relationship + + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): + if direction is interfaces.ONETOMANY: + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True + # make use of the built-in function to actually return + # the result. + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(engine, reflect=True, + generate_relationship=_gen_relationship) + +Many-to-Many relationships +-------------------------- + +:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g. +those which contain a ``secondary`` argument. The process for producing these +is as follows: + +1. A given :class:`.Table` is examined for :class:`.ForeignKeyConstraint` objects, + before any mapped class has been assigned to it. + +2. If the table contains two and exactly two :class:`.ForeignKeyConstraint` + objects, and all columns within this table are members of these two + :class:`.ForeignKeyConstraint` objects, the table is assumed to be a + "secondary" table, and will **not be mapped directly**. + +3. The two (or one, for self-referential) external tables to which the :class:`.Table` + refers to are matched to the classes to which they will be mapped, if any. + +4. If mapped classes for both sides are located, a many-to-many bi-directional + :func:`.relationship` / :func:`.backref` pair is created between the two + classes. + +5. The override logic for many-to-many works the same as that of one-to-many/ + many-to-one; the :func:`.generate_relationship` function is called upon + to generate the strucures and existing attributes will be maintained. + +Relationships with Inheritance +------------------------------ + +:mod:`.sqlalchemy.ext.automap` will not generate any relationships between +two classes that are in an inheritance relationship. That is, with two classes +given as follows:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on': type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity':'engineer', + } + +The foreign key from ``Engineer`` to ``Employee`` is used not for a relationship, +but to establish joined inheritance between the two classes. + +Note that this means automap will not generate *any* relationships +for foreign keys that link from a subclass to a superclass. If a mapping +has actual relationships from subclass to superclass as well, those +need to be explicit. Below, as we have two separate foreign keys +from ``Engineer`` to ``Employee``, we need to set up both the relationship +we want as well as the ``inherit_condition``, as these are not things +SQLAlchemy can guess:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on':type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + + favorite_employee = relationship(Employee, foreign_keys=favorite_employee_id) + + __mapper_args__ = { + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id + } + + +Using Automap with Explicit Declarations +======================================== + +As noted previously, automap has no dependency on reflection, and can make +use of any collection of :class:`.Table` objects within a :class:`.MetaData` +collection. From this, it follows that automap can also be used +generate missing relationships given an otherwise complete model that fully defines +table metadata:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import Column, Integer, String, ForeignKey + + Base = automap_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + name = Column(String) + + class Address(Base): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + email = Column(String) + user_id = Column(ForeignKey('user.id')) + + # produce relationships + Base.prepare() + + # mapping is complete, with "address_collection" and + # "user" relationships + a1 = Address(email='u1') + a2 = Address(email='u2') + u1 = User(address_collection=[a1, a2]) + assert a1.user is u1 + +Above, given mostly complete ``User`` and ``Address`` mappings, the +:class:`.ForeignKey` which we defined on ``Address.user_id`` allowed a +bidirectional relationship pair ``Address.user`` and ``User.address_collection`` +to be generated on the mapped classes. + +Note that when subclassing :class:`.AutomapBase`, the :meth:`.AutomapBase.prepare` +method is required; if not called, the classes we've declared are in an +un-mapped state. + + +""" +from .declarative import declarative_base as _declarative_base +from .declarative.base import _DeferredMapperConfig +from ..sql import and_ +from ..schema import ForeignKeyConstraint +from ..orm import relationship, backref, interfaces +from .. import util + + +def classname_for_table(base, tablename, table): + """Return the class name that should be used, given the name + of a table. + + The default implementation is:: + + return str(tablename) + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.classname_for_table` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param tablename: string name of the :class:`.Table`. + + :param table: the :class:`.Table` object itself. + + :return: a string class name. + + .. note:: + + In Python 2, the string used for the class name **must** be a non-Unicode + object, e.g. a ``str()`` object. The ``.name`` attribute of + :class:`.Table` is typically a Python unicode subclass, so the ``str()`` + function should be applied to this name, after accounting for any non-ASCII + characters. + + """ + return str(tablename) + +def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + """Return the attribute name that should be used to refer from one + class to another, for a scalar object reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + +def name_for_collection_relationship(base, local_cls, referred_cls, constraint): + """Return the attribute name that should be used to refer from one + class to another, for a collection reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + "_collection" + + Alternate implementations + can be specified using the :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + "_collection" + +def generate_relationship(base, direction, return_fn, attrname, local_cls, referred_cls, **kw): + """Generate a :func:`.relationship` or :func:`.backref` on behalf of two + mapped classes. + + An alternate implementation of this function can be specified using the + :paramref:`.AutomapBase.prepare.generate_relationship` parameter. + + The default implementation of this function is as follows:: + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param direction: indicate the "direction" of the relationship; this will + be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOONE`. + + :param return_fn: the function that is used by default to create the + relationship. This will be either :func:`.relationship` or :func:`.backref`. + The :func:`.backref` function's result will be used to produce a new + :func:`.relationship` in a second step, so it is critical that user-defined + implementations correctly differentiate between the two functions, if + a custom relationship function is being used. + + :attrname: the attribute name to which this relationship is being assigned. + If the value of :paramref:`.generate_relationship.return_fn` is the + :func:`.backref` function, then this name is the name that is being + assigned to the backref. + + :param local_cls: the "local" class to which this relationship or backref + will be locally present. + + :param referred_cls: the "referred" class to which the relationship or backref + refers to. + + :param \**kw: all additional keyword arguments are passed along to the + function. + + :return: a :func:`.relationship` or :func:`.backref` construct, as dictated + by the :paramref:`.generate_relationship.return_fn` parameter. + + """ + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + +class AutomapBase(object): + """Base class for an "automap" schema. + + The :class:`.AutomapBase` class can be compared to the "declarative base" + class that is produced by the :func:`.declarative.declarative_base` + function. In practice, the :class:`.AutomapBase` class is always used + as a mixin along with an actual declarative base. + + A new subclassable :class:`.AutomapBase` is typically instantated + using the :func:`.automap_base` function. + + .. seealso:: + + :ref:`automap_toplevel` + + """ + __abstract__ = True + + classes = None + """An instance of :class:`.util.Properties` containing classes. + + This object behaves much like the ``.c`` collection on a table. Classes + are present under the name they were given, e.g.:: + + Base = automap_base() + Base.prepare(engine=some_engine, reflect=True) + + User, Address = Base.classes.User, Base.classes.Address + + """ + + @classmethod + def prepare(cls, + engine=None, + reflect=False, + classname_for_table=classname_for_table, + collection_class=list, + name_for_scalar_relationship=name_for_scalar_relationship, + name_for_collection_relationship=name_for_collection_relationship, + generate_relationship=generate_relationship): + + """Extract mapped classes and relationships from the :class:`.MetaData` and + perform mappings. + + :param engine: an :class:`.Engine` or :class:`.Connection` with which + to perform schema reflection, if specified. + If the :paramref:`.AutomapBase.prepare.reflect` argument is False, this + object is not used. + + :param reflect: if True, the :meth:`.MetaData.reflect` method is called + on the :class:`.MetaData` associated with this :class:`.AutomapBase`. + The :class:`.Engine` passed via :paramref:`.AutomapBase.prepare.engine` will + be used to perform the reflection if present; else, the :class:`.MetaData` + should already be bound to some engine else the operation will fail. + + :param classname_for_table: callable function which will be used to + produce new class names, given a table name. Defaults to + :func:`.classname_for_table`. + + :param name_for_scalar_relationship: callable function which will be used + to produce relationship names for scalar relationships. Defaults to + :func:`.name_for_scalar_relationship`. + + :param name_for_collection_relationship: callable function which will be used + to produce relationship names for collection-oriented relationships. Defaults to + :func:`.name_for_collection_relationship`. + + :param generate_relationship: callable function which will be used to + actually generate :func:`.relationship` and :func:`.backref` constructs. + Defaults to :func:`.generate_relationship`. + + :param collection_class: the Python collection class that will be used + when a new :func:`.relationship` object is created that represents a + collection. Defaults to ``list``. + + """ + if reflect: + cls.metadata.reflect( + engine, + extend_existing=True, + autoload_replace=False + ) + + table_to_map_config = dict( + (m.local_table, m) + for m in _DeferredMapperConfig. + classes_for_base(cls, sort=False) + ) + + many_to_many = [] + + for table in cls.metadata.tables.values(): + lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) + if lcl_m2m is not None: + many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) + elif not table.primary_key: + continue + elif table not in table_to_map_config: + mapped_cls = type( + classname_for_table(cls, table.name, table), + (cls, ), + {"__table__": table} + ) + map_config = _DeferredMapperConfig.config_for_cls(mapped_cls) + cls.classes[map_config.cls.__name__] = mapped_cls + table_to_map_config[table] = map_config + + for map_config in table_to_map_config.values(): + _relationships_for_fks(cls, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship) + + for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: + _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship) + + for map_config in _DeferredMapperConfig.classes_for_base(cls): + map_config.map() + + + _sa_decl_prepare = True + """Indicate that the mapping of classes should be deferred. + + The presence of this attribute name indicates to declarative + that the call to mapper() should not occur immediately; instead, + information about the table and attributes to be mapped are gathered + into an internal structure called _DeferredMapperConfig. These + objects can be collected later using classes_for_base(), additional + mapping decisions can be made, and then the map() method will actually + apply the mapping. + + The only real reason this deferral of the whole + thing is needed is to support primary key columns that aren't reflected + yet when the class is declared; everything else can theoretically be + added to the mapper later. However, the _DeferredMapperConfig is a + nice interface in any case which exists at that not usually exposed point + at which declarative has the class and the Table but hasn't called + mapper() yet. + + """ + +def automap_base(declarative_base=None, **kw): + """Produce a declarative automap base. + + This function produces a new base class that is a product of the + :class:`.AutomapBase` class as well a declarative base produced by + :func:`.declarative.declarative_base`. + + All parameters other than ``declarative_base`` are keyword arguments + that are passed directly to the :func:`.declarative.declarative_base` + function. + + :param declarative_base: an existing class produced by + :func:`.declarative.declarative_base`. When this is passed, the function + no longer invokes :func:`.declarative.declarative_base` itself, and all other + keyword arguments are ignored. + + :param \**kw: keyword arguments are passed along to + :func:`.declarative.declarative_base`. + + """ + if declarative_base is None: + Base = _declarative_base(**kw) + else: + Base = declarative_base + + return type( + Base.__name__, + (AutomapBase, Base,), + {"__abstract__": True, "classes": util.Properties({})} + ) + +def _is_many_to_many(automap_base, table): + fk_constraints = [const for const in table.constraints + if isinstance(const, ForeignKeyConstraint)] + if len(fk_constraints) != 2: + return None, None, None + + cols = sum( + [[fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints], []) + + if set(cols) != set(table.c): + return None, None, None + + return ( + fk_constraints[0].elements[0].column.table, + fk_constraints[1].elements[0].column.table, + fk_constraints + ) + +def _relationships_for_fks(automap_base, map_config, table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship): + local_table = map_config.local_table + local_cls = map_config.cls + + if local_table is None: + return + for constraint in local_table.constraints: + if isinstance(constraint, ForeignKeyConstraint): + fks = constraint.elements + referred_table = fks[0].column.table + referred_cfg = table_to_map_config.get(referred_table, None) + if referred_cfg is None: + continue + referred_cls = referred_cfg.cls + + if local_cls is not referred_cls and issubclass(local_cls, referred_cls): + continue + + relationship_name = name_for_scalar_relationship( + automap_base, + local_cls, + referred_cls, constraint) + backref_name = name_for_collection_relationship( + automap_base, + referred_cls, + local_cls, + constraint + ) + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship(automap_base, + interfaces.ONETOMANY, backref, + backref_name, referred_cls, local_cls, + collection_class=collection_class) + else: + backref_obj = None + rel = generate_relationship(automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements] + ) + if rel is not None: + map_config.properties[relationship_name] = rel + if not create_backref: + referred_cfg.properties[backref_name].back_populates = relationship_name + elif create_backref: + rel = generate_relationship(automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[relationship_name].back_populates = backref_name + +def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship): + + map_config = table_to_map_config.get(lcl_m2m, None) + referred_cfg = table_to_map_config.get(rem_m2m, None) + if map_config is None or referred_cfg is None: + return + + local_cls = map_config.cls + referred_cls = referred_cfg.cls + + relationship_name = name_for_collection_relationship( + automap_base, + local_cls, + referred_cls, m2m_const[0]) + backref_name = name_for_collection_relationship( + automap_base, + referred_cls, + local_cls, + m2m_const[1] + ) + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship(automap_base, + interfaces.MANYTOMANY, + backref, + backref_name, + referred_cls, local_cls, + collection_class=collection_class + ) + else: + backref_obj = None + rel = generate_relationship(automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, referred_cls, + secondary=table, + primaryjoin=and_(fk.column == fk.parent for fk in m2m_const[0].elements), + secondaryjoin=and_(fk.column == fk.parent for fk in m2m_const[1].elements), + backref=backref_obj, + collection_class=collection_class + ) + if rel is not None: + map_config.properties[relationship_name] = rel + + if not create_backref: + referred_cfg.properties[backref_name].back_populates = relationship_name + elif create_backref: + rel = generate_relationship(automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, local_cls, + secondary=table, + primaryjoin=and_(fk.column == fk.parent for fk in m2m_const[1].elements), + secondaryjoin=and_(fk.column == fk.parent for fk in m2m_const[0].elements), + back_populates=relationship_name, + collection_class=collection_class) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[relationship_name].back_populates = backref_name diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py new file mode 100644 index 00000000..5dde74e0 --- /dev/null +++ b/lib/sqlalchemy/ext/compiler.py @@ -0,0 +1,448 @@ +# ext/compiler.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + pass + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select([MyColumn('x'), MyColumn('y')]) + print str(s) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select) + ) + + insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5)) + print insert + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`.Insert.from_select` method. + +.. note:: + + The above ``InsertFromSelect`` construct probably wants to have "autocommit" + enabled. See :ref:`enabling_compiled_autocommit` for this step. + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process(constraint.expression) + ) + +.. _enabling_compiled_autocommit: + +Enabling Autocommit on a Construct +================================== + +Recall from the section :ref:`autocommit` that the :class:`.Engine`, when +asked to execute a construct in the absence of a user-defined transaction, +detects if the given construct represents DML or DDL, that is, a data +modification or data definition statement, which requires (or may require, +in the case of DDL) that the transaction generated by the DBAPI be committed +(recall that DBAPI always has a transaction going on regardless of what +SQLAlchemy does). Checking for this is actually accomplished by checking for +the "autocommit" execution option on the construct. When building a +construct like an INSERT derivation, a new DDL type, or perhaps a stored +procedure that alters data, the "autocommit" option needs to be set in order +for the statement to function with "connectionless" execution +(as described in :ref:`dbengine_implicit`). + +Currently a quick way to do this is to subclass :class:`.Executable`, then +add the "autocommit" flag to the ``_execution_options`` dictionary (note this +is a "frozen" dictionary which supplies a generative ``union()`` method):: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class MyInsertThing(Executable, ClauseElement): + _execution_options = \\ + Executable._execution_options.union({'autocommit': True}) + +More succinctly, if the construct is truly similar to an INSERT, UPDATE, or +DELETE, :class:`.UpdateBase` can be used, which already is a subclass +of :class:`.Executable`, :class:`.ClauseElement` and includes the +``autocommit`` flag:: + + from sqlalchemy.sql.expression import UpdateBase + + class MyInsertThing(UpdateBase): + def __init__(self, ...): + ... + + + + +DDL elements that subclass :class:`.DDLElement` already have the +"autocommit" flag turned on. + + + + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM " + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses) + +* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement`` + subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``. + ``DDLElement`` also features ``Table`` and ``MetaData`` event hooks via the + ``execute_at()`` method, allowing the construct to be invoked during CREATE + TABLE and DROP TABLE sequences. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For Postgresql and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - it's equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return "CASE WHEN %s > %s THEN %s ELSE %s END" % ( + compiler.process(arg1), + compiler.process(arg2), + compiler.process(arg1), + compiler.process(arg2), + ) + +Example usage:: + + Session.query(Account).\\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + pass + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select([users.c.name, sql_false().label("enrolled")]), + select([customers.c.name, customers.c.enrolled]) + ) + +""" +from .. import exc +from ..sql import visitors + + +def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`.ClauseElement` type.""" + + def decorate(fn): + existing = class_.__dict__.get('_compiler_dispatcher', None) + existing_dispatch = class_.__dict__.get('_compiler_dispatch') + if not existing: + existing = _dispatcher() + + if existing_dispatch: + existing.specs['default'] = existing_dispatch + + # TODO: why is the lambda needed ? + setattr(class_, '_compiler_dispatch', + lambda *arg, **kw: existing(*arg, **kw)) + setattr(class_, '_compiler_dispatcher', existing) + + if specs: + for s in specs: + existing.specs[s] = fn + + else: + existing.specs['default'] = fn + return fn + return decorate + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`.ClauseElement` type.""" + + if hasattr(class_, '_compiler_dispatcher'): + # regenerate default _compiler_dispatch + visitors._generate_dispatch(class_) + # remove custom directive + del class_._compiler_dispatcher + + +class _dispatcher(object): + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + try: + fn = self.specs['default'] + except KeyError: + raise exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element)) + return fn(element, compiler, **kw) diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py new file mode 100644 index 00000000..4010789b --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -0,0 +1,1316 @@ +# ext/declarative/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Synopsis +======== + +SQLAlchemy object-relational configuration involves the +combination of :class:`.Table`, :func:`.mapper`, and class +objects to define a mapped class. +:mod:`~sqlalchemy.ext.declarative` allows all three to be +expressed at once within the class declaration. As much as +possible, regular SQLAlchemy schema and ORM constructs are +used directly, so that configuration between "classical" ORM +usage and declarative remain highly similar. + +As a simple example:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + +Above, the :func:`declarative_base` callable returns a new base class from +which all mapped classes should inherit. When the class definition is +completed, a new :class:`.Table` and :func:`.mapper` will have been generated. + +The resulting table and mapper are accessible via +``__table__`` and ``__mapper__`` attributes on the +``SomeClass`` class:: + + # access the mapped Table + SomeClass.__table__ + + # access the Mapper + SomeClass.__mapper__ + +Defining Attributes +=================== + +In the previous example, the :class:`.Column` objects are +automatically named with the name of the attribute to which they are +assigned. + +To name columns explicitly with a name distinct from their mapped attribute, +just give the column a name. Below, column "some_table_id" is mapped to the +"id" attribute of `SomeClass`, but in SQL will be represented as +"some_table_id":: + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column("some_table_id", Integer, primary_key=True) + +Attributes may be added to the class after its construction, and they will be +added to the underlying :class:`.Table` and +:func:`.mapper` definitions as appropriate:: + + SomeClass.data = Column('data', Unicode) + SomeClass.related = relationship(RelatedInfo) + +Classes which are constructed using declarative can interact freely +with classes that are mapped explicitly with :func:`.mapper`. + +It is recommended, though not required, that all tables +share the same underlying :class:`~sqlalchemy.schema.MetaData` object, +so that string-configured :class:`~sqlalchemy.schema.ForeignKey` +references can be resolved without issue. + +Accessing the MetaData +======================= + +The :func:`declarative_base` base class contains a +:class:`.MetaData` object where newly defined +:class:`.Table` objects are collected. This object is +intended to be accessed directly for +:class:`.MetaData`-specific operations. Such as, to issue +CREATE statements for all tables:: + + engine = create_engine('sqlite://') + Base.metadata.create_all(engine) + +:func:`declarative_base` can also receive a pre-existing +:class:`.MetaData` object, which allows a +declarative setup to be associated with an already +existing traditional collection of :class:`~sqlalchemy.schema.Table` +objects:: + + mymetadata = MetaData() + Base = declarative_base(metadata=mymetadata) + + +.. _declarative_configuring_relationships: + +Configuring Relationships +========================= + +Relationships to other classes are done in the usual way, with the added +feature that the class specified to :func:`~sqlalchemy.orm.relationship` +may be a string name. The "class registry" associated with ``Base`` +is used at mapper compilation time to resolve the name into the actual +class object, which is expected to have been defined once the mapper +configuration is used:: + + class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relationship("Address", backref="user") + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + +Column constructs, since they are just that, are immediately usable, +as below where we define a primary join condition on the ``Address`` +class using them:: + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + user = relationship(User, primaryjoin=user_id == User.id) + +In addition to the main argument for :func:`~sqlalchemy.orm.relationship`, +other arguments which depend upon the columns present on an as-yet +undefined class may also be specified as strings. These strings are +evaluated as Python expressions. The full namespace available within +this evaluation includes all classes mapped for this declarative base, +as well as the contents of the ``sqlalchemy`` package, including +expression functions like :func:`~sqlalchemy.sql.expression.desc` and +:attr:`~sqlalchemy.sql.expression.func`:: + + class User(Base): + # .... + addresses = relationship("Address", + order_by="desc(Address.email)", + primaryjoin="Address.user_id==User.id") + +For the case where more than one module contains a class of the same name, +string class names can also be specified as module-qualified paths +within any of these string expressions:: + + class User(Base): + # .... + addresses = relationship("myapp.model.address.Address", + order_by="desc(myapp.model.address.Address.email)", + primaryjoin="myapp.model.address.Address.user_id==" + "myapp.model.user.User.id") + +The qualified path can be any partial path that removes ambiguity between +the names. For example, to disambiguate between +``myapp.model.address.Address`` and ``myapp.model.lookup.Address``, +we can specify ``address.Address`` or ``lookup.Address``:: + + class User(Base): + # .... + addresses = relationship("address.Address", + order_by="desc(address.Address.email)", + primaryjoin="address.Address.user_id==" + "User.id") + +.. versionadded:: 0.8 + module-qualified paths can be used when specifying string arguments + with Declarative, in order to specify specific modules. + +Two alternatives also exist to using string-based attributes. A lambda +can also be used, which will be evaluated after all mappers have been +configured:: + + class User(Base): + # ... + addresses = relationship(lambda: Address, + order_by=lambda: desc(Address.email), + primaryjoin=lambda: Address.user_id==User.id) + +Or, the relationship can be added to the class explicitly after the classes +are available:: + + User.addresses = relationship(Address, + primaryjoin=Address.user_id==User.id) + + + +.. _declarative_many_to_many: + +Configuring Many-to-Many Relationships +====================================== + +Many-to-many relationships are also declared in the same way +with declarative as with traditional mappings. The +``secondary`` argument to +:func:`.relationship` is as usual passed a +:class:`.Table` object, which is typically declared in the +traditional way. The :class:`.Table` usually shares +the :class:`.MetaData` object used by the declarative base:: + + keywords = Table( + 'keywords', Base.metadata, + Column('author_id', Integer, ForeignKey('authors.id')), + Column('keyword_id', Integer, ForeignKey('keywords.id')) + ) + + class Author(Base): + __tablename__ = 'authors' + id = Column(Integer, primary_key=True) + keywords = relationship("Keyword", secondary=keywords) + +Like other :func:`~sqlalchemy.orm.relationship` arguments, a string is accepted +as well, passing the string name of the table as defined in the +``Base.metadata.tables`` collection:: + + class Author(Base): + __tablename__ = 'authors' + id = Column(Integer, primary_key=True) + keywords = relationship("Keyword", secondary="keywords") + +As with traditional mapping, its generally not a good idea to use +a :class:`.Table` as the "secondary" argument which is also mapped to +a class, unless the :func:`.relationship` is declared with ``viewonly=True``. +Otherwise, the unit-of-work system may attempt duplicate INSERT and +DELETE statements against the underlying table. + +.. _declarative_sql_expressions: + +Defining SQL Expressions +======================== + +See :ref:`mapper_sql_expressions` for examples on declaratively +mapping attributes to SQL expressions. + +.. _declarative_table_args: + +Table Configuration +=================== + +Table arguments other than the name, metadata, and mapped Column +arguments are specified using the ``__table_args__`` class attribute. +This attribute accommodates both positional as well as keyword +arguments that are normally sent to the +:class:`~sqlalchemy.schema.Table` constructor. +The attribute can be specified in one of two forms. One is as a +dictionary:: + + class MyClass(Base): + __tablename__ = 'sometable' + __table_args__ = {'mysql_engine':'InnoDB'} + +The other, a tuple, where each argument is positional +(usually constraints):: + + class MyClass(Base): + __tablename__ = 'sometable' + __table_args__ = ( + ForeignKeyConstraint(['id'], ['remote_table.id']), + UniqueConstraint('foo'), + ) + +Keyword arguments can be specified with the above form by +specifying the last argument as a dictionary:: + + class MyClass(Base): + __tablename__ = 'sometable' + __table_args__ = ( + ForeignKeyConstraint(['id'], ['remote_table.id']), + UniqueConstraint('foo'), + {'autoload':True} + ) + +Using a Hybrid Approach with __table__ +======================================= + +As an alternative to ``__tablename__``, a direct +:class:`~sqlalchemy.schema.Table` construct may be used. The +:class:`~sqlalchemy.schema.Column` objects, which in this case require +their names, will be added to the mapping just like a regular mapping +to a table:: + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + +``__table__`` provides a more focused point of control for establishing +table metadata, while still getting most of the benefits of using declarative. +An application that uses reflection might want to load table metadata elsewhere +and pass it to declarative classes:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + Base.metadata.reflect(some_engine) + + class User(Base): + __table__ = metadata.tables['user'] + + class Address(Base): + __table__ = metadata.tables['address'] + +Some configuration schemes may find it more appropriate to use ``__table__``, +such as those which already take advantage of the data-driven nature of +:class:`.Table` to customize and/or automate schema definition. + +Note that when the ``__table__`` approach is used, the object is immediately +usable as a plain :class:`.Table` within the class declaration body itself, +as a Python class is only another syntactical block. Below this is illustrated +by using the ``id`` column in the ``primaryjoin`` condition of a +:func:`.relationship`:: + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + + widgets = relationship(Widget, + primaryjoin=Widget.myclass_id==__table__.c.id) + +Similarly, mapped attributes which refer to ``__table__`` can be placed inline, +as below where we assign the ``name`` column to the attribute ``_name``, +generating a synonym for ``name``:: + + from sqlalchemy.ext.declarative import synonym_for + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + + _name = __table__.c.name + + @synonym_for("_name") + def name(self): + return "Name: %s" % _name + +Using Reflection with Declarative +================================= + +It's easy to set up a :class:`.Table` that uses ``autoload=True`` +in conjunction with a mapped class:: + + class MyClass(Base): + __table__ = Table('mytable', Base.metadata, + autoload=True, autoload_with=some_engine) + +However, one improvement that can be made here is to not +require the :class:`.Engine` to be available when classes are +being first declared. To achieve this, use the +:class:`.DeferredReflection` mixin, which sets up mappings +only after a special ``prepare(engine)`` step is called:: + + from sqlalchemy.ext.declarative import declarative_base, DeferredReflection + + Base = declarative_base(cls=DeferredReflection) + + class Foo(Base): + __tablename__ = 'foo' + bars = relationship("Bar") + + class Bar(Base): + __tablename__ = 'bar' + + # illustrate overriding of "bar.foo_id" to have + # a foreign key constraint otherwise not + # reflected, such as when using MySQL + foo_id = Column(Integer, ForeignKey('foo.id')) + + Base.prepare(e) + +.. versionadded:: 0.8 + Added :class:`.DeferredReflection`. + +Mapper Configuration +==================== + +Declarative makes use of the :func:`~.orm.mapper` function internally +when it creates the mapping to the declared table. The options +for :func:`~.orm.mapper` are passed directly through via the +``__mapper_args__`` class attribute. As always, arguments which reference +locally mapped columns can reference them directly from within the +class declaration:: + + from datetime import datetime + + class Widget(Base): + __tablename__ = 'widgets' + + id = Column(Integer, primary_key=True) + timestamp = Column(DateTime, nullable=False) + + __mapper_args__ = { + 'version_id_col': timestamp, + 'version_id_generator': lambda v:datetime.now() + } + +.. _declarative_inheritance: + +Inheritance Configuration +========================= + +Declarative supports all three forms of inheritance as intuitively +as possible. The ``inherits`` mapper keyword argument is not needed +as declarative will determine this from the class itself. The various +"polymorphic" keyword arguments are specified using ``__mapper_args__``. + +Joined Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~ + +Joined table inheritance is defined as a subclass that defines its own +table:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity': 'engineer'} + id = Column(Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column(String(50)) + +Note that above, the ``Engineer.id`` attribute, since it shares the +same attribute name as the ``Person.id`` attribute, will in fact +represent the ``people.id`` and ``engineers.id`` columns together, +with the "Engineer.id" column taking precedence if queried directly. +To provide the ``Engineer`` class with an attribute that represents +only the ``engineers.id`` column, give it a different attribute name:: + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity': 'engineer'} + engineer_id = Column('id', Integer, ForeignKey('people.id'), + primary_key=True) + primary_language = Column(String(50)) + + +.. versionchanged:: 0.7 joined table inheritance favors the subclass + column over that of the superclass, such as querying above + for ``Engineer.id``. Prior to 0.7 this was the reverse. + +.. _declarative_single_table: + +Single Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~ + +Single table inheritance is defined as a subclass that does not have +its own table; you just leave out the ``__table__`` and ``__tablename__`` +attributes:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity': 'engineer'} + primary_language = Column(String(50)) + +When the above mappers are configured, the ``Person`` class is mapped +to the ``people`` table *before* the ``primary_language`` column is +defined, and this column will not be included in its own mapping. +When ``Engineer`` then defines the ``primary_language`` column, the +column is added to the ``people`` table so that it is included in the +mapping for ``Engineer`` and is also part of the table's full set of +columns. Columns which are not mapped to ``Person`` are also excluded +from any other single or joined inheriting classes using the +``exclude_properties`` mapper argument. Below, ``Manager`` will have +all the attributes of ``Person`` and ``Manager`` but *not* the +``primary_language`` attribute of ``Engineer``:: + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity': 'manager'} + golf_swing = Column(String(50)) + +The attribute exclusion logic is provided by the +``exclude_properties`` mapper argument, and declarative's default +behavior can be disabled by passing an explicit ``exclude_properties`` +collection (empty or otherwise) to the ``__mapper_args__``. + +Resolving Column Conflicts +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Note above that the ``primary_language`` and ``golf_swing`` columns +are "moved up" to be applied to ``Person.__table__``, as a result of their +declaration on a subclass that has no table of its own. A tricky case +comes up when two subclasses want to specify *the same* column, as below:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity': 'engineer'} + start_date = Column(DateTime) + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity': 'manager'} + start_date = Column(DateTime) + +Above, the ``start_date`` column declared on both ``Engineer`` and ``Manager`` +will result in an error:: + + sqlalchemy.exc.ArgumentError: Column 'start_date' on class + conflicts with existing + column 'people.start_date' + +In a situation like this, Declarative can't be sure +of the intent, especially if the ``start_date`` columns had, for example, +different types. A situation like this can be resolved by using +:class:`.declared_attr` to define the :class:`.Column` conditionally, taking +care to return the **existing column** via the parent ``__table__`` if it +already exists:: + + from sqlalchemy.ext.declarative import declared_attr + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity': 'engineer'} + + @declared_attr + def start_date(cls): + "Start date column, if not present already." + return Person.__table__.c.get('start_date', Column(DateTime)) + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity': 'manager'} + + @declared_attr + def start_date(cls): + "Start date column, if not present already." + return Person.__table__.c.get('start_date', Column(DateTime)) + +Above, when ``Manager`` is mapped, the ``start_date`` column is +already present on the ``Person`` class. Declarative lets us return +that :class:`.Column` as a result in this case, where it knows to skip +re-assigning the same column. If the mapping is mis-configured such +that the ``start_date`` column is accidentally re-assigned to a +different table (such as, if we changed ``Manager`` to be joined +inheritance without fixing ``start_date``), an error is raised which +indicates an existing :class:`.Column` is trying to be re-assigned to +a different owning :class:`.Table`. + +.. versionadded:: 0.8 :class:`.declared_attr` can be used on a non-mixin + class, and the returned :class:`.Column` or other mapped attribute + will be applied to the mapping as any other attribute. Previously, + the resulting attribute would be ignored, and also result in a warning + being emitted when a subclass was created. + +.. versionadded:: 0.8 :class:`.declared_attr`, when used either with a + mixin or non-mixin declarative class, can return an existing + :class:`.Column` already assigned to the parent :class:`.Table`, + to indicate that the re-assignment of the :class:`.Column` should be + skipped, however should still be mapped on the target class, + in order to resolve duplicate column conflicts. + +The same concept can be used with mixin classes (see +:ref:`declarative_mixins`):: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class HasStartDate(object): + @declared_attr + def start_date(cls): + return cls.__table__.c.get('start_date', Column(DateTime)) + + class Engineer(HasStartDate, Person): + __mapper_args__ = {'polymorphic_identity': 'engineer'} + + class Manager(HasStartDate, Person): + __mapper_args__ = {'polymorphic_identity': 'manager'} + +The above mixin checks the local ``__table__`` attribute for the column. +Because we're using single table inheritance, we're sure that in this case, +``cls.__table__`` refers to ``People.__table__``. If we were mixing joined- +and single-table inheritance, we might want our mixin to check more carefully +if ``cls.__table__`` is really the :class:`.Table` we're looking for. + +Concrete Table Inheritance +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Concrete is defined as a subclass which has its own table and sets the +``concrete`` keyword argument to ``True``:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'concrete':True} + id = Column(Integer, primary_key=True) + primary_language = Column(String(50)) + name = Column(String(50)) + +Usage of an abstract base class is a little less straightforward as it +requires usage of :func:`~sqlalchemy.orm.util.polymorphic_union`, +which needs to be created with the :class:`.Table` objects +before the class is built:: + + engineers = Table('engineers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('primary_language', String(50)) + ) + managers = Table('managers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('golf_swing', String(50)) + ) + + punion = polymorphic_union({ + 'engineer':engineers, + 'manager':managers + }, 'type', 'punion') + + class Person(Base): + __table__ = punion + __mapper_args__ = {'polymorphic_on':punion.c.type} + + class Engineer(Person): + __table__ = engineers + __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} + + class Manager(Person): + __table__ = managers + __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} + +.. _declarative_concrete_helpers: + +Using the Concrete Helpers +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Helper classes provides a simpler pattern for concrete inheritance. +With these objects, the ``__declare_first__`` helper is used to configure the +"polymorphic" loader for the mapper after all subclasses have been declared. + +.. versionadded:: 0.7.3 + +An abstract base can be declared using the +:class:`.AbstractConcreteBase` class:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Employee(AbstractConcreteBase, Base): + pass + +To have a concrete ``employee`` table, use :class:`.ConcreteBase` instead:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + +Either ``Employee`` base can be used in the normal fashion:: + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + class Engineer(Employee): + __tablename__ = 'engineer' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + engineer_info = Column(String(40)) + __mapper_args__ = {'polymorphic_identity':'engineer', + 'concrete':True} + + +The :class:`.AbstractConcreteBase` class is itself mapped, and can be +used as a target of relationships:: + + class Company(Base): + __tablename__ = 'company' + + id = Column(Integer, primary_key=True) + employees = relationship("Employee", + primaryjoin="Company.id == Employee.company_id") + + +.. versionchanged:: 0.9.3 Support for use of :class:`.AbstractConcreteBase` + as the target of a :func:`.relationship` has been improved. + +It can also be queried directly:: + + for employee in session.query(Employee).filter(Employee.name == 'qbert'): + print(employee) + + +.. _declarative_mixins: + +Mixin and Custom Base Classes +============================== + +A common need when using :mod:`~sqlalchemy.ext.declarative` is to +share some functionality, such as a set of common columns, some common +table options, or other mapped properties, across many +classes. The standard Python idioms for this is to have the classes +inherit from a base which includes these common features. + +When using :mod:`~sqlalchemy.ext.declarative`, this idiom is allowed +via the usage of a custom declarative base class, as well as a "mixin" class +which is inherited from in addition to the primary base. Declarative +includes several helper features to make this work in terms of how +mappings are declared. An example of some commonly mixed-in +idioms is below:: + + from sqlalchemy.ext.declarative import declared_attr + + class MyMixin(object): + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {'mysql_engine': 'InnoDB'} + __mapper_args__= {'always_refresh': True} + + id = Column(Integer, primary_key=True) + + class MyModel(MyMixin, Base): + name = Column(String(1000)) + +Where above, the class ``MyModel`` will contain an "id" column +as the primary key, a ``__tablename__`` attribute that derives +from the name of the class itself, as well as ``__table_args__`` +and ``__mapper_args__`` defined by the ``MyMixin`` mixin class. + +There's no fixed convention over whether ``MyMixin`` precedes +``Base`` or not. Normal Python method resolution rules apply, and +the above example would work just as well with:: + + class MyModel(Base, MyMixin): + name = Column(String(1000)) + +This works because ``Base`` here doesn't define any of the +variables that ``MyMixin`` defines, i.e. ``__tablename__``, +``__table_args__``, ``id``, etc. If the ``Base`` did define +an attribute of the same name, the class placed first in the +inherits list would determine which attribute is used on the +newly defined class. + +Augmenting the Base +~~~~~~~~~~~~~~~~~~~ + +In addition to using a pure mixin, most of the techniques in this +section can also be applied to the base class itself, for patterns that +should apply to all classes derived from a particular base. This is achieved +using the ``cls`` argument of the :func:`.declarative_base` function:: + + from sqlalchemy.ext.declarative import declared_attr + + class Base(object): + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {'mysql_engine': 'InnoDB'} + + id = Column(Integer, primary_key=True) + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base(cls=Base) + + class MyModel(Base): + name = Column(String(1000)) + +Where above, ``MyModel`` and all other classes that derive from ``Base`` will +have a table name derived from the class name, an ``id`` primary key column, +as well as the "InnoDB" engine for MySQL. + +Mixing in Columns +~~~~~~~~~~~~~~~~~ + +The most basic way to specify a column on a mixin is by simple +declaration:: + + class TimestampMixin(object): + created_at = Column(DateTime, default=func.now()) + + class MyModel(TimestampMixin, Base): + __tablename__ = 'test' + + id = Column(Integer, primary_key=True) + name = Column(String(1000)) + +Where above, all declarative classes that include ``TimestampMixin`` +will also have a column ``created_at`` that applies a timestamp to +all row insertions. + +Those familiar with the SQLAlchemy expression language know that +the object identity of clause elements defines their role in a schema. +Two ``Table`` objects ``a`` and ``b`` may both have a column called +``id``, but the way these are differentiated is that ``a.c.id`` +and ``b.c.id`` are two distinct Python objects, referencing their +parent tables ``a`` and ``b`` respectively. + +In the case of the mixin column, it seems that only one +:class:`.Column` object is explicitly created, yet the ultimate +``created_at`` column above must exist as a distinct Python object +for each separate destination class. To accomplish this, the declarative +extension creates a **copy** of each :class:`.Column` object encountered on +a class that is detected as a mixin. + +This copy mechanism is limited to simple columns that have no foreign +keys, as a :class:`.ForeignKey` itself contains references to columns +which can't be properly recreated at this level. For columns that +have foreign keys, as well as for the variety of mapper-level constructs +that require destination-explicit context, the +:class:`~.declared_attr` decorator is provided so that +patterns common to many classes can be defined as callables:: + + from sqlalchemy.ext.declarative import declared_attr + + class ReferenceAddressMixin(object): + @declared_attr + def address_id(cls): + return Column(Integer, ForeignKey('address.id')) + + class User(ReferenceAddressMixin, Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + +Where above, the ``address_id`` class-level callable is executed at the +point at which the ``User`` class is constructed, and the declarative +extension can use the resulting :class:`.Column` object as returned by +the method without the need to copy it. + +.. versionchanged:: > 0.6.5 + Rename 0.6.5 ``sqlalchemy.util.classproperty`` + into :class:`~.declared_attr`. + +Columns generated by :class:`~.declared_attr` can also be +referenced by ``__mapper_args__`` to a limited degree, currently +by ``polymorphic_on`` and ``version_id_col``, by specifying the +classdecorator itself into the dictionary - the declarative extension +will resolve them at class construction time:: + + class MyMixin: + @declared_attr + def type_(cls): + return Column(String(50)) + + __mapper_args__= {'polymorphic_on':type_} + + class MyModel(MyMixin, Base): + __tablename__='test' + id = Column(Integer, primary_key=True) + + + +Mixing in Relationships +~~~~~~~~~~~~~~~~~~~~~~~ + +Relationships created by :func:`~sqlalchemy.orm.relationship` are provided +with declarative mixin classes exclusively using the +:class:`.declared_attr` approach, eliminating any ambiguity +which could arise when copying a relationship and its possibly column-bound +contents. Below is an example which combines a foreign key column and a +relationship so that two classes ``Foo`` and ``Bar`` can both be configured to +reference a common target class via many-to-one:: + + class RefTargetMixin(object): + @declared_attr + def target_id(cls): + return Column('target_id', ForeignKey('target.id')) + + @declared_attr + def target(cls): + return relationship("Target") + + class Foo(RefTargetMixin, Base): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + + class Bar(RefTargetMixin, Base): + __tablename__ = 'bar' + id = Column(Integer, primary_key=True) + + class Target(Base): + __tablename__ = 'target' + id = Column(Integer, primary_key=True) + +Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +:func:`~sqlalchemy.orm.relationship` definitions which require explicit +primaryjoin, order_by etc. expressions should in all but the most +simplistic cases use **late bound** forms +for these arguments, meaning, using either the string form or a lambda. +The reason for this is that the related :class:`.Column` objects which are to +be configured using ``@declared_attr`` are not available to another +``@declared_attr`` attribute; while the methods will work and return new +:class:`.Column` objects, those are not the :class:`.Column` objects that +Declarative will be using as it calls the methods on its own, thus using +*different* :class:`.Column` objects. + +The canonical example is the primaryjoin condition that depends upon +another mixed-in column:: + + class RefTargetMixin(object): + @declared_attr + def target_id(cls): + return Column('target_id', ForeignKey('target.id')) + + @declared_attr + def target(cls): + return relationship(Target, + primaryjoin=Target.id==cls.target_id # this is *incorrect* + ) + +Mapping a class using the above mixin, we will get an error like:: + + sqlalchemy.exc.InvalidRequestError: this ForeignKey's parent column is not + yet associated with a Table. + +This is because the ``target_id`` :class:`.Column` we've called upon in our ``target()`` +method is not the same :class:`.Column` that declarative is actually going to map +to our table. + +The condition above is resolved using a lambda:: + + class RefTargetMixin(object): + @declared_attr + def target_id(cls): + return Column('target_id', ForeignKey('target.id')) + + @declared_attr + def target(cls): + return relationship(Target, + primaryjoin=lambda: Target.id==cls.target_id + ) + +or alternatively, the string form (which ultmately generates a lambda):: + + class RefTargetMixin(object): + @declared_attr + def target_id(cls): + return Column('target_id', ForeignKey('target.id')) + + @declared_attr + def target(cls): + return relationship("Target", + primaryjoin="Target.id==%s.target_id" % cls.__name__ + ) + +Mixing in deferred(), column_property(), and other MapperProperty classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Like :func:`~sqlalchemy.orm.relationship`, all +:class:`~sqlalchemy.orm.interfaces.MapperProperty` subclasses such as +:func:`~sqlalchemy.orm.deferred`, :func:`~sqlalchemy.orm.column_property`, +etc. ultimately involve references to columns, and therefore, when +used with declarative mixins, have the :class:`.declared_attr` +requirement so that no reliance on copying is needed:: + + class SomethingMixin(object): + + @declared_attr + def dprop(cls): + return deferred(Column(Integer)) + + class Something(SomethingMixin, Base): + __tablename__ = "something" + +Mixing in Association Proxy and Other Attributes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Mixins can specify user-defined attributes as well as other extension +units such as :func:`.association_proxy`. The usage of +:class:`.declared_attr` is required in those cases where the attribute must +be tailored specifically to the target subclass. An example is when +constructing multiple :func:`.association_proxy` attributes which each +target a different type of child object. Below is an +:func:`.association_proxy` / mixin example which provides a scalar list of +string values to an implementing class:: + + from sqlalchemy import Column, Integer, ForeignKey, String + from sqlalchemy.orm import relationship + from sqlalchemy.ext.associationproxy import association_proxy + from sqlalchemy.ext.declarative import declarative_base, declared_attr + + Base = declarative_base() + + class HasStringCollection(object): + @declared_attr + def _strings(cls): + class StringAttribute(Base): + __tablename__ = cls.string_table_name + id = Column(Integer, primary_key=True) + value = Column(String(50), nullable=False) + parent_id = Column(Integer, + ForeignKey('%s.id' % cls.__tablename__), + nullable=False) + def __init__(self, value): + self.value = value + + return relationship(StringAttribute) + + @declared_attr + def strings(cls): + return association_proxy('_strings', 'value') + + class TypeA(HasStringCollection, Base): + __tablename__ = 'type_a' + string_table_name = 'type_a_strings' + id = Column(Integer(), primary_key=True) + + class TypeB(HasStringCollection, Base): + __tablename__ = 'type_b' + string_table_name = 'type_b_strings' + id = Column(Integer(), primary_key=True) + +Above, the ``HasStringCollection`` mixin produces a :func:`.relationship` +which refers to a newly generated class called ``StringAttribute``. The +``StringAttribute`` class is generated with it's own :class:`.Table` +definition which is local to the parent class making usage of the +``HasStringCollection`` mixin. It also produces an :func:`.association_proxy` +object which proxies references to the ``strings`` attribute onto the ``value`` +attribute of each ``StringAttribute`` instance. + +``TypeA`` or ``TypeB`` can be instantiated given the constructor +argument ``strings``, a list of strings:: + + ta = TypeA(strings=['foo', 'bar']) + tb = TypeA(strings=['bat', 'bar']) + +This list will generate a collection +of ``StringAttribute`` objects, which are persisted into a table that's +local to either the ``type_a_strings`` or ``type_b_strings`` table:: + + >>> print ta._strings + [<__main__.StringAttribute object at 0x10151cd90>, + <__main__.StringAttribute object at 0x10151ce10>] + +When constructing the :func:`.association_proxy`, the +:class:`.declared_attr` decorator must be used so that a distinct +:func:`.association_proxy` object is created for each of the ``TypeA`` +and ``TypeB`` classes. + +.. versionadded:: 0.8 :class:`.declared_attr` is usable with non-mapped + attributes, including user-defined attributes as well as + :func:`.association_proxy`. + + +Controlling table inheritance with mixins +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``__tablename__`` attribute in conjunction with the hierarchy of +classes involved in a declarative mixin scenario controls what type of +table inheritance, if any, +is configured by the declarative extension. + +If the ``__tablename__`` is computed by a mixin, you may need to +control which classes get the computed attribute in order to get the +type of table inheritance you require. + +For example, if you had a mixin that computes ``__tablename__`` but +where you wanted to use that mixin in a single table inheritance +hierarchy, you can explicitly specify ``__tablename__`` as ``None`` to +indicate that the class should not have a table mapped:: + + from sqlalchemy.ext.declarative import declared_attr + + class Tablename: + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + class Person(Tablename, Base): + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = None + __mapper_args__ = {'polymorphic_identity': 'engineer'} + primary_language = Column(String(50)) + +Alternatively, you can make the mixin intelligent enough to only +return a ``__tablename__`` in the event that no table is already +mapped in the inheritance hierarchy. To help with this, a +:func:`~sqlalchemy.ext.declarative.has_inherited_table` helper +function is provided that returns ``True`` if a parent class already +has a mapped table. + +As an example, here's a mixin that will only allow single table +inheritance:: + + from sqlalchemy.ext.declarative import declared_attr + from sqlalchemy.ext.declarative import has_inherited_table + + class Tablename(object): + @declared_attr + def __tablename__(cls): + if has_inherited_table(cls): + return None + return cls.__name__.lower() + + class Person(Tablename, Base): + id = Column(Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + primary_language = Column(String(50)) + __mapper_args__ = {'polymorphic_identity': 'engineer'} + + +Combining Table/Mapper Arguments from Multiple Mixins +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the case of ``__table_args__`` or ``__mapper_args__`` +specified with declarative mixins, you may want to combine +some parameters from several mixins with those you wish to +define on the class iteself. The +:class:`.declared_attr` decorator can be used +here to create user-defined collation routines that pull +from multiple collections:: + + from sqlalchemy.ext.declarative import declared_attr + + class MySQLSettings(object): + __table_args__ = {'mysql_engine':'InnoDB'} + + class MyOtherMixin(object): + __table_args__ = {'info':'foo'} + + class MyModel(MySQLSettings, MyOtherMixin, Base): + __tablename__='my_model' + + @declared_attr + def __table_args__(cls): + args = dict() + args.update(MySQLSettings.__table_args__) + args.update(MyOtherMixin.__table_args__) + return args + + id = Column(Integer, primary_key=True) + +Creating Indexes with Mixins +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To define a named, potentially multicolumn :class:`.Index` that applies to all +tables derived from a mixin, use the "inline" form of :class:`.Index` and +establish it as part of ``__table_args__``:: + + class MyMixin(object): + a = Column(Integer) + b = Column(Integer) + + @declared_attr + def __table_args__(cls): + return (Index('test_idx_%s' % cls.__tablename__, 'a', 'b'),) + + class MyModel(MyMixin, Base): + __tablename__ = 'atable' + c = Column(Integer,primary_key=True) + +Special Directives +================== + +``__declare_last__()`` +~~~~~~~~~~~~~~~~~~~~~~ + +The ``__declare_last__()`` hook allows definition of +a class level function that is automatically called by the +:meth:`.MapperEvents.after_configured` event, which occurs after mappings are +assumed to be completed and the 'configure' step has finished:: + + class MyClass(Base): + @classmethod + def __declare_last__(cls): + "" + # do something with mappings + +.. versionadded:: 0.7.3 + +``__declare_first__()`` +~~~~~~~~~~~~~~~~~~~~~~~ + +Like ``__declare_last__()``, but is called at the beginning of mapper configuration +via the :meth:`.MapperEvents.before_configured` event:: + + class MyClass(Base): + @classmethod + def __declare_first__(cls): + "" + # do something before mappings are configured + +.. versionadded:: 0.9.3 + +.. _declarative_abstract: + +``__abstract__`` +~~~~~~~~~~~~~~~~~~~ + +``__abstract__`` causes declarative to skip the production +of a table or mapper for the class entirely. A class can be added within a +hierarchy in the same way as mixin (see :ref:`declarative_mixins`), allowing +subclasses to extend just from the special class:: + + class SomeAbstractBase(Base): + __abstract__ = True + + def some_helpful_method(self): + "" + + @declared_attr + def __mapper_args__(cls): + return {"helpful mapper arguments":True} + + class MyMappedClass(SomeAbstractBase): + "" + +One possible use of ``__abstract__`` is to use a distinct +:class:`.MetaData` for different bases:: + + Base = declarative_base() + + class DefaultBase(Base): + __abstract__ = True + metadata = MetaData() + + class OtherBase(Base): + __abstract__ = True + metadata = MetaData() + +Above, classes which inherit from ``DefaultBase`` will use one +:class:`.MetaData` as the registry of tables, and those which inherit from +``OtherBase`` will use a different one. The tables themselves can then be +created perhaps within distinct databases:: + + DefaultBase.metadata.create_all(some_engine) + OtherBase.metadata_create_all(some_other_engine) + +.. versionadded:: 0.7.3 + +Class Constructor +================= + +As a convenience feature, the :func:`declarative_base` sets a default +constructor on classes which takes keyword arguments, and assigns them +to the named attributes:: + + e = Engineer(primary_language='python') + +Sessions +======== + +Note that ``declarative`` does nothing special with sessions, and is +only intended as an easier way to configure mappers and +:class:`~sqlalchemy.schema.Table` objects. A typical application +setup using :class:`~sqlalchemy.orm.scoping.scoped_session` might look like:: + + engine = create_engine('postgresql://scott:tiger@localhost/test') + Session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) + Base = declarative_base() + +Mapped instances then make usage of +:class:`~sqlalchemy.orm.session.Session` in the usual way. + +""" + +from .api import declarative_base, synonym_for, comparable_using, \ + instrument_declarative, ConcreteBase, AbstractConcreteBase, \ + DeclarativeMeta, DeferredReflection, has_inherited_table,\ + declared_attr, as_declarative + + +__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table', + 'comparable_using', 'instrument_declarative', 'declared_attr', + 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta', + 'DeferredReflection'] diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py new file mode 100644 index 00000000..941f02b0 --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -0,0 +1,513 @@ +# ext/declarative/api.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Public API functions and helpers for declarative.""" + + +from ...schema import Table, MetaData +from ...orm import synonym as _orm_synonym, mapper,\ + comparable_property,\ + interfaces, properties +from ...orm.util import polymorphic_union +from ...orm.base import _mapper_or_none +from ...util import OrderedDict +from ... import exc +import weakref + +from .base import _as_declarative, \ + _declarative_constructor,\ + _DeferredMapperConfig, _add_attribute +from .clsregistry import _class_resolver +from . import clsregistry + +def instrument_declarative(cls, registry, metadata): + """Given a class, configure the class declaratively, + using the given registry, which can be any dictionary, and + MetaData object. + + """ + if '_decl_class_registry' in cls.__dict__: + raise exc.InvalidRequestError( + "Class %r already has been " + "instrumented declaratively" % cls) + cls._decl_class_registry = registry + cls.metadata = metadata + _as_declarative(cls, cls.__name__, cls.__dict__) + + +def has_inherited_table(cls): + """Given a class, return True if any of the classes it inherits from has a + mapped table, otherwise return False. + """ + for class_ in cls.__mro__[1:]: + if getattr(class_, '__table__', None) is not None: + return True + return False + + +class DeclarativeMeta(type): + def __init__(cls, classname, bases, dict_): + if '_decl_class_registry' not in cls.__dict__: + _as_declarative(cls, classname, cls.__dict__) + type.__init__(cls, classname, bases, dict_) + + def __setattr__(cls, key, value): + _add_attribute(cls, key, value) + + +def synonym_for(name, map_column=False): + """Decorator, make a Python @property a query synonym for a column. + + A decorator version of :func:`~sqlalchemy.orm.synonym`. The function being + decorated is the 'descriptor', otherwise passes its arguments through to + synonym():: + + @synonym_for('col') + @property + def prop(self): + return 'special sauce' + + The regular ``synonym()`` is also usable directly in a declarative setting + and may be convenient for read/write properties:: + + prop = synonym('col', descriptor=property(_read_prop, _write_prop)) + + """ + def decorate(fn): + return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate + + +def comparable_using(comparator_factory): + """Decorator, allow a Python @property to be used in query criteria. + + This is a decorator front end to + :func:`~sqlalchemy.orm.comparable_property` that passes + through the comparator_factory and the function being decorated:: + + @comparable_using(MyComparatorType) + @property + def prop(self): + return 'special sauce' + + The regular ``comparable_property()`` is also usable directly in a + declarative setting and may be convenient for read/write properties:: + + prop = comparable_property(MyComparatorType) + + """ + def decorate(fn): + return comparable_property(comparator_factory, fn) + return decorate + + +class declared_attr(interfaces._MappedAttribute, property): + """Mark a class-level method as representing the definition of + a mapped property or special declarative member name. + + @declared_attr turns the attribute into a scalar-like + property that can be invoked from the uninstantiated class. + Declarative treats attributes specifically marked with + @declared_attr as returning a construct that is specific + to mapping or declarative table configuration. The name + of the attribute is that of what the non-dynamic version + of the attribute would be. + + @declared_attr is more often than not applicable to mixins, + to define relationships that are to be applied to different + implementors of the class:: + + class ProvidesUser(object): + "A mixin that adds a 'user' relationship to classes." + + @declared_attr + def user(self): + return relationship("User") + + It also can be applied to mapped classes, such as to provide + a "polymorphic" scheme for inheritance:: + + class Employee(Base): + id = Column(Integer, primary_key=True) + type = Column(String(50), nullable=False) + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + @declared_attr + def __mapper_args__(cls): + if cls.__name__ == 'Employee': + return { + "polymorphic_on":cls.type, + "polymorphic_identity":"Employee" + } + else: + return {"polymorphic_identity":cls.__name__} + + .. versionchanged:: 0.8 :class:`.declared_attr` can be used with + non-ORM or extension attributes, such as user-defined attributes + or :func:`.association_proxy` objects, which will be assigned + to the class at class construction time. + + + """ + + def __init__(self, fget, *arg, **kw): + super(declared_attr, self).__init__(fget, *arg, **kw) + self.__doc__ = fget.__doc__ + + def __get__(desc, self, cls): + return desc.fget(cls) + + +def declarative_base(bind=None, metadata=None, mapper=None, cls=object, + name='Base', constructor=_declarative_constructor, + class_registry=None, + metaclass=DeclarativeMeta): + """Construct a base class for declarative class definitions. + + The new base class will be given a metaclass that produces + appropriate :class:`~sqlalchemy.schema.Table` objects and makes + the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the + information provided declaratively in the class and any subclasses + of the class. + + :param bind: An optional + :class:`~sqlalchemy.engine.Connectable`, will be assigned + the ``bind`` attribute on the :class:`~sqlalchemy.schema.MetaData` + instance. + + :param metadata: + An optional :class:`~sqlalchemy.schema.MetaData` instance. All + :class:`~sqlalchemy.schema.Table` objects implicitly declared by + subclasses of the base will share this MetaData. A MetaData instance + will be created if none is provided. The + :class:`~sqlalchemy.schema.MetaData` instance will be available via the + `metadata` attribute of the generated declarative base class. + + :param mapper: + An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will + be used to map subclasses to their Tables. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the generated + declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param constructor: + Defaults to + :func:`~sqlalchemy.ext.declarative._declarative_constructor`, an + __init__ implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param class_registry: optional dictionary that will serve as the + registry of class names-> mapped classes when string names + are used to identify classes inside of :func:`.relationship` + and others. Allows two or more declarative base classes + to share the same registry of class names for simplified + inter-base relationships. + + :param metaclass: + Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + .. seealso:: + + :func:`.as_declarative` + + """ + lcl_metadata = metadata or MetaData() + if bind: + lcl_metadata.bind = bind + + if class_registry is None: + class_registry = weakref.WeakValueDictionary() + + bases = not isinstance(cls, tuple) and (cls,) or cls + class_dict = dict(_decl_class_registry=class_registry, + metadata=lcl_metadata) + + if constructor: + class_dict['__init__'] = constructor + if mapper: + class_dict['__mapper_cls__'] = mapper + + return metaclass(name, bases, class_dict) + +def as_declarative(**kw): + """ + Class decorator for :func:`.declarative_base`. + + Provides a syntactical shortcut to the ``cls`` argument + sent to :func:`.declarative_base`, allowing the base class + to be converted in-place to a "declarative" base:: + + from sqlalchemy.ext.declarative import as_declarative + + @as_declarative() + class Base(object): + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + id = Column(Integer, primary_key=True) + + class MyMappedClass(Base): + # ... + + All keyword arguments passed to :func:`.as_declarative` are passed + along to :func:`.declarative_base`. + + .. versionadded:: 0.8.3 + + .. seealso:: + + :func:`.declarative_base` + + """ + def decorate(cls): + kw['cls'] = cls + kw['name'] = cls.__name__ + return declarative_base(**kw) + + return decorate + +class ConcreteBase(object): + """A helper class for 'concrete' declarative mappings. + + :class:`.ConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.ConcreteBase` produces a mapped + table for the class itself. Compare to :class:`.AbstractConcreteBase`, + which does not. + + Example:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + """ + + @classmethod + def _create_polymorphic_union(cls, mappers): + return polymorphic_union(OrderedDict( + (mp.polymorphic_identity, mp.local_table) + for mp in mappers + ), 'type', 'pjoin') + + @classmethod + def __declare_first__(cls): + m = cls.__mapper__ + if m.with_polymorphic: + return + + mappers = list(m.self_and_descendants) + pjoin = cls._create_polymorphic_union(mappers) + m._set_with_polymorphic(("*", pjoin)) + m._set_polymorphic_on(pjoin.c.type) + + +class AbstractConcreteBase(ConcreteBase): + """A helper class for 'concrete' declarative mappings. + + :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.AbstractConcreteBase` does not produce a mapped + table for the class itself. Compare to :class:`.ConcreteBase`, + which does. + + Example:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Employee(AbstractConcreteBase, Base): + pass + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + """ + + __abstract__ = True + + @classmethod + def __declare_first__(cls): + if hasattr(cls, '__mapper__'): + return + + clsregistry.add_class(cls.__name__, cls) + # can't rely on 'self_and_descendants' here + # since technically an immediate subclass + # might not be mapped, but a subclass + # may be. + mappers = [] + stack = list(cls.__subclasses__()) + while stack: + klass = stack.pop() + stack.extend(klass.__subclasses__()) + mn = _mapper_or_none(klass) + if mn is not None: + mappers.append(mn) + pjoin = cls._create_polymorphic_union(mappers) + cls.__mapper__ = m = mapper(cls, pjoin, polymorphic_on=pjoin.c.type) + + for scls in cls.__subclasses__(): + sm = _mapper_or_none(scls) + if sm.concrete and cls in scls.__bases__: + sm._set_concrete_base(m) + + +class DeferredReflection(object): + """A helper class for construction of mappings based on + a deferred reflection step. + + Normally, declarative can be used with reflection by + setting a :class:`.Table` object using autoload=True + as the ``__table__`` attribute on a declarative class. + The caveat is that the :class:`.Table` must be fully + reflected, or at the very least have a primary key column, + at the point at which a normal declarative mapping is + constructed, meaning the :class:`.Engine` must be available + at class declaration time. + + The :class:`.DeferredReflection` mixin moves the construction + of mappers to be at a later point, after a specific + method is called which first reflects all :class:`.Table` + objects created so far. Classes can define it as such:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + + class MyClass(DeferredReflection, Base): + __tablename__ = 'mytable' + + Above, ``MyClass`` is not yet mapped. After a series of + classes have been defined in the above fashion, all tables + can be reflected and mappings created using + :meth:`.prepare`:: + + engine = create_engine("someengine://...") + DeferredReflection.prepare(engine) + + The :class:`.DeferredReflection` mixin can be applied to individual + classes, used as the base for the declarative base itself, + or used in a custom abstract class. Using an abstract base + allows that only a subset of classes to be prepared for a + particular prepare step, which is necessary for applications + that use more than one engine. For example, if an application + has two engines, you might use two bases, and prepare each + separately, e.g.:: + + class ReflectedOne(DeferredReflection, Base): + __abstract__ = True + + class ReflectedTwo(DeferredReflection, Base): + __abstract__ = True + + class MyClass(ReflectedOne): + __tablename__ = 'mytable' + + class MyOtherClass(ReflectedOne): + __tablename__ = 'myothertable' + + class YetAnotherClass(ReflectedTwo): + __tablename__ = 'yetanothertable' + + # ... etc. + + Above, the class hierarchies for ``ReflectedOne`` and + ``ReflectedTwo`` can be configured separately:: + + ReflectedOne.prepare(engine_one) + ReflectedTwo.prepare(engine_two) + + .. versionadded:: 0.8 + + """ + @classmethod + def prepare(cls, engine): + """Reflect all :class:`.Table` objects for all current + :class:`.DeferredReflection` subclasses""" + + to_map = _DeferredMapperConfig.classes_for_base(cls) + for thingy in to_map: + cls._sa_decl_prepare(thingy.local_table, engine) + thingy.map() + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + for rel in mapper._props.values(): + if isinstance(rel, properties.RelationshipProperty) and \ + rel.secondary is not None: + if isinstance(rel.secondary, Table): + cls._reflect_table(rel.secondary, engine) + elif isinstance(rel.secondary, _class_resolver): + rel.secondary._resolvers += ( + cls._sa_deferred_table_resolver(engine, metadata), + ) + + @classmethod + def _sa_deferred_table_resolver(cls, engine, metadata): + def _resolve(key): + t1 = Table(key, metadata) + cls._reflect_table(t1, engine) + return t1 + return _resolve + + @classmethod + def _sa_decl_prepare(cls, local_table, engine): + # autoload Table, which is already + # present in the metadata. This + # will fill in db-loaded columns + # into the existing Table object. + if local_table is not None: + cls._reflect_table(local_table, engine) + + @classmethod + def _reflect_table(cls, table, engine): + Table(table.name, + table.metadata, + extend_existing=True, + autoload_replace=False, + autoload=True, + autoload_with=engine, + schema=table.schema) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py new file mode 100644 index 00000000..eb66f12b --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -0,0 +1,532 @@ +# ext/declarative/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Internal implementation for declarative.""" + +from ...schema import Table, Column +from ...orm import mapper, class_mapper, synonym +from ...orm.interfaces import MapperProperty +from ...orm.properties import ColumnProperty, CompositeProperty +from ...orm.attributes import QueryableAttribute +from ...orm.base import _is_mapped_class +from ... import util, exc +from ...util import topological +from ...sql import expression +from ... import event +from . import clsregistry +import collections +import weakref + +def _declared_mapping_info(cls): + # deferred mapping + if _DeferredMapperConfig.has_cls(cls): + return _DeferredMapperConfig.config_for_cls(cls) + # regular mapping + elif _is_mapped_class(cls): + return class_mapper(cls, configure=False) + else: + return None + + +def _as_declarative(cls, classname, dict_): + from .api import declared_attr + + # dict_ will be a dictproxy, which we can't write to, and we need to! + dict_ = dict(dict_) + + column_copies = {} + potential_columns = {} + + mapper_args_fn = None + table_args = inherited_table_args = None + tablename = None + + declarative_props = (declared_attr, util.classproperty) + + for base in cls.__mro__: + _is_declarative_inherits = hasattr(base, '_decl_class_registry') + + if '__declare_last__' in base.__dict__: + @event.listens_for(mapper, "after_configured") + def go(): + cls.__declare_last__() + if '__declare_first__' in base.__dict__: + @event.listens_for(mapper, "before_configured") + def go(): + cls.__declare_first__() + if '__abstract__' in base.__dict__: + if (base is cls or + (base in cls.__bases__ and not _is_declarative_inherits) + ): + return + + class_mapped = _declared_mapping_info(base) is not None + + for name, obj in vars(base).items(): + if name == '__mapper_args__': + if not mapper_args_fn and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + mapper_args_fn = lambda: cls.__mapper_args__ + elif name == '__tablename__': + if not tablename and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + tablename = cls.__tablename__ + elif name == '__table_args__': + if not table_args and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + table_args = cls.__table_args__ + if not isinstance(table_args, (tuple, dict, type(None))): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None") + if base is not cls: + inherited_table_args = True + elif class_mapped: + if isinstance(obj, declarative_props): + util.warn("Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls)) + continue + elif base is not cls: + # we're a mixin. + if isinstance(obj, Column): + if getattr(cls, name) is not obj: + # if column has been overridden + # (like by the InstrumentedAttribute of the + # superclass), skip + continue + if obj.foreign_keys: + raise exc.InvalidRequestError( + "Columns with foreign keys to other columns " + "must be declared as @declared_attr callables " + "on declarative mixin classes. ") + if name not in dict_ and not ( + '__table__' in dict_ and + (obj.name or name) in dict_['__table__'].c + ) and name not in potential_columns: + potential_columns[name] = \ + column_copies[obj] = \ + obj.copy() + column_copies[obj]._creation_order = \ + obj._creation_order + elif isinstance(obj, MapperProperty): + raise exc.InvalidRequestError( + "Mapper properties (i.e. deferred," + "column_property(), relationship(), etc.) must " + "be declared as @declared_attr callables " + "on declarative mixin classes.") + elif isinstance(obj, declarative_props): + dict_[name] = ret = \ + column_copies[obj] = getattr(cls, name) + if isinstance(ret, (Column, MapperProperty)) and \ + ret.doc is None: + ret.doc = obj.__doc__ + + # apply inherited columns as we should + for k, v in potential_columns.items(): + dict_[k] = v + + if inherited_table_args and not tablename: + table_args = None + + clsregistry.add_class(classname, cls) + our_stuff = util.OrderedDict() + + for k in list(dict_): + + # TODO: improve this ? all dunders ? + if k in ('__table__', '__tablename__', '__mapper_args__'): + continue + + value = dict_[k] + if isinstance(value, declarative_props): + value = getattr(cls, k) + + elif isinstance(value, QueryableAttribute) and \ + value.class_ is not cls and \ + value.key != k: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = synonym(value.key) + setattr(cls, k, value) + + + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + if not isinstance(value, (Column, MapperProperty)): + if not k.startswith('__'): + dict_.pop(k) + setattr(cls, k, value) + continue + if k == 'metadata': + raise exc.InvalidRequestError( + "Attribute name 'metadata' is reserved " + "for the MetaData instance when using a " + "declarative base class." + ) + prop = clsregistry._deferred_relationship(cls, value) + our_stuff[k] = prop + + # set up attributes in the order they were created + our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) + + # extract columns from the class dict + declared_columns = set() + name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): + if isinstance(c, (ColumnProperty, CompositeProperty)): + for col in c.columns: + if isinstance(col, Column) and \ + col.table is None: + _undefer_column_name(key, col) + if not isinstance(c, CompositeProperty): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + name_to_prop_key[c.name].add(key) + declared_columns.add(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + for name, keys in name_to_prop_key.items(): + if len(keys) > 1: + util.warn( + "On class %r, Column object %r named directly multiple times, " + "only one will be used: %s" % + (classname, name, (", ".join(sorted(keys)))) + ) + + declared_columns = sorted( + declared_columns, key=lambda c: c._creation_order) + table = None + + if hasattr(cls, '__table_cls__'): + table_cls = util.unbound_method_to_callable(cls.__table_cls__) + else: + table_cls = Table + + if '__table__' not in dict_: + if tablename is not None: + + args, table_kw = (), {} + if table_args: + if isinstance(table_args, dict): + table_kw = table_args + elif isinstance(table_args, tuple): + if isinstance(table_args[-1], dict): + args, table_kw = table_args[0:-1], table_args[-1] + else: + args = table_args + + autoload = dict_.get('__autoload__') + if autoload: + table_kw['autoload'] = True + + cls.__table__ = table = table_cls( + tablename, cls.metadata, + *(tuple(declared_columns) + tuple(args)), + **table_kw) + else: + table = cls.__table__ + if declared_columns: + for c in declared_columns: + if not table.c.contains_column(c): + raise exc.ArgumentError( + "Can't add additional column %r when " + "specifying __table__" % c.key + ) + + if hasattr(cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) + else: + mapper_cls = mapper + + for c in cls.__bases__: + if _declared_mapping_info(c) is not None: + inherits = c + break + else: + inherits = None + + if table is None and inherits is None: + raise exc.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing " + "table-mapped class." % cls + ) + elif inherits: + inherited_mapper = _declared_mapping_info(inherits) + inherited_table = inherited_mapper.local_table + inherited_mapped_table = inherited_mapper.mapped_table + + if table is None: + # single table inheritance. + # ensure no table args + if table_args: + raise exc.ArgumentError( + "Can't place __table_args__ on an inherited class " + "with no table." + ) + # add any columns declared here to the inherited table. + for c in declared_columns: + if c.primary_key: + raise exc.ArgumentError( + "Can't place primary key columns on an inherited " + "class with no table." + ) + if c.name in inherited_table.c: + if inherited_table.c[c.name] is c: + continue + raise exc.ArgumentError( + "Column '%s' on class %s conflicts with " + "existing column '%s'" % + (c, cls, inherited_table.c[c.name]) + ) + inherited_table.append_column(c) + if inherited_mapped_table is not None and \ + inherited_mapped_table is not inherited_table: + inherited_mapped_table._refresh_for_new_column(c) + + defer_map = hasattr(cls, '_sa_decl_prepare') + if defer_map: + cfg_cls = _DeferredMapperConfig + else: + cfg_cls = _MapperConfig + mt = cfg_cls(mapper_cls, + cls, table, + inherits, + declared_columns, + column_copies, + our_stuff, + mapper_args_fn) + if not defer_map: + mt.map() + + +class _MapperConfig(object): + + mapped_table = None + + def __init__(self, mapper_cls, + cls, + table, + inherits, + declared_columns, + column_copies, + properties, mapper_args_fn): + self.mapper_cls = mapper_cls + self.cls = cls + self.local_table = table + self.inherits = inherits + self.properties = properties + self.mapper_args_fn = mapper_args_fn + self.declared_columns = declared_columns + self.column_copies = column_copies + + + def _prepare_mapper_arguments(self): + properties = self.properties + if self.mapper_args_fn: + mapper_args = self.mapper_args_fn() + else: + mapper_args = {} + + # make sure that column copies are used rather + # than the original columns from any mixins + for k in ('version_id_col', 'polymorphic_on',): + if k in mapper_args: + v = mapper_args[k] + mapper_args[k] = self.column_copies.get(v, v) + + assert 'inherits' not in mapper_args, \ + "Can't specify 'inherits' explicitly with declarative mappings" + + if self.inherits: + mapper_args['inherits'] = self.inherits + + if self.inherits and not mapper_args.get('concrete', False): + # single or joined inheritance + # exclude any cols on the inherited table which are + # not mapped on the parent class, to avoid + # mapping columns specific to sibling/nephew classes + inherited_mapper = _declared_mapping_info(self.inherits) + inherited_table = inherited_mapper.local_table + + if 'exclude_properties' not in mapper_args: + mapper_args['exclude_properties'] = exclude_properties = \ + set([c.key for c in inherited_table.c + if c not in inherited_mapper._columntoproperty]) + exclude_properties.difference_update( + [c.key for c in self.declared_columns]) + + # look through columns in the current mapper that + # are keyed to a propname different than the colname + # (if names were the same, we'd have popped it out above, + # in which case the mapper makes this combination). + # See if the superclass has a similar column property. + # If so, join them together. + for k, col in list(properties.items()): + if not isinstance(col, expression.ColumnElement): + continue + if k in inherited_mapper._props: + p = inherited_mapper._props[k] + if isinstance(p, ColumnProperty): + # note here we place the subclass column + # first. See [ticket:1892] for background. + properties[k] = [col] + p.columns + result_mapper_args = mapper_args.copy() + result_mapper_args['properties'] = properties + return result_mapper_args + + def map(self): + mapper_args = self._prepare_mapper_arguments() + self.cls.__mapper__ = self.mapper_cls( + self.cls, + self.local_table, + **mapper_args + ) + +class _DeferredMapperConfig(_MapperConfig): + _configs = util.OrderedDict() + + @property + def cls(self): + return self._cls() + + @cls.setter + def cls(self, class_): + self._cls = weakref.ref(class_, self._remove_config_cls) + self._configs[self._cls] = self + + @classmethod + def _remove_config_cls(cls, ref): + cls._configs.pop(ref, None) + + @classmethod + def has_cls(cls, class_): + # 2.6 fails on weakref if class_ is an old style class + return isinstance(class_, type) and \ + weakref.ref(class_) in cls._configs + + @classmethod + def config_for_cls(cls, class_): + return cls._configs[weakref.ref(class_)] + + + @classmethod + def classes_for_base(cls, base_cls, sort=True): + classes_for_base = [m for m in cls._configs.values() + if issubclass(m.cls, base_cls)] + if not sort: + return classes_for_base + + all_m_by_cls = dict( + (m.cls, m) + for m in classes_for_base + ) + + tuples = [] + for m_cls in all_m_by_cls: + tuples.extend( + (all_m_by_cls[base_cls], all_m_by_cls[m_cls]) + for base_cls in m_cls.__bases__ + if base_cls in all_m_by_cls + ) + return list( + topological.sort( + tuples, + classes_for_base + ) + ) + + def map(self): + self._configs.pop(self._cls, None) + super(_DeferredMapperConfig, self).map() + + +def _add_attribute(cls, key, value): + """add an attribute to an existing declarative class. + + This runs through the logic to determine MapperProperty, + adds it to the Mapper, adds a column to the mapped Table, etc. + + """ + + if '__mapper__' in cls.__dict__: + if isinstance(value, Column): + _undefer_column_name(key, value) + cls.__table__.append_column(value) + cls.__mapper__.add_property(key, value) + elif isinstance(value, ColumnProperty): + for col in value.columns: + if isinstance(col, Column) and col.table is None: + _undefer_column_name(key, col) + cls.__table__.append_column(col) + cls.__mapper__.add_property(key, value) + elif isinstance(value, MapperProperty): + cls.__mapper__.add_property( + key, + clsregistry._deferred_relationship(cls, value) + ) + elif isinstance(value, QueryableAttribute) and value.key != key: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = synonym(value.key) + cls.__mapper__.add_property( + key, + clsregistry._deferred_relationship(cls, value) + ) + else: + type.__setattr__(cls, key, value) + else: + type.__setattr__(cls, key, value) + + +def _declarative_constructor(self, **kwargs): + """A simple constructor that allows initialization from kwargs. + + Sets attributes on the constructed instance using the names and + values in ``kwargs``. + + Only keys that are present as + attributes of the instance's class are allowed. These could be, + for example, any mapped columns or relationships. + """ + cls_ = type(self) + for k in kwargs: + if not hasattr(cls_, k): + raise TypeError( + "%r is an invalid keyword argument for %s" % + (k, cls_.__name__)) + setattr(self, k, kwargs[k]) +_declarative_constructor.__name__ = '__init__' + + +def _undefer_column_name(key, column): + if column.key is None: + column.key = key + if column.name is None: + column.name = key diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py new file mode 100644 index 00000000..8b846746 --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -0,0 +1,309 @@ +# ext/declarative/clsregistry.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Routines to handle the string class registry used by declarative. + +This system allows specification of classes and expressions used in +:func:`.relationship` using strings. + +""" +from ...orm.properties import ColumnProperty, RelationshipProperty, \ + SynonymProperty +from ...schema import _get_table_key +from ...orm import class_mapper, interfaces +from ... import util +from ... import inspection +from ... import exc +import weakref + +# strong references to registries which we place in +# the _decl_class_registry, which is usually weak referencing. +# the internal registries here link to classes with weakrefs and remove +# themselves when all references to contained classes are removed. +_registries = set() + + +def add_class(classname, cls): + """Add a class to the _decl_class_registry associated with the + given declarative class. + + """ + if classname in cls._decl_class_registry: + # class already exists. + existing = cls._decl_class_registry[classname] + if not isinstance(existing, _MultipleClassMarker): + existing = \ + cls._decl_class_registry[classname] = \ + _MultipleClassMarker([cls, existing]) + else: + cls._decl_class_registry[classname] = cls + + try: + root_module = cls._decl_class_registry['_sa_module_registry'] + except KeyError: + cls._decl_class_registry['_sa_module_registry'] = \ + root_module = _ModuleMarker('_sa_module_registry', None) + + tokens = cls.__module__.split(".") + + # build up a tree like this: + # modulename: myapp.snacks.nuts + # + # myapp->snack->nuts->(classes) + # snack->nuts->(classes) + # nuts->(classes) + # + # this allows partial token paths to be used. + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + module.add_class(classname, cls) + + +class _MultipleClassMarker(object): + """refers to multiple classes of the same name + within _decl_class_registry. + + """ + + def __init__(self, classes, on_remove=None): + self.on_remove = on_remove + self.contents = set([ + weakref.ref(item, self._remove_item) for item in classes]) + _registries.add(self) + + def __iter__(self): + return (ref() for ref in self.contents) + + def attempt_get(self, path, key): + if len(self.contents) > 1: + raise exc.InvalidRequestError( + "Multiple classes found for path \"%s\" " + "in the registry of this declarative " + "base. Please use a fully module-qualified path." % + (".".join(path + [key])) + ) + else: + ref = list(self.contents)[0] + cls = ref() + if cls is None: + raise NameError(key) + return cls + + def _remove_item(self, ref): + self.contents.remove(ref) + if not self.contents: + _registries.discard(self) + if self.on_remove: + self.on_remove() + + def add_item(self, item): + modules = set([cls().__module__ for cls in self.contents]) + if item.__module__ in modules: + util.warn( + "This declarative base already contains a class with the " + "same class name and module name as %s.%s, and will " + "be replaced in the string-lookup table." % ( + item.__module__, + item.__name__ + ) + ) + self.contents.add(weakref.ref(item, self._remove_item)) + + +class _ModuleMarker(object): + """"refers to a module name within + _decl_class_registry. + + """ + def __init__(self, name, parent): + self.parent = parent + self.name = name + self.contents = {} + self.mod_ns = _ModNS(self) + if self.parent: + self.path = self.parent.path + [self.name] + else: + self.path = [] + _registries.add(self) + + def __contains__(self, name): + return name in self.contents + + def __getitem__(self, name): + return self.contents[name] + + def _remove_item(self, name): + self.contents.pop(name, None) + if not self.contents and self.parent is not None: + self.parent._remove_item(self.name) + _registries.discard(self) + + def resolve_attr(self, key): + return getattr(self.mod_ns, key) + + def get_module(self, name): + if name not in self.contents: + marker = _ModuleMarker(name, self) + self.contents[name] = marker + else: + marker = self.contents[name] + return marker + + def add_class(self, name, cls): + if name in self.contents: + existing = self.contents[name] + existing.add_item(cls) + else: + existing = self.contents[name] = \ + _MultipleClassMarker([cls], + on_remove=lambda: self._remove_item(name)) + + +class _ModNS(object): + def __init__(self, parent): + self.__parent = parent + + def __getattr__(self, key): + try: + value = self.__parent.contents[key] + except KeyError: + pass + else: + if value is not None: + if isinstance(value, _ModuleMarker): + return value.mod_ns + else: + assert isinstance(value, _MultipleClassMarker) + return value.attempt_get(self.__parent.path, key) + raise AttributeError("Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key)) + + +class _GetColumns(object): + def __init__(self, cls): + self.cls = cls + + def __getattr__(self, key): + mp = class_mapper(self.cls, configure=False) + if mp: + if key not in mp.all_orm_descriptors: + raise exc.InvalidRequestError( + "Class %r does not have a mapped column named %r" + % (self.cls, key)) + + desc = mp.all_orm_descriptors[key] + if desc.extension_type is interfaces.NOT_EXTENSION: + prop = desc.property + if isinstance(prop, SynonymProperty): + key = prop.name + elif not isinstance(prop, ColumnProperty): + raise exc.InvalidRequestError( + "Property %r is not an instance of" + " ColumnProperty (i.e. does not correspond" + " directly to a Column)." % key) + return getattr(self.cls, key) + +inspection._inspects(_GetColumns)( + lambda target: inspection.inspect(target.cls)) + + +class _GetTable(object): + def __init__(self, key, metadata): + self.key = key + self.metadata = metadata + + def __getattr__(self, key): + return self.metadata.tables[ + _get_table_key(key, self.key) + ] + + +def _determine_container(key, value): + if isinstance(value, _MultipleClassMarker): + value = value.attempt_get([], key) + return _GetColumns(value) + + +class _class_resolver(object): + def __init__(self, cls, prop, fallback, arg): + self.cls = cls + self.prop = prop + self.arg = self._declarative_arg = arg + self.fallback = fallback + self._dict = util.PopulateDict(self._access_cls) + self._resolvers = () + + def _access_cls(self, key): + cls = self.cls + if key in cls._decl_class_registry: + return _determine_container(key, cls._decl_class_registry[key]) + elif key in cls.metadata.tables: + return cls.metadata.tables[key] + elif key in cls.metadata._schemas: + return _GetTable(key, cls.metadata) + elif '_sa_module_registry' in cls._decl_class_registry and \ + key in cls._decl_class_registry['_sa_module_registry']: + registry = cls._decl_class_registry['_sa_module_registry'] + return registry.resolve_attr(key) + elif self._resolvers: + for resolv in self._resolvers: + value = resolv(key) + if value is not None: + return value + + return self.fallback[key] + + def __call__(self): + try: + x = eval(self.arg, globals(), self._dict) + + if isinstance(x, _GetColumns): + return x.cls + else: + return x + except NameError as n: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." % + (self.prop.parent, self.arg, n.args[0], self.cls) + ) + + +def _resolver(cls, prop): + import sqlalchemy + from sqlalchemy.orm import foreign, remote + + fallback = sqlalchemy.__dict__.copy() + fallback.update({'foreign': foreign, 'remote': remote}) + + def resolve_arg(arg): + return _class_resolver(cls, prop, fallback, arg) + return resolve_arg + + +def _deferred_relationship(cls, prop): + + if isinstance(prop, RelationshipProperty): + resolve_arg = _resolver(cls, prop) + + for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin', + 'secondary', '_user_defined_foreign_keys', 'remote_side'): + v = getattr(prop, attr) + if isinstance(v, util.string_types): + setattr(prop, attr, resolve_arg(v)) + + if prop.backref and isinstance(prop.backref, tuple): + key, kwargs = prop.backref + for attr in ('primaryjoin', 'secondaryjoin', 'secondary', + 'foreign_keys', 'remote_side', 'order_by'): + if attr in kwargs and isinstance(kwargs[attr], str): + kwargs[attr] = resolve_arg(kwargs[attr]) + + return prop diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 00000000..8b3f968d --- /dev/null +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,128 @@ +# ext/horizontal_shard.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. + +""" + +from .. import util +from ..orm.session import Session +from ..orm.query import Query + +__all__ = ['ShardedSession', 'ShardedQuery'] + + +class ShardedQuery(Query): + def __init__(self, *args, **kwargs): + super(ShardedQuery, self).__init__(*args, **kwargs) + self.id_chooser = self.session.id_chooser + self.query_chooser = self.session.query_chooser + self._shard_id = None + + def set_shard(self, shard_id): + """return a new query, limited to a single shard ID. + + all subsequent operations with the returned query will + be against the single shard regardless of other state. + """ + + q = self._clone() + q._shard_id = shard_id + return q + + def _execute_and_instances(self, context): + def iter_for_shard(shard_id): + context.attributes['shard_id'] = shard_id + result = self._connection_from_session( + mapper=self._mapper_zero(), + shard_id=shard_id).execute( + context.statement, + self._params) + return self.instances(result, context) + + if self._shard_id is not None: + return iter_for_shard(self._shard_id) + else: + partial = [] + for shard_id in self.query_chooser(self): + partial.extend(iter_for_shard(shard_id)) + + # if some kind of in memory 'sorting' + # were done, this is where it would happen + return iter(partial) + + def get(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).get(ident) + else: + ident = util.to_list(ident) + for shard_id in self.id_chooser(self, ident): + o = self.set_shard(shard_id).get(ident, **kwargs) + if o is not None: + return o + else: + return None + + +class ShardedSession(Session): + def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, + query_cls=ShardedQuery, **kwargs): + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param id_chooser: A callable, passed a query and a tuple of identity + values, which should return a list of shard ids where the ID might + reside. The databases will be queried in the order of this listing. + + :param query_chooser: For a given Query, returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + self.query_chooser = query_chooser + self.__binds = {} + self.connection_callable = self.connection + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + + if self.transaction is not None: + return self.transaction.connection(mapper, shard_id=shard_id) + else: + return self.get_bind(mapper, + shard_id=shard_id, + instance=instance).contextual_connect(**kwargs) + + def get_bind(self, mapper, shard_id=None, + instance=None, clause=None, **kw): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance, clause=clause) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py new file mode 100644 index 00000000..576e0bd4 --- /dev/null +++ b/lib/sqlalchemy/ext/hybrid.py @@ -0,0 +1,808 @@ +# ext/hybrid.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Define attributes on ORM-mapped classes that have "hybrid" behavior. + +"hybrid" means the attribute has distinct behaviors defined at the +class level and at the instance level. + +The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of +method decorator, is around 50 lines of code and has almost no +dependencies on the rest of SQLAlchemy. It can, in theory, work with +any descriptor-based expression system. + +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce +SQL expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +as the class itself:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session, aliased + from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method + + Base = declarative_base() + + class Interval(Base): + __tablename__ = 'interval' + + id = Column(Integer, primary_key=True) + start = Column(Integer, nullable=False) + end = Column(Integer, nullable=False) + + def __init__(self, start, end): + self.start = start + self.end = end + + @hybrid_property + def length(self): + return self.end - self.start + + @hybrid_method + def contains(self,point): + return (self.start <= point) & (point < self.end) + + @hybrid_method + def intersects(self, other): + return self.contains(other.start) | self.contains(other.end) + +Above, the ``length`` property returns the difference between the +``end`` and ``start`` attributes. With an instance of ``Interval``, +this subtraction occurs in Python, using normal Python descriptor +mechanics:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression:: + + >>> print Interval.length + interval."end" - interval.start + + >>> print Session().query(Interval).filter(Interval.length > 10) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval."end" - interval.start > :param_1 + +ORM methods such as :meth:`~.Query.filter_by` generally use ``getattr()`` to +locate attributes, so can also be used with hybrid attributes:: + + >>> print Session().query(Interval).filter_by(length=5) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval."end" - interval.start = :param_1 + +The ``Interval`` class example also illustrates two methods, +``contains()`` and ``intersects()``, decorated with +:class:`.hybrid_method`. This decorator applies the same idea to +methods that :class:`.hybrid_property` applies to attributes. The +methods return boolean values, and take advantage of the Python ``|`` +and ``&`` bitwise operators to produce equivalent instance-level and +SQL expression-level boolean behavior:: + + >>> i1.contains(6) + True + >>> i1.contains(15) + False + >>> i1.intersects(Interval(7, 18)) + True + >>> i1.intersects(Interval(25, 29)) + False + + >>> print Session().query(Interval).filter(Interval.contains(15)) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval.start <= :start_1 AND interval."end" > :end_1 + + >>> ia = aliased(Interval) + >>> print Session().query(Interval, ia).filter(Interval.intersects(ia)) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end, interval_1.id AS interval_1_id, + interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end + FROM interval, interval AS interval_1 + WHERE interval.start <= interval_1.start + AND interval."end" > interval_1.start + OR interval.start <= interval_1."end" + AND interval."end" > interval_1."end" + +Defining Expression Behavior Distinct from Attribute Behavior +-------------------------------------------------------------- + +Our usage of the ``&`` and ``|`` bitwise operators above was +fortunate, considering our functions operated on two boolean values to +return a new one. In many cases, the construction of an in-Python +function and a SQLAlchemy SQL expression have enough differences that +two separate Python expressions should be defined. The +:mod:`~sqlalchemy.ext.hybrid` decorators define the +:meth:`.hybrid_property.expression` modifier for this purpose. As an +example we'll define the radius of the interval, which requires the +usage of the absolute value function:: + + from sqlalchemy import func + + class Interval(object): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + @radius.expression + def radius(cls): + return func.abs(cls.length) / 2 + +Above the Python function ``abs()`` is used for instance-level +operations, the SQL function ``ABS()`` is used via the :attr:`.func` +object for class-level expressions:: + + >>> i1.radius + 2 + + >>> print Session().query(Interval).filter(Interval.radius > 5) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 + +Defining Setters +---------------- + +Hybrid properties can also define setter methods. If we wanted +``length`` above, when set, to modify the endpoint value:: + + class Interval(object): + # ... + + @hybrid_property + def length(self): + return self.end - self.start + + @length.setter + def length(self, value): + self.end = self.start + value + +The ``length(self, value)`` method is now called upon set:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + >>> i1.length = 12 + >>> i1.end + 17 + +Working with Relationships +-------------------------- + +There's no essential difference when creating hybrids that work with +related objects as opposed to column-based data. The need for distinct +expressions tends to be greater. Two variants of we'll illustrate +are the "join-dependent" hybrid, and the "correlated subquery" hybrid. + +Join-Dependent Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider the following declarative +mapping which relates a ``User`` to a ``SavingsAccount``:: + + from sqlalchemy import Column, Integer, ForeignKey, Numeric, String + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.hybrid import hybrid_property + + Base = declarative_base() + + class SavingsAccount(Base): + __tablename__ = 'account' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('user.id'), nullable=False) + balance = Column(Numeric(15, 5)) + + class User(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + + accounts = relationship("SavingsAccount", backref="owner") + + @hybrid_property + def balance(self): + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.setter + def balance(self, value): + if not self.accounts: + account = Account(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.expression + def balance(cls): + return SavingsAccount.balance + +The above hybrid property ``balance`` works with the first +``SavingsAccount`` entry in the list of accounts for this user. The +in-Python getter/setter methods can treat ``accounts`` as a Python +list available on ``self``. + +However, at the expression level, it's expected that the ``User`` class will +be used in an appropriate context such that an appropriate join to +``SavingsAccount`` will be present:: + + >>> print Session().query(User, User.balance).\\ + ... join(User.accounts).filter(User.balance > 5000) + SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" JOIN account ON "user".id = account.user_id + WHERE account.balance > :balance_1 + +Note however, that while the instance level accessors need to worry +about whether ``self.accounts`` is even present, this issue expresses +itself differently at the SQL expression level, where we basically +would use an outer join:: + + >>> from sqlalchemy import or_ + >>> print (Session().query(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) + SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id + WHERE account.balance < :balance_1 OR account.balance IS NULL + +Correlated Subquery Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We can, of course, forego being dependent on the enclosing query's usage +of joins in favor of the correlated subquery, which can portably be packed +into a single column expression. A correlated subquery is more portable, but +often performs more poorly at the SQL level. Using the same technique +illustrated at :ref:`mapper_column_property_sql_expressions`, +we can adjust our ``SavingsAccount`` example to aggregate the balances for +*all* accounts, and use a correlated subquery for the column expression:: + + from sqlalchemy import Column, Integer, ForeignKey, Numeric, String + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy import select, func + + Base = declarative_base() + + class SavingsAccount(Base): + __tablename__ = 'account' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('user.id'), nullable=False) + balance = Column(Numeric(15, 5)) + + class User(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + + accounts = relationship("SavingsAccount", backref="owner") + + @hybrid_property + def balance(self): + return sum(acc.balance for acc in self.accounts) + + @balance.expression + def balance(cls): + return select([func.sum(SavingsAccount.balance)]).\\ + where(SavingsAccount.user_id==cls.id).\\ + label('total_balance') + +The above recipe will give us the ``balance`` column which renders +a correlated SELECT:: + + >>> print s.query(User).filter(User.balance > 400) + SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE (SELECT sum(account.balance) AS sum_1 + FROM account + WHERE account.user_id = "user".id) > :param_1 + +.. _hybrid_custom_comparators: + +Building Custom Comparators +--------------------------- + +The hybrid property also includes a helper that allows construction of +custom comparators. A comparator object allows one to customize the +behavior of each SQLAlchemy expression operator individually. They +are useful when creating custom types that have some highly +idiosyncratic behavior on the SQL side. + +The example class below allows case-insensitive comparisons on the attribute +named ``word_insensitive``:: + + from sqlalchemy.ext.hybrid import Comparator, hybrid_property + from sqlalchemy import func, Column, Integer, String + from sqlalchemy.orm import Session + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class CaseInsensitiveComparator(Comparator): + def __eq__(self, other): + return func.lower(self.__clause_element__()) == func.lower(other) + + class SearchWord(Base): + __tablename__ = 'searchword' + id = Column(Integer, primary_key=True) + word = Column(String(255), nullable=False) + + @hybrid_property + def word_insensitive(self): + return self.word.lower() + + @word_insensitive.comparator + def word_insensitive(cls): + return CaseInsensitiveComparator(cls.word) + +Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` +SQL function to both sides:: + + >>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks") + SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = lower(:lower_1) + +The ``CaseInsensitiveComparator`` above implements part of the +:class:`.ColumnOperators` interface. A "coercion" operation like +lowercasing can be applied to all comparison operations (i.e. ``eq``, +``lt``, ``gt``, etc.) using :meth:`.Operators.operate`:: + + class CaseInsensitiveComparator(Comparator): + def operate(self, op, other): + return op(func.lower(self.__clause_element__()), func.lower(other)) + +Hybrid Value Objects +-------------------- + +Note in our previous example, if we were to compare the +``word_insensitive`` attribute of a ``SearchWord`` instance to a plain +Python string, the plain Python string would not be coerced to lower +case - the ``CaseInsensitiveComparator`` we built, being returned by +``@word_insensitive.comparator``, only applies to the SQL side. + +A more comprehensive form of the custom comparator is to construct a +*Hybrid Value Object*. This technique applies the target value or +expression to a value object which is then returned by the accessor in +all cases. The value object allows control of all operations upon +the value as well as how compared values are treated, both on the SQL +expression side as well as the Python value side. Replacing the +previous ``CaseInsensitiveComparator`` class with a new +``CaseInsensitiveWord`` class:: + + class CaseInsensitiveWord(Comparator): + "Hybrid value representing a lower case representation of a word." + + def __init__(self, word): + if isinstance(word, basestring): + self.word = word.lower() + elif isinstance(word, CaseInsensitiveWord): + self.word = word.word + else: + self.word = func.lower(word) + + def operate(self, op, other): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word) + + def __clause_element__(self): + return self.word + + def __str__(self): + return self.word + + key = 'word' + "Label to apply to Query tuple results" + +Above, the ``CaseInsensitiveWord`` object represents ``self.word``, +which may be a SQL function, or may be a Python native. By +overriding ``operate()`` and ``__clause_element__()`` to work in terms +of ``self.word``, all comparison operations will work against the +"converted" form of ``word``, whether it be SQL side or Python side. +Our ``SearchWord`` class can now deliver the ``CaseInsensitiveWord`` +object unconditionally from a single hybrid call:: + + class SearchWord(Base): + __tablename__ = 'searchword' + id = Column(Integer, primary_key=True) + word = Column(String(255), nullable=False) + + @hybrid_property + def word_insensitive(self): + return CaseInsensitiveWord(self.word) + +The ``word_insensitive`` attribute now has case-insensitive comparison +behavior universally, including SQL expression vs. Python expression +(note the Python value is converted to lower case on the Python side +here):: + + >>> print Session().query(SearchWord).filter_by(word_insensitive="Trucks") + SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = :lower_1 + +SQL expression versus SQL expression:: + + >>> sw1 = aliased(SearchWord) + >>> sw2 = aliased(SearchWord) + >>> print Session().query( + ... sw1.word_insensitive, + ... sw2.word_insensitive).\\ + ... filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) + SELECT lower(searchword_1.word) AS lower_1, + lower(searchword_2.word) AS lower_2 + FROM searchword AS searchword_1, searchword AS searchword_2 + WHERE lower(searchword_1.word) > lower(searchword_2.word) + +Python only expression:: + + >>> ws1 = SearchWord(word="SomeWord") + >>> ws1.word_insensitive == "sOmEwOrD" + True + >>> ws1.word_insensitive == "XOmEwOrX" + False + >>> print ws1.word_insensitive + someword + +The Hybrid Value pattern is very useful for any kind of value that may +have multiple representations, such as timestamps, time deltas, units +of measurement, currencies and encrypted passwords. + +.. seealso:: + + `Hybrids and Value Agnostic Types + `_ - + on the techspot.zzzeek.org blog + + `Value Agnostic Types, Part II + `_ - + on the techspot.zzzeek.org blog + +.. _hybrid_transformers: + +Building Transformers +---------------------- + +A *transformer* is an object which can receive a :class:`.Query` +object and return a new one. The :class:`.Query` object includes a +method :meth:`.with_transformation` that returns a new :class:`.Query` +transformed by the given function. + +We can combine this with the :class:`.Comparator` class to produce one type +of recipe which can both set up the FROM clause of a query as well as assign +filtering criterion. + +Consider a mapped class ``Node``, which assembles using adjacency list +into a hierarchical tree pattern:: + + from sqlalchemy import Column, Integer, ForeignKey + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class Node(Base): + __tablename__ = 'node' + id =Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('node.id')) + parent = relationship("Node", remote_side=id) + +Suppose we wanted to add an accessor ``grandparent``. This would +return the ``parent`` of ``Node.parent``. When we have an instance of +``Node``, this is simple:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class Node(Base): + # ... + + @hybrid_property + def grandparent(self): + return self.parent.parent + +For the expression, things are not so clear. We'd need to construct +a :class:`.Query` where we :meth:`~.Query.join` twice along +``Node.parent`` to get to the ``grandparent``. We can instead return +a transforming callable that we'll combine with the +:class:`.Comparator` class to receive any :class:`.Query` object, and +return a new one that's joined to the ``Node.parent`` attribute and +filtered based on the given criterion:: + + from sqlalchemy.ext.hybrid import Comparator + + class GrandparentTransformer(Comparator): + def operate(self, op, other): + def transform(q): + cls = self.__clause_element__() + parent_alias = aliased(cls) + return q.join(parent_alias, cls.parent).\\ + filter(op(parent_alias.parent, other)) + return transform + + Base = declarative_base() + + class Node(Base): + __tablename__ = 'node' + id =Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('node.id')) + parent = relationship("Node", remote_side=id) + + @hybrid_property + def grandparent(self): + return self.parent.parent + + @grandparent.comparator + def grandparent(cls): + return GrandparentTransformer(cls) + +The ``GrandparentTransformer`` overrides the core +:meth:`.Operators.operate` method at the base of the +:class:`.Comparator` hierarchy to return a query-transforming +callable, which then runs the given comparison operation in a +particular context. Such as, in the example above, the ``operate`` +method is called, given the :attr:`.Operators.eq` callable as well as +the right side of the comparison ``Node(id=5)``. A function +``transform`` is then returned which will transform a :class:`.Query` +first to join to ``Node.parent``, then to compare ``parent_alias`` +using :attr:`.Operators.eq` against the left and right sides, passing +into :class:`.Query.filter`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.orm import Session + >>> session = Session() + {sql}>>> session.query(Node).\\ + ... with_transformation(Node.grandparent==Node(id=5)).\\ + ... all() + SELECT node.id AS node_id, node.parent_id AS node_parent_id + FROM node JOIN node AS node_1 ON node_1.id = node.parent_id + WHERE :param_1 = node_1.parent_id + {stop} + +We can modify the pattern to be more verbose but flexible by separating +the "join" step from the "filter" step. The tricky part here is ensuring +that successive instances of ``GrandparentTransformer`` use the same +:class:`.AliasedClass` object against ``Node``. Below we use a simple +memoizing approach that associates a ``GrandparentTransformer`` +with each class:: + + class Node(Base): + + # ... + + @grandparent.comparator + def grandparent(cls): + # memoize a GrandparentTransformer + # per class + if '_gp' not in cls.__dict__: + cls._gp = GrandparentTransformer(cls) + return cls._gp + + class GrandparentTransformer(Comparator): + + def __init__(self, cls): + self.parent_alias = aliased(cls) + + @property + def join(self): + def go(q): + return q.join(self.parent_alias, Node.parent) + return go + + def operate(self, op, other): + return op(self.parent_alias.parent, other) + +.. sourcecode:: pycon+sql + + {sql}>>> session.query(Node).\\ + ... with_transformation(Node.grandparent.join).\\ + ... filter(Node.grandparent==Node(id=5)) + SELECT node.id AS node_id, node.parent_id AS node_parent_id + FROM node JOIN node AS node_1 ON node_1.id = node.parent_id + WHERE :param_1 = node_1.parent_id + {stop} + +The "transformer" pattern is an experimental pattern that starts +to make usage of some functional programming paradigms. +While it's only recommended for advanced and/or patient developers, +there's probably a whole lot of amazing things it can be used for. + +""" +from .. import util +from ..orm import attributes, interfaces + +HYBRID_METHOD = util.symbol('HYBRID_METHOD') +"""Symbol indicating an :class:`_InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`._InspectionAttr.extension_type` + attibute. + + .. seealso:: + + :attr:`.Mapper.all_orm_attributes` + +""" + +HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') +"""Symbol indicating an :class:`_InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`._InspectionAttr.extension_type` + attibute. + + .. seealso:: + + :attr:`.Mapper.all_orm_attributes` + +""" + +class hybrid_method(interfaces._InspectionAttr): + """A decorator which allows definition of a Python object method with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HYBRID_METHOD + + def __init__(self, func, expr=None): + """Create a new :class:`.hybrid_method`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_method + + class SomeClass(object): + @hybrid_method + def value(self, x, y): + return self._value + x + y + + @value.expression + def value(self, x, y): + return func.some_function(self._value, x, y) + + """ + self.func = func + self.expr = expr or func + + def __get__(self, instance, owner): + if instance is None: + return self.expr.__get__(owner, owner.__class__) + else: + return self.func.__get__(instance, owner) + + def expression(self, expr): + """Provide a modifying decorator that defines a + SQL-expression producing method.""" + + self.expr = expr + return self + + +class hybrid_property(interfaces._InspectionAttr): + """A decorator which allows definition of a Python descriptor with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HYBRID_PROPERTY + + def __init__(self, fget, fset=None, fdel=None, expr=None): + """Create a new :class:`.hybrid_property`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class SomeClass(object): + @hybrid_property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + """ + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr or fget + util.update_wrapper(self, fget) + + def __get__(self, instance, owner): + if instance is None: + return self.expr(owner) + else: + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError("can't set attribute") + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError("can't delete attribute") + self.fdel(instance) + + def setter(self, fset): + """Provide a modifying decorator that defines a value-setter method.""" + + self.fset = fset + return self + + def deleter(self, fdel): + """Provide a modifying decorator that defines a + value-deletion method.""" + + self.fdel = fdel + return self + + def expression(self, expr): + """Provide a modifying decorator that defines a SQL-expression + producing method.""" + + self.expr = expr + return self + + def comparator(self, comparator): + """Provide a modifying decorator that defines a custom + comparator producing method. + + The return value of the decorated method should be an instance of + :class:`~.hybrid.Comparator`. + + """ + + proxy_attr = attributes.\ + create_proxied_attribute(self) + + def expr(owner): + return proxy_attr(owner, self.__name__, self, comparator(owner)) + self.expr = expr + return self + + +class Comparator(interfaces.PropComparator): + """A helper class that allows easy construction of custom + :class:`~.orm.interfaces.PropComparator` + classes for usage with hybrids.""" + + property = None + + def __init__(self, expression): + self.expression = expression + + def __clause_element__(self): + expr = self.expression + while hasattr(expr, '__clause_element__'): + expr = expr.__clause_element__() + return expr + + def adapt_to_entity(self, adapt_to_entity): + # interesting.... + return self diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py new file mode 100644 index 00000000..2cf36e9b --- /dev/null +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -0,0 +1,407 @@ +"""Extensible class instrumentation. + +The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate +systems of class instrumentation within the ORM. Class instrumentation +refers to how the ORM places attributes on the class which maintain +data and track changes to that data, as well as event hooks installed +on the class. + +.. note:: + The extension package is provided for the benefit of integration + with other object management packages, which already perform + their own instrumentation. It is not intended for general use. + +For examples of how the instrumentation extension is used, +see the example :ref:`examples_instrumentation`. + +.. versionchanged:: 0.8 + The :mod:`sqlalchemy.orm.instrumentation` was split out so + that all functionality having to do with non-standard + instrumentation was moved out to :mod:`sqlalchemy.ext.instrumentation`. + When imported, the module installs itself within + :mod:`sqlalchemy.orm.instrumentation` so that it + takes effect, including recognition of + ``__sa_instrumentation_manager__`` on mapped classes, as + well :data:`.instrumentation_finders` + being used to determine class instrumentation resolution. + +""" +from ..orm import instrumentation as orm_instrumentation +from ..orm.instrumentation import ( + ClassManager, InstrumentationFactory, _default_state_getter, + _default_dict_getter, _default_manager_getter +) +from ..orm import attributes, collections, base as orm_base +from .. import util +from ..orm import exc as orm_exc +import weakref + +INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an InstrumentationManager or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a ClassManager or subclass + +This attribute is consulted by SQLAlchemy instrumentation +resolution, once the :mod:`sqlalchemy.ext.instrumentation` module +has been imported. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) + +instrumentation_finders = [find_native_user_instrumentation_hook] +"""An extensible sequence of callables which return instrumentation +implementations + +When a class is registered, each callable will be passed a class object. +If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + + +class ExtendedInstrumentationRegistry(InstrumentationFactory): + """Extends :class:`.InstrumentationFactory` with additional + bookkeeping, to accommodate multiple types of + class managers. + + """ + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = weakref.WeakKeyDictionary() + _dict_finders = weakref.WeakKeyDictionary() + _extended = False + + def _locate_extended_factory(self, class_): + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + manager = self._extended_class_manager(class_, factory) + return manager, factory + else: + return None, None + + def _check_conflicts(self, class_, factory): + existing_factories = self._collect_management_factories_for(class_).\ + difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" % ( + class_.__name__, list(existing_factories))) + + def _extended_class_manager(self, class_, factory): + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_instrumented_lookups() + + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a + hierarchy. + + Traverses the entire inheritance graph of a cls and returns a + collection of instrumentation factories for those classes. Factories + are extracted from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = self.manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def unregister(self, class_): + if class_ in self._manager_finders: + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + super(ExtendedInstrumentationRegistry, self).unregister(class_) + + def manager_of_class(self, cls): + if cls is None: + return None + return self._manager_finders.get(cls, _default_manager_getter)(cls) + + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._state_finders.get( + instance.__class__, _default_state_getter)(instance) + + def dict_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._dict_finders.get( + instance.__class__, _default_dict_getter)(instance) + + +orm_instrumentation._instrumentation_factory = \ + _instrumentation_factory = ExtendedInstrumentationRegistry() +orm_instrumentation.instrumentation_finders = instrumentation_finders + + +class InstrumentationManager(object): + """User-defined class instrumentation extension. + + :class:`.InstrumentationManager` can be subclassed in order + to change + how class instrumentation proceeds. This class exists for + the purposes of integration with other object management + frameworks which would like to entirely modify the + instrumentation methodology of the ORM, and is not intended + for regular usage. For interception of class instrumentation + events, see :class:`.InstrumentationEvents`. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + .. versionchanged:: 0.8 + :class:`.InstrumentationManager` was moved from + :mod:`sqlalchemy.orm.instrumentation` to + :mod:`sqlalchemy.ext.instrumentation`. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, '_default_class_manager', manager) + + def dispose(self, class_, manager): + delattr(class_, '_default_class_manager') + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, '_default_state', state) + + def remove_state(self, class_, instance): + delattr(instance, '_default_state') + + def state_getter(self, class_): + return lambda instance: getattr(instance, '_default_state') + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_) + + def manage(self): + self._adapted.manage(self.class_, self) + + def dispose(self): + self._adapted.dispose(self.class_) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + super(_ClassInstrumentationAdapter, self).post_configure_attribute(key) + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, 'initialize_collection', None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection(self, key, + state, factory) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._state_constructor(instance, self) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + self._get_state(instance) + except orm_exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + + +def _install_instrumented_lookups(): + """Replace global class/object management functions + with ExtendedInstrumentationRegistry implementations, which + allow multiple types of class managers to be present, + at the cost of performance. + + This function is called only by ExtendedInstrumentationRegistry + and unit tests specific to this behavior. + + The _reinstall_default_lookups() function can be called + after this one to re-establish the default functions. + + """ + _install_lookups( + dict( + instance_state=_instrumentation_factory.state_of, + instance_dict=_instrumentation_factory.dict_of, + manager_of_class=_instrumentation_factory.manager_of_class + ) + ) + + +def _reinstall_default_lookups(): + """Restore simplified lookups.""" + _install_lookups( + dict( + instance_state=_default_state_getter, + instance_dict=_default_dict_getter, + manager_of_class=_default_manager_getter + ) + ) + + +def _install_lookups(lookups): + global instance_state, instance_dict, manager_of_class + instance_state = lookups['instance_state'] + instance_dict = lookups['instance_dict'] + manager_of_class = lookups['manager_of_class'] + orm_base.instance_state = attributes.instance_state = \ + orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = attributes.instance_dict = \ + orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = attributes.manager_of_class = \ + orm_instrumentation.manager_of_class = manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py new file mode 100644 index 00000000..7869e888 --- /dev/null +++ b/lib/sqlalchemy/ext/mutable.py @@ -0,0 +1,636 @@ +# ext/mutable.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Provide support for tracking of in-place changes to scalar values, +which are propagated into ORM change events on owning parent objects. + +.. versionadded:: 0.7 :mod:`sqlalchemy.ext.mutable` replaces SQLAlchemy's + legacy approach to in-place mutations of scalar values; see + :ref:`07_migration_mutation_extension`. + +.. _mutable_scalars: + +Establishing Mutability on Scalar Column Values +=============================================== + +A typical example of a "mutable" structure is a Python dictionary. +Following the example introduced in :ref:`types_toplevel`, we +begin with a custom type that marshals Python dictionaries into +JSON strings before being persisted:: + + from sqlalchemy.types import TypeDecorator, VARCHAR + import json + + class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + +The usage of ``json`` is only for the purposes of example. The +:mod:`sqlalchemy.ext.mutable` extension can be used +with any type whose target Python type may be mutable, including +:class:`.PickleType`, :class:`.postgresql.ARRAY`, etc. + +When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself +tracks all parents which reference it. Below, we illustrate the a simple +version of the :class:`.MutableDict` dictionary object, which applies +the :class:`.Mutable` mixin to a plain Python dictionary:: + + from sqlalchemy.ext.mutable import Mutable + + class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + +The above dictionary class takes the approach of subclassing the Python +built-in ``dict`` to produce a dict +subclass which routes all mutation events through ``__setitem__``. There are +variants on this approach, such as subclassing ``UserDict.UserDict`` or +``collections.MutableMapping``; the part that's important to this example is +that the :meth:`.Mutable.changed` method is called whenever an in-place +change to the datastructure takes place. + +We also redefine the :meth:`.Mutable.coerce` method which will be used to +convert any values that are not instances of ``MutableDict``, such +as the plain dictionaries returned by the ``json`` module, into the +appropriate type. Defining this method is optional; we could just as well +created our ``JSONEncodedDict`` such that it always returns an instance +of ``MutableDict``, and additionally ensured that all calling code +uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not +overridden, any values applied to a parent object which are not instances +of the mutable type will raise a ``ValueError``. + +Our new ``MutableDict`` type offers a class method +:meth:`~.Mutable.as_mutable` which we can use within column metadata +to associate with types. This method grabs the given type object or +class and associates a listener that will detect all future mappings +of this type, applying event listening instrumentation to the mapped +attribute. Such as, with classical table metadata:: + + from sqlalchemy import Table, Column, Integer + + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) + ) + +Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` +(if the type object was not an instance already), which will intercept any +attributes which are mapped against this type. Below we establish a simple +mapping against the ``my_data`` table:: + + from sqlalchemy import mapper + + class MyDataClass(object): + pass + + # associates mutation listeners with MyDataClass.data + mapper(MyDataClass, my_data) + +The ``MyDataClass.data`` member will now be notified of in place changes +to its value. + +There's no difference in usage when using declarative:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class MyDataClass(Base): + __tablename__ = 'my_data' + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(JSONEncodedDict)) + +Any in-place changes to the ``MyDataClass.data`` member +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session() + >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> sess.add(m1) + >>> sess.commit() + + >>> m1.data['value1'] = 'bar' + >>> assert m1 in sess.dirty + True + +The ``MutableDict`` can be associated with all future instances +of ``JSONEncodedDict`` in one step, using +:meth:`~.Mutable.associate_with`. This is similar to +:meth:`~.Mutable.as_mutable` except it will intercept all occurrences +of ``MutableDict`` in all mappings unconditionally, without +the need to declare it individually:: + + MutableDict.associate_with(JSONEncodedDict) + + class MyDataClass(Base): + __tablename__ = 'my_data' + id = Column(Integer, primary_key=True) + data = Column(JSONEncodedDict) + + +Supporting Pickling +-------------------- + +The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the +placement of a ``weakref.WeakKeyDictionary`` upon the value object, which +stores a mapping of parent mapped objects keyed to the attribute name under +which they are associated with this value. ``WeakKeyDictionary`` objects are +not picklable, due to the fact that they contain weakrefs and function +callbacks. In our case, this is a good thing, since if this dictionary were +picklable, it could lead to an excessively large pickle size for our value +objects that are pickled by themselves outside of the context of the parent. +The developer responsibility here is only to provide a ``__getstate__`` method +that excludes the :meth:`~MutableBase._parents` collection from the pickle +stream:: + + class MyMutableType(Mutable): + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_parents', None) + return d + +With our dictionary example, we need to return the contents of the dict itself +(and also restore them on __setstate__):: + + class MutableDict(Mutable, dict): + # .... + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + +In the case that our mutable value object is pickled as it is attached to one +or more parent objects that are also part of the pickle, the :class:`.Mutable` +mixin will re-establish the :attr:`.Mutable._parents` collection on each value +object as the owning parents themselves are unpickled. + +.. _mutable_composites: + +Establishing Mutability on Composites +===================================== + +Composites are a special ORM feature which allow a single scalar attribute to +be assigned an object value which represents information "composed" from one +or more columns from the underlying mapped table. The usual example is that of +a geometric "point", and is introduced in :ref:`mapper_composite`. + +.. versionchanged:: 0.7 + The internals of :func:`.orm.composite` have been + greatly simplified and in-place mutation detection is no longer enabled by + default; instead, the user-defined value must detect changes on its own and + propagate them to all owning parents. The :mod:`sqlalchemy.ext.mutable` + extension provides the helper class :class:`.MutableComposite`, which is a + slight variant on the :class:`.Mutable` class. + +As is the case with :class:`.Mutable`, the user-defined composite class +subclasses :class:`.MutableComposite` as a mixin, and detects and delivers +change events to its parents via the :meth:`.MutableComposite.changed` method. +In the case of a composite class, the detection is usually via the usage of +Python descriptors (i.e. ``@property``), or alternatively via the special +Python method ``__setattr__()``. Below we expand upon the ``Point`` class +introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite` +and to also route attribute set events via ``__setattr__`` to the +:meth:`.MutableComposite.changed` method:: + + from sqlalchemy.ext.mutable import MutableComposite + + class Point(MutableComposite): + def __init__(self, x, y): + self.x = x + self.y = y + + def __setattr__(self, key, value): + "Intercept set events" + + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + def __composite_values__(self): + return self.x, self.y + + def __eq__(self, other): + return isinstance(other, Point) and \\ + other.x == self.x and \\ + other.y == self.y + + def __ne__(self, other): + return not self.__eq__(other) + +The :class:`.MutableComposite` class uses a Python metaclass to automatically +establish listeners for any usage of :func:`.orm.composite` that specifies our +``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class, +listeners are established which will route change events from ``Point`` +objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: + + from sqlalchemy.orm import composite, mapper + from sqlalchemy import Table, Column + + vertices = Table('vertices', metadata, + Column('id', Integer, primary_key=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + + class Vertex(object): + pass + + mapper(Vertex, vertices, properties={ + 'start': composite(Point, vertices.c.x1, vertices.c.y1), + 'end': composite(Point, vertices.c.x2, vertices.c.y2) + }) + +Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session() + >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15)) + >>> sess.add(v1) + >>> sess.commit() + + >>> v1.end.x = 8 + >>> assert v1 in sess.dirty + True + +Coercing Mutable Composites +--------------------------- + +The :meth:`.MutableBase.coerce` method is also supported on composite types. +In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce` +method is only called for attribute set operations, not load operations. +Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent +to using a :func:`.validates` validation routine for all attributes which +make use of the custom composite type:: + + class Point(MutableComposite): + # other Point methods + # ... + + def coerce(cls, key, value): + if isinstance(value, tuple): + value = Point(*value) + elif not isinstance(value, Point): + raise ValueError("tuple or Point expected") + return value + +.. versionadded:: 0.7.10,0.8.0b2 + Support for the :meth:`.MutableBase.coerce` method in conjunction with + objects of type :class:`.MutableComposite`. + +Supporting Pickling +-------------------- + +As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper +class uses a ``weakref.WeakKeyDictionary`` available via the +:meth:`MutableBase._parents` attribute which isn't picklable. If we need to +pickle instances of ``Point`` or its owning class ``Vertex``, we at least need +to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary. +Below we define both a ``__getstate__`` and a ``__setstate__`` that package up +the minimal form of our ``Point`` class:: + + class Point(MutableComposite): + # ... + + def __getstate__(self): + return self.x, self.y + + def __setstate__(self, state): + self.x, self.y = state + +As with :class:`.Mutable`, the :class:`.MutableComposite` augments the +pickling process of the parent's object-relational state so that the +:meth:`MutableBase._parents` collection is restored to all ``Point`` objects. + +""" +from ..orm.attributes import flag_modified +from .. import event, types +from ..orm import mapper, object_mapper, Mapper +from ..util import memoized_property +import weakref + + +class MutableBase(object): + """Common base class to :class:`.Mutable` + and :class:`.MutableComposite`. + + """ + + @memoized_property + def _parents(self): + """Dictionary of parent object->attribute name on the parent. + + This attribute is a so-called "memoized" property. It initializes + itself with a new ``weakref.WeakKeyDictionary`` the first time + it is accessed, returning the same object upon subsequent access. + + """ + + return weakref.WeakKeyDictionary() + + @classmethod + def coerce(cls, key, value): + """Given a value, coerce it into the target type. + + Can be overridden by custom subclasses to coerce incoming + data into a particular type. + + By default, raises ``ValueError``. + + This method is called in different scenarios depending on if + the parent class is of type :class:`.Mutable` or of type + :class:`.MutableComposite`. In the case of the former, it is called + for both attribute-set operations as well as during ORM loading + operations. For the latter, it is only called during attribute-set + operations; the mechanics of the :func:`.composite` construct + handle coercion during load operations. + + + :param key: string name of the ORM-mapped attribute being set. + :param value: the incoming value. + :return: the method should return the coerced value, or raise + ``ValueError`` if the coercion cannot be completed. + + """ + if value is None: + return None + msg = "Attribute '%s' does not accept objects of type %s" + raise ValueError(msg % (key, type(value))) + + @classmethod + def _listen_on_attribute(cls, attribute, coerce, parent_cls): + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + key = attribute.key + if parent_cls is not attribute.class_: + return + + # rely on "propagate" here + parent_cls = attribute.class_ + + def load(state, *args): + """Listen for objects loaded or refreshed. + + Wrap the target data member's value with + ``Mutable``. + + """ + val = state.dict.get(key, None) + if val is not None: + if coerce: + val = cls.coerce(key, val) + state.dict[key] = val + val._parents[state.obj()] = key + + def set(target, value, oldvalue, initiator): + """Listen for set/replace events on the target + data member. + + Establish a weak reference to the parent object + on the incoming value, remove it for the one + outgoing. + + """ + if value is oldvalue: + return value + + if not isinstance(value, cls): + value = cls.coerce(key, value) + if value is not None: + value._parents[target.obj()] = key + if isinstance(oldvalue, cls): + oldvalue._parents.pop(target.obj(), None) + return value + + def pickle(state, state_dict): + val = state.dict.get(key, None) + if val is not None: + if 'ext.mutable.values' not in state_dict: + state_dict['ext.mutable.values'] = [] + state_dict['ext.mutable.values'].append(val) + + def unpickle(state, state_dict): + if 'ext.mutable.values' in state_dict: + for val in state_dict['ext.mutable.values']: + val._parents[state.obj()] = key + + event.listen(parent_cls, 'load', load, + raw=True, propagate=True) + event.listen(parent_cls, 'refresh', load, + raw=True, propagate=True) + event.listen(attribute, 'set', set, + raw=True, retval=True, propagate=True) + event.listen(parent_cls, 'pickle', pickle, + raw=True, propagate=True) + event.listen(parent_cls, 'unpickle', unpickle, + raw=True, propagate=True) + + +class Mutable(MutableBase): + """Mixin that defines transparent propagation of change + events to a parent object. + + See the example in :ref:`mutable_scalars` for usage information. + + """ + + def changed(self): + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + flag_modified(parent, key) + + @classmethod + def associate_with_attribute(cls, attribute): + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + cls._listen_on_attribute(attribute, True, attribute.class_) + + @classmethod + def associate_with(cls, sqltype): + """Associate this wrapper with all future mapped columns + of the given type. + + This is a convenience method that calls + ``associate_with_attribute`` automatically. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.associate_with` for types that are permanent to an + application, not with ad-hoc types else this will cause unbounded + growth in memory usage. + + """ + + def listen_for_type(mapper, class_): + for prop in mapper.column_attrs: + if isinstance(prop.columns[0].type, sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(mapper, 'mapper_configured', listen_for_type) + + @classmethod + def as_mutable(cls, sqltype): + """Associate a SQL type with this mutable Python type. + + This establishes listeners that will detect ORM mappings against + the given type, adding mutation event trackers to those mappings. + + The type is returned, unconditionally as an instance, so that + :meth:`.as_mutable` can be used inline:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) + ) + + Note that the returned type is always an instance, even if a class + is given, and that only columns which are declared specifically with + that type instance receive additional instrumentation. + + To associate a particular mutable type with all occurrences of a + particular type, use the :meth:`.Mutable.associate_with` classmethod + of the particular :class:`.Mutable` subclass to establish a global + association. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.as_mutable` for types that are permanent to an application, + not with ad-hoc types else this will cause unbounded growth + in memory usage. + + """ + sqltype = types.to_instance(sqltype) + + def listen_for_type(mapper, class_): + for prop in mapper.column_attrs: + if prop.columns[0].type is sqltype: + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(mapper, 'mapper_configured', listen_for_type) + + return sqltype + + + +class MutableComposite(MutableBase): + """Mixin that defines transparent propagation of change + events on a SQLAlchemy "composite" object to its + owning parent or parents. + + See the example in :ref:`mutable_composites` for usage information. + + """ + + def changed(self): + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + + prop = object_mapper(parent).get_property(key) + for value, attr_name in zip( + self.__composite_values__(), + prop._attribute_keys): + setattr(parent, attr_name, value) + +def _setup_composite_listener(): + def _listen_for_type(mapper, class_): + for prop in mapper.iterate_properties: + if (hasattr(prop, 'composite_class') and + isinstance(prop.composite_class, type) and + issubclass(prop.composite_class, MutableComposite)): + prop.composite_class._listen_on_attribute( + getattr(class_, prop.key), False, class_) + if not event.contains(Mapper, "mapper_configured", _listen_for_type): + event.listen(Mapper, 'mapper_configured', _listen_for_type) +_setup_composite_listener() + + +class MutableDict(Mutable, dict): + """A dictionary type that implements :class:`.Mutable`. + + .. versionadded:: 0.8 + + """ + + def __setitem__(self, key, value): + """Detect dictionary set events and emit change events.""" + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + """Detect dictionary del events and emit change events.""" + dict.__delitem__(self, key) + self.changed() + + def clear(self): + dict.clear(self) + self.changed() + + @classmethod + def coerce(cls, key, value): + """Convert plain dictionary to MutableDict.""" + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py new file mode 100644 index 00000000..c4ba6d57 --- /dev/null +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,376 @@ +# ext/orderinglist.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""A custom list that manages index/position information for contained +elements. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. + +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: + + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position") + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. + +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +The :class:`.OrderingList` construct only works with **changes** to a collection, +and not the initial load from the database, and requires that the list be +sorted when loaded. Therefore, be sure to +specify ``order_by`` on the :func:`.relationship` against the target ordering +attribute, so that the ordering is correct when first loaded. + +.. warning:: + + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: + + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this lmitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, allevating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. + +:func:`.ordering_list` takes the name of the related object's ordering attribute as +an argument. By default, the zero-based integer index of the object's +position in the :func:`.ordering_list` is synchronized with the ordering attribute: +index 0 will get position 0, index 1 position 1, etc. To start numbering at 1 +or some other integer, provide ``count_from=1``. + + +""" +from ..orm.collections import collection +from .. import util + +__all__ = ['ordering_list'] + + +def ordering_list(attr, count_from=None, **kw): + """Prepares an :class:`OrderingList` factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: + + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: + Name of the mapped attribute to use for storage and retrieval of + ordering information + + :param count_from: + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Additional arguments are passed to the :class:`.OrderingList` constructor. + + """ + + kw = _unsugar_count_from(count_from=count_from, **kw) + return lambda: OrderingList(attr, **kw) + + +# Ordering utility functions + + +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + try: + f.__name__ = 'count_from_%i' % start + except TypeError: + pass + return f + + +def _unsugar_count_from(**kw): + """Builds counting functions from keyword arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop('count_from', None) + if kw.get('ordering_func', None) is None and count_from is not None: + if count_from == 0: + kw['ordering_func'] = count_from_0 + elif count_from == 1: + kw['ordering_func'] = count_from_1 + else: + kw['ordering_func'] = count_from_n_factory(count_from) + return kw + + +class OrderingList(list): + """A custom list that manages position information for its children. + + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`.relationship` function. + + """ + + def __init__(self, ordering_attr=None, ordering_func=None, + reorder_on_append=False): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + :param ordering_attr: + Name of the attribute that stores the object's order in the + relationship. + + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + :param reorder_on_append: + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self): + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super(OrderingList, self).append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super(OrderingList, self).append(entity) + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super(OrderingList, self).insert(index, entity) + self._reorder() + + def remove(self, entity): + super(OrderingList, self).remove(entity) + self._reorder() + + def pop(self, index=-1): + entity = super(OrderingList, self).pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super(OrderingList, self).__setitem__(index, entity) + + def __delitem__(self, index): + super(OrderingList, self).__delitem__(index) + self._reorder() + + def __setslice__(self, start, end, values): + super(OrderingList, self).__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super(OrderingList, self).__delslice__(start, end) + self._reorder() + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(list, func_name)): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +def _reconstitute(cls, dict_, items): + """ Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py new file mode 100644 index 00000000..388cd404 --- /dev/null +++ b/lib/sqlalchemy/ext/serializer.py @@ -0,0 +1,156 @@ +# ext/serializer.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. + +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. + +""" + +from ..orm import class_mapper +from ..orm.session import Session +from ..orm.mapper import Mapper +from ..orm.interfaces import MapperProperty +from ..orm.attributes import QueryableAttribute +from .. import Table, Column +from ..engine import Engine +from ..util import pickle, byte_buffer, b64encode, b64decode, text_type +import re + + +__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + #print "serializing:", repr(obj) + if isinstance(obj, QueryableAttribute): + cls = obj.impl.class_ + key = obj.impl.key + id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) + elif isinstance(obj, Mapper) and not obj.non_primary: + id = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \ + ":" + obj.key + elif isinstance(obj, Table): + id = "table:" + text_type(obj.key) + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id = "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) + elif isinstance(obj, Session): + id = "session:" + elif isinstance(obj, Engine): + id = "engine:" + else: + return None + return id + + pickler.persistent_id = persistent_id + return pickler + +our_ids = re.compile( + r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)') + + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id): + m = our_ids.match(text_type(id)) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == 'attribute': + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "mapperprop": + mapper, keyname = args.split(':') + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(':') + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load + return unpickler + + +def dumps(obj, protocol=0): + buf = byte_buffer() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = byte_buffer(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py new file mode 100644 index 00000000..fe9e4055 --- /dev/null +++ b/lib/sqlalchemy/inspection.py @@ -0,0 +1,92 @@ +# sqlalchemy/inspect.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The inspection module provides the :func:`.inspect` function, +which delivers runtime information about a wide variety +of SQLAlchemy objects, both within the Core as well as the +ORM. + +The :func:`.inspect` function is the entry point to SQLAlchemy's +public API for viewing the configuration and construction +of in-memory objects. Depending on the type of object +passed to :func:`.inspect`, the return value will either be +a related object which provides a known interface, or in many +cases it will return the object itself. + +The rationale for :func:`.inspect` is twofold. One is that +it replaces the need to be aware of a large variety of "information +getting" functions in SQLAlchemy, such as :meth:`.Inspector.from_engine`, +:func:`.orm.attributes.instance_state`, :func:`.orm.class_mapper`, +and others. The other is that the return value of :func:`.inspect` +is guaranteed to obey a documented API, thus allowing third party +tools which build on top of SQLAlchemy configurations to be constructed +in a forwards-compatible way. + +.. versionadded:: 0.8 The :func:`.inspect` system is introduced + as of version 0.8. + +""" + +from . import util, exc +_registrars = util.defaultdict(list) + + +def inspect(subject, raiseerr=True): + """Produce an inspection object for the given target. + + The returned value in some cases may be the + same object as the one given, such as if a + :class:`.Mapper` object is passed. In other + cases, it will be an instance of the registered + inspection type for the given object, such as + if an :class:`.engine.Engine` is passed, an + :class:`.Inspector` object is returned. + + :param subject: the subject to be inspected. + :param raiseerr: When ``True``, if the given subject + does not + correspond to a known SQLAlchemy inspected type, + :class:`sqlalchemy.exc.NoInspectionAvailable` + is raised. If ``False``, ``None`` is returned. + + """ + type_ = type(subject) + for cls in type_.__mro__: + if cls in _registrars: + reg = _registrars[cls] + if reg is True: + return subject + ret = reg(subject) + if ret is not None: + break + else: + reg = ret = None + + if raiseerr and ( + reg is None or ret is None + ): + raise exc.NoInspectionAvailable( + "No inspection system is " + "available for object of type %s" % + type_) + return ret + + +def _inspects(*types): + def decorate(fn_or_cls): + for type_ in types: + if type_ in _registrars: + raise AssertionError( + "Type %s is already " + "registered" % type_) + _registrars[type_] = fn_or_cls + return fn_or_cls + return decorate + + +def _self_inspects(cls): + _inspects(cls)(True) + return cls diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py new file mode 100644 index 00000000..ed50a645 --- /dev/null +++ b/lib/sqlalchemy/interfaces.py @@ -0,0 +1,310 @@ +# sqlalchemy/interfaces.py +# Copyright (C) 2007-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2007 Jason Kirtland jek@discorporate.us +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Deprecated core event interfaces. + +This module is **deprecated** and is superseded by the +event system. + +""" + +from . import event, util + + +class PoolListener(object): + """Hooks into the lifecycle of connections in a :class:`.Pool`. + + .. note:: + + :class:`.PoolListener` is deprecated. Please + refer to :class:`.PoolEvents`. + + Usage:: + + class MyListener(PoolListener): + def connect(self, dbapi_con, con_record): + '''perform connect operations''' + # etc. + + # create a new pool with a listener + p = QueuePool(..., listeners=[MyListener()]) + + # add a listener after the fact + p.add_listener(MyListener()) + + # usage with create_engine() + e = create_engine("url://", listeners=[MyListener()]) + + All of the standard connection :class:`~sqlalchemy.pool.Pool` types can + accept event listeners for key connection lifecycle events: + creation, pool check-out and check-in. There are no events fired + when a connection closes. + + For any given DB-API connection, there will be one ``connect`` + event, `n` number of ``checkout`` events, and either `n` or `n - 1` + ``checkin`` events. (If a ``Connection`` is detached from its + pool via the ``detach()`` method, it won't be checked back in.) + + These are low-level events for low-level objects: raw Python + DB-API connections, without the conveniences of the SQLAlchemy + ``Connection`` wrapper, ``Dialect`` services or ``ClauseElement`` + execution. If you execute SQL through the connection, explicitly + closing all cursors and other resources is recommended. + + Events also receive a ``_ConnectionRecord``, a long-lived internal + ``Pool`` object that basically represents a "slot" in the + connection pool. ``_ConnectionRecord`` objects have one public + attribute of note: ``info``, a dictionary whose contents are + scoped to the lifetime of the DB-API connection managed by the + record. You can use this shared storage area however you like. + + There is no need to subclass ``PoolListener`` to handle events. + Any class that implements one or more of these methods can be used + as a pool listener. The ``Pool`` will inspect the methods + provided by a listener object and add the listener to one or more + internal event queues based on its capabilities. In terms of + efficiency and function call overhead, you're much better off only + providing implementations for the hooks you'll be using. + + """ + + @classmethod + def _adapt_listener(cls, self, listener): + """Adapt a :class:`.PoolListener` to individual + :class:`event.Dispatch` events. + + """ + + listener = util.as_interface(listener, methods=('connect', + 'first_connect', 'checkout', 'checkin')) + if hasattr(listener, 'connect'): + event.listen(self, 'connect', listener.connect) + if hasattr(listener, 'first_connect'): + event.listen(self, 'first_connect', listener.first_connect) + if hasattr(listener, 'checkout'): + event.listen(self, 'checkout', listener.checkout) + if hasattr(listener, 'checkin'): + event.listen(self, 'checkin', listener.checkin) + + def connect(self, dbapi_con, con_record): + """Called once for each new DB-API connection or Pool's ``creator()``. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + def first_connect(self, dbapi_con, con_record): + """Called exactly once for the first DB-API connection. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + def checkout(self, dbapi_con, con_record, con_proxy): + """Called when a connection is retrieved from the Pool. + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + con_proxy + The ``_ConnectionFairy`` which manages the connection for the span of + the current checkout. + + If you raise an ``exc.DisconnectionError``, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + """ + + def checkin(self, dbapi_con, con_record): + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + +class ConnectionProxy(object): + """Allows interception of statement execution by Connections. + + .. note:: + + :class:`.ConnectionProxy` is deprecated. Please + refer to :class:`.ConnectionEvents`. + + Either or both of the ``execute()`` and ``cursor_execute()`` + may be implemented to intercept compiled statement and + cursor level executions, e.g.:: + + class MyProxy(ConnectionProxy): + def execute(self, conn, execute, clauseelement, + *multiparams, **params): + print "compiled statement:", clauseelement + return execute(clauseelement, *multiparams, **params) + + def cursor_execute(self, execute, cursor, statement, + parameters, context, executemany): + print "raw statement:", statement + return execute(cursor, statement, parameters, context) + + The ``execute`` argument is a function that will fulfill the default + execution behavior for the operation. The signature illustrated + in the example should be used. + + The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via + the ``proxy`` argument:: + + e = create_engine('someurl://', proxy=MyProxy()) + + """ + + @classmethod + def _adapt_listener(cls, self, listener): + + def adapt_execute(conn, clauseelement, multiparams, params): + + def execute_wrapper(clauseelement, *multiparams, **params): + return clauseelement, multiparams, params + + return listener.execute(conn, execute_wrapper, + clauseelement, *multiparams, + **params) + + event.listen(self, 'before_execute', adapt_execute) + + def adapt_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + + def execute_wrapper( + cursor, + statement, + parameters, + context, + ): + return statement, parameters + + return listener.cursor_execute( + execute_wrapper, + cursor, + statement, + parameters, + context, + executemany, + ) + + event.listen(self, 'before_cursor_execute', adapt_cursor_execute) + + def do_nothing_callback(*arg, **kw): + pass + + def adapt_listener(fn): + + def go(conn, *arg, **kw): + fn(conn, do_nothing_callback, *arg, **kw) + + return util.update_wrapper(go, fn) + + event.listen(self, 'begin', adapt_listener(listener.begin)) + event.listen(self, 'rollback', + adapt_listener(listener.rollback)) + event.listen(self, 'commit', adapt_listener(listener.commit)) + event.listen(self, 'savepoint', + adapt_listener(listener.savepoint)) + event.listen(self, 'rollback_savepoint', + adapt_listener(listener.rollback_savepoint)) + event.listen(self, 'release_savepoint', + adapt_listener(listener.release_savepoint)) + event.listen(self, 'begin_twophase', + adapt_listener(listener.begin_twophase)) + event.listen(self, 'prepare_twophase', + adapt_listener(listener.prepare_twophase)) + event.listen(self, 'rollback_twophase', + adapt_listener(listener.rollback_twophase)) + event.listen(self, 'commit_twophase', + adapt_listener(listener.commit_twophase)) + + def execute(self, conn, execute, clauseelement, *multiparams, **params): + """Intercept high level execute() events.""" + + return execute(clauseelement, *multiparams, **params) + + def cursor_execute(self, execute, cursor, statement, parameters, + context, executemany): + """Intercept low-level cursor execute() events.""" + + return execute(cursor, statement, parameters, context) + + def begin(self, conn, begin): + """Intercept begin() events.""" + + return begin() + + def rollback(self, conn, rollback): + """Intercept rollback() events.""" + + return rollback() + + def commit(self, conn, commit): + """Intercept commit() events.""" + + return commit() + + def savepoint(self, conn, savepoint, name=None): + """Intercept savepoint() events.""" + + return savepoint(name=name) + + def rollback_savepoint(self, conn, rollback_savepoint, name, context): + """Intercept rollback_savepoint() events.""" + + return rollback_savepoint(name, context) + + def release_savepoint(self, conn, release_savepoint, name, context): + """Intercept release_savepoint() events.""" + + return release_savepoint(name, context) + + def begin_twophase(self, conn, begin_twophase, xid): + """Intercept begin_twophase() events.""" + + return begin_twophase(xid) + + def prepare_twophase(self, conn, prepare_twophase, xid): + """Intercept prepare_twophase() events.""" + + return prepare_twophase(xid) + + def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared): + """Intercept rollback_twophase() events.""" + + return rollback_twophase(xid, is_prepared) + + def commit_twophase(self, conn, commit_twophase, xid, is_prepared): + """Intercept commit_twophase() events.""" + + return commit_twophase(xid, is_prepared) diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py new file mode 100644 index 00000000..935761d5 --- /dev/null +++ b/lib/sqlalchemy/log.py @@ -0,0 +1,214 @@ +# sqlalchemy/log.py +# Copyright (C) 2006-2014 the SQLAlchemy authors and contributors +# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Logging control and utilities. + +Control of logging for SA can be performed from the regular python logging +module. The regular dotted module namespace is used, starting at +'sqlalchemy'. For class-level logging, the class name is appended. + +The "echo" keyword parameter, available on SQLA :class:`.Engine` +and :class:`.Pool` objects, corresponds to a logger specific to that +instance only. + +""" + +import logging +import sys + +# set initial level to WARN. This so that +# log statements don't occur in the absense of explicit +# logging being enabled for 'sqlalchemy'. +rootlogger = logging.getLogger('sqlalchemy') +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) + + +def _add_default_handler(logger): + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s %(name)s %(message)s')) + logger.addHandler(handler) + + +_logged_classes = set() + + +def class_logger(cls): + logger = logging.getLogger(cls.__module__ + "." + cls.__name__) + cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG) + cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO) + cls.logger = logger + _logged_classes.add(cls) + return cls + +class Identified(object): + logging_name = None + + def _should_log_debug(self): + return self.logger.isEnabledFor(logging.DEBUG) + + def _should_log_info(self): + return self.logger.isEnabledFor(logging.INFO) + + +class InstanceLogger(object): + """A logger adapter (wrapper) for :class:`.Identified` subclasses. + + This allows multiple instances (e.g. Engine or Pool instances) + to share a logger, but have its verbosity controlled on a + per-instance basis. + + The basic functionality is to return a logging level + which is based on an instance's echo setting. + + Default implementation is: + + 'debug' -> logging.DEBUG + True -> logging.INFO + False -> Effective level of underlying logger + (logging.WARNING by default) + None -> same as False + """ + + # Map echo settings to logger levels + _echo_map = { + None: logging.NOTSET, + False: logging.NOTSET, + True: logging.INFO, + 'debug': logging.DEBUG, + } + + def __init__(self, echo, name): + self.echo = echo + self.logger = logging.getLogger(name) + + # if echo flag is enabled and no handlers, + # add a handler to the list + if self._echo_map[echo] <= logging.INFO \ + and not self.logger.handlers: + _add_default_handler(self.logger) + + # + # Boilerplate convenience methods + # + def debug(self, msg, *args, **kwargs): + """Delegate a debug call to the underlying logger.""" + + self.log(logging.DEBUG, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + """Delegate an info call to the underlying logger.""" + + self.log(logging.INFO, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """Delegate a warning call to the underlying logger.""" + + self.log(logging.WARNING, msg, *args, **kwargs) + + warn = warning + + def error(self, msg, *args, **kwargs): + """ + Delegate an error call to the underlying logger. + """ + self.log(logging.ERROR, msg, *args, **kwargs) + + def exception(self, msg, *args, **kwargs): + """Delegate an exception call to the underlying logger.""" + + kwargs["exc_info"] = 1 + self.log(logging.ERROR, msg, *args, **kwargs) + + def critical(self, msg, *args, **kwargs): + """Delegate a critical call to the underlying logger.""" + + self.log(logging.CRITICAL, msg, *args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + """Delegate a log call to the underlying logger. + + The level here is determined by the echo + flag as well as that of the underlying logger, and + logger._log() is called directly. + + """ + + # inline the logic from isEnabledFor(), + # getEffectiveLevel(), to avoid overhead. + + if self.logger.manager.disable >= level: + return + + selected_level = self._echo_map[self.echo] + if selected_level == logging.NOTSET: + selected_level = self.logger.getEffectiveLevel() + + if level >= selected_level: + self.logger._log(level, msg, args, **kwargs) + + def isEnabledFor(self, level): + """Is this logger enabled for level 'level'?""" + + if self.logger.manager.disable >= level: + return False + return level >= self.getEffectiveLevel() + + def getEffectiveLevel(self): + """What's the effective level for this logger?""" + + level = self._echo_map[self.echo] + if level == logging.NOTSET: + level = self.logger.getEffectiveLevel() + return level + + +def instance_logger(instance, echoflag=None): + """create a logger for an instance that implements :class:`.Identified`.""" + + if instance.logging_name: + name = "%s.%s.%s" % (instance.__class__.__module__, + instance.__class__.__name__, instance.logging_name) + else: + name = "%s.%s" % (instance.__class__.__module__, + instance.__class__.__name__) + + instance._echo = echoflag + + if echoflag in (False, None): + # if no echo setting or False, return a Logger directly, + # avoiding overhead of filtering + logger = logging.getLogger(name) + else: + # if a specified echo flag, return an EchoLogger, + # which checks the flag, overrides normal log + # levels by calling logger._log() + logger = InstanceLogger(echoflag, name) + + instance.logger = logger + + +class echo_property(object): + __doc__ = """\ + When ``True``, enable log output for this element. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + """ + + def __get__(self, instance, owner): + if instance is None: + return self + else: + return instance._echo + + def __set__(self, instance, value): + instance_logger(instance, echoflag=value) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py new file mode 100644 index 00000000..7825a70a --- /dev/null +++ b/lib/sqlalchemy/orm/__init__.py @@ -0,0 +1,267 @@ +# orm/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Functional constructs for ORM configuration. + +See the SQLAlchemy object relational tutorial and mapper configuration +documentation for an overview of how this module is used. + +""" + +from . import exc +from .mapper import ( + Mapper, + _mapper_registry, + class_mapper, + configure_mappers, + reconstructor, + validates + ) +from .interfaces import ( + EXT_CONTINUE, + EXT_STOP, + PropComparator, + ) +from .deprecated_interfaces import ( + MapperExtension, + SessionExtension, + AttributeExtension, +) +from .util import ( + aliased, + join, + object_mapper, + outerjoin, + polymorphic_union, + was_deleted, + with_parent, + with_polymorphic, + ) +from .properties import ColumnProperty +from .relationships import RelationshipProperty +from .descriptor_props import ( + ComparableProperty, + CompositeProperty, + SynonymProperty, + ) +from .relationships import ( + foreign, + remote, +) +from .session import ( + Session, + object_session, + sessionmaker, + make_transient +) +from .scoping import ( + scoped_session +) +from . import mapper as mapperlib +from .query import AliasOption, Query, Bundle +from ..util.langhelpers import public_factory +from .. import util as _sa_util +from . import strategies as _strategies + +def create_session(bind=None, **kwargs): + """Create a new :class:`.Session` + with no automation enabled by default. + + This function is used primarily for testing. The usual + route to :class:`.Session` creation is via its constructor + or the :func:`.sessionmaker` function. + + :param bind: optional, a single Connectable to use for all + database access in the created + :class:`~sqlalchemy.orm.session.Session`. + + :param \*\*kwargs: optional, passed through to the + :class:`.Session` constructor. + + :returns: an :class:`~sqlalchemy.orm.session.Session` instance + + The defaults of create_session() are the opposite of that of + :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are + False, ``autocommit`` is True. In this sense the session acts + more like the "classic" SQLAlchemy 0.3 session with these. + + Usage:: + + >>> from sqlalchemy.orm import create_session + >>> session = create_session() + + It is recommended to use :func:`sessionmaker` instead of + create_session(). + + """ + kwargs.setdefault('autoflush', False) + kwargs.setdefault('autocommit', True) + kwargs.setdefault('expire_on_commit', False) + return Session(bind=bind, **kwargs) + +relationship = public_factory(RelationshipProperty, ".orm.relationship") + +def relation(*arg, **kw): + """A synonym for :func:`relationship`.""" + + return relationship(*arg, **kw) + + +def dynamic_loader(argument, **kw): + """Construct a dynamically-loading mapper property. + + This is essentially the same as + using the ``lazy='dynamic'`` argument with :func:`relationship`:: + + dynamic_loader(SomeClass) + + # is the same as + + relationship(SomeClass, lazy="dynamic") + + See the section :ref:`dynamic_relationship` for more details + on dynamic loading. + + """ + kw['lazy'] = 'dynamic' + return relationship(argument, **kw) + + +column_property = public_factory(ColumnProperty, ".orm.column_property") +composite = public_factory(CompositeProperty, ".orm.composite") + + +def backref(name, **kwargs): + """Create a back reference with explicit keyword arguments, which are the + same arguments one can send to :func:`relationship`. + + Used with the ``backref`` keyword argument to :func:`relationship` in + place of a string argument, e.g.:: + + 'items':relationship(SomeItem, backref=backref('parent', lazy='subquery')) + + """ + return (name, kwargs) + + +def deferred(*columns, **kw): + """Indicate a column-based mapped attribute that by default will + not load unless accessed. + + :param \*columns: columns to be mapped. This is typically a single + :class:`.Column` object, however a collection is supported in order + to support multiple columns mapped under the same attribute. + + :param \**kw: additional keyword arguments passed to :class:`.ColumnProperty`. + + .. seealso:: + + :ref:`deferred` + + """ + return ColumnProperty(deferred=True, *columns, **kw) + + +mapper = public_factory(Mapper, ".orm.mapper") + +synonym = public_factory(SynonymProperty, ".orm.synonym") + +comparable_property = public_factory(ComparableProperty, + ".orm.comparable_property") + + +@_sa_util.deprecated("0.7", message=":func:`.compile_mappers` " + "is renamed to :func:`.configure_mappers`") +def compile_mappers(): + """Initialize the inter-mapper relationships of all mappers that have + been defined. + + """ + configure_mappers() + + +def clear_mappers(): + """Remove all mappers from all classes. + + This function removes all instrumentation from classes and disposes + of their associated mappers. Once called, the classes are unmapped + and can be later re-mapped with new mappers. + + :func:`.clear_mappers` is *not* for normal use, as there is literally no + valid usage for it outside of very specific testing scenarios. Normally, + mappers are permanent structural components of user-defined classes, and + are never discarded independently of their class. If a mapped class itself + is garbage collected, its mapper is automatically disposed of as well. As + such, :func:`.clear_mappers` is only for usage in test suites that re-use + the same classes with different mappings, which is itself an extremely rare + use case - the only such use case is in fact SQLAlchemy's own test suite, + and possibly the test suites of other ORM extension libraries which + intend to test various combinations of mapper construction upon a fixed + set of classes. + + """ + mapperlib._CONFIGURE_MUTEX.acquire() + try: + while _mapper_registry: + try: + # can't even reliably call list(weakdict) in jython + mapper, b = _mapper_registry.popitem() + mapper.dispose() + except KeyError: + pass + finally: + mapperlib._CONFIGURE_MUTEX.release() + +from . import strategy_options + +joinedload = strategy_options.joinedload._unbound_fn +joinedload_all = strategy_options.joinedload._unbound_all_fn +contains_eager = strategy_options.contains_eager._unbound_fn +defer = strategy_options.defer._unbound_fn +undefer = strategy_options.undefer._unbound_fn +undefer_group = strategy_options.undefer_group._unbound_fn +load_only = strategy_options.load_only._unbound_fn +lazyload = strategy_options.lazyload._unbound_fn +lazyload_all = strategy_options.lazyload_all._unbound_all_fn +subqueryload = strategy_options.subqueryload._unbound_fn +subqueryload_all = strategy_options.subqueryload_all._unbound_all_fn +immediateload = strategy_options.immediateload._unbound_fn +noload = strategy_options.noload._unbound_fn +defaultload = strategy_options.defaultload._unbound_fn + +from .strategy_options import Load + +def eagerload(*args, **kwargs): + """A synonym for :func:`joinedload()`.""" + return joinedload(*args, **kwargs) + + +def eagerload_all(*args, **kwargs): + """A synonym for :func:`joinedload_all()`""" + return joinedload_all(*args, **kwargs) + + + + +contains_alias = public_factory(AliasOption, ".orm.contains_alias") + + + +def __go(lcls): + global __all__ + from .. import util as sa_util + from . import dynamic + from . import events + import inspect as _inspect + + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) + + _sa_util.dependencies.resolve_all("sqlalchemy.orm") + +__go(locals()) + diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py new file mode 100644 index 00000000..3a786c73 --- /dev/null +++ b/lib/sqlalchemy/orm/attributes.py @@ -0,0 +1,1541 @@ +# orm/attributes.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines instrumentation for class attributes and their interaction +with instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + + +""" + +import operator +from .. import util, event, inspection +from . import interfaces, collections, exc as orm_exc + +from .base import instance_state, instance_dict, manager_of_class + +from .base import PASSIVE_NO_RESULT, ATTR_WAS_SET, ATTR_EMPTY, NO_VALUE,\ + NEVER_SET, NO_CHANGE, CALLABLES_OK, SQL_OK, RELATED_OBJECT_OK,\ + INIT_OK, NON_PERSISTENT_OK, LOAD_AGAINST_COMMITTED, PASSIVE_OFF,\ + PASSIVE_RETURN_NEVER_SET, PASSIVE_NO_INITIALIZE, PASSIVE_NO_FETCH,\ + PASSIVE_NO_FETCH_RELATED, PASSIVE_ONLY_PERSISTENT, NO_AUTOFLUSH +from .base import state_str, instance_str + +@inspection._self_inspects +class QueryableAttribute(interfaces._MappedAttribute, + interfaces._InspectionAttr, + interfaces.PropComparator): + """Base class for :term:`descriptor` objects that intercept + attribute events on behalf of a :class:`.MapperProperty` + object. The actual :class:`.MapperProperty` is accessible + via the :attr:`.QueryableAttribute.property` + attribute. + + + .. seealso:: + + :class:`.InstrumentedAttribute` + + :class:`.MapperProperty` + + :attr:`.Mapper.all_orm_descriptors` + + :attr:`.Mapper.attrs` + """ + + is_attribute = True + + def __init__(self, class_, key, impl=None, + comparator=None, parententity=None, + of_type=None): + self.class_ = class_ + self.key = key + self.impl = impl + self.comparator = comparator + self._parententity = parententity + self._of_type = of_type + + manager = manager_of_class(class_) + # manager is None in the case of AliasedClass + if manager: + # propagate existing event listeners from + # immediate superclass + for base in manager._bases: + if key in base: + self.dispatch._update(base[key].dispatch) + + @util.memoized_property + def _supports_population(self): + return self.impl.supports_population + + def get_history(self, instance, passive=PASSIVE_OFF): + return self.impl.get_history(instance_state(instance), + instance_dict(instance), passive) + + def __selectable__(self): + # TODO: conditionally attach this method based on clause_element ? + return self + + + @util.memoized_property + def info(self): + """Return the 'info' dictionary for the underlying SQL element. + + The behavior here is as follows: + + * If the attribute is a column-mapped property, i.e. + :class:`.ColumnProperty`, which is mapped directly + to a schema-level :class:`.Column` object, this attribute + will return the :attr:`.SchemaItem.info` dictionary associated + with the core-level :class:`.Column` object. + + * If the attribute is a :class:`.ColumnProperty` but is mapped to + any other kind of SQL expression other than a :class:`.Column`, + the attribute will refer to the :attr:`.MapperProperty.info` dictionary + associated directly with the :class:`.ColumnProperty`, assuming the SQL + expression itself does not have it's own ``.info`` attribute + (which should be the case, unless a user-defined SQL construct + has defined one). + + * If the attribute refers to any other kind of :class:`.MapperProperty`, + including :class:`.RelationshipProperty`, the attribute will refer + to the :attr:`.MapperProperty.info` dictionary associated with + that :class:`.MapperProperty`. + + * To access the :attr:`.MapperProperty.info` dictionary of the :class:`.MapperProperty` + unconditionally, including for a :class:`.ColumnProperty` that's + associated directly with a :class:`.schema.Column`, the attribute + can be referred to using :attr:`.QueryableAttribute.property` + attribute, as ``MyClass.someattribute.property.info``. + + .. versionadded:: 0.8.0 + + .. seealso:: + + :attr:`.SchemaItem.info` + + :attr:`.MapperProperty.info` + + """ + return self.comparator.info + + @util.memoized_property + def parent(self): + """Return an inspection instance representing the parent. + + This will be either an instance of :class:`.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. + + """ + return inspection.inspect(self._parententity) + + @property + def expression(self): + return self.comparator.__clause_element__() + + def __clause_element__(self): + return self.comparator.__clause_element__() + + def _query_clause_element(self): + """like __clause_element__(), but called specifically + by :class:`.Query` to allow special behavior.""" + + return self.comparator._query_clause_element() + + def adapt_to_entity(self, adapt_to_entity): + assert not self._of_type + return self.__class__(adapt_to_entity.entity, self.key, impl=self.impl, + comparator=self.comparator.adapt_to_entity(adapt_to_entity), + parententity=adapt_to_entity) + + def of_type(self, cls): + return QueryableAttribute( + self.class_, + self.key, + self.impl, + self.comparator.of_type(cls), + self._parententity, + of_type=cls) + + def label(self, name): + return self._query_clause_element().label(name) + + def operate(self, op, *other, **kwargs): + return op(self.comparator, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) + + def hasparent(self, state, optimistic=False): + return self.impl.hasparent(state, optimistic=optimistic) is not False + + def __getattr__(self, key): + try: + return getattr(self.comparator, key) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object associated with %s ' + 'has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + self, + key) + ) + + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) + + @util.memoized_property + def property(self): + """Return the :class:`.MapperProperty` associated with this + :class:`.QueryableAttribute`. + + + Return values here will commonly be instances of + :class:`.ColumnProperty` or :class:`.RelationshipProperty`. + + + """ + return self.comparator.property + + +class InstrumentedAttribute(QueryableAttribute): + """Class bound instrumented attribute which adds basic + :term:`descriptor` methods. + + See :class:`.QueryableAttribute` for a description of most features. + + + """ + + def __set__(self, instance, value): + self.impl.set(instance_state(instance), + instance_dict(instance), value, None) + + def __delete__(self, instance): + self.impl.delete(instance_state(instance), instance_dict(instance)) + + def __get__(self, instance, owner): + if instance is None: + return self + + dict_ = instance_dict(instance) + if self._supports_population and self.key in dict_: + return dict_[self.key] + else: + return self.impl.get(instance_state(instance), dict_) + + +def create_proxied_attribute(descriptor): + """Create an QueryableAttribute / user descriptor hybrid. + + Returns a new QueryableAttribute type that delegates descriptor + behavior and getattr() to the given descriptor. + """ + + # TODO: can move this to descriptor_props if the need for this + # function is removed from ext/hybrid.py + + class Proxy(QueryableAttribute): + """Presents the :class:`.QueryableAttribute` interface as a + proxy on top of a Python descriptor / :class:`.PropComparator` + combination. + + """ + + def __init__(self, class_, key, descriptor, + comparator, + adapt_to_entity=None, doc=None, + original_property=None): + self.class_ = class_ + self.key = key + self.descriptor = descriptor + self.original_property = original_property + self._comparator = comparator + self._adapt_to_entity = adapt_to_entity + self.__doc__ = doc + + @property + def property(self): + return self.comparator.property + + @util.memoized_property + def comparator(self): + if util.callable(self._comparator): + self._comparator = self._comparator() + if self._adapt_to_entity: + self._comparator = self._comparator.adapt_to_entity( + self._adapt_to_entity) + return self._comparator + + def adapt_to_entity(self, adapt_to_entity): + return self.__class__(adapt_to_entity.entity, self.key, self.descriptor, + self._comparator, + adapt_to_entity) + + def __get__(self, instance, owner): + if instance is None: + return self + else: + return self.descriptor.__get__(instance, owner) + + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) + + def __getattr__(self, attribute): + """Delegate __getattr__ to the original descriptor and/or + comparator.""" + + try: + return getattr(descriptor, attribute) + except AttributeError: + try: + return getattr(self.comparator, attribute) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object associated with %s ' + 'has an attribute %r' % ( + type(descriptor).__name__, + type(self.comparator).__name__, + self, + attribute) + ) + + Proxy.__name__ = type(descriptor).__name__ + 'Proxy' + + util.monkeypatch_proxied_specials(Proxy, type(descriptor), + name='descriptor', + from_instance=descriptor) + return Proxy + +OP_REMOVE = util.symbol("REMOVE") +OP_APPEND = util.symbol("APPEND") +OP_REPLACE = util.symbol("REPLACE") + +class Event(object): + """A token propagated throughout the course of a chain of attribute + events. + + Serves as an indicator of the source of the event and also provides + a means of controlling propagation across a chain of attribute + operations. + + The :class:`.Event` object is sent as the ``initiator`` argument + when dealing with the :meth:`.AttributeEvents.append`, + :meth:`.AttributeEvents.set`, + and :meth:`.AttributeEvents.remove` events. + + The :class:`.Event` object is currently interpreted by the backref + event handlers, and is used to control the propagation of operations + across two mutually-dependent attributes. + + .. versionadded:: 0.9.0 + + """ + + impl = None + """The :class:`.AttributeImpl` which is the current event initiator. + """ + + op = None + """The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or :attr:`.OP_REPLACE`, + indicating the source operation. + + """ + + def __init__(self, attribute_impl, op): + self.impl = attribute_impl + self.op = op + self.parent_token = self.impl.parent_token + + + @property + def key(self): + return self.impl.key + + def hasparent(self, state): + return self.impl.hasparent(state) + +class AttributeImpl(object): + """internal implementation for instrumented attributes.""" + + def __init__(self, class_, key, + callable_, dispatch, trackparent=False, extension=None, + compare_function=None, active_history=False, + parent_token=None, expire_missing=True, + send_modified_events=True, + **kwargs): + """Construct an AttributeImpl. + + \class_ + associated class + + key + string name of the attribute + + \callable_ + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present + already. + + trackparent + if True, attempt to track if an instance has a parent attached + to it via this attribute. + + extension + a single or list of AttributeExtension object(s) which will + receive set/delete/append/remove/etc. events. Deprecated. + The event package is now used. + + compare_function + a function that compares two values which are normally + assignable to this attribute. + + active_history + indicates that get_history() should always return the "old" value, + even if it means executing a lazy callable upon attribute change. + + parent_token + Usually references the MapperProperty, used as a key for + the hasparent() function to identify an "owning" attribute. + Allows multiple AttributeImpls to all match a single + owner attribute. + + expire_missing + if False, don't add an "expiry" callable to this attribute + during state.expire_attributes(None), if no value is present + for this key. + + send_modified_events + if False, the InstanceState._modified_event method will have no effect; + this means the attribute will never show up as changed in a + history entry. + """ + self.class_ = class_ + self.key = key + self.callable_ = callable_ + self.dispatch = dispatch + self.trackparent = trackparent + self.parent_token = parent_token or self + self.send_modified_events = send_modified_events + if compare_function is None: + self.is_equal = operator.eq + else: + self.is_equal = compare_function + + # TODO: pass in the manager here + # instead of doing a lookup + attr = manager_of_class(class_)[key] + + for ext in util.to_list(extension or []): + ext._adapt_listener(attr, ext) + + if active_history: + self.dispatch._active_history = True + + self.expire_missing = expire_missing + + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) + + def _get_active_history(self): + """Backwards compat for impl.active_history""" + + return self.dispatch._active_history + + def _set_active_history(self, value): + self.dispatch._active_history = value + + active_history = property(_get_active_history, _set_active_history) + + def hasparent(self, state, optimistic=False): + """Return the boolean value of a `hasparent` flag attached to + the given state. + + The `optimistic` flag determines what the default return value + should be if no `hasparent` flag can be located. + + As this function is used to determine if an instance is an + *orphan*, instances that were loaded from storage should be + assumed to not be orphans, until a True/False value for this + flag is set. + + An instance attribute that is loaded by a callable function + will also not have a `hasparent` flag. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + return state.parents.get(id(self.parent_token), optimistic) \ + is not False + + def sethasparent(self, state, parent_state, value): + """Set a boolean flag on the given item corresponding to + whether or not it is attached to a parent object via the + attribute represented by this ``InstrumentedAttribute``. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + id_ = id(self.parent_token) + if value: + state.parents[id_] = parent_state + else: + if id_ in state.parents: + last_parent = state.parents[id_] + + if last_parent is not False and \ + last_parent.key != parent_state.key: + + if last_parent.obj() is None: + raise orm_exc.StaleDataError( + "Removing state %s from parent " + "state %s along attribute '%s', " + "but the parent record " + "has gone stale, can't be sure this " + "is the most recent parent." % + (state_str(state), + state_str(parent_state), + self.key)) + + return + + state.parents[id_] = False + + def set_callable(self, state, callable_): + """Set a callable function for this attribute on the given object. + + This callable will be executed when the attribute is next + accessed, and is assumed to construct part of the instances + previously stored state. When its value or values are loaded, + they will be established as part of the instance's *committed + state*. While *trackparent* information will be assembled for + these instances, attribute-level event handlers will not be + fired. + + The callable overrides the class level callable set in the + ``InstrumentedAttribute`` constructor. + + """ + state.callables[self.key] = callable_ + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + raise NotImplementedError() + + def get_all_pending(self, state, dict_): + """Return a list of tuples of (state, obj) + for all objects in this attribute's current state + + history. + + Only applies to object-based attributes. + + This is an inlining of existing functionality + which roughly corresponds to: + + get_state_history( + state, + key, + passive=PASSIVE_NO_INITIALIZE).sum() + + """ + raise NotImplementedError() + + def initialize(self, state, dict_): + """Initialize the given state's attribute with an empty value.""" + + dict_[self.key] = None + return None + + def get(self, state, dict_, passive=PASSIVE_OFF): + """Retrieve a value from the given object. + If a callable is assembled on this object's attribute, and + passive is False, the callable will be executed and the + resulting value will be set as the new value for this attribute. + """ + if self.key in dict_: + return dict_[self.key] + else: + # if history present, don't load + key = self.key + if key not in state.committed_state or \ + state.committed_state[key] is NEVER_SET: + if not passive & CALLABLES_OK: + return PASSIVE_NO_RESULT + + if key in state.callables: + callable_ = state.callables[key] + value = callable_(state, passive) + elif self.callable_: + value = self.callable_(state, passive) + else: + value = ATTR_EMPTY + + if value is PASSIVE_NO_RESULT or value is NEVER_SET: + return value + elif value is ATTR_WAS_SET: + try: + return dict_[key] + except KeyError: + # TODO: no test coverage here. + raise KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key) + elif value is not ATTR_EMPTY: + return self.set_committed_value(state, dict_, value) + + if not passive & INIT_OK: + return NEVER_SET + else: + # Return a new, empty value + return self.initialize(state, dict_) + + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, value, initiator, passive=passive) + + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator, + passive=passive, check_old=value) + + def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator, + passive=passive, check_old=value, pop=True) + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): + raise NotImplementedError() + + def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): + """return the unchanged value of this attribute""" + + if self.key in state.committed_state: + value = state.committed_state[self.key] + if value is NO_VALUE: + return None + else: + return value + else: + return self.get(state, dict_, passive=passive) + + def set_committed_value(self, state, dict_, value): + """set an attribute value on the given instance and 'commit' it.""" + + dict_[self.key] = value + state._commit(dict_, [self.key]) + return value + + +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" + + accepts_scalar_loader = True + uses_objects = False + supports_population = True + collection = False + + def delete(self, state, dict_): + + # TODO: catch key errors, convert to attributeerror? + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.remove: + self.fire_remove_event(state, dict_, old, self._remove_token) + state._modified_event(dict_, self, old) + del dict_[self.key] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_scalar_attribute(self, state, dict_[self.key]) + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_scalar_attribute(self, state, current) + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.set: + value = self.fire_replace_event(state, dict_, + value, old, initiator) + state._modified_event(dict_, self, old) + dict_[self.key] = value + + @util.memoized_property + def _replace_token(self): + return Event(self, OP_REPLACE) + + @util.memoized_property + def _append_token(self): + return Event(self, OP_REPLACE) + + @util.memoized_property + def _remove_token(self): + return Event(self, OP_REMOVE) + + def fire_replace_event(self, state, dict_, value, previous, initiator): + for fn in self.dispatch.set: + value = fn(state, value, previous, initiator or self._replace_token) + return value + + def fire_remove_event(self, state, dict_, value, initiator): + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + @property + def type(self): + self.property.columns[0].type + + +class ScalarObjectAttributeImpl(ScalarAttributeImpl): + """represents a scalar-holding InstrumentedAttribute, + where the target object is also instrumented. + + Adds events to delete/set operations. + + """ + + accepts_scalar_loader = False + uses_objects = True + supports_population = True + collection = False + + def delete(self, state, dict_): + old = self.get(state, dict_) + self.fire_remove_event(state, dict_, old, self._remove_token) + del dict_[self.key] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_object_attribute(self, state, dict_[self.key]) + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_object_attribute(self, state, current) + + def get_all_pending(self, state, dict_): + if self.key in dict_: + current = dict_[self.key] + if current is not None: + ret = [(instance_state(current), current)] + else: + ret = [(None, None)] + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original not in (NEVER_SET, PASSIVE_NO_RESULT, None) and \ + original is not current: + + ret.append((instance_state(original), original)) + return ret + else: + return [] + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): + """Set a value on the given InstanceState. + + """ + if self.dispatch._active_history: + old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT | NO_AUTOFLUSH) + else: + old = self.get(state, dict_, passive=PASSIVE_NO_FETCH) + + if check_old is not None and \ + old is not PASSIVE_NO_RESULT and \ + check_old is not old: + if pop: + return + else: + raise ValueError( + "Object %s not associated with %s on attribute '%s'" % ( + instance_str(check_old), + state_str(state), + self.key + )) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value + + + def fire_remove_event(self, state, dict_, value, initiator): + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + state._modified_event(dict_, self, value) + + def fire_replace_event(self, state, dict_, value, previous, initiator): + if self.trackparent: + if (previous is not value and + previous is not None and + previous is not PASSIVE_NO_RESULT): + self.sethasparent(instance_state(previous), state, False) + + for fn in self.dispatch.set: + value = fn(state, value, previous, initiator or self._replace_token) + + state._modified_event(dict_, self, previous) + + if self.trackparent: + if value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + +class CollectionAttributeImpl(AttributeImpl): + """A collection-holding attribute that instruments changes in membership. + + Only handles collections of instrumented objects. + + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent bag + semantics to the orm layer independent of the user data implementation. + + """ + accepts_scalar_loader = False + uses_objects = True + supports_population = True + collection = True + + def __init__(self, class_, key, callable_, dispatch, + typecallable=None, trackparent=False, extension=None, + copy_function=None, compare_function=None, **kwargs): + super(CollectionAttributeImpl, self).__init__( + class_, + key, + callable_, dispatch, + trackparent=trackparent, + extension=extension, + compare_function=compare_function, + **kwargs) + + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function + self.collection_factory = typecallable + + def __copy(self, item): + return [y for y in collections.collection_adapter(item)] + + def get_history(self, state, dict_, passive=PASSIVE_OFF): + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_collection(self, state, current) + + def get_all_pending(self, state, dict_): + if self.key not in dict_: + return [] + + current = dict_[self.key] + current = getattr(current, '_sa_adapter') + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original not in (NO_VALUE, NEVER_SET): + current_states = [((c is not None) and + instance_state(c) or None, c) + for c in current] + original_states = [((c is not None) and + instance_state(c) or None, c) + for c in original] + + current_set = dict(current_states) + original_set = dict(original_states) + + return \ + [(s, o) for s, o in current_states + if s not in original_set] + \ + [(s, o) for s, o in current_states + if s in original_set] + \ + [(s, o) for s, o in original_states + if s not in current_set] + + return [(instance_state(o), o) for o in current] + + @util.memoized_property + def _append_token(self): + return Event(self, OP_APPEND) + + @util.memoized_property + def _remove_token(self): + return Event(self, OP_REMOVE) + + def fire_append_event(self, state, dict_, value, initiator): + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token) + + state._modified_event(dict_, self, NEVER_SET, True) + + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + def fire_pre_remove_event(self, state, dict_, initiator): + state._modified_event(dict_, self, NEVER_SET, True) + + def fire_remove_event(self, state, dict_, value, initiator): + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + state._modified_event(dict_, self, NEVER_SET, True) + + def delete(self, state, dict_): + if self.key not in dict_: + return + + state._modified_event(dict_, self, NEVER_SET, True) + + collection = self.get_collection(state, state.dict) + collection.clear_with_event() + # TODO: catch key errors, convert to attributeerror? + del dict_[self.key] + + def initialize(self, state, dict_): + """Initialize this attribute with an empty collection.""" + + _, user_data = self._initialize_collection(state) + dict_[self.key] = user_data + return user_data + + def _initialize_collection(self, state): + return state.manager.initialize_collection( + self.key, state, self.collection_factory) + + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + collection = self.get_collection(state, dict_, passive=passive) + if collection is PASSIVE_NO_RESULT: + value = self.fire_append_event(state, dict_, value, initiator) + assert self.key not in dict_, \ + "Collection was loaded during event handling." + state._get_pending_mutation(self.key).append(value) + else: + collection.append_with_event(value, initiator) + + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + collection = self.get_collection(state, state.dict, passive=passive) + if collection is PASSIVE_NO_RESULT: + self.fire_remove_event(state, dict_, value, initiator) + assert self.key not in dict_, \ + "Collection was loaded during event handling." + state._get_pending_mutation(self.key).remove(value) + else: + collection.remove_with_event(value, initiator) + + def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + try: + # TODO: better solution here would be to add + # a "popper" role to collections.py to complement + # "remover". + self.remove(state, dict_, value, initiator, passive=passive) + except (ValueError, KeyError, IndexError): + pass + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, pop=False): + """Set a value on the given object. + + """ + + self._set_iterable( + state, dict_, value, + lambda adapter, i: adapter.adapt_like_to_iterable(i)) + + def _set_iterable(self, state, dict_, iterable, adapter=None): + """Set a collection value from an iterable of state-bearers. + + ``adapter`` is an optional callable invoked with a CollectionAdapter + and the iterable. Should return an iterable of state-bearing + instances suitable for appending via a CollectionAdapter. Can be used + for, e.g., adapting an incoming dictionary into an iterator of values + rather than keys. + + """ + # pulling a new collection first so that an adaptation exception does + # not trigger a lazy load of the old collection. + new_collection, user_data = self._initialize_collection(state) + if adapter: + new_values = list(adapter(new_collection, iterable)) + else: + new_values = list(iterable) + + old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT) + if old is PASSIVE_NO_RESULT: + old = self.initialize(state, dict_) + elif old is iterable: + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) + return + + # place a copy of "old" in state.committed_state + state._modified_event(dict_, self, old, True) + + old_collection = getattr(old, '_sa_adapter') + + dict_[self.key] = user_data + + collections.bulk_replace(new_values, old_collection, new_collection) + old_collection.unlink(old) + + def _invalidate_collection(self, collection): + adapter = getattr(collection, '_sa_adapter') + adapter.invalidated = True + + def set_committed_value(self, state, dict_, value): + """Set an attribute value on the given instance and 'commit' it.""" + + collection, user_data = self._initialize_collection(state) + + if value: + collection.append_multiple_without_event(value) + + state.dict[self.key] = user_data + + state._commit(dict_, [self.key]) + + if self.key in state._pending_mutations: + # pending items exist. issue a modified event, + # add/remove new items. + state._modified_event(dict_, self, user_data, True) + + pending = state._pending_mutations.pop(self.key) + added = pending.added_items + removed = pending.deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + + return user_data + + def get_collection(self, state, dict_, + user_data=None, passive=PASSIVE_OFF): + """Retrieve the CollectionAdapter associated with the given state. + + Creates a new CollectionAdapter if one does not exist. + + """ + if user_data is None: + user_data = self.get(state, dict_, passive=passive) + if user_data is PASSIVE_NO_RESULT: + return user_data + + return getattr(user_data, '_sa_adapter') + + +def backref_listeners(attribute, key, uselist): + """Apply listeners to synchronize a two-way relationship.""" + + # use easily recognizable names for stack traces + + parent_token = attribute.impl.parent_token + parent_impl = attribute.impl + + def _acceptable_key_err(child_state, initiator, child_impl): + raise ValueError( + "Bidirectional attribute conflict detected: " + 'Passing object %s to attribute "%s" ' + 'triggers a modify event on attribute "%s" ' + 'via the backref "%s".' % ( + state_str(child_state), + initiator.parent_token, + child_impl.parent_token, + attribute.impl.parent_token + ) + ) + + def emit_backref_from_scalar_set_event(state, child, oldchild, initiator): + if oldchild is child: + return child + if oldchild is not None and oldchild is not PASSIVE_NO_RESULT: + # With lazy=None, there's no guarantee that the full collection is + # present when updating via a backref. + old_state, old_dict = instance_state(oldchild),\ + instance_dict(oldchild) + impl = old_state.manager[key].impl + + if initiator.impl is not impl or \ + initiator.op not in (OP_REPLACE, OP_REMOVE): + impl.pop(old_state, + old_dict, + state.obj(), + parent_impl._append_token, + passive=PASSIVE_NO_FETCH) + + if child is not None: + child_state, child_dict = instance_state(child),\ + instance_dict(child) + child_impl = child_state.manager[key].impl + if initiator.parent_token is not parent_token and \ + initiator.parent_token is not child_impl.parent_token: + _acceptable_key_err(state, initiator, child_impl) + elif initiator.impl is not child_impl or \ + initiator.op not in (OP_APPEND, OP_REPLACE): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) + return child + + def emit_backref_from_collection_append_event(state, child, initiator): + if child is None: + return + + child_state, child_dict = instance_state(child), \ + instance_dict(child) + child_impl = child_state.manager[key].impl + + if initiator.parent_token is not parent_token and \ + initiator.parent_token is not child_impl.parent_token: + _acceptable_key_err(state, initiator, child_impl) + elif initiator.impl is not child_impl or \ + initiator.op not in (OP_APPEND, OP_REPLACE): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) + return child + + def emit_backref_from_collection_remove_event(state, child, initiator): + if child is not None: + child_state, child_dict = instance_state(child),\ + instance_dict(child) + child_impl = child_state.manager[key].impl + if initiator.impl is not child_impl or \ + initiator.op not in (OP_REMOVE, OP_REPLACE): + child_impl.pop( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) + + if uselist: + event.listen(attribute, "append", + emit_backref_from_collection_append_event, + retval=True, raw=True) + else: + event.listen(attribute, "set", + emit_backref_from_scalar_set_event, + retval=True, raw=True) + # TODO: need coverage in test/orm/ of remove event + event.listen(attribute, "remove", + emit_backref_from_collection_remove_event, + retval=True, raw=True) + +_NO_HISTORY = util.symbol('NO_HISTORY') +_NO_STATE_SYMBOLS = frozenset([ + id(PASSIVE_NO_RESULT), + id(NO_VALUE), + id(NEVER_SET)]) + +History = util.namedtuple("History", [ + "added", "unchanged", "deleted" + ]) + + +class History(History): + """A 3-tuple of added, unchanged and deleted values, + representing the changes which have occurred on an instrumented + attribute. + + The easiest way to get a :class:`.History` object for a particular + attribute on an object is to use the :func:`.inspect` function:: + + from sqlalchemy import inspect + + hist = inspect(myobject).attrs.myattribute.history + + Each tuple member is an iterable sequence: + + * ``added`` - the collection of items added to the attribute (the first + tuple element). + + * ``unchanged`` - the collection of items that have not changed on the + attribute (the second tuple element). + + * ``deleted`` - the collection of items that have been removed from the + attribute (the third tuple element). + + """ + + def __bool__(self): + return self != HISTORY_BLANK + __nonzero__ = __bool__ + + def empty(self): + """Return True if this :class:`.History` has no changes + and no existing, unchanged state. + + """ + + return not bool( + (self.added or self.deleted) + or self.unchanged and self.unchanged != [None] + ) + + def sum(self): + """Return a collection of added + unchanged + deleted.""" + + return (self.added or []) +\ + (self.unchanged or []) +\ + (self.deleted or []) + + def non_deleted(self): + """Return a collection of added + unchanged.""" + + return (self.added or []) +\ + (self.unchanged or []) + + def non_added(self): + """Return a collection of unchanged + deleted.""" + + return (self.unchanged or []) +\ + (self.deleted or []) + + def has_changes(self): + """Return True if this :class:`.History` has changes.""" + + return bool(self.added or self.deleted) + + def as_state(self): + return History( + [(c is not None) + and instance_state(c) or None + for c in self.added], + [(c is not None) + and instance_state(c) or None + for c in self.unchanged], + [(c is not None) + and instance_state(c) or None + for c in self.deleted], + ) + + @classmethod + def from_scalar_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if original is _NO_HISTORY: + if current is NEVER_SET: + return cls((), (), ()) + else: + return cls((), [current], ()) + # don't let ClauseElement expressions here trip things up + elif attribute.is_equal(current, original) is True: + return cls((), [current], ()) + else: + # current convention on native scalars is to not + # include information + # about missing previous value in "deleted", but + # we do include None, which helps in some primary + # key situations + if id(original) in _NO_STATE_SYMBOLS: + deleted = () + else: + deleted = [original] + if current is NEVER_SET: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_object_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if original is _NO_HISTORY: + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), ()) + else: + return cls((), [current], ()) + elif current is original: + return cls((), [current], ()) + else: + # current convention on related objects is to not + # include information + # about missing previous value in "deleted", and + # to also not include None - the dependency.py rules + # ignore the None in any case. + if id(original) in _NO_STATE_SYMBOLS or original is None: + deleted = () + else: + deleted = [original] + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_collection(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), ()) + + current = getattr(current, '_sa_adapter') + if original in (NO_VALUE, NEVER_SET): + return cls(list(current), (), ()) + elif original is _NO_HISTORY: + return cls((), list(current), ()) + else: + + current_states = [((c is not None) and instance_state(c) + or None, c) + for c in current + ] + original_states = [((c is not None) and instance_state(c) + or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return cls( + [o for s, o in current_states if s not in original_set], + [o for s, o in current_states if s in original_set], + [o for s, o in original_states if s not in current_set] + ) + +HISTORY_BLANK = History(None, None, None) + + +def get_history(obj, key, passive=PASSIVE_OFF): + """Return a :class:`.History` record for the given object + and attribute key. + + :param obj: an object whose class is instrumented by the + attributes package. + + :param key: string attribute name. + + :param passive: indicates loading behavior for the attribute + if the value is not already present. This is a + bitflag attribute, which defaults to the symbol + :attr:`.PASSIVE_OFF` indicating all necessary SQL + should be emitted. + + """ + if passive is True: + util.warn_deprecated("Passing True for 'passive' is deprecated. " + "Use attributes.PASSIVE_NO_INITIALIZE") + passive = PASSIVE_NO_INITIALIZE + elif passive is False: + util.warn_deprecated("Passing False for 'passive' is " + "deprecated. Use attributes.PASSIVE_OFF") + passive = PASSIVE_OFF + + return get_state_history(instance_state(obj), key, passive) + + +def get_state_history(state, key, passive=PASSIVE_OFF): + return state.get_history(key, passive) + + +def has_parent(cls, obj, key, optimistic=False): + """TODO""" + manager = manager_of_class(cls) + state = instance_state(obj) + return manager.has_parent(state, key, optimistic) + + +def register_attribute(class_, key, **kw): + comparator = kw.pop('comparator', None) + parententity = kw.pop('parententity', None) + doc = kw.pop('doc', None) + desc = register_descriptor(class_, key, + comparator, parententity, doc=doc) + register_attribute_impl(class_, key, **kw) + return desc + + +def register_attribute_impl(class_, key, + uselist=False, callable_=None, + useobject=False, + impl_class=None, backref=None, **kw): + + manager = manager_of_class(class_) + if uselist: + factory = kw.pop('typecallable', None) + typecallable = manager.instrument_collection_class( + key, factory or list) + else: + typecallable = kw.pop('typecallable', None) + + dispatch = manager[key].dispatch + + if impl_class: + impl = impl_class(class_, key, typecallable, dispatch, **kw) + elif uselist: + impl = CollectionAttributeImpl(class_, key, callable_, dispatch, + typecallable=typecallable, **kw) + elif useobject: + impl = ScalarObjectAttributeImpl(class_, key, callable_, + dispatch, **kw) + else: + impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) + + manager[key].impl = impl + + if backref: + backref_listeners(manager[key], backref, uselist) + + manager.post_configure_attribute(key) + return manager[key] + + +def register_descriptor(class_, key, comparator=None, + parententity=None, doc=None): + manager = manager_of_class(class_) + + descriptor = InstrumentedAttribute(class_, key, comparator=comparator, + parententity=parententity) + + descriptor.__doc__ = doc + + manager.instrument_attribute(key, descriptor) + return descriptor + + +def unregister_attribute(class_, key): + manager_of_class(class_).uninstrument_attribute(key) + + +def init_collection(obj, key): + """Initialize a collection attribute and return the collection adapter. + + This function is used to provide direct access to collection internals + for a previously unloaded attribute. e.g.:: + + collection_adapter = init_collection(someobject, 'elements') + for elem in values: + collection_adapter.append_without_event(elem) + + For an easier way to do the above, see + :func:`~sqlalchemy.orm.attributes.set_committed_value`. + + obj is an instrumented object instance. An InstanceState + is accepted directly for backwards compatibility but + this usage is deprecated. + + """ + state = instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) + + +def init_state_collection(state, dict_, key): + """Initialize a collection attribute and return the collection adapter.""" + + attr = state.manager[key].impl + user_data = attr.initialize(state, dict_) + return attr.get_collection(state, dict_, user_data) + + +def set_committed_value(instance, key, value): + """Set the value of an attribute with no history events. + + Cancels any previous history present. The value should be + a scalar value for scalar-holding attributes, or + an iterable for any collection-holding attribute. + + This is the same underlying method used when a lazy loader + fires off and loads additional data from the database. + In particular, this method can be used by application code + which has loaded additional attributes or collections through + separate queries, which can then be attached to an instance + as though it were part of its original loaded state. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set_committed_value(state, dict_, value) + + +def set_attribute(instance, key, value): + """Set the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set(state, dict_, value, None) + + +def get_attribute(instance, key): + """Get the value of an attribute, firing any callables required. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to make usage of attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + return state.manager[key].impl.get(state, dict_) + + +def del_attribute(instance, key): + """Delete the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.delete(state, dict_) + + +def flag_modified(instance, key): + """Mark an attribute on an instance as 'modified'. + + This sets the 'modified' flag on the instance and + establishes an unconditional change event for the given attribute. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + impl = state.manager[key].impl + state._modified_event(dict_, impl, NO_VALUE, force=True) diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py new file mode 100644 index 00000000..e973de89 --- /dev/null +++ b/lib/sqlalchemy/orm/base.py @@ -0,0 +1,457 @@ +# orm/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Constants and rudimental functions used throughout the ORM. + +""" + +from .. import util, inspection, exc as sa_exc +from ..sql import expression +from . import exc +import operator + +PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT', +"""Symbol returned by a loader callable or other attribute/history +retrieval operation when a value could not be determined, based +on loader callable flags. +""" +) + +ATTR_WAS_SET = util.symbol('ATTR_WAS_SET', +"""Symbol returned by a loader callable to indicate the +retrieved value, or values, were assigned to their attributes +on the target object. +""") + +ATTR_EMPTY = util.symbol('ATTR_EMPTY', +"""Symbol used internally to indicate an attribute had no callable. +""") + +NO_VALUE = util.symbol('NO_VALUE', +"""Symbol which may be placed as the 'previous' value of an attribute, +indicating no value was loaded for an attribute when it was modified, +and flags indicated we were not to load it. +""" +) + +NEVER_SET = util.symbol('NEVER_SET', +"""Symbol which may be placed as the 'previous' value of an attribute +indicating that the attribute had not been assigned to previously. +""" +) + +NO_CHANGE = util.symbol("NO_CHANGE", +"""No callables or SQL should be emitted on attribute access +and no state should change""", canonical=0 +) + +CALLABLES_OK = util.symbol("CALLABLES_OK", +"""Loader callables can be fired off if a value +is not present.""", canonical=1 +) + +SQL_OK = util.symbol("SQL_OK", +"""Loader callables can emit SQL at least on scalar value +attributes.""", canonical=2) + +RELATED_OBJECT_OK = util.symbol("RELATED_OBJECT_OK", +"""callables can use SQL to load related objects as well +as scalar value attributes. +""", canonical=4 +) + +INIT_OK = util.symbol("INIT_OK", +"""Attributes should be initialized with a blank +value (None or an empty collection) upon get, if no other +value can be obtained. +""", canonical=8 +) + +NON_PERSISTENT_OK = util.symbol("NON_PERSISTENT_OK", +"""callables can be emitted if the parent is not persistent.""", +canonical=16 +) + +LOAD_AGAINST_COMMITTED = util.symbol("LOAD_AGAINST_COMMITTED", +"""callables should use committed values as primary/foreign keys during a load +""", canonical=32 +) + +NO_AUTOFLUSH = util.symbol("NO_AUTOFLUSH", +"""loader callables should disable autoflush. +""", canonical=64) + +# pre-packaged sets of flags used as inputs +PASSIVE_OFF = util.symbol("PASSIVE_OFF", + "Callables can be emitted in all cases.", + canonical=(RELATED_OBJECT_OK | NON_PERSISTENT_OK | + INIT_OK | CALLABLES_OK | SQL_OK) +) +PASSIVE_RETURN_NEVER_SET = util.symbol("PASSIVE_RETURN_NEVER_SET", + """PASSIVE_OFF ^ INIT_OK""", + canonical=PASSIVE_OFF ^ INIT_OK +) +PASSIVE_NO_INITIALIZE = util.symbol("PASSIVE_NO_INITIALIZE", + "PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK", + canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK +) +PASSIVE_NO_FETCH = util.symbol("PASSIVE_NO_FETCH", + "PASSIVE_OFF ^ SQL_OK", + canonical=PASSIVE_OFF ^ SQL_OK +) +PASSIVE_NO_FETCH_RELATED = util.symbol("PASSIVE_NO_FETCH_RELATED", + "PASSIVE_OFF ^ RELATED_OBJECT_OK", + canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK +) +PASSIVE_ONLY_PERSISTENT = util.symbol("PASSIVE_ONLY_PERSISTENT", + "PASSIVE_OFF ^ NON_PERSISTENT_OK", + canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK +) + +DEFAULT_MANAGER_ATTR = '_sa_class_manager' +DEFAULT_STATE_ATTR = '_sa_instance_state' +_INSTRUMENTOR = ('mapper', 'instrumentor') + +EXT_CONTINUE = util.symbol('EXT_CONTINUE') +EXT_STOP = util.symbol('EXT_STOP') + +ONETOMANY = util.symbol('ONETOMANY', +"""Indicates the one-to-many direction for a :func:`.relationship`. + +This symbol is typically used by the internals but may be exposed within +certain API features. + +""") + +MANYTOONE = util.symbol('MANYTOONE', +"""Indicates the many-to-one direction for a :func:`.relationship`. + +This symbol is typically used by the internals but may be exposed within +certain API features. + +""") + +MANYTOMANY = util.symbol('MANYTOMANY', +"""Indicates the many-to-many direction for a :func:`.relationship`. + +This symbol is typically used by the internals but may be exposed within +certain API features. + +""") + +NOT_EXTENSION = util.symbol('NOT_EXTENSION', +"""Symbol indicating an :class:`_InspectionAttr` that's + not part of sqlalchemy.ext. + + Is assigned to the :attr:`._InspectionAttr.extension_type` + attibute. + +""") + +_none_set = frozenset([None]) + + +def _generative(*assertions): + """Mark a method as generative, e.g. method-chained.""" + + @util.decorator + def generate(fn, *args, **kw): + self = args[0]._clone() + for assertion in assertions: + assertion(self, fn.__name__) + fn(self, *args[1:], **kw) + return self + return generate + + +# these can be replaced by sqlalchemy.ext.instrumentation +# if augmented class instrumentation is enabled. +def manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) + +instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) + +instance_dict = operator.attrgetter('__dict__') + +def instance_str(instance): + """Return a string describing an instance.""" + + return state_str(instance_state(instance)) + +def state_str(state): + """Return a string describing an instance via its InstanceState.""" + + if state is None: + return "None" + else: + return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) + +def state_class_str(state): + """Return a string describing an instance's class via its InstanceState.""" + + if state is None: + return "None" + else: + return '<%s>' % (state.class_.__name__, ) + + +def attribute_str(instance, attribute): + return instance_str(instance) + "." + attribute + + +def state_attribute_str(state, attribute): + return state_str(state) + "." + attribute + +def object_mapper(instance): + """Given an object, return the primary Mapper associated with the object + instance. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + This function is available via the inspection system as:: + + inspect(instance).mapper + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + return object_state(instance).mapper + + +def object_state(instance): + """Given an object, return the :class:`.InstanceState` + associated with the object. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + Equivalent functionality is available via the :func:`.inspect` + function as:: + + inspect(instance) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + state = _inspect_mapped_object(instance) + if state is None: + raise exc.UnmappedInstanceError(instance) + else: + return state + + +@inspection._inspects(object) +def _inspect_mapped_object(instance): + try: + return instance_state(instance) + # TODO: whats the py-2/3 syntax to catch two + # different kinds of exceptions at once ? + except exc.UnmappedClassError: + return None + except exc.NO_STATE: + return None + + + +def _class_to_mapper(class_or_mapper): + insp = inspection.inspect(class_or_mapper, False) + if insp is not None: + return insp.mapper + else: + raise exc.UnmappedClassError(class_or_mapper) + + +def _mapper_or_none(entity): + """Return the :class:`.Mapper` for the given class or None if the + class is not mapped.""" + + insp = inspection.inspect(entity, False) + if insp is not None: + return insp.mapper + else: + return None + + +def _is_mapped_class(entity): + """Return True if the given object is a mapped class, + :class:`.Mapper`, or :class:`.AliasedClass`.""" + + insp = inspection.inspect(entity, False) + return insp is not None and \ + hasattr(insp, "mapper") and \ + ( + insp.is_mapper + or insp.is_aliased_class + ) + +def _attr_as_key(attr): + if hasattr(attr, 'key'): + return attr.key + else: + return expression._column_as_key(attr) + + + +def _orm_columns(entity): + insp = inspection.inspect(entity, False) + if hasattr(insp, 'selectable'): + return [c for c in insp.selectable.c] + else: + return [entity] + + + +def _is_aliased_class(entity): + insp = inspection.inspect(entity, False) + return insp is not None and \ + getattr(insp, "is_aliased_class", False) + + +def _entity_descriptor(entity, key): + """Return a class attribute given an entity and string name. + + May return :class:`.InstrumentedAttribute` or user-defined + attribute. + + """ + insp = inspection.inspect(entity) + if insp.is_selectable: + description = entity + entity = insp.c + elif insp.is_aliased_class: + entity = insp.entity + description = entity + elif hasattr(insp, "mapper"): + description = entity = insp.mapper.class_ + else: + description = entity + + try: + return getattr(entity, key) + except AttributeError: + raise sa_exc.InvalidRequestError( + "Entity '%s' has no property '%s'" % + (description, key) + ) + +_state_mapper = util.dottedgetter('manager.mapper') + +@inspection._inspects(type) +def _inspect_mapped_class(class_, configure=False): + try: + class_manager = manager_of_class(class_) + if not class_manager.is_mapped: + return None + mapper = class_manager.mapper + if configure and mapper._new_mappers: + mapper._configure_all() + return mapper + + except exc.NO_STATE: + return None + +def class_mapper(class_, configure=True): + """Given a class, return the primary :class:`.Mapper` associated + with the key. + + Raises :exc:`.UnmappedClassError` if no mapping is configured + on the given class, or :exc:`.ArgumentError` if a non-class + object is passed. + + Equivalent functionality is available via the :func:`.inspect` + function as:: + + inspect(some_mapped_class) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped. + + """ + mapper = _inspect_mapped_class(class_, configure=configure) + if mapper is None: + if not isinstance(class_, type): + raise sa_exc.ArgumentError( + "Class object expected, got '%r'." % (class_, )) + raise exc.UnmappedClassError(class_) + else: + return mapper + + +class _InspectionAttr(object): + """A base class applied to all ORM objects that can be returned + by the :func:`.inspect` function. + + The attributes defined here allow the usage of simple boolean + checks to test basic facts about the object returned. + + While the boolean checks here are basically the same as using + the Python isinstance() function, the flags here can be used without + the need to import all of these classes, and also such that + the SQLAlchemy class system can change while leaving the flags + here intact for forwards-compatibility. + + """ + + is_selectable = False + """Return True if this object is an instance of :class:`.Selectable`.""" + + is_aliased_class = False + """True if this object is an instance of :class:`.AliasedClass`.""" + + is_instance = False + """True if this object is an instance of :class:`.InstanceState`.""" + + is_mapper = False + """True if this object is an instance of :class:`.Mapper`.""" + + is_property = False + """True if this object is an instance of :class:`.MapperProperty`.""" + + is_attribute = False + """True if this object is a Python :term:`descriptor`. + + This can refer to one of many types. Usually a + :class:`.QueryableAttribute` which handles attributes events on behalf + of a :class:`.MapperProperty`. But can also be an extension type + such as :class:`.AssociationProxy` or :class:`.hybrid_property`. + The :attr:`._InspectionAttr.extension_type` will refer to a constant + identifying the specific subtype. + + .. seealso:: + + :attr:`.Mapper.all_orm_descriptors` + + """ + + is_clause_element = False + """True if this object is an instance of :class:`.ClauseElement`.""" + + extension_type = NOT_EXTENSION + """The extension type, if any. + Defaults to :data:`.interfaces.NOT_EXTENSION` + + .. versionadded:: 0.8.0 + + .. seealso:: + + :data:`.HYBRID_METHOD` + + :data:`.HYBRID_PROPERTY` + + :data:`.ASSOCIATION_PROXY` + + """ + +class _MappedAttribute(object): + """Mixin for attributes which should be replaced by mapper-assigned + attributes. + + """ diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py new file mode 100644 index 00000000..87e351b6 --- /dev/null +++ b/lib/sqlalchemy/orm/collections.py @@ -0,0 +1,1550 @@ +# orm/collections.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for collections of mapped entities. + +The collections package supplies the machinery used to inform the ORM of +collection membership changes. An instrumentation via decoration approach is +used, allowing arbitrary types (including built-ins) to be used as entity +collections without requiring inheritance from a base class. + +Instrumentation decoration relays membership change events to the +:class:`.CollectionAttributeImpl` that is currently managing the collection. +The decorators observe function call arguments and return values, tracking +entities entering or leaving the collection. Two decorator approaches are +provided. One is a bundle of generic decorators that map function arguments +and return values to events:: + + from sqlalchemy.orm.collections import collection + class MyClass(object): + # ... + + @collection.adds(1) + def store(self, item): + self.data.append(item) + + @collection.removes_return() + def pop(self): + return self.data.pop() + + +The second approach is a bundle of targeted decorators that wrap appropriate +append and remove notifiers around the mutation methods present in the +standard Python ``list``, ``set`` and ``dict`` interfaces. These could be +specified in terms of generic decorator recipes, but are instead hand-tooled +for increased efficiency. The targeted decorators occasionally implement +adapter-like behavior, such as mapping bulk-set methods (``extend``, +``update``, ``__setslice__``, etc.) into the series of atomic mutation events +that the ORM requires. + +The targeted decorators are used internally for automatic instrumentation of +entity collection classes. Every collection class goes through a +transformation process roughly like so: + +1. If the class is a built-in, substitute a trivial sub-class +2. Is this class already instrumented? +3. Add in generic decorators +4. Sniff out the collection interface through duck-typing +5. Add targeted decoration to any undecorated interface method + +This process modifies the class at runtime, decorating methods and adding some +bookkeeping properties. This isn't possible (or desirable) for built-in +classes like ``list``, so trivial sub-classes are substituted to hold +decoration:: + + class InstrumentedList(list): + pass + +Collection classes can be specified in ``relationship(collection_class=)`` as +types or a function that returns an instance. Collection classes are +inspected and instrumented during the mapper compilation phase. The +collection_class callable will be executed once to produce a specimen +instance, and the type of that specimen will be instrumented. Functions that +return built-in types like ``lists`` will be adapted to produce instrumented +instances. + +When extending a known type like ``list``, additional decorations are not +generally not needed. Odds are, the extension method will delegate to a +method that's already instrumented. For example:: + + class QueueIsh(list): + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) + +There's no need to decorate these methods. ``append`` and ``pop`` are already +instrumented as part of the ``list`` interface. Decorating them would fire +duplicate events, which should be avoided. + +The targeted decoration tries not to rely on other methods in the underlying +collection class, but some are unavoidable. Many depend on 'read' methods +being present to properly instrument a 'write', for example, ``__setitem__`` +needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also +reimplemented in terms of atomic appends and removes, so the ``extend`` +decoration will actually perform many ``append`` operations and not call the +underlying method at all. + +Tight control over bulk operation and the firing of events is also possible by +implementing the instrumentation internally in your methods. The basic +instrumentation package works under the general assumption that collection +mutation will not raise unusual exceptions. If you want to closely +orchestrate append and remove events with exception management, internal +instrumentation may be the answer. Within your method, +``collection_adapter(self)`` will retrieve an object that you can use for +explicit control over triggering append and remove events. + +The owning object and :class:`.CollectionAttributeImpl` are also reachable +through the adapter, allowing for some very sophisticated behavior. + +""" + +import inspect +import operator +import weakref + +from ..sql import expression +from .. import util, exc as sa_exc +from . import base + + +__all__ = ['collection', 'collection_adapter', + 'mapped_collection', 'column_mapped_collection', + 'attribute_mapped_collection'] + +__instrumentation_mutex = util.threading.Lock() + + +class _PlainColumnGetter(object): + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetter(object): + """Column-based getter used in version 0.7.6 only. + + Remains here for pickle compatibility with 0.7.6. + + """ + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return _SerializableColumnGetter, (self.colkeys,) + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + key = [m._get_state_attr_by_column( + state, state.dict, + m.mapped_table.columns[k]) + for k in self.colkeys] + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, 'metadata', None) + for (ckey, tkey) in self.colkeys: + if tkey is None or \ + metadata is None or \ + tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from mapping_spec, which may be a Column or a sequence + of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + cols = [expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return lambda: MappedCollection(keyfunc) + + +class _SerializableAttrGetter(object): + def __init__(self, name): + self.name = name + self.getter = operator.attrgetter(name) + + def __call__(self, target): + return self.getter(target) + + def __reduce__(self): + return _SerializableAttrGetter, (self.name, ) + + +def attribute_mapped_collection(attr_name): + """A dictionary-based collection type with attribute-based keying. + + Returns a :class:`.MappedCollection` factory with a keying based on the + 'attr_name' attribute of entities in the collection, where ``attr_name`` + is the string name of the attribute. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + getter = _SerializableAttrGetter(attr_name) + return lambda: MappedCollection(getter) + + +def mapped_collection(keyfunc): + """A dictionary-based collection type with arbitrary keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return lambda: MappedCollection(keyfunc) + + +class collection(object): + """Decorators for entity collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, linker, converter, + internally_instrumented) indicate the method's purpose and take no + arguments. They are not written with parens:: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments:: + + @collection.adds('entity') + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + """ + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + @staticmethod + def appender(fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated:: + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = 'appender' + return fn + + @staticmethod + def remover(fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + :meth:`removes_return` if not already decorated:: + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = 'remover' + return fn + + @staticmethod + def iterator(fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members:: + + @collection.iterator + def __iter__(self): ... + + """ + fn._sa_instrument_role = 'iterator' + return fn + + @staticmethod + def internally_instrumented(fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the + method. Use this if you are orchestrating your own calls to + :func:`.collection_adapter` in one of the basic SQLAlchemy + interface methods, or to prevent an automatic ABC method + decoration from wrapping your implementation:: + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + + """ + fn._sa_instrumented = True + return fn + + @staticmethod + def linker(fn): + """Tag the method as a "linked to attribute" event handler. + + This optional event handler will be called when the collection class + is linked to or unlinked from the InstrumentedAttribute. It is + invoked immediately after the '_sa_adapter' property is set on + the instance. A single argument is passed: the collection adapter + that has been linked, or None if unlinking. + + """ + fn._sa_instrument_role = 'linker' + return fn + + link = linker + """deprecated; synonym for :meth:`.collection.linker`.""" + + @staticmethod + def converter(fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + it's sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated:: + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + + """ + fn._sa_instrument_role = 'converter' + return fn + + @staticmethod + def adds(arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value. Arguments can be specified positionally (i.e. integer) or by + name:: + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + + """ + def decorator(fn): + fn._sa_instrument_before = ('fire_append_event', arg) + return fn + return decorator + + @staticmethod + def replaces(arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.replaces(2) + def __setitem__(self, index, item): ... + + """ + def decorator(fn): + fn._sa_instrument_before = ('fire_append_event', arg) + fn._sa_instrument_after = 'fire_remove_event' + return fn + return decorator + + @staticmethod + def removes(arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name:: + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + + """ + def decorator(fn): + fn._sa_instrument_before = ('fire_remove_event', arg) + return fn + return decorator + + @staticmethod + def removes_return(): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return value + of the method, if any, is considered the value to remove. The method + arguments are not inspected:: + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + + """ + def decorator(fn): + fn._sa_instrument_after = 'fire_remove_event' + return fn + return decorator + + +collection_adapter = operator.attrgetter('_sa_adapter') +"""Fetch the :class:`.CollectionAdapter` for a collection.""" + + +class CollectionAdapter(object): + """Bridges between the ORM and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + + The ORM uses :class:`.CollectionAdapter` exclusively for interaction with + entity collections. + + + """ + invalidated = False + + def __init__(self, attr, owner_state, data): + self._key = attr.key + self._data = weakref.ref(data) + self.owner_state = owner_state + self.link_to_self(data) + + def _warn_invalidated(self): + util.warn("This collection has been invalidated.") + + @property + def data(self): + "The entity collection being adapted." + return self._data() + + @util.memoized_property + def attr(self): + return self.owner_state.manager[self._key].impl + + def link_to_self(self, data): + """Link a collection to this adapter""" + + data._sa_adapter = self + if data._sa_linker: + data._sa_linker(self) + + + def unlink(self, data): + """Unlink a collection from any adapter""" + + del data._sa_adapter + if data._sa_linker: + data._sa_linker(None) + + def adapt_like_to_iterable(self, obj): + """Converts collection-compatible objects to an iterable of values. + + Can be passed any type of object, and if the underlying collection + determines that it can be adapted into a stream of values it can + use, returns an iterable of values suitable for append()ing. + + This method may raise TypeError or any other suitable exception + if adaptation fails. + + If a converter implementation is not supplied on the collection, + a default duck-typing-based implementation is used. + + """ + converter = self._data()._sa_converter + if converter is not None: + return converter(obj) + + setting_type = util.duck_type_collection(obj) + receiving_type = util.duck_type_collection(self._data()) + + if obj is None or setting_type != receiving_type: + given = obj is None and 'None' or obj.__class__.__name__ + if receiving_type is None: + wanted = self._data().__class__.__name__ + else: + wanted = receiving_type.__name__ + + raise TypeError( + "Incompatible collection type: %s is not %s-like" % ( + given, wanted)) + + # If the object is an adapted collection, return the (iterable) + # adapter. + if getattr(obj, '_sa_adapter', None) is not None: + return obj._sa_adapter + elif setting_type == dict: + if util.py3k: + return obj.values() + else: + return getattr(obj, 'itervalues', obj.values)() + else: + return iter(obj) + + def append_with_event(self, item, initiator=None): + """Add an entity to the collection, firing mutation events.""" + + self._data()._sa_appender(item, _sa_initiator=initiator) + + def append_without_event(self, item): + """Add or restore an entity to the collection, firing no events.""" + self._data()._sa_appender(item, _sa_initiator=False) + + def append_multiple_without_event(self, items): + """Add or restore an entity to the collection, firing no events.""" + appender = self._data()._sa_appender + for item in items: + appender(item, _sa_initiator=False) + + def remove_with_event(self, item, initiator=None): + """Remove an entity from the collection, firing mutation events.""" + self._data()._sa_remover(item, _sa_initiator=initiator) + + def remove_without_event(self, item): + """Remove an entity from the collection, firing no events.""" + self._data()._sa_remover(item, _sa_initiator=False) + + def clear_with_event(self, initiator=None): + """Empty the collection, firing a mutation event for each entity.""" + + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=initiator) + + def clear_without_event(self): + """Empty the collection, firing no events.""" + + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=False) + + def __iter__(self): + """Iterate over entities in the collection.""" + + return iter(self._data()._sa_iterator()) + + def __len__(self): + """Count entities in the collection.""" + return len(list(self._data()._sa_iterator())) + + def __bool__(self): + return True + + __nonzero__ = __bool__ + + def fire_append_event(self, item, initiator=None): + """Notify that a entity has entered the collection. + + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + return self.attr.fire_append_event( + self.owner_state, + self.owner_state.dict, + item, initiator) + else: + return item + + def fire_remove_event(self, item, initiator=None): + """Notify that a entity has been removed from the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + self.attr.fire_remove_event( + self.owner_state, + self.owner_state.dict, + item, initiator) + + def fire_pre_remove_event(self, initiator=None): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + + """ + if self.invalidated: + self._warn_invalidated() + self.attr.fire_pre_remove_event( + self.owner_state, + self.owner_state.dict, + initiator=initiator) + + def __getstate__(self): + return {'key': self._key, + 'owner_state': self.owner_state, + 'data': self.data} + + def __setstate__(self, d): + self._key = d['key'] + self.owner_state = d['owner_state'] + self._data = weakref.ref(d['data']) + + +def bulk_replace(values, existing_adapter, new_adapter): + """Load a new collection, firing events based on prior like membership. + + Appends instances in ``values`` onto the ``new_adapter``. Events will be + fired for any instance not present in the ``existing_adapter``. Any + instances in ``existing_adapter`` not present in ``values`` will have + remove events fired upon them. + + :param values: An iterable of collection member instances + + :param existing_adapter: A :class:`.CollectionAdapter` of + instances to be replaced + + :param new_adapter: An empty :class:`.CollectionAdapter` + to load with ``values`` + + + """ + if not isinstance(values, list): + values = list(values) + + idset = util.IdentitySet + existing_idset = idset(existing_adapter or ()) + constants = existing_idset.intersection(values or ()) + additions = idset(values or ()).difference(constants) + removals = existing_idset.difference(constants) + + for member in values or (): + if member in additions: + new_adapter.append_with_event(member) + elif member in constants: + new_adapter.append_without_event(member) + + if existing_adapter: + for member in removals: + existing_adapter.remove_with_event(member) + + +def prepare_instrumentation(factory): + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + + """ + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + factory = __canned_instrumentation[factory] + + # Create a specimen + cls = type(factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + # Wrap it so that it returns our 'Instrumented*' + factory = __converting_factory(cls, factory) + cls = factory() + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if getattr(cls, '_sa_instrumented', None) != id(cls): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return factory + + +def __converting_factory(specimen_cls, original_factory): + """Return a wrapper that converts a "canned" collection like + set, dict, list into the Instrumented* version. + + """ + + instrumented_cls = __canned_instrumentation[specimen_cls] + + def wrapper(): + collection = original_factory() + return instrumented_cls(collection) + + # often flawed but better than nothing + wrapper.__name__ = "%sWrapper" % original_factory.__name__ + wrapper.__doc__ = original_factory.__doc__ + + return wrapper + +def _instrument_class(cls): + """Modify methods in a class and install instrumentation.""" + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == '__builtin__': + raise sa_exc.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one.") + + roles = {} + methods = {} + + # search for _sa_instrument_role-decorated methods in + # method resolution order, assign to roles + for supercls in cls.__mro__: + for name, method in vars(supercls).items(): + if not util.callable(method): + continue + + # note role declarations + if hasattr(method, '_sa_instrument_role'): + role = method._sa_instrument_role + assert role in ('appender', 'remover', 'iterator', + 'linker', 'converter') + roles.setdefault(role, name) + + # transfer instrumentation requests from decorated function + # to the combined queue + before, after = None, None + if hasattr(method, '_sa_instrument_before'): + op, argument = method._sa_instrument_before + assert op in ('fire_append_event', 'fire_remove_event') + before = op, argument + if hasattr(method, '_sa_instrument_after'): + op = method._sa_instrument_after + assert op in ('fire_append_event', 'fire_remove_event') + after = op + if before: + methods[name] = before[0], before[1], after + elif after: + methods[name] = None, None, after + + # see if this class has "canned" roles based on a known + # collection type (dict, set, list). Apply those roles + # as needed to the "roles" dictionary, and also + # prepare "decorator" methods + collection_type = util.duck_type_collection(cls) + if collection_type in __interfaces: + canned_roles, decorators = __interfaces[collection_type] + for role, name in canned_roles.items(): + roles.setdefault(role, name) + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if (fn and method not in methods and + not hasattr(fn, '_sa_instrumented')): + setattr(cls, method, decorator(fn)) + + # ensure all roles are present, and apply implicit instrumentation if + # needed + if 'appender' not in roles or not hasattr(cls, roles['appender']): + raise sa_exc.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__) + elif (roles['appender'] not in methods and + not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')): + methods[roles['appender']] = ('fire_append_event', 1, None) + + if 'remover' not in roles or not hasattr(cls, roles['remover']): + raise sa_exc.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__) + elif (roles['remover'] not in methods and + not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')): + methods[roles['remover']] = ('fire_remove_event', 1, None) + + if 'iterator' not in roles or not hasattr(cls, roles['iterator']): + raise sa_exc.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__) + + # apply ad-hoc instrumentation from decorators, class-level defaults + # and implicit role declarations + for method_name, (before, argument, after) in methods.items(): + setattr(cls, method_name, + _instrument_membership_mutator(getattr(cls, method_name), + before, argument, after)) + # intern the role map + for role, method_name in roles.items(): + setattr(cls, '_sa_%s' % role, getattr(cls, method_name)) + + cls._sa_adapter = None + if not hasattr(cls, '_sa_linker'): + cls._sa_linker = None + if not hasattr(cls, '_sa_converter'): + cls._sa_converter = None + cls._sa_instrumented = id(cls) + + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection adapter.""" + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) + if type(argument) is int: + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) + else: + pos_arg = None + named_arg = argument + del fn_args + + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument) + value = kw[named_arg] + else: + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] + else: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument) + + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = args[0]._sa_adapter + + if before and executor: + getattr(executor, before)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + + wrapper._sa_instrumented = True + if hasattr(method, "_sa_instrument_role"): + wrapper._sa_instrument_role = method._sa_instrument_role + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + return wrapper + + +def __set(collection, item, _sa_initiator=None): + """Run set events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + item = executor.fire_append_event(item, _sa_initiator) + return item + + +def __del(collection, item, _sa_initiator=None): + """Run del events, may eventually be inlined into decorators.""" + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + executor.fire_remove_event(item, _sa_initiator) + + +def __before_delete(collection, _sa_initiator=None): + """Special method to run 'commit existing value' methods""" + executor = collection._sa_adapter + if executor: + executor.fire_pre_remove_event(_sa_initiator) + + +def _list_decorators(): + """Tailored instrumentation wrappers for any list-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(list, fn.__name__).__doc__ + + def append(fn): + def append(self, item, _sa_initiator=None): + item = __set(self, item, _sa_initiator) + fn(self, item) + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + __before_delete(self, _sa_initiator) + # testlib.pragma exempt:__eq__ + fn(self, value) + __del(self, value, _sa_initiator) + _tidy(remove) + return remove + + def insert(fn): + def insert(self, index, value): + value = __set(self, value) + fn(self, index, value) + _tidy(insert) + return insert + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing) + value = __set(self, value) + fn(self, index, value) + else: + # slice assignment requires __delitem__, insert, __len__ + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + if index.stop is not None: + stop = index.stop + else: + stop = len(self) + if stop < 0: + stop += len(self) + + if step == 1: + for i in range(start, stop, step): + if len(self) > start: + del self[start] + + for i, item in enumerate(value): + self.insert(i + start, item) + else: + rng = list(range(start, stop, step)) + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), + len(rng))) + for i, item in zip(rng, value): + self.__setitem__(i, item) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + if not isinstance(index, slice): + item = self[index] + __del(self, item) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item) + fn(self, index) + _tidy(__delitem__) + return __delitem__ + + if util.py2k: + def __setslice__(fn): + def __setslice__(self, start, end, values): + for value in self[start:end]: + __del(self, value) + values = [__set(self, value) for value in values] + fn(self, start, end, values) + _tidy(__setslice__) + return __setslice__ + + def __delslice__(fn): + def __delslice__(self, start, end): + for value in self[start:end]: + __del(self, value) + fn(self, start, end) + _tidy(__delslice__) + return __delslice__ + + def extend(fn): + def extend(self, iterable): + for value in iterable: + self.append(value) + _tidy(extend) + return extend + + def __iadd__(fn): + def __iadd__(self, iterable): + # list.__iadd__ takes any iterable and seems to let TypeError raise + # as-is instead of returning NotImplemented + for value in iterable: + self.append(value) + return self + _tidy(__iadd__) + return __iadd__ + + def pop(fn): + def pop(self, index=-1): + __before_delete(self) + item = fn(self, index) + __del(self, item) + return item + _tidy(pop) + return pop + + if not util.py2k: + def clear(fn): + def clear(self, index=-1): + for item in self: + __del(self, item) + fn(self) + _tidy(clear) + return clear + + # __imul__ : not wrapping this. all members of the collection are already + # present, so no need to fire appends... wrapping it with an explicit + # decorator is still possible, so events on *= can be had if they're + # desired. hard to imagine a use case for __imul__, though. + + l = locals().copy() + l.pop('_tidy') + return l + + +def _dict_decorators(): + """Tailored instrumentation wrappers for any dict-like mapping class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(dict, fn.__name__).__doc__ + + Unspecified = util.symbol('Unspecified') + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + value = __set(self, value, _sa_initiator) + fn(self, key, value) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + fn(self, key) + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key]) + fn(self) + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=Unspecified): + if key in self: + __del(self, self[key]) + if default is Unspecified: + return fn(self, key) + else: + return fn(self, key, default) + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + __before_delete(self) + item = fn(self) + __del(self, item[1]) + return item + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self: + self.__setitem__(key, default) + return default + else: + return self.__getitem__(key) + _tidy(setdefault) + return setdefault + + def update(fn): + def update(self, __other=Unspecified, **kw): + if __other is not Unspecified: + if hasattr(__other, 'keys'): + for key in list(__other): + if (key not in self or + self[key] is not __other[key]): + self[key] = __other[key] + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + _tidy(update) + return update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + +_set_binop_bases = (set, frozenset) + + +def _set_binops_check_strict(self, obj): + """Allow only set, frozenset and self.__class__-derived + objects in binops.""" + return isinstance(obj, _set_binop_bases + (self.__class__,)) + + +def _set_binops_check_loose(self, obj): + """Allow anything set-like to participate in set binops.""" + return (isinstance(obj, _set_binop_bases + (self.__class__,)) or + util.duck_type_collection(obj) == set) + + +def _set_decorators(): + """Tailored instrumentation wrappers for any set-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(set, fn.__name__).__doc__ + + Unspecified = util.symbol('Unspecified') + + def add(fn): + def add(self, value, _sa_initiator=None): + if value not in self: + value = __set(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(add) + return add + + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + __before_delete(self) + item = fn(self) + __del(self, item) + return item + _tidy(pop) + return pop + + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + _tidy(clear) + return clear + + def update(fn): + def update(self, value): + for item in value: + self.add(item) + _tidy(update) + return update + + def __ior__(fn): + def __ior__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.add(item) + return self + _tidy(__ior__) + return __ior__ + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + _tidy(difference_update) + return difference_update + + def __isub__(fn): + def __isub__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.discard(item) + return self + _tidy(__isub__) + return __isub__ + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(intersection_update) + return intersection_update + + def __iand__(fn): + def __iand__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__iand__) + return __iand__ + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(symmetric_difference_update) + return symmetric_difference_update + + def __ixor__(fn): + def __ixor__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__ixor__) + return __ixor__ + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + + +class InstrumentedList(list): + """An instrumented version of the built-in list.""" + + +class InstrumentedSet(set): + """An instrumented version of the built-in set.""" + + +class InstrumentedDict(dict): + """An instrumented version of the built-in dict.""" + + +__canned_instrumentation = { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } + +__interfaces = { + list: ( + {'appender': 'append', 'remover': 'remove', + 'iterator': '__iter__'}, _list_decorators() + ), + + set: ({'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__'}, _set_decorators() + ), + + # decorators are required for dicts and object collections. + dict: ({'iterator': 'values'}, _dict_decorators()) if util.py3k + else ({'iterator': 'itervalues'}, _dict_decorators()), + } + + +class MappedCollection(dict): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection + classes require. ``set`` and ``remove`` are implemented in terms + of a keying function: any callable that takes an object and + returns an object for use as a dictionary key. + + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable any callable that takes an object and + returns an object for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + + @collection.appender + @collection.internally_instrumented + def set(self, value, _sa_initiator=None): + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + + @collection.remover + @collection.internally_instrumented + def remove(self, value, _sa_initiator=None): + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % + (value, self[key], key)) + self.__delitem__(key, _sa_initiator) + + @collection.converter + def _convert(self, dictlike): + """Validate and convert a dict-like object into values for set()ing. + + This is called behind the scenes when a MappedCollection is replaced + entirely by another collection, as in:: + + myobj.mappedcollection = {'a':obj1, 'b': obj2} # ... + + Raises a TypeError if the key in any (key, value) pair in the dictlike + object does not match the key that this collection's keyfunc would + have assigned for that value. + + """ + for incoming_key, value in util.dictlike_iteritems(dictlike): + new_key = self.keyfunc(value) + if incoming_key != new_key: + raise TypeError( + "Found incompatible key %r for value %r; this " + "collection's " + "keying function requires a key of %r for this value." % ( + incoming_key, value, new_key)) + yield value + +# ensure instrumentation is associated with +# these built-in classes; if a user-defined class +# subclasses these and uses @internally_instrumented, +# the superclass is otherwise not instrumented. +# see [ticket:2406]. +_instrument_class(MappedCollection) +_instrument_class(InstrumentedList) +_instrument_class(InstrumentedSet) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py new file mode 100644 index 00000000..34a2af39 --- /dev/null +++ b/lib/sqlalchemy/orm/dependency.py @@ -0,0 +1,1165 @@ +# orm/dependency.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Relationship dependencies. + +""" + +from .. import sql, util, exc as sa_exc +from . import attributes, exc, sync, unitofwork, \ + util as mapperutil +from .interfaces import ONETOMANY, MANYTOONE, MANYTOMANY + + +class DependencyProcessor(object): + def __init__(self, prop): + self.prop = prop + self.cascade = prop.cascade + self.mapper = prop.mapper + self.parent = prop.parent + self.secondary = prop.secondary + self.direction = prop.direction + self.post_update = prop.post_update + self.passive_deletes = prop.passive_deletes + self.passive_updates = prop.passive_updates + self.enable_typechecks = prop.enable_typechecks + if self.passive_deletes: + self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_delete_flag = attributes.PASSIVE_OFF + if self.passive_updates: + self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_update_flag = attributes.PASSIVE_OFF + + self.key = prop.key + if not self.prop.synchronize_pairs: + raise sa_exc.ArgumentError( + "Can't build a DependencyProcessor for relationship %s. " + "No target attributes to populate between parent and " + "child are present" % + self.prop) + + @classmethod + def from_relationship(cls, prop): + return _direction_to_processor[prop.direction](prop) + + def hasparent(self, state): + """return True if the given object instance has a parent, + according to the ``InstrumentedAttribute`` handled by this + ``DependencyProcessor``. + + """ + return self.parent.class_manager.get_impl(self.key).hasparent(state) + + def per_property_preprocessors(self, uow): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states in + the aggregate. + + """ + uow.register_preprocessor(self, True) + + def per_property_flush_actions(self, uow): + after_save = unitofwork.ProcessAll(uow, self, False, True) + before_delete = unitofwork.ProcessAll(uow, self, True, True) + + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.primary_base_mapper + ) + child_saves = unitofwork.SaveUpdateAll( + uow, + self.mapper.primary_base_mapper + ) + + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.primary_base_mapper + ) + child_deletes = unitofwork.DeleteAll( + uow, + self.mapper.primary_base_mapper + ) + + self.per_property_dependencies(uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete + ) + + def per_state_flush_actions(self, uow, states, isdelete): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states + individually. This occurs only if there are cycles + in the 'aggregated' version of events. + + """ + + parent_base_mapper = self.parent.primary_base_mapper + child_base_mapper = self.mapper.primary_base_mapper + child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper) + child_deletes = unitofwork.DeleteAll(uow, child_base_mapper) + + # locate and disable the aggregate processors + # for this dependency + + if isdelete: + before_delete = unitofwork.ProcessAll(uow, self, True, True) + before_delete.disabled = True + else: + after_save = unitofwork.ProcessAll(uow, self, False, True) + after_save.disabled = True + + # check if the "child" side is part of the cycle + + if child_saves not in uow.cycles: + # based on the current dependencies we use, the saves/ + # deletes should always be in the 'cycles' collection + # together. if this changes, we will have to break up + # this method a bit more. + assert child_deletes not in uow.cycles + + # child side is not part of the cycle, so we will link per-state + # actions to the aggregate "saves", "deletes" actions + child_actions = [ + (child_saves, False), (child_deletes, True) + ] + child_in_cycles = False + else: + child_in_cycles = True + + # check if the "parent" side is part of the cycle + if not isdelete: + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) + parent_deletes = before_delete = None + if parent_saves in uow.cycles: + parent_in_cycles = True + else: + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.base_mapper) + parent_saves = after_save = None + if parent_deletes in uow.cycles: + parent_in_cycles = True + + # now create actions /dependencies for each state. + for state in states: + # detect if there's anything changed or loaded + # by a preprocessor on this state/attribute. if not, + # we should be able to skip it entirely. + sum_ = state.manager[self.key].impl.get_all_pending( + state, state.dict) + + if not sum_: + continue + + if isdelete: + before_delete = unitofwork.ProcessState(uow, + self, True, state) + if parent_in_cycles: + parent_deletes = unitofwork.DeleteState( + uow, + state, + parent_base_mapper) + else: + after_save = unitofwork.ProcessState(uow, self, False, state) + if parent_in_cycles: + parent_saves = unitofwork.SaveUpdateState( + uow, + state, + parent_base_mapper) + + if child_in_cycles: + child_actions = [] + for child_state, child in sum_: + if child_state not in uow.states: + child_action = (None, None) + else: + (deleted, listonly) = uow.states[child_state] + if deleted: + child_action = ( + unitofwork.DeleteState( + uow, child_state, + child_base_mapper), + True) + else: + child_action = ( + unitofwork.SaveUpdateState( + uow, child_state, + child_base_mapper), + False) + child_actions.append(child_action) + + # establish dependencies between our possibly per-state + # parent action and our possibly per-state child action. + for child_action, childisdelete in child_actions: + self.per_state_dependencies(uow, parent_saves, + parent_deletes, + child_action, + after_save, before_delete, + isdelete, childisdelete) + + def presort_deletes(self, uowcommit, states): + return False + + def presort_saves(self, uowcommit, states): + return False + + def process_deletes(self, uowcommit, states): + pass + + def process_saves(self, uowcommit, states): + pass + + def prop_has_changes(self, uowcommit, states, isdelete): + if not isdelete or self.passive_deletes: + passive = attributes.PASSIVE_NO_INITIALIZE + elif self.direction is MANYTOONE: + passive = attributes.PASSIVE_NO_FETCH_RELATED + else: + passive = attributes.PASSIVE_OFF + + for s in states: + # TODO: add a high speed method + # to InstanceState which returns: attribute + # has a non-None value, or had one + history = uowcommit.get_attribute_history( + s, + self.key, + passive) + if history and not history.empty(): + return True + else: + return states and \ + not self.prop._is_self_referential and \ + self.mapper in uowcommit.mappers + + def _verify_canload(self, state): + if self.prop.uselist and state is None: + raise exc.FlushError( + "Can't flush None value found in " + "collection %s" % (self.prop, )) + elif state is not None and \ + not self.mapper._canload(state, + allow_subtypes=not self.enable_typechecks): + if self.mapper._canload(state, allow_subtypes=True): + raise exc.FlushError('Attempting to flush an item of type ' + '%(x)s as a member of collection ' + '"%(y)s". Expected an object of type ' + '%(z)s or a polymorphic subclass of ' + 'this type. If %(x)s is a subclass of ' + '%(z)s, configure mapper "%(zm)s" to ' + 'load this subtype polymorphically, or ' + 'set enable_typechecks=False to allow ' + 'any subtype to be accepted for flush. ' + % { + 'x': state.class_, + 'y': self.prop, + 'z': self.mapper.class_, + 'zm': self.mapper, + }) + else: + raise exc.FlushError( + 'Attempting to flush an item of type ' + '%(x)s as a member of collection ' + '"%(y)s". Expected an object of type ' + '%(z)s or a polymorphic subclass of ' + 'this type.' % { + 'x': state.class_, + 'y': self.prop, + 'z': self.mapper.class_, + }) + + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit): + raise NotImplementedError() + + def _get_reversed_processed_set(self, uow): + if not self.prop._reverse_property: + return None + + process_key = tuple(sorted( + [self.key] + + [p.key for p in self.prop._reverse_property] + )) + return uow.memo( + ('reverse_key', process_key), + set + ) + + def _post_update(self, state, uowcommit, related): + for x in related: + if x is not None: + uowcommit.issue_post_update( + state, + [r for l, r in self.prop.synchronize_pairs] + ) + break + + def _pks_changed(self, uowcommit, state): + raise NotImplementedError() + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.prop) + + +class OneToManyDP(DependencyProcessor): + + def per_property_dependencies(self, uow, parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + if self.post_update: + child_post_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + False) + child_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + True) + + uow.dependencies.update([ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, child_post_updates), + + (before_delete, child_pre_updates), + (child_pre_updates, parent_deletes), + (child_pre_updates, child_deletes), + + ]) + else: + uow.dependencies.update([ + (parent_saves, after_save), + (after_save, child_saves), + (after_save, child_deletes), + + (child_saves, parent_deletes), + (child_deletes, parent_deletes), + + (before_delete, child_saves), + (before_delete, child_deletes), + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): + + if self.post_update: + + child_post_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + False) + child_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + True) + + # TODO: this whole block is not covered + # by any tests + if not isdelete: + if childisdelete: + uow.dependencies.update([ + (child_action, after_save), + (after_save, child_post_updates), + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + (after_save, child_post_updates), + ]) + else: + if childisdelete: + uow.dependencies.update([ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ]) + else: + uow.dependencies.update([ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ]) + elif not isdelete: + uow.dependencies.update([ + (save_parent, after_save), + (after_save, child_action), + (save_parent, child_action) + ]) + else: + uow.dependencies.update([ + (before_delete, child_action), + (child_action, delete_parent) + ]) + + def presort_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their + # foreign key to the parent set to NULL + should_null_fks = not self.cascade.delete and \ + not self.passive_deletes == 'all' + + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=True) + else: + uowcommit.register_object(child) + + if should_null_fks: + for child in history.unchanged: + if child is not None: + uowcommit.register_object(child, + operation="delete", prop=self.prop) + + def presort_saves(self, uowcommit, states): + children_added = uowcommit.memo(('children_added', self), set) + + for state in states: + pks_changed = self._pks_changed(uowcommit, state) + + if not pks_changed or self.passive_updates: + passive = attributes.PASSIVE_NO_INITIALIZE + else: + passive = attributes.PASSIVE_OFF + + history = uowcommit.get_attribute_history( + state, + self.key, + passive) + if history: + for child in history.added: + if child is not None: + uowcommit.register_object(child, cancel_delete=True, + operation="add", + prop=self.prop) + + children_added.update(history.added) + + for child in history.deleted: + if not self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=False, + operation='delete', + prop=self.prop) + elif self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True, + operation="delete", prop=self.prop) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + 'delete', child): + uowcommit.register_object( + st_, + isdelete=True) + + if pks_changed: + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, + False, + self.passive_updates, + operation="pk change", + prop=self.prop) + + def process_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their foreign + # key to the parent set to NULL this phase can be called + # safely for any cascade but is unnecessary if delete cascade + # is on. + + if self.post_update or not self.passive_deletes == 'all': + children_added = uowcommit.memo(('children_added', self), set) + + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if child is not None and \ + self.hasparent(child) is False: + self._synchronize( + state, + child, + None, True, + uowcommit, False) + if self.post_update and child: + self._post_update(child, uowcommit, [state]) + + if self.post_update or not self.cascade.delete: + for child in set(history.unchanged).\ + difference(children_added): + if child is not None: + self._synchronize( + state, + child, + None, True, + uowcommit, False) + if self.post_update and child: + self._post_update(child, + uowcommit, + [state]) + + # technically, we can even remove each child from the + # collection here too. but this would be a somewhat + # inconsistent behavior since it wouldn't happen + #if the old parent wasn't deleted but child was moved. + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + for child in history.added: + self._synchronize(state, child, None, + False, uowcommit, False) + if child is not None and self.post_update: + self._post_update(child, uowcommit, [state]) + + for child in history.deleted: + if not self.cascade.delete_orphan and \ + not self.hasparent(child): + self._synchronize(state, child, None, True, + uowcommit, False) + + if self._pks_changed(uowcommit, state): + for child in history.unchanged: + self._synchronize(state, child, None, + False, uowcommit, True) + + def _synchronize(self, state, child, + associationrow, clearkeys, uowcommit, + pks_changed): + source = state + dest = child + self._verify_canload(child) + if dest is None or \ + (not self.post_update and uowcommit.is_deleted(dest)): + return + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate(source, self.parent, dest, self.mapper, + self.prop.synchronize_pairs, uowcommit, + self.passive_updates and pks_changed) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, + state, + self.parent, + self.prop.synchronize_pairs) + + +class ManyToOneDP(DependencyProcessor): + def __init__(self, prop): + DependencyProcessor.__init__(self, prop) + self.mapper._dependency_processors.append(DetectKeySwitch(prop)) + + def per_property_dependencies(self, uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete): + + if self.post_update: + parent_post_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + False) + parent_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + True) + + uow.dependencies.update([ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, parent_post_updates), + + (after_save, parent_pre_updates), + (before_delete, parent_pre_updates), + + (parent_pre_updates, child_deletes), + ]) + else: + uow.dependencies.update([ + (child_saves, after_save), + (after_save, parent_saves), + (parent_saves, child_deletes), + (parent_deletes, child_deletes) + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): + + if self.post_update: + + if not isdelete: + parent_post_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + False) + if childisdelete: + uow.dependencies.update([ + (after_save, parent_post_updates), + (parent_post_updates, child_action) + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + + (after_save, parent_post_updates) + ]) + else: + parent_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + True) + + uow.dependencies.update([ + (before_delete, parent_pre_updates), + (parent_pre_updates, delete_parent), + (parent_pre_updates, child_action) + ]) + + elif not isdelete: + if not childisdelete: + uow.dependencies.update([ + (child_action, after_save), + (after_save, save_parent), + ]) + else: + uow.dependencies.update([ + (after_save, save_parent), + ]) + + else: + if childisdelete: + uow.dependencies.update([ + (delete_parent, child_action) + ]) + + def presort_deletes(self, uowcommit, states): + if self.cascade.delete or self.cascade.delete_orphan: + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + if self.cascade.delete_orphan: + todelete = history.sum() + else: + todelete = history.non_deleted() + for child in todelete: + if child is None: + continue + uowcommit.register_object(child, isdelete=True, + operation="delete", prop=self.prop) + t = self.mapper.cascade_iterator('delete', child) + for c, m, st_, dct_ in t: + uowcommit.register_object( + st_, isdelete=True) + + def presort_saves(self, uowcommit, states): + for state in states: + uowcommit.register_object(state, operation="add", prop=self.prop) + if self.cascade.delete_orphan: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True, + operation="delete", prop=self.prop) + + t = self.mapper.cascade_iterator('delete', child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + if self.post_update and \ + not self.cascade.delete_orphan and \ + not self.passive_deletes == 'all': + + # post_update means we have to update our + # row to not reference the child object + # before we can DELETE the row + for state in states: + self._synchronize(state, None, None, True, uowcommit) + if state and self.post_update: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + self._post_update(state, uowcommit, history.sum()) + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + for child in history.added: + self._synchronize(state, child, None, False, + uowcommit, "add") + + if self.post_update: + self._post_update(state, uowcommit, history.sum()) + + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit, operation=None): + if state is None or \ + (not self.post_update and uowcommit.is_deleted(state)): + return + + if operation is not None and \ + child is not None and \ + not uowcommit.session._contains_state(child): + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" % + (mapperutil.state_class_str(child), operation, self.prop)) + return + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate(child, self.mapper, state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + False) + + +class DetectKeySwitch(DependencyProcessor): + """For many-to-one relationships with no one-to-many backref, + searches for parents through the unit of work when a primary + key has changed and updates them. + + Theoretically, this approach could be expanded to support transparent + deletion of objects referenced via many-to-one as well, although + the current attribute system doesn't do enough bookkeeping for this + to be efficient. + + """ + + def per_property_preprocessors(self, uow): + if self.prop._reverse_property: + if self.passive_updates: + return + else: + if False in (prop.passive_updates for \ + prop in self.prop._reverse_property): + return + + uow.register_preprocessor(self, False) + + def per_property_flush_actions(self, uow): + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) + after_save = unitofwork.ProcessAll(uow, self, False, False) + uow.dependencies.update([ + (parent_saves, after_save) + ]) + + def per_state_flush_actions(self, uow, states, isdelete): + pass + + def presort_deletes(self, uowcommit, states): + pass + + def presort_saves(self, uow, states): + if not self.passive_updates: + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + self._process_key_switches(states, uow) + + def prop_has_changes(self, uow, states, isdelete): + if not isdelete and self.passive_updates: + d = self._key_switchers(uow, states) + return bool(d) + + return False + + def process_deletes(self, uowcommit, states): + assert False + + def process_saves(self, uowcommit, states): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + assert self.passive_updates + self._process_key_switches(states, uowcommit) + + def _key_switchers(self, uow, states): + switched, notswitched = uow.memo( + ('pk_switchers', self), + lambda: (set(), set()) + ) + + allstates = switched.union(notswitched) + for s in states: + if s not in allstates: + if self._pks_changed(uow, s): + switched.add(s) + else: + notswitched.add(s) + return switched + + def _process_key_switches(self, deplist, uowcommit): + switchers = self._key_switchers(uowcommit, deplist) + if switchers: + # if primary key values have actually changed somewhere, perform + # a linear search through the UOW in search of a parent. + for state in uowcommit.session.identity_map.all_states(): + if not issubclass(state.class_, self.parent.class_): + continue + dict_ = state.dict + related = state.get_impl(self.key).get(state, dict_, + passive=self._passive_update_flag) + if related is not attributes.PASSIVE_NO_RESULT and \ + related is not None: + related_state = attributes.instance_state(dict_[self.key]) + if related_state in switchers: + uowcommit.register_object(state, + False, + self.passive_updates) + sync.populate( + related_state, + self.mapper, state, + self.parent, self.prop.synchronize_pairs, + uowcommit, self.passive_updates) + + def _pks_changed(self, uowcommit, state): + return bool(state.key) and sync.source_modified(uowcommit, + state, + self.mapper, + self.prop.synchronize_pairs) + + +class ManyToManyDP(DependencyProcessor): + + def per_property_dependencies(self, uow, parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete + ): + + uow.dependencies.update([ + (parent_saves, after_save), + (child_saves, after_save), + (after_save, child_deletes), + + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), + + (before_delete, parent_deletes), + (before_delete, child_deletes), + (before_delete, child_saves), + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): + if not isdelete: + if childisdelete: + uow.dependencies.update([ + (save_parent, after_save), + (after_save, child_action), + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + ]) + else: + uow.dependencies.update([ + (before_delete, child_action), + (before_delete, delete_parent) + ]) + + def presort_deletes(self, uowcommit, states): + # TODO: no tests fail if this whole + # thing is removed !!!! + if not self.passive_deletes: + # if no passive deletes, load history on + # the collection, so that prop_has_changes() + # returns True + for state in states: + uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + + def presort_saves(self, uowcommit, states): + if not self.passive_updates: + # if no passive updates, load history on + # each collection where parent has changed PK, + # so that prop_has_changes() returns True + for state in states: + if self._pks_changed(uowcommit, state): + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_OFF) + + if not self.cascade.delete_orphan: + return + + # check for child items removed from the collection + # if delete_orphan check is turned on. + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object(child, isdelete=True, + operation="delete", prop=self.prop) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + 'delete', + child): + uowcommit.register_object( + st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + for state in states: + # this history should be cached already, as + # we loaded it in preprocess_deletes + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.non_added(): + if child is None or \ + (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, uowcommit, "delete"): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.non_added()) + + if processed is not None: + processed.update(tmp) + + self._run_crud(uowcommit, secondary_insert, + secondary_update, secondary_delete) + + def process_saves(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + + for state in states: + need_cascade_pks = not self.passive_updates and \ + self._pks_changed(uowcommit, state) + if need_cascade_pks: + passive = attributes.PASSIVE_OFF + else: + passive = attributes.PASSIVE_NO_INITIALIZE + history = uowcommit.get_attribute_history(state, self.key, + passive) + if history: + for child in history.added: + if (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize(state, + child, + associationrow, + False, uowcommit, "add"): + continue + secondary_insert.append(associationrow) + for child in history.deleted: + if (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize(state, + child, + associationrow, + False, uowcommit, "delete"): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) + for c in history.added + history.deleted) + + if need_cascade_pks: + + for child in history.unchanged: + associationrow = {} + sync.update(state, + self.parent, + associationrow, + "old_", + self.prop.synchronize_pairs) + sync.update(child, + self.mapper, + associationrow, + "old_", + self.prop.secondary_synchronize_pairs) + + secondary_update.append(associationrow) + + if processed is not None: + processed.update(tmp) + + self._run_crud(uowcommit, secondary_insert, + secondary_update, secondary_delete) + + def _run_crud(self, uowcommit, secondary_insert, + secondary_update, secondary_delete): + connection = uowcommit.transaction.connection(self.mapper) + + if secondary_delete: + associationrow = secondary_delete[0] + statement = self.secondary.delete(sql.and_(*[ + c == sql.bindparam(c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ])) + result = connection.execute(statement, secondary_delete) + + if result.supports_sane_multi_rowcount() and \ + result.rowcount != len(secondary_delete): + raise exc.StaleDataError( + "DELETE statement on table '%s' expected to delete " + "%d row(s); Only %d were matched." % + (self.secondary.description, len(secondary_delete), + result.rowcount) + ) + + if secondary_update: + associationrow = secondary_update[0] + statement = self.secondary.update(sql.and_(*[ + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ])) + result = connection.execute(statement, secondary_update) + if result.supports_sane_multi_rowcount() and \ + result.rowcount != len(secondary_update): + raise exc.StaleDataError( + "UPDATE statement on table '%s' expected to update " + "%d row(s); Only %d were matched." % + (self.secondary.description, len(secondary_update), + result.rowcount) + ) + + if secondary_insert: + statement = self.secondary.insert() + connection.execute(statement, secondary_insert) + + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit, operation): + + # this checks for None if uselist=True + self._verify_canload(child) + + # but if uselist=False we get here. If child is None, + # no association row can be generated, so return. + if child is None: + return False + + if child is not None and not uowcommit.session._contains_state(child): + if not child.deleted: + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" % + (mapperutil.state_class_str(child), operation, self.prop)) + return False + + sync.populate_dict(state, self.parent, associationrow, + self.prop.synchronize_pairs) + sync.populate_dict(child, self.mapper, associationrow, + self.prop.secondary_synchronize_pairs) + + return True + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, + state, + self.parent, + self.prop.synchronize_pairs) + +_direction_to_processor = { + ONETOMANY: OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY: ManyToManyDP, +} diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py new file mode 100644 index 00000000..020b7c71 --- /dev/null +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -0,0 +1,590 @@ +# orm/deprecated_interfaces.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .. import event, util +from .interfaces import EXT_CONTINUE + +@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces") +class MapperExtension(object): + """Base implementation for :class:`.Mapper` event hooks. + + .. note:: + + :class:`.MapperExtension` is deprecated. Please + refer to :func:`.event.listen` as well as + :class:`.MapperEvents`. + + New extension classes subclass :class:`.MapperExtension` and are specified + using the ``extension`` mapper() argument, which is a single + :class:`.MapperExtension` or a list of such:: + + from sqlalchemy.orm.interfaces import MapperExtension + + class MyExtension(MapperExtension): + def before_insert(self, mapper, connection, instance): + print "instance %s before insert !" % instance + + m = mapper(User, users_table, extension=MyExtension()) + + A single mapper can maintain a chain of ``MapperExtension`` + objects. When a particular mapping event occurs, the + corresponding method on each ``MapperExtension`` is invoked + serially, and each method has the ability to halt the chain + from proceeding further:: + + m = mapper(User, users_table, extension=[ext1, ext2, ext3]) + + Each ``MapperExtension`` method returns the symbol + EXT_CONTINUE by default. This symbol generally means "move + to the next ``MapperExtension`` for processing". For methods + that return objects like translated rows or new object + instances, EXT_CONTINUE means the result of the method + should be ignored. In some cases it's required for a + default mapper activity to be performed, such as adding a + new instance to a result list. + + The symbol EXT_STOP has significance within a chain + of ``MapperExtension`` objects that the chain will be stopped + when this symbol is returned. Like EXT_CONTINUE, it also + has additional significance in some cases that a default + mapper activity will not be performed. + + """ + + @classmethod + def _adapt_instrument_class(cls, self, listener): + cls._adapt_listener_methods(self, listener, ('instrument_class',)) + + @classmethod + def _adapt_listener(cls, self, listener): + cls._adapt_listener_methods( + self, listener, + ( + 'init_instance', + 'init_failed', + 'translate_row', + 'create_instance', + 'append_result', + 'populate_instance', + 'reconstruct_instance', + 'before_insert', + 'after_insert', + 'before_update', + 'after_update', + 'before_delete', + 'after_delete' + )) + + @classmethod + def _adapt_listener_methods(cls, self, listener, methods): + + for meth in methods: + me_meth = getattr(MapperExtension, meth) + ls_meth = getattr(listener, meth) + + if not util.methods_equivalent(me_meth, ls_meth): + if meth == 'reconstruct_instance': + def go(ls_meth): + def reconstruct(instance, ctx): + ls_meth(self, instance) + return reconstruct + event.listen(self.class_manager, 'load', + go(ls_meth), raw=False, propagate=True) + elif meth == 'init_instance': + def go(ls_meth): + def init_instance(instance, args, kwargs): + ls_meth(self, self.class_, + self.class_manager.original_init, + instance, args, kwargs) + return init_instance + event.listen(self.class_manager, 'init', + go(ls_meth), raw=False, propagate=True) + elif meth == 'init_failed': + def go(ls_meth): + def init_failed(instance, args, kwargs): + util.warn_exception(ls_meth, self, self.class_, + self.class_manager.original_init, + instance, args, kwargs) + + return init_failed + event.listen(self.class_manager, 'init_failure', + go(ls_meth), raw=False, propagate=True) + else: + event.listen(self, "%s" % meth, ls_meth, + raw=False, retval=True, propagate=True) + + def instrument_class(self, mapper, class_): + """Receive a class when the mapper is first constructed, and has + applied instrumentation to the mapped class. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + """Receive an instance when it's constructor is called. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + """Receive an instance when it's constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def translate_row(self, mapper, context, row): + """Perform pre-processing on the given result row and return a + new row instance. + + This is called when the mapper first receives a row, before + the object identity or the instance itself has been derived + from that row. The given row may or may not be a + ``RowProxy`` object - it will always be a dictionary-like + object which contains mapped columns as keys. The + returned object should also be a dictionary-like object + which recognizes mapped columns as keys. + + If the ultimate return value is EXT_CONTINUE, the row + is not translated. + + """ + return EXT_CONTINUE + + def create_instance(self, mapper, selectcontext, row, class_): + """Receive a row when a new object instance is about to be + created from that row. + + The method can choose to create the instance itself, or it can return + EXT_CONTINUE to indicate normal object creation should take place. + + mapper + The mapper doing the operation + + selectcontext + The QueryContext generated from the Query. + + row + The result row from the database + + class\_ + The class we are mapping. + + return value + A new object instance, or EXT_CONTINUE + + """ + return EXT_CONTINUE + + def append_result(self, mapper, selectcontext, row, instance, + result, **flags): + """Receive an object instance before that instance is appended + to a result list. + + If this method returns EXT_CONTINUE, result appending will proceed + normally. if this method returns any other value or None, + result appending will not proceed for this instance, giving + this extension an opportunity to do the appending itself, if + desired. + + mapper + The mapper doing the operation. + + selectcontext + The QueryContext generated from the Query. + + row + The result row from the database. + + instance + The object instance to be appended to the result. + + result + List to which results are being appended. + + \**flags + extra information about the row, same as criterion in + ``create_row_processor()`` method of + :class:`~sqlalchemy.orm.interfaces.MapperProperty` + """ + + return EXT_CONTINUE + + def populate_instance(self, mapper, selectcontext, row, + instance, **flags): + """Receive an instance before that instance has + its attributes populated. + + This usually corresponds to a newly loaded instance but may + also correspond to an already-loaded instance which has + unloaded attributes to be populated. The method may be called + many times for a single instance, as multiple result rows are + used to populate eagerly loaded collections. + + If this method returns EXT_CONTINUE, instance population will + proceed normally. If any other value or None is returned, + instance population will not proceed, giving this extension an + opportunity to populate the instance itself, if desired. + + .. deprecated:: 0.5 + Most usages of this hook are obsolete. For a + generic "object has been newly created from a row" hook, use + ``reconstruct_instance()``, or the ``@orm.reconstructor`` + decorator. + + """ + return EXT_CONTINUE + + def reconstruct_instance(self, mapper, instance): + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + Note that during a result-row load, this method is called upon + the first row received for this instance. Note that some + attributes and collections may or may not be loaded or even + initialized, depending on what's present in the result rows. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + return EXT_CONTINUE + + def before_insert(self, mapper, connection, instance): + """Receive an object instance before that instance is inserted + into its table. + + This is a good place to set up primary key values and such + that aren't handled otherwise. + + Column-based attributes can be modified within this method + which will result in the new value being inserted. However + *no* changes to the overall flush plan can be made, and + manipulation of the ``Session`` will not have the desired effect. + To manipulate the ``Session`` within an extension, use + ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + """Receive an object instance after that instance is inserted. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def before_update(self, mapper, connection, instance): + """Receive an object instance before that instance is updated. + + Note that this method is called for all instances that are marked as + "dirty", even those which have no net changes to their column-based + attributes. An object is marked as dirty when any of its column-based + attributes have a "set attribute" operation called or when any of its + collections are modified. If, at update time, no column-based + attributes have any net changes, no UPDATE statement will be issued. + This means that an instance being sent to before_update is *not* a + guarantee that an UPDATE statement will be issued (although you can + affect the outcome here). + + To detect if the column-based attributes on the object have net + changes, and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + Column-based attributes can be modified within this method + which will result in the new value being updated. However + *no* changes to the overall flush plan can be made, and + manipulation of the ``Session`` will not have the desired effect. + To manipulate the ``Session`` within an extension, use + ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_update(self, mapper, connection, instance): + """Receive an object instance after that instance is updated. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def before_delete(self, mapper, connection, instance): + """Receive an object instance before that instance is deleted. + + Note that *no* changes to the overall flush plan can be made + here; and manipulation of the ``Session`` will not have the + desired effect. To manipulate the ``Session`` within an + extension, use ``SessionExtension``. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + def after_delete(self, mapper, connection, instance): + """Receive an object instance after that instance is deleted. + + The return value is only significant within the ``MapperExtension`` + chain; the parent mapper's behavior isn't modified by this method. + + """ + + return EXT_CONTINUE + + +@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces") +class SessionExtension(object): + + """Base implementation for :class:`.Session` event hooks. + + .. note:: + + :class:`.SessionExtension` is deprecated. Please + refer to :func:`.event.listen` as well as + :class:`.SessionEvents`. + + Subclasses may be installed into a :class:`.Session` (or + :class:`.sessionmaker`) using the ``extension`` keyword + argument:: + + from sqlalchemy.orm.interfaces import SessionExtension + + class MySessionExtension(SessionExtension): + def before_commit(self, session): + print "before commit!" + + Session = sessionmaker(extension=MySessionExtension()) + + The same :class:`.SessionExtension` instance can be used + with any number of sessions. + + """ + + @classmethod + def _adapt_listener(cls, self, listener): + for meth in [ + 'before_commit', + 'after_commit', + 'after_rollback', + 'before_flush', + 'after_flush', + 'after_flush_postexec', + 'after_begin', + 'after_attach', + 'after_bulk_update', + 'after_bulk_delete', + ]: + me_meth = getattr(SessionExtension, meth) + ls_meth = getattr(listener, meth) + + if not util.methods_equivalent(me_meth, ls_meth): + event.listen(self, meth, getattr(listener, meth)) + + def before_commit(self, session): + """Execute right before commit is called. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def after_commit(self, session): + """Execute after a commit has occurred. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def after_rollback(self, session): + """Execute after a rollback has occurred. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def before_flush(self, session, flush_context, instances): + """Execute before flush process has started. + + `instances` is an optional list of objects which were passed to + the ``flush()`` method. """ + + def after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been + called. + + Note that the session's state is still in pre-flush, i.e. 'new', + 'dirty', and 'deleted' lists still show pre-flush state as well + as the history settings on instance attributes.""" + + def after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec + state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in + their final state. An actual commit() may or may not have + occurred, depending on whether or not the flush started its own + transaction or participated in a larger transaction. """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called + after an engine level transaction is begun on a connection. """ + + def after_attach(self, session, instance): + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. """ + + def after_bulk_update(self, session, query, query_context, result): + """Execute after a bulk update operation to the session. + + This is called after a session.query(...).update() + + `query` is the query object that this update operation was + called on. `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + + def after_bulk_delete(self, session, query, query_context, result): + """Execute after a bulk delete operation to the session. + + This is called after a session.query(...).delete() + + `query` is the query object that this delete operation was + called on. `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + + +@util.langhelpers.dependency_for("sqlalchemy.orm.interfaces") +class AttributeExtension(object): + """Base implementation for :class:`.AttributeImpl` event hooks, events + that fire upon attribute mutations in user code. + + .. note:: + + :class:`.AttributeExtension` is deprecated. Please + refer to :func:`.event.listen` as well as + :class:`.AttributeEvents`. + + :class:`.AttributeExtension` is used to listen for set, + remove, and append events on individual mapped attributes. + It is established on an individual mapped attribute using + the `extension` argument, available on + :func:`.column_property`, :func:`.relationship`, and + others:: + + from sqlalchemy.orm.interfaces import AttributeExtension + from sqlalchemy.orm import mapper, relationship, column_property + + class MyAttrExt(AttributeExtension): + def append(self, state, value, initiator): + print "append event !" + return value + + def set(self, state, value, oldvalue, initiator): + print "set event !" + return value + + mapper(SomeClass, sometable, properties={ + 'foo':column_property(sometable.c.foo, extension=MyAttrExt()), + 'bar':relationship(Bar, extension=MyAttrExt()) + }) + + Note that the :class:`.AttributeExtension` methods + :meth:`~.AttributeExtension.append` and + :meth:`~.AttributeExtension.set` need to return the + ``value`` parameter. The returned value is used as the + effective value, and allows the extension to change what is + ultimately persisted. + + AttributeExtension is assembled within the descriptors associated + with a mapped class. + + """ + + active_history = True + """indicates that the set() method would like to receive the 'old' value, + even if it means firing lazy callables. + + Note that ``active_history`` can also be set directly via + :func:`.column_property` and :func:`.relationship`. + + """ + + @classmethod + def _adapt_listener(cls, self, listener): + event.listen(self, 'append', listener.append, + active_history=listener.active_history, + raw=True, retval=True) + event.listen(self, 'remove', listener.remove, + active_history=listener.active_history, + raw=True, retval=True) + event.listen(self, 'set', listener.set, + active_history=listener.active_history, + raw=True, retval=True) + + def append(self, state, value, initiator): + """Receive a collection append event. + + The returned value will be used as the actual value to be + appended. + + """ + return value + + def remove(self, state, value, initiator): + """Receive a remove event. + + No return value is defined. + + """ + pass + + def set(self, state, value, oldvalue, initiator): + """Receive a set event. + + The returned value will be used as the actual value to be + set. + + """ + return value diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py new file mode 100644 index 00000000..9ecc9bb6 --- /dev/null +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -0,0 +1,678 @@ +# orm/descriptor_props.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Descriptor properties are more "auxiliary" properties +that exist as configurational elements, but don't participate +as actively in the load/persist ORM loop. + +""" + +from .interfaces import MapperProperty, PropComparator +from .util import _none_set +from . import attributes +from .. import util, sql, exc as sa_exc, event, schema +from ..sql import expression +from . import properties +from . import query + + +class DescriptorProperty(MapperProperty): + """:class:`.MapperProperty` which proxies access to a + user-defined descriptor.""" + + doc = None + + def instrument_class(self, mapper): + prop = self + + class _ProxyImpl(object): + accepts_scalar_loader = False + expire_missing = True + collection = False + + def __init__(self, key): + self.key = key + + if hasattr(prop, 'get_history'): + def get_history(self, state, dict_, + passive=attributes.PASSIVE_OFF): + return prop.get_history(state, dict_, passive) + + if self.descriptor is None: + desc = getattr(mapper.class_, self.key, None) + if mapper._is_userland_descriptor(desc): + self.descriptor = desc + + if self.descriptor is None: + def fset(obj, value): + setattr(obj, self.name, value) + + def fdel(obj): + delattr(obj, self.name) + + def fget(obj): + return getattr(obj, self.name) + + self.descriptor = property( + fget=fget, + fset=fset, + fdel=fdel, + ) + + proxy_attr = attributes.\ + create_proxied_attribute(self.descriptor)\ + ( + self.parent.class_, + self.key, + self.descriptor, + lambda: self._comparator_factory(mapper), + doc=self.doc, + original_property=self + ) + proxy_attr.impl = _ProxyImpl(self.key) + mapper.class_manager.instrument_attribute(self.key, proxy_attr) + + +@util.langhelpers.dependency_for("sqlalchemy.orm.properties") +class CompositeProperty(DescriptorProperty): + """Defines a "composite" mapped attribute, representing a collection + of columns as one attribute. + + :class:`.CompositeProperty` is constructed using the :func:`.composite` + function. + + .. seealso:: + + :ref:`mapper_composite` + + """ + def __init__(self, class_, *attrs, **kwargs): + """Return a composite column-based property for use with a Mapper. + + See the mapping documentation section :ref:`mapper_composite` for a full + usage example. + + The :class:`.MapperProperty` returned by :func:`.composite` + is the :class:`.CompositeProperty`. + + :param class\_: + The "composite type" class. + + :param \*cols: + List of Column objects to be mapped. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. See the same flag on :func:`.column_property`. + + .. versionchanged:: 0.7 + This flag specifically becomes meaningful + - previously it was a placeholder. + + :param group: + A group name for this property when marked as deferred. + + :param deferred: + When True, the column property is "deferred", meaning that it does not + load immediately, and is instead loaded when the attribute is first + accessed on an instance. See also :func:`~sqlalchemy.orm.deferred`. + + :param comparator_factory: a class which extends + :class:`.CompositeProperty.Comparator` which provides custom SQL clause + generation for comparison operations. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + .. versionadded:: 0.8 + + :param extension: + an :class:`.AttributeExtension` instance, + or list of extensions, which will be prepended to the list of + attribute listeners for the resulting descriptor placed on the class. + **Deprecated.** Please see :class:`.AttributeEvents`. + + """ + + self.attrs = attrs + self.composite_class = class_ + self.active_history = kwargs.get('active_history', False) + self.deferred = kwargs.get('deferred', False) + self.group = kwargs.get('group', None) + self.comparator_factory = kwargs.pop('comparator_factory', + self.__class__.Comparator) + if 'info' in kwargs: + self.info = kwargs.pop('info') + + util.set_creation_order(self) + self._create_descriptor() + + + def instrument_class(self, mapper): + super(CompositeProperty, self).instrument_class(mapper) + self._setup_event_handlers() + + def do_init(self): + """Initialization which occurs after the :class:`.CompositeProperty` + has been associated with its parent mapper. + + """ + self._setup_arguments_on_columns() + + def _create_descriptor(self): + """Create the Python descriptor that will serve as + the access point on instances of the mapped class. + + """ + + def fget(instance): + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + + if self.key not in dict_: + # key not present. Iterate through related + # attributes, retrieve their values. This + # ensures they all load. + values = [ + getattr(instance, key) + for key in self._attribute_keys + ] + + # current expected behavior here is that the composite is + # created on access if the object is persistent or if + # col attributes have non-None. This would be better + # if the composite were created unconditionally, + # but that would be a behavioral change. + if self.key not in dict_ and ( + state.key is not None or + not _none_set.issuperset(values) + ): + dict_[self.key] = self.composite_class(*values) + state.manager.dispatch.refresh(state, None, [self.key]) + + return dict_.get(self.key, None) + + def fset(instance, value): + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + attr = state.manager[self.key] + previous = dict_.get(self.key, attributes.NO_VALUE) + for fn in attr.dispatch.set: + value = fn(state, value, previous, attr.impl) + dict_[self.key] = value + if value is None: + for key in self._attribute_keys: + setattr(instance, key, None) + else: + for key, value in zip( + self._attribute_keys, + value.__composite_values__()): + setattr(instance, key, value) + + def fdel(instance): + state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) + previous = dict_.pop(self.key, attributes.NO_VALUE) + attr = state.manager[self.key] + attr.dispatch.remove(state, previous, attr.impl) + for key in self._attribute_keys: + setattr(instance, key, None) + + self.descriptor = property(fget, fset, fdel) + + @util.memoized_property + def _comparable_elements(self): + return [ + getattr(self.parent.class_, prop.key) + for prop in self.props + ] + + @util.memoized_property + def props(self): + props = [] + for attr in self.attrs: + if isinstance(attr, str): + prop = self.parent.get_property(attr, _configure_mappers=False) + elif isinstance(attr, schema.Column): + prop = self.parent._columntoproperty[attr] + elif isinstance(attr, attributes.InstrumentedAttribute): + prop = attr.property + else: + raise sa_exc.ArgumentError( + "Composite expects Column objects or mapped " + "attributes/attribute names as arguments, got: %r" + % (attr,)) + props.append(prop) + return props + + @property + def columns(self): + return [a for a in self.attrs if isinstance(a, schema.Column)] + + def _setup_arguments_on_columns(self): + """Propagate configuration arguments made on this composite + to the target columns, for those that apply. + + """ + for prop in self.props: + prop.active_history = self.active_history + if self.deferred: + prop.deferred = self.deferred + prop.strategy_class = prop._strategy_lookup( + ("deferred", True), + ("instrument", True)) + prop.group = self.group + + def _setup_event_handlers(self): + """Establish events that populate/expire the composite attribute.""" + + def load_handler(state, *args): + dict_ = state.dict + + if self.key in dict_: + return + + # if column elements aren't loaded, skip. + # __get__() will initiate a load for those + # columns + for k in self._attribute_keys: + if k not in dict_: + return + + #assert self.key not in dict_ + dict_[self.key] = self.composite_class( + *[state.dict[key] for key in + self._attribute_keys] + ) + + def expire_handler(state, keys): + if keys is None or set(self._attribute_keys).intersection(keys): + state.dict.pop(self.key, None) + + def insert_update_handler(mapper, connection, state): + """After an insert or update, some columns may be expired due + to server side defaults, or re-populated due to client side + defaults. Pop out the composite value here so that it + recreates. + + """ + + state.dict.pop(self.key, None) + + event.listen(self.parent, 'after_insert', + insert_update_handler, raw=True) + event.listen(self.parent, 'after_update', + insert_update_handler, raw=True) + event.listen(self.parent, 'load', + load_handler, raw=True, propagate=True) + event.listen(self.parent, 'refresh', + load_handler, raw=True, propagate=True) + event.listen(self.parent, 'expire', + expire_handler, raw=True, propagate=True) + + # TODO: need a deserialize hook here + + @util.memoized_property + def _attribute_keys(self): + return [ + prop.key for prop in self.props + ] + + def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): + """Provided for userland code that uses attributes.get_history().""" + + added = [] + deleted = [] + + has_history = False + for prop in self.props: + key = prop.key + hist = state.manager[key].impl.get_history(state, dict_) + if hist.has_changes(): + has_history = True + + non_deleted = hist.non_deleted() + if non_deleted: + added.extend(non_deleted) + else: + added.append(None) + if hist.deleted: + deleted.extend(hist.deleted) + else: + deleted.append(None) + + if has_history: + return attributes.History( + [self.composite_class(*added)], + (), + [self.composite_class(*deleted)] + ) + else: + return attributes.History( + (), [self.composite_class(*added)], () + ) + + def _comparator_factory(self, mapper): + return self.comparator_factory(self, mapper) + + class CompositeBundle(query.Bundle): + def __init__(self, property, expr): + self.property = property + super(CompositeProperty.CompositeBundle, self).__init__( + property.key, *expr) + + def create_row_processor(self, query, procs, labels): + def proc(row, result): + return self.property.composite_class(*[proc(row, result) for proc in procs]) + return proc + + + class Comparator(PropComparator): + """Produce boolean, comparison, and other operators for + :class:`.CompositeProperty` attributes. + + See the example in :ref:`composite_operations` for an overview + of usage , as well as the documentation for :class:`.PropComparator`. + + See also: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + + __hash__ = None + + @property + def clauses(self): + return self.__clause_element__() + + def __clause_element__(self): + return expression.ClauseList(group=False, *self._comparable_elements) + + def _query_clause_element(self): + return CompositeProperty.CompositeBundle(self.prop, self.__clause_element__()) + + @util.memoized_property + def _comparable_elements(self): + if self._adapt_to_entity: + return [ + getattr( + self._adapt_to_entity.entity, + prop.key + ) for prop in self.prop._comparable_elements + ] + else: + return self.prop._comparable_elements + + def __eq__(self, other): + if other is None: + values = [None] * len(self.prop._comparable_elements) + else: + values = other.__composite_values__() + comparisons = [ + a == b + for a, b in zip(self.prop._comparable_elements, values) + ] + if self._adapt_to_entity: + comparisons = [self.adapter(x) for x in comparisons] + return sql.and_(*comparisons) + + def __ne__(self, other): + return sql.not_(self.__eq__(other)) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + + +@util.langhelpers.dependency_for("sqlalchemy.orm.properties") +class ConcreteInheritedProperty(DescriptorProperty): + """A 'do nothing' :class:`.MapperProperty` that disables + an attribute on a concrete subclass that is only present + on the inherited mapper, not the concrete classes' mapper. + + Cases where this occurs include: + + * When the superclass mapper is mapped against a + "polymorphic union", which includes all attributes from + all subclasses. + * When a relationship() is configured on an inherited mapper, + but not on the subclass mapper. Concrete mappers require + that relationship() is configured explicitly on each + subclass. + + """ + + def _comparator_factory(self, mapper): + comparator_callable = None + + for m in self.parent.iterate_to_root(): + p = m._props[self.key] + if not isinstance(p, ConcreteInheritedProperty): + comparator_callable = p.comparator_factory + break + return comparator_callable + + def __init__(self): + def warn(): + raise AttributeError("Concrete %s does not implement " + "attribute %r at the instance level. Add this " + "property explicitly to %s." % + (self.parent, self.key, self.parent)) + + class NoninheritedConcreteProp(object): + def __set__(s, obj, value): + warn() + + def __delete__(s, obj): + warn() + + def __get__(s, obj, owner): + if obj is None: + return self.descriptor + warn() + self.descriptor = NoninheritedConcreteProp() + + +@util.langhelpers.dependency_for("sqlalchemy.orm.properties") +class SynonymProperty(DescriptorProperty): + + def __init__(self, name, map_column=None, + descriptor=None, comparator_factory=None, + doc=None): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :param name: the name of the existing mapped property. This + can refer to the string name of any :class:`.MapperProperty` + configured on the class, including column-bound attributes + and relationships. + + :param descriptor: a Python :term:`descriptor` that will be used + as a getter (and potentially a setter) when this attribute is + accessed at the instance level. + + :param map_column: if ``True``, the :func:`.synonym` construct will + locate the existing named :class:`.MapperProperty` based on the + attribute name of this :func:`.synonym`, and assign it to a new + attribute linked to the name of this :func:`.synonym`. + That is, given a mapping like:: + + class MyClass(Base): + __tablename__ = 'my_table' + + id = Column(Integer, primary_key=True) + job_status = Column(String(50)) + + job_status = synonym("_job_status", map_column=True) + + The above class ``MyClass`` will now have the ``job_status`` + :class:`.Column` object mapped to the attribute named ``_job_status``, + and the attribute named ``job_status`` will refer to the synonym + itself. This feature is typically used in conjunction with the + ``descriptor`` argument in order to link a user-defined descriptor + as a "wrapper" for an existing column. + + :param comparator_factory: A subclass of :class:`.PropComparator` + that will provide custom comparison behavior at the SQL expression + level. + + .. note:: + + For the use case of providing an attribute which redefines both + Python-level and SQL-expression level behavior of an attribute, + please refer to the Hybrid attribute introduced at + :ref:`mapper_hybrids` for a more effective technique. + + .. seealso:: + + :ref:`synonyms` - examples of functionality. + + :ref:`mapper_hybrids` - Hybrids provide a better approach for + more complicated attribute-wrapping schemes than synonyms. + + """ + + self.name = name + self.map_column = map_column + self.descriptor = descriptor + self.comparator_factory = comparator_factory + self.doc = doc or (descriptor and descriptor.__doc__) or None + + util.set_creation_order(self) + + # TODO: when initialized, check _proxied_property, + # emit a warning if its not a column-based property + + @util.memoized_property + def _proxied_property(self): + return getattr(self.parent.class_, self.name).property + + def _comparator_factory(self, mapper): + prop = self._proxied_property + + if self.comparator_factory: + comp = self.comparator_factory(prop, mapper) + else: + comp = prop.comparator_factory(prop, mapper) + return comp + + def set_parent(self, parent, init): + if self.map_column: + # implement the 'map_column' option. + if self.key not in parent.mapped_table.c: + raise sa_exc.ArgumentError( + "Can't compile synonym '%s': no column on table " + "'%s' named '%s'" + % (self.name, parent.mapped_table.description, self.key)) + elif parent.mapped_table.c[self.key] in \ + parent._columntoproperty and \ + parent._columntoproperty[ + parent.mapped_table.c[self.key] + ].key == self.name: + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % + (self.key, self.name, self.name, self.key) + ) + p = properties.ColumnProperty(parent.mapped_table.c[self.key]) + parent._configure_property( + self.name, p, + init=init, + setparent=True) + p._mapped_by_synonym = self.key + + self.parent = parent + + +@util.langhelpers.dependency_for("sqlalchemy.orm.properties") +class ComparableProperty(DescriptorProperty): + """Instruments a Python property for use in query expressions.""" + + def __init__(self, comparator_factory, descriptor=None, doc=None): + """Provides a method of applying a :class:`.PropComparator` + to any Python descriptor attribute. + + .. versionchanged:: 0.7 + :func:`.comparable_property` is superseded by + the :mod:`~sqlalchemy.ext.hybrid` extension. See the example + at :ref:`hybrid_custom_comparators`. + + Allows any Python descriptor to behave like a SQL-enabled + attribute when used at the class level in queries, allowing + redefinition of expression operator behavior. + + In the example below we redefine :meth:`.PropComparator.operate` + to wrap both sides of an expression in ``func.lower()`` to produce + case-insensitive comparison:: + + from sqlalchemy.orm import comparable_property + from sqlalchemy.orm.interfaces import PropComparator + from sqlalchemy.sql import func + from sqlalchemy import Integer, String, Column + from sqlalchemy.ext.declarative import declarative_base + + class CaseInsensitiveComparator(PropComparator): + def __clause_element__(self): + return self.prop + + def operate(self, op, other): + return op( + func.lower(self.__clause_element__()), + func.lower(other) + ) + + Base = declarative_base() + + class SearchWord(Base): + __tablename__ = 'search_word' + id = Column(Integer, primary_key=True) + word = Column(String) + word_insensitive = comparable_property(lambda prop, mapper: + CaseInsensitiveComparator(mapper.c.word, mapper) + ) + + + A mapping like the above allows the ``word_insensitive`` attribute + to render an expression like:: + + >>> print SearchWord.word_insensitive == "Trucks" + lower(search_word.word) = lower(:lower_1) + + :param comparator_factory: + A PropComparator subclass or factory that defines operator behavior + for this property. + + :param descriptor: + Optional when used in a ``properties={}`` declaration. The Python + descriptor or property to layer comparison behavior on top of. + + The like-named descriptor will be automatically retrieved from the + mapped class if left blank in a ``properties`` declaration. + + """ + self.descriptor = descriptor + self.comparator_factory = comparator_factory + self.doc = doc or (descriptor and descriptor.__doc__) or None + util.set_creation_order(self) + + def _comparator_factory(self, mapper): + return self.comparator_factory(self, mapper) + + diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py new file mode 100644 index 00000000..bae09d32 --- /dev/null +++ b/lib/sqlalchemy/orm/dynamic.py @@ -0,0 +1,368 @@ +# orm/dynamic.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Dynamic collection API. + +Dynamic collections act like Query() objects for read operations and support +basic add/delete mutation. + +""" + +from .. import log, util, exc +from ..sql import operators +from . import ( + attributes, object_session, util as orm_util, strategies, + object_mapper, exc as orm_exc, properties + ) +from .query import Query + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="dynamic") +class DynaLoader(strategies.AbstractRelationshipLoader): + def init_class_attribute(self, mapper): + self.is_class_level = True + if not self.uselist: + raise exc.InvalidRequestError( + "On relationship %s, 'dynamic' loaders cannot be used with " + "many-to-one/one-to-one relationships and/or " + "uselist=False." % self.parent_property) + strategies._register_attribute(self, + mapper, + useobject=True, + uselist=True, + impl_class=DynamicAttributeImpl, + target_mapper=self.parent_property.mapper, + order_by=self.parent_property.order_by, + query_class=self.parent_property.query_class, + backref=self.parent_property.back_populates, + ) + +class DynamicAttributeImpl(attributes.AttributeImpl): + uses_objects = True + accepts_scalar_loader = False + supports_population = False + collection = False + + def __init__(self, class_, key, typecallable, + dispatch, + target_mapper, order_by, query_class=None, **kw): + super(DynamicAttributeImpl, self).\ + __init__(class_, key, typecallable, dispatch, **kw) + self.target_mapper = target_mapper + self.order_by = order_by + if not query_class: + self.query_class = AppenderQuery + elif AppenderMixin in query_class.mro(): + self.query_class = query_class + else: + self.query_class = mixin_user_query(query_class) + + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if not passive & attributes.SQL_OK: + return self._get_collection_history(state, + attributes.PASSIVE_NO_INITIALIZE).added_items + else: + return self.query_class(self, state) + + def get_collection(self, state, dict_, user_data=None, + passive=attributes.PASSIVE_NO_INITIALIZE): + if not passive & attributes.SQL_OK: + return self._get_collection_history(state, + passive).added_items + else: + history = self._get_collection_history(state, passive) + return history.added_plus_unchanged + + @util.memoized_property + def _append_token(self): + return attributes.Event(self, attributes.OP_APPEND) + + @util.memoized_property + def _remove_token(self): + return attributes.Event(self, attributes.OP_REMOVE) + + def fire_append_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_added(value) + + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, True) + + def fire_remove_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_removed(value) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + def _modified_event(self, state, dict_): + + if self.key not in state.committed_state: + state.committed_state[self.key] = CollectionHistory(self, state) + + state._modified_event(dict_, + self, + attributes.NEVER_SET) + + # this is a hack to allow the fixtures.ComparableEntity fixture + # to work + dict_[self.key] = True + return state.committed_state[self.key] + + def set(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF, + check_old=None, pop=False): + if initiator and initiator.parent_token is self.parent_token: + return + + if pop and value is None: + return + self._set_iterable(state, dict_, value) + + def _set_iterable(self, state, dict_, iterable, adapter=None): + new_values = list(iterable) + if state.has_identity: + old_collection = util.IdentitySet(self.get(state, dict_)) + + collection_history = self._modified_event(state, dict_) + if not state.has_identity: + old_collection = collection_history.added_items + else: + old_collection = old_collection.union( + collection_history.added_items) + + idset = util.IdentitySet + constants = old_collection.intersection(new_values) + additions = idset(new_values).difference(constants) + removals = old_collection.difference(constants) + + for member in new_values: + if member in additions: + self.fire_append_event(state, dict_, member, None, + collection_history=collection_history) + + for member in removals: + self.fire_remove_event(state, dict_, member, None, + collection_history=collection_history) + + def delete(self, *args, **kwargs): + raise NotImplementedError() + + def set_committed_value(self, state, dict_, value): + raise NotImplementedError("Dynamic attributes don't support " + "collection population.") + + def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): + c = self._get_collection_history(state, passive) + return c.as_history() + + def get_all_pending(self, state, dict_): + c = self._get_collection_history( + state, attributes.PASSIVE_NO_INITIALIZE) + return [ + (attributes.instance_state(x), x) + for x in + c.all_items + ] + + def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF): + if self.key in state.committed_state: + c = state.committed_state[self.key] + else: + c = CollectionHistory(self, state) + + if state.has_identity and (passive & attributes.INIT_OK): + return CollectionHistory(self, state, apply_to=c) + else: + return c + + def append(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): + if initiator is not self: + self.fire_append_event(state, dict_, value, initiator) + + def remove(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): + if initiator is not self: + self.fire_remove_event(state, dict_, value, initiator) + + def pop(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): + self.remove(state, dict_, value, initiator, passive=passive) + + +class AppenderMixin(object): + query_class = None + + def __init__(self, attr, state): + super(AppenderMixin, self).__init__(attr.target_mapper, None) + self.instance = instance = state.obj() + self.attr = attr + + mapper = object_mapper(instance) + prop = mapper._props[self.attr.key] + self._criterion = prop.compare( + operators.eq, + instance, + value_is_parent=True, + alias_secondary=False) + + if self.attr.order_by: + self._order_by = self.attr.order_by + + def session(self): + sess = object_session(self.instance) + if sess is not None and self.autoflush and sess.autoflush \ + and self.instance in sess: + sess.flush() + if not orm_util.has_identity(self.instance): + return None + else: + return sess + session = property(session, lambda s, x: None) + + def __iter__(self): + sess = self.session + if sess is None: + return iter(self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE).added_items) + else: + return iter(self._clone(sess)) + + def __getitem__(self, index): + sess = self.session + if sess is None: + return self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE).indexed(index) + else: + return self._clone(sess).__getitem__(index) + + def count(self): + sess = self.session + if sess is None: + return len(self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE).added_items) + else: + return self._clone(sess).count() + + def _clone(self, sess=None): + # note we're returning an entirely new Query class instance + # here without any assignment capabilities; the class of this + # query is determined by the session. + instance = self.instance + if sess is None: + sess = object_session(instance) + if sess is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session, and no " + "contextual session is established; lazy load operation " + "of attribute '%s' cannot proceed" % ( + orm_util.instance_str(instance), self.attr.key)) + + if self.query_class: + query = self.query_class(self.attr.target_mapper, session=sess) + else: + query = sess.query(self.attr.target_mapper) + + query._criterion = self._criterion + query._order_by = self._order_by + + return query + + def extend(self, iterator): + for item in iterator: + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + + def append(self, item): + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + + def remove(self, item): + self.attr.remove( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + + +class AppenderQuery(AppenderMixin, Query): + """A dynamic query that supports basic collection storage operations.""" + + +def mixin_user_query(cls): + """Return a new class with AppenderQuery functionality layered over.""" + name = 'Appender' + cls.__name__ + return type(name, (AppenderMixin, cls), {'query_class': cls}) + + +class CollectionHistory(object): + """Overrides AttributeHistory to receive append/remove events directly.""" + + def __init__(self, attr, state, apply_to=None): + if apply_to: + coll = AppenderQuery(attr, state).autoflush(False) + self.unchanged_items = util.OrderedIdentitySet(coll) + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + self._reconcile_collection = True + else: + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + + @property + def added_plus_unchanged(self): + return list(self.added_items.union(self.unchanged_items)) + + @property + def all_items(self): + return list(self.added_items.union( + self.unchanged_items).union(self.deleted_items)) + + def as_history(self): + if self._reconcile_collection: + added = self.added_items.difference(self.unchanged_items) + deleted = self.deleted_items.intersection(self.unchanged_items) + unchanged = self.unchanged_items.difference(deleted) + else: + added, unchanged, deleted = self.added_items,\ + self.unchanged_items,\ + self.deleted_items + return attributes.History( + list(added), + list(unchanged), + list(deleted), + ) + + def indexed(self, index): + return list(self.added_items)[index] + + def add_added(self, value): + self.added_items.add(value) + + def add_removed(self, value): + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) + diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py new file mode 100644 index 00000000..e1dd9606 --- /dev/null +++ b/lib/sqlalchemy/orm/evaluator.py @@ -0,0 +1,123 @@ +# orm/evaluator.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import operator +from ..sql import operators + + +class UnevaluatableError(Exception): + pass + +_straight_ops = set(getattr(operators, op) + for op in ('add', 'mul', 'sub', + 'div', + 'mod', 'truediv', + 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) + + +_notimplemented_ops = set(getattr(operators, op) + for op in ('like_op', 'notlike_op', 'ilike_op', + 'notilike_op', 'between_op', 'in_op', + 'notin_op', 'endswith_op', 'concat_op')) + + +class EvaluatorCompiler(object): + def process(self, clause): + meth = getattr(self, "visit_%s" % clause.__visit_name__, None) + if not meth: + raise UnevaluatableError( + "Cannot evaluate %s" % type(clause).__name__) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_false(self, clause): + return lambda obj: False + + def visit_true(self, clause): + return lambda obj: True + + def visit_column(self, clause): + if 'parentmapper' in clause._annotations: + key = clause._annotations['parentmapper'].\ + _columntoproperty[clause].key + else: + key = clause.key + get_corresponding_attr = operator.attrgetter(key) + return lambda obj: get_corresponding_attr(obj) + + def visit_clauselist(self, clause): + evaluators = list(map(self.process, clause.clauses)) + if clause.operator is operators.or_: + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + elif clause.operator is operators.and_: + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if not value: + if value is None: + return None + return False + return True + else: + raise UnevaluatableError( + "Cannot evaluate clauselist with operator %s" % + clause.operator) + + return evaluate + + def visit_binary(self, clause): + eval_left, eval_right = list(map(self.process, + [clause.left, clause.right])) + operator = clause.operator + if operator is operators.is_: + def evaluate(obj): + return eval_left(obj) == eval_right(obj) + elif operator is operators.isnot: + def evaluate(obj): + return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + return operator(eval_left(obj), eval_right(obj)) + else: + raise UnevaluatableError( + "Cannot evaluate %s with operator %s" % + (type(clause).__name__, clause.operator)) + return evaluate + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + def evaluate(obj): + value = eval_inner(obj) + if value is None: + return None + return not value + return evaluate + raise UnevaluatableError( + "Cannot evaluate %s with operator %s" % + (type(clause).__name__, clause.operator)) + + def visit_bindparam(self, clause): + val = clause.value + return lambda obj: val diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py new file mode 100644 index 00000000..078f4d12 --- /dev/null +++ b/lib/sqlalchemy/orm/events.py @@ -0,0 +1,1711 @@ +# orm/events.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""ORM event interfaces. + +""" +from .. import event, exc, util +from .base import _mapper_or_none +import inspect +import weakref +from . import interfaces +from . import mapperlib, instrumentation +from .session import Session, sessionmaker +from .scoping import scoped_session +from .attributes import QueryableAttribute + +class InstrumentationEvents(event.Events): + """Events related to class instrumentation events. + + The listeners here support being established against + any new style class, that is any object that is a subclass + of 'type'. Events will then be fired off for events + against that class. If the "propagate=True" flag is passed + to event.listen(), the event will fire off for subclasses + of that class as well. + + The Python ``type`` builtin is also accepted as a target, + which when used has the effect of events being emitted + for all classes. + + Note the "propagate" flag here is defaulted to ``True``, + unlike the other class level events where it defaults + to ``False``. This means that new subclasses will also + be the subject of these events, when a listener + is established on a superclass. + + .. versionchanged:: 0.8 - events here will emit based + on comparing the incoming class to the type of class + passed to :func:`.event.listen`. Previously, the + event would fire for any class unconditionally regardless + of what class was sent for listening, despite + documentation which stated the contrary. + + """ + + _target_class_doc = "SomeBaseClass" + _dispatch_target = instrumentation.InstrumentationFactory + + + @classmethod + def _accept_with(cls, target): + if isinstance(target, type): + return _InstrumentationEventsHold(target) + else: + return None + + @classmethod + def _listen(cls, event_key, propagate=True, **kw): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + def listen(target_cls, *arg): + listen_cls = target() + if propagate and issubclass(target_cls, listen_cls): + return fn(target_cls, *arg) + elif not propagate and target_cls is listen_cls: + return fn(target_cls, *arg) + + def remove(ref): + key = event.registry._EventKey(None, identifier, listen, + instrumentation._instrumentation_factory) + getattr(instrumentation._instrumentation_factory.dispatch, + identifier).remove(key) + + target = weakref.ref(target.class_, remove) + + event_key.\ + with_dispatch_target(instrumentation._instrumentation_factory).\ + with_wrapper(listen).base_listen(**kw) + + @classmethod + def _clear(cls): + super(InstrumentationEvents, cls)._clear() + instrumentation._instrumentation_factory.dispatch._clear() + + def class_instrument(self, cls): + """Called after the given class is instrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def class_uninstrument(self, cls): + """Called before the given class is uninstrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def attribute_instrument(self, cls, key, inst): + """Called when an attribute is instrumented.""" + + + +class _InstrumentationEventsHold(object): + """temporary marker object used to transfer from _accept_with() to + _listen() on the InstrumentationEvents class. + + """ + def __init__(self, class_): + self.class_ = class_ + + dispatch = event.dispatcher(InstrumentationEvents) + +class InstanceEvents(event.Events): + """Define events specific to object lifecycle. + + e.g.:: + + from sqlalchemy import event + + def my_load_listener(target, context): + print "on load!" + + event.listen(SomeClass, 'load', my_load_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`.Mapper` objects + * the :class:`.Mapper` class itself and the :func:`.mapper` + function indicate listening for all mappers. + + .. versionchanged:: 0.8.0 instance events can be associated with + unmapped superclasses of mapped classes. + + Instance events are closely related to mapper events, but + are more specific to the instance and its instrumentation, + rather than its system of persistence. + + When using :class:`.InstanceEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting classes as well as the + class which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + + """ + + _target_class_doc = "SomeClass" + + _dispatch_target = instrumentation.ClassManager + + @classmethod + def _new_classmanager_instance(cls, class_, classmanager): + _InstanceEventsHold.populate(class_, classmanager) + + @classmethod + @util.dependencies("sqlalchemy.orm") + def _accept_with(cls, orm, target): + if isinstance(target, instrumentation.ClassManager): + return target + elif isinstance(target, mapperlib.Mapper): + return target.class_manager + elif target is orm.mapper: + return instrumentation.ClassManager + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return instrumentation.ClassManager + else: + manager = instrumentation.manager_of_class(target) + if manager: + return manager + else: + return _InstanceEventsHold(target) + return None + + @classmethod + def _listen(cls, event_key, raw=False, propagate=False, **kw): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + if not raw: + def wrap(state, *arg, **kw): + return fn(state.obj(), *arg, **kw) + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate, **kw) + + if propagate: + for mgr in target.subclass_managers(True): + event_key.with_dispatch_target(mgr).base_listen(propagate=True) + + @classmethod + def _clear(cls): + super(InstanceEvents, cls)._clear() + _InstanceEventsHold._clear() + + def first_init(self, manager, cls): + """Called when the first instance of a particular mapping is called. + + """ + + def init(self, target, args, kwargs): + """Receive an instance when it's constructor is called. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + """ + + def init_failure(self, target, args, kwargs): + """Receive an instance when it's constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + """ + + def load(self, target, context): + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + Note that during a result-row load, this method is called upon + the first row received for this instance. Note that some + attributes and collections may or may not be loaded or even + initialized, depending on what's present in the result rows. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`.Query` in progress. This argument may be + ``None`` if the load does not correspond to a :class:`.Query`, + such as during :meth:`.Session.merge`. + + """ + + def refresh(self, target, context, attrs): + """Receive an object instance after one or more attributes have + been refreshed from a query. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`.Query` in progress. + :param attrs: iterable collection of attribute names which + were populated, or None if all column-mapped, non-deferred + attributes were populated. + + """ + + def expire(self, target, attrs): + """Receive an object instance after its attributes or some subset + have been expired. + + 'keys' is a list of attribute names. If None, the entire + state was expired. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param attrs: iterable collection of attribute + names which were expired, or None if all attributes were + expired. + + """ + + def resurrect(self, target): + """Receive an object instance as it is 'resurrected' from + garbage collection, which occurs when a "dirty" state falls + out of scope. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + + """ + + def pickle(self, target, state_dict): + """Receive an object instance when its associated state is + being pickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary returned by + :class:`.InstanceState.__getstate__`, containing the state + to be pickled. + + """ + + def unpickle(self, target, state_dict): + """Receive an object instance after it's associated state has + been unpickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary sent to + :class:`.InstanceState.__setstate__`, containing the state + dictionary which was pickled. + + """ + +class _EventsHold(event.RefCollection): + """Hold onto listeners against unmapped, uninstrumented classes. + + Establish _listen() for that class' mapper/instrumentation when + those objects are created for that class. + + """ + def __init__(self, class_): + self.class_ = class_ + + @classmethod + def _clear(cls): + cls.all_holds.clear() + + class HoldEvents(object): + _dispatch_target = None + + @classmethod + def _listen(cls, event_key, raw=False, propagate=False, **kw): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + if target.class_ in target.all_holds: + collection = target.all_holds[target.class_] + else: + collection = target.all_holds[target.class_] = {} + + event.registry._stored_in_collection(event_key, target) + collection[event_key._key] = (event_key, raw, propagate) + + if propagate: + stack = list(target.class_.__subclasses__()) + while stack: + subclass = stack.pop(0) + stack.extend(subclass.__subclasses__()) + subject = target.resolve(subclass) + if subject is not None: + # we are already going through __subclasses__() + # so leave generic propagate flag False + event_key.with_dispatch_target(subject).\ + listen(raw=raw, propagate=False, **kw) + + def remove(self, event_key): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + if isinstance(target, _EventsHold): + collection = target.all_holds[target.class_] + del collection[event_key._key] + + @classmethod + def populate(cls, class_, subject): + for subclass in class_.__mro__: + if subclass in cls.all_holds: + collection = cls.all_holds[subclass] + for event_key, raw, propagate in collection.values(): + if propagate or subclass is class_: + # since we can't be sure in what order different classes + # in a hierarchy are triggered with populate(), + # we rely upon _EventsHold for all event + # assignment, instead of using the generic propagate + # flag. + event_key.with_dispatch_target(subject).\ + listen(raw=raw, propagate=False) + + +class _InstanceEventsHold(_EventsHold): + all_holds = weakref.WeakKeyDictionary() + + def resolve(self, class_): + return instrumentation.manager_of_class(class_) + + class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): + pass + + dispatch = event.dispatcher(HoldInstanceEvents) + + +class MapperEvents(event.Events): + """Define events specific to mappings. + + e.g.:: + + from sqlalchemy import event + + def my_before_insert_listener(mapper, connection, target): + # execute a stored procedure upon INSERT, + # apply the value to the row to be inserted + target.calculated_value = connection.scalar( + "select my_special_function(%d)" + % target.special_number) + + # associate the listener function with SomeClass, + # to execute during the "before_insert" hook + event.listen( + SomeClass, 'before_insert', my_before_insert_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`.Mapper` objects + * the :class:`.Mapper` class itself and the :func:`.mapper` + function indicate listening for all mappers. + + .. versionchanged:: 0.8.0 mapper events can be associated with + unmapped superclasses of mapped classes. + + Mapper events provide hooks into critical sections of the + mapper, including those related to object instrumentation, + object loading, and object persistence. In particular, the + persistence methods :meth:`~.MapperEvents.before_insert`, + and :meth:`~.MapperEvents.before_update` are popular + places to augment the state being persisted - however, these + methods operate with several significant restrictions. The + user is encouraged to evaluate the + :meth:`.SessionEvents.before_flush` and + :meth:`.SessionEvents.after_flush` methods as more + flexible and user-friendly hooks in which to apply + additional database state during a flush. + + When using :class:`.MapperEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting mappers and/or the mappers of + inheriting classes, as well as any + mapper which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event function + must have a return value, the purpose of which is either to + control subsequent event propagation, or to otherwise alter + the operation in progress by the mapper. Possible return + values are: + + * ``sqlalchemy.orm.interfaces.EXT_CONTINUE`` - continue event + processing normally. + * ``sqlalchemy.orm.interfaces.EXT_STOP`` - cancel all subsequent + event handlers in the chain. + * other values - the return value specified by specific listeners, + such as :meth:`~.MapperEvents.translate_row` or + :meth:`~.MapperEvents.create_instance`. + + """ + + _target_class_doc = "SomeClass" + _dispatch_target = mapperlib.Mapper + + @classmethod + def _new_mapper_instance(cls, class_, mapper): + _MapperEventsHold.populate(class_, mapper) + + @classmethod + @util.dependencies("sqlalchemy.orm") + def _accept_with(cls, orm, target): + if target is orm.mapper: + return mapperlib.Mapper + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return target + else: + mapper = _mapper_or_none(target) + if mapper is not None: + return mapper + else: + return _MapperEventsHold(target) + else: + return target + + @classmethod + def _listen(cls, event_key, raw=False, retval=False, propagate=False, **kw): + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + if identifier in ("before_configured", "after_configured") and \ + target is not mapperlib.Mapper: + util.warn( + "'before_configured' and 'after_configured' ORM events " + "only invoke with the mapper() function or Mapper class " + "as the target.") + + if not raw or not retval: + if not raw: + meth = getattr(cls, identifier) + try: + target_index = \ + inspect.getargspec(meth)[0].index('target') - 1 + except ValueError: + target_index = None + + def wrap(*arg, **kw): + if not raw and target_index is not None: + arg = list(arg) + arg[target_index] = arg[target_index].obj() + if not retval: + fn(*arg, **kw) + return interfaces.EXT_CONTINUE + else: + return fn(*arg, **kw) + event_key = event_key.with_wrapper(wrap) + + if propagate: + for mapper in target.self_and_descendants: + event_key.with_dispatch_target(mapper).base_listen( + propagate=True, **kw) + else: + event_key.base_listen(**kw) + + @classmethod + def _clear(cls): + super(MapperEvents, cls)._clear() + _MapperEventsHold._clear() + + def instrument_class(self, mapper, class_): + """Receive a class when the mapper is first constructed, + before instrumentation is applied to the mapped class. + + This event is the earliest phase of mapper construction. + Most attributes of the mapper are not yet initialized. + + This listener can either be applied to the :class:`.Mapper` + class overall, or to any un-mapped class which serves as a base + for classes that will be mapped (using the ``propagate=True`` flag):: + + Base = declarative_base() + + @event.listens_for(Base, "instrument_class", propagate=True) + def on_new_class(mapper, cls_): + " ... " + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + """ + + def mapper_configured(self, mapper, class_): + """Called when the mapper for the class is fully configured. + + This event is the latest phase of mapper construction, and + is invoked when the mapped classes are first used, so that + relationships between mappers can be resolved. When the event is + called, the mapper should be in its final state. + + While the configuration event normally occurs automatically, + it can be forced to occur ahead of time, in the case where the event + is needed before any actual mapper usage, by using the + :func:`.configure_mappers` function. + + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + """ + # TODO: need coverage for this event + + def before_configured(self): + """Called before a series of mappers have been configured. + + This corresponds to the :func:`.orm.configure_mappers` call, which + note is usually called automatically as mappings are first + used. + + This event can **only** be applied to the :class:`.Mapper` class + or :func:`.mapper` function, and not to individual mappings or + mapped classes. It is only invoked for all mappings as a whole:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "before_configured") + def go(): + # ... + + Theoretically this event is called once per + application, but is actually called any time new mappers + are to be affected by a :func:`.orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event can be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "before_configured", once=True) + def go(): + # ... + + + .. versionadded:: 0.9.3 + + """ + + def after_configured(self): + """Called after a series of mappers have been configured. + + This corresponds to the :func:`.orm.configure_mappers` call, which + note is usually called automatically as mappings are first + used. + + This event can **only** be applied to the :class:`.Mapper` class + or :func:`.mapper` function, and not to individual mappings or + mapped classes. It is only invoked for all mappings as a whole:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "after_configured") + def go(): + # ... + + Theoretically this event is called once per + application, but is actually called any time new mappers + have been affected by a :func:`.orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event can be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "after_configured", once=True) + def go(): + # ... + + """ + + def translate_row(self, mapper, context, row): + """Perform pre-processing on the given result row and return a + new row instance. + + This listener is typically registered with ``retval=True``. + It is called when the mapper first receives a row, before + the object identity or the instance itself has been derived + from that row. The given row may or may not be a + :class:`.RowProxy` object - it will always be a dictionary-like + object which contains mapped columns as keys. The + returned object should also be a dictionary-like object + which recognizes mapped columns as keys. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param context: the :class:`.QueryContext`, which includes + a handle to the current :class:`.Query` in progress as well + as additional state information. + :param row: the result row being handled. This may be + an actual :class:`.RowProxy` or may be a dictionary containing + :class:`.Column` objects as keys. + :return: When configured with ``retval=True``, the function + should return a dictionary-like row object, or ``EXT_CONTINUE``, + indicating the original row should be used. + + + """ + + def create_instance(self, mapper, context, row, class_): + """Receive a row when a new object instance is about to be + created from that row. + + The method can choose to create the instance itself, or it can return + EXT_CONTINUE to indicate normal object creation should take place. + This listener is typically registered with ``retval=True``. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param context: the :class:`.QueryContext`, which includes + a handle to the current :class:`.Query` in progress as well + as additional state information. + :param row: the result row being handled. This may be + an actual :class:`.RowProxy` or may be a dictionary containing + :class:`.Column` objects as keys. + :param class\_: the mapped class. + :return: When configured with ``retval=True``, the return value + should be a newly created instance of the mapped class, + or ``EXT_CONTINUE`` indicating that default object construction + should take place. + + """ + + def append_result(self, mapper, context, row, target, + result, **flags): + """Receive an object instance before that instance is appended + to a result list. + + This is a rarely used hook which can be used to alter + the construction of a result list returned by :class:`.Query`. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param context: the :class:`.QueryContext`, which includes + a handle to the current :class:`.Query` in progress as well + as additional state information. + :param row: the result row being handled. This may be + an actual :class:`.RowProxy` or may be a dictionary containing + :class:`.Column` objects as keys. + :param target: the mapped instance being populated. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param result: a list-like object where results are being + appended. + :param \**flags: Additional state information about the + current handling of the row. + :return: If this method is registered with ``retval=True``, + a return value of ``EXT_STOP`` will prevent the instance + from being appended to the given result list, whereas a + return value of ``EXT_CONTINUE`` will result in the default + behavior of appending the value to the result list. + + """ + + def populate_instance(self, mapper, context, row, + target, **flags): + """Receive an instance before that instance has + its attributes populated. + + This usually corresponds to a newly loaded instance but may + also correspond to an already-loaded instance which has + unloaded attributes to be populated. The method may be called + many times for a single instance, as multiple result rows are + used to populate eagerly loaded collections. + + Most usages of this hook are obsolete. For a + generic "object has been newly created from a row" hook, use + :meth:`.InstanceEvents.load`. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param context: the :class:`.QueryContext`, which includes + a handle to the current :class:`.Query` in progress as well + as additional state information. + :param row: the result row being handled. This may be + an actual :class:`.RowProxy` or may be a dictionary containing + :class:`.Column` objects as keys. + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: When configured with ``retval=True``, a return + value of ``EXT_STOP`` will bypass instance population by + the mapper. A value of ``EXT_CONTINUE`` indicates that + default instance population should take place. + + """ + + def before_insert(self, mapper, connection, target): + """Receive an object instance before an INSERT statement + is emitted corresponding to that instance. + + This event is used to modify local, non-object related + attributes on the instance before an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class before their INSERT statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :func:`.mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** + :class:`.Connection` **only.** Handlers here should **not** make + alterations to the state of the :class:`.Session` overall, and + in general should not affect any :func:`.relationship` -mapped + attributes, as session cascade rules will not function properly, + nor is it always known if the related class has already been + handled. Operations that **are not supported in mapper + events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, or + another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + """ + + def after_insert(self, mapper, connection, target): + """Receive an object instance after an INSERT statement + is emitted corresponding to that instance. + + This event is used to modify in-Python-only + state on the instance after an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class after their INSERT statements have been + emitted at once in a previous step. In the extremely + rare case that this is not desirable, the + :func:`.mapper` can be configured with ``batch=False``, + which will cause batches of instances to be broken up + into individual (and more poorly performing) + event->persist->event steps. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** + :class:`.Connection` **only.** Handlers here should **not** make + alterations to the state of the :class:`.Session` overall, and in + general should not affect any :func:`.relationship` -mapped + attributes, as session cascade rules will not function properly, + nor is it always known if the related class has already been + handled. Operations that **are not supported in mapper + events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, + or another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + """ + + def before_update(self, mapper, connection, target): + """Receive an object instance before an UPDATE statement + is emitted corresponding to that instance. + + This event is used to modify local, non-object related + attributes on the instance before an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.before_update` is + *not* a guarantee that an UPDATE statement will be + issued, although you can affect the outcome here by + modifying attributes so that a net change in value does + exist. + + To detect if the column-based attributes on the object have net + changes, and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class before their UPDATE statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :func:`.mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** :class:`.Connection` + **only.** Handlers here should **not** make alterations to the + state of the :class:`.Session` overall, and in general should not + affect any :func:`.relationship` -mapped attributes, as + session cascade rules will not function properly, nor is it + always known if the related class has already been handled. + Operations that **are not supported in mapper events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, + or another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + """ + + def after_update(self, mapper, connection, target): + """Receive an object instance after an UPDATE statement + is emitted corresponding to that instance. + + This event is used to modify in-Python-only + state on the instance after an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*, and for which + no UPDATE statement has proceeded. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.after_update` is + *not* a guarantee that an UPDATE statement has been + issued. + + To detect if the column-based attributes on the object have net + changes, and therefore resulted in an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class after their UPDATE statements have been emitted at + once in a previous step. In the extremely rare case that + this is not desirable, the :func:`.mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** :class:`.Connection` + **only.** Handlers here should **not** make alterations to the + state of the :class:`.Session` overall, and in general should not + affect any :func:`.relationship` -mapped attributes, as + session cascade rules will not function properly, nor is it + always known if the related class has already been handled. + Operations that **are not supported in mapper events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, + or another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + """ + + def before_delete(self, mapper, connection, target): + """Receive an object instance before a DELETE statement + is emitted corresponding to that instance. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class before their DELETE statements are emitted at + once in a later step. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** :class:`.Connection` + **only.** Handlers here should **not** make alterations to the + state of the :class:`.Session` overall, and in general should not + affect any :func:`.relationship` -mapped attributes, as + session cascade rules will not function properly, nor is it + always known if the related class has already been handled. + Operations that **are not supported in mapper events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, + or another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + """ + + def after_delete(self, mapper, connection, target): + """Receive an object instance after a DELETE statement + has been emitted corresponding to that instance. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class after their DELETE statements have been emitted at + once in a previous step. + + .. warning:: + Mapper-level flush events are designed to operate **on attributes + local to the immediate object being handled + and via SQL operations with the given** :class:`.Connection` + **only.** Handlers here should **not** make alterations to the + state of the :class:`.Session` overall, and in general should not + affect any :func:`.relationship` -mapped attributes, as + session cascade rules will not function properly, nor is it + always known if the related class has already been handled. + Operations that **are not supported in mapper events** include: + + * :meth:`.Session.add` + * :meth:`.Session.delete` + * Mapped collection append, add, remove, delete, discard, etc. + * Mapped relationship attribute set/del events, + i.e. ``someobject.related = someotherobject`` + + Operations which manipulate the state of the object + relative to other objects are better handled: + + * In the ``__init__()`` method of the mapped object itself, + or another method designed to establish some particular state. + * In a ``@validates`` handler, see :ref:`simple_validators` + * Within the :meth:`.SessionEvents.before_flush` event. + + :param mapper: the :class:`.Mapper` which is the target + of this event. + :param connection: the :class:`.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + """ + +class _MapperEventsHold(_EventsHold): + all_holds = weakref.WeakKeyDictionary() + + def resolve(self, class_): + return _mapper_or_none(class_) + + class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents): + pass + + dispatch = event.dispatcher(HoldMapperEvents) + + +class SessionEvents(event.Events): + """Define events specific to :class:`.Session` lifecycle. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import sessionmaker + + def my_before_commit(session): + print "before commit!" + + Session = sessionmaker() + + event.listen(Session, "before_commit", my_before_commit) + + The :func:`~.event.listen` function will accept + :class:`.Session` objects as well as the return result + of :class:`~.sessionmaker()` and :class:`~.scoped_session()`. + + Additionally, it accepts the :class:`.Session` class which + will apply listeners to all :class:`.Session` instances + globally. + + """ + + _target_class_doc = "SomeSessionOrFactory" + + _dispatch_target = Session + + @classmethod + def _accept_with(cls, target): + if isinstance(target, scoped_session): + + target = target.session_factory + if not isinstance(target, sessionmaker) and \ + ( + not isinstance(target, type) or + not issubclass(target, Session) + ): + raise exc.ArgumentError( + "Session event listen on a scoped_session " + "requires that its creation callable " + "is associated with the Session class.") + + if isinstance(target, sessionmaker): + return target.class_ + elif isinstance(target, type): + if issubclass(target, scoped_session): + return Session + elif issubclass(target, Session): + return target + elif isinstance(target, Session): + return target + else: + return None + + def after_transaction_create(self, session, transaction): + """Execute when a new :class:`.SessionTransaction` is created. + + This event differs from :meth:`~.SessionEvents.after_begin` + in that it occurs for each :class:`.SessionTransaction` + overall, as opposed to when transactions are begun + on individual database connections. It is also invoked + for nested transactions and subtransactions, and is always + matched by a corresponding + :meth:`~.SessionEvents.after_transaction_end` event + (assuming normal operation of the :class:`.Session`). + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_transaction_end(self, session, transaction): + """Execute when the span of a :class:`.SessionTransaction` ends. + + This event differs from :meth:`~.SessionEvents.after_commit` + in that it corresponds to all :class:`.SessionTransaction` + objects in use, including those for nested transactions + and subtransactions, and is always matched by a corresponding + :meth:`~.SessionEvents.after_transaction_create` event. + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`~.SessionEvents.after_transaction_create` + + """ + + def before_commit(self, session): + """Execute before commit is called. + + .. note:: + + The :meth:`~.SessionEvents.before_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or :meth:`~.SessionEvents.after_flush_postexec` + events. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_commit(self, session): + """Execute after a commit has occurred. + + .. note:: + + The :meth:`~.SessionEvents.after_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or :meth:`~.SessionEvents.after_flush_postexec` + events. + + .. note:: + + The :class:`.Session` is not in an active tranasction + when the :meth:`~.SessionEvents.after_commit` event is invoked, and therefore + can not emit SQL. To emit SQL corresponding to every transaction, + use the :meth:`~.SessionEvents.before_commit` event. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_rollback(self, session): + """Execute after a real DBAPI rollback has occurred. + + Note that this event only fires when the *actual* rollback against + the database occurs - it does *not* fire each time the + :meth:`.Session.rollback` method is called, if the underlying + DBAPI transaction has already been rolled back. In many + cases, the :class:`.Session` will not be in + an "active" state during this event, as the current + transaction is not valid. To acquire a :class:`.Session` + which is active after the outermost rollback has proceeded, + use the :meth:`.SessionEvents.after_soft_rollback` event, checking the + :attr:`.Session.is_active` flag. + + :param session: The target :class:`.Session`. + + """ + + def after_soft_rollback(self, session, previous_transaction): + """Execute after any rollback has occurred, including "soft" + rollbacks that don't actually emit at the DBAPI level. + + This corresponds to both nested and outer rollbacks, i.e. + the innermost rollback that calls the DBAPI's + rollback() method, as well as the enclosing rollback + calls that only pop themselves from the transaction stack. + + The given :class:`.Session` can be used to invoke SQL and + :meth:`.Session.query` operations after an outermost rollback + by first checking the :attr:`.Session.is_active` flag:: + + @event.listens_for(Session, "after_soft_rollback") + def do_something(session, previous_transaction): + if session.is_active: + session.execute("select * from some_table") + + :param session: The target :class:`.Session`. + :param previous_transaction: The :class:`.SessionTransaction` + transactional marker object which was just closed. The current + :class:`.SessionTransaction` for the given :class:`.Session` is + available via the :attr:`.Session.transaction` attribute. + + .. versionadded:: 0.7.3 + + """ + + def before_flush(self, session, flush_context, instances): + """Execute before flush process has started. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + :param instances: Usually ``None``, this is the collection of + objects which can be passed to the :meth:`.Session.flush` method + (note this usage is deprecated). + + .. seealso:: + + :meth:`~.SessionEvents.after_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + """ + + def after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been + called. + + Note that the session's state is still in pre-flush, i.e. 'new', + 'dirty', and 'deleted' lists still show pre-flush state as well + as the history settings on instance attributes. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + """ + + def after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec + state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in + their final state. An actual commit() may or may not have + occurred, depending on whether or not the flush started its own + transaction or participated in a larger transaction. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush` + + """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + :param session: The target :class:`.Session`. + :param transaction: The :class:`.SessionTransaction`. + :param connection: The :class:`~.engine.Connection` object + which will be used for SQL statements. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def before_attach(self, session, instance): + """Execute before an instance is attached to a session. + + This is called before an add, delete or merge causes + the object to be part of the session. + + .. versionadded:: 0.8. Note that :meth:`~.SessionEvents.after_attach` now + fires off after the item is part of the session. + :meth:`.before_attach` is provided for those cases where + the item should not yet be part of the session state. + + .. seealso:: + + :meth:`~.SessionEvents.after_attach` + + """ + + def after_attach(self, session, instance): + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. + + .. note:: + + As of 0.8, this event fires off *after* the item + has been fully associated with the session, which is + different than previous releases. For event + handlers that require the object not yet + be part of session state (such as handlers which + may autoflush while the target object is not + yet complete) consider the + new :meth:`.before_attach` event. + + .. seealso:: + + :meth:`~.SessionEvents.before_attach` + + """ + + @event._legacy_signature("0.9", + ["session", "query", "query_context", "result"], + lambda update_context: ( + update_context.session, + update_context.query, + update_context.context, + update_context.result)) + def after_bulk_update(self, update_context): + """Execute after a bulk update operation to the session. + + This is called as a result of the :meth:`.Query.update` method. + + :param update_context: an "update context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`.Query` object that this update operation was + called upon. + * ``context`` The :class:`.QueryContext` object, corresponding + to the invocation of an ORM query. + * ``result`` the :class:`.ResultProxy` returned as a result of the + bulk UPDATE operation. + + + """ + + @event._legacy_signature("0.9", + ["session", "query", "query_context", "result"], + lambda delete_context: ( + delete_context.session, + delete_context.query, + delete_context.context, + delete_context.result)) + def after_bulk_delete(self, delete_context): + """Execute after a bulk delete operation to the session. + + This is called as a result of the :meth:`.Query.delete` method. + + :param delete_context: a "delete context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`.Query` object that this update operation was + called upon. + * ``context`` The :class:`.QueryContext` object, corresponding + to the invocation of an ORM query. + * ``result`` the :class:`.ResultProxy` returned as a result of the + bulk DELETE operation. + + + """ + + +class AttributeEvents(event.Events): + """Define events for object attributes. + + These are typically defined on the class-bound descriptor for the + target class. + + e.g.:: + + from sqlalchemy import event + + def my_append_listener(target, value, initiator): + print "received append event for target: %s" % target + + event.listen(MyClass.collection, 'append', my_append_listener) + + Listeners have the option to return a possibly modified version + of the value, when the ``retval=True`` flag is passed + to :func:`~.event.listen`:: + + def validate_phone(target, value, oldvalue, initiator): + "Strip non-numeric characters from a phone number" + + return re.sub(r'(?![0-9])', '', value) + + # setup listener on UserContact.phone attribute, instructing + # it to use the return value + listen(UserContact.phone, 'set', validate_phone, retval=True) + + A validation function like the above can also raise an exception + such as :exc:`ValueError` to halt the operation. + + Several modifiers are available to the :func:`~.event.listen` function. + + :param active_history=False: When True, indicates that the + "set" event would like to receive the "old" value being + replaced unconditionally, even if this requires firing off + database loads. Note that ``active_history`` can also be + set directly via :func:`.column_property` and + :func:`.relationship`. + + :param propagate=False: When True, the listener function will + be established not just for the class attribute given, but + for attributes of the same name on all current subclasses + of that class, as well as all future subclasses of that + class, using an additional listener that listens for + instrumentation events. + :param raw=False: When True, the "target" argument to the + event will be the :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event + listening must return the "value" argument from the + function. This gives the listening function the opportunity + to change the value that is ultimately used for a "set" + or "append" event. + + """ + + _target_class_doc = "SomeClass.some_attribute" + _dispatch_target = QueryableAttribute + + @staticmethod + def _set_dispatch(cls, dispatch_cls): + event.Events._set_dispatch(cls, dispatch_cls) + dispatch_cls._active_history = False + + @classmethod + def _accept_with(cls, target): + # TODO: coverage + if isinstance(target, interfaces.MapperProperty): + return getattr(target.parent.class_, target.key) + else: + return target + + @classmethod + def _listen(cls, event_key, active_history=False, + raw=False, retval=False, + propagate=False): + + target, identifier, fn = \ + event_key.dispatch_target, event_key.identifier, event_key.fn + + if active_history: + target.dispatch._active_history = True + + if not raw or not retval: + def wrap(target, value, *arg): + if not raw: + target = target.obj() + if not retval: + fn(target, value, *arg) + return value + else: + return fn(target, value, *arg) + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate) + + if propagate: + manager = instrumentation.manager_of_class(target.class_) + + for mgr in manager.subclass_managers(True): + event_key.with_dispatch_target(mgr[target.key]).base_listen(propagate=True) + + def append(self, target, value, initiator): + """Receive a collection append event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being appended. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from it's original value by backref handlers in order to control + chained event propagation. + + .. versionchanged:: 0.9.0 the ``initiator`` argument is now + passed as a :class:`.attributes.Event` object, and may be modified + by backref handlers within a chain of backref-linked events. + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + """ + + def remove(self, target, value, initiator): + """Receive a collection remove event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being removed. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from it's original value by backref handlers in order to control + chained event propagation. + + .. versionchanged:: 0.9.0 the ``initiator`` argument is now + passed as a :class:`.attributes.Event` object, and may be modified + by backref handlers within a chain of backref-linked events. + + :return: No return value is defined for this event. + """ + + def set(self, target, value, oldvalue, initiator): + """Receive a scalar set event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being set. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param oldvalue: the previous value being replaced. This + may also be the symbol ``NEVER_SET`` or ``NO_VALUE``. + If the listener is registered with ``active_history=True``, + the previous value of the attribute will be loaded from + the database if the existing value is currently unloaded + or expired. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from it's original value by backref handlers in order to control + chained event propagation. + + .. versionchanged:: 0.9.0 the ``initiator`` argument is now + passed as a :class:`.attributes.Event` object, and may be modified + by backref handlers within a chain of backref-linked events. + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + """ + diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py new file mode 100644 index 00000000..d1ef1ded --- /dev/null +++ b/lib/sqlalchemy/orm/exc.py @@ -0,0 +1,163 @@ +# orm/exc.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQLAlchemy ORM exceptions.""" +from .. import exc as sa_exc, util + +NO_STATE = (AttributeError, KeyError) +"""Exception types that may be raised by instrumentation implementations.""" + + +class StaleDataError(sa_exc.SQLAlchemyError): + """An operation encountered database state that is unaccounted for. + + Conditions which cause this to happen include: + + * A flush may have attempted to update or delete rows + and an unexpected number of rows were matched during + the UPDATE or DELETE statement. Note that when + version_id_col is used, rows in UPDATE or DELETE statements + are also matched against the current known version + identifier. + + * A mapped object with version_id_col was refreshed, + and the version number coming back from the database does + not match that of the object itself. + + * A object is detached from its parent object, however + the object was previously attached to a different parent + identity which was garbage collected, and a decision + cannot be made if the new parent was really the most + recent "parent". + + .. versionadded:: 0.7.4 + + """ + +ConcurrentModificationError = StaleDataError + + +class FlushError(sa_exc.SQLAlchemyError): + """A invalid condition was detected during flush().""" + + +class UnmappedError(sa_exc.InvalidRequestError): + """Base for exceptions that involve expected mappings not present.""" + + +class ObjectDereferencedError(sa_exc.SQLAlchemyError): + """An operation cannot complete due to an object being garbage + collected. + + """ + + +class DetachedInstanceError(sa_exc.SQLAlchemyError): + """An attempt to access unloaded attributes on a + mapped instance that is detached.""" + + +class UnmappedInstanceError(UnmappedError): + """An mapping operation was requested for an unknown instance.""" + + @util.dependencies("sqlalchemy.orm.base") + def __init__(self, base, obj, msg=None): + if not msg: + try: + base.class_mapper(type(obj)) + name = _safe_cls_name(type(obj)) + msg = ("Class %r is mapped, but this instance lacks " + "instrumentation. This occurs when the instance" + "is created before sqlalchemy.orm.mapper(%s) " + "was called." % (name, name)) + except UnmappedClassError: + msg = _default_unmapped(type(obj)) + if isinstance(obj, type): + msg += ( + '; was a class (%s) supplied where an instance was ' + 'required?' % _safe_cls_name(obj)) + UnmappedError.__init__(self, msg) + + def __reduce__(self): + return self.__class__, (None, self.args[0]) + + +class UnmappedClassError(UnmappedError): + """An mapping operation was requested for an unknown class.""" + + def __init__(self, cls, msg=None): + if not msg: + msg = _default_unmapped(cls) + UnmappedError.__init__(self, msg) + + def __reduce__(self): + return self.__class__, (None, self.args[0]) + + +class ObjectDeletedError(sa_exc.InvalidRequestError): + """A refresh operation failed to retrieve the database + row corresponding to an object's known primary key identity. + + A refresh operation proceeds when an expired attribute is + accessed on an object, or when :meth:`.Query.get` is + used to retrieve an object which is, upon retrieval, detected + as expired. A SELECT is emitted for the target row + based on primary key; if no row is returned, this + exception is raised. + + The true meaning of this exception is simply that + no row exists for the primary key identifier associated + with a persistent object. The row may have been + deleted, or in some cases the primary key updated + to a new value, outside of the ORM's management of the target + object. + + """ + @util.dependencies("sqlalchemy.orm.base") + def __init__(self, base, state, msg=None): + if not msg: + msg = "Instance '%s' has been deleted, or its "\ + "row is otherwise not present." % base.state_str(state) + + sa_exc.InvalidRequestError.__init__(self, msg) + + def __reduce__(self): + return self.__class__, (None, self.args[0]) + + +class UnmappedColumnError(sa_exc.InvalidRequestError): + """Mapping operation was requested on an unknown column.""" + + +class NoResultFound(sa_exc.InvalidRequestError): + """A database result was required but none was found.""" + + +class MultipleResultsFound(sa_exc.InvalidRequestError): + """A single database result was required but more than one were found.""" + + +def _safe_cls_name(cls): + try: + cls_name = '.'.join((cls.__module__, cls.__name__)) + except AttributeError: + cls_name = getattr(cls, '__name__', None) + if cls_name is None: + cls_name = repr(cls) + return cls_name + +@util.dependencies("sqlalchemy.orm.base") +def _default_unmapped(base, cls): + try: + mappers = base.manager_of_class(cls).mappers + except NO_STATE: + mappers = {} + except TypeError: + mappers = {} + name = _safe_cls_name(cls) + + if not mappers: + return "Class '%s' is not mapped" % name diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py new file mode 100644 index 00000000..a91085d2 --- /dev/null +++ b/lib/sqlalchemy/orm/identity.py @@ -0,0 +1,240 @@ +# orm/identity.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import weakref +from . import attributes +from .. import util + +class IdentityMap(dict): + def __init__(self): + self._modified = set() + self._wr = weakref.ref(self) + + def replace(self, state): + raise NotImplementedError() + + def add(self, state): + raise NotImplementedError() + + def update(self, dict): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def clear(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def _manage_incoming_state(self, state): + state._instance_dict = self._wr + + if state.modified: + self._modified.add(state) + + def _manage_removed_state(self, state): + del state._instance_dict + self._modified.discard(state) + + def _dirty_states(self): + return self._modified + + def check_modified(self): + """return True if any InstanceStates present have been marked + as 'modified'. + + """ + return bool(self._modified) + + def has_key(self, key): + return key in self + + def popitem(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def pop(self, key, *args): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def setdefault(self, key, default=None): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def copy(self): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def __delitem__(self, key): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + +class WeakInstanceDict(IdentityMap): + def __init__(self): + IdentityMap.__init__(self) + + def __getitem__(self, key): + state = dict.__getitem__(self, key) + o = state.obj() + if o is None: + raise KeyError(key) + return o + + def __contains__(self, key): + try: + if dict.__contains__(self, key): + state = dict.__getitem__(self, key) + o = state.obj() + else: + return False + except KeyError: + return False + else: + return o is not None + + def contains_state(self, state): + return dict.get(self, state.key) is state + + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + + def add(self, state): + key = state.key + # inline of self.__contains__ + if dict.__contains__(self, key): + try: + existing_state = dict.__getitem__(self, key) + if existing_state is not state: + o = existing_state.obj() + if o is not None: + raise AssertionError( + "A conflicting state is already " + "present in the identity map for key %r" + % (key, )) + else: + return + except KeyError: + pass + dict.__setitem__(self, key, state) + self._manage_incoming_state(state) + + def get(self, key, default=None): + state = dict.get(self, key, default) + if state is default: + return default + o = state.obj() + if o is None: + return default + return o + + def _items(self): + values = self.all_states() + result = [] + for state in values: + value = state.obj() + if value is not None: + result.append((state.key, value)) + return result + + def _values(self): + values = self.all_states() + result = [] + for state in values: + value = state.obj() + if value is not None: + result.append(value) + + return result + + if util.py2k: + items = _items + values = _values + + def iteritems(self): + return iter(self.items()) + + def itervalues(self): + return iter(self.values()) + else: + def items(self): + return iter(self._items()) + + def values(self): + return iter(self._values()) + + def all_states(self): + if util.py2k: + return dict.values(self) + else: + return list(dict.values(self)) + + def discard(self, state): + st = dict.get(self, state.key, None) + if st is state: + dict.pop(self, state.key, None) + self._manage_removed_state(state) + + def prune(self): + return 0 + + +class StrongInstanceDict(IdentityMap): + def all_states(self): + return [attributes.instance_state(o) for o in self.values()] + + def contains_state(self, state): + return ( + state.key in self and + attributes.instance_state(self[state.key]) is state) + + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + existing = attributes.instance_state(existing) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + + def add(self, state): + if state.key in self: + if attributes.instance_state(dict.__getitem__(self, + state.key)) is not state: + raise AssertionError('A conflicting state is already ' + 'present in the identity map for key %r' + % (state.key, )) + else: + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + + def discard(self, state): + obj = dict.get(self, state.key, None) + if obj is not None: + st = attributes.instance_state(obj) + if st is state: + dict.pop(self, state.key, None) + self._manage_removed_state(state) + + def prune(self): + """prune unreferenced, non-dirty states.""" + + ref_count = len(self) + dirty = [s.obj() for s in self.all_states() if s.modified] + + # work around http://bugs.python.org/issue6149 + keepers = weakref.WeakValueDictionary() + keepers.update(self) + + dict.clear(self) + dict.update(self, keepers) + self.modified = bool(dirty) + return ref_count - len(self) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py new file mode 100644 index 00000000..68b4f061 --- /dev/null +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -0,0 +1,499 @@ +# orm/instrumentation.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines SQLAlchemy's system of class instrumentation. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +instrumentation.py deals with registration of end-user classes +for state tracking. It interacts closely with state.py +and attributes.py which establish per-instance and per-class-attribute +instrumentation, respectively. + +The class instrumentation system can be customized on a per-class +or global basis using the :mod:`sqlalchemy.ext.instrumentation` +module, which provides the means to build and specify +alternate instrumentation forms. + +.. versionchanged: 0.8 + The instrumentation extension system was moved out of the + ORM and into the external :mod:`sqlalchemy.ext.instrumentation` + package. When that package is imported, it installs + itself within sqlalchemy.orm so that its more comprehensive + resolution mechanics take effect. + +""" + + +from . import exc, collections, interfaces, state +from .. import util +from . import base + +class ClassManager(dict): + """tracks state information at the class level.""" + + MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR + STATE_ATTR = base.DEFAULT_STATE_ATTR + + deferred_scalar_loader = None + + original_init = object.__init__ + + factory = None + + def __init__(self, class_): + self.class_ = class_ + self.info = {} + self.new_init = None + self.local_attrs = {} + self.originals = {} + + self._bases = [mgr for mgr in [ + manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ] if mgr is not None] + + for base in self._bases: + self.update(base) + + self.dispatch._events._new_classmanager_instance(class_, self) + #events._InstanceEventsHold.populate(class_, self) + + for basecls in class_.__mro__: + mgr = manager_of_class(basecls) + if mgr is not None: + self.dispatch._update(mgr.dispatch) + self.manage() + self._instrument_init() + + if '__del__' in class_.__dict__: + util.warn("__del__() method on class %s will " + "cause unreachable cycles and memory leaks, " + "as SQLAlchemy instrumentation often creates " + "reference cycles. Please remove this method." % + class_) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return other is self + + @property + def is_mapped(self): + return 'mapper' in self.__dict__ + + @util.memoized_property + def mapper(self): + # raises unless self.mapper has been assigned + raise exc.UnmappedClassError(self.class_) + + def _all_sqla_attributes(self, exclude=None): + """return an iterator of all classbound attributes that are + implement :class:`._InspectionAttr`. + + This includes :class:`.QueryableAttribute` as well as extension + types such as :class:`.hybrid_property` and :class:`.AssociationProxy`. + + """ + if exclude is None: + exclude = set() + for supercls in self.class_.__mro__: + for key in set(supercls.__dict__).difference(exclude): + exclude.add(key) + val = supercls.__dict__[key] + if isinstance(val, interfaces._InspectionAttr): + yield key, val + + + def _attr_has_impl(self, key): + """Return True if the given attribute is fully initialized. + + i.e. has an impl. + """ + + return key in self and self[key].impl is not None + + def _subclass_manager(self, cls): + """Create a new ClassManager for a subclass of this ClassManager's + class. + + This is called automatically when attributes are instrumented so that + the attributes can be propagated to subclasses against their own + class-local manager, without the need for mappers etc. to have already + pre-configured managers for the full class hierarchy. Mappers + can post-configure the auto-generated ClassManager when needed. + + """ + manager = manager_of_class(cls) + if manager is None: + manager = _instrumentation_factory.create_manager_for_cls(cls) + return manager + + def _instrument_init(self): + # TODO: self.class_.__init__ is often the already-instrumented + # __init__ from an instrumented superclass. We still need to make + # our own wrapper, but it would + # be nice to wrap the original __init__ and not our existing wrapper + # of such, since this adds method overhead. + self.original_init = self.class_.__init__ + self.new_init = _generate_init(self.class_, self) + self.install_member('__init__', self.new_init) + + def _uninstrument_init(self): + if self.new_init: + self.uninstall_member('__init__') + self.new_init = None + + @util.memoized_property + def _state_constructor(self): + self.dispatch.first_init(self, self.class_) + return state.InstanceState + + def manage(self): + """Mark this instance as the manager for its class.""" + + setattr(self.class_, self.MANAGER_ATTR, self) + + def dispose(self): + """Dissasociate this manager from its class.""" + + delattr(self.class_, self.MANAGER_ATTR) + + @util.hybridmethod + def manager_getter(self): + return _default_manager_getter + + @util.hybridmethod + def state_getter(self): + """Return a (instance) -> InstanceState callable. + + "state getter" callables should raise either KeyError or + AttributeError if no InstanceState could be found for the + instance. + """ + + return _default_state_getter + + @util.hybridmethod + def dict_getter(self): + return _default_dict_getter + + + def instrument_attribute(self, key, inst, propagated=False): + if propagated: + if key in self.local_attrs: + return # don't override local attr with inherited attr + else: + self.local_attrs[key] = inst + self.install_descriptor(key, inst) + self[key] = inst + + for cls in self.class_.__subclasses__(): + manager = self._subclass_manager(cls) + manager.instrument_attribute(key, inst, True) + + def subclass_managers(self, recursive): + for cls in self.class_.__subclasses__(): + mgr = manager_of_class(cls) + if mgr is not None and mgr is not self: + yield mgr + if recursive: + for m in mgr.subclass_managers(True): + yield m + + def post_configure_attribute(self, key): + _instrumentation_factory.dispatch.\ + attribute_instrument(self.class_, key, self[key]) + + def uninstrument_attribute(self, key, propagated=False): + if key not in self: + return + if propagated: + if key in self.local_attrs: + return # don't get rid of local attr + else: + del self.local_attrs[key] + self.uninstall_descriptor(key) + del self[key] + for cls in self.class_.__subclasses__(): + manager = manager_of_class(cls) + if manager: + manager.uninstrument_attribute(key, True) + + def unregister(self): + """remove all instrumentation established by this ClassManager.""" + + self._uninstrument_init() + + self.mapper = self.dispatch = None + self.info.clear() + + for key in list(self): + if key in self.local_attrs: + self.uninstrument_attribute(key) + + def install_descriptor(self, key, inst): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % + key) + setattr(self.class_, key, inst) + + def uninstall_descriptor(self, key): + delattr(self.class_, key) + + def install_member(self, key, implementation): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % + key) + self.originals.setdefault(key, getattr(self.class_, key, None)) + setattr(self.class_, key, implementation) + + def uninstall_member(self, key): + original = self.originals.pop(key, None) + if original is not None: + setattr(self.class_, key, original) + + def instrument_collection_class(self, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def initialize_collection(self, key, state, factory): + user_data = factory() + adapter = collections.CollectionAdapter( + self.get_impl(key), state, user_data) + return adapter, user_data + + def is_instrumented(self, key, search=False): + if search: + return key in self + else: + return key in self.local_attrs + + def get_impl(self, key): + return self[key].impl + + @property + def attributes(self): + return iter(self.values()) + + ## InstanceState management + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + setattr(instance, self.STATE_ATTR, + state or self._state_constructor(instance, self)) + return instance + + def setup_instance(self, instance, state=None): + setattr(instance, self.STATE_ATTR, + state or self._state_constructor(instance, self)) + + def teardown_instance(self, instance): + delattr(instance, self.STATE_ATTR) + + def _serialize(self, state, state_dict): + return _SerializeManager(state, state_dict) + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + + """ + if hasattr(instance, self.STATE_ATTR): + return False + elif self.class_ is not instance.__class__ and \ + self.is_mapped: + # this will create a new ClassManager for the + # subclass, without a mapper. This is likely a + # user error situation but allow the object + # to be constructed, so that it is usable + # in a non-ORM context at least. + return self._subclass_manager(instance.__class__).\ + _new_state_if_none(instance) + else: + state = self._state_constructor(instance, self) + setattr(instance, self.STATE_ATTR, state) + return state + + def has_state(self, instance): + return hasattr(instance, self.STATE_ATTR) + + def has_parent(self, state, key, optimistic=False): + """TODO""" + return self.get_impl(key).hasparent(state, optimistic=optimistic) + + def __bool__(self): + """All ClassManagers are non-zero regardless of attribute state.""" + return True + + __nonzero__ = __bool__ + + def __repr__(self): + return '<%s of %r at %x>' % ( + self.__class__.__name__, self.class_, id(self)) + +class _SerializeManager(object): + """Provide serialization of a :class:`.ClassManager`. + + The :class:`.InstanceState` uses ``__init__()`` on serialize + and ``__call__()`` on deserialize. + + """ + def __init__(self, state, d): + self.class_ = state.class_ + manager = state.manager + manager.dispatch.pickle(state, d) + + def __call__(self, state, inst, state_dict): + state.manager = manager = manager_of_class(self.class_) + if manager is None: + raise exc.UnmappedInstanceError( + inst, + "Cannot deserialize object of type %r - " + "no mapper() has " + "been configured for this class within the current " + "Python process!" % + self.class_) + elif manager.is_mapped and not manager.mapper.configured: + manager.mapper._configure_all() + + # setup _sa_instance_state ahead of time so that + # unpickle events can access the object normally. + # see [ticket:2362] + if inst is not None: + manager.setup_instance(inst, state) + manager.dispatch.unpickle(state, state_dict) + +class InstrumentationFactory(object): + """Factory for new ClassManager instances.""" + + def create_manager_for_cls(self, class_): + assert class_ is not None + assert manager_of_class(class_) is None + + # give a more complicated subclass + # a chance to do what it wants here + manager, factory = self._locate_extended_factory(class_) + + if factory is None: + factory = ClassManager + manager = factory(class_) + + self._check_conflicts(class_, factory) + + manager.factory = factory + + self.dispatch.class_instrument(class_) + return manager + + def _locate_extended_factory(self, class_): + """Overridden by a subclass to do an extended lookup.""" + return None, None + + def _check_conflicts(self, class_, factory): + """Overridden by a subclass to test for conflicting factories.""" + return + + def unregister(self, class_): + manager = manager_of_class(class_) + manager.unregister() + manager.dispose() + self.dispatch.class_uninstrument(class_) + if ClassManager.MANAGER_ATTR in class_.__dict__: + delattr(class_, ClassManager.MANAGER_ATTR) + +# this attribute is replaced by sqlalchemy.ext.instrumentation +# when importred. +_instrumentation_factory = InstrumentationFactory() + +# these attributes are replaced by sqlalchemy.ext.instrumentation +# when a non-standard InstrumentationManager class is first +# used to instrument a class. +instance_state = _default_state_getter = base.instance_state + +instance_dict = _default_dict_getter = base.instance_dict + +manager_of_class = _default_manager_getter = base.manager_of_class + +def register_class(class_): + """Register class instrumentation. + + Returns the existing or newly created class manager. + + """ + + manager = manager_of_class(class_) + if manager is None: + manager = _instrumentation_factory.create_manager_for_cls(class_) + return manager + + +def unregister_class(class_): + """Unregister class instrumentation.""" + + _instrumentation_factory.unregister(class_) + + +def is_instrumented(instance, key): + """Return True if the given attribute on the given instance is + instrumented by the attributes package. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + + """ + return manager_of_class(instance.__class__).\ + is_instrumented(key, search=True) + + +def _generate_init(class_, class_manager): + """Build an __init__ decorator that triggers ClassManager events.""" + + # TODO: we should use the ClassManager's notion of the + # original '__init__' method, once ClassManager is fixed + # to always reference that. + original__init__ = class_.__init__ + assert original__init__ + + # Go through some effort here and don't change the user's __init__ + # calling signature, including the unlikely case that it has + # a return value. + # FIXME: need to juggle local names to avoid constructor argument + # clashes. + func_body = """\ +def __init__(%(apply_pos)s): + new_state = class_manager._new_state_if_none(%(self_arg)s) + if new_state: + return new_state._initialize_instance(%(apply_kw)s) + else: + return original__init__(%(apply_kw)s) +""" + func_vars = util.format_argspec_init(original__init__, grouped=False) + func_text = func_body % func_vars + + if util.py2k: + func = getattr(original__init__, 'im_func', original__init__) + func_defaults = getattr(func, 'func_defaults', None) + else: + func_defaults = getattr(original__init__, '__defaults__', None) + func_kw_defaults = getattr(original__init__, '__kwdefaults__', None) + + env = locals().copy() + exec(func_text, env) + __init__ = env['__init__'] + __init__.__doc__ = original__init__.__doc__ + + if func_defaults: + __init__.__defaults__ = func_defaults + if not util.py2k and func_kw_defaults: + __init__.__kwdefaults__ = func_kw_defaults + + return __init__ diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py new file mode 100644 index 00000000..1b0bf48a --- /dev/null +++ b/lib/sqlalchemy/orm/interfaces.py @@ -0,0 +1,577 @@ +# orm/interfaces.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +Contains various base classes used throughout the ORM. + +Defines the now deprecated ORM extension classes as well +as ORM internals. + +Other than the deprecated extensions, this module and the +classes within should be considered mostly private. + +""" + +from __future__ import absolute_import + +from .. import exc as sa_exc, util, inspect +from ..sql import operators +from collections import deque +from .base import ONETOMANY, MANYTOONE, MANYTOMANY, EXT_CONTINUE, EXT_STOP, NOT_EXTENSION +from .base import _InspectionAttr, _MappedAttribute +from .path_registry import PathRegistry +import collections + + +__all__ = ( + 'AttributeExtension', + 'EXT_CONTINUE', + 'EXT_STOP', + 'ONETOMANY', + 'MANYTOMANY', + 'MANYTOONE', + 'NOT_EXTENSION', + 'LoaderStrategy', + 'MapperExtension', + 'MapperOption', + 'MapperProperty', + 'PropComparator', + 'SessionExtension', + 'StrategizedProperty', + ) + + + +class MapperProperty(_MappedAttribute, _InspectionAttr): + """Manage the relationship of a ``Mapper`` to a single class + attribute, as well as that attribute as it appears on individual + instances of the class, including attribute instrumentation, + attribute access, loading behavior, and dependency calculations. + + The most common occurrences of :class:`.MapperProperty` are the + mapped :class:`.Column`, which is represented in a mapping as + an instance of :class:`.ColumnProperty`, + and a reference to another class produced by :func:`.relationship`, + represented in the mapping as an instance of + :class:`.RelationshipProperty`. + + """ + + cascade = frozenset() + """The set of 'cascade' attribute names. + + This collection is checked before the 'cascade_iterator' method is called. + + """ + + is_property = True + + def setup(self, context, entity, path, adapter, **kwargs): + """Called by Query for the purposes of constructing a SQL statement. + + Each MapperProperty associated with the target mapper processes the + statement referenced by the query context, adding columns and/or + criterion as appropriate. + """ + + pass + + def create_row_processor(self, context, path, + mapper, row, adapter): + """Return a 3-tuple consisting of three row processing functions. + + """ + return None, None, None + + def cascade_iterator(self, type_, state, visited_instances=None, + halt_on=None): + """Iterate through instances related to the given instance for + a particular 'cascade', starting with this MapperProperty. + + Return an iterator3-tuples (instance, mapper, state). + + Note that the 'cascade' collection on this MapperProperty is + checked first for the given type before cascade_iterator is called. + + See PropertyLoader for the related instance implementation. + """ + + return iter(()) + + def set_parent(self, parent, init): + self.parent = parent + + def instrument_class(self, mapper): # pragma: no-coverage + raise NotImplementedError() + + @util.memoized_property + def info(self): + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.MapperProperty`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`.relationship`, or :func:`.composite` + functions. + + .. versionadded:: 0.8 Added support for .info to all + :class:`.MapperProperty` subclasses. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + + _configure_started = False + _configure_finished = False + + def init(self): + """Called after all mappers are created to assemble + relationships between mappers and perform other post-mapper-creation + initialization steps. + + """ + self._configure_started = True + self.do_init() + self._configure_finished = True + + @property + def class_attribute(self): + """Return the class-bound descriptor corresponding to this + :class:`.MapperProperty`. + + This is basically a ``getattr()`` call:: + + return getattr(self.parent.class_, self.key) + + I.e. if this :class:`.MapperProperty` were named ``addresses``, + and the class to which it is mapped is ``User``, this sequence + is possible:: + + >>> from sqlalchemy import inspect + >>> mapper = inspect(User) + >>> addresses_property = mapper.attrs.addresses + >>> addresses_property.class_attribute is User.addresses + True + >>> User.addresses.property is addresses_property + True + + + """ + + return getattr(self.parent.class_, self.key) + + def do_init(self): + """Perform subclass-specific initialization post-mapper-creation + steps. + + This is a template method called by the ``MapperProperty`` + object's init() method. + + """ + + pass + + def post_instrument_class(self, mapper): + """Perform instrumentation adjustments that need to occur + after init() has completed. + + """ + pass + + def is_primary(self): + """Return True if this ``MapperProperty``'s mapper is the + primary mapper for its class. + + This flag is used to indicate that the ``MapperProperty`` can + define attribute instrumentation for the class at the class + level (as opposed to the individual instance level). + """ + + return not self.parent.non_primary + + def merge(self, session, source_state, source_dict, dest_state, + dest_dict, load, _recursive): + """Merge the attribute represented by this ``MapperProperty`` + from source to destination object""" + + pass + + def compare(self, operator, value, **kw): + """Return a compare operation for the columns represented by + this ``MapperProperty`` to the given value, which may be a + column value or an instance. 'operator' is an operator from + the operators module, or from sql.Comparator. + + By default uses the PropComparator attached to this MapperProperty + under the attribute name "comparator". + """ + + return operator(self.comparator, value) + + def __repr__(self): + return '<%s at 0x%x; %s>' % ( + self.__class__.__name__, + id(self), getattr(self, 'key', 'no key')) + +class PropComparator(operators.ColumnOperators): + """Defines boolean, comparison, and other operators for + :class:`.MapperProperty` objects. + + SQLAlchemy allows for operators to + be redefined at both the Core and ORM level. :class:`.PropComparator` + is the base class of operator redefinition for ORM-level operations, + including those of :class:`.ColumnProperty`, + :class:`.RelationshipProperty`, and :class:`.CompositeProperty`. + + .. note:: With the advent of Hybrid properties introduced in SQLAlchemy + 0.7, as well as Core-level operator redefinition in + SQLAlchemy 0.8, the use case for user-defined :class:`.PropComparator` + instances is extremely rare. See :ref:`hybrids_toplevel` as well + as :ref:`types_operators`. + + User-defined subclasses of :class:`.PropComparator` may be created. The + built-in Python comparison and math operator methods, such as + :meth:`.operators.ColumnOperators.__eq__`, + :meth:`.operators.ColumnOperators.__lt__`, and + :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide + new operator behavior. The custom :class:`.PropComparator` is passed to + the :class:`.MapperProperty` instance via the ``comparator_factory`` + argument. In each case, + the appropriate subclass of :class:`.PropComparator` should be used:: + + # definition of custom PropComparator subclasses + + from sqlalchemy.orm.properties import \\ + ColumnProperty,\\ + CompositeProperty,\\ + RelationshipProperty + + class MyColumnComparator(ColumnProperty.Comparator): + def __eq__(self, other): + return self.__clause_element__() == other + + class MyRelationshipComparator(RelationshipProperty.Comparator): + def any(self, expression): + "define the 'any' operation" + # ... + + class MyCompositeComparator(CompositeProperty.Comparator): + def __gt__(self, other): + "redefine the 'greater than' operation" + + return sql.and_(*[a>b for a, b in + zip(self.__clause_element__().clauses, + other.__composite_values__())]) + + + # application of custom PropComparator subclasses + + from sqlalchemy.orm import column_property, relationship, composite + from sqlalchemy import Column, String + + class SomeMappedClass(Base): + some_column = column_property(Column("some_column", String), + comparator_factory=MyColumnComparator) + + some_relationship = relationship(SomeOtherClass, + comparator_factory=MyRelationshipComparator) + + some_composite = composite( + Column("a", String), Column("b", String), + comparator_factory=MyCompositeComparator + ) + + Note that for column-level operator redefinition, it's usually + simpler to define the operators at the Core level, using the + :attr:`.TypeEngine.comparator_factory` attribute. See + :ref:`types_operators` for more detail. + + See also: + + :class:`.ColumnProperty.Comparator` + + :class:`.RelationshipProperty.Comparator` + + :class:`.CompositeProperty.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + def __init__(self, prop, parentmapper, adapt_to_entity=None): + self.prop = self.property = prop + self._parentmapper = parentmapper + self._adapt_to_entity = adapt_to_entity + + def __clause_element__(self): + raise NotImplementedError("%r" % self) + + def _query_clause_element(self): + return self.__clause_element__() + + def adapt_to_entity(self, adapt_to_entity): + """Return a copy of this PropComparator which will use the given + :class:`.AliasedInsp` to produce corresponding expressions. + """ + return self.__class__(self.prop, self._parentmapper, adapt_to_entity) + + @property + def adapter(self): + """Produce a callable that adapts column expressions + to suit an aliased version of this comparator. + + """ + if self._adapt_to_entity is None: + return None + else: + return self._adapt_to_entity._adapt_element + + @util.memoized_property + def info(self): + return self.property.info + + @staticmethod + def any_op(a, b, **kwargs): + return a.any(b, **kwargs) + + @staticmethod + def has_op(a, b, **kwargs): + return a.has(b, **kwargs) + + @staticmethod + def of_type_op(a, class_): + return a.of_type(class_) + + def of_type(self, class_): + """Redefine this object in terms of a polymorphic subclass. + + Returns a new PropComparator from which further criterion can be + evaluated. + + e.g.:: + + query.join(Company.employees.of_type(Engineer)).\\ + filter(Engineer.name=='foo') + + :param \class_: a class or mapper indicating that criterion will be + against this specific subclass. + + + """ + + return self.operate(PropComparator.of_type_op, class_) + + def any(self, criterion=None, **kwargs): + """Return true if this collection contains any member that meets the + given criterion. + + The usual implementation of ``any()`` is + :meth:`.RelationshipProperty.Comparator.any`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate(PropComparator.any_op, criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + """Return true if this element references a member which meets the + given criterion. + + The usual implementation of ``has()`` is + :meth:`.RelationshipProperty.Comparator.has`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate(PropComparator.has_op, criterion, **kwargs) + + +class StrategizedProperty(MapperProperty): + """A MapperProperty which uses selectable strategies to affect + loading behavior. + + There is a single strategy selected by default. Alternate + strategies can be selected at Query time through the usage of + ``StrategizedOption`` objects via the Query.options() method. + + """ + + strategy_wildcard_key = None + + def _get_context_loader(self, context, path): + load = None + + # use EntityRegistry.__getitem__()->PropRegistry here so + # that the path is stated in terms of our base + search_path = dict.__getitem__(path, self) + + # search among: exact match, "attr.*", "default" strategy + # if any. + for path_key in ( + search_path._loader_key, + search_path._wildcard_path_loader_key, + search_path._default_path_loader_key + ): + if path_key in context.attributes: + load = context.attributes[path_key] + break + + return load + + def _get_strategy(self, key): + try: + return self._strategies[key] + except KeyError: + cls = self._strategy_lookup(*key) + self._strategies[key] = self._strategies[cls] = strategy = cls(self) + return strategy + + def _get_strategy_by_cls(self, cls): + return self._get_strategy(cls._strategy_keys[0]) + + def setup(self, context, entity, path, adapter, **kwargs): + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.setup_query(context, entity, path, loader, adapter, **kwargs) + + def create_row_processor(self, context, path, mapper, row, adapter): + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + return strat.create_row_processor(context, path, loader, + mapper, row, adapter) + + def do_init(self): + self._strategies = {} + self.strategy = self._get_strategy_by_cls(self.strategy_class) + + def post_instrument_class(self, mapper): + if self.is_primary() and \ + not mapper.class_manager._attr_has_impl(self.key): + self.strategy.init_class_attribute(mapper) + + + _strategies = collections.defaultdict(dict) + + @classmethod + def strategy_for(cls, **kw): + def decorate(dec_cls): + dec_cls._strategy_keys = [] + key = tuple(sorted(kw.items())) + cls._strategies[cls][key] = dec_cls + dec_cls._strategy_keys.append(key) + return dec_cls + return decorate + + @classmethod + def _strategy_lookup(cls, *key): + for prop_cls in cls.__mro__: + if prop_cls in cls._strategies: + strategies = cls._strategies[prop_cls] + try: + return strategies[key] + except KeyError: + pass + raise Exception("can't locate strategy for %s %s" % (cls, key)) + + +class MapperOption(object): + """Describe a modification to a Query.""" + + propagate_to_loaders = False + """if True, indicate this option should be carried along + Query object generated by scalar or object lazy loaders. + """ + + def process_query(self, query): + pass + + def process_query_conditionally(self, query): + """same as process_query(), except that this option may not + apply to the given query. + + Used when secondary loaders resend existing options to a new + Query.""" + + self.process_query(query) + + + + +class LoaderStrategy(object): + """Describe the loading behavior of a StrategizedProperty object. + + The ``LoaderStrategy`` interacts with the querying process in three + ways: + + * it controls the configuration of the ``InstrumentedAttribute`` + placed on a class to handle the behavior of the attribute. this + may involve setting up class-level callable functions to fire + off a select operation when the attribute is first accessed + (i.e. a lazy load) + + * it processes the ``QueryContext`` at statement construction time, + where it can modify the SQL statement that is being produced. + Simple column attributes may add their represented column to the + list of selected columns, *eager loading* properties may add + ``LEFT OUTER JOIN`` clauses to the statement. + + * It produces "row processor" functions at result fetching time. + These "row processor" functions populate a particular attribute + on a particular mapped instance. + + """ + def __init__(self, parent): + self.parent_property = parent + self.is_class_level = False + self.parent = self.parent_property.parent + self.key = self.parent_property.key + + def init_class_attribute(self, mapper): + pass + + def setup_query(self, context, entity, path, loadopt, adapter, **kwargs): + pass + + def create_row_processor(self, context, path, loadopt, mapper, + row, adapter): + """Return row processing functions which fulfill the contract + specified by MapperProperty.create_row_processor. + + StrategizedProperty delegates its create_row_processor method + directly to this method. """ + + return None, None, None + + def __str__(self): + return str(self.parent_property) diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py new file mode 100644 index 00000000..b79ea429 --- /dev/null +++ b/lib/sqlalchemy/orm/loading.py @@ -0,0 +1,611 @@ +# orm/loading.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to convert database +rows into object instances and associated state. + +the functions here are called primarily by Query, Mapper, +as well as some of the attribute loading strategies. + +""" + + +from .. import util +from . import attributes, exc as orm_exc, state as statelib +from .interfaces import EXT_CONTINUE +from ..sql import util as sql_util +from .util import _none_set, state_str +from .. import exc as sa_exc + +_new_runid = util.counter() + + +def instances(query, cursor, context): + """Return an ORM result as an iterator.""" + session = query.session + + context.runid = _new_runid() + + filter_fns = [ent.filter_fn + for ent in query._entities] + filtered = id in filter_fns + + single_entity = len(query._entities) == 1 and \ + query._entities[0].supports_single_entity + + if filtered: + if single_entity: + filter_fn = id + else: + def filter_fn(row): + return tuple(fn(x) for x, fn in zip(row, filter_fns)) + + custom_rows = single_entity and \ + query._entities[0].custom_rows + + (process, labels) = \ + list(zip(*[ + query_entity.row_processor(query, + context, custom_rows) + for query_entity in query._entities + ])) + + while True: + context.progress = {} + context.partials = {} + + if query._yield_per: + fetch = cursor.fetchmany(query._yield_per) + if not fetch: + break + else: + fetch = cursor.fetchall() + + if custom_rows: + rows = [] + for row in fetch: + process[0](row, rows) + elif single_entity: + rows = [process[0](row, None) for row in fetch] + else: + rows = [util.KeyedTuple([proc(row, None) for proc in process], + labels) for row in fetch] + + if filtered: + rows = util.unique_list(rows, filter_fn) + + if context.refresh_state and query._only_load_props \ + and context.refresh_state in context.progress: + context.refresh_state._commit( + context.refresh_state.dict, query._only_load_props) + context.progress.pop(context.refresh_state) + + statelib.InstanceState._commit_all_states( + list(context.progress.items()), + session.identity_map + ) + + for state, (dict_, attrs) in context.partials.items(): + state._commit(dict_, attrs) + + for row in rows: + yield row + + if not query._yield_per: + break + + +@util.dependencies("sqlalchemy.orm.query") +def merge_result(querylib, query, iterator, load=True): + """Merge a result into this :class:`.Query` object's Session.""" + + session = query.session + if load: + # flush current contents if we expect to load data + session._autoflush() + + autoflush = session.autoflush + try: + session.autoflush = False + single_entity = len(query._entities) == 1 + if single_entity: + if isinstance(query._entities[0], querylib._MapperEntity): + result = [session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive={}) + for instance in iterator] + else: + result = list(iterator) + else: + mapped_entities = [i for i, e in enumerate(query._entities) + if isinstance(e, querylib._MapperEntity)] + result = [] + keys = [ent._label_name for ent in query._entities] + for row in iterator: + newrow = list(row) + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, _recursive={}) + result.append(util.KeyedTuple(newrow, keys)) + + return iter(result) + finally: + session.autoflush = autoflush + + +def get_from_identity(session, key, passive): + """Look up the given key in the given session's identity map, + check the object for expired state if found. + + """ + instance = session.identity_map.get(key) + if instance is not None: + + state = attributes.instance_state(instance) + + # expired - ensure it still exists + if state.expired: + if not passive & attributes.SQL_OK: + # TODO: no coverage here + return attributes.PASSIVE_NO_RESULT + elif not passive & attributes.RELATED_OBJECT_OK: + # this mode is used within a flush and the instance's + # expired state will be checked soon enough, if necessary + return instance + try: + state(state, passive) + except orm_exc.ObjectDeletedError: + session._remove_newly_deleted([state]) + return None + return instance + else: + return None + + +def load_on_ident(query, key, + refresh_state=None, lockmode=None, + only_load_props=None): + """Load the given identity key from the database.""" + + if key is not None: + ident = key[1] + else: + ident = None + + if refresh_state is None: + q = query._clone() + q._get_condition() + else: + q = query._clone() + + if ident is not None: + mapper = query._mapper_zero() + + (_get_clause, _get_params) = mapper._get_clause + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in ident: + nones = set([ + _get_params[col].key for col, value in + zip(mapper.primary_key, ident) if value is None + ]) + _get_clause = sql_util.adapt_criterion_to_null( + _get_clause, nones) + + _get_clause = q._adapt_clause(_get_clause, True, False) + q._criterion = _get_clause + + params = dict([ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip(ident, mapper.primary_key) + ]) + + q._params = params + + if lockmode is not None: + version_check = True + q = q.with_lockmode(lockmode) + elif query._for_update_arg is not None: + version_check = True + q._for_update_arg = query._for_update_arg + else: + version_check = False + + q._get_options( + populate_existing=bool(refresh_state), + version_check=version_check, + only_load_props=only_load_props, + refresh_state=refresh_state) + q._order_by = None + + try: + return q.one() + except orm_exc.NoResultFound: + return None + + +def instance_processor(mapper, context, path, adapter, + polymorphic_from=None, + only_load_props=None, + refresh_state=None, + polymorphic_discriminator=None): + + """Produce a mapper level row processor callable + which processes rows into mapped instances.""" + + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + + pk_cols = mapper.primary_key + + if polymorphic_from or refresh_state: + polymorphic_on = None + else: + if polymorphic_discriminator is not None: + polymorphic_on = polymorphic_discriminator + else: + polymorphic_on = mapper.polymorphic_on + polymorphic_instances = util.PopulateDict( + _configure_subclass_mapper( + mapper, + context, path, adapter) + ) + + version_id_col = mapper.version_id_col + + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + if polymorphic_on is not None: + polymorphic_on = adapter.columns[polymorphic_on] + if version_id_col is not None: + version_id_col = adapter.columns[version_id_col] + + identity_class = mapper._identity_class + + new_populators = [] + existing_populators = [] + eager_populators = [] + + load_path = context.query._current_path + path \ + if context.query._current_path.path \ + else path + + def populate_state(state, dict_, row, isnew, only_load_props): + if isnew: + if context.propagate_options: + state.load_options = context.propagate_options + if state.load_options: + state.load_path = load_path + + if not new_populators: + _populators(mapper, context, path, row, adapter, + new_populators, + existing_populators, + eager_populators + ) + + if isnew: + populators = new_populators + else: + populators = existing_populators + + if only_load_props is None: + for key, populator in populators: + populator(state, dict_, row) + elif only_load_props: + for key, populator in populators: + if key in only_load_props: + populator(state, dict_, row) + + session_identity_map = context.session.identity_map + + listeners = mapper.dispatch + + translate_row = listeners.translate_row or None + create_instance = listeners.create_instance or None + populate_instance = listeners.populate_instance or None + append_result = listeners.append_result or None + populate_existing = context.populate_existing or mapper.always_refresh + invoke_all_eagers = context.invoke_all_eagers + + if mapper.allow_partial_pks: + is_not_primary_key = _none_set.issuperset + else: + is_not_primary_key = _none_set.issubset + + def _instance(row, result): + if not new_populators and invoke_all_eagers: + _populators(mapper, context, path, row, adapter, + new_populators, + existing_populators, + eager_populators + ) + + if translate_row: + for fn in translate_row: + ret = fn(mapper, context, row) + if ret is not EXT_CONTINUE: + row = ret + break + + if polymorphic_on is not None: + discriminator = row[polymorphic_on] + if discriminator is not None: + _instance = polymorphic_instances[discriminator] + if _instance: + return _instance(row, result) + + # determine identity key + if refresh_state: + identitykey = refresh_state.key + if identitykey is None: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur within a flush() + identitykey = mapper._identity_key_from_state(refresh_state) + else: + identitykey = ( + identity_class, + tuple([row[column] for column in pk_cols]) + ) + + instance = session_identity_map.get(identitykey) + + if instance is not None: + state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) + + isnew = state.runid != context.runid + currentload = not isnew + loaded_instance = False + + if not currentload and \ + version_id_col is not None and \ + context.version_check and \ + mapper._get_state_attr_by_column( + state, + dict_, + mapper.version_id_col) != \ + row[version_id_col]: + + raise orm_exc.StaleDataError( + "Instance '%s' has version id '%s' which " + "does not match database-loaded version id '%s'." + % (state_str(state), + mapper._get_state_attr_by_column( + state, dict_, + mapper.version_id_col), + row[version_id_col])) + elif refresh_state: + # out of band refresh_state detected (i.e. its not in the + # session.identity_map) honor it anyway. this can happen + # if a _get() occurs within save_obj(), such as + # when eager_defaults is True. + state = refresh_state + instance = state.obj() + dict_ = attributes.instance_dict(instance) + isnew = state.runid != context.runid + currentload = True + loaded_instance = False + else: + # check for non-NULL values in the primary key columns, + # else no entity is returned for the row + if is_not_primary_key(identitykey[1]): + return None + + isnew = True + currentload = True + loaded_instance = True + + if create_instance: + for fn in create_instance: + instance = fn(mapper, context, + row, mapper.class_) + if instance is not EXT_CONTINUE: + manager = attributes.manager_of_class( + instance.__class__) + # TODO: if manager is None, raise a friendly error + # about returning instances of unmapped types + manager.setup_instance(instance) + break + else: + instance = mapper.class_manager.new_instance() + else: + instance = mapper.class_manager.new_instance() + + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + state.key = identitykey + + # attach instance to session. + state.session_id = context.session.hash_key + session_identity_map.add(state) + + if currentload or populate_existing: + # state is being fully loaded, so populate. + # add to the "context.progress" collection. + if isnew: + state.runid = context.runid + context.progress[state] = dict_ + + if populate_instance: + for fn in populate_instance: + ret = fn(mapper, context, row, state, + only_load_props=only_load_props, + instancekey=identitykey, isnew=isnew) + if ret is not EXT_CONTINUE: + break + else: + populate_state(state, dict_, row, isnew, only_load_props) + else: + populate_state(state, dict_, row, isnew, only_load_props) + + if loaded_instance: + state.manager.dispatch.load(state, context) + elif isnew: + state.manager.dispatch.refresh(state, context, only_load_props) + + elif state in context.partials or state.unloaded or eager_populators: + # state is having a partial set of its attributes + # refreshed. Populate those attributes, + # and add to the "context.partials" collection. + if state in context.partials: + isnew = False + (d_, attrs) = context.partials[state] + else: + isnew = True + attrs = state.unloaded + context.partials[state] = (dict_, attrs) + + if populate_instance: + for fn in populate_instance: + ret = fn(mapper, context, row, state, + only_load_props=attrs, + instancekey=identitykey, isnew=isnew) + if ret is not EXT_CONTINUE: + break + else: + populate_state(state, dict_, row, isnew, attrs) + else: + populate_state(state, dict_, row, isnew, attrs) + + for key, pop in eager_populators: + if key not in state.unloaded: + pop(state, dict_, row) + + if isnew: + state.manager.dispatch.refresh(state, context, attrs) + + if result is not None: + if append_result: + for fn in append_result: + if fn(mapper, context, row, state, + result, instancekey=identitykey, + isnew=isnew) is not EXT_CONTINUE: + break + else: + result.append(instance) + else: + result.append(instance) + + return instance + return _instance + + +def _populators(mapper, context, path, row, adapter, + new_populators, existing_populators, eager_populators): + """Produce a collection of attribute level row processor + callables.""" + + delayed_populators = [] + pops = (new_populators, existing_populators, delayed_populators, + eager_populators) + + for prop in mapper._props.values(): + + for i, pop in enumerate(prop.create_row_processor( + context, + path, + mapper, row, adapter)): + if pop is not None: + pops[i].append((prop.key, pop)) + + if delayed_populators: + new_populators.extend(delayed_populators) + + +def _configure_subclass_mapper(mapper, context, path, adapter): + """Produce a mapper level row processor callable factory for mappers + inheriting this one.""" + + def configure_subclass_mapper(discriminator): + try: + sub_mapper = mapper.polymorphic_map[discriminator] + except KeyError: + raise AssertionError( + "No such polymorphic_identity %r is defined" % + discriminator) + if sub_mapper is mapper: + return None + + return instance_processor( + sub_mapper, + context, + path, + adapter, + polymorphic_from=mapper) + return configure_subclass_mapper + + +def load_scalar_attributes(mapper, state, attribute_names): + """initiate a column-based attribute refresh operation.""" + + #assert mapper is _state_mapper(state) + session = state.session + if not session: + raise orm_exc.DetachedInstanceError( + "Instance %s is not bound to a Session; " + "attribute refresh operation cannot proceed" % + (state_str(state))) + + has_key = bool(state.key) + + result = False + + if mapper.inherits and not mapper.concrete: + statement = mapper._optimized_get_statement(state, attribute_names) + if statement is not None: + result = load_on_ident( + session.query(mapper).from_statement(statement), + None, + only_load_props=attribute_names, + refresh_state=state + ) + + if result is False: + if has_key: + identity_key = state.key + else: + # this codepath is rare - only valid when inside a flush, and the + # object is becoming persistent but hasn't yet been assigned + # an identity_key. + # check here to ensure we have the attrs we need. + pk_attrs = [mapper._columntoproperty[col].key + for col in mapper.primary_key] + if state.expired_attributes.intersection(pk_attrs): + raise sa_exc.InvalidRequestError( + "Instance %s cannot be refreshed - it's not " + " persistent and does not " + "contain a full primary key." % state_str(state)) + identity_key = mapper._identity_key_from_state(state) + + if (_none_set.issubset(identity_key) and \ + not mapper.allow_partial_pks) or \ + _none_set.issuperset(identity_key): + util.warn("Instance %s to be refreshed doesn't " + "contain a full primary key - can't be refreshed " + "(and shouldn't be expired, either)." + % state_str(state)) + return + + result = load_on_ident( + session.query(mapper), + identity_key, + refresh_state=state, + only_load_props=attribute_names) + + # if instance is pending, a refresh operation + # may not complete (even if PK attributes are assigned) + if has_key and result is None: + raise orm_exc.ObjectDeletedError(state) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py new file mode 100644 index 00000000..a939cb9c --- /dev/null +++ b/lib/sqlalchemy/orm/mapper.py @@ -0,0 +1,2709 @@ +# orm/mapper.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Logic to map Python classes to and from selectables. + +Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central +configurational unit which associates a class with a database table. + +This is a semi-private module; the main configurational API of the ORM is +available in :class:`~sqlalchemy.orm.`. + +""" +from __future__ import absolute_import + +import types +import weakref +from itertools import chain +from collections import deque + +from .. import sql, util, log, exc as sa_exc, event, schema, inspection +from ..sql import expression, visitors, operators, util as sql_util +from . import instrumentation, attributes, exc as orm_exc, loading +from . import properties +from .interfaces import MapperProperty, _InspectionAttr, _MappedAttribute + +from .base import _class_to_mapper, _state_mapper, class_mapper, \ + state_str, _INSTRUMENTOR +from .path_registry import PathRegistry + +import sys + + +_mapper_registry = weakref.WeakKeyDictionary() +_already_compiling = False + +_memoized_configured_property = util.group_expirable_memoized_property() + + +# a constant returned by _get_attr_by_column to indicate +# this mapper is not handling an attribute for a particular +# column +NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') + +# lock used to synchronize the "mapper configure" step +_CONFIGURE_MUTEX = util.threading.RLock() + + +@inspection._self_inspects +@log.class_logger +class Mapper(_InspectionAttr): + """Define the correlation of class attributes to database table + columns. + + The :class:`.Mapper` object is instantiated using the + :func:`~sqlalchemy.orm.mapper` function. For information + about instantiating new :class:`.Mapper` objects, see + that function's documentation. + + + When :func:`.mapper` is used + explicitly to link a user defined class with table + metadata, this is referred to as *classical mapping*. + Modern SQLAlchemy usage tends to favor the + :mod:`sqlalchemy.ext.declarative` extension for class + configuration, which + makes usage of :func:`.mapper` behind the scenes. + + Given a particular class known to be mapped by the ORM, + the :class:`.Mapper` which maintains it can be acquired + using the :func:`.inspect` function:: + + from sqlalchemy import inspect + + mapper = inspect(MyClass) + + A class which was mapped by the :mod:`sqlalchemy.ext.declarative` + extension will also have its mapper available via the ``__mapper__`` + attribute. + + + """ + + _new_mappers = False + + def __init__(self, + class_, + local_table=None, + properties=None, + primary_key=None, + non_primary=False, + inherits=None, + inherit_condition=None, + inherit_foreign_keys=None, + extension=None, + order_by=False, + always_refresh=False, + version_id_col=None, + version_id_generator=None, + polymorphic_on=None, + _polymorphic_map=None, + polymorphic_identity=None, + concrete=False, + with_polymorphic=None, + allow_partial_pks=True, + batch=True, + column_prefix=None, + include_properties=None, + exclude_properties=None, + passive_updates=True, + confirm_deleted_rows=True, + eager_defaults=False, + legacy_is_orphan=False, + _compiled_cache_size=100, + ): + """Return a new :class:`~.Mapper` object. + + This function is typically used behind the scenes + via the Declarative extension. When using Declarative, + many of the usual :func:`.mapper` arguments are handled + by the Declarative extension itself, including ``class_``, + ``local_table``, ``properties``, and ``inherits``. + Other options are passed to :func:`.mapper` using + the ``__mapper_args__`` class variable:: + + class MyClass(Base): + __tablename__ = 'my_table' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + alt = Column("some_alt", Integer) + + __mapper_args__ = { + 'polymorphic_on' : type + } + + + Explicit use of :func:`.mapper` + is often referred to as *classical mapping*. The above + declarative example is equivalent in classical form to:: + + my_table = Table("my_table", metadata, + Column('id', Integer, primary_key=True), + Column('type', String(50)), + Column("some_alt", Integer) + ) + + class MyClass(object): + pass + + mapper(MyClass, my_table, + polymorphic_on=my_table.c.type, + properties={ + 'alt':my_table.c.some_alt + }) + + .. seealso:: + + :ref:`classical_mapping` - discussion of direct usage of + :func:`.mapper` + + :param class\_: The class to be mapped. When using Declarative, + this argument is automatically passed as the declared class + itself. + + :param local_table: The :class:`.Table` or other selectable + to which the class is mapped. May be ``None`` if + this mapper inherits from another mapper using single-table + inheritance. When using Declarative, this argument is + automatically passed by the extension, based on what + is configured via the ``__table__`` argument or via the + :class:`.Table` produced as a result of the ``__tablename__`` + and :class:`.Column` arguments present. + + :param always_refresh: If True, all query operations for this mapped + class will overwrite all data within object instances that already + exist within the session, erasing any in-memory changes with + whatever information was loaded from the database. Usage of this + flag is highly discouraged; as an alternative, see the method + :meth:`.Query.populate_existing`. + + :param allow_partial_pks: Defaults to True. Indicates that a + composite primary key with some NULL values should be considered as + possibly existing within the database. This affects whether a + mapper will assign an incoming row to an existing identity, as well + as if :meth:`.Session.merge` will check the database first for a + particular primary key value. A "partial primary key" can occur if + one has mapped to an OUTER JOIN, for example. + + :param batch: Defaults to ``True``, indicating that save operations + of multiple entities can be batched together for efficiency. + Setting to False indicates + that an instance will be fully saved before saving the next + instance. This is used in the extremely rare case that a + :class:`.MapperEvents` listener requires being called + in between individual row persistence operations. + + :param column_prefix: A string which will be prepended + to the mapped attribute name when :class:`.Column` + objects are automatically assigned as attributes to the + mapped class. Does not affect explicitly specified + column-based properties. + + See the section :ref:`column_prefix` for an example. + + :param concrete: If True, indicates this mapper should use concrete + table inheritance with its parent mapper. + + See the section :ref:`concrete_inheritance` for an example. + + :param confirm_deleted_rows: defaults to True; when a DELETE occurs + of one more more rows based on specific primary keys, a warning is + emitted when the number of rows matched does not equal the number + of rows expected. This parameter may be set to False to handle the case + where database ON DELETE CASCADE rules may be deleting some of those + rows automatically. The warning may be changed to an exception + in a future release. + + .. versionadded:: 0.9.4 - added :paramref:`.mapper.confirm_deleted_rows` + as well as conditional matched row checking on delete. + + :param eager_defaults: if True, the ORM will immediately fetch the + value of server-generated default values after an INSERT or UPDATE, + rather than leaving them as expired to be fetched on next access. + This can be used for event schemes where the server-generated values + are needed immediately before the flush completes. By default, + this scheme will emit an individual ``SELECT`` statement per row + inserted or updated, which note can add significant performance + overhead. However, if the + target database supports :term:`RETURNING`, the default values will be + returned inline with the INSERT or UPDATE statement, which can + greatly enhance performance for an application that needs frequent + access to just-generated server defaults. + + .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now + make use of :term:`RETURNING` for backends which support it. + + :param exclude_properties: A list or set of string column names to + be excluded from mapping. + + See :ref:`include_exclude_cols` for an example. + + :param extension: A :class:`.MapperExtension` instance or + list of :class:`.MapperExtension` instances which will be applied + to all operations by this :class:`.Mapper`. **Deprecated.** + Please see :class:`.MapperEvents`. + + :param include_properties: An inclusive list or set of string column + names to map. + + See :ref:`include_exclude_cols` for an example. + + :param inherits: A mapped class or the corresponding :class:`.Mapper` + of one indicating a superclass to which this :class:`.Mapper` + should *inherit* from. The mapped class here must be a subclass + of the other mapper's class. When using Declarative, this argument + is passed automatically as a result of the natural class + hierarchy of the declared classes. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param inherit_condition: For joined table inheritance, a SQL + expression which will + define how the two tables are joined; defaults to a natural join + between the two tables. + + :param inherit_foreign_keys: When ``inherit_condition`` is used and the + columns present are missing a :class:`.ForeignKey` configuration, + this parameter can be used to specify which columns are "foreign". + In most cases can be left as ``None``. + + :param legacy_is_orphan: Boolean, defaults to ``False``. + When ``True``, specifies that "legacy" orphan consideration + is to be applied to objects mapped by this mapper, which means + that a pending (that is, not persistent) object is auto-expunged + from an owning :class:`.Session` only when it is de-associated + from *all* parents that specify a ``delete-orphan`` cascade towards + this mapper. The new default behavior is that the object is auto-expunged + when it is de-associated with *any* of its parents that specify + ``delete-orphan`` cascade. This behavior is more consistent with + that of a persistent object, and allows behavior to be consistent + in more scenarios independently of whether or not an orphanable + object has been flushed yet or not. + + See the change note and example at :ref:`legacy_is_orphan_addition` + for more detail on this change. + + .. versionadded:: 0.8 - the consideration of a pending object as + an "orphan" has been modified to more closely match the + behavior as that of persistent objects, which is that the object + is expunged from the :class:`.Session` as soon as it is + de-associated from any of its orphan-enabled parents. Previously, + the pending object would be expunged only if de-associated + from all of its orphan-enabled parents. The new flag ``legacy_is_orphan`` + is added to :func:`.orm.mapper` which re-establishes the + legacy behavior. + + :param non_primary: Specify that this :class:`.Mapper` is in addition + to the "primary" mapper, that is, the one used for persistence. + The :class:`.Mapper` created here may be used for ad-hoc + mapping of the class to an alternate selectable, for loading + only. + + :paramref:`.Mapper.non_primary` is not an often used option, but + is useful in some specific :func:`.relationship` cases. + + .. seealso:: + + :ref:`relationship_non_primary_mapper` + + :param order_by: A single :class:`.Column` or list of :class:`.Column` + objects for which selection operations should use as the default + ordering for entities. By default mappers have no pre-defined + ordering. + + :param passive_updates: Indicates UPDATE behavior of foreign key + columns when a primary key column changes on a joined-table + inheritance mapping. Defaults to ``True``. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will handle + propagation of an UPDATE from a source column to dependent columns + on joined-table rows. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The unit of work process will + emit an UPDATE statement for the dependent columns during a + primary key change. + + .. seealso:: + + :ref:`passive_updates` - description of a similar feature as + used with :func:`.relationship` + + :param polymorphic_on: Specifies the column, attribute, or + SQL expression used to determine the target class for an + incoming row, when inheriting classes are present. + + This value is commonly a :class:`.Column` object that's + present in the mapped :class:`.Table`:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + + __mapper_args__ = { + "polymorphic_on":discriminator, + "polymorphic_identity":"employee" + } + + It may also be specified + as a SQL expression, as in this example where we + use the :func:`.case` construct to provide a conditional + approach:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + + __mapper_args__ = { + "polymorphic_on":case([ + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + ], else_="employee"), + "polymorphic_identity":"employee" + } + + It may also refer to any attribute + configured with :func:`.column_property`, or to the + string name of one:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + employee_type = column_property( + case([ + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + ], else_="employee") + ) + + __mapper_args__ = { + "polymorphic_on":employee_type, + "polymorphic_identity":"employee" + } + + .. versionchanged:: 0.7.4 + ``polymorphic_on`` may be specified as a SQL expression, + or refer to any attribute configured with + :func:`.column_property`, or to the string name of one. + + When setting ``polymorphic_on`` to reference an + attribute or expression that's not present in the + locally mapped :class:`.Table`, yet the value + of the discriminator should be persisted to the database, + the value of the + discriminator is not automatically set on new + instances; this must be handled by the user, + either through manual means or via event listeners. + A typical approach to establishing such a listener + looks like:: + + from sqlalchemy import event + from sqlalchemy.orm import object_mapper + + @event.listens_for(Employee, "init", propagate=True) + def set_identity(instance, *arg, **kw): + mapper = object_mapper(instance) + instance.discriminator = mapper.polymorphic_identity + + Where above, we assign the value of ``polymorphic_identity`` + for the mapped class to the ``discriminator`` attribute, + thus persisting the value to the ``discriminator`` column + in the database. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param polymorphic_identity: Specifies the value which + identifies this particular class as returned by the + column expression referred to by the ``polymorphic_on`` + setting. As rows are received, the value corresponding + to the ``polymorphic_on`` column expression is compared + to this value, indicating which subclass should + be used for the newly reconstructed object. + + :param properties: A dictionary mapping the string names of object + attributes to :class:`.MapperProperty` instances, which define the + persistence behavior of that attribute. Note that :class:`.Column` + objects present in + the mapped :class:`.Table` are automatically placed into + ``ColumnProperty`` instances upon mapping, unless overridden. + When using Declarative, this argument is passed automatically, + based on all those :class:`.MapperProperty` instances declared + in the declared class body. + + :param primary_key: A list of :class:`.Column` objects which define the + primary key to be used against this mapper's selectable unit. + This is normally simply the primary key of the ``local_table``, but + can be overridden here. + + :param version_id_col: A :class:`.Column` + that will be used to keep a running version id of rows + in the table. This is used to detect concurrent updates or + the presence of stale data in a flush. The methodology is to + detect if an UPDATE statement does not match the last known + version id, a + :class:`~sqlalchemy.orm.exc.StaleDataError` exception is + thrown. + By default, the column must be of :class:`.Integer` type, + unless ``version_id_generator`` specifies an alternative version + generator. + + .. seealso:: + + :ref:`mapper_version_counter` - discussion of version counting + and rationale. + + :param version_id_generator: Define how new version ids should + be generated. Defaults to ``None``, which indicates that + a simple integer counting scheme be employed. To provide a custom + versioning scheme, provide a callable function of the form:: + + def generate_version(version): + return next_version + + Alternatively, server-side versioning functions such as triggers, + or programmatic versioning schemes outside of the version id generator + may be used, by specifying the value ``False``. + Please see :ref:`server_side_version_counter` for a discussion + of important points when using this option. + + .. versionadded:: 0.9.0 ``version_id_generator`` supports server-side + version number generation. + + .. seealso:: + + :ref:`custom_version_counter` + + :ref:`server_side_version_counter` + + + :param with_polymorphic: A tuple in the form ``(, + )`` indicating the default style of "polymorphic" + loading, that is, which tables are queried at once. is + any single or list of mappers and/or classes indicating the + inherited classes that should be loaded at once. The special value + ``'*'`` may be used to indicate all descending classes should be + loaded immediately. The second tuple argument + indicates a selectable that will be used to query for multiple + classes. + + .. seealso:: + + :ref:`with_polymorphic` - discussion of polymorphic querying techniques. + + """ + + self.class_ = util.assert_arg_type(class_, type, 'class_') + + self.class_manager = None + + self._primary_key_argument = util.to_list(primary_key) + self.non_primary = non_primary + + if order_by is not False: + self.order_by = util.to_list(order_by) + else: + self.order_by = order_by + + self.always_refresh = always_refresh + + if isinstance(version_id_col, MapperProperty): + self.version_id_prop = version_id_col + self.version_id_col = None + else: + self.version_id_col = version_id_col + if version_id_generator is False: + self.version_id_generator = False + elif version_id_generator is None: + self.version_id_generator = lambda x: (x or 0) + 1 + else: + self.version_id_generator = version_id_generator + + self.concrete = concrete + self.single = False + self.inherits = inherits + self.local_table = local_table + self.inherit_condition = inherit_condition + self.inherit_foreign_keys = inherit_foreign_keys + self._init_properties = properties or {} + self._delete_orphans = [] + self.batch = batch + self.eager_defaults = eager_defaults + self.column_prefix = column_prefix + self.polymorphic_on = expression._clause_element_as_expr( + polymorphic_on) + self._dependency_processors = [] + self.validators = util.immutabledict() + self.passive_updates = passive_updates + self.legacy_is_orphan = legacy_is_orphan + self._clause_adapter = None + self._requires_row_aliasing = False + self._inherits_equated_pairs = None + self._memoized_values = {} + self._compiled_cache_size = _compiled_cache_size + self._reconstructor = None + self._deprecated_extensions = util.to_list(extension or []) + self.allow_partial_pks = allow_partial_pks + + if self.inherits and not self.concrete: + self.confirm_deleted_rows = False + else: + self.confirm_deleted_rows = confirm_deleted_rows + + self._set_with_polymorphic(with_polymorphic) + + if isinstance(self.local_table, expression.SelectBase): + raise sa_exc.InvalidRequestError( + "When mapping against a select() construct, map against " + "an alias() of the construct instead." + "This because several databases don't allow a " + "SELECT from a subquery that does not have an alias." + ) + + if self.with_polymorphic and \ + isinstance(self.with_polymorphic[1], + expression.SelectBase): + self.with_polymorphic = (self.with_polymorphic[0], + self.with_polymorphic[1].alias()) + + # our 'polymorphic identity', a string name that when located in a + # result set row indicates this Mapper should be used to construct + # the object instance for that row. + self.polymorphic_identity = polymorphic_identity + + # a dictionary of 'polymorphic identity' names, associating those + # names with Mappers that will be used to construct object instances + # upon a select operation. + if _polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = _polymorphic_map + + if include_properties is not None: + self.include_properties = util.to_set(include_properties) + else: + self.include_properties = None + if exclude_properties: + self.exclude_properties = util.to_set(exclude_properties) + else: + self.exclude_properties = None + + self.configured = False + + # prevent this mapper from being constructed + # while a configure_mappers() is occurring (and defer a + # configure_mappers() until construction succeeds) + _CONFIGURE_MUTEX.acquire() + try: + self.dispatch._events._new_mapper_instance(class_, self) + self._configure_inheritance() + self._configure_legacy_instrument_class() + self._configure_class_instrumentation() + self._configure_listeners() + self._configure_properties() + self._configure_polymorphic_setter() + self._configure_pks() + Mapper._new_mappers = True + self._log("constructed") + self._expire_memoizations() + finally: + _CONFIGURE_MUTEX.release() + + # major attributes initialized at the classlevel so that + # they can be Sphinx-documented. + + is_mapper = True + """Part of the inspection API.""" + + @property + def mapper(self): + """Part of the inspection API. + + Returns self. + + """ + return self + + @property + def entity(self): + """Part of the inspection API. + + Returns self.class\_. + + """ + return self.class_ + + local_table = None + """The :class:`.Selectable` which this :class:`.Mapper` manages. + + Typically is an instance of :class:`.Table` or :class:`.Alias`. + May also be ``None``. + + The "local" table is the + selectable that the :class:`.Mapper` is directly responsible for + managing from an attribute access and flush perspective. For + non-inheriting mappers, the local table is the same as the + "mapped" table. For joined-table inheritance mappers, local_table + will be the particular sub-table of the overall "join" which + this :class:`.Mapper` represents. If this mapper is a + single-table inheriting mapper, local_table will be ``None``. + + .. seealso:: + + :attr:`~.Mapper.mapped_table`. + + """ + + mapped_table = None + """The :class:`.Selectable` to which this :class:`.Mapper` is mapped. + + Typically an instance of :class:`.Table`, :class:`.Join`, or + :class:`.Alias`. + + The "mapped" table is the selectable that + the mapper selects from during queries. For non-inheriting + mappers, the mapped table is the same as the "local" table. + For joined-table inheritance mappers, mapped_table references the + full :class:`.Join` representing full rows for this particular + subclass. For single-table inheritance mappers, mapped_table + references the base table. + + .. seealso:: + + :attr:`~.Mapper.local_table`. + + """ + + inherits = None + """References the :class:`.Mapper` which this :class:`.Mapper` + inherits from, if any. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + configured = None + """Represent ``True`` if this :class:`.Mapper` has been configured. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + .. seealso:: + + :func:`.configure_mappers`. + + """ + + concrete = None + """Represent ``True`` if this :class:`.Mapper` is a concrete + inheritance mapper. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + tables = None + """An iterable containing the collection of :class:`.Table` objects + which this :class:`.Mapper` is aware of. + + If the mapper is mapped to a :class:`.Join`, or an :class:`.Alias` + representing a :class:`.Select`, the individual :class:`.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + primary_key = None + """An iterable containing the collection of :class:`.Column` objects + which comprise the 'primary key' of the mapped table, from the + perspective of this :class:`.Mapper`. + + This list is against the selectable in :attr:`~.Mapper.mapped_table`. In + the case of inheriting mappers, some columns may be managed by a + superclass mapper. For example, in the case of a :class:`.Join`, the + primary key is determined by all of the primary key columns across all + tables referenced by the :class:`.Join`. + + The list is also not necessarily the same as the primary key column + collection associated with the underlying tables; the :class:`.Mapper` + features a ``primary_key`` argument that can override what the + :class:`.Mapper` considers as primary key columns. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_ = None + """The Python class which this :class:`.Mapper` maps. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_manager = None + """The :class:`.ClassManager` which maintains event listeners + and class-bound descriptors for this :class:`.Mapper`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + single = None + """Represent ``True`` if this :class:`.Mapper` is a single table + inheritance mapper. + + :attr:`~.Mapper.local_table` will be ``None`` if this flag is set. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + non_primary = None + """Represent ``True`` if this :class:`.Mapper` is a "non-primary" + mapper, e.g. a mapper that is used only to selet rows but not for + persistence management. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_on = None + """The :class:`.Column` or SQL expression specified as the + ``polymorphic_on`` argument + for this :class:`.Mapper`, within an inheritance scenario. + + This attribute is normally a :class:`.Column` instance but + may also be an expression, such as one derived from + :func:`.cast`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_map = None + """A mapping of "polymorphic identity" identifiers mapped to + :class:`.Mapper` instances, within an inheritance scenario. + + The identifiers can be of any type which is comparable to the + type of column represented by :attr:`~.Mapper.polymorphic_on`. + + An inheritance chain of mappers will all reference the same + polymorphic map object. The object is used to correlate incoming + result rows to target mappers. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_identity = None + """Represent an identifier which is matched against the + :attr:`~.Mapper.polymorphic_on` column during result row loading. + + Used only with inheritance, this object can be of any type which is + comparable to the type of column represented by + :attr:`~.Mapper.polymorphic_on`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + base_mapper = None + """The base-most :class:`.Mapper` in an inheritance chain. + + In a non-inheriting scenario, this attribute will always be this + :class:`.Mapper`. In an inheritance scenario, it references + the :class:`.Mapper` which is parent to all other :class:`.Mapper` + objects in the inheritance chain. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + columns = None + """A collection of :class:`.Column` or other scalar expression + objects maintained by this :class:`.Mapper`. + + The collection behaves the same as that of the ``c`` attribute on + any :class:`.Table` object, except that only those columns included in + this mapping are present, and are keyed based on the attribute name + defined in the mapping, not necessarily the ``key`` attribute of the + :class:`.Column` itself. Additionally, scalar expressions mapped + by :func:`.column_property` are also present here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators = None + """An immutable dictionary of attributes which have been decorated + using the :func:`~.orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + c = None + """A synonym for :attr:`~.Mapper.columns`.""" + + @util.memoized_property + def _path_registry(self): + return PathRegistry.per_mapper(self) + + def _configure_inheritance(self): + """Configure settings related to inherting and/or inherited mappers + being present.""" + + # a set of all mappers which inherit from this one. + self._inheriting_mappers = util.WeakSequence() + + if self.inherits: + if isinstance(self.inherits, type): + self.inherits = class_mapper(self.inherits, configure=False) + if not issubclass(self.class_, self.inherits.class_): + raise sa_exc.ArgumentError( + "Class '%s' does not inherit from '%s'" % + (self.class_.__name__, self.inherits.class_.__name__)) + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" + raise sa_exc.ArgumentError( + "Inheritance of %s mapper for class '%s' is " + "only allowed from a %s mapper" % + (np, self.class_.__name__, np)) + # inherit_condition is optional. + if self.local_table is None: + self.local_table = self.inherits.local_table + self.mapped_table = self.inherits.mapped_table + self.single = True + elif not self.local_table is self.inherits.local_table: + if self.concrete: + self.mapped_table = self.local_table + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + else: + if self.inherit_condition is None: + # figure out inherit condition from our table to the + # immediate table of the inherited mapper, not its + # full table which could pull in other stuff we dont + # want (allows test/inheritance.InheritTest4 to pass) + self.inherit_condition = sql_util.join_condition( + self.inherits.local_table, + self.local_table) + self.mapped_table = sql.join( + self.inherits.mapped_table, + self.local_table, + self.inherit_condition) + + fks = util.to_set(self.inherit_foreign_keys) + self._inherits_equated_pairs = sql_util.criterion_as_pairs( + self.mapped_table.onclause, + consider_as_foreign_keys=fks) + else: + self.mapped_table = self.local_table + + if self.polymorphic_identity is not None and not self.concrete: + self._identity_class = self.inherits._identity_class + else: + self._identity_class = self.class_ + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col + self.version_id_generator = self.inherits.version_id_generator + elif self.inherits.version_id_col is not None and \ + self.version_id_col is not self.inherits.version_id_col: + util.warn( + "Inheriting version_id_col '%s' does not match inherited " + "version_id_col '%s' and will not automatically populate " + "the inherited versioning column. " + "version_id_col should only be specified on " + "the base-most mapper that includes versioning." % + (self.version_id_col.description, + self.inherits.version_id_col.description) + ) + + if self.order_by is False and \ + not self.concrete and \ + self.inherits.order_by is not False: + self.order_by = self.inherits.order_by + + self.polymorphic_map = self.inherits.polymorphic_map + self.batch = self.inherits.batch + self.inherits._inheriting_mappers.append(self) + self.base_mapper = self.inherits.base_mapper + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + + else: + self._all_tables = set() + self.base_mapper = self + self.mapped_table = self.local_table + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + self._identity_class = self.class_ + + if self.mapped_table is None: + raise sa_exc.ArgumentError( + "Mapper '%s' does not have a mapped_table specified." + % self) + + def _set_with_polymorphic(self, with_polymorphic): + if with_polymorphic == '*': + self.with_polymorphic = ('*', None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance(with_polymorphic[0], util.string_types + (tuple, list)): + self.with_polymorphic = with_polymorphic + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + else: + self.with_polymorphic = None + + if isinstance(self.local_table, expression.SelectBase): + raise sa_exc.InvalidRequestError( + "When mapping against a select() construct, map against " + "an alias() of the construct instead." + "This because several databases don't allow a " + "SELECT from a subquery that does not have an alias." + ) + + if self.with_polymorphic and \ + isinstance(self.with_polymorphic[1], + expression.SelectBase): + self.with_polymorphic = (self.with_polymorphic[0], + self.with_polymorphic[1].alias()) + if self.configured: + self._expire_memoizations() + + def _set_concrete_base(self, mapper): + """Set the given :class:`.Mapper` as the 'inherits' for this + :class:`.Mapper`, assuming this :class:`.Mapper` is concrete + and does not already have an inherits.""" + + assert self.concrete + assert not self.inherits + assert isinstance(mapper, Mapper) + self.inherits = mapper + self.inherits.polymorphic_map.update(self.polymorphic_map) + self.polymorphic_map = self.inherits.polymorphic_map + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + self.batch = self.inherits.batch + for mp in self.self_and_descendants: + mp.base_mapper = self.inherits.base_mapper + self.inherits._inheriting_mappers.append(self) + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + for key, prop in mapper._props.items(): + if key not in self._props and \ + not self._should_exclude(key, key, local=False, + column=None): + self._adapt_inherited_property(key, prop, False) + + def _set_polymorphic_on(self, polymorphic_on): + self.polymorphic_on = polymorphic_on + self._configure_polymorphic_setter(True) + + def _configure_legacy_instrument_class(self): + + if self.inherits: + self.dispatch._update(self.inherits.dispatch) + super_extensions = set( + chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) + else: + super_extensions = set() + + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_instrument_class(self, ext) + + def _configure_listeners(self): + if self.inherits: + super_extensions = set( + chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) + else: + super_extensions = set() + + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_listener(self, ext) + + def _configure_class_instrumentation(self): + """If this mapper is to be a primary mapper (i.e. the + non_primary flag is not set), associate this Mapper with the + given class_ and entity name. + + Subsequent calls to ``class_mapper()`` for the class_/entity + name combination will return this mapper. Also decorate the + `__init__` method on the mapped class to include optional + auto-session attachment logic. + + """ + manager = attributes.manager_of_class(self.class_) + + if self.non_primary: + if not manager or not manager.is_mapped: + raise sa_exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper." % self.class_) + self.class_manager = manager + self._identity_class = manager.mapper._identity_class + _mapper_registry[self] = True + return + + if manager is not None: + assert manager.class_ is self.class_ + if manager.is_mapped: + raise sa_exc.ArgumentError( + "Class '%s' already has a primary mapper defined. " + "Use non_primary=True to " + "create a non primary Mapper. clear_mappers() will " + "remove *all* current mappers from all classes." % + self.class_) + #else: + # a ClassManager may already exist as + # ClassManager.instrument_attribute() creates + # new managers for each subclass if they don't yet exist. + + _mapper_registry[self] = True + + self.dispatch.instrument_class(self, self.class_) + + if manager is None: + manager = instrumentation.register_class(self.class_) + + self.class_manager = manager + + manager.mapper = self + manager.deferred_scalar_loader = util.partial( + loading.load_scalar_attributes, self) + + # The remaining members can be added by any mapper, + # e_name None or not. + if manager.info.get(_INSTRUMENTOR, False): + return + + event.listen(manager, 'first_init', _event_on_first_init, raw=True) + event.listen(manager, 'init', _event_on_init, raw=True) + event.listen(manager, 'resurrect', _event_on_resurrect, raw=True) + + for key, method in util.iterate_attributes(self.class_): + if isinstance(method, types.FunctionType): + if hasattr(method, '__sa_reconstructor__'): + self._reconstructor = method + event.listen(manager, 'load', _event_on_load, raw=True) + elif hasattr(method, '__sa_validators__'): + validation_opts = method.__sa_validation_opts__ + for name in method.__sa_validators__: + self.validators = self.validators.union( + {name: (method, validation_opts)} + ) + + manager.info[_INSTRUMENTOR] = self + + + @classmethod + def _configure_all(cls): + """Class-level path to the :func:`.configure_mappers` call. + """ + configure_mappers() + + def dispose(self): + # Disable any attribute-based compilation. + self.configured = True + + if hasattr(self, '_configure_failed'): + del self._configure_failed + + if not self.non_primary and \ + self.class_manager is not None and \ + self.class_manager.is_mapped and \ + self.class_manager.mapper is self: + instrumentation.unregister_class(self.class_) + + def _configure_pks(self): + + self.tables = sql_util.find_tables(self.mapped_table) + + self._pks_by_table = {} + self._cols_by_table = {} + + all_cols = util.column_set(chain(*[ + col.proxy_set for col in + self._columntoproperty])) + + pk_cols = util.column_set(c for c in all_cols if c.primary_key) + + # identify primary key columns which are also mapped by this mapper. + tables = set(self.tables + [self.mapped_table]) + self._all_tables.update(tables) + for t in tables: + if t.primary_key and pk_cols.issuperset(t.primary_key): + # ordering is important since it determines the ordering of + # mapper.primary_key (and therefore query.get()) + self._pks_by_table[t] = \ + util.ordered_column_set(t.primary_key).\ + intersection(pk_cols) + self._cols_by_table[t] = \ + util.ordered_column_set(t.c).\ + intersection(all_cols) + + # determine cols that aren't expressed within our tables; mark these + # as "read only" properties which are refreshed upon INSERT/UPDATE + self._readonly_props = set( + self._columntoproperty[col] + for col in self._columntoproperty + if not hasattr(col, 'table') or + col.table not in self._cols_by_table) + + # if explicit PK argument sent, add those columns to the + # primary key mappings + if self._primary_key_argument: + for k in self._primary_key_argument: + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + # otherwise, see that we got a full PK for the mapped table + elif self.mapped_table not in self._pks_by_table or \ + len(self._pks_by_table[self.mapped_table]) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % + (self, self.mapped_table.description)) + elif self.local_table not in self._pks_by_table and \ + isinstance(self.local_table, schema.Table): + util.warn("Could not assemble any primary " + "keys for locally mapped table '%s' - " + "no rows will be persisted in this Table." + % self.local_table.description) + + if self.inherits and \ + not self.concrete and \ + not self._primary_key_argument: + # if inheriting, the "primary key" for this mapper is + # that of the inheriting (unless concrete or explicit) + self.primary_key = self.inherits.primary_key + else: + # determine primary key from argument or mapped_table pks - + # reduce to the minimal set of columns + if self._primary_key_argument: + primary_key = sql_util.reduce_columns( + [self.mapped_table.corresponding_column(c) for c in + self._primary_key_argument], + ignore_nonexistent_tables=True) + else: + primary_key = sql_util.reduce_columns( + self._pks_by_table[self.mapped_table], + ignore_nonexistent_tables=True) + + if len(primary_key) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % + (self, self.mapped_table.description)) + + self.primary_key = tuple(primary_key) + self._log("Identified primary key columns: %s", primary_key) + + def _configure_properties(self): + + # Column and other ClauseElement objects which are mapped + self.columns = self.c = util.OrderedProperties() + + # object attribute names mapped to MapperProperty objects + self._props = util.OrderedDict() + + # table columns mapped to lists of MapperProperty objects + # using a list allows a single column to be defined as + # populating multiple object attributes + self._columntoproperty = _ColumnMapping(self) + + # load custom properties + if self._init_properties: + for key, prop in self._init_properties.items(): + self._configure_property(key, prop, False) + + # pull properties from the inherited mapper if any. + if self.inherits: + for key, prop in self.inherits._props.items(): + if key not in self._props and \ + not self._should_exclude(key, key, local=False, + column=None): + self._adapt_inherited_property(key, prop, False) + + # create properties for each column in the mapped table, + # for those columns which don't already map to a property + for column in self.mapped_table.columns: + if column in self._columntoproperty: + continue + + column_key = (self.column_prefix or '') + column.key + + if self._should_exclude( + column.key, column_key, + local=self.local_table.c.contains_column(column), + column=column + ): + continue + + # adjust the "key" used for this column to that + # of the inheriting mapper + for mapper in self.iterate_to_root(): + if column in mapper._columntoproperty: + column_key = mapper._columntoproperty[column].key + + self._configure_property(column_key, + column, + init=False, + setparent=True) + + def _configure_polymorphic_setter(self, init=False): + """Configure an attribute on the mapper representing the + 'polymorphic_on' column, if applicable, and not + already generated by _configure_properties (which is typical). + + Also create a setter function which will assign this + attribute to the value of the 'polymorphic_identity' + upon instance construction, also if applicable. This + routine will run when an instance is created. + + """ + setter = False + + if self.polymorphic_on is not None: + setter = True + + if isinstance(self.polymorphic_on, util.string_types): + # polymorphic_on specified as as string - link + # it to mapped ColumnProperty + try: + self.polymorphic_on = self._props[self.polymorphic_on] + except KeyError: + raise sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on) + + if self.polymorphic_on in self._columntoproperty: + # polymorphic_on is a column that is already mapped + # to a ColumnProperty + prop = self._columntoproperty[self.polymorphic_on] + polymorphic_key = prop.key + self.polymorphic_on = prop.columns[0] + polymorphic_key = prop.key + elif isinstance(self.polymorphic_on, MapperProperty): + # polymorphic_on is directly a MapperProperty, + # ensure it's a ColumnProperty + if not isinstance(self.polymorphic_on, + properties.ColumnProperty): + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on") + prop = self.polymorphic_on + self.polymorphic_on = prop.columns[0] + polymorphic_key = prop.key + elif not expression._is_column(self.polymorphic_on): + # polymorphic_on is not a Column and not a ColumnProperty; + # not supported right now. + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on" + ) + else: + # polymorphic_on is a Column or SQL expression and + # doesn't appear to be mapped. this means it can be 1. + # only present in the with_polymorphic selectable or + # 2. a totally standalone SQL expression which we'd + # hope is compatible with this mapper's mapped_table + col = self.mapped_table.corresponding_column( + self.polymorphic_on) + if col is None: + # polymorphic_on doesn't derive from any + # column/expression isn't present in the mapped + # table. we will make a "hidden" ColumnProperty + # for it. Just check that if it's directly a + # schema.Column and we have with_polymorphic, it's + # likely a user error if the schema.Column isn't + # represented somehow in either mapped_table or + # with_polymorphic. Otherwise as of 0.7.4 we + # just go with it and assume the user wants it + # that way (i.e. a CASE statement) + setter = False + instrument = False + col = self.polymorphic_on + if isinstance(col, schema.Column) and ( + self.with_polymorphic is None or \ + self.with_polymorphic[1].\ + corresponding_column(col) is None + ): + raise sa_exc.InvalidRequestError( + "Could not map polymorphic_on column " + "'%s' to the mapped table - polymorphic " + "loads will not function properly" + % col.description) + else: + # column/expression that polymorphic_on derives from + # is present in our mapped table + # and is probably mapped, but polymorphic_on itself + # is not. This happens when + # the polymorphic_on is only directly present in the + # with_polymorphic selectable, as when use + # polymorphic_union. + # we'll make a separate ColumnProperty for it. + instrument = True + key = getattr(col, 'key', None) + if key: + if self._should_exclude(col.key, col.key, False, col): + raise sa_exc.InvalidRequestError( + "Cannot exclude or override the " + "discriminator column %r" % + col.key) + else: + self.polymorphic_on = col = \ + col.label("_sa_polymorphic_on") + key = col.key + + self._configure_property( + key, + properties.ColumnProperty(col, + _instrument=instrument), + init=init, setparent=True) + polymorphic_key = key + else: + # no polymorphic_on was set. + # check inheriting mappers for one. + for mapper in self.iterate_to_root(): + # determine if polymorphic_on of the parent + # should be propagated here. If the col + # is present in our mapped table, or if our mapped + # table is the same as the parent (i.e. single table + # inheritance), we can use it + if mapper.polymorphic_on is not None: + if self.mapped_table is mapper.mapped_table: + self.polymorphic_on = mapper.polymorphic_on + else: + self.polymorphic_on = \ + self.mapped_table.corresponding_column( + mapper.polymorphic_on) + # we can use the parent mapper's _set_polymorphic_identity + # directly; it ensures the polymorphic_identity of the + # instance's mapper is used so is portable to subclasses. + if self.polymorphic_on is not None: + self._set_polymorphic_identity = \ + mapper._set_polymorphic_identity + self._validate_polymorphic_identity = \ + mapper._validate_polymorphic_identity + else: + self._set_polymorphic_identity = None + return + + if setter: + def _set_polymorphic_identity(state): + dict_ = state.dict + state.get_impl(polymorphic_key).set(state, dict_, + state.manager.mapper.polymorphic_identity, None) + + def _validate_polymorphic_identity(mapper, state, dict_): + if polymorphic_key in dict_ and \ + dict_[polymorphic_key] not in \ + mapper._acceptable_polymorphic_identities: + util.warn( + "Flushing object %s with " + "incompatible polymorphic identity %r; the " + "object may not refresh and/or load correctly" % ( + state_str(state), + dict_[polymorphic_key] + ) + ) + + self._set_polymorphic_identity = _set_polymorphic_identity + self._validate_polymorphic_identity = _validate_polymorphic_identity + else: + self._set_polymorphic_identity = None + + + _validate_polymorphic_identity = None + + @_memoized_configured_property + def _version_id_prop(self): + if self.version_id_col is not None: + return self._columntoproperty[self.version_id_col] + else: + return None + + @_memoized_configured_property + def _acceptable_polymorphic_identities(self): + identities = set() + + stack = deque([self]) + while stack: + item = stack.popleft() + if item.mapped_table is self.mapped_table: + identities.add(item.polymorphic_identity) + stack.extend(item._inheriting_mappers) + + return identities + + def _adapt_inherited_property(self, key, prop, init): + if not self.concrete: + self._configure_property(key, prop, init=False, setparent=False) + elif key not in self._props: + self._configure_property( + key, + properties.ConcreteInheritedProperty(), + init=init, setparent=True) + + def _configure_property(self, key, prop, init=True, setparent=True): + self._log("_configure_property(%s, %s)", key, prop.__class__.__name__) + + if not isinstance(prop, MapperProperty): + prop = self._property_from_column(key, prop) + + if isinstance(prop, properties.ColumnProperty): + col = self.mapped_table.corresponding_column(prop.columns[0]) + + # if the column is not present in the mapped table, + # test if a column has been added after the fact to the + # parent table (or their parent, etc.) [ticket:1570] + if col is None and self.inherits: + path = [self] + for m in self.inherits.iterate_to_root(): + col = m.local_table.corresponding_column(prop.columns[0]) + if col is not None: + for m2 in path: + m2.mapped_table._reset_exported() + col = self.mapped_table.corresponding_column( + prop.columns[0]) + break + path.append(m) + + # subquery expression, column not present in the mapped + # selectable. + if col is None: + col = prop.columns[0] + + # column is coming in after _readonly_props was + # initialized; check for 'readonly' + if hasattr(self, '_readonly_props') and \ + (not hasattr(col, 'table') or + col.table not in self._cols_by_table): + self._readonly_props.add(prop) + + else: + # if column is coming in after _cols_by_table was + # initialized, ensure the col is in the right set + if hasattr(self, '_cols_by_table') and \ + col.table in self._cols_by_table and \ + col not in self._cols_by_table[col.table]: + self._cols_by_table[col.table].add(col) + + # if this properties.ColumnProperty represents the "polymorphic + # discriminator" column, mark it. We'll need this when rendering + # columns in SELECT statements. + if not hasattr(prop, '_is_polymorphic_discriminator'): + prop._is_polymorphic_discriminator = \ + (col is self.polymorphic_on or + prop.columns[0] is self.polymorphic_on) + + self.columns[key] = col + for col in prop.columns + prop._orig_columns: + for col in col.proxy_set: + self._columntoproperty[col] = prop + + prop.key = key + + if setparent: + prop.set_parent(self, init) + + if key in self._props and \ + getattr(self._props[key], '_mapped_by_synonym', False): + syn = self._props[key]._mapped_by_synonym + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % (syn, key, key, syn) + ) + + if key in self._props and \ + not isinstance(prop, properties.ColumnProperty) and \ + not isinstance(self._props[key], properties.ColumnProperty): + util.warn("Property %s on %s being replaced with new " + "property %s; the old property will be discarded" % ( + self._props[key], + self, + prop, + )) + + self._props[key] = prop + + if not self.non_primary: + prop.instrument_class(self) + + for mapper in self._inheriting_mappers: + mapper._adapt_inherited_property(key, prop, init) + + if init: + prop.init() + prop.post_instrument_class(self) + + if self.configured: + self._expire_memoizations() + + def _property_from_column(self, key, prop): + """generate/update a :class:`.ColumnProprerty` given a + :class:`.Column` object. """ + + # we were passed a Column or a list of Columns; + # generate a properties.ColumnProperty + columns = util.to_list(prop) + column = columns[0] + if not expression._is_column(column): + raise sa_exc.ArgumentError( + "%s=%r is not an instance of MapperProperty or Column" + % (key, prop)) + + prop = self._props.get(key, None) + + if isinstance(prop, properties.ColumnProperty): + if prop.parent is self: + raise sa_exc.InvalidRequestError( + "Implicitly combining column %s with column " + "%s under attribute '%s'. Please configure one " + "or more attributes for these same-named columns " + "explicitly." + % (prop.columns[-1], column, key)) + + # existing properties.ColumnProperty from an inheriting + # mapper. make a copy and append our column to it + prop = prop.copy() + prop.columns.insert(0, column) + self._log("inserting column to existing list " + "in properties.ColumnProperty %s" % (key)) + return prop + elif prop is None or isinstance(prop, + properties.ConcreteInheritedProperty): + mapped_column = [] + for c in columns: + mc = self.mapped_table.corresponding_column(c) + if mc is None: + mc = self.local_table.corresponding_column(c) + if mc is not None: + # if the column is in the local table but not the + # mapped table, this corresponds to adding a + # column after the fact to the local table. + # [ticket:1523] + self.mapped_table._reset_exported() + mc = self.mapped_table.corresponding_column(c) + if mc is None: + raise sa_exc.ArgumentError( + "When configuring property '%s' on %s, " + "column '%s' is not represented in the mapper's " + "table. Use the `column_property()` function to " + "force this column to be mapped as a read-only " + "attribute." % (key, self, c)) + mapped_column.append(mc) + return properties.ColumnProperty(*mapped_column) + else: + raise sa_exc.ArgumentError( + "WARNING: when configuring property '%s' on %s, " + "column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a " + "different name in the 'properties' dictionary. Or, " + "to remove all awareness of the column entirely " + "(including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' " + "mapper arguments to control specifically which table " + "columns get mapped." % + (key, self, column.key, prop)) + + def _post_configure_properties(self): + """Call the ``init()`` method on all ``MapperProperties`` + attached to this mapper. + + This is a deferred configuration step which is intended + to execute once all mappers have been constructed. + + """ + + self._log("_post_configure_properties() started") + l = [(key, prop) for key, prop in self._props.items()] + for key, prop in l: + self._log("initialize prop %s", key) + + if prop.parent is self and not prop._configure_started: + prop.init() + + if prop._configure_finished: + prop.post_instrument_class(self) + + self._log("_post_configure_properties() complete") + self.configured = True + + def add_properties(self, dict_of_properties): + """Add the given dictionary of properties to this mapper, + using `add_property`. + + """ + for key, value in dict_of_properties.items(): + self.add_property(key, value) + + def add_property(self, key, prop): + """Add an individual MapperProperty to this mapper. + + If the mapper has not been configured yet, just adds the + property to the initial properties dictionary sent to the + constructor. If this Mapper has already been configured, then + the given MapperProperty is configured immediately. + + """ + self._init_properties[key] = prop + self._configure_property(key, prop, init=self.configured) + + def _expire_memoizations(self): + for mapper in self.iterate_to_root(): + _memoized_configured_property.expire_instance(mapper) + + @property + def _log_desc(self): + return "(" + self.class_.__name__ + \ + "|" + \ + (self.local_table is not None and + self.local_table.description or + str(self.local_table)) +\ + (self.non_primary and + "|non-primary" or "") + ")" + + def _log(self, msg, *args): + self.logger.info( + "%s " + msg, *((self._log_desc,) + args) + ) + + def _log_debug(self, msg, *args): + self.logger.debug( + "%s " + msg, *((self._log_desc,) + args) + ) + + def __repr__(self): + return '' % ( + id(self), self.class_.__name__) + + def __str__(self): + return "Mapper|%s|%s%s" % ( + self.class_.__name__, + self.local_table is not None and + self.local_table.description or None, + self.non_primary and "|non-primary" or "" + ) + + def _is_orphan(self, state): + orphan_possible = False + for mapper in self.iterate_to_root(): + for (key, cls) in mapper._delete_orphans: + orphan_possible = True + + has_parent = attributes.manager_of_class(cls).has_parent( + state, key, optimistic=state.has_identity) + + if self.legacy_is_orphan and has_parent: + return False + elif not self.legacy_is_orphan and not has_parent: + return True + + if self.legacy_is_orphan: + return orphan_possible + else: + return False + + def has_property(self, key): + return key in self._props + + def get_property(self, key, _configure_mappers=True): + """return a MapperProperty associated with the given key. + """ + + if _configure_mappers and Mapper._new_mappers: + configure_mappers() + + try: + return self._props[key] + except KeyError: + raise sa_exc.InvalidRequestError( + "Mapper '%s' has no property '%s'" % (self, key)) + + def get_property_by_column(self, column): + """Given a :class:`.Column` object, return the + :class:`.MapperProperty` which maps this column.""" + + return self._columntoproperty[column] + + @property + def iterate_properties(self): + """return an iterator of all MapperProperty objects.""" + if Mapper._new_mappers: + configure_mappers() + return iter(self._props.values()) + + def _mappers_from_spec(self, spec, selectable): + """given a with_polymorphic() argument, return the set of mappers it + represents. + + Trims the list of mappers to just those represented within the given + selectable, if present. This helps some more legacy-ish mappings. + + """ + if spec == '*': + mappers = list(self.self_and_descendants) + elif spec: + mappers = set() + for m in util.to_list(spec): + m = _class_to_mapper(m) + if not m.isa(self): + raise sa_exc.InvalidRequestError( + "%r does not inherit from %r" % + (m, self)) + + if selectable is None: + mappers.update(m.iterate_to_root()) + else: + mappers.add(m) + mappers = [m for m in self.self_and_descendants if m in mappers] + else: + mappers = [] + + if selectable is not None: + tables = set(sql_util.find_tables(selectable, + include_aliases=True)) + mappers = [m for m in mappers if m.local_table in tables] + return mappers + + def _selectable_from_mappers(self, mappers, innerjoin): + """given a list of mappers (assumed to be within this mapper's + inheritance hierarchy), construct an outerjoin amongst those mapper's + mapped tables. + + """ + from_obj = self.mapped_table + for m in mappers: + if m is self: + continue + if m.concrete: + raise sa_exc.InvalidRequestError( + "'with_polymorphic()' requires 'selectable' argument " + "when concrete-inheriting mappers are used.") + elif not m.single: + if innerjoin: + from_obj = from_obj.join(m.local_table, + m.inherit_condition) + else: + from_obj = from_obj.outerjoin(m.local_table, + m.inherit_condition) + + return from_obj + + @_memoized_configured_property + def _single_table_criterion(self): + if self.single and \ + self.inherits and \ + self.polymorphic_on is not None: + return self.polymorphic_on.in_( + m.polymorphic_identity + for m in self.self_and_descendants) + else: + return None + + @_memoized_configured_property + def _with_polymorphic_mappers(self): + if Mapper._new_mappers: + configure_mappers() + if not self.with_polymorphic: + return [] + return self._mappers_from_spec(*self.with_polymorphic) + + @_memoized_configured_property + def _with_polymorphic_selectable(self): + if not self.with_polymorphic: + return self.mapped_table + + spec, selectable = self.with_polymorphic + if selectable is not None: + return selectable + else: + return self._selectable_from_mappers( + self._mappers_from_spec(spec, selectable), + False) + + with_polymorphic_mappers = _with_polymorphic_mappers + """The list of :class:`.Mapper` objects included in the + default "polymorphic" query. + + """ + + @property + def selectable(self): + """The :func:`.select` construct this :class:`.Mapper` selects from + by default. + + Normally, this is equivalent to :attr:`.mapped_table`, unless + the ``with_polymorphic`` feature is in use, in which case the + full "polymorphic" selectable is returned. + + """ + return self._with_polymorphic_selectable + + def _with_polymorphic_args(self, spec=None, selectable=False, + innerjoin=False): + if self.with_polymorphic: + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + elif selectable is False: + selectable = None + mappers = self._mappers_from_spec(spec, selectable) + if selectable is not None: + return mappers, selectable + else: + return mappers, self._selectable_from_mappers(mappers, + innerjoin) + + @_memoized_configured_property + def _polymorphic_properties(self): + return list(self._iterate_polymorphic_properties( + self._with_polymorphic_mappers)) + + + def _iterate_polymorphic_properties(self, mappers=None): + """Return an iterator of MapperProperty objects which will render into + a SELECT.""" + if mappers is None: + mappers = self._with_polymorphic_mappers + + if not mappers: + for c in self.iterate_properties: + yield c + else: + # in the polymorphic case, filter out discriminator columns + # from other mappers, as these are sometimes dependent on that + # mapper's polymorphic selectable (which we don't want rendered) + for c in util.unique_list( + chain(*[ + list(mapper.iterate_properties) for mapper in + [self] + mappers + ]) + ): + if getattr(c, '_is_polymorphic_discriminator', False) and \ + (self.polymorphic_on is None or + c.columns[0] is not self.polymorphic_on): + continue + yield c + + @util.memoized_property + def attrs(self): + """A namespace of all :class:`.MapperProperty` objects + associated this mapper. + + This is an object that provides each property based on + its key name. For instance, the mapper for a + ``User`` class which has ``User.name`` attribute would + provide ``mapper.attrs.name``, which would be the + :class:`.ColumnProperty` representing the ``name`` + column. The namespace object can also be iterated, + which would yield each :class:`.MapperProperty`. + + :class:`.Mapper` has several pre-filtered views + of this attribute which limit the types of properties + returned, inclding :attr:`.synonyms`, :attr:`.column_attrs`, + :attr:`.relationships`, and :attr:`.composites`. + + .. seealso:: + + :attr:`.Mapper.all_orm_descriptors` + + """ + if Mapper._new_mappers: + configure_mappers() + return util.ImmutableProperties(self._props) + + @util.memoized_property + def all_orm_descriptors(self): + """A namespace of all :class:`._InspectionAttr` attributes associated + with the mapped class. + + These attributes are in all cases Python :term:`descriptors` associated + with the mapped class or its superclasses. + + This namespace includes attributes that are mapped to the class + as well as attributes declared by extension modules. + It includes any Python descriptor type that inherits from + :class:`._InspectionAttr`. This includes :class:`.QueryableAttribute`, + as well as extension types such as :class:`.hybrid_property`, + :class:`.hybrid_method` and :class:`.AssociationProxy`. + + To distinguish between mapped attributes and extension attributes, + the attribute :attr:`._InspectionAttr.extension_type` will refer + to a constant that distinguishes between different extension types. + + When dealing with a :class:`.QueryableAttribute`, the + :attr:`.QueryableAttribute.property` attribute refers to the + :class:`.MapperProperty` property, which is what you get when referring + to the collection of mapped properties via :attr:`.Mapper.attrs`. + + .. versionadded:: 0.8.0 + + .. seealso:: + + :attr:`.Mapper.attrs` + + """ + return util.ImmutableProperties( + dict(self.class_manager._all_sqla_attributes())) + + @_memoized_configured_property + def synonyms(self): + """Return a namespace of all :class:`.SynonymProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.SynonymProperty) + + @_memoized_configured_property + def column_attrs(self): + """Return a namespace of all :class:`.ColumnProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.ColumnProperty) + + @_memoized_configured_property + def relationships(self): + """Return a namespace of all :class:`.RelationshipProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.RelationshipProperty) + + @_memoized_configured_property + def composites(self): + """Return a namespace of all :class:`.CompositeProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.CompositeProperty) + + def _filter_properties(self, type_): + if Mapper._new_mappers: + configure_mappers() + return util.ImmutableProperties(util.OrderedDict( + (k, v) for k, v in self._props.items() + if isinstance(v, type_) + )) + + @_memoized_configured_property + def _get_clause(self): + """create a "get clause" based on the primary key. this is used + by query.get() and many-to-one lazyloads to load this item + by primary key. + + """ + params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) + for primary_key in self.primary_key] + return sql.and_(*[k == v for (k, v) in params]), \ + util.column_dict(params) + + @_memoized_configured_property + def _equivalent_columns(self): + """Create a map of all *equivalent* columns, based on + the determination of column pairs that are equated to + one another based on inherit condition. This is designed + to work with the queries that util.polymorphic_union + comes up with, which often don't include the columns from + the base table directly (including the subclass table columns + only). + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, i.e. + + { + tablea.col1: + set([tableb.col1, tablec.col1]), + tablea.col2: + set([tabled.col2]) + } + + """ + result = util.column_dict() + + def visit_binary(binary): + if binary.operator == operators.eq: + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = util.column_set((binary.right,)) + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = util.column_set((binary.left,)) + for mapper in self.base_mapper.self_and_descendants: + if mapper.inherit_condition is not None: + visitors.traverse( + mapper.inherit_condition, {}, + {'binary': visit_binary}) + + return result + + def _is_userland_descriptor(self, obj): + if isinstance(obj, (_MappedAttribute, + instrumentation.ClassManager, + expression.ColumnElement)): + return False + else: + return True + + def _should_exclude(self, name, assigned_name, local, column): + """determine whether a particular property should be implicitly + present on the class. + + This occurs when properties are propagated from an inherited class, or + are applied from the columns present in the mapped table. + + """ + + # check for class-bound attributes and/or descriptors, + # either local or from an inherited class + if local: + if self.class_.__dict__.get(assigned_name, None) is not None \ + and self._is_userland_descriptor( + self.class_.__dict__[assigned_name]): + return True + else: + if getattr(self.class_, assigned_name, None) is not None \ + and self._is_userland_descriptor( + getattr(self.class_, assigned_name)): + return True + + if self.include_properties is not None and \ + name not in self.include_properties and \ + (column is None or column not in self.include_properties): + self._log("not including property %s" % (name)) + return True + + if self.exclude_properties is not None and \ + ( + name in self.exclude_properties or \ + (column is not None and column in self.exclude_properties) + ): + self._log("excluding property %s" % (name)) + return True + + return False + + def common_parent(self, other): + """Return true if the given mapper shares a + common inherited parent as this mapper.""" + + return self.base_mapper is other.base_mapper + + def _canload(self, state, allow_subtypes): + s = self.primary_mapper() + if self.polymorphic_on is not None or allow_subtypes: + return _state_mapper(state).isa(s) + else: + return _state_mapper(state) is s + + def isa(self, other): + """Return True if the this mapper inherits from the given mapper.""" + + m = self + while m and m is not other: + m = m.inherits + return bool(m) + + def iterate_to_root(self): + m = self + while m: + yield m + m = m.inherits + + @_memoized_configured_property + def self_and_descendants(self): + """The collection including this mapper and all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + """ + descendants = [] + stack = deque([self]) + while stack: + item = stack.popleft() + descendants.append(item) + stack.extend(item._inheriting_mappers) + return util.WeakSequence(descendants) + + def polymorphic_iterator(self): + """Iterate through the collection including this mapper and + all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + To iterate through an entire hierarchy, use + ``mapper.base_mapper.polymorphic_iterator()``. + + """ + return iter(self.self_and_descendants) + + def primary_mapper(self): + """Return the primary mapper corresponding to this mapper's class key + (class).""" + + return self.class_manager.mapper + + @property + def primary_base_mapper(self): + return self.class_manager.mapper.base_mapper + + def identity_key_from_row(self, row, adapter=None): + """Return an identity-map key for use in storing/retrieving an + item from the identity map. + + :param row: A :class:`.RowProxy` instance. The columns which are mapped + by this :class:`.Mapper` should be locatable in the row, preferably + via the :class:`.Column` object directly (as is the case when a + :func:`.select` construct is executed), or via string names of the form + ``_``. + + """ + pk_cols = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + + return self._identity_class, \ + tuple(row[column] for column in pk_cols) + + def identity_key_from_primary_key(self, primary_key): + """Return an identity-map key for use in storing/retrieving an + item from an identity map. + + :param primary_key: A list of values indicating the identifier. + + """ + return self._identity_class, tuple(primary_key) + + def identity_key_from_instance(self, instance): + """Return the identity key for the given instance, based on + its primary key attributes. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + This value is typically also found on the instance state under the + attribute name `key`. + + """ + return self.identity_key_from_primary_key( + self.primary_key_from_instance(instance)) + + def _identity_key_from_state(self, state): + dict_ = state.dict + manager = state.manager + return self._identity_class, tuple([ + manager[self._columntoproperty[col].key].\ + impl.get(state, dict_, attributes.PASSIVE_OFF) + for col in self.primary_key + ]) + + def primary_key_from_instance(self, instance): + """Return the list of primary key values for the given + instance. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + """ + state = attributes.instance_state(instance) + return self._primary_key_from_state(state) + + def _primary_key_from_state(self, state): + dict_ = state.dict + manager = state.manager + return [ + manager[self._columntoproperty[col].key].\ + impl.get(state, dict_, attributes.PASSIVE_OFF) + for col in self.primary_key + ] + + def _get_state_attr_by_column(self, state, dict_, column, + passive=attributes.PASSIVE_OFF): + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get(state, dict_, passive=passive) + + def _set_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set(state, dict_, value, None) + + def _get_committed_attr_by_column(self, obj, column): + state = attributes.instance_state(obj) + dict_ = attributes.instance_dict(obj) + return self._get_committed_state_attr_by_column(state, dict_, column) + + def _get_committed_state_attr_by_column(self, state, dict_, + column, passive=attributes.PASSIVE_OFF): + + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.\ + get_committed_value(state, dict_, passive=passive) + + def _optimized_get_statement(self, state, attribute_names): + """assemble a WHERE clause which retrieves a given state by primary + key, using a minimized set of tables. + + Applies to a joined-table inheritance mapper where the + requested attribute names are only present on joined tables, + not the base table. The WHERE clause attempts to include + only those tables to minimize joins. + + """ + props = self._props + + tables = set(chain( + *[sql_util.find_tables(c, check_columns=True) + for key in attribute_names + for c in props[key].columns] + )) + + if self.base_mapper.local_table in tables: + return None + + class ColumnsNotAvailable(Exception): + pass + + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + + if leftcol.table not in tables: + leftval = self._get_committed_state_attr_by_column( + state, state.dict, + leftcol, + passive=attributes.PASSIVE_NO_INITIALIZE) + if leftval is attributes.PASSIVE_NO_RESULT or leftval is None: + raise ColumnsNotAvailable() + binary.left = sql.bindparam(None, leftval, + type_=binary.right.type) + elif rightcol.table not in tables: + rightval = self._get_committed_state_attr_by_column( + state, state.dict, + rightcol, + passive=attributes.PASSIVE_NO_INITIALIZE) + if rightval is attributes.PASSIVE_NO_RESULT or \ + rightval is None: + raise ColumnsNotAvailable() + binary.right = sql.bindparam(None, rightval, + type_=binary.right.type) + + allconds = [] + + try: + start = False + for mapper in reversed(list(self.iterate_to_root())): + if mapper.local_table in tables: + start = True + elif not isinstance(mapper.local_table, expression.TableClause): + return None + if start and not mapper.single: + allconds.append(visitors.cloned_traverse( + mapper.inherit_condition, + {}, + {'binary': visit_binary} + ) + ) + except ColumnsNotAvailable: + return None + + cond = sql.and_(*allconds) + + cols = [] + for key in attribute_names: + cols.extend(props[key].columns) + return sql.select(cols, cond, use_labels=True) + + def cascade_iterator(self, type_, state, halt_on=None): + """Iterate each element and its mapper in an object graph, + for all relationships that meet the given cascade rule. + + :param type_: + The name of the cascade rule (i.e. save-update, delete, + etc.) + + :param state: + The lead InstanceState. child items will be processed per + the relationships defined for this object's mapper. + + the return value are object instances; this provides a strong + reference so that they don't fall out of scope immediately. + + """ + visited_states = set() + prp, mpp = object(), object() + + visitables = deque([(deque(self._props.values()), prp, + state, state.dict)]) + + while visitables: + iterator, item_type, parent_state, parent_dict = visitables[-1] + if not iterator: + visitables.pop() + continue + + if item_type is prp: + prop = iterator.popleft() + if type_ not in prop.cascade: + continue + queue = deque(prop.cascade_iterator(type_, parent_state, + parent_dict, visited_states, halt_on)) + if queue: + visitables.append((queue, mpp, None, None)) + elif item_type is mpp: + instance, instance_mapper, corresponding_state, \ + corresponding_dict = iterator.popleft() + yield instance, instance_mapper, \ + corresponding_state, corresponding_dict + visitables.append((deque(instance_mapper._props.values()), + prp, corresponding_state, + corresponding_dict)) + + @_memoized_configured_property + def _compiled_cache(self): + return util.LRUCache(self._compiled_cache_size) + + @_memoized_configured_property + def _sorted_tables(self): + table_to_mapper = {} + + for mapper in self.base_mapper.self_and_descendants: + for t in mapper.tables: + table_to_mapper.setdefault(t, mapper) + + extra_dependencies = [] + for table, mapper in table_to_mapper.items(): + super_ = mapper.inherits + if super_: + extra_dependencies.extend([ + (super_table, table) + for super_table in super_.tables + ]) + + def skip(fk): + # attempt to skip dependencies that are not + # significant to the inheritance chain + # for two tables that are related by inheritance. + # while that dependency may be important, it's techinically + # not what we mean to sort on here. + parent = table_to_mapper.get(fk.parent.table) + dep = table_to_mapper.get(fk.column.table) + if parent is not None and \ + dep is not None and \ + dep is not parent and \ + dep.inherit_condition is not None: + cols = set(sql_util._find_columns(dep.inherit_condition)) + if parent.inherit_condition is not None: + cols = cols.union(sql_util._find_columns( + parent.inherit_condition)) + return fk.parent not in cols and fk.column not in cols + else: + return fk.parent not in cols + return False + + sorted_ = sql_util.sort_tables(table_to_mapper, + skip_fn=skip, + extra_dependencies=extra_dependencies) + + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + + def _memo(self, key, callable_): + if key in self._memoized_values: + return self._memoized_values[key] + else: + self._memoized_values[key] = value = callable_() + return value + + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" + + result = util.defaultdict(list) + + for table in self._sorted_tables: + cols = set(table.c) + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and \ + cols.intersection( + util.reduce(set.union, + [l.proxy_set for l, r in m._inherits_equated_pairs]) + ): + result[table].append((m, m._inherits_equated_pairs)) + + return result + + +def configure_mappers(): + """Initialize the inter-mapper relationships of all mappers that + have been constructed thus far. + + This function can be called any number of times, but in + most cases is handled internally. + + """ + + if not Mapper._new_mappers: + return + + _CONFIGURE_MUTEX.acquire() + try: + global _already_compiling + if _already_compiling: + return + _already_compiling = True + try: + + # double-check inside mutex + if not Mapper._new_mappers: + return + + Mapper.dispatch(Mapper).before_configured() + # initialize properties on all mappers + # note that _mapper_registry is unordered, which + # may randomly conceal/reveal issues related to + # the order of mapper compilation + + for mapper in list(_mapper_registry): + if getattr(mapper, '_configure_failed', False): + e = sa_exc.InvalidRequestError( + "One or more mappers failed to initialize - " + "can't proceed with initialization of other " + "mappers. Original exception was: %s" + % mapper._configure_failed) + e._configure_failed = mapper._configure_failed + raise e + if not mapper.configured: + try: + mapper._post_configure_properties() + mapper._expire_memoizations() + mapper.dispatch.mapper_configured( + mapper, mapper.class_) + except: + exc = sys.exc_info()[1] + if not hasattr(exc, '_configure_failed'): + mapper._configure_failed = exc + raise + + Mapper._new_mappers = False + finally: + _already_compiling = False + finally: + _CONFIGURE_MUTEX.release() + Mapper.dispatch(Mapper).after_configured() + + +def reconstructor(fn): + """Decorate a method as the 'reconstructor' hook. + + Designates a method as the "reconstructor", an ``__init__``-like + method that will be called by the ORM after the instance has been + loaded from the database or otherwise reconstituted. + + The reconstructor will be invoked with no arguments. Scalar + (non-collection) database-mapped attributes of the instance will + be available for use within the function. Eagerly-loaded + collections are generally not yet available and will usually only + contain the first element. ORM state changes made to objects at + this stage will not be recorded for the next flush() operation, so + the activity within a reconstructor should be conservative. + + """ + fn.__sa_reconstructor__ = True + return fn + + +def validates(*names, **kw): + """Decorate a method as a 'validator' for one or more named properties. + + Designates a method as a validator, a method which receives the + name of the attribute as well as a value to be assigned, or in the + case of a collection, the value to be added to the collection. + The function can then raise validation exceptions to halt the + process from continuing (where Python's built-in ``ValueError`` + and ``AssertionError`` exceptions are reasonable choices), or can + modify or replace the value before proceeding. The function should + otherwise return the given value. + + Note that a validator for a collection **cannot** issue a load of that + collection within the validation routine - this usage raises + an assertion to avoid recursion overflows. This is a reentrant + condition which is not supported. + + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. + + .. versionadded:: 0.7.7 + :param include_backrefs: defaults to ``True``; if ``False``, the + validation function will not emit if the originator is an attribute + event related via a backref. This can be used for bi-directional + :func:`.validates` usage where only one validator should emit per + attribute operation. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`simple_validators` - usage examples for :func:`.validates` + + """ + include_removes = kw.pop('include_removes', False) + include_backrefs = kw.pop('include_backrefs', True) + + def wrap(fn): + fn.__sa_validators__ = names + fn.__sa_validation_opts__ = { + "include_removes": include_removes, + "include_backrefs": include_backrefs + } + return fn + return wrap + + +def _event_on_load(state, ctx): + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + if instrumenting_mapper._reconstructor: + instrumenting_mapper._reconstructor(state.obj()) + + +def _event_on_first_init(manager, cls): + """Initial mapper compilation trigger. + + instrumentation calls this one when InstanceState + is first generated, and is needed for legacy mutable + attributes to work. + """ + + instrumenting_mapper = manager.info.get(_INSTRUMENTOR) + if instrumenting_mapper: + if Mapper._new_mappers: + configure_mappers() + + +def _event_on_init(state, args, kwargs): + """Run init_instance hooks. + + This also includes mapper compilation, normally not needed + here but helps with some piecemeal configuration + scenarios (such as in the ORM tutorial). + + """ + + instrumenting_mapper = state.manager.info.get(_INSTRUMENTOR) + if instrumenting_mapper: + if Mapper._new_mappers: + configure_mappers() + if instrumenting_mapper._set_polymorphic_identity: + instrumenting_mapper._set_polymorphic_identity(state) + + +def _event_on_resurrect(state): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info.get(_INSTRUMENTOR) + if instrumenting_mapper: + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column( + state, state.dict, col, val) + + +class _ColumnMapping(dict): + """Error reporting helper for mapper._columntoproperty.""" + + def __init__(self, mapper): + self.mapper = mapper + + def __missing__(self, column): + prop = self.mapper._props.get(column) + if prop: + raise orm_exc.UnmappedColumnError( + "Column '%s.%s' is not available, due to " + "conflicting property '%s':%r" % ( + column.table.name, column.name, column.key, prop)) + raise orm_exc.UnmappedColumnError( + "No column %s is configured on mapper %s..." % + (column, self.mapper)) diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py new file mode 100644 index 00000000..3397626b --- /dev/null +++ b/lib/sqlalchemy/orm/path_registry.py @@ -0,0 +1,261 @@ +# orm/path_registry.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Path tracking utilities, representing mapper graph traversals. + +""" + +from .. import inspection +from .. import util +from .. import exc +from itertools import chain +from .base import class_mapper + +def _unreduce_path(path): + return PathRegistry.deserialize(path) + + +_WILDCARD_TOKEN = "*" +_DEFAULT_TOKEN = "_sa_default" + +class PathRegistry(object): + """Represent query load paths and registry functions. + + Basically represents structures like: + + (, "orders", , "items", ) + + These structures are generated by things like + query options (joinedload(), subqueryload(), etc.) and are + used to compose keys stored in the query._attributes dictionary + for various options. + + They are then re-composed at query compile/result row time as + the query is formed and as rows are fetched, where they again + serve to compose keys to look up options in the context.attributes + dictionary, which is copied from query._attributes. + + The path structure has a limited amount of caching, where each + "root" ultimately pulls from a fixed registry associated with + the first mapper, that also contains elements for each of its + property keys. However paths longer than two elements, which + are the exception rather than the rule, are generated on an + as-needed basis. + + """ + + def __eq__(self, other): + return other is not None and \ + self.path == other.path + + def set(self, attributes, key, value): + attributes[(key, self.path)] = value + + def setdefault(self, attributes, key, value): + attributes.setdefault((key, self.path), value) + + def get(self, attributes, key, value=None): + key = (key, self.path) + if key in attributes: + return attributes[key] + else: + return value + + def __len__(self): + return len(self.path) + + @property + def length(self): + return len(self.path) + + def pairs(self): + path = self.path + for i in range(0, len(path), 2): + yield path[i], path[i + 1] + + def contains_mapper(self, mapper): + for path_mapper in [ + self.path[i] for i in range(0, len(self.path), 2) + ]: + if path_mapper.is_mapper and \ + path_mapper.isa(mapper): + return True + else: + return False + + def contains(self, attributes, key): + return (key, self.path) in attributes + + def __reduce__(self): + return _unreduce_path, (self.serialize(), ) + + def serialize(self): + path = self.path + return list(zip( + [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], + [path[i].key for i in range(1, len(path), 2)] + [None] + )) + + @classmethod + def deserialize(cls, path): + if path is None: + return None + + p = tuple(chain(*[(class_mapper(mcls), + class_mapper(mcls).attrs[key] + if key is not None else None) + for mcls, key in path])) + if p and p[-1] is None: + p = p[0:-1] + return cls.coerce(p) + + @classmethod + def per_mapper(cls, mapper): + return EntityRegistry( + cls.root, mapper + ) + + @classmethod + def coerce(cls, raw): + return util.reduce(lambda prev, next: prev[next], raw, cls.root) + + def token(self, token): + if token.endswith(':' + _WILDCARD_TOKEN): + return TokenRegistry(self, token) + elif token.endswith(":" + _DEFAULT_TOKEN): + return TokenRegistry(self.root, token) + else: + raise exc.ArgumentError("invalid token: %s" % token) + + def __add__(self, other): + return util.reduce( + lambda prev, next: prev[next], + other.path, self) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.path, ) + + +class RootRegistry(PathRegistry): + """Root registry, defers to mappers so that + paths are maintained per-root-mapper. + + """ + path = () + has_entity = False + def __getitem__(self, entity): + return entity._path_registry + +PathRegistry.root = RootRegistry() + +class TokenRegistry(PathRegistry): + def __init__(self, parent, token): + self.token = token + self.parent = parent + self.path = parent.path + (token,) + + has_entity = False + + def __getitem__(self, entity): + raise NotImplementedError() + +class PropRegistry(PathRegistry): + def __init__(self, parent, prop): + # restate this path in terms of the + # given MapperProperty's parent. + insp = inspection.inspect(parent[-1]) + if not insp.is_aliased_class or insp._use_mapper_path: + parent = parent.parent[prop.parent] + elif insp.is_aliased_class and insp.with_polymorphic_mappers: + if prop.parent is not insp.mapper and \ + prop.parent in insp.with_polymorphic_mappers: + subclass_entity = parent[-1]._entity_for_mapper(prop.parent) + parent = parent.parent[subclass_entity] + + self.prop = prop + self.parent = parent + self.path = parent.path + (prop,) + + @util.memoized_property + def has_entity(self): + return hasattr(self.prop, "mapper") + + @util.memoized_property + def entity(self): + return self.prop.mapper + + @util.memoized_property + def _wildcard_path_loader_key(self): + """Given a path (mapper A, prop X), replace the prop with the wildcard, + e.g. (mapper A, 'relationship:.*') or (mapper A, 'column:.*'), then + return within the ("loader", path) structure. + + """ + return ("loader", + self.parent.token( + "%s:%s" % (self.prop.strategy_wildcard_key, _WILDCARD_TOKEN) + ).path + ) + + @util.memoized_property + def _default_path_loader_key(self): + return ("loader", + self.parent.token( + "%s:%s" % (self.prop.strategy_wildcard_key, _DEFAULT_TOKEN) + ).path + ) + + @util.memoized_property + def _loader_key(self): + return ("loader", self.path) + + @property + def mapper(self): + return self.entity + + @property + def entity_path(self): + return self[self.entity] + + def __getitem__(self, entity): + if isinstance(entity, (int, slice)): + return self.path[entity] + else: + return EntityRegistry( + self, entity + ) + +class EntityRegistry(PathRegistry, dict): + is_aliased_class = False + has_entity = True + + def __init__(self, parent, entity): + self.key = entity + self.parent = parent + self.is_aliased_class = entity.is_aliased_class + self.entity = entity + self.path = parent.path + (entity,) + self.entity_path = self + + @property + def mapper(self): + return inspection.inspect(self.entity).mapper + + def __bool__(self): + return True + __nonzero__ = __bool__ + + def __getitem__(self, entity): + if isinstance(entity, (int, slice)): + return self.path[entity] + else: + return dict.__getitem__(self, entity) + + def __missing__(self, key): + self[key] = item = PropRegistry(self, key) + return item + + + diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py new file mode 100644 index 00000000..1bd432f1 --- /dev/null +++ b/lib/sqlalchemy/orm/persistence.py @@ -0,0 +1,1107 @@ +# orm/persistence.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" + +import operator +from itertools import groupby +from .. import sql, util, exc as sa_exc, schema +from . import attributes, sync, exc as orm_exc, evaluator +from .base import _state_mapper, state_str, _attr_as_key +from ..sql import expression +from . import loading + + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_insert, states_to_update = _organize_states_for_save( + base_mapper, + states, + uowtransaction) + + cached_connections = _cached_connection_dict(base_mapper) + + for table, mapper in base_mapper._sorted_tables.items(): + insert = _collect_insert_commands(base_mapper, uowtransaction, + table, states_to_insert) + + update = _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update) + + if update: + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + if insert: + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, insert) + + _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update) + + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = _cached_connection_dict(base_mapper) + + states_to_update = _organize_states_for_post_update( + base_mapper, + states, uowtransaction) + + for table, mapper in base_mapper._sorted_tables.items(): + update = _collect_post_update_commands(base_mapper, uowtransaction, + table, states_to_update, + post_update_cols) + + if update: + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + cached_connections = _cached_connection_dict(base_mapper) + + states_to_delete = _organize_states_for_delete( + base_mapper, + states, + uowtransaction) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(list(table_to_mapper.keys())): + delete = _collect_delete_commands(base_mapper, uowtransaction, + table, states_to_delete) + + mapper = table_to_mapper[table] + + _emit_delete_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, delete) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + states_to_insert = [] + states_to_update = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + if mapper._validate_polymorphic_identity: + mapper._validate_polymorphic_identity(mapper, state, dict_) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if not has_identity and not row_switch: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + else: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + + return states_to_insert, states_to_update + + +def _organize_states_for_post_update(base_mapper, states, + uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return list(_connections_for_states(base_mapper, uowtransaction, + states)) + + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + states_to_delete = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + mapper.dispatch.before_delete(mapper, connection, state) + + states_to_delete.append((state, dict_, mapper, + bool(state.key), connection)) + return states_to_delete + + +def _collect_insert_commands(base_mapper, uowtransaction, table, + states_to_insert): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + insert = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + has_all_pks = True + has_all_defaults = True + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col and \ + mapper.version_id_generator is not False: + val = mapper.version_id_generator(None) + params[col.key] = val + else: + # pull straight from the dict for + # pending objects + prop = mapper._columntoproperty[col] + value = state_dict.get(prop.key, None) + + if value is None: + if col in pks: + has_all_pks = False + elif col.default is None and \ + col.server_default is None: + params[col.key] = value + elif col.server_default is not None and \ + mapper.base_mapper.eager_defaults: + has_all_defaults = False + + elif isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + insert.append((state, state_dict, params, mapper, + connection, value_params, has_all_pks, + has_all_defaults)) + return insert + + +def _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + update = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + hasdata = hasnull = False + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = \ + mapper._get_committed_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + params[col.key] = history.added[0] + hasdata = True + else: + if mapper.version_id_generator is not False: + val = mapper.version_id_generator(params[col._label]) + params[col.key] = val + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.values(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True + else: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + if isinstance(history.added[0], + sql.ClauseElement): + value_params[col] = history.added[0] + else: + value = history.added[0] + params[col.key] = value + + if col in pks: + if history.deleted and \ + not row_switch: + # if passive_updates and sync detected + # this was a pk->pk sync, use the new + # value to locate the row, since the + # DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + value = history.added[0] + params[col._label] = value + else: + # use the old value to + # locate the row + value = history.deleted[0] + params[col._label] = value + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params + # so the update doesn't + # attempt to include the pk in the + # update statement + del params[col.key] + value = history.added[0] + params[col._label] = value + if value is None: + hasnull = True + else: + hasdata = True + elif col in pks: + value = state.manager[prop.key].impl.get( + state, state_dict) + if value is None: + hasnull = True + params[col._label] = value + if hasdata: + if hasnull: + raise orm_exc.FlushError( + "Can't update table " + "using NULL for primary " + "key value") + update.append((state, state_dict, params, mapper, + connection, value_params)) + return update + + +def _collect_post_update_commands(base_mapper, uowtransaction, table, + states_to_update, post_update_cols): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + update = [] + for state, state_dict, mapper, connection in states_to_update: + if table not in mapper._pks_by_table: + continue + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + return update + + +def _collect_delete_commands(base_mapper, uowtransaction, table, + states_to_delete): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + delete = util.defaultdict(list) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + if not has_identity or table not in mapper._pks_by_table: + continue + + params = {} + delete[connection].append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = \ + value = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, col) + if value is None: + raise orm_exc.FlushError( + "Can't delete from table " + "using NULL for primary " + "key value") + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) + return delete + + +def _emit_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt + + statement = base_mapper._memo(('update', table), update_stmt) + + rows = 0 + for state, state_dict, params, mapper, \ + connection, value_params in update: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + + +def _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, insert): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + statement = base_mapper._memo(('insert', table), table.insert) + + for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ + records in groupby(insert, + lambda rec: (rec[4], + list(rec[2].keys()), + bool(rec[5]), + rec[6], rec[7]) + ): + if \ + ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) and has_all_pks and not hasvalue: + + records = list(records) + multiparams = [rec[2] for rec in records] + + c = cached_connections[connection].\ + execute(statement, multiparams) + + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) + + else: + if not has_all_defaults and base_mapper.eager_defaults: + statement = statement.return_defaults() + elif mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) + + for state, state_dict, params, mapper_rec, \ + connection, value_params, \ + has_all_pks, has_all_defaults in records: + + if value_params: + result = connection.execute( + statement.values(value_params), + params) + else: + result = cached_connections[connection].\ + execute(statement, params) + + primary_key = result.context.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper_rec._set_state_attr_by_column( + state, + state_dict, + col, pk) + + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params) + + +def _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], list(rec[2].keys())) + ): + connection = key[0] + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[connection].\ + execute(statement, multiparams) + + +def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, + mapper, table, delete): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def delete_stmt(): + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append( + col == sql.bindparam(col.key, type_=col.type)) + + if need_version_id: + clause.clauses.append( + mapper.version_id_col == + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + + return table.delete(clause) + + for connection, del_objects in delete.items(): + statement = base_mapper._memo(('delete', table), delete_stmt) + + connection = cached_connections[connection] + + expected = len(del_objects) + rows_matched = -1 + only_warn = False + if connection.dialect.supports_sane_multi_rowcount: + c = connection.execute(statement, del_objects) + + if not need_version_id: + only_warn = True + + rows_matched = c.rowcount + + elif need_version_id: + if connection.dialect.supports_sane_rowcount: + rows_matched = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows_matched += c.rowcount + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + connection.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + connection.execute(statement, del_objects) + + if base_mapper.confirm_deleted_rows and \ + rows_matched > -1 and expected != rows_matched: + if only_warn: + util.warn( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." % + (table.description, expected, rows_matched) + ) + else: + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." % + (table.description, expected, rows_matched) + ) + +def _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert + \ + states_to_update: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + if base_mapper.eager_defaults: + toload_now.extend(state._unloaded_non_object) + elif mapper.version_id_col is not None and \ + mapper.version_id_generator is False: + prop = mapper._columntoproperty[mapper.version_id_col] + if prop.key in state.unloaded: + toload_now.extend([prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + loading.load_on_ident( + uowtransaction.session.query(base_mapper), + state.key, refresh_state=state, + only_load_props=toload_now) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + + +def _postfetch(mapper, uowtransaction, table, + state, dict_, result, params, value_params): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + prefetch_cols = result.context.prefetch_cols + postfetch_cols = result.context.postfetch_cols + returning_cols = result.context.returning_cols + + if mapper.version_id_col is not None: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + if returning_cols: + row = result.context.returned_defaults + if row is not None: + for col in returning_cols: + if col.primary_key: + continue + mapper._set_state_attr_by_column(state, dict_, col, row[col]) + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + state._expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = \ + uowtransaction.session.connection_callable + else: + connection = None + connection_callable = None + + for state in _sort_states(states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + elif not connection: + connection = uowtransaction.transaction.connection( + base_mapper) + + mapper = _state_mapper(state) + + yield state, state.dict, mapper, connection + + +def _cached_connection_dict(base_mapper): + # dictionary of connection->connection_with_cache_options. + return util.PopulateDict( + lambda conn: conn.execution_options( + compiled_cache=base_mapper._compiled_cache + )) + + +def _sort_states(states): + pending = set(states) + persistent = set(s for s in pending if s.key is not None) + pending.difference_update(persistent) + return sorted(pending, key=operator.attrgetter("insert_order")) + \ + sorted(persistent, key=lambda q: q.key[1]) + + +class BulkUD(object): + """Handle bulk update and deletes via a :class:`.Query`.""" + + def __init__(self, query): + self.query = query.enable_eagerloads(False) + + @property + def session(self): + return self.query.session + + @classmethod + def _factory(cls, lookup, synchronize_session, *arg): + try: + klass = lookup[synchronize_session] + except KeyError: + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are %s" % (", ".join(sorted(repr(x) + for x in lookup)))) + else: + return klass(*arg) + + def exec_(self): + self._do_pre() + self._do_pre_synchronize() + self._do_exec() + self._do_post_synchronize() + self._do_post() + + def _do_pre(self): + query = self.query + self.context = context = query._compile_context() + if len(context.statement.froms) != 1 or \ + not isinstance(context.statement.froms[0], schema.Table): + + self.primary_table = query._only_entity_zero( + "This operation requires only one Table or " + "entity be specified as the target." + ).mapper.local_table + else: + self.primary_table = context.statement.froms[0] + + session = query.session + + if query._autoflush: + session._autoflush() + + def _do_pre_synchronize(self): + pass + + def _do_post_synchronize(self): + pass + + +class BulkEvaluate(BulkUD): + """BulkUD which does the 'evaluate' method of session state resolution.""" + + def _additional_evaluators(self, evaluator_compiler): + pass + + def _do_pre_synchronize(self): + query = self.query + try: + evaluator_compiler = evaluator.EvaluatorCompiler() + if query.whereclause is not None: + eval_condition = evaluator_compiler.process( + query.whereclause) + else: + def eval_condition(obj): + return True + + self._additional_evaluators(evaluator_compiler) + + except evaluator.UnevaluatableError: + raise sa_exc.InvalidRequestError( + "Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the " + "synchronize_session parameter.") + target_cls = query._mapper_zero().class_ + + #TODO: detect when the where clause is a trivial primary key match + self.matched_objects = [ + obj for (cls, pk), obj in + query.session.identity_map.items() + if issubclass(cls, target_cls) and + eval_condition(obj)] + + +class BulkFetch(BulkUD): + """BulkUD which does the 'fetch' method of session state resolution.""" + + def _do_pre_synchronize(self): + query = self.query + session = query.session + select_stmt = self.context.statement.with_only_columns( + self.primary_table.primary_key) + self.matched_rows = session.execute( + select_stmt, + params=query._params).fetchall() + + +class BulkUpdate(BulkUD): + """BulkUD which handles UPDATEs.""" + + def __init__(self, query, values): + super(BulkUpdate, self).__init__(query) + self.query._no_select_modifiers("update") + self.values = values + + @classmethod + def factory(cls, query, synchronize_session, values): + return BulkUD._factory({ + "evaluate": BulkUpdateEvaluate, + "fetch": BulkUpdateFetch, + False: BulkUpdate + }, synchronize_session, query, values) + + def _do_exec(self): + update_stmt = sql.update(self.primary_table, + self.context.whereclause, self.values) + + self.result = self.query.session.execute( + update_stmt, params=self.query._params) + self.rowcount = self.result.rowcount + + def _do_post(self): + session = self.query.session + session.dispatch.after_bulk_update(self) + + +class BulkDelete(BulkUD): + """BulkUD which handles DELETEs.""" + + def __init__(self, query): + super(BulkDelete, self).__init__(query) + self.query._no_select_modifiers("delete") + + @classmethod + def factory(cls, query, synchronize_session): + return BulkUD._factory({ + "evaluate": BulkDeleteEvaluate, + "fetch": BulkDeleteFetch, + False: BulkDelete + }, synchronize_session, query) + + def _do_exec(self): + delete_stmt = sql.delete(self.primary_table, + self.context.whereclause) + + self.result = self.query.session.execute(delete_stmt, + params=self.query._params) + self.rowcount = self.result.rowcount + + def _do_post(self): + session = self.query.session + session.dispatch.after_bulk_delete(self) + + +class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): + """BulkUD which handles UPDATEs using the "evaluate" + method of session resolution.""" + + def _additional_evaluators(self, evaluator_compiler): + self.value_evaluators = {} + for key, value in self.values.items(): + key = _attr_as_key(key) + self.value_evaluators[key] = evaluator_compiler.process( + expression._literal_as_binds(value)) + + def _do_post_synchronize(self): + session = self.query.session + states = set() + evaluated_keys = list(self.value_evaluators.keys()) + for obj in self.matched_objects: + state, dict_ = attributes.instance_state(obj),\ + attributes.instance_dict(obj) + + # only evaluate unmodified attributes + to_evaluate = state.unmodified.intersection( + evaluated_keys) + for key in to_evaluate: + dict_[key] = self.value_evaluators[key](obj) + + state._commit(dict_, list(to_evaluate)) + + # expire attributes with pending changes + # (there was no autoflush, so they are overwritten) + state._expire_attributes(dict_, + set(evaluated_keys). + difference(to_evaluate)) + states.add(state) + session._register_altered(states) + + +class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): + """BulkUD which handles DELETEs using the "evaluate" + method of session resolution.""" + + def _do_post_synchronize(self): + self.query.session._remove_newly_deleted( + [attributes.instance_state(obj) + for obj in self.matched_objects]) + + +class BulkUpdateFetch(BulkFetch, BulkUpdate): + """BulkUD which handles UPDATEs using the "fetch" + method of session resolution.""" + + def _do_post_synchronize(self): + session = self.query.session + target_mapper = self.query._mapper_zero() + + states = set([ + attributes.instance_state(session.identity_map[identity_key]) + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key)) + for primary_key in self.matched_rows + ] + if identity_key in session.identity_map + ]) + attrib = [_attr_as_key(k) for k in self.values] + for state in states: + session._expire_state(state, attrib) + session._register_altered(states) + + +class BulkDeleteFetch(BulkFetch, BulkDelete): + """BulkUD which handles DELETEs using the "fetch" + method of session resolution.""" + + def _do_post_synchronize(self): + session = self.query.session + target_mapper = self.query._mapper_zero() + for primary_key in self.matched_rows: + # TODO: inline this and call remove_newly_deleted + # once + identity_key = target_mapper.identity_key_from_primary_key( + list(primary_key)) + if identity_key in session.identity_map: + session._remove_newly_deleted( + [attributes.instance_state( + session.identity_map[identity_key] + )] + ) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py new file mode 100644 index 00000000..a0def7d3 --- /dev/null +++ b/lib/sqlalchemy/orm/properties.py @@ -0,0 +1,259 @@ +# orm/properties.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""MapperProperty implementations. + +This is a private module which defines the behavior of invidual ORM- +mapped attributes. + +""" +from __future__ import absolute_import + +from .. import util, log +from ..sql import expression +from . import attributes +from .util import _orm_full_deannotate + +from .interfaces import PropComparator, StrategizedProperty + +__all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', + 'ComparableProperty', 'RelationshipProperty'] + + +@log.class_logger +class ColumnProperty(StrategizedProperty): + """Describes an object attribute that corresponds to a table column. + + Public constructor is the :func:`.orm.column_property` function. + + """ + + strategy_wildcard_key = 'column' + + def __init__(self, *columns, **kwargs): + """Provide a column-level property for use with a Mapper. + + Column-based properties can normally be applied to the mapper's + ``properties`` dictionary using the :class:`.Column` element directly. + Use this function when the given column is not directly present within the + mapper's selectable; examples include SQL expressions, functions, and + scalar SELECT queries. + + Columns that aren't present in the mapper's selectable won't be persisted + by the mapper and are effectively "read-only" attributes. + + :param \*cols: + list of Column objects to be mapped. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. Normally, history tracking logic for + simple non-primary-key scalar values only needs to be + aware of the "new" value in order to perform a flush. This + flag is available for applications that make use of + :func:`.attributes.get_history` or :meth:`.Session.is_modified` + which also need to know + the "previous" value of the attribute. + + .. versionadded:: 0.6.6 + + :param comparator_factory: a class which extends + :class:`.ColumnProperty.Comparator` which provides custom SQL clause + generation for comparison operations. + + :param group: + a group name for this property when marked as deferred. + + :param deferred: + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param expire_on_flush=True: + Disable expiry on flush. A column_property() which refers + to a SQL expression (and not a single table-bound column) + is considered to be a "read only" property; populating it + has no effect on the state of data, and it can only return + database state. For this reason a column_property()'s value + is expired whenever the parent object is involved in a + flush, that is, has any kind of "dirty" state within a flush. + Setting this parameter to ``False`` will have the effect of + leaving any existing value present after the flush proceeds. + Note however that the :class:`.Session` with default expiration + settings still expires + all attributes after a :meth:`.Session.commit` call, however. + + .. versionadded:: 0.7.3 + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + .. versionadded:: 0.8 + + :param extension: + an + :class:`.AttributeExtension` + instance, or list of extensions, which will be prepended + to the list of attribute listeners for the resulting + descriptor placed on the class. + **Deprecated.** Please see :class:`.AttributeEvents`. + + """ + self._orig_columns = [expression._labeled(c) for c in columns] + self.columns = [expression._labeled(_orm_full_deannotate(c)) + for c in columns] + self.group = kwargs.pop('group', None) + self.deferred = kwargs.pop('deferred', False) + self.instrument = kwargs.pop('_instrument', True) + self.comparator_factory = kwargs.pop('comparator_factory', + self.__class__.Comparator) + self.descriptor = kwargs.pop('descriptor', None) + self.extension = kwargs.pop('extension', None) + self.active_history = kwargs.pop('active_history', False) + self.expire_on_flush = kwargs.pop('expire_on_flush', True) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + if 'doc' in kwargs: + self.doc = kwargs.pop('doc') + else: + for col in reversed(self.columns): + doc = getattr(col, 'doc', None) + if doc is not None: + self.doc = doc + break + else: + self.doc = None + + if kwargs: + raise TypeError( + "%s received unexpected keyword argument(s): %s" % ( + self.__class__.__name__, + ', '.join(sorted(kwargs.keys())))) + + util.set_creation_order(self) + + self.strategy_class = self._strategy_lookup( + ("deferred", self.deferred), + ("instrument", self.instrument) + ) + + @property + def expression(self): + """Return the primary column or expression for this ColumnProperty. + + """ + return self.columns[0] + + def instrument_class(self, mapper): + if not self.instrument: + return + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc + ) + + def do_init(self): + super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and \ + set(self.parent.primary_key).issuperset(self.columns): + util.warn( + ("On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own mapped " + "attribute name.") % (self.parent, self.columns[1], + self.columns[0], self.key)) + + def copy(self): + return ColumnProperty( + deferred=self.deferred, + group=self.group, + active_history=self.active_history, + *self.columns) + + def _getcommitted(self, state, dict_, column, + passive=attributes.PASSIVE_OFF): + return state.get_impl(self.key).\ + get_committed_value(state, dict_, passive=passive) + + def merge(self, session, source_state, source_dict, dest_state, + dest_dict, load, _recursive): + if not self.instrument: + return + elif self.key in source_dict: + value = source_dict[self.key] + + if not load: + dest_dict[self.key] = value + else: + impl = dest_state.get_impl(self.key) + impl.set(dest_state, dest_dict, value, None) + elif dest_state.has_identity and self.key not in dest_dict: + dest_state._expire_attributes(dest_dict, [self.key]) + + class Comparator(PropComparator): + """Produce boolean, comparison, and other operators for + :class:`.ColumnProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview. + + See also: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + @util.memoized_instancemethod + def __clause_element__(self): + if self.adapter: + return self.adapter(self.prop.columns[0]) + else: + return self.prop.columns[0]._annotate({ + "parententity": self._parentmapper, + "parentmapper": self._parentmapper}) + + @util.memoized_property + def info(self): + ce = self.__clause_element__() + try: + return ce.info + except AttributeError: + return self.prop.info + + def __getattr__(self, key): + """proxy attribute access down to the mapped column. + + this allows user-defined comparison methods to be accessed. + """ + return getattr(self.__clause_element__(), key) + + def operate(self, op, *other, **kwargs): + return op(self.__clause_element__(), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py new file mode 100644 index 00000000..afcbf350 --- /dev/null +++ b/lib/sqlalchemy/orm/query.py @@ -0,0 +1,3564 @@ +# orm/query.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The Query class and support. + +Defines the :class:`.Query` class, the central +construct used by the ORM to construct database queries. + +The :class:`.Query` class should not be confused with the +:class:`.Select` class, which defines database +SELECT operations at the SQL (non-ORM) level. ``Query`` differs from +``Select`` in that it returns ORM-mapped objects and interacts with an +ORM session, whereas the ``Select`` construct interacts directly with the +database to return iterable result sets. + +""" + +from itertools import chain + +from . import ( + attributes, interfaces, object_mapper, persistence, + exc as orm_exc, loading + ) +from .base import _entity_descriptor, _is_aliased_class, \ + _is_mapped_class, _orm_columns, _generative +from .path_registry import PathRegistry +from .util import ( + AliasedClass, ORMAdapter, join as orm_join, with_parent, aliased + ) +from .. import sql, util, log, exc as sa_exc, inspect, inspection +from ..sql.expression import _interpret_as_from +from ..sql import ( + util as sql_util, + expression, visitors + ) +from ..sql.base import ColumnCollection +from . import properties + +__all__ = ['Query', 'QueryContext', 'aliased'] + + +_path_registry = PathRegistry.root + +@inspection._self_inspects +@log.class_logger +class Query(object): + """ORM-level SQL construction object. + + :class:`.Query` is the source of all SELECT statements generated by the + ORM, both those formulated by end-user query operations as well as by + high level internal operations such as related collection loading. It + features a generative interface whereby successive calls return a new + :class:`.Query` object, a copy of the former with additional + criteria and options associated with it. + + :class:`.Query` objects are normally initially generated using the + :meth:`~.Session.query` method of :class:`.Session`. For a full + walkthrough of :class:`.Query` usage, see the + :ref:`ormtutorial_toplevel`. + + """ + + _enable_eagerloads = True + _enable_assertions = True + _with_labels = False + _criterion = None + _yield_per = None + _order_by = False + _group_by = False + _having = None + _distinct = False + _prefixes = None + _offset = None + _limit = None + _for_update_arg = None + _statement = None + _correlate = frozenset() + _populate_existing = False + _invoke_all_eagers = True + _version_check = False + _autoflush = True + _only_load_props = None + _refresh_state = None + _from_obj = () + _join_entities = () + _select_from_entity = None + _mapper_adapter_map = {} + _filter_aliases = None + _from_obj_alias = None + _joinpath = _joinpoint = util.immutabledict() + _execution_options = util.immutabledict() + _params = util.immutabledict() + _attributes = util.immutabledict() + _with_options = () + _with_hints = () + _enable_single_crit = True + + _current_path = _path_registry + + def __init__(self, entities, session=None): + self.session = session + self._polymorphic_adapters = {} + self._set_entities(entities) + + def _set_entities(self, entities, entity_wrapper=None): + if entity_wrapper is None: + entity_wrapper = _QueryEntity + self._entities = [] + self._primary_entity = None + for ent in util.to_list(entities): + entity_wrapper(self, ent) + + self._set_entity_selectables(self._entities) + + def _set_entity_selectables(self, entities): + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() + + for ent in entities: + for entity in ent.entities: + if entity not in d: + ext_info = inspect(entity) + if not ext_info.is_aliased_class and \ + ext_info.mapper.with_polymorphic: + if ext_info.mapper.mapped_table not in \ + self._polymorphic_adapters: + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + ext_info.selectable, + ext_info.mapper._equivalent_columns + ) + ) + aliased_adapter = None + elif ext_info.is_aliased_class: + aliased_adapter = sql_util.ColumnAdapter( + ext_info.selectable, + ext_info.mapper._equivalent_columns + ) + else: + aliased_adapter = None + + d[entity] = ( + ext_info, + aliased_adapter + ) + ent.setup_entity(*d[entity]) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + def _set_select_from(self, obj, set_base_alias): + fa = [] + select_from_alias = None + + for from_obj in obj: + info = inspect(from_obj) + + if hasattr(info, 'mapper') and \ + (info.is_mapper or info.is_aliased_class): + if set_base_alias: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set.") + fa.append(info.selectable) + elif not info.is_selectable: + raise sa_exc.ArgumentError( + "argument is not a mapped class, mapper, " + "aliased(), or FromClause instance.") + else: + if isinstance(from_obj, expression.SelectBase): + from_obj = from_obj.alias() + if set_base_alias: + select_from_alias = from_obj + fa.append(from_obj) + + self._from_obj = tuple(fa) + + if set_base_alias and \ + len(self._from_obj) == 1 and \ + isinstance(select_from_alias, expression.Alias): + equivs = self.__all_equivs() + self._from_obj_alias = sql_util.ColumnAdapter( + self._from_obj[0], equivs) + + def _reset_polymorphic_adapter(self, mapper): + for m2 in mapper._with_polymorphic_mappers: + self._polymorphic_adapters.pop(m2, None) + for m in m2.iterate_to_root(): + self._polymorphic_adapters.pop(m.local_table, None) + + def _adapt_polymorphic_element(self, element): + if "parententity" in element._annotations: + search = element._annotations['parententity'] + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, 'table'): + search = element.table + else: + return None + + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def _adapt_col_list(self, cols): + return [ + self._adapt_clause( + expression._literal_as_text(o), + True, True) + for o in cols + ] + + @_generative() + def _adapt_all_clauses(self): + self._orm_only_adapt = False + + def _adapt_clause(self, clause, as_filter, orm_only): + """Adapt incoming clauses to transformations which + have been applied within this query.""" + + adapters = [] + # do we adapt all expression elements or only those + # tagged as 'ORM' constructs ? + orm_only = getattr(self, '_orm_only_adapt', orm_only) + + if as_filter and self._filter_aliases: + for fa in self._filter_aliases._visitor_iterator: + adapters.append( + ( + orm_only, fa.replace + ) + ) + + if self._from_obj_alias: + # for the "from obj" alias, apply extra rule to the + # 'ORM only' check, if this query were generated from a + # subquery of itself, i.e. _from_selectable(), apply adaption + # to all SQL constructs. + adapters.append( + ( + getattr(self, '_orm_only_from_obj_alias', orm_only), + self._from_obj_alias.replace + ) + ) + + if self._polymorphic_adapters: + adapters.append( + ( + orm_only, self._adapt_polymorphic_element + ) + ) + + if not adapters: + return clause + + def replace(elem): + for _orm_only, adapter in adapters: + # if 'orm only', look for ORM annotations + # in the element before adapting. + if not _orm_only or \ + '_orm_adapt' in elem._annotations or \ + "parententity" in elem._annotations: + + e = adapter(elem) + if e is not None: + return e + + return visitors.replacement_traverse( + clause, + {}, + replace + ) + + def _entity_zero(self): + return self._entities[0] + + def _mapper_zero(self): + return self._select_from_entity or \ + self._entity_zero().entity_zero + + @property + def _mapper_entities(self): + for ent in self._entities: + if isinstance(ent, _MapperEntity): + yield ent + + def _joinpoint_zero(self): + return self._joinpoint.get( + '_joinpoint_entity', + self._mapper_zero() + ) + + def _mapper_zero_or_none(self): + if self._primary_entity: + return self._primary_entity.mapper + else: + return None + + def _only_mapper_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale or + "This operation requires a Query " + "against a single mapper." + ) + return self._mapper_zero() + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale or + "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def __all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + + def _get_condition(self): + return self._no_criterion_condition("get", order_by=False, distinct=False) + + def _get_existing_condition(self): + self._no_criterion_assertion("get", order_by=False, distinct=False) + + def _no_criterion_assertion(self, meth, order_by=True, distinct=True): + if not self._enable_assertions: + return + if self._criterion is not None or \ + self._statement is not None or self._from_obj or \ + self._limit is not None or self._offset is not None or \ + self._group_by or (order_by and self._order_by) or \ + (distinct and self._distinct): + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth) + + def _no_criterion_condition(self, meth, order_by=True, distinct=True): + self._no_criterion_assertion(meth, order_by, distinct) + + self._from_obj = () + self._statement = self._criterion = None + self._order_by = self._group_by = self._distinct = False + + def _no_clauseelement_condition(self, meth): + if not self._enable_assertions: + return + if self._order_by: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth) + self._no_criterion_condition(meth) + + def _no_statement_condition(self, meth): + if not self._enable_assertions: + return + if self._statement is not None: + raise sa_exc.InvalidRequestError( + ("Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion.") % meth) + + def _no_limit_offset(self, meth): + if not self._enable_assertions: + return + if self._limit is not None or self._offset is not None: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a Query which already has LIMIT " + "or OFFSET applied. To modify the row-limited results of a " + " Query, call from_self() first. " + "Otherwise, call %s() before limit() or offset() " + "are applied." + % (meth, meth) + ) + + def _no_select_modifiers(self, meth): + if not self._enable_assertions: + return + for attr, methname, notset in ( + ('_limit', 'limit()', None), + ('_offset', 'offset()', None), + ('_order_by', 'order_by()', False), + ('_group_by', 'group_by()', False), + ('_distinct', 'distinct()', False), + ): + if getattr(self, attr) is not notset: + raise sa_exc.InvalidRequestError( + "Can't call Query.%s() when %s has been called" % + (meth, methname) + ) + + def _get_options(self, populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None): + if populate_existing: + self._populate_existing = populate_existing + if version_check: + self._version_check = version_check + if refresh_state: + self._refresh_state = refresh_state + if only_load_props: + self._only_load_props = set(only_load_props) + return self + + def _clone(self): + cls = self.__class__ + q = cls.__new__(cls) + q.__dict__ = self.__dict__.copy() + return q + + @property + def statement(self): + """The full SELECT statement represented by this Query. + + The statement by default will not have disambiguating labels + applied to the construct unless with_labels(True) is called + first. + + """ + + stmt = self._compile_context(labels=self._with_labels).\ + statement + if self._params: + stmt = stmt.params(self._params) + + + # TODO: there's no tests covering effects of + # the annotation not being there + return stmt._annotate({'no_replacement_traverse': True}) + + def subquery(self, name=None, with_labels=False, reduce_columns=False): + """return the full SELECT statement represented by + this :class:`.Query`, embedded within an :class:`.Alias`. + + Eager JOIN generation within the query is disabled. + + :param name: string name to be assigned as the alias; + this is passed through to :meth:`.FromClause.alias`. + If ``None``, a name will be deterministically generated + at compile time. + + :param with_labels: if True, :meth:`.with_labels` will be called + on the :class:`.Query` first to apply table-qualified labels + to all columns. + + :param reduce_columns: if True, :meth:`.Select.reduce_columns` will + be called on the resulting :func:`.select` construct, + to remove same-named columns where one also refers to the other + via foreign key or WHERE clause equivalence. + + .. versionchanged:: 0.8 the ``with_labels`` and ``reduce_columns`` + keyword arguments were added. + + """ + q = self.enable_eagerloads(False) + if with_labels: + q = q.with_labels() + q = q.statement + if reduce_columns: + q = q.reduce_columns() + return q.alias(name=name) + + def cte(self, name=None, recursive=False): + """Return the full SELECT statement represented by this + :class:`.Query` represented as a common table expression (CTE). + + .. versionadded:: 0.7.6 + + Parameters and usage are the same as those of the + :meth:`.SelectBase.cte` method; see that method for + further details. + + Here is the `Postgresql WITH + RECURSIVE example + `_. + Note that, in this example, the ``included_parts`` cte and the + ``incl_alias`` alias of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The + ``parts_alias`` object is an :func:`.orm.aliased` instance of the + ``Part`` entity, so column-mapped attributes are available + directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\\ + filter(Part.part=="our part").\\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.part, + parts_alias.sub_part, + parts_alias.quantity).\\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ).\\ + group_by(included_parts.c.sub_part) + + .. seealso:: + + :meth:`.SelectBase.cte` + + """ + return self.enable_eagerloads(False).\ + statement.cte(name=name, recursive=recursive) + + def label(self, name): + """Return the full SELECT statement represented by this + :class:`.Query`, converted + to a scalar subquery with a label of the given name. + + Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.label`. + + .. versionadded:: 0.6.5 + + """ + + return self.enable_eagerloads(False).statement.label(name) + + def as_scalar(self): + """Return the full SELECT statement represented by this + :class:`.Query`, converted to a scalar subquery. + + Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.as_scalar`. + + .. versionadded:: 0.6.5 + + """ + + return self.enable_eagerloads(False).statement.as_scalar() + + @property + def selectable(self): + """Return the :class:`.Select` object emitted by this :class:`.Query`. + + Used for :func:`.inspect` compatibility, this is equivalent to:: + + query.enable_eagerloads(False).with_labels().statement + + """ + return self.__clause_element__() + + def __clause_element__(self): + return self.enable_eagerloads(False).with_labels().statement + + @_generative() + def enable_eagerloads(self, value): + """Control whether or not eager joins and subqueries are + rendered. + + When set to False, the returned Query will not render + eager joins regardless of :func:`~sqlalchemy.orm.joinedload`, + :func:`~sqlalchemy.orm.subqueryload` options + or mapper-level ``lazy='joined'``/``lazy='subquery'`` + configurations. + + This is used primarily when nesting the Query's + statement into a subquery or other + selectable. + + """ + self._enable_eagerloads = value + + @_generative() + def with_labels(self): + """Apply column labels to the return value of Query.statement. + + Indicates that this Query's `statement` accessor should return + a SELECT statement that applies labels to all columns in the + form _; this is commonly used to + disambiguate columns from multiple tables which have the same + name. + + When the `Query` actually issues SQL to load rows, it always + uses column labeling. + + """ + self._with_labels = True + + @_generative() + def enable_assertions(self, value): + """Control whether assertions are generated. + + When set to False, the returned Query will + not assert its state before certain operations, + including that LIMIT/OFFSET has not been applied + when filter() is called, no criterion exists + when get() is called, and no "from_statement()" + exists when filter()/order_by()/group_by() etc. + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or + other modifiers outside of the usual usage patterns. + + Care should be taken to ensure that the usage + pattern is even possible. A statement applied + by from_statement() will override any criterion + set by filter() or order_by(), for example. + + """ + self._enable_assertions = value + + @property + def whereclause(self): + """A readonly attribute which returns the current WHERE criterion for + this Query. + + This returned value is a SQL expression construct, or ``None`` if no + criterion has been established. + + """ + return self._criterion + + @_generative() + def _with_current_path(self, path): + """indicate that this query applies to objects loaded + within a certain path. + + Used by deferred loaders (see strategies.py) which transfer + query options from an originating query to a newly generated + query intended for the deferred load. + + """ + self._current_path = path + + @_generative(_no_clauseelement_condition) + def with_polymorphic(self, + cls_or_mappers, + selectable=None, + polymorphic_on=None): + """Load columns for inheriting classes. + + :meth:`.Query.with_polymorphic` applies transformations + to the "main" mapped class represented by this :class:`.Query`. + The "main" mapped class here means the :class:`.Query` + object's first argument is a full class, i.e. + ``session.query(SomeClass)``. These transformations allow additional + tables to be present in the FROM clause so that columns for a + joined-inheritance subclass are available in the query, both for the + purposes of load-time efficiency as well as the ability to use + these columns at query time. + + See the documentation section :ref:`with_polymorphic` for + details on how this method is used. + + .. versionchanged:: 0.8 + A new and more flexible function + :func:`.orm.with_polymorphic` supersedes + :meth:`.Query.with_polymorphic`, as it can apply the equivalent + functionality to any set of columns or classes in the + :class:`.Query`, not just the "zero mapper". See that + function for a description of arguments. + + """ + + if not self._primary_entity: + raise sa_exc.InvalidRequestError( + "No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] + entity.set_with_polymorphic(self, + cls_or_mappers, + selectable=selectable, + polymorphic_on=polymorphic_on) + + @_generative() + def yield_per(self, count): + """Yield only ``count`` rows at a time. + + WARNING: use this method with caution; if the same instance is present + in more than one batch of rows, end-user changes to attributes will be + overwritten. + + In particular, it's usually impossible to use this setting with + eagerly loaded collections (i.e. any lazy='joined' or 'subquery') + since those collections will be cleared for a new load when + encountered in a subsequent result batch. In the case of 'subquery' + loading, the full result for all rows is fetched which generally + defeats the purpose of :meth:`~sqlalchemy.orm.query.Query.yield_per`. + + Also note that while :meth:`~sqlalchemy.orm.query.Query.yield_per` + will set the ``stream_results`` execution option to True, currently + this is only understood by :mod:`~sqlalchemy.dialects.postgresql.psycopg2` dialect + which will stream results using server side cursors instead of pre-buffer + all rows for this query. Other DBAPIs pre-buffer all rows before + making them available. + + """ + self._yield_per = count + self._execution_options = self._execution_options.union( + {"stream_results": True}) + + def get(self, ident): + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.query(User).get(5) + + some_object = session.query(VersionedFoo).get((5, 10)) + + :meth:`~.Query.get` is special in that it provides direct + access to the identity map of the owning :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`~.Query.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :meth:`~.Query.get` is only used to return a single + mapped instance, not multiple instances or + individual column constructs, and strictly + on a single primary key value. The originating + :class:`.Query` must be constructed in this way, + i.e. against a single mapped entity, + with no additional filtering criterion. Loading + options via :meth:`~.Query.options` may be applied + however, and will be used if the object is not + yet locally present. + + A lazy-loading, many-to-one attribute configured + by :func:`.relationship`, using a simple + foreign-key-to-primary-key criterion, will also use an + operation equivalent to :meth:`~.Query.get` in order to retrieve + the target value from the local identity map + before querying the database. See :doc:`/orm/loading` + for further details on relationship loading. + + :param ident: A scalar or tuple value representing + the primary key. For a composite primary key, + the order of identifiers corresponds in most cases + to that of the mapped :class:`.Table` object's + primary key columns. For a :func:`.mapper` that + was given the ``primary key`` argument during + construction, the order of identifiers corresponds + to the elements present in this collection. + + :return: The object instance, or ``None``. + + """ + + # convert composite types to individual args + if hasattr(ident, '__composite_values__'): + ident = ident.__composite_values__() + + ident = util.to_list(ident) + + mapper = self._only_full_mapper_zero("get") + + if len(ident) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate " + "primary key for query.get(); primary key columns are %s" % + ','.join("'%s'" % c for c in mapper.primary_key)) + + key = mapper.identity_key_from_primary_key(ident) + + if not self._populate_existing and \ + not mapper.always_refresh and \ + self._for_update_arg is None: + + instance = loading.get_from_identity( + self.session, key, attributes.PASSIVE_OFF) + if instance is not None: + self._get_existing_condition() + # reject calls for id in identity map but class + # mismatch. + if not issubclass(instance.__class__, mapper.class_): + return None + return instance + + return loading.load_on_ident(self, key) + + @_generative() + def correlate(self, *args): + """Return a :class:`.Query` construct which will correlate the given + FROM clauses to that of an enclosing :class:`.Query` or + :func:`~.expression.select`. + + The method here accepts mapped classes, :func:`.aliased` constructs, + and :func:`.mapper` constructs as arguments, which are resolved into + expression constructs, in addition to appropriate expression + constructs. + + The correlation arguments are ultimately passed to + :meth:`.Select.correlate` after coercion to expression constructs. + + The correlation arguments take effect in such cases + as when :meth:`.Query.from_self` is used, or when + a subquery as returned by :meth:`.Query.subquery` is + embedded in another :func:`~.expression.select` construct. + + """ + + self._correlate = self._correlate.union( + _interpret_as_from(s) + if s is not None else None + for s in args) + + @_generative() + def autoflush(self, setting): + """Return a Query with a specific 'autoflush' setting. + + Note that a Session with autoflush=False will + not autoflush, even if this flag is set to True at the + Query level. Therefore this flag is usually used only + to disable autoflush for a specific Query. + + """ + self._autoflush = setting + + @_generative() + def populate_existing(self): + """Return a :class:`.Query` that will expire and refresh all instances + as they are loaded, or reused from the current :class:`.Session`. + + :meth:`.populate_existing` does not improve behavior when + the ORM is used normally - the :class:`.Session` object's usual + behavior of maintaining a transaction and expiring all attributes + after rollback or commit handles object state automatically. + This method is not intended for general use. + + """ + self._populate_existing = True + + @_generative() + def _with_invoke_all_eagers(self, value): + """Set the 'invoke all eagers' flag which causes joined- and + subquery loaders to traverse into already-loaded related objects + and collections. + + Default is that of :attr:`.Query._invoke_all_eagers`. + + """ + self._invoke_all_eagers = value + + def with_parent(self, instance, property=None): + """Add filtering criterion that relates the given instance + to a child object or collection, using its attribute state + as well as an established :func:`.relationship()` + configuration. + + The method uses the :func:`.with_parent` function to generate + the clause, the result of which is passed to :meth:`.Query.filter`. + + Parameters are the same as :func:`.with_parent`, with the exception + that the given property can be None, in which case a search is + performed against this :class:`.Query` object's target mapper. + + """ + + if property is None: + mapper = object_mapper(instance) + + for prop in mapper.iterate_properties: + if isinstance(prop, properties.RelationshipProperty) and \ + prop.mapper is self._mapper_zero(): + property = prop + break + else: + raise sa_exc.InvalidRequestError( + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" % + ( + self._mapper_zero().class_.__name__, + instance.__class__.__name__) + ) + + return self.filter(with_parent(instance, property)) + + @_generative() + def add_entity(self, entity, alias=None): + """add a mapped entity to the list of result columns + to be returned.""" + + if alias is not None: + entity = aliased(entity, alias) + + self._entities = list(self._entities) + m = _MapperEntity(self, entity) + self._set_entity_selectables([m]) + + @_generative() + def with_session(self, session): + """Return a :class:`.Query` that will use the given :class:`.Session`. + + """ + + self.session = session + + def from_self(self, *entities): + """return a Query that selects from this Query's + SELECT statement. + + \*entities - optional list of entities which will replace + those being selected. + + """ + fromclause = self.with_labels().enable_eagerloads(False).\ + _enable_single_crit(False).\ + statement.correlate(None) + q = self._from_selectable(fromclause) + if entities: + q._set_entities(entities) + return q + + @_generative() + def _enable_single_crit(self, val): + self._enable_single_crit = val + + @_generative() + def _from_selectable(self, fromclause): + for attr in ( + '_statement', '_criterion', + '_order_by', '_group_by', + '_limit', '_offset', + '_joinpath', '_joinpoint', + '_distinct', '_having', + '_prefixes', + ): + self.__dict__.pop(attr, None) + self._set_select_from([fromclause], True) + + # this enables clause adaptation for non-ORM + # expressions. + self._orm_only_from_obj_alias = False + + old_entities = self._entities + self._entities = [] + for e in old_entities: + e.adapt_to_selectable(self, self._from_obj[0]) + + def values(self, *columns): + """Return an iterator yielding result tuples corresponding + to the given list of columns""" + + if not columns: + return iter(()) + q = self._clone() + q._set_entities(columns, entity_wrapper=_ColumnEntity) + if not q._yield_per: + q._yield_per = 10 + return iter(q) + _values = values + + def value(self, column): + """Return a scalar result corresponding to the given + column expression.""" + try: + return next(self.values(column))[0] + except StopIteration: + return None + + @_generative() + def with_entities(self, *entities): + """Return a new :class:`.Query` replacing the SELECT list with the + given entities. + + e.g.:: + + # Users, filtered on some arbitrary criterion + # and then ordered by related email address + q = session.query(User).\\ + join(User.address).\\ + filter(User.name.like('%ed%')).\\ + order_by(Address.email) + + # given *only* User.id==5, Address.email, and 'q', what + # would the *next* User in the result be ? + subq = q.with_entities(Address.email).\\ + order_by(None).\\ + filter(User.id==5).\\ + subquery() + q = q.join((subq, subq.c.email < Address.email)).\\ + limit(1) + + .. versionadded:: 0.6.5 + + """ + self._set_entities(entities) + + @_generative() + def add_columns(self, *column): + """Add one or more column expressions to the list + of result columns to be returned.""" + + self._entities = list(self._entities) + l = len(self._entities) + for c in column: + _ColumnEntity(self, c) + # _ColumnEntity may add many entities if the + # given arg is a FROM clause + self._set_entity_selectables(self._entities[l:]) + + @util.pending_deprecation("0.7", + ":meth:`.add_column` is superseded by :meth:`.add_columns`", + False) + def add_column(self, column): + """Add a column expression to the list of result columns to be + returned. + + Pending deprecation: :meth:`.add_column` will be superseded by + :meth:`.add_columns`. + + """ + return self.add_columns(column) + + def options(self, *args): + """Return a new Query object, applying the given list of + mapper options. + + Most supplied options regard changing how column- and + relationship-mapped attributes are loaded. See the sections + :ref:`deferred` and :doc:`/orm/loading` for reference + documentation. + + """ + return self._options(False, *args) + + def _conditional_options(self, *args): + return self._options(True, *args) + + @_generative() + def _options(self, conditional, *args): + # most MapperOptions write to the '_attributes' dictionary, + # so copy that as well + self._attributes = self._attributes.copy() + opts = tuple(util.flatten_iterator(args)) + self._with_options = self._with_options + opts + if conditional: + for opt in opts: + opt.process_query_conditionally(self) + else: + for opt in opts: + opt.process_query(self) + + def with_transformation(self, fn): + """Return a new :class:`.Query` object transformed by + the given function. + + E.g.:: + + def filter_something(criterion): + def transform(q): + return q.filter(criterion) + return transform + + q = q.with_transformation(filter_something(x==5)) + + This allows ad-hoc recipes to be created for :class:`.Query` + objects. See the example at :ref:`hybrid_transformers`. + + .. versionadded:: 0.7.4 + + """ + return fn(self) + + @_generative() + def with_hint(self, selectable, text, dialect_name='*'): + """Add an indexing hint for the given entity or selectable to + this :class:`.Query`. + + Functionality is passed straight through to + :meth:`~sqlalchemy.sql.expression.Select.with_hint`, + with the addition that ``selectable`` can be a + :class:`.Table`, :class:`.Alias`, or ORM entity / mapped class + /etc. + """ + selectable = inspect(selectable).selectable + + self._with_hints += ((selectable, text, dialect_name),) + + @_generative() + def execution_options(self, **kwargs): + """ Set non-SQL options which take effect during execution. + + The options are the same as those accepted by + :meth:`.Connection.execution_options`. + + Note that the ``stream_results`` execution option is enabled + automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()` + method is used. + + """ + self._execution_options = self._execution_options.union(kwargs) + + @_generative() + def with_lockmode(self, mode): + """Return a new :class:`.Query` object with the specified "locking mode", + which essentially refers to the ``FOR UPDATE`` clause. + + .. deprecated:: 0.9.0 superseded by :meth:`.Query.with_for_update`. + + :param mode: a string representing the desired locking mode. + Valid values are: + + * ``None`` - translates to no lockmode + + * ``'update'`` - translates to ``FOR UPDATE`` + (standard SQL, supported by most dialects) + + * ``'update_nowait'`` - translates to ``FOR UPDATE NOWAIT`` + (supported by Oracle, PostgreSQL 8.1 upwards) + + * ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL), + and ``FOR SHARE`` (for PostgreSQL) + + .. seealso:: + + :meth:`.Query.with_for_update` - improved API for + specifying the ``FOR UPDATE`` clause. + + """ + self._for_update_arg = LockmodeArg.parse_legacy_query(mode) + + @_generative() + def with_for_update(self, read=False, nowait=False, of=None): + """return a new :class:`.Query` with the specified options for the + ``FOR UPDATE`` clause. + + The behavior of this method is identical to that of + :meth:`.SelectBase.with_for_update`. When called with no arguments, + the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause + appended. When additional arguments are specified, backend-specific + options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE`` + can take effect. + + E.g.:: + + q = sess.query(User).with_for_update(nowait=True, of=User) + + The above query on a Postgresql backend will render like:: + + SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT + + .. versionadded:: 0.9.0 :meth:`.Query.with_for_update` supersedes + the :meth:`.Query.with_lockmode` method. + + .. seealso:: + + :meth:`.GenerativeSelect.with_for_update` - Core level method with + full argument and behavioral description. + + """ + self._for_update_arg = LockmodeArg(read=read, nowait=nowait, of=of) + + @_generative() + def params(self, *args, **kwargs): + """add values for bind parameters which may have been + specified in filter(). + + parameters may be specified using \**kwargs, or optionally a single + dictionary as the first positional argument. The reason for both is + that \**kwargs is convenient, however some parameter dictionaries + contain unicode keys in which case \**kwargs cannot be used. + + """ + if len(args) == 1: + kwargs.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary.") + self._params = self._params.copy() + self._params.update(kwargs) + + @_generative(_no_statement_condition, _no_limit_offset) + def filter(self, *criterion): + """apply the given filtering criterion to a copy + of this :class:`.Query`, using SQL expressions. + + e.g.:: + + session.query(MyClass).filter(MyClass.name == 'some name') + + Multiple criteria are joined together by AND:: + + session.query(MyClass).\\ + filter(MyClass.name == 'some name', MyClass.id > 5) + + The criterion is any SQL expression object applicable to the + WHERE clause of a select. String expressions are coerced + into SQL expression constructs via the :func:`.text` construct. + + .. versionchanged:: 0.7.5 + Multiple criteria joined by AND. + + .. seealso:: + + :meth:`.Query.filter_by` - filter on keyword expressions. + + """ + for criterion in list(criterion): + criterion = expression._literal_as_text(criterion) + + criterion = self._adapt_clause(criterion, True, True) + + if self._criterion is not None: + self._criterion = self._criterion & criterion + else: + self._criterion = criterion + + def filter_by(self, **kwargs): + """apply the given filtering criterion to a copy + of this :class:`.Query`, using keyword expressions. + + e.g.:: + + session.query(MyClass).filter_by(name = 'some name') + + Multiple criteria are joined together by AND:: + + session.query(MyClass).\\ + filter_by(name = 'some name', id = 5) + + The keyword expressions are extracted from the primary + entity of the query, or the last entity that was the + target of a call to :meth:`.Query.join`. + + .. seealso:: + + :meth:`.Query.filter` - filter on SQL expressions. + + """ + + clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value + for key, value in kwargs.items()] + return self.filter(sql.and_(*clauses)) + + @_generative(_no_statement_condition, _no_limit_offset) + def order_by(self, *criterion): + """apply one or more ORDER BY criterion to the query and return + the newly resulting ``Query`` + + All existing ORDER BY settings can be suppressed by + passing ``None`` - this will suppress any ORDER BY configured + on mappers as well. + + Alternatively, an existing ORDER BY setting on the Query + object can be entirely cancelled by passing ``False`` + as the value - use this before calling methods where + an ORDER BY is invalid. + + """ + + if len(criterion) == 1: + if criterion[0] is False: + if '_order_by' in self.__dict__: + del self._order_by + return + if criterion[0] is None: + self._order_by = None + return + + criterion = self._adapt_col_list(criterion) + + if self._order_by is False or self._order_by is None: + self._order_by = criterion + else: + self._order_by = self._order_by + criterion + + @_generative(_no_statement_condition, _no_limit_offset) + def group_by(self, *criterion): + """apply one or more GROUP BY criterion to the query and return + the newly resulting :class:`.Query`""" + + criterion = list(chain(*[_orm_columns(c) for c in criterion])) + criterion = self._adapt_col_list(criterion) + + if self._group_by is False: + self._group_by = criterion + else: + self._group_by = self._group_by + criterion + + @_generative(_no_statement_condition, _no_limit_offset) + def having(self, criterion): + """apply a HAVING criterion to the query and return the + newly resulting :class:`.Query`. + + :meth:`~.Query.having` is used in conjunction with :meth:`~.Query.group_by`. + + HAVING criterion makes it possible to use filters on aggregate + functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: + + q = session.query(User.id).\\ + join(User.addresses).\\ + group_by(User.id).\\ + having(func.count(Address.id) > 2) + + """ + + if isinstance(criterion, util.string_types): + criterion = sql.text(criterion) + + if criterion is not None and \ + not isinstance(criterion, sql.ClauseElement): + raise sa_exc.ArgumentError( + "having() argument must be of type " + "sqlalchemy.sql.ClauseElement or string") + + criterion = self._adapt_clause(criterion, True, True) + + if self._having is not None: + self._having = self._having & criterion + else: + self._having = criterion + + def union(self, *q): + """Produce a UNION of this Query against one or more queries. + + e.g.:: + + q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') + q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') + + q3 = q1.union(q2) + + The method accepts multiple Query objects so as to control + the level of nesting. A series of ``union()`` calls such as:: + + x.union(y).union(z).all() + + will nest on each ``union()``, and produces:: + + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION + SELECT * FROM y) UNION SELECT * FROM Z) + + Whereas:: + + x.union(y, z).all() + + produces:: + + SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION + SELECT * FROM Z) + + Note that many database backends do not allow ORDER BY to + be rendered on a query called within UNION, EXCEPT, etc. + To disable all ORDER BY clauses including those configured + on mappers, issue ``query.order_by(None)`` - the resulting + :class:`.Query` object will not render ORDER BY within + its SELECT statement. + + """ + + return self._from_selectable( + expression.union(*([self] + list(q)))) + + def union_all(self, *q): + """Produce a UNION ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + """ + return self._from_selectable( + expression.union_all(*([self] + list(q))) + ) + + def intersect(self, *q): + """Produce an INTERSECT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + """ + return self._from_selectable( + expression.intersect(*([self] + list(q))) + ) + + def intersect_all(self, *q): + """Produce an INTERSECT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + """ + return self._from_selectable( + expression.intersect_all(*([self] + list(q))) + ) + + def except_(self, *q): + """Produce an EXCEPT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + """ + return self._from_selectable( + expression.except_(*([self] + list(q))) + ) + + def except_all(self, *q): + """Produce an EXCEPT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + """ + return self._from_selectable( + expression.except_all(*([self] + list(q))) + ) + + def join(self, *props, **kwargs): + """Create a SQL JOIN against this :class:`.Query` object's criterion + and apply generatively, returning the newly resulting :class:`.Query`. + + **Simple Relationship Joins** + + Consider a mapping between two classes ``User`` and ``Address``, + with a relationship ``User.addresses`` representing a collection + of ``Address`` objects associated with each ``User``. The most common + usage of :meth:`~.Query.join` is to create a JOIN along this + relationship, using the ``User.addresses`` attribute as an indicator + for how this should occur:: + + q = session.query(User).join(User.addresses) + + Where above, the call to :meth:`~.Query.join` along ``User.addresses`` + will result in SQL equivalent to:: + + SELECT user.* FROM user JOIN address ON user.id = address.user_id + + In the above example we refer to ``User.addresses`` as passed to + :meth:`~.Query.join` as the *on clause*, that is, it indicates + how the "ON" portion of the JOIN should be constructed. For a + single-entity query such as the one above (i.e. we start by selecting + only from ``User`` and nothing else), the relationship can also be + specified by its string name:: + + q = session.query(User).join("addresses") + + :meth:`~.Query.join` can also accommodate multiple + "on clause" arguments to produce a chain of joins, such as below + where a join across four related entities is constructed:: + + q = session.query(User).join("orders", "items", "keywords") + + The above would be shorthand for three separate calls to + :meth:`~.Query.join`, each using an explicit attribute to indicate + the source entity:: + + q = session.query(User).\\ + join(User.orders).\\ + join(Order.items).\\ + join(Item.keywords) + + **Joins to a Target Entity or Selectable** + + A second form of :meth:`~.Query.join` allows any mapped entity + or core selectable construct as a target. In this usage, + :meth:`~.Query.join` will attempt + to create a JOIN along the natural foreign key relationship between + two entities:: + + q = session.query(User).join(Address) + + The above calling form of :meth:`~.Query.join` will raise an error if + either there are no foreign keys between the two entities, or if + there are multiple foreign key linkages between them. In the + above calling form, :meth:`~.Query.join` is called upon to + create the "on clause" automatically for us. The target can + be any mapped entity or selectable, such as a :class:`.Table`:: + + q = session.query(User).join(addresses_table) + + **Joins to a Target with an ON Clause** + + The third calling form allows both the target entity as well + as the ON clause to be passed explicitly. Suppose for + example we wanted to join to ``Address`` twice, using + an alias the second time. We use :func:`~sqlalchemy.orm.aliased` + to create a distinct alias of ``Address``, and join + to it using the ``target, onclause`` form, so that the + alias can be specified explicitly as the target along with + the relationship to instruct how the ON clause should proceed:: + + a_alias = aliased(Address) + + q = session.query(User).\\ + join(User.addresses).\\ + join(a_alias, User.addresses).\\ + filter(Address.email_address=='ed@foo.com').\\ + filter(a_alias.email_address=='ed@bar.com') + + Where above, the generated SQL would be similar to:: + + SELECT user.* FROM user + JOIN address ON user.id = address.user_id + JOIN address AS address_1 ON user.id=address_1.user_id + WHERE address.email_address = :email_address_1 + AND address_1.email_address = :email_address_2 + + The two-argument calling form of :meth:`~.Query.join` + also allows us to construct arbitrary joins with SQL-oriented + "on clause" expressions, not relying upon configured relationships + at all. Any SQL expression can be passed as the ON clause + when using the two-argument form, which should refer to the target + entity in some way as well as an applicable source entity:: + + q = session.query(User).join(Address, User.id==Address.user_id) + + .. versionchanged:: 0.7 + In SQLAlchemy 0.6 and earlier, the two argument form of + :meth:`~.Query.join` requires the usage of a tuple: + ``query(User).join((Address, User.id==Address.user_id))``\ . + This calling form is accepted in 0.7 and further, though + is not necessary unless multiple join conditions are passed to + a single :meth:`~.Query.join` call, which itself is also not + generally necessary as it is now equivalent to multiple + calls (this wasn't always the case). + + **Advanced Join Targeting and Adaption** + + There is a lot of flexibility in what the "target" can be when using + :meth:`~.Query.join`. As noted previously, it also accepts + :class:`.Table` constructs and other selectables such as + :func:`.alias` and :func:`.select` constructs, with either the one + or two-argument forms:: + + addresses_q = select([Address.user_id]).\\ + where(Address.email_address.endswith("@bar.com")).\\ + alias() + + q = session.query(User).\\ + join(addresses_q, addresses_q.c.user_id==User.id) + + :meth:`~.Query.join` also features the ability to *adapt* a + :meth:`~sqlalchemy.orm.relationship` -driven ON clause to the target + selectable. Below we construct a JOIN from ``User`` to a subquery + against ``Address``, allowing the relationship denoted by + ``User.addresses`` to *adapt* itself to the altered target:: + + address_subq = session.query(Address).\\ + filter(Address.email_address == 'ed@foo.com').\\ + subquery() + + q = session.query(User).join(address_subq, User.addresses) + + Producing SQL similar to:: + + SELECT user.* FROM user + JOIN ( + SELECT address.id AS id, + address.user_id AS user_id, + address.email_address AS email_address + FROM address + WHERE address.email_address = :email_address_1 + ) AS anon_1 ON user.id = anon_1.user_id + + The above form allows one to fall back onto an explicit ON + clause at any time:: + + q = session.query(User).\\ + join(address_subq, User.id==address_subq.c.user_id) + + **Controlling what to Join From** + + While :meth:`~.Query.join` exclusively deals with the "right" + side of the JOIN, we can also control the "left" side, in those + cases where it's needed, using :meth:`~.Query.select_from`. + Below we construct a query against ``Address`` but can still + make usage of ``User.addresses`` as our ON clause by instructing + the :class:`.Query` to select first from the ``User`` + entity:: + + q = session.query(Address).select_from(User).\\ + join(User.addresses).\\ + filter(User.name == 'ed') + + Which will produce SQL similar to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + **Constructing Aliases Anonymously** + + :meth:`~.Query.join` can construct anonymous aliases + using the ``aliased=True`` flag. This feature is useful + when a query is being joined algorithmically, such as + when querying self-referentially to an arbitrary depth:: + + q = session.query(Node).\\ + join("children", "children", aliased=True) + + When ``aliased=True`` is used, the actual "alias" construct + is not explicitly available. To work with it, methods such as + :meth:`.Query.filter` will adapt the incoming entity to + the last join point:: + + q = session.query(Node).\\ + join("children", "children", aliased=True).\\ + filter(Node.name == 'grandchild 1') + + When using automatic aliasing, the ``from_joinpoint=True`` + argument can allow a multi-node join to be broken into + multiple calls to :meth:`~.Query.join`, so that + each path along the way can be further filtered:: + + q = session.query(Node).\\ + join("children", aliased=True).\\ + filter(Node.name='child 1').\\ + join("children", aliased=True, from_joinpoint=True).\\ + filter(Node.name == 'grandchild 1') + + The filtering aliases above can then be reset back to the + original ``Node`` entity using :meth:`~.Query.reset_joinpoint`:: + + q = session.query(Node).\\ + join("children", "children", aliased=True).\\ + filter(Node.name == 'grandchild 1').\\ + reset_joinpoint().\\ + filter(Node.name == 'parent 1) + + For an example of ``aliased=True``, see the distribution + example :ref:`examples_xmlpersistence` which illustrates + an XPath-like query system using algorithmic joins. + + :param \*props: A collection of one or more join conditions, + each consisting of a relationship-bound attribute or string + relationship name representing an "on clause", or a single + target entity, or a tuple in the form of ``(target, onclause)``. + A special two-argument calling form of the form ``target, onclause`` + is also accepted. + :param aliased=False: If True, indicate that the JOIN target should be + anonymously aliased. Subsequent calls to :meth:`~.Query.filter` + and similar will adapt the incoming criterion to the target + alias, until :meth:`~.Query.reset_joinpoint` is called. + :param from_joinpoint=False: When using ``aliased=True``, a setting + of True here will cause the join to be from the most recent + joined target, rather than starting back from the original + FROM clauses of the query. + + .. seealso:: + + :ref:`ormtutorial_joins` in the ORM tutorial. + + :ref:`inheritance_toplevel` for details on how :meth:`~.Query.join` + is used for inheritance relationships. + + :func:`.orm.join` - a standalone ORM-level join function, + used internally by :meth:`.Query.join`, which in previous + SQLAlchemy versions was the primary ORM-level joining interface. + + """ + aliased, from_joinpoint = kwargs.pop('aliased', False),\ + kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % + ','.join(kwargs.keys)) + return self._join(props, + outerjoin=False, create_aliases=aliased, + from_joinpoint=from_joinpoint) + + def outerjoin(self, *props, **kwargs): + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, returning the newly resulting ``Query``. + + Usage is the same as the ``join()`` method. + + """ + aliased, from_joinpoint = kwargs.pop('aliased', False), \ + kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % + ','.join(kwargs)) + return self._join(props, + outerjoin=True, create_aliases=aliased, + from_joinpoint=from_joinpoint) + + def _update_joinpoint(self, jp): + self._joinpoint = jp + # copy backwards to the root of the _joinpath + # dict, so that no existing dict in the path is mutated + while 'prev' in jp: + f, prev = jp['prev'] + prev = prev.copy() + prev[f] = jp + jp['prev'] = (f, prev) + jp = prev + self._joinpath = jp + + @_generative(_no_statement_condition, _no_limit_offset) + def _join(self, keys, outerjoin, create_aliases, from_joinpoint): + """consumes arguments from join() or outerjoin(), places them into a + consistent format with which to form the actual JOIN constructs. + + """ + + if not from_joinpoint: + self._reset_joinpoint() + + if len(keys) == 2 and \ + isinstance(keys[0], (expression.FromClause, + type, AliasedClass)) and \ + isinstance(keys[1], (str, expression.ClauseElement, + interfaces.PropComparator)): + # detect 2-arg form of join and + # convert to a tuple. + keys = (keys,) + + for arg1 in util.to_list(keys): + if isinstance(arg1, tuple): + # "tuple" form of join, multiple + # tuples are accepted as well. The simpler + # "2-arg" form is preferred. May deprecate + # the "tuple" usage. + arg1, arg2 = arg1 + else: + arg2 = None + + # determine onclause/right_entity. there + # is a little bit of legacy behavior still at work here + # which means they might be in either order. may possibly + # lock this down to (right_entity, onclause) in 0.6. + if isinstance(arg1, (interfaces.PropComparator, util.string_types)): + right_entity, onclause = arg2, arg1 + else: + right_entity, onclause = arg1, arg2 + + left_entity = prop = None + + if isinstance(onclause, util.string_types): + left_entity = self._joinpoint_zero() + + descriptor = _entity_descriptor(left_entity, onclause) + onclause = descriptor + + # check for q.join(Class.propname, from_joinpoint=True) + # and Class is that of the current joinpoint + elif from_joinpoint and \ + isinstance(onclause, interfaces.PropComparator): + left_entity = onclause._parententity + + info = inspect(self._joinpoint_zero()) + left_mapper, left_selectable, left_is_aliased = \ + getattr(info, 'mapper', None), \ + info.selectable, \ + getattr(info, 'is_aliased_class', None) + + if left_mapper is left_entity: + left_entity = self._joinpoint_zero() + descriptor = _entity_descriptor(left_entity, + onclause.key) + onclause = descriptor + + if isinstance(onclause, interfaces.PropComparator): + if right_entity is None: + right_entity = onclause.property.mapper + of_type = getattr(onclause, '_of_type', None) + if of_type: + right_entity = of_type + else: + right_entity = onclause.property.mapper + + left_entity = onclause._parententity + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + if not create_aliases: + # check for this path already present. + # don't render in that case. + edge = (left_entity, right_entity, prop.key) + if edge in self._joinpoint: + # The child's prev reference might be stale -- + # it could point to a parent older than the + # current joinpoint. If this is the case, + # then we need to update it and then fix the + # tree's spine with _update_joinpoint. Copy + # and then mutate the child, which might be + # shared by a different query object. + jp = self._joinpoint[edge].copy() + jp['prev'] = (edge, self._joinpoint) + self._update_joinpoint(jp) + continue + + elif onclause is not None and right_entity is None: + # TODO: no coverage here + raise NotImplementedError("query.join(a==b) not supported.") + + self._join_left_to_right( + left_entity, + right_entity, onclause, + outerjoin, create_aliases, prop) + + + def _join_left_to_right(self, left, right, + onclause, outerjoin, create_aliases, prop): + """append a JOIN to the query's from clause.""" + + self._polymorphic_adapters = self._polymorphic_adapters.copy() + + if left is None: + if self._from_obj: + left = self._from_obj[0] + elif self._entities: + left = self._entities[0].entity_zero_or_selectable + + if left is None: + raise sa_exc.InvalidRequestError( + "Don't know how to join from %s; please use " + "select_from() to establish the left " + "entity/selectable of this join" % self._entities[0]) + + if left is right and \ + not create_aliases: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % + (left, right)) + + l_info = inspect(left) + r_info = inspect(right) + + + overlap = False + if not create_aliases: + right_mapper = getattr(r_info, "mapper", None) + # if the target is a joined inheritance mapping, + # be more liberal about auto-aliasing. + if right_mapper and ( + right_mapper.with_polymorphic or + isinstance(right_mapper.mapped_table, expression.Join) + ): + for from_obj in self._from_obj or [l_info.selectable]: + if sql_util.selectables_overlap(l_info.selectable, from_obj) and \ + sql_util.selectables_overlap(from_obj, r_info.selectable): + overlap = True + break + elif sql_util.selectables_overlap(l_info.selectable, r_info.selectable): + overlap = True + + + if overlap and l_info.selectable is r_info.selectable: + raise sa_exc.InvalidRequestError( + "Can't join table/selectable '%s' to itself" % + l_info.selectable) + + right, onclause = self._prepare_right_side( + r_info, right, onclause, + create_aliases, + prop, overlap) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity': right, + 'prev': ((left, right, prop.key), self._joinpoint) + }) + else: + self._joinpoint = {'_joinpoint_entity': right} + + self._join_to_left(l_info, left, right, onclause, outerjoin) + + def _prepare_right_side(self, r_info, right, onclause, create_aliases, + prop, overlap): + info = r_info + + right_mapper, right_selectable, right_is_aliased = \ + getattr(info, 'mapper', None), \ + info.selectable, \ + getattr(info, 'is_aliased_class', False) + + if right_mapper: + self._join_entities += (info, ) + + if right_mapper and prop and \ + not right_mapper.common_parent(prop.mapper): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) + ) + + if not right_mapper and prop: + right_mapper = prop.mapper + + need_adapter = False + + if right_mapper and right is right_selectable: + if not right_selectable.is_derived_from( + right_mapper.mapped_table): + raise sa_exc.InvalidRequestError( + "Selectable '%s' is not derived from '%s'" % + (right_selectable.description, + right_mapper.mapped_table.description)) + + if isinstance(right_selectable, expression.SelectBase): + # TODO: this isn't even covered now! + right_selectable = right_selectable.alias() + need_adapter = True + + right = aliased(right_mapper, right_selectable) + + aliased_entity = right_mapper and \ + not right_is_aliased and \ + ( + right_mapper.with_polymorphic and isinstance( + right_mapper._with_polymorphic_selectable, + expression.Alias) + or + overlap # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + ) + + if not need_adapter and (create_aliases or aliased_entity): + right = aliased(right, flat=True) + need_adapter = True + + # if an alias() of the right side was generated here, + # apply an adapter to all subsequent filter() calls + # until reset_joinpoint() is called. + if need_adapter: + self._filter_aliases = ORMAdapter(right, + equivalents=right_mapper and + right_mapper._equivalent_columns or {}, + chain_to=self._filter_aliases) + + # if the onclause is a ClauseElement, adapt it with any + # adapters that are in place right now + if isinstance(onclause, expression.ClauseElement): + onclause = self._adapt_clause(onclause, True, True) + + # if an alias() on the right side was generated, + # which is intended to wrap a the right side in a subquery, + # ensure that columns retrieved from this target in the result + # set are also adapted. + if aliased_entity and not create_aliases: + self._mapper_loads_polymorphically_with( + right_mapper, + ORMAdapter( + right, + equivalents=right_mapper._equivalent_columns + ) + ) + + return right, onclause + + def _join_to_left(self, l_info, left, right, onclause, outerjoin): + info = l_info + left_mapper = getattr(info, 'mapper', None) + left_selectable = info.selectable + + if self._from_obj: + replace_clause_index, clause = sql_util.find_join_source( + self._from_obj, + left_selectable) + if clause is not None: + try: + clause = orm_join(clause, + right, + onclause, isouter=outerjoin) + except sa_exc.ArgumentError as ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) + + self._from_obj = \ + self._from_obj[:replace_clause_index] + \ + (clause, ) + \ + self._from_obj[replace_clause_index + 1:] + return + + if left_mapper: + for ent in self._entities: + if ent.corresponds_to(left): + clause = ent.selectable + break + else: + clause = left + else: + clause = left_selectable + + assert clause is not None + try: + clause = orm_join(clause, right, onclause, isouter=outerjoin) + except sa_exc.ArgumentError as ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) + self._from_obj = self._from_obj + (clause,) + + def _reset_joinpoint(self): + self._joinpoint = self._joinpath + self._filter_aliases = None + + @_generative(_no_statement_condition) + def reset_joinpoint(self): + """Return a new :class:`.Query`, where the "join point" has + been reset back to the base FROM entities of the query. + + This method is usually used in conjunction with the + ``aliased=True`` feature of the :meth:`~.Query.join` + method. See the example in :meth:`~.Query.join` for how + this is used. + + """ + self._reset_joinpoint() + + @_generative(_no_clauseelement_condition) + def select_from(self, *from_obj): + """Set the FROM clause of this :class:`.Query` explicitly. + + :meth:`.Query.select_from` is often used in conjunction with + :meth:`.Query.join` in order to control which entity is selected + from on the "left" side of the join. + + The entity or selectable object here effectively replaces the + "left edge" of any calls to :meth:`~.Query.join`, when no + joinpoint is otherwise established - usually, the default "join + point" is the leftmost entity in the :class:`~.Query` object's + list of entities to be selected. + + A typical example:: + + q = session.query(Address).select_from(User).\\ + join(User.addresses).\\ + filter(User.name == 'ed') + + Which produces SQL equivalent to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + :param \*from_obj: collection of one or more entities to apply + to the FROM clause. Entities can be mapped classes, + :class:`.AliasedClass` objects, :class:`.Mapper` objects + as well as core :class:`.FromClause` elements like subqueries. + + .. versionchanged:: 0.9 + This method no longer applies the given FROM object + to be the selectable from which matching entities + select from; the :meth:`.select_entity_from` method + now accomplishes this. See that method for a description + of this behavior. + + .. seealso:: + + :meth:`~.Query.join` + + :meth:`.Query.select_entity_from` + + """ + + self._set_select_from(from_obj, False) + + @_generative(_no_clauseelement_condition) + def select_entity_from(self, from_obj): + """Set the FROM clause of this :class:`.Query` to a + core selectable, applying it as a replacement FROM clause + for corresponding mapped entities. + + This method is similar to the :meth:`.Query.select_from` + method, in that it sets the FROM clause of the query. However, + where :meth:`.Query.select_from` only affects what is placed + in the FROM, this method also applies the given selectable + to replace the FROM which the selected entities would normally + select from. + + The given ``from_obj`` must be an instance of a :class:`.FromClause`, + e.g. a :func:`.select` or :class:`.Alias` construct. + + An example would be a :class:`.Query` that selects ``User`` entities, + but uses :meth:`.Query.select_entity_from` to have the entities + selected from a :func:`.select` construct instead of the + base ``user`` table:: + + select_stmt = select([User]).where(User.id == 7) + + q = session.query(User).\\ + select_entity_from(select_stmt).\\ + filter(User.name == 'ed') + + The query generated will select ``User`` entities directly + from the given :func:`.select` construct, and will be:: + + SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name + FROM (SELECT "user".id AS id, "user".name AS name + FROM "user" + WHERE "user".id = :id_1) AS anon_1 + WHERE anon_1.name = :name_1 + + Notice above that even the WHERE criterion was "adapted" such that + the ``anon_1`` subquery effectively replaces all references to the + ``user`` table, except for the one that it refers to internally. + + Compare this to :meth:`.Query.select_from`, which as of + version 0.9, does not affect existing entities. The + statement below:: + + q = session.query(User).\\ + select_from(select_stmt).\\ + filter(User.name == 'ed') + + Produces SQL where both the ``user`` table as well as the + ``select_stmt`` construct are present as separate elements + in the FROM clause. No "adaptation" of the ``user`` table + is applied:: + + SELECT "user".id AS user_id, "user".name AS user_name + FROM "user", (SELECT "user".id AS id, "user".name AS name + FROM "user" + WHERE "user".id = :id_1) AS anon_1 + WHERE "user".name = :name_1 + + :meth:`.Query.select_entity_from` maintains an older + behavior of :meth:`.Query.select_from`. In modern usage, + similar results can also be achieved using :func:`.aliased`:: + + select_stmt = select([User]).where(User.id == 7) + user_from_select = aliased(User, select_stmt.alias()) + + q = session.query(user_from_select) + + :param from_obj: a :class:`.FromClause` object that will replace + the FROM clause of this :class:`.Query`. + + .. seealso:: + + :meth:`.Query.select_from` + + .. versionadded:: 0.8 + :meth:`.Query.select_entity_from` was added to specify + the specific behavior of entity replacement, however + the :meth:`.Query.select_from` maintains this behavior + as well until 0.9. + + """ + + self._set_select_from([from_obj], True) + + def __getitem__(self, item): + if isinstance(item, slice): + start, stop, step = util.decode_slice(item) + + if isinstance(stop, int) and \ + isinstance(start, int) and \ + stop - start <= 0: + return [] + + # perhaps we should execute a count() here so that we + # can still use LIMIT/OFFSET ? + elif (isinstance(start, int) and start < 0) \ + or (isinstance(stop, int) and stop < 0): + return list(self)[item] + + res = self.slice(start, stop) + if step is not None: + return list(res)[None:None:item.step] + else: + return list(res) + else: + if item == -1: + return list(self)[-1] + else: + return list(self[item:item + 1])[0] + + @_generative(_no_statement_condition) + def slice(self, start, stop): + """apply LIMIT/OFFSET to the ``Query`` based on a " + "range and return the newly resulting ``Query``.""" + + if start is not None and stop is not None: + self._offset = (self._offset or 0) + start + self._limit = stop - start + elif start is None and stop is not None: + self._limit = stop + elif start is not None and stop is None: + self._offset = (self._offset or 0) + start + + if self._offset == 0: + self._offset = None + + @_generative(_no_statement_condition) + def limit(self, limit): + """Apply a ``LIMIT`` to the query and return the newly resulting + + ``Query``. + + """ + self._limit = limit + + @_generative(_no_statement_condition) + def offset(self, offset): + """Apply an ``OFFSET`` to the query and return the newly resulting + ``Query``. + + """ + self._offset = offset + + @_generative(_no_statement_condition) + def distinct(self, *criterion): + """Apply a ``DISTINCT`` to the query and return the newly resulting + ``Query``. + + :param \*expr: optional column expressions. When present, + the Postgresql dialect will render a ``DISTINCT ON (>)`` + construct. + + """ + if not criterion: + self._distinct = True + else: + criterion = self._adapt_col_list(criterion) + if isinstance(self._distinct, list): + self._distinct += criterion + else: + self._distinct = criterion + + @_generative() + def prefix_with(self, *prefixes): + """Apply the prefixes to the query and return the newly resulting + ``Query``. + + :param \*prefixes: optional prefixes, typically strings, + not using any commas. In particular is useful for MySQL keywords. + + e.g.:: + + query = sess.query(User.name).\\ + prefix_with('HIGH_PRIORITY').\\ + prefix_with('SQL_SMALL_RESULT', 'ALL') + + Would render:: + + SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL users.name AS users_name + FROM users + + .. versionadded:: 0.7.7 + + """ + if self._prefixes: + self._prefixes += prefixes + else: + self._prefixes = prefixes + + def all(self): + """Return the results represented by this ``Query`` as a list. + + This results in an execution of the underlying query. + + """ + return list(self) + + @_generative(_no_clauseelement_condition) + def from_statement(self, statement): + """Execute the given SELECT statement and return results. + + This method bypasses all internal statement compilation, and the + statement is executed without modification. + + The statement argument is either a string, a ``select()`` construct, + or a ``text()`` construct, and should return the set of columns + appropriate to the entity class represented by this ``Query``. + + """ + if isinstance(statement, util.string_types): + statement = sql.text(statement) + + if not isinstance(statement, + (expression.TextClause, + expression.SelectBase)): + raise sa_exc.ArgumentError( + "from_statement accepts text(), select(), " + "and union() objects only.") + + self._statement = statement + + def first(self): + """Return the first result of this ``Query`` or + None if the result doesn't contain any row. + + first() applies a limit of one within the generated SQL, so that + only one primary entity row is generated on the server side + (note this may consist of multiple result rows if join-loaded + collections are present). + + Calling ``first()`` results in an execution of the underlying query. + + """ + if self._statement is not None: + ret = list(self)[0:1] + else: + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] + else: + return None + + def one(self): + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that does not return object + identities. + + Note that an entity query, that is, one which selects one or + more mapped classes as opposed to individual column attributes, + may ultimately represent many rows but only one row of + unique entity or entities - this is a successful result for one(). + + Calling ``one()`` results in an execution of the underlying query. + + .. versionchanged:: 0.6 + ``one()`` fully fetches all results instead of applying + any kind of limit, so that the "unique"-ing of entities does not + conceal multiple object identities. + + """ + ret = list(self) + + l = len(ret) + if l == 1: + return ret[0] + elif l == 0: + raise orm_exc.NoResultFound("No row was found for one()") + else: + raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one()") + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + >>> session.query(Item).scalar() + + >>> session.query(Item.id).scalar() + 1 + >>> session.query(Item.id).filter(Item.id < 0).scalar() + None + >>> session.query(Item.id, Item.name).scalar() + 1 + >>> session.query(func.count(Parent.id)).scalar() + 20 + + This results in an execution of the underlying query. + + """ + try: + ret = self.one() + if not isinstance(ret, tuple): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def __iter__(self): + context = self._compile_context() + context.statement.use_labels = True + if self._autoflush and not self._populate_existing: + self.session._autoflush() + return self._execute_and_instances(context) + + def _connection_from_session(self, **kw): + conn = self.session.connection( + **kw) + if self._execution_options: + conn = conn.execution_options(**self._execution_options) + return conn + + def _execute_and_instances(self, querycontext): + conn = self._connection_from_session( + mapper=self._mapper_zero_or_none(), + clause=querycontext.statement, + close_with_result=True) + + result = conn.execute(querycontext.statement, self._params) + return loading.instances(self, result, querycontext) + + @property + def column_descriptions(self): + """Return metadata about the columns which would be + returned by this :class:`.Query`. + + Format is a list of dictionaries:: + + user_alias = aliased(User, name='user2') + q = sess.query(User, User.id, user_alias) + + # this expression: + q.column_descriptions + + # would return: + [ + { + 'name':'User', + 'type':User, + 'aliased':False, + 'expr':User, + }, + { + 'name':'id', + 'type':Integer(), + 'aliased':False, + 'expr':User.id, + }, + { + 'name':'user2', + 'type':User, + 'aliased':True, + 'expr':user_alias + } + ] + + """ + return [ + { + 'name': ent._label_name, + 'type': ent.type, + 'aliased': getattr(ent, 'is_aliased_class', False), + 'expr': ent.expr + } + for ent in self._entities + ] + + def instances(self, cursor, __context=None): + """Given a ResultProxy cursor as returned by connection.execute(), + return an ORM result as an iterator. + + e.g.:: + + result = engine.execute("select * from users") + for u in session.query(User).instances(result): + print u + """ + context = __context + if context is None: + context = QueryContext(self) + + return loading.instances(self, cursor, context) + + def merge_result(self, iterator, load=True): + """Merge a result into this :class:`.Query` object's Session. + + Given an iterator returned by a :class:`.Query` of the same structure + as this one, return an identical iterator of results, with all mapped + instances merged into the session using :meth:`.Session.merge`. This + is an optimized method which will merge all mapped instances, + preserving the structure of the result rows and unmapped columns with + less method overhead than that of calling :meth:`.Session.merge` + explicitly for each value. + + The structure of the results is determined based on the column list of + this :class:`.Query` - if these do not correspond, unchecked errors + will occur. + + The 'load' argument is the same as that of :meth:`.Session.merge`. + + For an example of how :meth:`~.Query.merge_result` is used, see + the source code for the example :ref:`examples_caching`, where + :meth:`~.Query.merge_result` is used to efficiently restore state + from a cache back into a target :class:`.Session`. + + """ + + return loading.merge_result(self, iterator, load) + + @property + def _select_args(self): + return { + 'limit': self._limit, + 'offset': self._offset, + 'distinct': self._distinct, + 'prefixes': self._prefixes, + 'group_by': self._group_by or None, + 'having': self._having + } + + @property + def _should_nest_selectable(self): + kwargs = self._select_args + return (kwargs.get('limit') is not None or + kwargs.get('offset') is not None or + kwargs.get('distinct', False)) + + def exists(self): + """A convenience method that turns a query into an EXISTS subquery + of the form EXISTS (SELECT 1 FROM ... WHERE ...). + + e.g.:: + + q = session.query(User).filter(User.name == 'fred') + session.query(q.exists()) + + Producing SQL similar to:: + + SELECT EXISTS ( + SELECT 1 FROM users WHERE users.name = :name_1 + ) AS anon_1 + + .. versionadded:: 0.8.1 + + """ + + # .add_columns() for the case that we are a query().select_from(X), + # so that ".statement" can be produced (#2995) but also without + # omitting the FROM clause from a query(X) (#2818); + # .with_only_columns() after we have a core select() so that + # we get just "SELECT 1" without any entities. + return sql.exists(self.add_columns('1').with_labels(). + statement.with_only_columns(['1'])) + + def count(self): + """Return a count of rows this Query would return. + + This generates the SQL for this Query as follows:: + + SELECT count(1) AS count_1 FROM ( + SELECT + ) AS anon_1 + + .. versionchanged:: 0.7 + The above scheme is newly refined as of 0.7b3. + + For fine grained control over specific columns + to count, to skip the usage of a subquery or + otherwise control of the FROM clause, + or to use other aggregate functions, + use :attr:`~sqlalchemy.sql.expression.func` + expressions in conjunction + with :meth:`~.Session.query`, i.e.:: + + from sqlalchemy import func + + # count User records, without + # using a subquery. + session.query(func.count(User.id)) + + # return count of user "id" grouped + # by "name" + session.query(func.count(User.id)).\\ + group_by(User.name) + + from sqlalchemy import distinct + + # count distinct "name" values + session.query(func.count(distinct(User.name))) + + """ + col = sql.func.count(sql.literal_column('*')) + return self.from_self(col).scalar() + + def delete(self, synchronize_session='evaluate'): + """Perform a bulk delete query. + + Deletes rows matched by this query from the database. + + :param synchronize_session: chooses the strategy for the removal of + matched objects from the session. Valid values are: + + ``False`` - don't synchronize the session. This option is the most + efficient and is reliable once the session is expired, which + typically occurs after a commit(), or explicitly using + expire_all(). Before the expiration, objects may still remain in + the session which were in fact deleted which can lead to confusing + results if they are accessed via get() or already loaded + collections. + + ``'fetch'`` - performs a select query before the delete to find + objects that are matched by the delete query and need to be + removed from the session. Matched objects are removed from the + session. + + ``'evaluate'`` - Evaluate the query's criteria in Python straight + on the objects in the session. If evaluation of the criteria isn't + implemented, an error is raised. In that case you probably + want to use the 'fetch' strategy as a fallback. + + The expression evaluator currently doesn't account for differing + string collations between the database and Python. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + This method has several key caveats: + + * The method does **not** offer in-Python cascading of relationships - it + is assumed that ON DELETE CASCADE/SET NULL/etc. is configured for any foreign key + references which require it, otherwise the database may emit an + integrity violation if foreign key references are being enforced. + + After the DELETE, dependent objects in the :class:`.Session` which + were impacted by an ON DELETE may not contain the current + state, or may have been deleted. This issue is resolved once the + :class:`.Session` is expired, + which normally occurs upon :meth:`.Session.commit` or can be forced + by using :meth:`.Session.expire_all`. Accessing an expired object + whose row has been deleted will invoke a SELECT to locate the + row; when the row is not found, an :class:`~sqlalchemy.orm.exc.ObjectDeletedError` + is raised. + + * The :meth:`.MapperEvents.before_delete` and + :meth:`.MapperEvents.after_delete` + events are **not** invoked from this method. Instead, the + :meth:`.SessionEvents.after_bulk_delete` method is provided to act + upon a mass DELETE of entity rows. + + .. seealso:: + + :meth:`.Query.update` + + :ref:`inserts_and_updates` - Core SQL tutorial + + """ + #TODO: cascades need handling. + + delete_op = persistence.BulkDelete.factory( + self, synchronize_session) + delete_op.exec_() + return delete_op.rowcount + + def update(self, values, synchronize_session='evaluate'): + """Perform a bulk update query. + + Updates rows matched by this query in the database. + + :param values: a dictionary with attributes names as keys and literal + values or sql expressions as values. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. Valid values are: + + ``False`` - don't synchronize the session. This option is the most + efficient and is reliable once the session is expired, which + typically occurs after a commit(), or explicitly using + expire_all(). Before the expiration, updated objects may still + remain in the session with stale values on their attributes, which + can lead to confusing results. + + ``'fetch'`` - performs a select query before the update to find + objects that are matched by the update query. The updated + attributes are expired on matched objects. + + ``'evaluate'`` - Evaluate the Query's criteria in Python straight + on the objects in the session. If evaluation of the criteria isn't + implemented, an exception is raised. + + The expression evaluator currently doesn't account for differing + string collations between the database and Python. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + This method has several key caveats: + + * The method does **not** offer in-Python cascading of relationships - it + is assumed that ON UPDATE CASCADE is configured for any foreign key + references which require it, otherwise the database may emit an + integrity violation if foreign key references are being enforced. + + After the UPDATE, dependent objects in the :class:`.Session` which + were impacted by an ON UPDATE CASCADE may not contain the current + state; this issue is resolved once the :class:`.Session` is expired, + which normally occurs upon :meth:`.Session.commit` or can be forced + by using :meth:`.Session.expire_all`. + + * As of 0.8, this method will support multiple table updates, as detailed + in :ref:`multi_table_updates`, and this behavior does extend to support + updates of joined-inheritance and other multiple table mappings. However, + the **join condition of an inheritance mapper is currently not + automatically rendered**. + Care must be taken in any multiple-table update to explicitly include + the joining condition between those tables, even in mappings where + this is normally automatic. + E.g. if a class ``Engineer`` subclasses ``Employee``, an UPDATE of the + ``Engineer`` local table using criteria against the ``Employee`` + local table might look like:: + + session.query(Engineer).\\ + filter(Engineer.id == Employee.id).\\ + filter(Employee.name == 'dilbert').\\ + update({"engineer_type": "programmer"}) + + * The :meth:`.MapperEvents.before_update` and + :meth:`.MapperEvents.after_update` + events are **not** invoked from this method. Instead, the + :meth:`.SessionEvents.after_bulk_update` method is provided to act + upon a mass UPDATE of entity rows. + + .. seealso:: + + :meth:`.Query.delete` + + :ref:`inserts_and_updates` - Core SQL tutorial + + """ + + #TODO: value keys need to be mapped to corresponding sql cols and + # instr.attr.s to string keys + #TODO: updates of manytoone relationships need to be converted to + # fk assignments + #TODO: cascades need handling. + + update_op = persistence.BulkUpdate.factory( + self, synchronize_session, values) + update_op.exec_() + return update_op.rowcount + + + def _compile_context(self, labels=True): + context = QueryContext(self) + + if context.statement is not None: + return context + + context.labels = labels + + context._for_update_arg = self._for_update_arg + + for entity in self._entities: + entity.setup_context(self, context) + + for rec in context.create_eager_joins: + strategy = rec[0] + strategy(*rec[1:]) + + if context.from_clause: + # "load from explicit FROMs" mode, + # i.e. when select_from() or join() is used + context.froms = list(context.from_clause) + else: + # "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM + context.froms = context.froms + + if self._enable_single_crit: + self._adjust_for_single_inheritance(context) + + if not context.primary_columns: + if self._only_load_props: + raise sa_exc.InvalidRequestError( + "No column-based properties specified for " + "refresh operation. Use session.expire() " + "to reload collections and related items.") + else: + raise sa_exc.InvalidRequestError( + "Query contains no columns with which to " + "SELECT from.") + + if context.multi_row_eager_loaders and self._should_nest_selectable: + context.statement = self._compound_eager_statement(context) + else: + context.statement = self._simple_statement(context) + return context + + def _compound_eager_statement(self, context): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.unwrap_order_by(o) + for o in context.order_by + ]) + ) + else: + context.order_by = None + order_by_col_expr = [] + + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + # TODO: this order_by is only needed if + # LIMIT/OFFSET is present in self._select_args, + # else the application on the outside is enough + order_by=context.order_by, + **self._select_args + ) + + for hint in self._with_hints: + inner = inner.with_hint(*hint) + + if self._correlate: + inner = inner.correlate(*self._correlate) + + inner = inner.alias() + + equivs = self.__all_equivs() + + context.adapter = sql_util.ColumnAdapter(inner, equivs) + + statement = sql.select( + [inner] + context.secondary_columns, + use_labels=context.labels) + + statement._for_update_arg = context._for_update_arg + + from_clause = inner + for eager_join in context.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, + eager_join, eager_join.stop_on) + + statement.append_from(from_clause) + + if context.order_by: + statement.append_order_by( + *context.adapter.copy_and_process( + context.order_by + ) + ) + + statement.append_order_by(*context.eager_order_by) + return statement + + def _simple_statement(self, context): + if not context.order_by: + context.order_by = None + + if self._distinct and context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.unwrap_order_by(o) + for o in context.order_by + ]) + ) + context.primary_columns += order_by_col_expr + + context.froms += tuple(context.eager_joins.values()) + + statement = sql.select( + context.primary_columns + + context.secondary_columns, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + order_by=context.order_by, + **self._select_args + ) + + statement._for_update_arg = context._for_update_arg + + for hint in self._with_hints: + statement = statement.with_hint(*hint) + + if self._correlate: + statement = statement.correlate(*self._correlate) + + if context.eager_order_by: + statement.append_order_by(*context.eager_order_by) + return statement + + def _adjust_for_single_inheritance(self, context): + """Apply single-table-inheritance filtering. + + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. + + """ + for (ext_info, adapter) in self._mapper_adapter_map.values(): + if ext_info in self._join_entities: + continue + single_crit = ext_info.mapper._single_table_criterion + if single_crit is not None: + if adapter: + single_crit = adapter.traverse(single_crit) + single_crit = self._adapt_clause(single_crit, False, False) + context.whereclause = sql.and_( + sql.True_._ifnone(context.whereclause), + single_crit) + + def __str__(self): + return str(self._compile_context().statement) + +from ..sql.selectable import ForUpdateArg + +class LockmodeArg(ForUpdateArg): + @classmethod + def parse_legacy_query(self, mode): + if mode in (None, False): + return None + + if mode == "read": + read = True + nowait = False + elif mode == "update": + read = nowait = False + elif mode == "update_nowait": + nowait = True + read = False + else: + raise sa_exc.ArgumentError( + "Unknown with_lockmode argument: %r" % mode) + + return LockmodeArg(read=read, nowait=nowait) + +class _QueryEntity(object): + """represent an entity column returned within a Query result.""" + + def __new__(cls, *args, **kwargs): + if cls is _QueryEntity: + entity = args[1] + if not isinstance(entity, util.string_types) and \ + _is_mapped_class(entity): + cls = _MapperEntity + elif isinstance(entity, Bundle): + cls = _BundleEntity + else: + cls = _ColumnEntity + return object.__new__(cls) + + def _clone(self): + q = self.__class__.__new__(self.__class__) + q.__dict__ = self.__dict__.copy() + return q + + +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" + + def __init__(self, query, entity): + if not query._primary_entity: + query._primary_entity = self + query._entities.append(self) + + self.entities = [entity] + self.expr = entity + + supports_single_entity = True + + def setup_entity(self, ext_info, aliased_adapter): + self.mapper = ext_info.mapper + self.aliased_adapter = aliased_adapter + self.selectable = ext_info.selectable + self.is_aliased_class = ext_info.is_aliased_class + self._with_polymorphic = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = \ + ext_info.polymorphic_on + self.entity_zero = ext_info + if ext_info.is_aliased_class: + self._label_name = self.entity_zero.name + else: + self._label_name = self.mapper.class_.__name__ + self.path = self.entity_zero._path_registry + self.custom_rows = bool(self.mapper.dispatch.append_result) + + def set_with_polymorphic(self, query, cls_or_mappers, + selectable, polymorphic_on): + """Receive an update from a call to query.with_polymorphic(). + + Note the newer style of using a free standing with_polymporphic() + construct doesn't make use of this method. + + + """ + if self.is_aliased_class: + # TODO: invalidrequest ? + raise NotImplementedError( + "Can't use with_polymorphic() against " + "an Aliased object" + ) + + if cls_or_mappers is None: + query._reset_polymorphic_adapter(self.mapper) + return + + mappers, from_obj = self.mapper._with_polymorphic_args( + cls_or_mappers, selectable) + self._with_polymorphic = mappers + self._polymorphic_discriminator = polymorphic_on + + self.selectable = from_obj + query._mapper_loads_polymorphically_with(self.mapper, + sql_util.ColumnAdapter(from_obj, + self.mapper._equivalent_columns)) + + filter_fn = id + + @property + def type(self): + return self.mapper.class_ + + @property + def entity_zero_or_selectable(self): + return self.entity_zero + + def corresponds_to(self, entity): + if entity.is_aliased_class: + if self.is_aliased_class: + if entity._base_alias is self.entity_zero._base_alias: + return True + return False + elif self.is_aliased_class: + if self.entity_zero._use_mapper_path: + return entity in self._with_polymorphic + else: + return entity is self.entity_zero + + return entity.common_parent(self.entity_zero) + + def adapt_to_selectable(self, query, sel): + query._entities.append(self) + + def _get_entity_clauses(self, query, context): + + adapter = None + + if not self.is_aliased_class: + if query._polymorphic_adapters: + adapter = query._polymorphic_adapters.get(self.mapper, None) + else: + adapter = self.aliased_adapter + + if adapter: + if query._from_obj_alias: + ret = adapter.wrap(query._from_obj_alias) + else: + ret = adapter + else: + ret = query._from_obj_alias + + return ret + + def row_processor(self, query, context, custom_rows): + adapter = self._get_entity_clauses(query, context) + + if context.adapter and adapter: + adapter = adapter.wrap(context.adapter) + elif not adapter: + adapter = context.adapter + + # polymorphic mappers which have concrete tables in + # their hierarchy usually + # require row aliasing unconditionally. + if not adapter and self.mapper._requires_row_aliasing: + adapter = sql_util.ColumnAdapter( + self.selectable, + self.mapper._equivalent_columns) + + if query._primary_entity is self: + _instance = loading.instance_processor( + self.mapper, + context, + self.path, + adapter, + only_load_props=query._only_load_props, + refresh_state=context.refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator + ) + else: + _instance = loading.instance_processor( + self.mapper, + context, + self.path, + adapter, + polymorphic_discriminator=self._polymorphic_discriminator + ) + + return _instance, self._label_name + + def setup_context(self, query, context): + adapter = self._get_entity_clauses(query, context) + + #if self._adapted_selectable is None: + context.froms += (self.selectable,) + + if context.order_by is False and self.mapper.order_by: + context.order_by = self.mapper.order_by + + # apply adaptation to the mapper's order_by if needed. + if adapter: + context.order_by = adapter.adapt_list( + util.to_list( + context.order_by + ) + ) + + if self._with_polymorphic: + poly_properties = self.mapper._iterate_polymorphic_properties( + self._with_polymorphic) + else: + poly_properties = self.mapper._polymorphic_properties + + for value in poly_properties: + if query._only_load_props and \ + value.key not in query._only_load_props: + continue + value.setup( + context, + self, + self.path, + adapter, + only_load_props=query._only_load_props, + column_collection=context.primary_columns + ) + + if self._polymorphic_discriminator is not None and \ + self._polymorphic_discriminator \ + is not self.mapper.polymorphic_on: + + if adapter: + pd = adapter.columns[self._polymorphic_discriminator] + else: + pd = self._polymorphic_discriminator + context.primary_columns.append(pd) + + def __str__(self): + return str(self.mapper) + +@inspection._self_inspects +class Bundle(object): + """A grouping of SQL expressions that are returned by a :class:`.Query` + under one namespace. + + The :class:`.Bundle` essentially allows nesting of the tuple-based + results returned by a column-oriented :class:`.Query` object. It also + is extensible via simple subclassing, where the primary capability + to override is that of how the set of expressions should be returned, + allowing post-processing as well as custom return types, without + involving ORM identity-mapped classes. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`bundles` + + """ + + single_entity = False + """If True, queries for a single Bundle will be returned as a single + entity, rather than an element within a keyed tuple.""" + + def __init__(self, name, *exprs, **kw): + """Construct a new :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + for row in session.query(bn).filter(bn.c.x == 5).filter(bn.c.y == 4): + print(row.mybundle.x, row.mybundle.y) + + :param name: name of the bundle. + :param \*exprs: columns or SQL expressions comprising the bundle. + :param single_entity=False: if True, rows for this :class:`.Bundle` + can be returned as a "single entity" outside of any enclosing tuple + in the same manner as a mapped entity. + + """ + self.name = self._label = name + self.exprs = exprs + self.c = self.columns = ColumnCollection() + self.columns.update((getattr(col, "key", col._label), col) + for col in exprs) + self.single_entity = kw.pop('single_entity', self.single_entity) + + columns = None + """A namespace of SQL expressions referred to by this :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + q = sess.query(bn).filter(bn.c.x == 5) + + Nesting of bundles is also supported:: + + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) + + q = sess.query(b1).filter(b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + + .. seealso:: + + :attr:`.Bundle.c` + + """ + + c = None + """An alias for :attr:`.Bundle.columns`.""" + + def _clone(self): + cloned = self.__class__.__new__(self.__class__) + cloned.__dict__.update(self.__dict__) + return cloned + + def __clause_element__(self): + return expression.ClauseList(group=False, *self.c) + + @property + def clauses(self): + return self.__clause_element__().clauses + + def label(self, name): + """Provide a copy of this :class:`.Bundle` passing a new label.""" + + cloned = self._clone() + cloned.name = name + return cloned + + def create_row_processor(self, query, procs, labels): + """Produce the "row processing" function for this :class:`.Bundle`. + + May be overridden by subclasses. + + .. seealso:: + + :ref:`bundles` - includes an example of subclassing. + + """ + def proc(row, result): + return util.KeyedTuple([proc(row, None) for proc in procs], labels) + return proc + + +class _BundleEntity(_QueryEntity): + def __init__(self, query, bundle, setup_entities=True): + query._entities.append(self) + self.bundle = self.expr = bundle + self.type = type(bundle) + self._label_name = bundle.name + self._entities = [] + + if setup_entities: + for expr in bundle.exprs: + if isinstance(expr, Bundle): + _BundleEntity(self, expr) + else: + _ColumnEntity(self, expr, namespace=self) + + self.entities = () + + self.filter_fn = lambda item: item + + self.supports_single_entity = self.bundle.single_entity + + custom_rows = False + + @property + def entity_zero(self): + for ent in self._entities: + ezero = ent.entity_zero + if ezero is not None: + return ezero + else: + return None + + def corresponds_to(self, entity): + # TODO: this seems to have no effect for + # _ColumnEntity either + return False + + @property + def entity_zero_or_selectable(self): + for ent in self._entities: + ezero = ent.entity_zero_or_selectable + if ezero is not None: + return ezero + else: + return None + + def adapt_to_selectable(self, query, sel): + c = _BundleEntity(query, self.bundle, setup_entities=False) + #c._label_name = self._label_name + #c.entity_zero = self.entity_zero + #c.entities = self.entities + + for ent in self._entities: + ent.adapt_to_selectable(c, sel) + + def setup_entity(self, ext_info, aliased_adapter): + for ent in self._entities: + ent.setup_entity(ext_info, aliased_adapter) + + def setup_context(self, query, context): + for ent in self._entities: + ent.setup_context(query, context) + + def row_processor(self, query, context, custom_rows): + procs, labels = zip( + *[ent.row_processor(query, context, custom_rows) + for ent in self._entities] + ) + + proc = self.bundle.create_row_processor(query, procs, labels) + + return proc, self._label_name + +class _ColumnEntity(_QueryEntity): + """Column/expression based entity.""" + + def __init__(self, query, column, namespace=None): + self.expr = column + self.namespace = namespace + + if isinstance(column, util.string_types): + column = sql.literal_column(column) + self._label_name = column.name + elif isinstance(column, ( + attributes.QueryableAttribute, + interfaces.PropComparator + )): + self._label_name = column.key + column = column._query_clause_element() + else: + self._label_name = getattr(column, 'key', None) + + if not isinstance(column, expression.ColumnElement) and \ + hasattr(column, '_select_iterable'): + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return + elif isinstance(column, Bundle): + _BundleEntity(query, column) + return + + if not isinstance(column, sql.ColumnElement): + raise sa_exc.InvalidRequestError( + "SQL expression, column, or mapped entity " + "expected - got '%r'" % (column, ) + ) + + self.type = type_ = column.type + if type_.hashable: + self.filter_fn = lambda item: item + else: + counter = util.counter() + self.filter_fn = lambda item: counter() + + # If the Column is unnamed, give it a + # label() so that mutable column expressions + # can be located in the result even + # if the expression's identity has been changed + # due to adaption. + + if not column._label and not getattr(column, 'is_literal', False): + column = column.label(self._label_name) + + query._entities.append(self) + + self.column = column + self.froms = set() + + # look for ORM entities represented within the + # given expression. Try to count only entities + # for columns whose FROM object is in the actual list + # of FROMs for the overall expression - this helps + # subqueries which were built from ORM constructs from + # leaking out their entities into the main select construct + self.actual_froms = actual_froms = set(column._from_objects) + + self.entities = util.OrderedSet( + elem._annotations['parententity'] + for elem in visitors.iterate(column, {}) + if 'parententity' in elem._annotations + and actual_froms.intersection(elem._from_objects) + ) + + if self.entities: + self.entity_zero = list(self.entities)[0] + elif self.namespace is not None: + self.entity_zero = self.namespace + else: + self.entity_zero = None + + supports_single_entity = False + custom_rows = False + + @property + def entity_zero_or_selectable(self): + if self.entity_zero is not None: + return self.entity_zero + elif self.actual_froms: + return list(self.actual_froms)[0] + else: + return None + + def adapt_to_selectable(self, query, sel): + c = _ColumnEntity(query, sel.corresponding_column(self.column)) + c._label_name = self._label_name + c.entity_zero = self.entity_zero + c.entities = self.entities + + def setup_entity(self, ext_info, aliased_adapter): + if 'selectable' not in self.__dict__: + self.selectable = ext_info.selectable + self.froms.add(ext_info.selectable) + + def corresponds_to(self, entity): + # TODO: just returning False here, + # no tests fail + if self.entity_zero is None: + return False + elif _is_aliased_class(entity): + # TODO: polymorphic subclasses ? + return entity is self.entity_zero + else: + return not _is_aliased_class(self.entity_zero) and \ + entity.common_parent(self.entity_zero) + + def _resolve_expr_against_query_aliases(self, query, expr, context): + return query._adapt_clause(expr, False, True) + + def row_processor(self, query, context, custom_rows): + column = self._resolve_expr_against_query_aliases( + query, self.column, context) + + if context.adapter: + column = context.adapter.columns[column] + + def proc(row, result): + return row[column] + + return proc, self._label_name + + def setup_context(self, query, context): + column = self._resolve_expr_against_query_aliases( + query, self.column, context) + context.froms += tuple(self.froms) + context.primary_columns.append(column) + + def __str__(self): + return str(self.column) + + +class QueryContext(object): + multi_row_eager_loaders = False + adapter = None + froms = () + for_update = None + + def __init__(self, query): + + if query._statement is not None: + if isinstance(query._statement, expression.SelectBase) and \ + not query._statement._textual and \ + not query._statement.use_labels: + self.statement = query._statement.apply_labels() + else: + self.statement = query._statement + else: + self.statement = None + self.from_clause = query._from_obj + self.whereclause = query._criterion + self.order_by = query._order_by + + self.query = query + self.session = query.session + self.populate_existing = query._populate_existing + self.invoke_all_eagers = query._invoke_all_eagers + self.version_check = query._version_check + self.refresh_state = query._refresh_state + self.primary_columns = [] + self.secondary_columns = [] + self.eager_order_by = [] + self.eager_joins = {} + self.create_eager_joins = [] + self.propagate_options = set(o for o in query._with_options if + o.propagate_to_loaders) + self.attributes = query._attributes.copy() + + +class AliasOption(interfaces.MapperOption): + + def __init__(self, alias): + """Return a :class:`.MapperOption` that will indicate to the :class:`.Query` + that the main table has been aliased. + + This is a seldom-used option to suit the + very rare case that :func:`.contains_eager` + is being used in conjunction with a user-defined SELECT + statement that aliases the parent table. E.g.:: + + # define an aliased UNION called 'ulist' + ulist = users.select(users.c.user_id==7).\\ + union(users.select(users.c.user_id>7)).\\ + alias('ulist') + + # add on an eager load of "addresses" + statement = ulist.outerjoin(addresses).\\ + select().apply_labels() + + # create query, indicating "ulist" will be an + # alias for the main table, "addresses" + # property should be eager loaded + query = session.query(User).options( + contains_alias(ulist), + contains_eager(User.addresses)) + + # then get results via the statement + results = query.from_statement(statement).all() + + :param alias: is the string name of an alias, or a + :class:`~.sql.expression.Alias` object representing + the alias. + + """ + self.alias = alias + + def process_query(self, query): + if isinstance(self.alias, util.string_types): + alias = query._mapper_zero().mapped_table.alias(self.alias) + else: + alias = self.alias + query._from_obj_alias = sql_util.ColumnAdapter(alias) + + diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py new file mode 100644 index 00000000..311fba47 --- /dev/null +++ b/lib/sqlalchemy/orm/relationships.py @@ -0,0 +1,2646 @@ +# orm/relationships.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Heuristics related to join conditions as used in +:func:`.relationship`. + +Provides the :class:`.JoinCondition` object, which encapsulates +SQL annotation and aliasing behavior focused on the `primaryjoin` +and `secondaryjoin` aspects of :func:`.relationship`. + +""" + +from .. import sql, util, exc as sa_exc, schema, log + +from .util import CascadeOptions, _orm_annotate, _orm_deannotate +from . import dependency +from . import attributes +from ..sql.util import ( + ClauseAdapter, + join_condition, _shallow_annotate, visit_binary_product, + _deep_deannotate, selectables_overlap + ) +from ..sql import operators, expression, visitors +from .interfaces import MANYTOMANY, MANYTOONE, ONETOMANY, StrategizedProperty, PropComparator +from ..inspection import inspect +from . import mapper as mapperlib + +def remote(expr): + """Annotate a portion of a primaryjoin expression + with a 'remote' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. versionadded:: 0.8 + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.foreign` + + """ + return _annotate_columns(expression._clause_element_as_expr(expr), + {"remote": True}) + + +def foreign(expr): + """Annotate a portion of a primaryjoin expression + with a 'foreign' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. versionadded:: 0.8 + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.remote` + + """ + + return _annotate_columns(expression._clause_element_as_expr(expr), + {"foreign": True}) + + +@log.class_logger +@util.langhelpers.dependency_for("sqlalchemy.orm.properties") +class RelationshipProperty(StrategizedProperty): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + + Public constructor is the :func:`.orm.relationship` function. + + See also: + + :ref:`relationship_config_toplevel` + + """ + + strategy_wildcard_key = 'relationship' + + _dependency_processor = None + + def __init__(self, argument, + secondary=None, primaryjoin=None, + secondaryjoin=None, + foreign_keys=None, + uselist=None, + order_by=False, + backref=None, + back_populates=None, + post_update=False, + cascade=False, extension=None, + viewonly=False, lazy=True, + collection_class=None, passive_deletes=False, + passive_updates=True, remote_side=None, + enable_typechecks=True, join_depth=None, + comparator_factory=None, + single_parent=False, innerjoin=False, + distinct_target_key=None, + doc=None, + active_history=False, + cascade_backrefs=True, + load_on_pending=False, + strategy_class=None, _local_remote_pairs=None, + query_class=None, + info=None): + """Provide a relationship between two mapped classes. + + This corresponds to a parent-child or associative table relationship. The + constructed class is an instance of :class:`.RelationshipProperty`. + + A typical :func:`.relationship`, used in a classical mapping:: + + mapper(Parent, properties={ + 'children': relationship(Child) + }) + + Some arguments accepted by :func:`.relationship` optionally accept a + callable function, which when called produces the desired value. + The callable is invoked by the parent :class:`.Mapper` at "mapper + initialization" time, which happens only when mappers are first used, and + is assumed to be after all mappings have been constructed. This can be + used to resolve order-of-declaration and other dependency issues, such as + if ``Child`` is declared below ``Parent`` in the same file:: + + mapper(Parent, properties={ + "children":relationship(lambda: Child, + order_by=lambda: Child.id) + }) + + When using the :ref:`declarative_toplevel` extension, the Declarative + initializer allows string arguments to be passed to :func:`.relationship`. + These string arguments are converted into callables that evaluate + the string as Python code, using the Declarative + class-registry as a namespace. This allows the lookup of related + classes to be automatic via their string name, and removes the need to + import related classes at all into the local module space:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class Parent(Base): + __tablename__ = 'parent' + id = Column(Integer, primary_key=True) + children = relationship("Child", order_by="Child.id") + + .. seealso:: + + :ref:`relationship_config_toplevel` - Full introductory and reference + documentation for :func:`.relationship`. + + :ref:`orm_tutorial_relationship` - ORM tutorial introduction. + + :param argument: + a mapped class, or actual :class:`.Mapper` instance, representing the + target of the relationship. + + :paramref:`~.relationship.argument` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + .. seealso:: + + :ref:`declarative_configuring_relationships` - further detail + on relationship configuration when using Declarative. + + :param secondary: + for a many-to-many relationship, specifies the intermediary + table, and is typically an instance of :class:`.Table`. + In less common circumstances, the argument may also be specified + as an :class:`.Alias` construct, or even a :class:`.Join` construct. + + :paramref:`~.relationship.secondary` may + also be passed as a callable function which is evaluated at + mapper initialization time. When using Declarative, it may also + be a string argument noting the name of a :class:`.Table` that is + present in the :class:`.MetaData` collection associated with the + parent-mapped :class:`.Table`. + + The :paramref:`~.relationship.secondary` keyword argument is typically + applied in the case where the intermediary :class:`.Table` is not + otherwise exprssed in any direct class mapping. If the "secondary" table + is also explicitly mapped elsewhere + (e.g. as in :ref:`association_pattern`), one should consider applying + the :paramref:`~.relationship.viewonly` flag so that this :func:`.relationship` + is not used for persistence operations which may conflict with those + of the association object pattern. + + .. seealso:: + + :ref:`relationships_many_to_many` - Reference example of "many to many". + + :ref:`orm_tutorial_many_to_many` - ORM tutorial introduction to + many-to-many relationships. + + :ref:`self_referential_many_to_many` - Specifics on using many-to-many + in a self-referential case. + + :ref:`declarative_many_to_many` - Additional options when using + Declarative. + + :ref:`association_pattern` - an alternative to :paramref:`~.relationship.secondary` + when composing association table relationships, allowing additional + attributes to be specified on the association table. + + :ref:`composite_secondary_join` - a lesser-used pattern which in some + cases can enable complex :func:`.relationship` SQL conditions + to be used. + + .. versionadded:: 0.9.2 :paramref:`~.relationship.secondary` works + more effectively when referring to a :class:`.Join` instance. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + many-to-one reference should be loaded when replaced, if + not already loaded. Normally, history tracking logic for + simple many-to-ones only needs to be aware of the "new" + value in order to perform a flush. This flag is available + for applications that make use of + :func:`.attributes.get_history` which also need to know + the "previous" value of the attribute. + + :param backref: + indicates the string name of a property to be placed on the related + mapper's class that will handle this relationship in the other + direction. The other property will be created automatically + when the mappers are configured. Can also be passed as a + :func:`.backref` object to control the configuration of the + new relationship. + + .. seealso:: + + :ref:`relationships_backref` - Introductory documentation and + examples. + + :paramref:`~.relationship.back_populates` - alternative form + of backref specification. + + :func:`.backref` - allows control over :func:`.relationship` + configuration when using :paramref:`~.relationship.backref`. + + + :param back_populates: + Takes a string name and has the same meaning as :paramref:`~.relationship.backref`, + except the complementing property is **not** created automatically, + and instead must be configured explicitly on the other mapper. The + complementing property should also indicate :paramref:`~.relationship.back_populates` + to this relationship to ensure proper functioning. + + .. seealso:: + + :ref:`relationships_backref` - Introductory documentation and + examples. + + :paramref:`~.relationship.backref` - alternative form + of backref specification. + + :param cascade: + a comma-separated list of cascade rules which determines how + Session operations should be "cascaded" from parent to child. + This defaults to ``False``, which means the default cascade + should be used - this default cascade is ``"save-update, merge"``. + + The available cascades are ``save-update``, ``merge``, + ``expunge``, ``delete``, ``delete-orphan``, and ``refresh-expire``. + An additional option, ``all`` indicates shorthand for + ``"save-update, merge, refresh-expire, + expunge, delete"``, and is often used as in ``"all, delete-orphan"`` + to indicate that related objects should follow along with the + parent object in all cases, and be deleted when de-associated. + + .. seealso:: + + :ref:`unitofwork_cascades` - Full detail on each of the available + cascade options. + + :ref:`tutorial_delete_cascade` - Tutorial example describing + a delete cascade. + + :param cascade_backrefs=True: + a boolean value indicating if the ``save-update`` cascade should + operate along an assignment event intercepted by a backref. + When set to ``False``, the attribute managed by this relationship + will not cascade an incoming transient object into the session of a + persistent parent, if the event is received via backref. + + .. seealso:: + + :ref:`backref_cascade` - Full discussion and examples on how + the :paramref:`~.relationship.cascade_backrefs` option is used. + + :param collection_class: + a class or callable that returns a new list-holding object. will + be used in place of a plain list for storing elements. + + .. seealso:: + + :ref:`custom_collections` - Introductory documentation and + examples. + + :param comparator_factory: + a class which extends :class:`.RelationshipProperty.Comparator` which + provides custom SQL clause generation for comparison operations. + + .. seealso:: + + :class:`.PropComparator` - some detail on redefining comparators + at this level. + + :ref:`custom_comparators` - Brief intro to this feature. + + + :param distinct_target_key=None: + Indicate if a "subquery" eager load should apply the DISTINCT + keyword to the innermost SELECT statement. When left as ``None``, + the DISTINCT keyword will be applied in those cases when the target + columns do not comprise the full primary key of the target table. + When set to ``True``, the DISTINCT keyword is applied to the innermost + SELECT unconditionally. + + It may be desirable to set this flag to False when the DISTINCT is + reducing performance of the innermost subquery beyond that of what + duplicate innermost rows may be causing. + + .. versionadded:: 0.8.3 - :paramref:`~.relationship.distinct_target_key` + allows the + subquery eager loader to apply a DISTINCT modifier to the + innermost SELECT. + + .. versionchanged:: 0.9.0 - :paramref:`~.relationship.distinct_target_key` + now defaults to ``None``, so that the feature enables itself automatically for + those cases where the innermost query targets a non-unique + key. + + .. seealso:: + + :ref:`loading_toplevel` - includes an introduction to subquery + eager loading. + + :param doc: + docstring which will be applied to the resulting descriptor. + + :param extension: + an :class:`.AttributeExtension` instance, or list of extensions, + which will be prepended to the list of attribute listeners for + the resulting descriptor placed on the class. + + .. deprecated:: 0.7 Please see :class:`.AttributeEvents`. + + :param foreign_keys: + + a list of columns which are to be used as "foreign key" + columns, or columns which refer to the value in a remote + column, within the context of this :func:`.relationship` + object's :paramref:`~.relationship.primaryjoin` condition. + That is, if the :paramref:`~.relationship.primaryjoin` + condition of this :func:`.relationship` is ``a.id == + b.a_id``, and the values in ``b.a_id`` are required to be + present in ``a.id``, then the "foreign key" column of this + :func:`.relationship` is ``b.a_id``. + + In normal cases, the :paramref:`~.relationship.foreign_keys` + parameter is **not required.** :func:`.relationship` will + automatically determine which columns in the + :paramref:`~.relationship.primaryjoin` conditition are to be + considered "foreign key" columns based on those + :class:`.Column` objects that specify :class:`.ForeignKey`, + or are otherwise listed as referencing columns in a + :class:`.ForeignKeyConstraint` construct. + :paramref:`~.relationship.foreign_keys` is only needed when: + + 1. There is more than one way to construct a join from the local + table to the remote table, as there are multiple foreign key + references present. Setting ``foreign_keys`` will limit the + :func:`.relationship` to consider just those columns specified + here as "foreign". + + .. versionchanged:: 0.8 + A multiple-foreign key join ambiguity can be resolved by + setting the :paramref:`~.relationship.foreign_keys` parameter alone, without the + need to explicitly set :paramref:`~.relationship.primaryjoin` as well. + + 2. The :class:`.Table` being mapped does not actually have + :class:`.ForeignKey` or :class:`.ForeignKeyConstraint` + constructs present, often because the table + was reflected from a database that does not support foreign key + reflection (MySQL MyISAM). + + 3. The :paramref:`~.relationship.primaryjoin` argument is used to construct a non-standard + join condition, which makes use of columns or expressions that do + not normally refer to their "parent" column, such as a join condition + expressed by a complex comparison using a SQL function. + + The :func:`.relationship` construct will raise informative + error messages that suggest the use of the + :paramref:`~.relationship.foreign_keys` parameter when + presented with an ambiguous condition. In typical cases, + if :func:`.relationship` doesn't raise any exceptions, the + :paramref:`~.relationship.foreign_keys` parameter is usually + not needed. + + :paramref:`~.relationship.foreign_keys` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + .. seealso:: + + :ref:`relationship_foreign_keys` + + :ref:`relationship_custom_foreign` + + :func:`.foreign` - allows direct annotation of the "foreign" columns + within a :paramref:`~.relationship.primaryjoin` condition. + + .. versionadded:: 0.8 + The :func:`.foreign` annotation can also be applied + directly to the :paramref:`~.relationship.primaryjoin` expression, which is an alternate, + more specific system of describing which columns in a particular + :paramref:`~.relationship.primaryjoin` should be considered "foreign". + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + .. versionadded:: 0.8 + + :param innerjoin=False: + when ``True``, joined eager loads will use an inner join to join + against related tables instead of an outer join. The purpose + of this option is generally one of performance, as inner joins + generally perform better than outer joins. + + This flag can be set to ``True`` when the relationship references an + object via many-to-one using local foreign keys that are not nullable, + or when the reference is one-to-one or a collection that is guaranteed + to have one or at least one entry. + + If the joined-eager load is chained onto an existing LEFT OUTER JOIN, + ``innerjoin=True`` will be bypassed and the join will continue to + chain as LEFT OUTER JOIN so that the results don't change. As an alternative, + specify the value ``"nested"``. This will instead nest the join + on the right side, e.g. using the form "a LEFT OUTER JOIN (b JOIN c)". + + .. versionadded:: 0.9.4 Added ``innerjoin="nested"`` option to support + nesting of eager "inner" joins. + + .. seealso:: + + :ref:`what_kind_of_loading` - Discussion of some details of + various loader options. + + :paramref:`.joinedload.innerjoin` - loader option version + + :param join_depth: + when non-``None``, an integer value indicating how many levels + deep "eager" loaders should join on a self-referring or cyclical + relationship. The number counts how many times the same Mapper + shall be present in the loading condition along a particular join + branch. When left at its default of ``None``, eager loaders + will stop chaining when they encounter a the same target mapper + which is already higher up in the chain. This option applies + both to joined- and subquery- eager loaders. + + .. seealso:: + + :ref:`self_referential_eager_loading` - Introductory documentation + and examples. + + :param lazy='select': specifies + how the related items should be loaded. Default value is + ``select``. Values include: + + * ``select`` - items should be loaded lazily when the property is first + accessed, using a separate SELECT statement, or identity map + fetch for simple many-to-one references. + + * ``immediate`` - items should be loaded as the parents are loaded, + using a separate SELECT statement, or identity map fetch for + simple many-to-one references. + + * ``joined`` - items should be loaded "eagerly" in the same query as + that of the parent, using a JOIN or LEFT OUTER JOIN. Whether + the join is "outer" or not is determined by the + :paramref:`~.relationship.innerjoin` parameter. + + * ``subquery`` - items should be loaded "eagerly" as the parents are + loaded, using one additional SQL statement, which issues a JOIN to a + subquery of the original statement, for each collection requested. + + * ``noload`` - no loading should occur at any time. This is to + support "write-only" attributes, or attributes which are + populated in some manner specific to the application. + + * ``dynamic`` - the attribute will return a pre-configured + :class:`.Query` object for all read + operations, onto which further filtering operations can be + applied before iterating the results. See + the section :ref:`dynamic_relationship` for more details. + + * True - a synonym for 'select' + + * False - a synonym for 'joined' + + * None - a synonym for 'noload' + + .. seealso:: + + :doc:`/orm/loading` - Full documentation on relationship loader + configuration. + + :ref:`dynamic_relationship` - detail on the ``dynamic`` option. + + :param load_on_pending=False: + Indicates loading behavior for transient or pending parent objects. + + When set to ``True``, causes the lazy-loader to + issue a query for a parent object that is not persistent, meaning it has + never been flushed. This may take effect for a pending object when + autoflush is disabled, or for a transient object that has been + "attached" to a :class:`.Session` but is not part of its pending + collection. + + The :paramref:`~.relationship.load_on_pending` flag does not improve behavior + when the ORM is used normally - object references should be constructed + at the object level, not at the foreign key level, so that they + are present in an ordinary way before a flush proceeds. This flag + is not not intended for general use. + + .. seealso:: + + :meth:`.Session.enable_relationship_loading` - this method establishes + "load on pending" behavior for the whole object, and also allows + loading on objects that remain transient or detached. + + :param order_by: + indicates the ordering that should be applied when loading these + items. :paramref:`~.relationship.order_by` is expected to refer to one + of the :class:`.Column` + objects to which the target class is mapped, or + the attribute itself bound to the target class which refers + to the column. + + :paramref:`~.relationship.order_by` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + :param passive_deletes=False: + Indicates loading behavior during delete operations. + + A value of True indicates that unloaded child items should not + be loaded during a delete operation on the parent. Normally, + when a parent item is deleted, all child items are loaded so + that they can either be marked as deleted, or have their + foreign key to the parent set to NULL. Marking this flag as + True usually implies an ON DELETE rule is in + place which will handle updating/deleting child rows on the + database side. + + Additionally, setting the flag to the string value 'all' will + disable the "nulling out" of the child foreign keys, when there + is no delete or delete-orphan cascade enabled. This is + typically used when a triggering or error raise scenario is in + place on the database side. Note that the foreign key + attributes on in-session child objects will not be changed + after a flush occurs so this is a very special use-case + setting. + + .. seealso:: + + :ref:`passive_deletes` - Introductory documentation + and examples. + + :param passive_updates=True: + Indicates loading and INSERT/UPDATE/DELETE behavior when the + source of a foreign key value changes (i.e. an "on update" + cascade), which are typically the primary key columns of the + source row. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will + handle propagation of an UPDATE from a source column to + dependent rows. Note that with databases which enforce + referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), + ON UPDATE CASCADE is required for this operation. The + relationship() will update the value of the attribute on related + items which are locally present in the session during a flush. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The relationship() will issue the + appropriate UPDATE statements to the database in response to the + change of a referenced key, and items locally present in the + session during a flush will also be refreshed. + + This flag should probably be set to False if primary key changes + are expected and the database in use doesn't support CASCADE + (i.e. SQLite, MySQL MyISAM tables). + + .. seealso:: + + :ref:`passive_updates` - Introductory documentation and + examples. + + :paramref:`.mapper.passive_updates` - a similar flag which + takes effect for joined-table inheritance mappings. + + :param post_update: + this indicates that the relationship should be handled by a + second UPDATE statement after an INSERT or before a + DELETE. Currently, it also will issue an UPDATE after the + instance was UPDATEd as well, although this technically should + be improved. This flag is used to handle saving bi-directional + dependencies between two individual rows (i.e. each row + references the other), where it would otherwise be impossible to + INSERT or DELETE both rows fully since one row exists before the + other. Use this flag when a particular mapping arrangement will + incur two rows that are dependent on each other, such as a table + that has a one-to-many relationship to a set of child rows, and + also has a column that references a single child row within that + list (i.e. both tables contain a foreign key to each other). If + a flush operation returns an error that a "cyclical + dependency" was detected, this is a cue that you might want to + use :paramref:`~.relationship.post_update` to "break" the cycle. + + .. seealso:: + + :ref:`post_update` - Introductory documentation and examples. + + :param primaryjoin: + a SQL expression that will be used as the primary + join of this child object against the parent object, or in a + many-to-many relationship the join of the primary object to the + association table. By default, this value is computed based on the + foreign key relationships of the parent and child tables (or association + table). + + :paramref:`~.relationship.primaryjoin` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param remote_side: + used for self-referential relationships, indicates the column or + list of columns that form the "remote side" of the relationship. + + :paramref:`.relationship.remote_side` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + .. versionchanged:: 0.8 + The :func:`.remote` annotation can also be applied + directly to the ``primaryjoin`` expression, which is an alternate, + more specific system of describing which columns in a particular + ``primaryjoin`` should be considered "remote". + + .. seealso:: + + :ref:`self_referential` - in-depth explaination of how + :paramref:`~.relationship.remote_side` + is used to configure self-referential relationships. + + :func:`.remote` - an annotation function that accomplishes the same + purpose as :paramref:`~.relationship.remote_side`, typically + when a custom :paramref:`~.relationship.primaryjoin` condition + is used. + + :param query_class: + a :class:`.Query` subclass that will be used as the base of the + "appender query" returned by a "dynamic" relationship, that + is, a relationship that specifies ``lazy="dynamic"`` or was + otherwise constructed using the :func:`.orm.dynamic_loader` + function. + + .. seealso:: + + :ref:`dynamic_relationship` - Introduction to "dynamic" relationship + loaders. + + :param secondaryjoin: + a SQL expression that will be used as the join of + an association table to the child object. By default, this value is + computed based on the foreign key relationships of the association and + child tables. + + :paramref:`~.relationship.secondaryjoin` may also be passed as a callable function + which is evaluated at mapper initialization time, and may be passed as a + Python-evaluable string when using Declarative. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param single_parent: + when True, installs a validator which will prevent objects + from being associated with more than one parent at a time. + This is used for many-to-one or many-to-many relationships that + should be treated either as one-to-one or one-to-many. Its usage + is optional, except for :func:`.relationship` constructs which + are many-to-one or many-to-many and also + specify the ``delete-orphan`` cascade option. The :func:`.relationship` + construct itself will raise an error instructing when this option + is required. + + .. seealso:: + + :ref:`unitofwork_cascades` - includes detail on when the + :paramref:`~.relationship.single_parent` flag may be appropriate. + + :param uselist: + a boolean that indicates if this property should be loaded as a + list or a scalar. In most cases, this value is determined + automatically by :func:`.relationship` at mapper configuration + time, based on the type and direction + of the relationship - one to many forms a list, many to one + forms a scalar, many to many is a list. If a scalar is desired + where normally a list would be present, such as a bi-directional + one-to-one relationship, set :paramref:`~.relationship.uselist` to False. + + The :paramref:`~.relationship.uselist` flag is also available on an + existing :func:`.relationship` construct as a read-only attribute, which + can be used to determine if this :func:`.relationship` deals with + collections or scalar attributes:: + + >>> User.addresses.property.uselist + True + + .. seealso:: + + :ref:`relationships_one_to_one` - Introduction to the "one to one" + relationship pattern, which is typically when the + :paramref:`~.relationship.uselist` flag is needed. + + :param viewonly=False: + when set to True, the relationship is used only for loading objects, + and not for any persistence operation. A :func:`.relationship` + which specifies :paramref:`~.relationship.viewonly` can work + with a wider range of SQL operations within the :paramref:`~.relationship.primaryjoin` + condition, including operations that feature the use of + a variety of comparison operators as well as SQL functions such + as :func:`~.sql.expression.cast`. The :paramref:`~.relationship.viewonly` + flag is also of general use when defining any kind of :func:`~.relationship` + that doesn't represent the full set of related objects, to prevent + modifications of the collection from resulting in persistence operations. + + + """ + + self.uselist = uselist + self.argument = argument + self.secondary = secondary + self.primaryjoin = primaryjoin + self.secondaryjoin = secondaryjoin + self.post_update = post_update + self.direction = None + self.viewonly = viewonly + self.lazy = lazy + self.single_parent = single_parent + self._user_defined_foreign_keys = foreign_keys + self.collection_class = collection_class + self.passive_deletes = passive_deletes + self.cascade_backrefs = cascade_backrefs + self.passive_updates = passive_updates + self.remote_side = remote_side + self.enable_typechecks = enable_typechecks + self.query_class = query_class + self.innerjoin = innerjoin + self.distinct_target_key = distinct_target_key + self.doc = doc + self.active_history = active_history + self.join_depth = join_depth + self.local_remote_pairs = _local_remote_pairs + self.extension = extension + self.load_on_pending = load_on_pending + self.comparator_factory = comparator_factory or \ + RelationshipProperty.Comparator + self.comparator = self.comparator_factory(self, None) + util.set_creation_order(self) + + if info is not None: + self.info = info + + if strategy_class: + self.strategy_class = strategy_class + else: + self.strategy_class = self._strategy_lookup(("lazy", self.lazy)) + + self._reverse_property = set() + + self.cascade = cascade if cascade is not False \ + else "save-update, merge" + + self.order_by = order_by + + self.back_populates = back_populates + + if self.back_populates: + if backref: + raise sa_exc.ArgumentError( + "backref and back_populates keyword arguments " + "are mutually exclusive") + self.backref = None + else: + self.backref = backref + + def instrument_class(self, mapper): + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc, + ) + + class Comparator(PropComparator): + """Produce boolean, comparison, and other operators for + :class:`.RelationshipProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief overview + of ORM level operator definition. + + See also: + + :class:`.PropComparator` + + :class:`.ColumnProperty.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + _of_type = None + + def __init__(self, prop, parentmapper, adapt_to_entity=None, of_type=None): + """Construction of :class:`.RelationshipProperty.Comparator` + is internal to the ORM's attribute mechanics. + + """ + self.prop = prop + self._parentmapper = parentmapper + self._adapt_to_entity = adapt_to_entity + if of_type: + self._of_type = of_type + + def adapt_to_entity(self, adapt_to_entity): + return self.__class__(self.property, self._parentmapper, + adapt_to_entity=adapt_to_entity, + of_type=self._of_type) + + @util.memoized_property + def mapper(self): + """The target :class:`.Mapper` referred to by this + :class:`.RelationshipProperty.Comparator`. + + This is the "target" or "remote" side of the + :func:`.relationship`. + + """ + return self.property.mapper + + @util.memoized_property + def _parententity(self): + return self.property.parent + + def _source_selectable(self): + if self._adapt_to_entity: + return self._adapt_to_entity.selectable + else: + return self.property.parent._with_polymorphic_selectable + + def __clause_element__(self): + adapt_from = self._source_selectable() + if self._of_type: + of_type = inspect(self._of_type).mapper + else: + of_type = None + + pj, sj, source, dest, \ + secondary, target_adapter = self.property._create_joins( + source_selectable=adapt_from, + source_polymorphic=True, + of_type=of_type) + if sj is not None: + return pj & sj + else: + return pj + + def of_type(self, cls): + """Produce a construct that represents a particular 'subtype' of + attribute for the parent class. + + Currently this is usable in conjunction with :meth:`.Query.join` + and :meth:`.Query.outerjoin`. + + """ + return RelationshipProperty.Comparator( + self.property, + self._parentmapper, + adapt_to_entity=self._adapt_to_entity, + of_type=cls) + + def in_(self, other): + """Produce an IN clause - this is not implemented + for :func:`~.orm.relationship`-based attributes at this time. + + """ + raise NotImplementedError('in_() not yet supported for ' + 'relationships. For a simple many-to-one, use ' + 'in_() against the set of foreign key values.') + + __hash__ = None + + def __eq__(self, other): + """Implement the ``==`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop == + + this will typically produce a + clause such as:: + + mytable.related_id == + + Where ```` is the primary key of the given + object. + + The ``==`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use :meth:`~.RelationshipProperty.Comparator.contains`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.RelationshipProperty.Comparator.has` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce a NOT EXISTS clause. + + """ + if isinstance(other, (util.NoneType, expression.Null)): + if self.property.direction in [ONETOMANY, MANYTOMANY]: + return ~self._criterion_exists() + else: + return _orm_annotate(self.property._optimized_compare( + None, adapt_source=self.adapter)) + elif self.property.uselist: + raise sa_exc.InvalidRequestError("Can't compare a colle" + "ction to an object or collection; use " + "contains() to test for membership.") + else: + return _orm_annotate(self.property._optimized_compare(other, + adapt_source=self.adapter)) + + def _criterion_exists(self, criterion=None, **kwargs): + if getattr(self, '_of_type', None): + info = inspect(self._of_type) + target_mapper, to_selectable, is_aliased_class = \ + info.mapper, info.selectable, info.is_aliased_class + if self.property._is_self_referential and not is_aliased_class: + to_selectable = to_selectable.alias() + + single_crit = target_mapper._single_table_criterion + if single_crit is not None: + if criterion is not None: + criterion = single_crit & criterion + else: + criterion = single_crit + else: + is_aliased_class = False + to_selectable = None + + if self.adapter: + source_selectable = self._source_selectable() + else: + source_selectable = None + + pj, sj, source, dest, secondary, target_adapter = \ + self.property._create_joins(dest_polymorphic=True, + dest_selectable=to_selectable, + source_selectable=source_selectable) + + for k in kwargs: + crit = getattr(self.property.mapper.class_, k) == kwargs[k] + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + + # annotate the *local* side of the join condition, in the case + # of pj + sj this is the full primaryjoin, in the case of just + # pj its the local side of the primaryjoin. + if sj is not None: + j = _orm_annotate(pj) & sj + else: + j = _orm_annotate(pj, exclude=self.property.remote_side) + + if criterion is not None and target_adapter and not is_aliased_class: + # limit this adapter to annotated only? + criterion = target_adapter.traverse(criterion) + + # only have the "joined left side" of what we + # return be subject to Query adaption. The right + # side of it is used for an exists() subquery and + # should not correlate or otherwise reach out + # to anything in the enclosing query. + if criterion is not None: + criterion = criterion._annotate( + {'no_replacement_traverse': True}) + + crit = j & sql.True_._ifnone(criterion) + + ex = sql.exists([1], crit, from_obj=dest).correlate_except(dest) + if secondary is not None: + ex = ex.correlate_except(secondary) + return ex + + def any(self, criterion=None, **kwargs): + """Produce an expression that tests a collection against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.any(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id + AND related.x=2) + + Because :meth:`~.RelationshipProperty.Comparator.any` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.RelationshipProperty.Comparator.any` is particularly + useful for testing for empty collections:: + + session.query(MyClass).filter( + ~MyClass.somereference.any() + ) + + will produce:: + + SELECT * FROM my_table WHERE + NOT EXISTS (SELECT 1 FROM related WHERE + related.my_id=my_table.id) + + :meth:`~.RelationshipProperty.Comparator.any` is only + valid for collections, i.e. a :func:`.relationship` + that has ``uselist=True``. For scalar references, + use :meth:`~.RelationshipProperty.Comparator.has`. + + """ + if not self.property.uselist: + raise sa_exc.InvalidRequestError( + "'any()' not implemented for scalar " + "attributes. Use has()." + ) + + return self._criterion_exists(criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + """Produce an expression that tests a scalar reference against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.has(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE + related.id==my_table.related_id AND related.x=2) + + Because :meth:`~.RelationshipProperty.Comparator.has` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.RelationshipProperty.Comparator.has` is only + valid for scalar references, i.e. a :func:`.relationship` + that has ``uselist=False``. For collection references, + use :meth:`~.RelationshipProperty.Comparator.any`. + + """ + if self.property.uselist: + raise sa_exc.InvalidRequestError( + "'has()' not implemented for collections. " + "Use any().") + return self._criterion_exists(criterion, **kwargs) + + def contains(self, other, **kwargs): + """Return a simple expression that tests a collection for + containment of a particular item. + + :meth:`~.RelationshipProperty.Comparator.contains` is + only valid for a collection, i.e. a + :func:`~.orm.relationship` that implements + one-to-many or many-to-many with ``uselist=True``. + + When used in a simple one-to-many context, an + expression like:: + + MyClass.contains(other) + + Produces a clause like:: + + mytable.id == + + Where ```` is the value of the foreign key + attribute on ``other`` which refers to the primary + key of its parent object. From this it follows that + :meth:`~.RelationshipProperty.Comparator.contains` is + very useful when used with simple one-to-many + operations. + + For many-to-many operations, the behavior of + :meth:`~.RelationshipProperty.Comparator.contains` + has more caveats. The association table will be + rendered in the statement, producing an "implicit" + join, that is, includes multiple tables in the FROM + clause which are equated in the WHERE clause:: + + query(MyClass).filter(MyClass.contains(other)) + + Produces a query like:: + + SELECT * FROM my_table, my_association_table AS + my_association_table_1 WHERE + my_table.id = my_association_table_1.parent_id + AND my_association_table_1.child_id = + + Where ```` would be the primary key of + ``other``. From the above, it is clear that + :meth:`~.RelationshipProperty.Comparator.contains` + will **not** work with many-to-many collections when + used in queries that move beyond simple AND + conjunctions, such as multiple + :meth:`~.RelationshipProperty.Comparator.contains` + expressions joined by OR. In such cases subqueries or + explicit "outer joins" will need to be used instead. + See :meth:`~.RelationshipProperty.Comparator.any` for + a less-performant alternative using EXISTS, or refer + to :meth:`.Query.outerjoin` as well as :ref:`ormtutorial_joins` + for more details on constructing outer joins. + + """ + if not self.property.uselist: + raise sa_exc.InvalidRequestError( + "'contains' not implemented for scalar " + "attributes. Use ==") + clause = self.property._optimized_compare(other, + adapt_source=self.adapter) + + if self.property.secondaryjoin is not None: + clause.negation_clause = \ + self.__negated_contains_or_equals(other) + + return clause + + def __negated_contains_or_equals(self, other): + if self.property.direction == MANYTOONE: + state = attributes.instance_state(other) + + def state_bindparam(x, state, col): + o = state.obj() # strong ref + return sql.bindparam(x, unique=True, callable_=lambda: \ + self.property.mapper._get_committed_attr_by_column(o, col)) + + def adapt(col): + if self.adapter: + return self.adapter(col) + else: + return col + + if self.property._use_get: + return sql.and_(*[ + sql.or_( + adapt(x) != state_bindparam(adapt(x), state, y), + adapt(x) == None) + for (x, y) in self.property.local_remote_pairs]) + + criterion = sql.and_(*[x == y for (x, y) in + zip( + self.property.mapper.primary_key, + self.property.\ + mapper.\ + primary_key_from_instance(other)) + ]) + return ~self._criterion_exists(criterion) + + def __ne__(self, other): + """Implement the ``!=`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop != + + This will typically produce a clause such as:: + + mytable.related_id != + + Where ```` is the primary key of the + given object. + + The ``!=`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use + :meth:`~.RelationshipProperty.Comparator.contains` + in conjunction with :func:`~.expression.not_`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.RelationshipProperty.Comparator.has` in + conjunction with :func:`~.expression.not_` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce an EXISTS clause. + + """ + if isinstance(other, (util.NoneType, expression.Null)): + if self.property.direction == MANYTOONE: + return sql.or_(*[x != None for x in + self.property._calculated_foreign_keys]) + else: + return self._criterion_exists() + elif self.property.uselist: + raise sa_exc.InvalidRequestError("Can't compare a collection" + " to an object or collection; use " + "contains() to test for membership.") + else: + return self.__negated_contains_or_equals(other) + + @util.memoized_property + def property(self): + if mapperlib.Mapper._new_mappers: + mapperlib.Mapper._configure_all() + return self.prop + + def compare(self, op, value, + value_is_parent=False, + alias_secondary=True): + if op == operators.eq: + if value is None: + if self.uselist: + return ~sql.exists([1], self.primaryjoin) + else: + return self._optimized_compare(None, + value_is_parent=value_is_parent, + alias_secondary=alias_secondary) + else: + return self._optimized_compare(value, + value_is_parent=value_is_parent, + alias_secondary=alias_secondary) + else: + return op(self.comparator, value) + + def _optimized_compare(self, value, value_is_parent=False, + adapt_source=None, + alias_secondary=True): + if value is not None: + value = attributes.instance_state(value) + return self._lazy_strategy.lazy_clause(value, + reverse_direction=not value_is_parent, + alias_secondary=alias_secondary, + adapt_source=adapt_source) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + + def merge(self, + session, + source_state, + source_dict, + dest_state, + dest_dict, + load, _recursive): + + if load: + for r in self._reverse_property: + if (source_state, r) in _recursive: + return + + if not "merge" in self._cascade: + return + + if self.key not in source_dict: + return + + if self.uselist: + instances = source_state.get_impl(self.key).\ + get(source_state, source_dict) + if hasattr(instances, '_sa_adapter'): + # convert collections to adapters to get a true iterator + instances = instances._sa_adapter + + if load: + # for a full merge, pre-load the destination collection, + # so that individual _merge of each item pulls from identity + # map for those already present. + # also assumes CollectionAttrbiuteImpl behavior of loading + # "old" list in any case + dest_state.get_impl(self.key).get(dest_state, dest_dict) + + dest_list = [] + for current in instances: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, + load=load, _recursive=_recursive) + if obj is not None: + dest_list.append(obj) + + if not load: + coll = attributes.init_state_collection(dest_state, + dest_dict, self.key) + for c in dest_list: + coll.append_without_event(c) + else: + dest_state.get_impl(self.key)._set_iterable(dest_state, + dest_dict, dest_list) + else: + current = source_dict[self.key] + if current is not None: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge(current_state, current_dict, + load=load, _recursive=_recursive) + else: + obj = None + + if not load: + dest_dict[self.key] = obj + else: + dest_state.get_impl(self.key).set(dest_state, + dest_dict, obj, None) + + def _value_as_iterable(self, state, dict_, key, + passive=attributes.PASSIVE_OFF): + """Return a list of tuples (state, obj) for the given + key. + + returns an empty list if the value is None/empty/PASSIVE_NO_RESULT + """ + + impl = state.manager[key].impl + x = impl.get(state, dict_, passive=passive) + if x is attributes.PASSIVE_NO_RESULT or x is None: + return [] + elif hasattr(impl, 'get_collection'): + return [ + (attributes.instance_state(o), o) for o in + impl.get_collection(state, dict_, x, passive=passive) + ] + else: + return [(attributes.instance_state(x), x)] + + def cascade_iterator(self, type_, state, dict_, + visited_states, halt_on=None): + #assert type_ in self._cascade + + # only actively lazy load on the 'delete' cascade + if type_ != 'delete' or self.passive_deletes: + passive = attributes.PASSIVE_NO_INITIALIZE + else: + passive = attributes.PASSIVE_OFF + + if type_ == 'save-update': + tuples = state.manager[self.key].impl.\ + get_all_pending(state, dict_) + + else: + tuples = self._value_as_iterable(state, dict_, self.key, + passive=passive) + + skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \ + not in self._cascade + + for instance_state, c in tuples: + if instance_state in visited_states: + continue + + if c is None: + # would like to emit a warning here, but + # would not be consistent with collection.append(None) + # current behavior of silently skipping. + # see [ticket:2229] + continue + + instance_dict = attributes.instance_dict(c) + + if halt_on and halt_on(instance_state): + continue + + if skip_pending and not instance_state.key: + continue + + instance_mapper = instance_state.manager.mapper + + if not instance_mapper.isa(self.mapper.class_manager.mapper): + raise AssertionError("Attribute '%s' on class '%s' " + "doesn't handle objects " + "of type '%s'" % ( + self.key, + self.parent.class_, + c.__class__ + )) + + visited_states.add(instance_state) + + yield c, instance_mapper, instance_state, instance_dict + + def _add_reverse_property(self, key): + other = self.mapper.get_property(key, _configure_mappers=False) + self._reverse_property.add(other) + other._reverse_property.add(self) + + if not other.mapper.common_parent(self.parent): + raise sa_exc.ArgumentError('reverse_property %r on ' + 'relationship %s references relationship %s, which ' + 'does not reference mapper %s' % (key, self, other, + self.parent)) + if self.direction in (ONETOMANY, MANYTOONE) and self.direction \ + == other.direction: + raise sa_exc.ArgumentError('%s and back-reference %s are ' + 'both of the same direction %r. Did you mean to ' + 'set remote_side on the many-to-one side ?' + % (other, self, self.direction)) + + @util.memoized_property + def mapper(self): + """Return the targeted :class:`.Mapper` for this + :class:`.RelationshipProperty`. + + This is a lazy-initializing static attribute. + + """ + if util.callable(self.argument) and \ + not isinstance(self.argument, (type, mapperlib.Mapper)): + argument = self.argument() + else: + argument = self.argument + + if isinstance(argument, type): + mapper_ = mapperlib.class_mapper(argument, + configure=False) + elif isinstance(self.argument, mapperlib.Mapper): + mapper_ = argument + else: + raise sa_exc.ArgumentError("relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(argument))) + return mapper_ + + @util.memoized_property + @util.deprecated("0.7", "Use .target") + def table(self): + """Return the selectable linked to this + :class:`.RelationshipProperty` object's target + :class:`.Mapper`. + """ + return self.target + + def do_init(self): + self._check_conflicts() + self._process_dependent_arguments() + self._setup_join_conditions() + self._check_cascade_settings(self._cascade) + self._post_init() + self._generate_backref() + super(RelationshipProperty, self).do_init() + self._lazy_strategy = self._get_strategy((("lazy", "select"),)) + + + def _process_dependent_arguments(self): + """Convert incoming configuration arguments to their + proper form. + + Callables are resolved, ORM annotations removed. + + """ + # accept callables for other attributes which may require + # deferred initialization. This technique is used + # by declarative "string configs" and some recipes. + for attr in ( + 'order_by', 'primaryjoin', 'secondaryjoin', + 'secondary', '_user_defined_foreign_keys', 'remote_side', + ): + attr_value = getattr(self, attr) + if util.callable(attr_value): + setattr(self, attr, attr_value()) + + # remove "annotations" which are present if mapped class + # descriptors are used to create the join expression. + for attr in 'primaryjoin', 'secondaryjoin': + val = getattr(self, attr) + if val is not None: + setattr(self, attr, _orm_deannotate( + expression._only_column_elements(val, attr)) + ) + + # ensure expressions in self.order_by, foreign_keys, + # remote_side are all columns, not strings. + if self.order_by is not False and self.order_by is not None: + self.order_by = [ + expression._only_column_elements(x, "order_by") + for x in + util.to_list(self.order_by)] + + self._user_defined_foreign_keys = \ + util.column_set( + expression._only_column_elements(x, "foreign_keys") + for x in util.to_column_set( + self._user_defined_foreign_keys + )) + + self.remote_side = \ + util.column_set( + expression._only_column_elements(x, "remote_side") + for x in + util.to_column_set(self.remote_side)) + + self.target = self.mapper.mapped_table + + + def _setup_join_conditions(self): + self._join_condition = jc = JoinCondition( + parent_selectable=self.parent.mapped_table, + child_selectable=self.mapper.mapped_table, + parent_local_selectable=self.parent.local_table, + child_local_selectable=self.mapper.local_table, + primaryjoin=self.primaryjoin, + secondary=self.secondary, + secondaryjoin=self.secondaryjoin, + parent_equivalents=self.parent._equivalent_columns, + child_equivalents=self.mapper._equivalent_columns, + consider_as_foreign_keys=self._user_defined_foreign_keys, + local_remote_pairs=self.local_remote_pairs, + remote_side=self.remote_side, + self_referential=self._is_self_referential, + prop=self, + support_sync=not self.viewonly, + can_be_synced_fn=self._columns_are_mapped + ) + self.primaryjoin = jc.deannotated_primaryjoin + self.secondaryjoin = jc.deannotated_secondaryjoin + self.direction = jc.direction + self.local_remote_pairs = jc.local_remote_pairs + self.remote_side = jc.remote_columns + self.local_columns = jc.local_columns + self.synchronize_pairs = jc.synchronize_pairs + self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs + + def _check_conflicts(self): + """Test that this relationship is legal, warn about + inheritance conflicts.""" + + if not self.is_primary() \ + and not mapperlib.class_mapper( + self.parent.class_, + configure=False).has_property(self.key): + raise sa_exc.ArgumentError("Attempting to assign a new " + "relationship '%s' to a non-primary mapper on " + "class '%s'. New relationships can only be added " + "to the primary mapper, i.e. the very first mapper " + "created for class '%s' " % (self.key, + self.parent.class_.__name__, + self.parent.class_.__name__)) + + # check for conflicting relationship() on superclass + if not self.parent.concrete: + for inheriting in self.parent.iterate_to_root(): + if inheriting is not self.parent \ + and inheriting.has_property(self.key): + util.warn("Warning: relationship '%s' on mapper " + "'%s' supersedes the same relationship " + "on inherited mapper '%s'; this can " + "cause dependency issues during flush" + % (self.key, self.parent, inheriting)) + + def _get_cascade(self): + """Return the current cascade setting for this + :class:`.RelationshipProperty`. + """ + return self._cascade + + def _set_cascade(self, cascade): + cascade = CascadeOptions(cascade) + if 'mapper' in self.__dict__: + self._check_cascade_settings(cascade) + self._cascade = cascade + + if self._dependency_processor: + self._dependency_processor.cascade = cascade + + cascade = property(_get_cascade, _set_cascade) + + def _check_cascade_settings(self, cascade): + if cascade.delete_orphan and not self.single_parent \ + and (self.direction is MANYTOMANY or self.direction + is MANYTOONE): + raise sa_exc.ArgumentError( + 'On %s, delete-orphan cascade is not supported ' + 'on a many-to-many or many-to-one relationship ' + 'when single_parent is not set. Set ' + 'single_parent=True on the relationship().' + % self) + if self.direction is MANYTOONE and self.passive_deletes: + util.warn("On %s, 'passive_deletes' is normally configured " + "on one-to-many, one-to-one, many-to-many " + "relationships only." + % self) + + if self.passive_deletes == 'all' and \ + ("delete" in cascade or + "delete-orphan" in cascade): + raise sa_exc.ArgumentError( + "On %s, can't set passive_deletes='all' in conjunction " + "with 'delete' or 'delete-orphan' cascade" % self) + + if cascade.delete_orphan: + self.mapper.primary_mapper()._delete_orphans.append( + (self.key, self.parent.class_) + ) + + def _columns_are_mapped(self, *cols): + """Return True if all columns in the given collection are + mapped by the tables referenced by this :class:`.Relationship`. + + """ + for c in cols: + if self.secondary is not None \ + and self.secondary.c.contains_column(c): + continue + if not self.parent.mapped_table.c.contains_column(c) and \ + not self.target.c.contains_column(c): + return False + return True + + def _generate_backref(self): + """Interpret the 'backref' instruction to create a + :func:`.relationship` complementary to this one.""" + + if not self.is_primary(): + return + if self.backref is not None and not self.back_populates: + if isinstance(self.backref, util.string_types): + backref_key, kwargs = self.backref, {} + else: + backref_key, kwargs = self.backref + mapper = self.mapper.primary_mapper() + + check = set(mapper.iterate_to_root()).\ + union(mapper.self_and_descendants) + for m in check: + if m.has_property(backref_key): + raise sa_exc.ArgumentError("Error creating backref " + "'%s' on relationship '%s': property of that " + "name exists on mapper '%s'" % (backref_key, + self, m)) + + # determine primaryjoin/secondaryjoin for the + # backref. Use the one we had, so that + # a custom join doesn't have to be specified in + # both directions. + if self.secondary is not None: + # for many to many, just switch primaryjoin/ + # secondaryjoin. use the annotated + # pj/sj on the _join_condition. + pj = kwargs.pop('primaryjoin', + self._join_condition.secondaryjoin_minus_local) + sj = kwargs.pop('secondaryjoin', + self._join_condition.primaryjoin_minus_local) + else: + pj = kwargs.pop('primaryjoin', + self._join_condition.primaryjoin_reverse_remote) + sj = kwargs.pop('secondaryjoin', None) + if sj: + raise sa_exc.InvalidRequestError( + "Can't assign 'secondaryjoin' on a backref " + "against a non-secondary relationship." + ) + + foreign_keys = kwargs.pop('foreign_keys', + self._user_defined_foreign_keys) + parent = self.parent.primary_mapper() + kwargs.setdefault('viewonly', self.viewonly) + kwargs.setdefault('post_update', self.post_update) + kwargs.setdefault('passive_updates', self.passive_updates) + self.back_populates = backref_key + relationship = RelationshipProperty( + parent, self.secondary, + pj, sj, + foreign_keys=foreign_keys, + back_populates=self.key, + **kwargs) + mapper._configure_property(backref_key, relationship) + + if self.back_populates: + self._add_reverse_property(self.back_populates) + + def _post_init(self): + if self.uselist is None: + self.uselist = self.direction is not MANYTOONE + if not self.viewonly: + self._dependency_processor = \ + dependency.DependencyProcessor.from_relationship(self) + + @util.memoized_property + def _use_get(self): + """memoize the 'use_get' attribute of this RelationshipLoader's + lazyloader.""" + + strategy = self._lazy_strategy + return strategy.use_get + + @util.memoized_property + def _is_self_referential(self): + return self.mapper.common_parent(self.parent) + + def _create_joins(self, source_polymorphic=False, + source_selectable=None, dest_polymorphic=False, + dest_selectable=None, of_type=None): + if source_selectable is None: + if source_polymorphic and self.parent.with_polymorphic: + source_selectable = self.parent._with_polymorphic_selectable + + aliased = False + if dest_selectable is None: + if dest_polymorphic and self.mapper.with_polymorphic: + dest_selectable = self.mapper._with_polymorphic_selectable + aliased = True + else: + dest_selectable = self.mapper.mapped_table + + if self._is_self_referential and source_selectable is None: + dest_selectable = dest_selectable.alias() + aliased = True + else: + aliased = True + + dest_mapper = of_type or self.mapper + + single_crit = dest_mapper._single_table_criterion + aliased = aliased or (source_selectable is not None) + + primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = \ + self._join_condition.join_targets( + source_selectable, dest_selectable, aliased, single_crit + ) + if source_selectable is None: + source_selectable = self.parent.local_table + if dest_selectable is None: + dest_selectable = self.mapper.local_table + return (primaryjoin, secondaryjoin, source_selectable, + dest_selectable, secondary, target_adapter) + +def _annotate_columns(element, annotations): + def clone(elem): + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + + +class JoinCondition(object): + def __init__(self, + parent_selectable, + child_selectable, + parent_local_selectable, + child_local_selectable, + primaryjoin=None, + secondary=None, + secondaryjoin=None, + parent_equivalents=None, + child_equivalents=None, + consider_as_foreign_keys=None, + local_remote_pairs=None, + remote_side=None, + self_referential=False, + prop=None, + support_sync=True, + can_be_synced_fn=lambda *c: True + ): + self.parent_selectable = parent_selectable + self.parent_local_selectable = parent_local_selectable + self.child_selectable = child_selectable + self.child_local_selectable = child_local_selectable + self.parent_equivalents = parent_equivalents + self.child_equivalents = child_equivalents + self.primaryjoin = primaryjoin + self.secondaryjoin = secondaryjoin + self.secondary = secondary + self.consider_as_foreign_keys = consider_as_foreign_keys + self._local_remote_pairs = local_remote_pairs + self._remote_side = remote_side + self.prop = prop + self.self_referential = self_referential + self.support_sync = support_sync + self.can_be_synced_fn = can_be_synced_fn + self._determine_joins() + self._annotate_fks() + self._annotate_remote() + self._annotate_local() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._determine_direction() + self._check_remote_side() + self._log_joins() + + def _log_joins(self): + if self.prop is None: + return + log = self.prop.logger + log.info('%s setup primary join %s', self.prop, + self.primaryjoin) + log.info('%s setup secondary join %s', self.prop, + self.secondaryjoin) + log.info('%s synchronize pairs [%s]', self.prop, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.synchronize_pairs)) + log.info('%s secondary synchronize pairs [%s]', self.prop, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.secondary_synchronize_pairs or [])) + log.info('%s local/remote pairs [%s]', self.prop, + ','.join('(%s / %s)' % (l, r) for (l, r) in + self.local_remote_pairs)) + log.info('%s remote columns [%s]', self.prop, + ','.join('%s' % col for col in self.remote_columns) + ) + log.info('%s local columns [%s]', self.prop, + ','.join('%s' % col for col in self.local_columns) + ) + log.info('%s relationship direction %s', self.prop, + self.direction) + + def _determine_joins(self): + """Determine the 'primaryjoin' and 'secondaryjoin' attributes, + if not passed to the constructor already. + + This is based on analysis of the foreign key relationships + between the parent and target mapped selectables. + + """ + if self.secondaryjoin is not None and self.secondary is None: + raise sa_exc.ArgumentError( + "Property %s specified with secondary " + "join condition but " + "no secondary argument" % self.prop) + + # find a join between the given mapper's mapped table and + # the given table. will try the mapper's local table first + # for more specificity, then if not found will try the more + # general mapped table, which in the case of inheritance is + # a join. + try: + consider_as_foreign_keys = self.consider_as_foreign_keys or None + if self.secondary is not None: + if self.secondaryjoin is None: + self.secondaryjoin = \ + join_condition( + self.child_selectable, + self.secondary, + a_subset=self.child_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys + ) + if self.primaryjoin is None: + self.primaryjoin = \ + join_condition( + self.parent_selectable, + self.secondary, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys + ) + else: + if self.primaryjoin is None: + self.primaryjoin = \ + join_condition( + self.parent_selectable, + self.child_selectable, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys + ) + except sa_exc.NoForeignKeysError: + if self.secondary is not None: + raise sa_exc.NoForeignKeysError("Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." + % (self.prop, self.secondary)) + else: + raise sa_exc.NoForeignKeysError("Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." + % self.prop) + except sa_exc.AmbiguousForeignKeysError: + if self.secondary is not None: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." + % (self.prop, self.secondary)) + else: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." + % self.prop) + + @property + def primaryjoin_minus_local(self): + return _deep_deannotate(self.primaryjoin, values=("local", "remote")) + + @property + def secondaryjoin_minus_local(self): + return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) + + @util.memoized_property + def primaryjoin_reverse_remote(self): + """Return the primaryjoin condition suitable for the + "reverse" direction. + + If the primaryjoin was delivered here with pre-existing + "remote" annotations, the local/remote annotations + are reversed. Otherwise, the local/remote annotations + are removed. + + """ + if self._has_remote_annotations: + def replace(element): + if "remote" in element._annotations: + v = element._annotations.copy() + del v['remote'] + v['local'] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = element._annotations.copy() + del v['local'] + v['remote'] = True + return element._with_annotations(v) + return visitors.replacement_traverse( + self.primaryjoin, {}, replace) + else: + if self._has_foreign_annotations: + # TODO: coverage + return _deep_deannotate(self.primaryjoin, + values=("local", "remote")) + else: + return _deep_deannotate(self.primaryjoin) + + def _has_annotation(self, clause, annotation): + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + + @util.memoized_property + def _has_foreign_annotations(self): + return self._has_annotation(self.primaryjoin, "foreign") + + @util.memoized_property + def _has_remote_annotations(self): + return self._has_annotation(self.primaryjoin, "remote") + + def _annotate_fks(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'foreign' annotations marking columns + considered as foreign. + + """ + if self._has_foreign_annotations: + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self): + def check_fk(col): + if col in self.consider_as_foreign_keys: + return col._annotate({"foreign": True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, + {}, + check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, + {}, + check_fk + ) + + def _annotate_present_fks(self): + if self.secondary is not None: + secondarycols = util.column_set(self.secondary.c) + else: + secondarycols = set() + + def is_foreign(a, b): + if isinstance(a, schema.Column) and \ + isinstance(b, schema.Column): + if a.references(b): + return a + elif b.references(a): + return b + + if secondarycols: + if a in secondarycols and b not in secondarycols: + return a + elif b in secondarycols and a not in secondarycols: + return b + + def visit_binary(binary): + if not isinstance(binary.left, sql.ColumnElement) or \ + not isinstance(binary.right, sql.ColumnElement): + return + + if "foreign" not in binary.left._annotations and \ + "foreign" not in binary.right._annotations: + col = is_foreign(binary.left, binary.right) + if col is not None: + if col.compare(binary.left): + binary.left = binary.left._annotate( + {"foreign": True}) + elif col.compare(binary.right): + binary.right = binary.right._annotate( + {"foreign": True}) + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, + {}, + {"binary": visit_binary} + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.cloned_traverse( + self.secondaryjoin, + {}, + {"binary": visit_binary} + ) + + def _refers_to_parent_table(self): + """Return True if the join condition contains column + comparisons where both columns are in both tables. + + """ + pt = self.parent_selectable + mt = self.child_selectable + result = [False] + + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + visitors.traverse( + self.primaryjoin, + {}, + {"binary": visit_binary} + ) + return result[0] + + def _tables_overlap(self): + """Return True if parent/child tables have some overlap.""" + + return selectables_overlap(self.parent_selectable, self.child_selectable) + + def _annotate_remote(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'remote' annotations marking columns + considered as part of the 'remote' side. + + """ + if self._has_remote_annotations: + return + + if self.secondary is not None: + self._annotate_remote_secondary() + elif self._local_remote_pairs or self._remote_side: + self._annotate_remote_from_args() + elif self._refers_to_parent_table(): + self._annotate_selfref(lambda col: "foreign" in col._annotations) + elif self._tables_overlap(): + self._annotate_remote_with_overlap() + else: + self._annotate_remote_distinct_selectables() + + def _annotate_remote_secondary(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when 'secondary' is present. + + """ + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + + def _annotate_selfref(self, fn): + """annotate 'remote' in primaryjoin, secondaryjoin + when the relationship is detected as self-referential. + + """ + def visit_binary(binary): + equated = binary.left.compare(binary.right) + if isinstance(binary.left, expression.ColumnClause) and \ + isinstance(binary.right, expression.ColumnClause): + # assume one to many - FKs are "remote" + if fn(binary.left): + binary.left = binary.left._annotate({"remote": True}) + if fn(binary.right) and not equated: + binary.right = binary.right._annotate( + {"remote": True}) + else: + self._warn_non_column_elements() + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, + {"binary": visit_binary}) + + def _annotate_remote_from_args(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the 'remote_side' or '_local_remote_pairs' + arguments are used. + + """ + if self._local_remote_pairs: + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument.") + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + self._annotate_selfref(lambda col: col in remote_side) + else: + def repl(element): + if element in remote_side: + return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _annotate_remote_with_overlap(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables have some set of + tables in common, though is not a fully self-referential + relationship. + + """ + def visit_binary(binary): + binary.left, binary.right = proc_left_right(binary.left, + binary.right) + binary.right, binary.left = proc_left_right(binary.right, + binary.left) + + def proc_left_right(left, right): + if isinstance(left, expression.ColumnClause) and \ + isinstance(right, expression.ColumnClause): + if self.child_selectable.c.contains_column(right) and \ + self.parent_selectable.c.contains_column(left): + right = right._annotate({"remote": True}) + else: + self._warn_non_column_elements() + + return left, right + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, + {"binary": visit_binary}) + + def _annotate_remote_distinct_selectables(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables are entirely + separate. + + """ + def repl(element): + if self.child_selectable.c.contains_column(element) and \ + ( + not self.parent_local_selectable.c.\ + contains_column(element) + or self.child_local_selectable.c.\ + contains_column(element)): + return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _warn_non_column_elements(self): + util.warn( + "Non-simple column elements in primary " + "join condition for property %s - consider using " + "remote() annotations to mark the remote side." + % self.prop + ) + + def _annotate_local(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'local' annotations. + + This annotates all column elements found + simultaneously in the parent table + and the join condition that don't have a + 'remote' annotation set up from + _annotate_remote() or user-defined. + + """ + if self._has_annotation(self.primaryjoin, "local"): + return + + if self._local_remote_pairs: + local_side = util.column_set([l for (l, r) + in self._local_remote_pairs]) + else: + local_side = util.column_set(self.parent_selectable.c) + + def locals_(elem): + if "remote" not in elem._annotations and \ + elem in local_side: + return elem._annotate({"local": True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, locals_ + ) + + def _check_remote_side(self): + if not self.local_remote_pairs: + raise sa_exc.ArgumentError('Relationship %s could ' + 'not determine any unambiguous local/remote column ' + 'pairs based on join condition and remote_side ' + 'arguments. ' + 'Consider using the remote() annotation to ' + 'accurately mark those elements of the join ' + 'condition that are on the remote side of ' + 'the relationship.' + % (self.prop, )) + + def _check_foreign_cols(self, join_condition, primary): + """Check the foreign key columns collected and emit error + messages.""" + + can_sync = False + + foreign_cols = self._gather_columns_with_annotation( + join_condition, "foreign") + + has_foreign = bool(foreign_cols) + + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) + + if self.support_sync and can_sync or \ + (not self.support_sync and has_foreign): + return + + # from here below is just determining the best error message + # to report. Check for a join condition using any operator + # (not just ==), perhaps they need to turn on "viewonly=True". + if self.support_sync and has_foreign and not can_sync: + err = "Could not locate any simple equality expressions "\ + "involving locally mapped foreign key columns for "\ + "%s join condition "\ + "'%s' on relationship %s." % ( + primary and 'primary' or 'secondary', + join_condition, + self.prop + ) + err += \ + " Ensure that referencing columns are associated "\ + "with a ForeignKey or ForeignKeyConstraint, or are "\ + "annotated in the join condition with the foreign() "\ + "annotation. To allow comparison operators other than "\ + "'==', the relationship can be marked as viewonly=True." + + raise sa_exc.ArgumentError(err) + else: + err = "Could not locate any relevant foreign key columns "\ + "for %s join condition '%s' on relationship %s." % ( + primary and 'primary' or 'secondary', + join_condition, + self.prop + ) + err += \ + ' Ensure that referencing columns are associated '\ + 'with a ForeignKey or ForeignKeyConstraint, or are '\ + 'annotated in the join condition with the foreign() '\ + 'annotation.' + raise sa_exc.ArgumentError(err) + + def _determine_direction(self): + """Determine if this relationship is one to many, many to one, + many to many. + + """ + if self.secondaryjoin is not None: + self.direction = MANYTOMANY + else: + parentcols = util.column_set(self.parent_selectable.c) + targetcols = util.column_set(self.child_selectable.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection( + self.foreign_key_columns) + + # fk collection which suggests MANYTOONE. + + manytoone_fk = parentcols.intersection( + self.foreign_key_columns) + + if onetomany_fk and manytoone_fk: + # fks on both sides. test for overlap of local/remote + # with foreign key + self_equated = self.remote_columns.intersection( + self.local_columns + ) + onetomany_local = self.remote_columns.\ + intersection(self.foreign_key_columns).\ + difference(self_equated) + manytoone_local = self.local_columns.\ + intersection(self.foreign_key_columns).\ + difference(self_equated) + if onetomany_local and not manytoone_local: + self.direction = ONETOMANY + elif manytoone_local and not onetomany_local: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship" + " direction for relationship '%s' - foreign " + "key columns within the join condition are present " + "in both the parent and the child's mapped tables. " + "Ensure that only those columns referring " + "to a parent column are marked as foreign, " + "either via the foreign() annotation or " + "via the foreign_keys argument." % self.prop) + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError("Can't determine relationship " + "direction for relationship '%s' - foreign " + "key columns are present in neither the parent " + "nor the child's mapped tables" % self.prop) + + def _deannotate_pairs(self, collection): + """provide deannotation for the various lists of + pairs, so that using them in hashes doesn't incur + high-overhead __eq__() comparisons against + original columns mapped. + + """ + return [(x._deannotate(), y._deannotate()) + for x, y in collection] + + def _setup_pairs(self): + sync_pairs = [] + lrp = util.OrderedSet([]) + secondary_sync_pairs = [] + + def go(joincond, collection): + def visit_binary(binary, left, right): + if "remote" in right._annotations and \ + "remote" not in left._annotations and \ + self.can_be_synced_fn(left): + lrp.add((left, right)) + elif "remote" in left._annotations and \ + "remote" not in right._annotations and \ + self.can_be_synced_fn(right): + lrp.add((right, left)) + if binary.operator is operators.eq and \ + self.can_be_synced_fn(left, right): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs) + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = self._deannotate_pairs(lrp) + self.synchronize_pairs = self._deannotate_pairs(sync_pairs) + self.secondary_synchronize_pairs = \ + self._deannotate_pairs(secondary_sync_pairs) + + @util.memoized_property + def remote_columns(self): + return self._gather_join_annotations("remote") + + @util.memoized_property + def local_columns(self): + return self._gather_join_annotations("local") + + @util.memoized_property + def foreign_key_columns(self): + return self._gather_join_annotations("foreign") + + @util.memoized_property + def deannotated_primaryjoin(self): + return _deep_deannotate(self.primaryjoin) + + @util.memoized_property + def deannotated_secondaryjoin(self): + if self.secondaryjoin is not None: + return _deep_deannotate(self.secondaryjoin) + else: + return None + + def _gather_join_annotations(self, annotation): + s = set( + self._gather_columns_with_annotation( + self.primaryjoin, annotation) + ) + if self.secondaryjoin is not None: + s.update( + self._gather_columns_with_annotation( + self.secondaryjoin, annotation) + ) + return set([x._deannotate() for x in s]) + + def _gather_columns_with_annotation(self, clause, *annotation): + annotation = set(annotation) + return set([ + col for col in visitors.iterate(clause, {}) + if annotation.issubset(col._annotations) + ]) + + def join_targets(self, source_selectable, + dest_selectable, + aliased, + single_crit=None): + """Given a source and destination selectable, create a + join between them. + + This takes into account aliasing the join clause + to reference the appropriate corresponding columns + in the target objects, as well as the extra child + criterion, equivalent column sets, etc. + + """ + + # place a barrier on the destination such that + # replacement traversals won't ever dig into it. + # its internal structure remains fixed + # regardless of context. + dest_selectable = _shallow_annotate( + dest_selectable, + {'no_replacement_traverse': True}) + + primaryjoin, secondaryjoin, secondary = self.primaryjoin, \ + self.secondaryjoin, self.secondary + + # adjust the join condition for single table inheritance, + # in the case that the join is to a subclass + # this is analogous to the + # "_adjust_for_single_table_inheritance()" method in Query. + + if single_crit is not None: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & single_crit + else: + primaryjoin = primaryjoin & single_crit + + if aliased: + if secondary is not None: + secondary = secondary.alias(flat=True) + primary_aliasizer = ClauseAdapter(secondary) + secondary_aliasizer = \ + ClauseAdapter(dest_selectable, + equivalents=self.child_equivalents).\ + chain(primary_aliasizer) + if source_selectable is not None: + primary_aliasizer = \ + ClauseAdapter(secondary).\ + chain(ClauseAdapter(source_selectable, + equivalents=self.parent_equivalents)) + secondaryjoin = \ + secondary_aliasizer.traverse(secondaryjoin) + else: + primary_aliasizer = ClauseAdapter(dest_selectable, + exclude_fn=_ColInAnnotations("local"), + equivalents=self.child_equivalents) + if source_selectable is not None: + primary_aliasizer.chain( + ClauseAdapter(source_selectable, + exclude_fn=_ColInAnnotations("remote"), + equivalents=self.parent_equivalents)) + secondary_aliasizer = None + + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer + target_adapter.exclude_fn = None + else: + target_adapter = None + return primaryjoin, secondaryjoin, secondary, \ + target_adapter, dest_selectable + + def create_lazy_clause(self, reverse_direction=False): + binds = util.column_dict() + lookup = util.column_dict() + equated_columns = util.column_dict() + being_replaced = set() + + if reverse_direction and self.secondaryjoin is None: + for l, r in self.local_remote_pairs: + _list = lookup.setdefault(r, []) + _list.append((r, l)) + equated_columns[l] = r + else: + # replace all "local side" columns, which is + # anything that isn't marked "remote" + being_replaced.update(self.local_columns) + for l, r in self.local_remote_pairs: + _list = lookup.setdefault(l, []) + _list.append((l, r)) + equated_columns[r] = l + + def col_to_bind(col): + if col in being_replaced or col in lookup: + if col in lookup: + for tobind, equated in lookup[col]: + if equated in binds: + return None + else: + assert not reverse_direction + if col not in binds: + binds[col] = sql.bindparam( + None, None, type_=col.type, unique=True) + return binds[col] + return None + + lazywhere = self.deannotated_primaryjoin + + if self.deannotated_secondaryjoin is None or not reverse_direction: + lazywhere = visitors.replacement_traverse( + lazywhere, {}, col_to_bind) + + if self.deannotated_secondaryjoin is not None: + secondaryjoin = self.deannotated_secondaryjoin + if reverse_direction: + secondaryjoin = visitors.replacement_traverse( + secondaryjoin, {}, col_to_bind) + lazywhere = sql.and_(lazywhere, secondaryjoin) + + bind_to_col = dict((binds[col].key, col) for col in binds) + + return lazywhere, bind_to_col, equated_columns + +class _ColInAnnotations(object): + """Seralizable equivalent to: + + lambda c: "name" in c._annotations + """ + def __init__(self, name): + self.name = name + + def __call__(self, c): + return self.name in c._annotations diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py new file mode 100644 index 00000000..c1f8f319 --- /dev/null +++ b/lib/sqlalchemy/orm/scoping.py @@ -0,0 +1,176 @@ +# orm/scoping.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .. import exc as sa_exc +from ..util import ScopedRegistry, ThreadLocalRegistry, warn +from . import class_mapper, exc as orm_exc +from .session import Session + + +__all__ = ['scoped_session'] + + +class scoped_session(object): + """Provides scoped management of :class:`.Session` objects. + + See :ref:`unitofwork_contextual` for a tutorial. + + """ + + def __init__(self, session_factory, scopefunc=None): + """Construct a new :class:`.scoped_session`. + + :param session_factory: a factory to create new :class:`.Session` + instances. This is usually, but not necessarily, an instance + of :class:`.sessionmaker`. + :param scopefunc: optional function which defines + the current scope. If not passed, the :class:`.scoped_session` + object assumes "thread-local" scope, and will use + a Python ``threading.local()`` in order to maintain the current + :class:`.Session`. If passed, the function should return + a hashable token; this token will be used as the key in a + dictionary in order to store and retrieve the current + :class:`.Session`. + + """ + self.session_factory = session_factory + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) + + def __call__(self, **kw): + """Return the current :class:`.Session`, creating it + using the session factory if not present. + + :param \**kw: Keyword arguments will be passed to the + session factory callable, if an existing :class:`.Session` + is not present. If the :class:`.Session` is present and + keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + scope = kw.pop('scope', False) + if scope is not None: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified.") + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + return sess + else: + return self.session_factory(**kw) + else: + return self.registry() + + def remove(self): + """Dispose of the current :class:`.Session`, if present. + + This will first call :meth:`.Session.close` method + on the current :class:`.Session`, which releases any existing + transactional/connection resources still being held; transactions + specifically are rolled back. The :class:`.Session` is then + discarded. Upon next usage within the same scope, + the :class:`.scoped_session` will produce a new + :class:`.Session` object. + + """ + + if self.registry.has(): + self.registry().close() + self.registry.clear() + + def configure(self, **kwargs): + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn('At least one scoped session is already present. ' + ' configure() can not affect sessions that have ' + 'already been created.') + + self.session_factory.configure(**kwargs) + + def query_property(self, query_cls=None): + """return a class property which produces a :class:`.Query` object + against the class and the current :class:`.Session` when called. + + e.g.:: + + Session = scoped_session(sessionmaker()) + + class MyClass(object): + query = Session.query_property() + + # after mappers are defined + result = MyClass.query.filter(MyClass.name=='foo').all() + + Produces instances of the session's configured query class by + default. To override and use a custom implementation, provide + a ``query_cls`` callable. The callable will be invoked with + the class's mapper as a positional argument and a session + keyword argument. + + There is no limit to the number of query properties placed on + a class. + + """ + class query(object): + def __get__(s, instance, owner): + try: + mapper = class_mapper(owner) + if mapper: + if query_cls: + # custom query class + return query_cls(mapper, session=self.registry()) + else: + # session's configured query class + return self.registry().query(mapper) + except orm_exc.UnmappedClassError: + return None + return query() + +ScopedSession = scoped_session +"""Old name for backwards compatibility.""" + + +def instrument(name): + def do(self, *args, **kwargs): + return getattr(self.registry(), name)(*args, **kwargs) + return do + +for meth in Session.public_methods: + setattr(scoped_session, meth, instrument(meth)) + + +def makeprop(name): + def set(self, attr): + setattr(self.registry(), name, attr) + + def get(self): + return getattr(self.registry(), name) + + return property(get, set) + +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', + 'is_active', 'autoflush', 'no_autoflush', 'info'): + setattr(scoped_session, prop, makeprop(prop)) + + +def clslevel(name): + def do(cls, *args, **kwargs): + return getattr(Session, name)(*args, **kwargs) + return classmethod(do) + +for prop in ('close_all', 'object_session', 'identity_key'): + setattr(scoped_session, prop, clslevel(prop)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py new file mode 100644 index 00000000..5bd46691 --- /dev/null +++ b/lib/sqlalchemy/orm/session.py @@ -0,0 +1,2407 @@ +# orm/session.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Provides the Session class and related utilities.""" + + + +import weakref +from .. import util, sql, engine, exc as sa_exc +from ..sql import util as sql_util, expression +from . import ( + SessionExtension, attributes, exc, query, + loading, identity + ) +from ..inspection import inspect +from .base import ( + object_mapper, class_mapper, + _class_to_mapper, _state_mapper, object_state, + _none_set, state_str, instance_str + ) +from .unitofwork import UOWTransaction +from . import state as statelib +import sys + +__all__ = ['Session', 'SessionTransaction', 'SessionExtension', 'sessionmaker'] + +_sessions = weakref.WeakValueDictionary() +"""Weak-referencing dictionary of :class:`.Session` objects. +""" + +def _state_session(state): + """Given an :class:`.InstanceState`, return the :class:`.Session` + associated, if any. + """ + if state.session_id: + try: + return _sessions[state.session_id] + except KeyError: + pass + return None + + + +class _SessionClassMethods(object): + """Class-level methods for :class:`.Session`, :class:`.sessionmaker`.""" + + @classmethod + def close_all(cls): + """Close *all* sessions in memory.""" + + for sess in _sessions.values(): + sess.close() + + @classmethod + @util.dependencies("sqlalchemy.orm.util") + def identity_key(cls, orm_util, *args, **kwargs): + """Return an identity key. + + This is an alias of :func:`.util.identity_key`. + + """ + return orm_util.identity_key(*args, **kwargs) + + @classmethod + def object_session(cls, instance): + """Return the :class:`.Session` to which an object belongs. + + This is an alias of :func:`.object_session`. + + """ + + return object_session(instance) + + +ACTIVE = util.symbol('ACTIVE') +PREPARED = util.symbol('PREPARED') +COMMITTED = util.symbol('COMMITTED') +DEACTIVE = util.symbol('DEACTIVE') +CLOSED = util.symbol('CLOSED') + +class SessionTransaction(object): + """A :class:`.Session`-level transaction. + + :class:`.SessionTransaction` is a mostly behind-the-scenes object + not normally referenced directly by application code. It coordinates + among multiple :class:`.Connection` objects, maintaining a database + transaction for each one individually, committing or rolling them + back all at once. It also provides optional two-phase commit behavior + which can augment this coordination operation. + + The :attr:`.Session.transaction` attribute of :class:`.Session` + refers to the current :class:`.SessionTransaction` object in use, if any. + + + A :class:`.SessionTransaction` is associated with a :class:`.Session` + in its default mode of ``autocommit=False`` immediately, associated + with no database connections. As the :class:`.Session` is called upon + to emit SQL on behalf of various :class:`.Engine` or :class:`.Connection` + objects, a corresponding :class:`.Connection` and associated + :class:`.Transaction` is added to a collection within the + :class:`.SessionTransaction` object, becoming one of the + connection/transaction pairs maintained by the + :class:`.SessionTransaction`. + + The lifespan of the :class:`.SessionTransaction` ends when the + :meth:`.Session.commit`, :meth:`.Session.rollback` or + :meth:`.Session.close` methods are called. At this point, the + :class:`.SessionTransaction` removes its association with its parent + :class:`.Session`. A :class:`.Session` that is in ``autocommit=False`` + mode will create a new :class:`.SessionTransaction` to replace it + immediately, whereas a :class:`.Session` that's in ``autocommit=True`` + mode will remain without a :class:`.SessionTransaction` until the + :meth:`.Session.begin` method is called. + + Another detail of :class:`.SessionTransaction` behavior is that it is + capable of "nesting". This means that the :meth:`.Session.begin` method + can be called while an existing :class:`.SessionTransaction` is already + present, producing a new :class:`.SessionTransaction` that temporarily + replaces the parent :class:`.SessionTransaction`. When a + :class:`.SessionTransaction` is produced as nested, it assigns itself to + the :attr:`.Session.transaction` attribute. When it is ended via + :meth:`.Session.commit` or :meth:`.Session.rollback`, it restores its + parent :class:`.SessionTransaction` back onto the + :attr:`.Session.transaction` attribute. The behavior is effectively a + stack, where :attr:`.Session.transaction` refers to the current head of + the stack. + + The purpose of this stack is to allow nesting of + :meth:`.Session.rollback` or :meth:`.Session.commit` calls in context + with various flavors of :meth:`.Session.begin`. This nesting behavior + applies to when :meth:`.Session.begin_nested` is used to emit a + SAVEPOINT transaction, and is also used to produce a so-called + "subtransaction" which allows a block of code to use a + begin/rollback/commit sequence regardless of whether or not its enclosing + code block has begun a transaction. The :meth:`.flush` method, whether + called explicitly or via autoflush, is the primary consumer of the + "subtransaction" feature, in that it wishes to guarantee that it works + within in a transaction block regardless of whether or not the + :class:`.Session` is in transactional mode when the method is called. + + See also: + + :meth:`.Session.rollback` + + :meth:`.Session.commit` + + :meth:`.Session.begin` + + :meth:`.Session.begin_nested` + + :attr:`.Session.is_active` + + :meth:`.SessionEvents.after_commit` + + :meth:`.SessionEvents.after_rollback` + + :meth:`.SessionEvents.after_soft_rollback` + + """ + + _rollback_exception = None + + def __init__(self, session, parent=None, nested=False): + self.session = session + self._connections = {} + self._parent = parent + self.nested = nested + self._state = ACTIVE + if not parent and nested: + raise sa_exc.InvalidRequestError( + "Can't start a SAVEPOINT transaction when no existing " + "transaction is in progress") + + if self.session._enable_transaction_accounting: + self._take_snapshot() + + if self.session.dispatch.after_transaction_create: + self.session.dispatch.after_transaction_create(self.session, self) + + @property + def is_active(self): + return self.session is not None and self._state is ACTIVE + + def _assert_active(self, prepared_ok=False, + rollback_ok=False, + deactive_ok=False, + closed_msg="This transaction is closed"): + if self._state is COMMITTED: + raise sa_exc.InvalidRequestError( + "This session is in 'committed' state; no further " + "SQL can be emitted within this transaction." + ) + elif self._state is PREPARED: + if not prepared_ok: + raise sa_exc.InvalidRequestError( + "This session is in 'prepared' state; no further " + "SQL can be emitted within this transaction." + ) + elif self._state is DEACTIVE: + if not deactive_ok and not rollback_ok: + if self._rollback_exception: + raise sa_exc.InvalidRequestError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + " Original exception was: %s" + % self._rollback_exception + ) + elif not deactive_ok: + raise sa_exc.InvalidRequestError( + "This Session's transaction has been rolled back " + "by a nested rollback() call. To begin a new " + "transaction, issue Session.rollback() first." + ) + elif self._state is CLOSED: + raise sa_exc.ResourceClosedError(closed_msg) + + @property + def _is_transaction_boundary(self): + return self.nested or not self._parent + + def connection(self, bindkey, **kwargs): + self._assert_active() + bind = self.session.get_bind(bindkey, **kwargs) + return self._connection_for_bind(bind) + + def _begin(self, nested=False): + self._assert_active() + return SessionTransaction( + self.session, self, nested=nested) + + def _iterate_parents(self, upto=None): + if self._parent is upto: + return (self,) + else: + if self._parent is None: + raise sa_exc.InvalidRequestError( + "Transaction %s is not on the active transaction list" % ( + upto)) + return (self,) + self._parent._iterate_parents(upto) + + def _take_snapshot(self): + if not self._is_transaction_boundary: + self._new = self._parent._new + self._deleted = self._parent._deleted + self._dirty = self._parent._dirty + self._key_switches = self._parent._key_switches + return + + if not self.session._flushing: + self.session.flush() + + self._new = weakref.WeakKeyDictionary() + self._deleted = weakref.WeakKeyDictionary() + self._dirty = weakref.WeakKeyDictionary() + self._key_switches = weakref.WeakKeyDictionary() + + def _restore_snapshot(self, dirty_only=False): + assert self._is_transaction_boundary + + for s in set(self._new).union(self.session._new): + self.session._expunge_state(s) + if s.key: + del s.key + + for s, (oldkey, newkey) in self._key_switches.items(): + self.session.identity_map.discard(s) + s.key = oldkey + self.session.identity_map.replace(s) + + for s in set(self._deleted).union(self.session._deleted): + if s.deleted: + #assert s in self._deleted + del s.deleted + self.session._update_impl(s, discard_existing=True) + + assert not self.session._deleted + + for s in self.session.identity_map.all_states(): + if not dirty_only or s.modified or s in self._dirty: + s._expire(s.dict, self.session.identity_map._modified) + + def _remove_snapshot(self): + assert self._is_transaction_boundary + + if not self.nested and self.session.expire_on_commit: + for s in self.session.identity_map.all_states(): + s._expire(s.dict, self.session.identity_map._modified) + for s in self._deleted: + s.session_id = None + self._deleted.clear() + + + def _connection_for_bind(self, bind): + self._assert_active() + + if bind in self._connections: + return self._connections[bind][0] + + if self._parent: + conn = self._parent._connection_for_bind(bind) + if not self.nested: + return conn + else: + if isinstance(bind, engine.Connection): + conn = bind + if conn.engine in self._connections: + raise sa_exc.InvalidRequestError( + "Session already has a Connection associated for the " + "given Connection's Engine") + else: + conn = bind.contextual_connect() + + if self.session.twophase and self._parent is None: + transaction = conn.begin_twophase() + elif self.nested: + transaction = conn.begin_nested() + else: + transaction = conn.begin() + + self._connections[conn] = self._connections[conn.engine] = \ + (conn, transaction, conn is not bind) + self.session.dispatch.after_begin(self.session, self, conn) + return conn + + def prepare(self): + if self._parent is not None or not self.session.twophase: + raise sa_exc.InvalidRequestError( + "'twophase' mode not enabled, or not root transaction; " + "can't prepare.") + self._prepare_impl() + + def _prepare_impl(self): + self._assert_active() + if self._parent is None or self.nested: + self.session.dispatch.before_commit(self.session) + + stx = self.session.transaction + if stx is not self: + for subtransaction in stx._iterate_parents(upto=self): + subtransaction.commit() + + if not self.session._flushing: + for _flush_guard in range(100): + if self.session._is_clean(): + break + self.session.flush() + else: + raise exc.FlushError( + "Over 100 subsequent flushes have occurred within " + "session.commit() - is an after_flush() hook " + "creating new objects?") + + if self._parent is None and self.session.twophase: + try: + for t in set(self._connections.values()): + t[1].prepare() + except: + with util.safe_reraise(): + self.rollback() + + self._state = PREPARED + + def commit(self): + self._assert_active(prepared_ok=True) + if self._state is not PREPARED: + self._prepare_impl() + + if self._parent is None or self.nested: + for t in set(self._connections.values()): + t[1].commit() + + self._state = COMMITTED + self.session.dispatch.after_commit(self.session) + + if self.session._enable_transaction_accounting: + self._remove_snapshot() + + self.close() + return self._parent + + def rollback(self, _capture_exception=False): + self._assert_active(prepared_ok=True, rollback_ok=True) + + stx = self.session.transaction + if stx is not self: + for subtransaction in stx._iterate_parents(upto=self): + subtransaction.close() + + if self._state in (ACTIVE, PREPARED): + for transaction in self._iterate_parents(): + if transaction._parent is None or transaction.nested: + transaction._rollback_impl() + transaction._state = DEACTIVE + break + else: + transaction._state = DEACTIVE + + sess = self.session + + if self.session._enable_transaction_accounting and \ + not sess._is_clean(): + # if items were added, deleted, or mutated + # here, we need to re-restore the snapshot + util.warn( + "Session's state has been changed on " + "a non-active transaction - this state " + "will be discarded.") + self._restore_snapshot(dirty_only=self.nested) + + self.close() + if self._parent and _capture_exception: + self._parent._rollback_exception = sys.exc_info()[1] + + sess.dispatch.after_soft_rollback(sess, self) + + return self._parent + + def _rollback_impl(self): + for t in set(self._connections.values()): + t[1].rollback() + + if self.session._enable_transaction_accounting: + self._restore_snapshot(dirty_only=self.nested) + + self.session.dispatch.after_rollback(self.session) + + def close(self): + self.session.transaction = self._parent + if self._parent is None: + for connection, transaction, autoclose in \ + set(self._connections.values()): + if autoclose: + connection.close() + else: + transaction.close() + + self._state = CLOSED + if self.session.dispatch.after_transaction_end: + self.session.dispatch.after_transaction_end(self.session, self) + + if self._parent is None: + if not self.session.autocommit: + self.session.begin() + self.session = None + self._connections = None + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self._assert_active(deactive_ok=True, prepared_ok=True) + if self.session.transaction is None: + return + if type is None: + try: + self.commit() + except: + with util.safe_reraise(): + self.rollback() + else: + self.rollback() + + +class Session(_SessionClassMethods): + """Manages persistence operations for ORM-mapped objects. + + The Session's usage paradigm is described at :doc:`/orm/session`. + + + """ + + public_methods = ( + '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', + 'close', 'commit', 'connection', 'delete', 'execute', 'expire', + 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', + 'is_modified', + 'merge', 'query', 'refresh', 'rollback', + 'scalar') + + def __init__(self, bind=None, autoflush=True, expire_on_commit=True, + _enable_transaction_accounting=True, + autocommit=False, twophase=False, + weak_identity_map=True, binds=None, extension=None, + info=None, + query_cls=query.Query): + """Construct a new Session. + + See also the :class:`.sessionmaker` function which is used to + generate a :class:`.Session`-producing callable with a given + set of arguments. + + :param autocommit: + + .. warning:: + + The autocommit flag is **not for general use**, and if it is used, + queries should only be invoked within the span of a + :meth:`.Session.begin` / :meth:`.Session.commit` pair. Executing + queries outside of a demarcated transaction is a legacy mode + of usage, and can in some cases lead to concurrent connection + checkouts. + + Defaults to ``False``. When ``True``, the + :class:`.Session` does not keep a persistent transaction running, and + will acquire connections from the engine on an as-needed basis, + returning them immediately after their use. Flushes will begin and + commit (or possibly rollback) their own transaction if no + transaction is present. When using this mode, the + :meth:`.Session.begin` method is used to explicitly start + transactions. + + .. seealso:: + + :ref:`session_autocommit` + + :param autoflush: When ``True``, all query operations will issue a + ``flush()`` call to this ``Session`` before proceeding. This is a + convenience feature so that ``flush()`` need not be called + repeatedly in order for database queries to retrieve results. It's + typical that ``autoflush`` is used in conjunction with + ``autocommit=False``. In this scenario, explicit calls to + ``flush()`` are rarely needed; you usually only need to call + ``commit()`` (which flushes) to finalize changes. + + :param bind: An optional ``Engine`` or ``Connection`` to which this + ``Session`` should be bound. When specified, all SQL operations + performed by this session will execute via this connectable. + + :param binds: An optional dictionary which contains more granular + "bind" information than the ``bind`` parameter provides. This + dictionary can map individual ``Table`` instances as well as + ``Mapper`` instances to individual ``Engine`` or ``Connection`` + objects. Operations which proceed relative to a particular + ``Mapper`` will consult this dictionary for the direct ``Mapper`` + instance as well as the mapper's ``mapped_table`` attribute in + order to locate an connectable to use. The full resolution is + described in the ``get_bind()`` method of ``Session``. + Usage looks like:: + + Session = sessionmaker(binds={ + SomeMappedClass: create_engine('postgresql://engine1'), + somemapper: create_engine('postgresql://engine2'), + some_table: create_engine('postgresql://engine3'), + }) + + Also see the :meth:`.Session.bind_mapper` + and :meth:`.Session.bind_table` methods. + + :param \class_: Specify an alternate class other than + ``sqlalchemy.orm.session.Session`` which should be used by the + returned class. This is the only argument that is local to the + ``sessionmaker()`` function, and is not sent directly to the + constructor for ``Session``. + + :param _enable_transaction_accounting: Defaults to ``True``. A + legacy-only flag which when ``False`` disables *all* 0.5-style + object accounting on transaction boundaries, including auto-expiry + of instances on rollback and commit, maintenance of the "new" and + "deleted" lists upon rollback, and autoflush of pending changes upon + begin(), all of which are interdependent. + + :param expire_on_commit: Defaults to ``True``. When ``True``, all + instances will be fully expired after each ``commit()``, so that + all attribute/object access subsequent to a completed transaction + will load from the most recent database state. + + :param extension: An optional + :class:`~.SessionExtension` instance, or a list + of such instances, which will receive pre- and post- commit and + flush events, as well as a post-rollback event. **Deprecated.** + Please see :class:`.SessionEvents`. + + :param info: optional dictionary of arbitrary data to be associated + with this :class:`.Session`. Is available via the :attr:`.Session.info` + attribute. Note the dictionary is copied at construction time so + that modifications to the per-:class:`.Session` dictionary will be local + to that :class:`.Session`. + + .. versionadded:: 0.9.0 + + :param query_cls: Class which should be used to create new Query + objects, as returned by the ``query()`` method. Defaults to + :class:`~sqlalchemy.orm.query.Query`. + + :param twophase: When ``True``, all transactions will be started as + a "two phase" transaction, i.e. using the "two phase" semantics + of the database in use along with an XID. During a ``commit()``, + after ``flush()`` has been issued for all attached databases, the + ``prepare()`` method on each database's ``TwoPhaseTransaction`` + will be called. This allows each database to roll back the entire + transaction, before each transaction is committed. + + :param weak_identity_map: Defaults to ``True`` - when set to + ``False``, objects placed in the :class:`.Session` will be + strongly referenced until explicitly removed or the + :class:`.Session` is closed. **Deprecated** - this option + is obsolete. + + """ + + if weak_identity_map: + self._identity_cls = identity.WeakInstanceDict + else: + util.warn_deprecated("weak_identity_map=False is deprecated. " + "This feature is not needed.") + self._identity_cls = identity.StrongInstanceDict + self.identity_map = self._identity_cls() + + self._new = {} # InstanceState->object, strong refs object + self._deleted = {} # same + self.bind = bind + self.__binds = {} + self._flushing = False + self._warn_on_events = False + self.transaction = None + self.hash_key = _new_sessionid() + self.autoflush = autoflush + self.autocommit = autocommit + self.expire_on_commit = expire_on_commit + self._enable_transaction_accounting = _enable_transaction_accounting + self.twophase = twophase + self._query_cls = query_cls + if info: + self.info.update(info) + + if extension: + for ext in util.to_list(extension): + SessionExtension._adapt_listener(self, ext) + + if binds is not None: + for mapperortable, bind in binds.items(): + insp = inspect(mapperortable) + if insp.is_selectable: + self.bind_table(mapperortable, bind) + elif insp.is_mapper: + self.bind_mapper(mapperortable, bind) + else: + assert False + + + if not self.autocommit: + self.begin() + _sessions[self.hash_key] = self + + connection_callable = None + + transaction = None + """The current active or inactive :class:`.SessionTransaction`.""" + + @util.memoized_property + def info(self): + """A user-modifiable dictionary. + + The initial value of this dictioanry can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + .. versionadded:: 0.9.0 + + """ + return {} + + def begin(self, subtransactions=False, nested=False): + """Begin a transaction on this Session. + + If this Session is already within a transaction, either a plain + transaction or nested transaction, an error is raised, unless + ``subtransactions=True`` or ``nested=True`` is specified. + + The ``subtransactions=True`` flag indicates that this + :meth:`~.Session.begin` can create a subtransaction if a transaction + is already in progress. For documentation on subtransactions, please + see :ref:`session_subtransactions`. + + The ``nested`` flag begins a SAVEPOINT transaction and is equivalent + to calling :meth:`~.Session.begin_nested`. For documentation on + SAVEPOINT transactions, please see :ref:`session_begin_nested`. + + """ + if self.transaction is not None: + if subtransactions or nested: + self.transaction = self.transaction._begin( + nested=nested) + else: + raise sa_exc.InvalidRequestError( + "A transaction is already begun. Use " + "subtransactions=True to allow subtransactions.") + else: + self.transaction = SessionTransaction( + self, nested=nested) + return self.transaction # needed for __enter__/__exit__ hook + + def begin_nested(self): + """Begin a `nested` transaction on this Session. + + The target database(s) must support SQL SAVEPOINTs or a + SQLAlchemy-supported vendor implementation of the idea. + + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. + + """ + return self.begin(nested=True) + + def rollback(self): + """Rollback the current transaction in progress. + + If no transaction is in progress, this method is a pass-through. + + This method rolls back the current transaction or nested transaction + regardless of subtransactions being in effect. All subtransactions up + to the first real transaction are closed. Subtransactions occur when + begin() is called multiple times. + + .. seealso:: + + :ref:`session_rollback` + + """ + if self.transaction is None: + pass + else: + self.transaction.rollback() + + def commit(self): + """Flush pending changes and commit the current transaction. + + If no transaction is in progress, this method raises an + :exc:`~sqlalchemy.exc.InvalidRequestError`. + + By default, the :class:`.Session` also expires all database + loaded state on all ORM-managed attributes after transaction commit. + This so that subsequent operations load the most recent + data from the database. This behavior can be disabled using + the ``expire_on_commit=False`` option to :class:`.sessionmaker` or + the :class:`.Session` constructor. + + If a subtransaction is in effect (which occurs when begin() is called + multiple times), the subtransaction will be closed, and the next call + to ``commit()`` will operate on the enclosing transaction. + + When using the :class:`.Session` in its default mode of + ``autocommit=False``, a new transaction will + be begun immediately after the commit, but note that the newly begun + transaction does *not* use any connection resources until the first + SQL is actually emitted. + + .. seealso:: + + :ref:`session_committing` + + """ + if self.transaction is None: + if not self.autocommit: + self.begin() + else: + raise sa_exc.InvalidRequestError("No transaction is begun.") + + self.transaction.commit() + + def prepare(self): + """Prepare the current transaction in progress for two phase commit. + + If no transaction is in progress, this method raises an + :exc:`~sqlalchemy.exc.InvalidRequestError`. + + Only root transactions of two phase sessions can be prepared. If the + current transaction is not such, an + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if self.transaction is None: + if not self.autocommit: + self.begin() + else: + raise sa_exc.InvalidRequestError("No transaction is begun.") + + self.transaction.prepare() + + def connection(self, mapper=None, clause=None, + bind=None, + close_with_result=False, + **kw): + """Return a :class:`.Connection` object corresponding to this + :class:`.Session` object's transactional state. + + If this :class:`.Session` is configured with ``autocommit=False``, + either the :class:`.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`.Connection` returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). + + Alternatively, if this :class:`.Session` is configured with + ``autocommit=True``, an ad-hoc :class:`.Connection` is returned + using :meth:`.Engine.contextual_connect` on the underlying + :class:`.Engine`. + + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. + + :param bind: + Optional :class:`.Engine` to be used as the bind. If + this engine is already involved in an ongoing transaction, + that connection will be used. This argument takes precedence + over ``mapper``, ``clause``. + + :param mapper: + Optional :func:`.mapper` mapped class, used to identify + the appropriate bind. This argument takes precedence over + ``clause``. + + :param clause: + A :class:`.ClauseElement` (i.e. :func:`~.sql.expression.select`, + :func:`~.sql.expression.text`, + etc.) which will be used to locate a bind, if a bind + cannot otherwise be identified. + + :param close_with_result: Passed to :meth:`.Engine.connect`, indicating + the :class:`.Connection` should be considered "single use", + automatically closing when the first result set is closed. This + flag only has an effect if this :class:`.Session` is configured with + ``autocommit=True`` and does not already have a transaction + in progress. + + :param \**kw: + Additional keyword arguments are sent to :meth:`get_bind()`, + allowing additional arguments to be passed to custom + implementations of :meth:`get_bind`. + + """ + if bind is None: + bind = self.get_bind(mapper, clause=clause, **kw) + + return self._connection_for_bind(bind, + close_with_result=close_with_result) + + def _connection_for_bind(self, engine, **kwargs): + if self.transaction is not None: + return self.transaction._connection_for_bind(engine) + else: + return engine.contextual_connect(**kwargs) + + def execute(self, clause, params=None, mapper=None, bind=None, **kw): + """Execute a SQL expression construct or string statement within + the current transaction. + + Returns a :class:`.ResultProxy` representing + results of the statement execution, in the same manner as that of an + :class:`.Engine` or + :class:`.Connection`. + + E.g.:: + + result = session.execute( + user_table.select().where(user_table.c.id == 5) + ) + + :meth:`~.Session.execute` accepts any executable clause construct, such + as :func:`~.sql.expression.select`, + :func:`~.sql.expression.insert`, + :func:`~.sql.expression.update`, + :func:`~.sql.expression.delete`, and + :func:`~.sql.expression.text`. Plain SQL strings can be passed + as well, which in the case of :meth:`.Session.execute` only + will be interpreted the same as if it were passed via a + :func:`~.expression.text` construct. That is, the following usage:: + + result = session.execute( + "SELECT * FROM user WHERE id=:param", + {"param":5} + ) + + is equivalent to:: + + from sqlalchemy import text + result = session.execute( + text("SELECT * FROM user WHERE id=:param"), + {"param":5} + ) + + The second positional argument to :meth:`.Session.execute` is an + optional parameter set. Similar to that of + :meth:`.Connection.execute`, whether this is passed as a single + dictionary, or a list of dictionaries, determines whether the DBAPI + cursor's ``execute()`` or ``executemany()`` is used to execute the + statement. An INSERT construct may be invoked for a single row:: + + result = session.execute(users.insert(), {"id": 7, "name": "somename"}) + + or for multiple rows:: + + result = session.execute(users.insert(), [ + {"id": 7, "name": "somename7"}, + {"id": 8, "name": "somename8"}, + {"id": 9, "name": "somename9"} + ]) + + The statement is executed within the current transactional context of + this :class:`.Session`. The :class:`.Connection` which is used + to execute the statement can also be acquired directly by + calling the :meth:`.Session.connection` method. Both methods use + a rule-based resolution scheme in order to determine the + :class:`.Connection`, which in the average case is derived directly + from the "bind" of the :class:`.Session` itself, and in other cases + can be based on the :func:`.mapper` + and :class:`.Table` objects passed to the method; see the documentation + for :meth:`.Session.get_bind` for a full description of this scheme. + + The :meth:`.Session.execute` method does *not* invoke autoflush. + + The :class:`.ResultProxy` returned by the :meth:`.Session.execute` + method is returned with the "close_with_result" flag set to true; + the significance of this flag is that if this :class:`.Session` is + autocommitting and does not have a transaction-dedicated + :class:`.Connection` available, a temporary :class:`.Connection` is + established for the statement execution, which is closed (meaning, + returned to the connection pool) when the :class:`.ResultProxy` has + consumed all available data. This applies *only* when the + :class:`.Session` is configured with autocommit=True and no + transaction has been started. + + :param clause: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`.expression.select`) or string SQL statement + to be executed. + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param mapper: + Optional :func:`.mapper` or mapped class, used to identify + the appropriate bind. This argument takes precedence over + ``clause`` when locating a bind. See :meth:`.Session.get_bind` + for more details. + + :param bind: + Optional :class:`.Engine` to be used as the bind. If + this engine is already involved in an ongoing transaction, + that connection will be used. This argument takes + precedence over ``mapper`` and ``clause`` when locating + a bind. + + :param \**kw: + Additional keyword arguments are sent to :meth:`.Session.get_bind()` + to allow extensibility of "bind" schemes. + + .. seealso:: + + :ref:`sqlexpression_toplevel` - Tutorial on using Core SQL + constructs. + + :ref:`connections_toplevel` - Further information on direct + statement execution. + + :meth:`.Connection.execute` - core level statement execution + method, which is :meth:`.Session.execute` ultimately uses + in order to execute the statement. + + """ + clause = expression._literal_as_text(clause) + + if bind is None: + bind = self.get_bind(mapper, clause=clause, **kw) + + return self._connection_for_bind(bind, close_with_result=True).execute( + clause, params or {}) + + def scalar(self, clause, params=None, mapper=None, bind=None, **kw): + """Like :meth:`~.Session.execute` but return a scalar result.""" + + return self.execute( + clause, params=params, mapper=mapper, bind=bind, **kw).scalar() + + def close(self): + """Close this Session. + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + """ + self.expunge_all() + if self.transaction is not None: + for transaction in self.transaction._iterate_parents(): + transaction.close() + + def expunge_all(self): + """Remove all object instances from this ``Session``. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + """ + for state in self.identity_map.all_states() + list(self._new): + state._detach() + + self.identity_map = self._identity_cls() + self._new = {} + self._deleted = {} + + # TODO: need much more test coverage for bind_mapper() and similar ! + # TODO: + crystalize + document resolution order + # vis. bind_mapper/bind_table + + def bind_mapper(self, mapper, bind): + """Bind operations for a mapper to a Connectable. + + mapper + A mapper instance or mapped class + + bind + Any Connectable: a ``Engine`` or ``Connection``. + + All subsequent operations involving this mapper will use the given + `bind`. + + """ + if isinstance(mapper, type): + mapper = class_mapper(mapper) + + self.__binds[mapper.base_mapper] = bind + for t in mapper._all_tables: + self.__binds[t] = bind + + def bind_table(self, table, bind): + """Bind operations on a Table to a Connectable. + + table + A ``Table`` instance + + bind + Any Connectable: a ``Engine`` or ``Connection``. + + All subsequent operations involving this ``Table`` will use the + given `bind`. + + """ + self.__binds[table] = bind + + def get_bind(self, mapper=None, clause=None): + """Return a "bind" to which this :class:`.Session` is bound. + + The "bind" is usually an instance of :class:`.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`.Connection`. + + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. + + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and session.binds is present, + locate a bind based on mapper. + 2. if clause given and session.binds is present, + locate a bind based on :class:`.Table` objects + found in the given clause present in session.binds. + 3. if session.bind is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`.MetaData` ultimately + associated with the :class:`.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + :param mapper: + Optional :func:`.mapper` mapped class or instance of + :class:`.Mapper`. The bind can be derived from a :class:`.Mapper` + first by consulting the "binds" map associated with this + :class:`.Session`, and secondly by consulting the :class:`.MetaData` + associated with the :class:`.Table` to which the :class:`.Mapper` + is mapped for a bind. + + :param clause: + A :class:`.ClauseElement` (i.e. :func:`~.sql.expression.select`, + :func:`~.sql.expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`.Table` associated with + bound :class:`.MetaData`. + + """ + if mapper is clause is None: + if self.bind: + return self.bind + else: + raise sa_exc.UnboundExecutionError( + "This session is not bound to a single Engine or " + "Connection, and no context was provided to locate " + "a binding.") + + c_mapper = mapper is not None and _class_to_mapper(mapper) or None + + # manually bound? + if self.__binds: + if c_mapper: + if c_mapper.base_mapper in self.__binds: + return self.__binds[c_mapper.base_mapper] + elif c_mapper.mapped_table in self.__binds: + return self.__binds[c_mapper.mapped_table] + if clause is not None: + for t in sql_util.find_tables(clause, include_crud=True): + if t in self.__binds: + return self.__binds[t] + + if self.bind: + return self.bind + + if isinstance(clause, sql.expression.ClauseElement) and clause.bind: + return clause.bind + + if c_mapper and c_mapper.mapped_table.bind: + return c_mapper.mapped_table.bind + + context = [] + if mapper is not None: + context.append('mapper %s' % c_mapper) + if clause is not None: + context.append('SQL expression') + + raise sa_exc.UnboundExecutionError( + "Could not locate a bind configured on %s or this Session" % ( + ', '.join(context))) + + def query(self, *entities, **kwargs): + """Return a new ``Query`` object corresponding to this ``Session``.""" + + return self._query_cls(entities, self, **kwargs) + + @property + @util.contextmanager + def no_autoflush(self): + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + .. versionadded:: 0.7.6 + + """ + autoflush = self.autoflush + self.autoflush = False + yield self + self.autoflush = autoflush + + def _autoflush(self): + if self.autoflush and not self._flushing: + try: + self.flush() + except sa_exc.StatementError as e: + # note we are reraising StatementError as opposed to + # raising FlushError with "chaining" to remain compatible + # with code that catches StatementError, IntegrityError, + # etc. + e.add_detail( + "raised as a result of Query-invoked autoflush; " + "consider using a session.no_autoflush block if this " + "flush is occuring prematurely") + util.raise_from_cause(e) + + def refresh(self, instance, attribute_names=None, lockmode=None): + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + Lazy-loaded relational attributes will remain lazily loaded, so that + the instance-wide refresh operation will be followed immediately by + the lazy load of that attribute. + + Eagerly-loaded relational attributes will eagerly load within the + single refresh operation. + + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction - usage of + :meth:`~Session.refresh` usually only makes sense if non-ORM SQL + statement were emitted in the ongoing transaction, or if autocommit + mode is turned on. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param lockmode: Passed to the :class:`~sqlalchemy.orm.query.Query` + as used by :meth:`~sqlalchemy.orm.query.Query.with_lockmode`. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + self._expire_state(state, attribute_names) + + if loading.load_on_ident( + self.query(object_mapper(instance)), + state.key, refresh_state=state, + lockmode=lockmode, + only_load_props=attribute_names) is None: + raise sa_exc.InvalidRequestError( + "Could not refresh instance '%s'" % + instance_str(instance)) + + def expire_all(self): + """Expires all persistent instances within this Session. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` should not be needed when + autocommit is ``False``, assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + """ + for state in self.identity_map.all_states(): + state._expire(state.dict, self.identity_map._modified) + + def expire(self, instance, attribute_names=None): + """Expire the attributes on an instance. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + self._expire_state(state, attribute_names) + + def _expire_state(self, state, attribute_names): + self._validate_persistent(state) + if attribute_names: + state._expire_attributes(state.dict, attribute_names) + else: + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list(state.manager.mapper.cascade_iterator( + 'refresh-expire', state)) + self._conditional_expire(state) + for o, m, st_, dct_ in cascaded: + self._conditional_expire(st_) + + def _conditional_expire(self, state): + """Expire a state if persistent, else expunge if pending""" + + if state.key: + state._expire(state.dict, self.identity_map._modified) + elif state in self._new: + self._new.pop(state) + state._detach() + + @util.deprecated("0.7", "The non-weak-referencing identity map " + "feature is no longer needed.") + def prune(self): + """Remove unreferenced instances cached in the identity map. + + Note that this method is only meaningful if "weak_identity_map" is set + to False. The default weak identity map is self-pruning. + + Removes any object in this Session's identity map that is not + referenced in user code, modified, new or scheduled for deletion. + Returns the number of objects pruned. + + """ + return self.identity_map.prune() + + def expunge(self, instance): + """Remove the `instance` from this ``Session``. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + if state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Instance %s is not present in this Session" % + state_str(state)) + + cascaded = list(state.manager.mapper.cascade_iterator( + 'expunge', state)) + self._expunge_state(state) + for o, m, st_, dct_ in cascaded: + self._expunge_state(st_) + + def _expunge_state(self, state): + if state in self._new: + self._new.pop(state) + state._detach() + elif self.identity_map.contains_state(state): + self.identity_map.discard(state) + self._deleted.pop(state, None) + state._detach() + elif self.transaction: + self.transaction._deleted.pop(state, None) + + def _register_newly_persistent(self, states): + for state in states: + mapper = _state_mapper(state) + + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + + instance_key = mapper._identity_key_from_state(state) + + if _none_set.issubset(instance_key[1]) and \ + not mapper.allow_partial_pks or \ + _none_set.issuperset(instance_key[1]): + raise exc.FlushError( + "Instance %s has a NULL identity key. If this is an " + "auto-generated value, check that the database table " + "allows generation of new primary key values, and " + "that the mapped Column object is configured to " + "expect these generated values. Ensure also that " + "this flush() is not occurring at an inappropriate " + "time, such aswithin a load() event." + % state_str(state) + ) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. use discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.discard(state) + if state in self.transaction._key_switches: + orig_key = self.transaction._key_switches[state][0] + else: + orig_key = state.key + self.transaction._key_switches[state] = ( + orig_key, instance_key) + state.key = instance_key + + self.identity_map.replace(state) + + statelib.InstanceState._commit_all_states( + ((state, state.dict) for state in states), + self.identity_map + ) + + self._register_altered(states) + # remove from new last, might be the last strong ref + for state in set(states).intersection(self._new): + self._new.pop(state) + + def _register_altered(self, states): + if self._enable_transaction_accounting and self.transaction: + for state in states: + if state in self._new: + self.transaction._new[state] = True + else: + self.transaction._dirty[state] = True + + def _remove_newly_deleted(self, states): + for state in states: + if self._enable_transaction_accounting and self.transaction: + self.transaction._deleted[state] = True + + self.identity_map.discard(state) + self._deleted.pop(state, None) + state.deleted = True + + def add(self, instance, _warn=True): + """Place an object in the ``Session``. + + Its state will be persisted to the database on the next flush + operation. + + Repeated calls to ``add()`` will be ignored. The opposite of ``add()`` + is ``expunge()``. + + """ + if _warn and self._warn_on_events: + self._flush_warning("Session.add()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + self._save_or_update_state(state) + + def add_all(self, instances): + """Add the given collection of instances to this ``Session``.""" + + if self._warn_on_events: + self._flush_warning("Session.add_all()") + + for instance in instances: + self.add(instance, _warn=False) + + def _save_or_update_state(self, state): + self._save_or_update_impl(state) + + mapper = _state_mapper(state) + for o, m, st_, dct_ in mapper.cascade_iterator( + 'save-update', + state, + halt_on=self._contains_state): + self._save_or_update_impl(st_) + + def delete(self, instance): + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + """ + if self._warn_on_events: + self._flush_warning("Session.delete()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + state_str(state)) + + if state in self._deleted: + return + + # ensure object is attached to allow the + # cascade operation to load deferred attributes + # and collections + self._attach(state, include_before=True) + + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + # the strong reference to the instance itself is significant here + cascade_states = list(state.manager.mapper.cascade_iterator( + 'delete', state)) + + self._deleted[state] = state.obj() + self.identity_map.add(state) + + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_) + + def merge(self, instance, load=True): + """Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target instance. + The resulting target instance is then returned by the method; the + original source instance is left unmodified, and un-associated with the + :class:`.Session` if not already. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the method. + + """ + + if self._warn_on_events: + self._flush_warning("Session.merge()") + + _recursive = {} + + if load: + # flush current contents if we expect to load data + self._autoflush() + + object_mapper(instance) # verify mapped + autoflush = self.autoflush + try: + self.autoflush = False + return self._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive=_recursive) + finally: + self.autoflush = autoflush + + def _merge(self, state, state_dict, load=True, _recursive=None): + mapper = _state_mapper(state) + if state in _recursive: + return _recursive[state] + + new_instance = False + key = state.key + + if key is None: + if not load: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects transient (i.e. unpersisted) objects. flush() " + "all changes on mapped instances before merging with " + "load=False.") + key = mapper._identity_key_from_state(state) + + if key in self.identity_map: + merged = self.identity_map[key] + + elif not load: + if state.modified: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects marked as 'dirty'. flush() all changes on " + "mapped instances before merging with load=False.") + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_state.key = key + self._update_impl(merged_state) + new_instance = True + + elif not _none_set.issubset(key[1]) or \ + (mapper.allow_partial_pks and + not _none_set.issuperset(key[1])): + merged = self.query(mapper.class_).get(key[1]) + else: + merged = None + + if merged is None: + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + new_instance = True + self._save_or_update_state(merged_state) + else: + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + + _recursive[state] = merged + + # check that we didn't just pull the exact same + # state out. + if state is not merged_state: + # version check if applicable + if mapper.version_id_col is not None: + existing_version = mapper._get_state_attr_by_column( + state, + state_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + merged_version = mapper._get_state_attr_by_column( + merged_state, + merged_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + if existing_version is not attributes.PASSIVE_NO_RESULT and \ + merged_version is not attributes.PASSIVE_NO_RESULT and \ + existing_version != merged_version: + raise exc.StaleDataError( + "Version id '%s' on merged state %s " + "does not match existing version '%s'. " + "Leave the version attribute unset when " + "merging to update the most recent version." + % ( + existing_version, + state_str(merged_state), + merged_version + )) + + merged_state.load_path = state.load_path + merged_state.load_options = state.load_options + + for prop in mapper.iterate_properties: + prop.merge(self, state, state_dict, + merged_state, merged_dict, + load, _recursive) + + if not load: + # remove any history + merged_state._commit_all(merged_dict, self.identity_map) + + if new_instance: + merged_state.manager.dispatch.load(merged_state, None) + return merged + + def _validate_persistent(self, state): + if not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persistent within this Session" % + state_str(state)) + + def _save_impl(self, state): + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Object '%s' already has an identity - it can't be registered " + "as pending" % state_str(state)) + + self._before_attach(state) + if state not in self._new: + self._new[state] = state.obj() + state.insert_order = len(self._new) + self._attach(state) + + def _update_impl(self, state, discard_existing=False): + if (self.identity_map.contains_state(state) and + state not in self._deleted): + return + + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + state_str(state)) + + if state.deleted: + raise sa_exc.InvalidRequestError( + "Instance '%s' has been deleted. Use the make_transient() " + "function to send this object back to the transient state." % + state_str(state) + ) + self._before_attach(state) + self._deleted.pop(state, None) + if discard_existing: + self.identity_map.replace(state) + else: + self.identity_map.add(state) + self._attach(state) + + def _save_or_update_impl(self, state): + if state.key is None: + self._save_impl(state) + else: + self._update_impl(state) + + def _delete_impl(self, state): + if state in self._deleted: + return + + if state.key is None: + return + + self._attach(state, include_before=True) + self._deleted[state] = state.obj() + self.identity_map.add(state) + + def enable_relationship_loading(self, obj): + """Associate an object with this :class:`.Session` for related + object loading. + + .. warning:: + + :meth:`.enable_relationship_loading` exists to serve special + use cases and is not recommended for general use. + + Accesses of attributes mapped with :func:`.relationship` + will attempt to load a value from the database using this + :class:`.Session` as the source of connectivity. The values + will be loaded based on foreign key values present on this + object - it follows that this functionality + generally only works for many-to-one-relationships. + + The object will be attached to this session, but will + **not** participate in any persistence operations; its state + for almost all purposes will remain either "transient" or + "detached", except for the case of relationship loading. + + Also note that backrefs will often not work as expected. + Altering a relationship-bound attribute on the target object + may not fire off a backref event, if the effective value + is what was already loaded from a foreign-key-holding value. + + The :meth:`.Session.enable_relationship_loading` method is + similar to the ``load_on_pending`` flag on :func:`.relationship`. Unlike + that flag, :meth:`.Session.enable_relationship_loading` allows + an object to remain transient while still being able to load + related items. + + To make a transient object associated with a :class:`.Session` + via :meth:`.Session.enable_relationship_loading` pending, add + it to the :class:`.Session` using :meth:`.Session.add` normally. + + :meth:`.Session.enable_relationship_loading` does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before flush() + proceeds. This method is not intended for general use. + + .. versionadded:: 0.8 + + .. seealso:: + + ``load_on_pending`` at :func:`.relationship` - this flag + allows per-relationship loading of many-to-ones on items that + are pending. + + """ + state = attributes.instance_state(obj) + self._attach(state, include_before=True) + state._load_pending = True + + def _before_attach(self, state): + if state.session_id != self.hash_key and \ + self.dispatch.before_attach: + self.dispatch.before_attach(self, state.obj()) + + def _attach(self, state, include_before=False): + if state.key and \ + state.key in self.identity_map and \ + not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError("Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." + % (state_str(state), state.key)) + + if state.session_id and \ + state.session_id is not self.hash_key and \ + state.session_id in _sessions: + raise sa_exc.InvalidRequestError( + "Object '%s' is already attached to session '%s' " + "(this is '%s')" % (state_str(state), + state.session_id, self.hash_key)) + + if state.session_id != self.hash_key: + if include_before and \ + self.dispatch.before_attach: + self.dispatch.before_attach(self, state.obj()) + state.session_id = self.hash_key + if state.modified and state._strong_obj is None: + state._strong_obj = state.obj() + if self.dispatch.after_attach: + self.dispatch.after_attach(self, state.obj()) + + def __contains__(self, instance): + """Return True if the instance is associated with this session. + + The instance may be pending or persistent within the Session for a + result of True. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + return self._contains_state(state) + + def __iter__(self): + """Iterate over all pending or persistent instances within this + Session. + + """ + return iter(list(self._new.values()) + list(self.identity_map.values())) + + def _contains_state(self, state): + return state in self._new or self.identity_map.contains_state(state) + + def flush(self, objects=None): + """Flush all the object changes to the database. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. + + For ``autocommit`` Sessions with no active manual transaction, flush() + will create a transaction on the fly that surrounds the entire set of + operations int the flush. + + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. + + """ + + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + + def _flush_warning(self, method): + util.warn( + "Usage of the '%s' operation is not currently supported " + "within the execution stage of the flush process. " + "Results may not be consistent. Consider using alternative " + "event listeners or connection-level operations instead." + % method) + + def _is_clean(self): + return not self.identity_map.check_modified() and \ + not self._deleted and \ + not self._new + + def _flush(self, objects=None): + + dirty = self._dirty_states + if not dirty and not self._deleted and not self._new: + self.identity_map._modified.clear() + return + + flush_context = UOWTransaction(self) + + if self.dispatch.before_flush: + self.dispatch.before_flush(self, flush_context, objects) + # re-establish "dirty states" in case the listeners + # added + dirty = self._dirty_states + + deleted = set(self._deleted) + new = set(self._new) + + dirty = set(dirty).difference(deleted) + + # create the set of all objects we want to operate upon + if objects: + # specific list passed in + objset = set() + for o in objects: + try: + state = attributes.instance_state(o) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(o) + objset.add(state) + else: + objset = None + + # store objects whose fate has been decided + processed = set() + + # put all saves/updates into the flush context. detect top-level + # orphans and throw them into deleted. + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: + is_orphan = ( + _state_mapper(state)._is_orphan(state) and state.has_identity) + flush_context.register_object(state, isdelete=is_orphan) + processed.add(state) + + # put all remaining deletes into the flush context. + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: + flush_context.register_object(state, isdelete=True) + + if not flush_context.has_work: + return + + flush_context.transaction = transaction = self.begin( + subtransactions=True) + try: + self._warn_on_events = True + try: + flush_context.execute() + finally: + self._warn_on_events = False + + self.dispatch.after_flush(self, flush_context) + + flush_context.finalize_flush_changes() + + if not objects and self.identity_map._modified: + len_ = len(self.identity_map._modified) + + statelib.InstanceState._commit_all_states( + [(state, state.dict) for state in + self.identity_map._modified], + instance_dict=self.identity_map) + util.warn("Attribute history events accumulated on %d " + "previously clean instances " + "within inner-flush event handlers have been reset, " + "and will not result in database updates. " + "Consider using set_committed_value() within " + "inner-flush event handlers to avoid this warning." + % len_) + + # useful assertions: + #if not objects: + # assert not self.identity_map._modified + #else: + # assert self.identity_map._modified == \ + # self.identity_map._modified.difference(objects) + + self.dispatch.after_flush_postexec(self, flush_context) + + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + + def is_modified(self, instance, include_collections=True, + passive=True): + """Return ``True`` if the given instance has locally + modified attributes. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + .. versionchanged:: 0.8 + When using SQLAlchemy 0.7 and earlier, the ``passive`` + flag should **always** be explicitly set to ``True``, + else SQL loads/autoflushes may proceed which can affect + the modified state itself: + ``session.is_modified(someobject, passive=True)``\ . + In 0.8 and above, the behavior is corrected and + this flag is ignored. + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may report + ``False`` when tested with this method. This is because + the object may have received change events via attribute + mutation, thus placing it in :attr:`.Session.dirty`, + but ultimately the state is the same as that loaded from + the database, resulting in no net change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + :param passive: + .. versionchanged:: 0.8 + Ignored for backwards compatibility. + When using SQLAlchemy 0.7 and earlier, this flag should always + be set to ``True``. + + """ + state = object_state(instance) + + if not state.modified: + return False + + dict_ = state.dict + + for attr in state.manager.attributes: + if \ + ( + not include_collections and + hasattr(attr.impl, 'get_collection') + ) or not hasattr(attr.impl, 'get_history'): + continue + + (added, unchanged, deleted) = \ + attr.impl.get_history(state, dict_, + passive=attributes.NO_CHANGE) + + if added or deleted: + return True + else: + return False + + @property + def is_active(self): + """True if this :class:`.Session` is in "transaction mode" and + is not in "partial rollback" state. + + The :class:`.Session` in its default mode of ``autocommit=False`` + is essentially always in "transaction mode", in that a + :class:`.SessionTransaction` is associated with it as soon as + it is instantiated. This :class:`.SessionTransaction` is immediately + replaced with a new one as soon as it is ended, due to a rollback, + commit, or close operation. + + "Transaction mode" does *not* indicate whether + or not actual database connection resources are in use; the + :class:`.SessionTransaction` object coordinates among zero or more + actual database transactions, and starts out with none, accumulating + individual DBAPI connections as different data sources are used + within its scope. The best way to track when a particular + :class:`.Session` has actually begun to use DBAPI resources is to + implement a listener using the :meth:`.SessionEvents.after_begin` + method, which will deliver both the :class:`.Session` as well as the + target :class:`.Connection` to a user-defined event listener. + + The "partial rollback" state refers to when an "inner" transaction, + typically used during a flush, encounters an error and emits a + rollback of the DBAPI connection. At this point, the + :class:`.Session` is in "partial rollback" and awaits for the user to + call :meth:`.Session.rollback`, in order to close out the + transaction stack. It is in this "partial rollback" period that the + :attr:`.is_active` flag returns False. After the call to + :meth:`.Session.rollback`, the :class:`.SessionTransaction` is replaced + with a new one and :attr:`.is_active` returns ``True`` again. + + When a :class:`.Session` is used in ``autocommit=True`` mode, the + :class:`.SessionTransaction` is only instantiated within the scope + of a flush call, or when :meth:`.Session.begin` is called. So + :attr:`.is_active` will always be ``False`` outside of a flush or + :meth:`.Session.begin` block in this mode, and will be ``True`` + within the :meth:`.Session.begin` block as long as it doesn't enter + "partial rollback" state. + + From all the above, it follows that the only purpose to this flag is + for application frameworks that wish to detect is a "rollback" is + necessary within a generic error handling routine, for + :class:`.Session` objects that would otherwise be in + "partial rollback" mode. In a typical integration case, this is also + not necessary as it is standard practice to emit + :meth:`.Session.rollback` unconditionally within the outermost + exception catch. + + To track the transactional state of a :class:`.Session` fully, + use event listeners, primarily the :meth:`.SessionEvents.after_begin`, + :meth:`.SessionEvents.after_commit`, + :meth:`.SessionEvents.after_rollback` and related events. + + """ + return self.transaction and self.transaction.is_active + + identity_map = None + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + + @property + def _dirty_states(self): + """The set of all persistent states considered dirty. + + This method returns all states that were modified including + those that were possibly deleted. + + """ + return self.identity_map._dirty_states() + + @property + def dirty(self): + """The set of all persistent instances considered dirty. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + """ + return util.IdentitySet( + [state.obj() + for state in self._dirty_states + if state not in self._deleted]) + + @property + def deleted(self): + "The set of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(list(self._deleted.values())) + + @property + def new(self): + "The set of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(list(self._new.values())) + + +class sessionmaker(_SessionClassMethods): + """A configurable :class:`.Session` factory. + + The :class:`.sessionmaker` factory generates new + :class:`.Session` objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + # global scope + Session = sessionmaker(autoflush=False) + + # later, in a local scope, create and use a session: + sess = Session() + + Any keyword arguments sent to the constructor itself will override the + "configured" keywords:: + + Session = sessionmaker() + + # bind an individual session to a connection + sess = Session(bind=connection) + + The class also includes a method :meth:`.configure`, which can + be used to specify additional keyword arguments to the factory, which + will take effect for subsequent :class:`.Session` objects generated. + This is usually used to associate one or more :class:`.Engine` objects + with an existing :class:`.sessionmaker` factory before it is first + used:: + + # application starts + Session = sessionmaker() + + # ... later + engine = create_engine('sqlite:///foo.db') + Session.configure(bind=engine) + + sess = Session() + + .. seealso: + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ + + def __init__(self, bind=None, class_=Session, autoflush=True, + autocommit=False, + expire_on_commit=True, + info=None, **kw): + """Construct a new :class:`.sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.Session.__init__` docstring for more details on parameters. + + :param bind: a :class:`.Engine` or other :class:`.Connectable` with + which newly created :class:`.Session` objects will be associated. + :param class_: class to use in order to create new :class:`.Session` + objects. Defaults to :class:`.Session`. + :param autoflush: The autoflush setting to use with newly created + :class:`.Session` objects. + :param autocommit: The autocommit setting to use with newly created + :class:`.Session` objects. + :param expire_on_commit=True: the expire_on_commit setting to use + with newly created :class:`.Session` objects. + :param info: optional dictionary of information that will be available + via :attr:`.Session.info`. Note this dictionary is *updated*, not + replaced, when the ``info`` parameter is specified to the specific + :class:`.Session` construction operation. + + .. versionadded:: 0.9.0 + + :param \**kw: all other keyword arguments are passed to the constructor + of newly created :class:`.Session` objects. + + """ + kw['bind'] = bind + kw['autoflush'] = autoflush + kw['autocommit'] = autocommit + kw['expire_on_commit'] = expire_on_commit + if info is not None: + kw['info'] = info + self.kw = kw + # make our own subclass of the given class, so that + # events can be associated with it specifically. + self.class_ = type(class_.__name__, (class_,), {}) + + def __call__(self, **local_kw): + """Produce a new :class:`.Session` object using the configuration + established in this :class:`.sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + Session = sessionmaker() + session = Session() # invokes sessionmaker.__call__() + + """ + for k, v in self.kw.items(): + if k == 'info' and 'info' in local_kw: + d = v.copy() + d.update(local_kw['info']) + local_kw['info'] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw): + """(Re)configure the arguments for this sessionmaker. + + e.g.:: + + Session = sessionmaker() + + Session.configure(bind=create_engine('sqlite://')) + """ + self.kw.update(new_kw) + + def __repr__(self): + return "%s(class_=%r,%s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()) + ) + + + +def make_transient(instance): + """Make the given instance 'transient'. + + This will remove its association with any + session and additionally will remove its "identity key", + such that it's as though the object were newly constructed, + except retaining its values. It also resets the + "deleted" flag on the state if this object + had been explicitly deleted by its session. + + Attributes which were "expired" or deferred at the + instance level are reverted to undefined, and + will not trigger any loads. + + """ + state = attributes.instance_state(instance) + s = _state_session(state) + if s: + s._expunge_state(state) + + # remove expired state and + # deferred callables + state.callables.clear() + if state.key: + del state.key + if state.deleted: + del state.deleted + + +def object_session(instance): + """Return the ``Session`` to which instance belongs. + + If the instance is not a mapped instance, an error is raised. + + """ + + try: + return _state_session(attributes.instance_state(instance)) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + + +_new_sessionid = util.counter() diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py new file mode 100644 index 00000000..fb5db1fc --- /dev/null +++ b/lib/sqlalchemy/orm/state.py @@ -0,0 +1,617 @@ +# orm/state.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines instrumentation of instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +""" + +import weakref +from .. import util +from . import exc as orm_exc, interfaces +from .path_registry import PathRegistry +from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ + NO_VALUE, PASSIVE_NO_INITIALIZE, INIT_OK, PASSIVE_OFF +from . import base + +class InstanceState(interfaces._InspectionAttr): + """tracks state information at the instance level.""" + + session_id = None + key = None + runid = None + load_options = util.EMPTY_SET + load_path = () + insert_order = None + _strong_obj = None + modified = False + expired = False + deleted = False + _load_pending = False + + is_instance = True + + def __init__(self, obj, manager): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.callables = {} + self.committed_state = {} + + @util.memoized_property + def attrs(self): + """Return a namespace representing each attribute on + the mapped object, including its current value + and history. + + The returned object is an instance of :class:`.AttributeState`. + + """ + return util.ImmutableProperties( + dict( + (key, AttributeState(self, key)) + for key in self.manager + ) + ) + + @property + def transient(self): + """Return true if the object is transient.""" + return self.key is None and \ + not self._attached + + @property + def pending(self): + """Return true if the object is pending.""" + return self.key is None and \ + self._attached + + @property + def persistent(self): + """Return true if the object is persistent.""" + return self.key is not None and \ + self._attached + + @property + def detached(self): + """Return true if the object is detached.""" + return self.key is not None and \ + not self._attached + + @property + @util.dependencies("sqlalchemy.orm.session") + def _attached(self, sessionlib): + return self.session_id is not None and \ + self.session_id in sessionlib._sessions + + @property + @util.dependencies("sqlalchemy.orm.session") + def session(self, sessionlib): + """Return the owning :class:`.Session` for this instance, + or ``None`` if none available.""" + return sessionlib._state_session(self) + + @property + def object(self): + """Return the mapped object represented by this + :class:`.InstanceState`.""" + return self.obj() + + @property + def identity(self): + """Return the mapped identity of the mapped object. + This is the primary key identity as persisted by the ORM + which can always be passed directly to + :meth:`.Query.get`. + + Returns ``None`` if the object has no primary key identity. + + .. note:: + An object which is transient or pending + does **not** have a mapped identity until it is flushed, + even if its attributes include primary key values. + + """ + if self.key is None: + return None + else: + return self.key[1] + + @property + def identity_key(self): + """Return the identity key for the mapped object. + + This is the key used to locate the object within + the :attr:`.Session.identity_map` mapping. It contains + the identity as returned by :attr:`.identity` within it. + + + """ + # TODO: just change .key to .identity_key across + # the board ? probably + return self.key + + @util.memoized_property + def parents(self): + return {} + + @util.memoized_property + def _pending_mutations(self): + return {} + + @util.memoized_property + def mapper(self): + """Return the :class:`.Mapper` used for this mapepd object.""" + return self.manager.mapper + + @property + def has_identity(self): + """Return ``True`` if this object has an identity key. + + This should always have the same value as the + expression ``state.persistent or state.detached``. + + """ + return bool(self.key) + + def _detach(self): + self.session_id = self._strong_obj = None + + def _dispose(self): + self._detach() + del self.obj + + def _cleanup(self, ref): + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.discard(self) + + self.callables = {} + self.session_id = self._strong_obj = None + del self.obj + + def obj(self): + return None + + @property + def dict(self): + o = self.obj() + if o is not None: + return base.instance_dict(o) + else: + return {} + + def _initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager + + manager.dispatch.init(self, args, kwargs) + + try: + return manager.original_init(*mixed[1:], **kwargs) + except: + manager.dispatch.init_failure(self, args, kwargs) + raise + + def get_history(self, key, passive): + return self.manager[key].impl.get_history(self, self.dict, passive) + + def get_impl(self, key): + return self.manager[key].impl + + def _get_pending_mutation(self, key): + if key not in self._pending_mutations: + self._pending_mutations[key] = PendingCollection() + return self._pending_mutations[key] + + def __getstate__(self): + state_dict = {'instance': self.obj()} + state_dict.update( + (k, self.__dict__[k]) for k in ( + 'committed_state', '_pending_mutations', 'modified', 'expired', + 'callables', 'key', 'parents', 'load_options', + 'class_', + ) if k in self.__dict__ + ) + if self.load_path: + state_dict['load_path'] = self.load_path.serialize() + + state_dict['manager'] = self.manager._serialize(self, state_dict) + + return state_dict + + def __setstate__(self, state_dict): + inst = state_dict['instance'] + if inst is not None: + self.obj = weakref.ref(inst, self._cleanup) + self.class_ = inst.__class__ + else: + # None being possible here generally new as of 0.7.4 + # due to storage of state in "parents". "class_" + # also new. + self.obj = None + self.class_ = state_dict['class_'] + + self.committed_state = state_dict.get('committed_state', {}) + self._pending_mutations = state_dict.get('_pending_mutations', {}) + self.parents = state_dict.get('parents', {}) + self.modified = state_dict.get('modified', False) + self.expired = state_dict.get('expired', False) + self.callables = state_dict.get('callables', {}) + + self.__dict__.update([ + (k, state_dict[k]) for k in ( + 'key', 'load_options', + ) if k in state_dict + ]) + + if 'load_path' in state_dict: + self.load_path = PathRegistry.\ + deserialize(state_dict['load_path']) + + state_dict['manager'](self, inst, state_dict) + + def _initialize(self, key): + """Set this attribute to an empty value or collection, + based on the AttributeImpl in use.""" + + self.manager.get_impl(key).initialize(self, self.dict) + + def _reset(self, dict_, key): + """Remove the given attribute and any + callables associated with it.""" + + old = dict_.pop(key, None) + if old is not None and self.manager[key].impl.collection: + self.manager[key].impl._invalidate_collection(old) + self.callables.pop(key, None) + + def _expire_attribute_pre_commit(self, dict_, key): + """a fast expire that can be called by column loaders during a load. + + The additional bookkeeping is finished up in commit_all(). + + Should only be called for scalar attributes. + + This method is actually called a lot with joined-table + loading, when the second table isn't present in the result. + + """ + dict_.pop(key, None) + self.callables[key] = self + + @classmethod + def _row_processor(cls, manager, fn, key): + impl = manager[key].impl + if impl.collection: + def _set_callable(state, dict_, row): + old = dict_.pop(key, None) + if old is not None: + impl._invalidate_collection(old) + state.callables[key] = fn + else: + def _set_callable(state, dict_, row): + state.callables[key] = fn + return _set_callable + + def _expire(self, dict_, modified_set): + self.expired = True + if self.modified: + modified_set.discard(self) + + self.modified = False + self._strong_obj = None + + self.committed_state.clear() + + InstanceState._pending_mutations._reset(self) + + # clear out 'parents' collection. not + # entirely clear how we can best determine + # which to remove, or not. + InstanceState.parents._reset(self) + + for key in self.manager: + impl = self.manager[key].impl + if impl.accepts_scalar_loader and \ + (impl.expire_missing or key in dict_): + self.callables[key] = self + old = dict_.pop(key, None) + if impl.collection and old is not None: + impl._invalidate_collection(old) + + self.manager.dispatch.expire(self, None) + + def _expire_attributes(self, dict_, attribute_names): + pending = self.__dict__.get('_pending_mutations', None) + + for key in attribute_names: + impl = self.manager[key].impl + if impl.accepts_scalar_loader: + self.callables[key] = self + old = dict_.pop(key, None) + if impl.collection and old is not None: + impl._invalidate_collection(old) + + self.committed_state.pop(key, None) + if pending: + pending.pop(key, None) + + self.manager.dispatch.expire(self, attribute_names) + + def __call__(self, state, passive): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + + if not passive & SQL_OK: + return PASSIVE_NO_RESULT + + toload = self.expired_attributes.\ + intersection(self.unmodified) + + self.manager.deferred_scalar_loader(self, toload) + + # if the loader failed, or this + # instance state didn't have an identity, + # the attributes still might be in the callables + # dict. ensure they are removed. + for k in toload.intersection(self.callables): + del self.callables[k] + + return ATTR_WAS_SET + + @property + def unmodified(self): + """Return the set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + def unmodified_intersection(self, keys): + """Return self.unmodified.intersection(keys).""" + + return set(keys).intersection(self.manager).\ + difference(self.committed_state) + + @property + def unloaded(self): + """Return the set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return set(self.manager).\ + difference(self.committed_state).\ + difference(self.dict) + + @property + def _unloaded_non_object(self): + return self.unloaded.intersection( + attr for attr in self.manager + if self.manager[attr].impl.accepts_scalar_loader + ) + + @property + def expired_attributes(self): + """Return the set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs. + + """ + return set([k for k, v in self.callables.items() if v is self]) + + def _instance_dict(self): + return None + + def _modified_event(self, dict_, attr, previous, collection=False, force=False): + if not attr.send_modified_events: + return + if attr.key not in self.committed_state or force: + if collection: + if previous is NEVER_SET: + if attr.key in dict_: + previous = dict_[attr.key] + + if previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + self.committed_state[attr.key] = previous + + # assert self._strong_obj is None or self.modified + + if (self.session_id and self._strong_obj is None) \ + or not self.modified: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict._modified.add(self) + + # only create _strong_obj link if attached + # to a session + + inst = self.obj() + if self.session_id: + self._strong_obj = inst + + if inst is None: + raise orm_exc.ObjectDereferencedError( + "Can't emit change event for attribute '%s' - " + "parent object of type %s has been garbage " + "collected." + % ( + self.manager[attr.key], + base.state_class_str(self) + )) + self.modified = True + + def _commit(self, dict_, keys): + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + for key in keys: + self.committed_state.pop(key, None) + + self.expired = False + + for key in set(self.callables).\ + intersection(keys).\ + intersection(dict_): + del self.callables[key] + + def _commit_all(self, dict_, instance_dict=None): + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables for attributes loaded are removed. + + Attributes marked as "expired" can potentially remain + "expired" after this step if a value was not populated in state.dict. + + """ + self._commit_all_states([(self, dict_)], instance_dict) + + @classmethod + def _commit_all_states(self, iter, instance_dict=None): + """Mass version of commit_all().""" + + for state, dict_ in iter: + state.committed_state.clear() + InstanceState._pending_mutations._reset(state) + + callables = state.callables + for key in list(callables): + if key in dict_ and callables[key] is state: + del callables[key] + + if instance_dict and state.modified: + instance_dict._modified.discard(state) + + state.modified = state.expired = False + state._strong_obj = None + + +class AttributeState(object): + """Provide an inspection interface corresponding + to a particular attribute on a particular mapped object. + + The :class:`.AttributeState` object is accessed + via the :attr:`.InstanceState.attrs` collection + of a particular :class:`.InstanceState`:: + + from sqlalchemy import inspect + + insp = inspect(some_mapped_object) + attr_state = insp.attrs.some_attribute + + """ + + def __init__(self, state, key): + self.state = state + self.key = key + + @property + def loaded_value(self): + """The current value of this attribute as loaded from the database. + + If the value has not been loaded, or is otherwise not present + in the object's dictionary, returns NO_VALUE. + + """ + return self.state.dict.get(self.key, NO_VALUE) + + @property + def value(self): + """Return the value of this attribute. + + This operation is equivalent to accessing the object's + attribute directly or via ``getattr()``, and will fire + off any pending loader callables if needed. + + """ + return self.state.manager[self.key].__get__( + self.state.obj(), self.state.class_) + + @property + def history(self): + """Return the current pre-flush change history for + this attribute, via the :class:`.History` interface. + + This method will **not** emit loader callables if the value of the + attribute is unloaded. + + .. seealso:: + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + :func:`.attributes.get_history` - underlying function + + """ + return self.state.get_history(self.key, + PASSIVE_NO_INITIALIZE) + + def load_history(self): + """Return the current pre-flush change history for + this attribute, via the :class:`.History` interface. + + This method **will** emit loader callables if the value of the + attribute is unloaded. + + .. seealso:: + + :attr:`.AttributeState.history` + + :func:`.attributes.get_history` - underlying function + + .. versionadded:: 0.9.0 + + """ + return self.state.get_history(self.key, + PASSIVE_OFF ^ INIT_OK) + + + +class PendingCollection(object): + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + else: + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py new file mode 100644 index 00000000..4a07e785 --- /dev/null +++ b/lib/sqlalchemy/orm/strategies.py @@ -0,0 +1,1459 @@ +# orm/strategies.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""sqlalchemy.orm.interfaces.LoaderStrategy + implementations, and related MapperOptions.""" + +from .. import exc as sa_exc, inspect +from .. import util, log, event +from ..sql import util as sql_util, visitors +from .. import sql +from . import ( + attributes, interfaces, exc as orm_exc, loading, + unitofwork, util as orm_util + ) +from .state import InstanceState +from .util import _none_set +from . import properties +from .interfaces import ( + LoaderStrategy, StrategizedProperty + ) +from .session import _state_session +import itertools + +def _register_attribute(strategy, mapper, useobject, + compare_function=None, + typecallable=None, + uselist=False, + callable_=None, + proxy_property=None, + active_history=False, + impl_class=None, + **kw +): + + prop = strategy.parent_property + + attribute_ext = list(util.to_list(prop.extension, default=[])) + + listen_hooks = [] + + if useobject and prop.single_parent: + listen_hooks.append(single_parent_validator) + + if prop.key in prop.parent.validators: + fn, opts = prop.parent.validators[prop.key] + listen_hooks.append( + lambda desc, prop: orm_util._validator_events(desc, + prop.key, fn, **opts) + ) + + if useobject: + listen_hooks.append(unitofwork.track_cascade_events) + + # need to assemble backref listeners + # after the singleparentvalidator, mapper validator + backref = kw.pop('backref', None) + if backref: + listen_hooks.append( + lambda desc, prop: attributes.backref_listeners(desc, + backref, + uselist) + ) + + for m in mapper.self_and_descendants: + if prop is m._props.get(prop.key): + + desc = attributes.register_attribute_impl( + m.class_, + prop.key, + parent_token=prop, + uselist=uselist, + compare_function=compare_function, + useobject=useobject, + extension=attribute_ext, + trackparent=useobject and (prop.single_parent + or prop.direction is interfaces.ONETOMANY), + typecallable=typecallable, + callable_=callable_, + active_history=active_history, + impl_class=impl_class, + send_modified_events=not useobject or not prop.viewonly, + doc=prop.doc, + **kw + ) + + for hook in listen_hooks: + hook(desc, prop) + +@properties.ColumnProperty.strategy_for(instrument=False, deferred=False) +class UninstrumentedColumnLoader(LoaderStrategy): + """Represent the a non-instrumented MapperProperty. + + The polymorphic_on argument of mapper() often results in this, + if the argument is against the with_polymorphic selectable. + + """ + def __init__(self, parent): + super(UninstrumentedColumnLoader, self).__init__(parent) + self.columns = self.parent_property.columns + + def setup_query(self, context, entity, path, loadopt, adapter, + column_collection=None, **kwargs): + for c in self.columns: + if adapter: + c = adapter.columns[c] + column_collection.append(c) + + def create_row_processor(self, context, path, loadopt, mapper, row, adapter): + return None, None, None + + +@log.class_logger +@properties.ColumnProperty.strategy_for(instrument=True, deferred=False) +class ColumnLoader(LoaderStrategy): + """Provide loading behavior for a :class:`.ColumnProperty`.""" + + def __init__(self, parent): + super(ColumnLoader, self).__init__(parent) + self.columns = self.parent_property.columns + self.is_composite = hasattr(self.parent_property, 'composite_class') + + def setup_query(self, context, entity, path, loadopt, + adapter, column_collection, **kwargs): + for c in self.columns: + if adapter: + c = adapter.columns[c] + column_collection.append(c) + + def init_class_attribute(self, mapper): + self.is_class_level = True + coltype = self.columns[0].type + # TODO: check all columns ? check for foreign key as well? + active_history = self.parent_property.active_history or \ + self.columns[0].primary_key or \ + mapper.version_id_col in set(self.columns) + + _register_attribute(self, mapper, useobject=False, + compare_function=coltype.compare_values, + active_history=active_history + ) + + def create_row_processor(self, context, path, + loadopt, mapper, row, adapter): + key = self.key + # look through list of columns represented here + # to see which, if any, is present in the row. + for col in self.columns: + if adapter: + col = adapter.columns[col] + if col is not None and col in row: + def fetch_col(state, dict_, row): + dict_[key] = row[col] + return fetch_col, None, None + else: + def expire_for_non_present_col(state, dict_, row): + state._expire_attribute_pre_commit(dict_, key) + return expire_for_non_present_col, None, None + + + +@log.class_logger +@properties.ColumnProperty.strategy_for(deferred=True, instrument=True) +class DeferredColumnLoader(LoaderStrategy): + """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" + + def __init__(self, parent): + super(DeferredColumnLoader, self).__init__(parent) + if hasattr(self.parent_property, 'composite_class'): + raise NotImplementedError("Deferred loading for composite " + "types not implemented yet") + self.columns = self.parent_property.columns + self.group = self.parent_property.group + + def create_row_processor(self, context, path, loadopt, mapper, row, adapter): + col = self.columns[0] + if adapter: + col = adapter.columns[col] + + key = self.key + if col in row: + return self.parent_property._get_strategy_by_cls(ColumnLoader).\ + create_row_processor( + context, path, loadopt, mapper, row, adapter) + + elif not self.is_class_level: + set_deferred_for_local_state = InstanceState._row_processor( + mapper.class_manager, + LoadDeferredColumns(key), key) + return set_deferred_for_local_state, None, None + else: + def reset_col_for_deferred(state, dict_, row): + # reset state on the key so that deferred callables + # fire off on next access. + state._reset(dict_, key) + return reset_col_for_deferred, None, None + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute(self, mapper, useobject=False, + compare_function=self.columns[0].type.compare_values, + callable_=self._load_for_state, + expire_missing=False + ) + + def setup_query(self, context, entity, path, loadopt, adapter, + only_load_props=None, **kwargs): + if ( + loadopt and self.group and + loadopt.local_opts.get('undefer_group', False) == self.group + ) or (only_load_props and self.key in only_load_props): + self.parent_property._get_strategy_by_cls(ColumnLoader).\ + setup_query(context, entity, + path, loadopt, adapter, **kwargs) + + def _load_for_state(self, state, passive): + if not state.key: + return attributes.ATTR_EMPTY + + if not passive & attributes.SQL_OK: + return attributes.PASSIVE_NO_RESULT + + localparent = state.manager.mapper + + if self.group: + toload = [ + p.key for p in + localparent.iterate_properties + if isinstance(p, StrategizedProperty) and + isinstance(p.strategy, DeferredColumnLoader) and + p.group == self.group + ] + else: + toload = [self.key] + + # narrow the keys down to just those which have no history + group = [k for k in toload if k in state.unmodified] + + session = _state_session(state) + if session is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "deferred load operation of attribute '%s' cannot proceed" % + (orm_util.state_str(state), self.key) + ) + + query = session.query(localparent) + if loading.load_on_ident(query, state.key, + only_load_props=group, refresh_state=state) is None: + raise orm_exc.ObjectDeletedError(state) + + return attributes.ATTR_WAS_SET + + + +class LoadDeferredColumns(object): + """serializable loader object used by DeferredColumnLoader""" + + def __init__(self, key): + self.key = key + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + + localparent = state.manager.mapper + prop = localparent._props[key] + strategy = prop._strategies[DeferredColumnLoader] + return strategy._load_for_state(state, passive) + + + +class AbstractRelationshipLoader(LoaderStrategy): + """LoaderStratgies which deal with related objects.""" + + def __init__(self, parent): + super(AbstractRelationshipLoader, self).__init__(parent) + self.mapper = self.parent_property.mapper + self.target = self.parent_property.target + self.uselist = self.parent_property.uselist + + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="noload") +@properties.RelationshipProperty.strategy_for(lazy=None) +class NoLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.RelationshipProperty` + with "lazy=None". + + """ + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute(self, mapper, + useobject=True, + uselist=self.parent_property.uselist, + typecallable=self.parent_property.collection_class, + ) + + def create_row_processor(self, context, path, loadopt, mapper, row, adapter): + def invoke_no_load(state, dict_, row): + state._initialize(self.key) + return invoke_no_load, None, None + + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy=True) +@properties.RelationshipProperty.strategy_for(lazy="select") +class LazyLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.RelationshipProperty` + with "lazy=True", that is loads when first accessed. + + """ + + def __init__(self, parent): + super(LazyLoader, self).__init__(parent) + join_condition = self.parent_property._join_condition + self._lazywhere, \ + self._bind_to_col, \ + self._equated_columns = join_condition.create_lazy_clause() + + self._rev_lazywhere, \ + self._rev_bind_to_col, \ + self._rev_equated_columns = join_condition.create_lazy_clause( + reverse_direction=True) + + self.logger.info("%s lazy loading clause %s", self, self._lazywhere) + + # determine if our "lazywhere" clause is the same as the mapper's + # get() clause. then we can just use mapper.get() + self.use_get = not self.uselist and \ + self.mapper._get_clause[0].compare( + self._lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns + ) + + if self.use_get: + for col in list(self._equated_columns): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + + self.logger.info("%s will use query.get() to " + "optimize instance loads" % self) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + active_history = ( + self.parent_property.active_history or + self.parent_property.direction is not interfaces.MANYTOONE or + not self.use_get + ) + + # MANYTOONE currently only needs the + # "old" value for delete-orphan + # cascades. the required _SingleParentValidator + # will enable active_history + # in that case. otherwise we don't need the + # "old" value during backref operations. + _register_attribute(self, + mapper, + useobject=True, + callable_=self._load_for_state, + uselist=self.parent_property.uselist, + backref=self.parent_property.back_populates, + typecallable=self.parent_property.collection_class, + active_history=active_history + ) + + def lazy_clause(self, state, reverse_direction=False, + alias_secondary=False, + adapt_source=None, + passive=None): + if state is None: + return self._lazy_none_clause( + reverse_direction, + adapt_source=adapt_source) + + if not reverse_direction: + criterion, bind_to_col, rev = \ + self._lazywhere, \ + self._bind_to_col, \ + self._equated_columns + else: + criterion, bind_to_col, rev = \ + self._rev_lazywhere, \ + self._rev_bind_to_col, \ + self._rev_equated_columns + + if reverse_direction: + mapper = self.parent_property.mapper + else: + mapper = self.parent_property.parent + + o = state.obj() # strong ref + dict_ = attributes.instance_dict(o) + + # use the "committed state" only if we're in a flush + # for this state. + + if passive and passive & attributes.LOAD_AGAINST_COMMITTED: + def visit_bindparam(bindparam): + if bindparam._identifying_key in bind_to_col: + bindparam.callable = \ + lambda: mapper._get_committed_state_attr_by_column( + state, dict_, + bind_to_col[bindparam._identifying_key]) + else: + def visit_bindparam(bindparam): + if bindparam._identifying_key in bind_to_col: + bindparam.callable = \ + lambda: mapper._get_state_attr_by_column( + state, dict_, + bind_to_col[bindparam._identifying_key]) + + if self.parent_property.secondary is not None and alias_secondary: + criterion = sql_util.ClauseAdapter( + self.parent_property.secondary.alias()).\ + traverse(criterion) + + criterion = visitors.cloned_traverse( + criterion, {}, {'bindparam': visit_bindparam}) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): + if not reverse_direction: + criterion, bind_to_col, rev = \ + self._lazywhere, \ + self._bind_to_col,\ + self._equated_columns + else: + criterion, bind_to_col, rev = \ + self._rev_lazywhere, \ + self._rev_bind_to_col, \ + self._rev_equated_columns + + criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _load_for_state(self, state, passive): + if not state.key and \ + ( + ( + not self.parent_property.load_on_pending + and not state._load_pending + ) + or not state.session_id + ): + return attributes.ATTR_EMPTY + + pending = not state.key + ident_key = None + + if ( + (not passive & attributes.SQL_OK and not self.use_get) + or + (not passive & attributes.NON_PERSISTENT_OK and pending) + ): + return attributes.PASSIVE_NO_RESULT + + session = _state_session(state) + if not session: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "lazy load operation of attribute '%s' cannot proceed" % + (orm_util.state_str(state), self.key) + ) + + # if we have a simple primary key load, check the + # identity map without generating a Query at all + if self.use_get: + ident = self._get_ident_for_use_get( + session, + state, + passive + ) + if attributes.PASSIVE_NO_RESULT in ident: + return attributes.PASSIVE_NO_RESULT + elif attributes.NEVER_SET in ident: + return attributes.NEVER_SET + + if _none_set.issuperset(ident): + return None + + ident_key = self.mapper.identity_key_from_primary_key(ident) + instance = loading.get_from_identity(session, ident_key, passive) + if instance is not None: + return instance + elif not passive & attributes.SQL_OK or \ + not passive & attributes.RELATED_OBJECT_OK: + return attributes.PASSIVE_NO_RESULT + + return self._emit_lazyload(session, state, ident_key, passive) + + def _get_ident_for_use_get(self, session, state, passive): + instance_mapper = state.manager.mapper + + if passive & attributes.LOAD_AGAINST_COMMITTED: + get_attr = instance_mapper._get_committed_state_attr_by_column + else: + get_attr = instance_mapper._get_state_attr_by_column + + dict_ = state.dict + + return [ + get_attr( + state, + dict_, + self._equated_columns[pk], + passive=passive) + for pk in self.mapper.primary_key + ] + + @util.dependencies("sqlalchemy.orm.strategy_options") + def _emit_lazyload(self, strategy_options, session, state, ident_key, passive): + q = session.query(self.mapper)._adapt_all_clauses() + + if self.parent_property.secondary is not None: + q = q.select_from(self.mapper, self.parent_property.secondary) + + q = q._with_invoke_all_eagers(False) + + pending = not state.key + + # don't autoflush on pending + if pending or passive & attributes.NO_AUTOFLUSH: + q = q.autoflush(False) + + + if state.load_path: + q = q._with_current_path(state.load_path[self.parent_property]) + + if state.load_options: + q = q._conditional_options(*state.load_options) + + if self.use_get: + return loading.load_on_ident(q, ident_key) + + if self.parent_property.order_by: + q = q.order_by(*util.to_list(self.parent_property.order_by)) + + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if rev.direction is interfaces.MANYTOONE and \ + rev._use_get and \ + not isinstance(rev.strategy, LazyLoader): + q = q.options(strategy_options.Load(rev.parent).lazyload(rev.key)) + + lazy_clause = self.lazy_clause(state, passive=passive) + + if pending: + bind_values = sql_util.bind_values(lazy_clause) + if None in bind_values: + return None + + q = q.filter(lazy_clause) + + + result = q.all() + if self.uselist: + return result + else: + l = len(result) + if l: + if l > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for lazily-loaded attribute '%s' " + % self.parent_property) + + return result[0] + else: + return None + + def create_row_processor(self, context, path, loadopt, + mapper, row, adapter): + key = self.key + if not self.is_class_level: + # we are not the primary manager for this attribute + # on this class - set up a + # per-instance lazyloader, which will override the + # class-level behavior. + # this currently only happens when using a + # "lazyload" option on a "no load" + # attribute - "eager" attributes always have a + # class-level lazyloader installed. + set_lazy_callable = InstanceState._row_processor( + mapper.class_manager, + LoadLazyAttribute(key), key) + + return set_lazy_callable, None, None + else: + def reset_for_lazy_callable(state, dict_, row): + # we are the primary manager for this attribute on + # this class - reset its + # per-instance attribute state, so that the class-level + # lazy loader is + # executed when next referenced on this instance. + # this is needed in + # populate_existing() types of scenarios to reset + # any existing state. + state._reset(dict_, key) + + return reset_for_lazy_callable, None, None + + + +class LoadLazyAttribute(object): + """serializable loader object used by LazyLoader""" + + def __init__(self, key): + self.key = key + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + instance_mapper = state.manager.mapper + prop = instance_mapper._props[key] + strategy = prop._strategies[LazyLoader] + + return strategy._load_for_state(state, passive) + + +@properties.RelationshipProperty.strategy_for(lazy="immediate") +class ImmediateLoader(AbstractRelationshipLoader): + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy_by_cls(LazyLoader).\ + init_class_attribute(mapper) + + def setup_query(self, context, entity, + path, loadopt, adapter, column_collection=None, + parentmapper=None, **kwargs): + pass + + def create_row_processor(self, context, path, loadopt, + mapper, row, adapter): + def load_immediate(state, dict_, row): + state.get_impl(self.key).get(state, dict_) + + return None, None, load_immediate + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="subquery") +class SubqueryLoader(AbstractRelationshipLoader): + def __init__(self, parent): + super(SubqueryLoader, self).__init__(parent) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy_by_cls(LazyLoader).\ + init_class_attribute(mapper) + + def setup_query(self, context, entity, + path, loadopt, adapter, + column_collection=None, + parentmapper=None, **kwargs): + + if not context.query._enable_eagerloads: + return + + path = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_info = path.get(context.attributes, + "path_with_polymorphic", None) + if with_poly_info is not None: + effective_entity = with_poly_info.entity + else: + effective_entity = self.mapper + + subq_path = context.attributes.get(('subquery_path', None), + orm_util.PathRegistry.root) + + subq_path = subq_path + path + + # if not via query option, check for + # a cycle + if not path.contains(context.attributes, "loader"): + if self.join_depth: + if path.length / 2 > self.join_depth: + return + elif subq_path.contains_mapper(self.mapper): + return + + subq_mapper, leftmost_mapper, leftmost_attr, leftmost_relationship = \ + self._get_leftmost(subq_path) + + orig_query = context.attributes.get( + ("orig_query", SubqueryLoader), + context.query) + + # generate a new Query from the original, then + # produce a subquery from it. + left_alias = self._generate_from_original_query( + orig_query, leftmost_mapper, + leftmost_attr, leftmost_relationship, + entity.mapper + ) + + # generate another Query that will join the + # left alias to the target relationships. + # basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + q = orig_query.session.query(effective_entity) + q._attributes = { + ("orig_query", SubqueryLoader): orig_query, + ('subquery_path', None): subq_path + } + q = q._enable_single_crit(False) + + to_join, local_attr, parent_alias = \ + self._prep_for_joins(left_alias, subq_path) + q = q.order_by(*local_attr) + q = q.add_columns(*local_attr) + + q = self._apply_joins(q, to_join, left_alias, + parent_alias, effective_entity) + + q = self._setup_options(q, subq_path, orig_query, effective_entity) + q = self._setup_outermost_orderby(q) + + # add new query to attributes to be picked up + # by create_row_processor + path.set(context.attributes, "subquery", q) + + def _get_leftmost(self, subq_path): + subq_path = subq_path.path + subq_mapper = orm_util._class_to_mapper(subq_path[0]) + + # determine attributes of the leftmost mapper + if self.parent.isa(subq_mapper) and self.parent_property is subq_path[1]: + leftmost_mapper, leftmost_prop = \ + self.parent, self.parent_property + else: + leftmost_mapper, leftmost_prop = \ + subq_mapper, \ + subq_path[1] + + leftmost_cols = leftmost_prop.local_columns + + leftmost_attr = [ + leftmost_mapper._columntoproperty[c].class_attribute + for c in leftmost_cols + ] + return subq_mapper, leftmost_mapper, leftmost_attr, leftmost_prop + + def _generate_from_original_query(self, + orig_query, leftmost_mapper, + leftmost_attr, leftmost_relationship, + entity_mapper + ): + # reformat the original query + # to look only for significant columns + q = orig_query._clone().correlate(None) + + # set a real "from" if not present, as this is more + # accurate than just going off of the column expression + if not q._from_obj and entity_mapper.isa(leftmost_mapper): + q._set_select_from([entity_mapper], False) + + target_cols = q._adapt_col_list(leftmost_attr) + + # select from the identity columns of the outer + q._set_entities(target_cols) + + distinct_target_key = leftmost_relationship.distinct_target_key + + if distinct_target_key is True: + q._distinct = True + elif distinct_target_key is None: + # if target_cols refer to a non-primary key or only + # part of a composite primary key, set the q as distinct + for t in set(c.table for c in target_cols): + if not set(target_cols).issuperset(t.primary_key): + q._distinct = True + break + + if q._order_by is False: + q._order_by = leftmost_mapper.order_by + + # don't need ORDER BY if no limit/offset + if q._limit is None and q._offset is None: + q._order_by = None + + # the original query now becomes a subquery + # which we'll join onto. + + embed_q = q.with_labels().subquery() + left_alias = orm_util.AliasedClass(leftmost_mapper, embed_q, + use_mapper_path=True) + return left_alias + + def _prep_for_joins(self, left_alias, subq_path): + # figure out what's being joined. a.k.a. the fun part + to_join = [] + pairs = list(subq_path.pairs()) + + for i, (mapper, prop) in enumerate(pairs): + if i > 0: + # look at the previous mapper in the chain - + # if it is as or more specific than this prop's + # mapper, use that instead. + # note we have an assumption here that + # the non-first element is always going to be a mapper, + # not an AliasedClass + + prev_mapper = pairs[i - 1][1].mapper + to_append = prev_mapper if prev_mapper.isa(mapper) else mapper + else: + to_append = mapper + + to_join.append((to_append, prop.key)) + + # determine the immediate parent class we are joining from, + # which needs to be aliased. + if len(to_join) > 1: + info = inspect(to_join[-1][0]) + + if len(to_join) < 2: + # in the case of a one level eager load, this is the + # leftmost "left_alias". + parent_alias = left_alias + elif info.mapper.isa(self.parent): + # In the case of multiple levels, retrieve + # it from subq_path[-2]. This is the same as self.parent + # in the vast majority of cases, and [ticket:2014] + # illustrates a case where sub_path[-2] is a subclass + # of self.parent + parent_alias = orm_util.AliasedClass(to_join[-1][0], + use_mapper_path=True) + else: + # if of_type() were used leading to this relationship, + # self.parent is more specific than subq_path[-2] + parent_alias = orm_util.AliasedClass(self.parent, + use_mapper_path=True) + + local_cols = self.parent_property.local_columns + + local_attr = [ + getattr(parent_alias, self.parent._columntoproperty[c].key) + for c in local_cols + ] + return to_join, local_attr, parent_alias + + def _apply_joins(self, q, to_join, left_alias, parent_alias, + effective_entity): + for i, (mapper, key) in enumerate(to_join): + + # we need to use query.join() as opposed to + # orm.join() here because of the + # rich behavior it brings when dealing with + # "with_polymorphic" mappers. "aliased" + # and "from_joinpoint" take care of most of + # the chaining and aliasing for us. + + first = i == 0 + middle = i < len(to_join) - 1 + second_to_last = i == len(to_join) - 2 + last = i == len(to_join) - 1 + + if first: + attr = getattr(left_alias, key) + if last and effective_entity is not self.mapper: + attr = attr.of_type(effective_entity) + else: + if last and effective_entity is not self.mapper: + attr = getattr(parent_alias, key).\ + of_type(effective_entity) + else: + attr = key + + if second_to_last: + q = q.join(parent_alias, attr, from_joinpoint=True) + else: + q = q.join(attr, aliased=middle, from_joinpoint=True) + return q + + def _setup_options(self, q, subq_path, orig_query, effective_entity): + # propagate loader options etc. to the new query. + # these will fire relative to subq_path. + q = q._with_current_path(subq_path) + q = q._conditional_options(*orig_query._with_options) + if orig_query._populate_existing: + q._populate_existing = orig_query._populate_existing + + return q + + def _setup_outermost_orderby(self, q): + if self.parent_property.order_by: + # if there's an ORDER BY, alias it the same + # way joinedloader does, but we have to pull out + # the "eagerjoin" from the query. + # this really only picks up the "secondary" table + # right now. + eagerjoin = q._from_obj[0] + eager_order_by = \ + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) + q = q.order_by(*eager_order_by) + return q + + class _SubqCollections(object): + """Given a :class:`.Query` used to emit the "subquery load", + provide a load interface that executes the query at the + first moment a value is needed. + + """ + _data = None + + def __init__(self, subq): + self.subq = subq + + def get(self, key, default): + if self._data is None: + self._load() + return self._data.get(key, default) + + def _load(self): + self._data = dict( + (k, [vv[0] for vv in v]) + for k, v in itertools.groupby( + self.subq, + lambda x: x[1:] + ) + ) + + def loader(self, state, dict_, row): + if self._data is None: + self._load() + + def create_row_processor(self, context, path, loadopt, + mapper, row, adapter): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % + self) + + path = path[self.parent_property] + + subq = path.get(context.attributes, 'subquery') + + if subq is None: + return None, None, None + + local_cols = self.parent_property.local_columns + + # cache the loaded collections in the context + # so that inheriting mappers don't re-load when they + # call upon create_row_processor again + collections = path.get(context.attributes, "collections") + if collections is None: + collections = self._SubqCollections(subq) + path.set(context.attributes, 'collections', collections) + + if adapter: + local_cols = [adapter.columns[c] for c in local_cols] + + if self.uselist: + return self._create_collection_loader(collections, local_cols) + else: + return self._create_scalar_loader(collections, local_cols) + + def _create_collection_loader(self, collections, local_cols): + def load_collection_from_subq(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + () + ) + state.get_impl(self.key).\ + set_committed_value(state, dict_, collection) + + return load_collection_from_subq, None, None, collections.loader + + def _create_scalar_loader(self, collections, local_cols): + def load_scalar_from_subq(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + (None,) + ) + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + + scalar = collection[0] + state.get_impl(self.key).\ + set_committed_value(state, dict_, scalar) + + return load_scalar_from_subq, None, None, collections.loader + + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="joined") +@properties.RelationshipProperty.strategy_for(lazy=False) +class JoinedLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.RelationshipProperty` + using joined eager loading. + + """ + def __init__(self, parent): + super(JoinedLoader, self).__init__(parent) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy_by_cls(LazyLoader).init_class_attribute(mapper) + + def setup_query(self, context, entity, path, loadopt, adapter, \ + column_collection=None, parentmapper=None, + **kwargs): + """Add a left outer join to the statement thats being constructed.""" + + if not context.query._enable_eagerloads: + return + + path = path[self.parent_property] + + with_polymorphic = None + + user_defined_adapter = self._init_user_defined_eager_proc( + loadopt, context) if loadopt else False + + if user_defined_adapter is not False: + clauses, adapter, add_to_collection = \ + self._setup_query_on_user_defined_adapter( + context, entity, path, adapter, + user_defined_adapter + ) + else: + # if not via query option, check for + # a cycle + if not path.contains(context.attributes, "loader"): + if self.join_depth: + if path.length / 2 > self.join_depth: + return + elif path.contains_mapper(self.mapper): + return + + clauses, adapter, add_to_collection = self._generate_row_adapter( + context, entity, path, loadopt, adapter, + column_collection, parentmapper + ) + + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic", + None + ) + if with_poly_info is not None: + with_polymorphic = with_poly_info.with_polymorphic_mappers + else: + with_polymorphic = None + + path = path[self.mapper] + + for value in self.mapper._iterate_polymorphic_properties( + mappers=with_polymorphic): + value.setup( + context, + entity, + path, + clauses, + parentmapper=self.mapper, + column_collection=add_to_collection) + + if with_poly_info is not None and \ + None in set(context.secondary_columns): + raise sa_exc.InvalidRequestError( + "Detected unaliased columns when generating joined " + "load. Make sure to use aliased=True or flat=True " + "when using joined loading with with_polymorphic()." + ) + + def _init_user_defined_eager_proc(self, loadopt, context): + + # check if the opt applies at all + if "eager_from_alias" not in loadopt.local_opts: + # nope + return False + + path = loadopt.path.parent + + # the option applies. check if the "user_defined_eager_row_processor" + # has been built up. + adapter = path.get(context.attributes, + "user_defined_eager_row_processor", False) + if adapter is not False: + # just return it + return adapter + + # otherwise figure it out. + alias = loadopt.local_opts["eager_from_alias"] + + root_mapper, prop = path[-2:] + + #from .mapper import Mapper + #from .interfaces import MapperProperty + #assert isinstance(root_mapper, Mapper) + #assert isinstance(prop, MapperProperty) + + if alias is not None: + if isinstance(alias, str): + alias = prop.target.alias(alias) + adapter = sql_util.ColumnAdapter(alias, + equivalents=prop.mapper._equivalent_columns) + else: + if path.contains(context.attributes, "path_with_polymorphic"): + with_poly_info = path.get(context.attributes, + "path_with_polymorphic") + adapter = orm_util.ORMAdapter( + with_poly_info.entity, + equivalents=prop.mapper._equivalent_columns) + else: + adapter = context.query._polymorphic_adapters.get(prop.mapper, None) + path.set(context.attributes, + "user_defined_eager_row_processor", + adapter) + + return adapter + + def _setup_query_on_user_defined_adapter(self, context, entity, + path, adapter, user_defined_adapter): + + # apply some more wrapping to the "user defined adapter" + # if we are setting up the query for SQL render. + adapter = entity._get_entity_clauses(context.query, context) + + if adapter and user_defined_adapter: + user_defined_adapter = user_defined_adapter.wrap(adapter) + path.set(context.attributes, "user_defined_eager_row_processor", + user_defined_adapter) + elif adapter: + user_defined_adapter = adapter + path.set(context.attributes, "user_defined_eager_row_processor", + user_defined_adapter) + + add_to_collection = context.primary_columns + return user_defined_adapter, adapter, add_to_collection + + def _generate_row_adapter(self, + context, entity, path, loadopt, adapter, + column_collection, parentmapper + ): + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic", + None + ) + if with_poly_info: + to_adapt = with_poly_info.entity + else: + to_adapt = orm_util.AliasedClass(self.mapper, + flat=True, + use_mapper_path=True) + clauses = orm_util.ORMAdapter( + to_adapt, + equivalents=self.mapper._equivalent_columns, + adapt_required=True) + assert clauses.aliased_class is not None + + if self.parent_property.direction != interfaces.MANYTOONE: + context.multi_row_eager_loaders = True + + innerjoin = ( + loadopt.local_opts.get( + 'innerjoin', self.parent_property.innerjoin) + if loadopt is not None + else self.parent_property.innerjoin + ) + + context.create_eager_joins.append( + (self._create_eager_join, context, + entity, path, adapter, + parentmapper, clauses, innerjoin) + ) + + add_to_collection = context.secondary_columns + path.set(context.attributes, "eager_row_processor", clauses) + + return clauses, adapter, add_to_collection + + def _create_eager_join(self, context, entity, + path, adapter, parentmapper, + clauses, innerjoin): + + if parentmapper is None: + localparent = entity.mapper + else: + localparent = parentmapper + + # whether or not the Query will wrap the selectable in a subquery, + # and then attach eager load joins to that (i.e., in the case of + # LIMIT/OFFSET etc.) + should_nest_selectable = context.multi_row_eager_loaders and \ + context.query._should_nest_selectable + + entity_key = None + + if entity not in context.eager_joins and \ + not should_nest_selectable and \ + context.from_clause: + index, clause = \ + sql_util.find_join_source( + context.from_clause, entity.selectable) + if clause is not None: + # join to an existing FROM clause on the query. + # key it to its list index in the eager_joins dict. + # Query._compile_context will adapt as needed and + # append to the FROM clause of the select(). + entity_key, default_towrap = index, clause + + if entity_key is None: + entity_key, default_towrap = entity, entity.selectable + + towrap = context.eager_joins.setdefault(entity_key, default_towrap) + + if adapter: + if getattr(adapter, 'aliased_class', None): + onclause = getattr( + adapter.aliased_class, self.key, + self.parent_property) + else: + onclause = getattr( + orm_util.AliasedClass( + self.parent, + adapter.selectable, + use_mapper_path=True + ), + self.key, self.parent_property + ) + + else: + onclause = self.parent_property + + assert clauses.aliased_class is not None + + join_to_outer = innerjoin and isinstance(towrap, sql.Join) and towrap.isouter + + if join_to_outer and innerjoin == 'nested': + inner = orm_util.join( + towrap.right, + clauses.aliased_class, + onclause, + isouter=False + ) + + eagerjoin = orm_util.join( + towrap.left, + inner, + towrap.onclause, + isouter=True + ) + eagerjoin._target_adapter = inner._target_adapter + else: + if join_to_outer: + innerjoin = False + eagerjoin = orm_util.join( + towrap, + clauses.aliased_class, + onclause, + isouter=not innerjoin + ) + context.eager_joins[entity_key] = eagerjoin + + # send a hint to the Query as to where it may "splice" this join + eagerjoin.stop_on = entity.selectable + + if self.parent_property.secondary is None and \ + not parentmapper: + # for parentclause that is the non-eager end of the join, + # ensure all the parent cols in the primaryjoin are actually + # in the + # columns clause (i.e. are not deferred), so that aliasing applied + # by the Query propagates those columns outward. + # This has the effect + # of "undefering" those columns. + for col in sql_util._find_columns( + self.parent_property.primaryjoin): + if localparent.mapped_table.c.contains_column(col): + if adapter: + col = adapter.columns[col] + context.primary_columns.append(col) + + if self.parent_property.order_by: + context.eager_order_by += \ + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) + + def _create_eager_adapter(self, context, row, adapter, path, loadopt): + user_defined_adapter = self._init_user_defined_eager_proc( + loadopt, context) if loadopt else False + + if user_defined_adapter is not False: + decorator = user_defined_adapter + # user defined eagerloads are part of the "primary" + # portion of the load. + # the adapters applied to the Query should be honored. + if context.adapter and decorator: + decorator = decorator.wrap(context.adapter) + elif context.adapter: + decorator = context.adapter + else: + decorator = path.get(context.attributes, "eager_row_processor") + if decorator is None: + return False + + try: + self.mapper.identity_key_from_row(row, decorator) + return decorator + except KeyError: + # no identity key - dont return a row + # processor, will cause a degrade to lazy + return False + + def create_row_processor(self, context, path, loadopt, mapper, row, adapter): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % + self) + + our_path = path[self.parent_property] + + eager_adapter = self._create_eager_adapter( + context, + row, + adapter, our_path, loadopt) + + if eager_adapter is not False: + key = self.key + + _instance = loading.instance_processor( + self.mapper, + context, + our_path[self.mapper], + eager_adapter) + + if not self.uselist: + return self._create_scalar_loader(context, key, _instance) + else: + return self._create_collection_loader(context, key, _instance) + else: + return self.parent_property.\ + _get_strategy_by_cls(LazyLoader).\ + create_row_processor( + context, path, loadopt, + mapper, row, adapter) + + def _create_collection_loader(self, context, key, _instance): + def load_collection_from_joined_new_row(state, dict_, row): + collection = attributes.init_state_collection( + state, dict_, key) + result_list = util.UniqueAppender(collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + _instance(row, result_list) + + def load_collection_from_joined_existing_row(state, dict_, row): + if (state, key) in context.attributes: + result_list = context.attributes[(state, key)] + else: + # appender_key can be absent from context.attributes + # with isnew=False when self-referential eager loading + # is used; the same instance may be present in two + # distinct sets of result columns + collection = attributes.init_state_collection(state, + dict_, key) + result_list = util.UniqueAppender( + collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + _instance(row, result_list) + + def load_collection_from_joined_exec(state, dict_, row): + _instance(row, None) + + return load_collection_from_joined_new_row, \ + load_collection_from_joined_existing_row, \ + None, load_collection_from_joined_exec + + def _create_scalar_loader(self, context, key, _instance): + def load_scalar_from_joined_new_row(state, dict_, row): + # set a scalar object instance directly on the parent + # object, bypassing InstrumentedAttribute event handlers. + dict_[key] = _instance(row, None) + + def load_scalar_from_joined_existing_row(state, dict_, row): + # call _instance on the row, even though the object has + # been created, so that we further descend into properties + existing = _instance(row, None) + if existing is not None \ + and key in dict_ \ + and existing is not dict_[key]: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + + def load_scalar_from_joined_exec(state, dict_, row): + _instance(row, None) + + return load_scalar_from_joined_new_row, \ + load_scalar_from_joined_existing_row, \ + None, load_scalar_from_joined_exec + + + +def single_parent_validator(desc, prop): + def _do_check(state, value, oldvalue, initiator): + if value is not None and initiator.key == prop.key: + hasparent = initiator.hasparent(attributes.instance_state(value)) + if hasparent and oldvalue is not value: + raise sa_exc.InvalidRequestError( + "Instance %s is already associated with an instance " + "of %s via its %s attribute, and is only allowed a " + "single parent." % + (orm_util.instance_str(value), state.class_, prop) + ) + return value + + def append(state, value, initiator): + return _do_check(state, value, None, initiator) + + def set_(state, value, oldvalue, initiator): + return _do_check(state, value, oldvalue, initiator) + + event.listen(desc, 'append', append, raw=True, retval=True, + active_history=True) + event.listen(desc, 'set', set_, raw=True, retval=True, + active_history=True) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py new file mode 100644 index 00000000..317fc081 --- /dev/null +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -0,0 +1,948 @@ +# orm/strategy_options.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" + +""" + +from .interfaces import MapperOption, PropComparator +from .. import util +from ..sql.base import _generative, Generative +from .. import exc as sa_exc, inspect +from .base import _is_aliased_class, _class_to_mapper +from . import util as orm_util +from .path_registry import PathRegistry, TokenRegistry, \ + _WILDCARD_TOKEN, _DEFAULT_TOKEN + +class Load(Generative, MapperOption): + """Represents loader options which modify the state of a + :class:`.Query` in order to affect how various mapped attributes are loaded. + + .. versionadded:: 0.9.0 The :meth:`.Load` system is a new foundation for + the existing system of loader options, including options such as + :func:`.orm.joinedload`, :func:`.orm.defer`, and others. In particular, + it introduces a new method-chained system that replaces the need for + dot-separated paths as well as "_all()" options such as :func:`.orm.joinedload_all`. + + A :class:`.Load` object can be used directly or indirectly. To use one + directly, instantiate given the parent class. This style of usage is + useful when dealing with a :class:`.Query` that has multiple entities, + or when producing a loader option that can be applied generically to + any style of query:: + + myopt = Load(MyClass).joinedload("widgets") + + The above ``myopt`` can now be used with :meth:`.Query.options`:: + + session.query(MyClass).options(myopt) + + The :class:`.Load` construct is invoked indirectly whenever one makes use + of the various loader options that are present in ``sqlalchemy.orm``, including + options such as :func:`.orm.joinedload`, :func:`.orm.defer`, :func:`.orm.subqueryload`, + and all the rest. These constructs produce an "anonymous" form of the + :class:`.Load` object which tracks attributes and options, but is not linked + to a parent class until it is associated with a parent :class:`.Query`:: + + # produce "unbound" Load object + myopt = joinedload("widgets") + + # when applied using options(), the option is "bound" to the + # class observed in the given query, e.g. MyClass + session.query(MyClass).options(myopt) + + Whether the direct or indirect style is used, the :class:`.Load` object + returned now represents a specific "path" along the entities of a :class:`.Query`. + This path can be traversed using a standard method-chaining approach. + Supposing a class hierarchy such as ``User``, ``User.addresses -> Address``, + ``User.orders -> Order`` and ``Order.items -> Item``, we can specify a variety + of loader options along each element in the "path":: + + session.query(User).options( + joinedload("addresses"), + subqueryload("orders").joinedload("items") + ) + + Where above, the ``addresses`` collection will be joined-loaded, the + ``orders`` collection will be subquery-loaded, and within that subquery load + the ``items`` collection will be joined-loaded. + + + """ + def __init__(self, entity): + insp = inspect(entity) + self.path = insp._path_registry + self.context = {} + self.local_opts = {} + + def _generate(self): + cloned = super(Load, self)._generate() + cloned.local_opts = {} + return cloned + + strategy = None + propagate_to_loaders = False + + def process_query(self, query): + self._process(query, True) + + def process_query_conditionally(self, query): + self._process(query, False) + + def _process(self, query, raiseerr): + current_path = query._current_path + if current_path: + for (token, start_path), loader in self.context.items(): + chopped_start_path = self._chop_path(start_path, current_path) + if chopped_start_path is not None: + query._attributes[(token, chopped_start_path)] = loader + else: + query._attributes.update(self.context) + + def _generate_path(self, path, attr, wildcard_key, raiseerr=True): + if raiseerr and not path.has_entity: + if isinstance(path, TokenRegistry): + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by another entity") + else: + raise sa_exc.ArgumentError( + "Attribute '%s' of entity '%s' does not " + "refer to a mapped entity" % + (path.prop.key, path.parent.entity) + ) + + if isinstance(attr, util.string_types): + default_token = attr.endswith(_DEFAULT_TOKEN) + if attr.endswith(_WILDCARD_TOKEN) or default_token: + if default_token: + self.propagate_to_loaders = False + if wildcard_key: + attr = "%s:%s" % (wildcard_key, attr) + return path.token(attr) + + try: + # use getattr on the class to work around + # synonyms, hybrids, etc. + attr = getattr(path.entity.class_, attr) + except AttributeError: + if raiseerr: + raise sa_exc.ArgumentError( + "Can't find property named '%s' on the " + "mapped entity %s in this Query. " % ( + attr, path.entity) + ) + else: + return None + else: + attr = attr.property + + path = path[attr] + else: + prop = attr.property + + if not prop.parent.common_parent(path.mapper): + if raiseerr: + raise sa_exc.ArgumentError("Attribute '%s' does not " + "link from element '%s'" % (attr, path.entity)) + else: + return None + + if getattr(attr, '_of_type', None): + ac = attr._of_type + ext_info = inspect(ac) + + path_element = ext_info.mapper + if not ext_info.is_aliased_class: + ac = orm_util.with_polymorphic( + ext_info.mapper.base_mapper, + ext_info.mapper, aliased=True, + _use_mapper_path=True) + path.entity_path[prop].set(self.context, + "path_with_polymorphic", inspect(ac)) + path = path[prop][path_element] + else: + path = path[prop] + + if path.has_entity: + path = path.entity_path + return path + + def _coerce_strat(self, strategy): + if strategy is not None: + strategy = tuple(sorted(strategy.items())) + return strategy + + @_generative + def set_relationship_strategy(self, attr, strategy, propagate_to_loaders=True): + strategy = self._coerce_strat(strategy) + + self.propagate_to_loaders = propagate_to_loaders + # if the path is a wildcard, this will set propagate_to_loaders=False + self.path = self._generate_path(self.path, attr, "relationship") + self.strategy = strategy + if strategy is not None: + self._set_path_strategy() + + @_generative + def set_column_strategy(self, attrs, strategy, opts=None): + strategy = self._coerce_strat(strategy) + + for attr in attrs: + path = self._generate_path(self.path, attr, "column") + cloned = self._generate() + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + if opts: + cloned.local_opts.update(opts) + cloned._set_path_strategy() + + def _set_path_strategy(self): + if self.path.has_entity: + self.path.parent.set(self.context, "loader", self) + else: + self.path.set(self.context, "loader", self) + + def __getstate__(self): + d = self.__dict__.copy() + d["path"] = self.path.serialize() + return d + + def __setstate__(self, state): + self.__dict__.update(state) + self.path = PathRegistry.deserialize(self.path) + + def _chop_path(self, to_chop, path): + i = -1 + + for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)): + if isinstance(c_token, util.string_types): + # TODO: this is approximated from the _UnboundLoad + # version and probably has issues, not fully covered. + + if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN): + return to_chop + elif c_token != 'relationship:%s' % (_WILDCARD_TOKEN,) and c_token != p_token.key: + return None + + if c_token is p_token: + continue + else: + return None + return to_chop[i+1:] + + +class _UnboundLoad(Load): + """Represent a loader option that isn't tied to a root entity. + + The loader option will produce an entity-linked :class:`.Load` + object when it is passed :meth:`.Query.options`. + + This provides compatibility with the traditional system + of freestanding options, e.g. ``joinedload('x.y.z')``. + + """ + def __init__(self): + self.path = () + self._to_bind = set() + self.local_opts = {} + + _is_chain_link = False + + def _set_path_strategy(self): + self._to_bind.add(self) + + def _generate_path(self, path, attr, wildcard_key): + if wildcard_key and isinstance(attr, util.string_types) and \ + attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN): + if attr == _DEFAULT_TOKEN: + self.propagate_to_loaders = False + attr = "%s:%s" % (wildcard_key, attr) + + return path + (attr, ) + + def __getstate__(self): + d = self.__dict__.copy() + d['path'] = ret = [] + for token in util.to_list(self.path): + if isinstance(token, PropComparator): + ret.append((token._parentmapper.class_, token.key)) + else: + ret.append(token) + return d + + def __setstate__(self, state): + ret = [] + for key in state['path']: + if isinstance(key, tuple): + cls, propkey = key + ret.append(getattr(cls, propkey)) + else: + ret.append(key) + state['path'] = tuple(ret) + self.__dict__ = state + + def _process(self, query, raiseerr): + for val in self._to_bind: + val._bind_loader(query, query._attributes, raiseerr) + + @classmethod + def _from_keys(self, meth, keys, chained, kw): + opt = _UnboundLoad() + + def _split_key(key): + if isinstance(key, util.string_types): + # coerce fooload('*') into "default loader strategy" + if key == _WILDCARD_TOKEN: + return (_DEFAULT_TOKEN, ) + # coerce fooload(".*") into "wildcard on default entity" + elif key.startswith("." + _WILDCARD_TOKEN): + key = key[1:] + return key.split(".") + else: + return (key,) + all_tokens = [token for key in keys for token in _split_key(key)] + + for token in all_tokens[0:-1]: + if chained: + opt = meth(opt, token, **kw) + else: + opt = opt.defaultload(token) + opt._is_chain_link = True + + opt = meth(opt, all_tokens[-1], **kw) + opt._is_chain_link = False + + return opt + + + def _chop_path(self, to_chop, path): + i = -1 + for i, (c_token, (p_mapper, p_prop)) in enumerate(zip(to_chop, path.pairs())): + if isinstance(c_token, util.string_types): + if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN): + return to_chop + elif c_token != 'relationship:%s' % (_WILDCARD_TOKEN,) and c_token != p_prop.key: + return None + elif isinstance(c_token, PropComparator): + if c_token.property is not p_prop: + return None + else: + i += 1 + + return to_chop[i:] + + + def _bind_loader(self, query, context, raiseerr): + start_path = self.path + # _current_path implies we're in a + # secondary load with an existing path + + current_path = query._current_path + if current_path: + start_path = self._chop_path(start_path, current_path) + + if not start_path: + return None + + token = start_path[0] + if isinstance(token, util.string_types): + entity = self._find_entity_basestring(query, token, raiseerr) + elif isinstance(token, PropComparator): + prop = token.property + entity = self._find_entity_prop_comparator( + query, + prop.key, + token._parententity, + raiseerr) + + else: + raise sa_exc.ArgumentError( + "mapper option expects " + "string key or list of attributes") + + if not entity: + return + + path_element = entity.entity_zero + + # transfer our entity-less state into a Load() object + # with a real entity path. + loader = Load(path_element) + loader.context = context + loader.strategy = self.strategy + + path = loader.path + for token in start_path: + loader.path = path = loader._generate_path( + loader.path, token, None, raiseerr) + if path is None: + return + + loader.local_opts.update(self.local_opts) + + if loader.path.has_entity: + effective_path = loader.path.parent + else: + effective_path = loader.path + + # prioritize "first class" options over those + # that were "links in the chain", e.g. "x" and "y" in someload("x.y.z") + # versus someload("x") / someload("x.y") + if self._is_chain_link: + effective_path.setdefault(context, "loader", loader) + else: + effective_path.set(context, "loader", loader) + + def _find_entity_prop_comparator(self, query, token, mapper, raiseerr): + if _is_aliased_class(mapper): + searchfor = mapper + else: + searchfor = _class_to_mapper(mapper) + for ent in query._mapper_entities: + if ent.corresponds_to(searchfor): + return ent + else: + if raiseerr: + if not list(query._mapper_entities): + raise sa_exc.ArgumentError( + "Query has only expression-based entities - " + "can't find property named '%s'." + % (token, ) + ) + else: + raise sa_exc.ArgumentError( + "Can't find property '%s' on any entity " + "specified in this Query. Note the full path " + "from root (%s) to target entity must be specified." + % (token, ",".join(str(x) for + x in query._mapper_entities)) + ) + else: + return None + + def _find_entity_basestring(self, query, token, raiseerr): + if token.endswith(':' + _WILDCARD_TOKEN): + if len(list(query._mapper_entities)) != 1: + if raiseerr: + raise sa_exc.ArgumentError( + "Wildcard loader can only be used with exactly " + "one entity. Use Load(ent) to specify " + "specific entities.") + elif token.endswith(_DEFAULT_TOKEN): + raiseerr = False + + for ent in query._mapper_entities: + # return only the first _MapperEntity when searching + # based on string prop name. Ideally object + # attributes are used to specify more exactly. + return ent + else: + if raiseerr: + raise sa_exc.ArgumentError( + "Query has only expression-based entities - " + "can't find property named '%s'." + % (token, ) + ) + else: + return None + + + +class loader_option(object): + def __init__(self): + pass + + def __call__(self, fn): + self.name = name = fn.__name__ + self.fn = fn + if hasattr(Load, name): + raise TypeError("Load class already has a %s method." % (name)) + setattr(Load, name, fn) + + return self + + def _add_unbound_fn(self, fn): + self._unbound_fn = fn + fn_doc = self.fn.__doc__ + self.fn.__doc__ = """Produce a new :class:`.Load` object with the +:func:`.orm.%(name)s` option applied. + +See :func:`.orm.%(name)s` for usage examples. + +""" % {"name": self.name} + + fn.__doc__ = fn_doc + return self + + def _add_unbound_all_fn(self, fn): + self._unbound_all_fn = fn + fn.__doc__ = """Produce a standalone "all" option for :func:`.orm.%(name)s`. + +.. deprecated:: 0.9.0 + + The "_all()" style is replaced by method chaining, e.g.:: + + session.query(MyClass).options( + %(name)s("someattribute").%(name)s("anotherattribute") + ) + +""" % {"name": self.name} + return self + +@loader_option() +def contains_eager(loadopt, attr, alias=None): + """Indicate that the given attribute should be eagerly loaded from + columns stated manually in the query. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + The option is used in conjunction with an explicit join that loads + the desired rows, i.e.:: + + sess.query(Order).\\ + join(Order.user).\\ + options(contains_eager(Order.user)) + + The above query would join from the ``Order`` entity to its related + ``User`` entity, and the returned ``Order`` objects would have the + ``Order.user`` attribute pre-populated. + + :func:`contains_eager` also accepts an `alias` argument, which is the + string name of an alias, an :func:`~sqlalchemy.sql.expression.alias` + construct, or an :func:`~sqlalchemy.orm.aliased` construct. Use this when + the eagerly-loaded rows are to come from an aliased table:: + + user_alias = aliased(User) + sess.query(Order).\\ + join((user_alias, Order.user)).\\ + options(contains_eager(Order.user, alias=user_alias)) + + .. seealso:: + + :ref:`contains_eager` + + """ + if alias is not None: + if not isinstance(alias, str): + info = inspect(alias) + alias = info.selectable + + cloned = loadopt.set_relationship_strategy( + attr, + {"lazy": "joined"}, + propagate_to_loaders=False + ) + cloned.local_opts['eager_from_alias'] = alias + return cloned + +@contains_eager._add_unbound_fn +def contains_eager(*keys, **kw): + return _UnboundLoad()._from_keys(_UnboundLoad.contains_eager, keys, True, kw) + +@loader_option() +def load_only(loadopt, *attrs): + """Indicate that for a particular entity, only the given list + of column-based attribute names should be loaded; all others will be + deferred. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + Example - given a class ``User``, load only the ``name`` and ``fullname`` + attributes:: + + session.query(User).options(load_only("name", "fullname")) + + Example - given a relationship ``User.addresses -> Address``, specify + subquery loading for the ``User.addresses`` collection, but on each ``Address`` + object load only the ``email_address`` attribute:: + + session.query(User).options( + subqueryload("addreses").load_only("email_address") + ) + + For a :class:`.Query` that has multiple entities, the lead entity can be + specifically referred to using the :class:`.Load` constructor:: + + session.query(User, Address).join(User.addresses).options( + Load(User).load_only("name", "fullname"), + Load(Address).load_only("email_addres") + ) + + + .. versionadded:: 0.9.0 + + """ + cloned = loadopt.set_column_strategy( + attrs, + {"deferred": False, "instrument": True} + ) + cloned.set_column_strategy("*", + {"deferred": True, "instrument": True}) + return cloned + +@load_only._add_unbound_fn +def load_only(*attrs): + return _UnboundLoad().load_only(*attrs) + +@loader_option() +def joinedload(loadopt, attr, innerjoin=None): + """Indicate that the given attribute should be loaded using joined + eager loading. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # joined-load the "orders" collection on "User" + query(User).options(joinedload(User.orders)) + + # joined-load Order.items and then Item.keywords + query(Order).options(joinedload(Order.items).joinedload(Item.keywords)) + + # lazily load Order.items, but when Items are loaded, + # joined-load the keywords collection + query(Order).options(lazyload(Order.items).joinedload(Item.keywords)) + + :param innerjoin: if ``True``, indicates that the joined eager load should + use an inner join instead of the default of left outer join:: + + query(Order).options(joinedload(Order.user, innerjoin=True)) + + If the joined-eager load is chained onto an existing LEFT OUTER JOIN, + ``innerjoin=True`` will be bypassed and the join will continue to + chain as LEFT OUTER JOIN so that the results don't change. As an alternative, + specify the value ``"nested"``. This will instead nest the join + on the right side, e.g. using the form "a LEFT OUTER JOIN (b JOIN c)". + + .. versionadded:: 0.9.4 Added ``innerjoin="nested"`` option to support + nesting of eager "inner" joins. + + .. note:: + + The joins produced by :func:`.orm.joinedload` are **anonymously aliased**. + The criteria by which the join proceeds cannot be modified, nor can the + :class:`.Query` refer to these joins in any way, including ordering. + + To produce a specific SQL JOIN which is explicitly available, use + :meth:`.Query.join`. To combine explicit JOINs with eager loading + of collections, use :func:`.orm.contains_eager`; see :ref:`contains_eager`. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`contains_eager` + + :func:`.orm.subqueryload` + + :func:`.orm.lazyload` + + :paramref:`.relationship.lazy` + + :paramref:`.relationship.innerjoin` - :func:`.relationship`-level version + of the :paramref:`.joinedload.innerjoin` option. + + """ + loader = loadopt.set_relationship_strategy(attr, {"lazy": "joined"}) + if innerjoin is not None: + loader.local_opts['innerjoin'] = innerjoin + return loader + +@joinedload._add_unbound_fn +def joinedload(*keys, **kw): + return _UnboundLoad._from_keys( + _UnboundLoad.joinedload, keys, False, kw) + +@joinedload._add_unbound_all_fn +def joinedload_all(*keys, **kw): + return _UnboundLoad._from_keys( + _UnboundLoad.joinedload, keys, True, kw) + + +@loader_option() +def subqueryload(loadopt, attr): + """Indicate that the given attribute should be loaded using + subquery eager loading. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # subquery-load the "orders" collection on "User" + query(User).options(subqueryload(User.orders)) + + # subquery-load Order.items and then Item.keywords + query(Order).options(subqueryload(Order.items).subqueryload(Item.keywords)) + + # lazily load Order.items, but when Items are loaded, + # subquery-load the keywords collection + query(Order).options(lazyload(Order.items).subqueryload(Item.keywords)) + + + .. seealso:: + + :ref:`loading_toplevel` + + :func:`.orm.joinedload` + + :func:`.orm.lazyload` + + :paramref:`.relationship.lazy` + + """ + return loadopt.set_relationship_strategy(attr, {"lazy": "subquery"}) + +@subqueryload._add_unbound_fn +def subqueryload(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.subqueryload, keys, False, {}) + +@subqueryload._add_unbound_all_fn +def subqueryload_all(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.subqueryload, keys, True, {}) + +@loader_option() +def lazyload(loadopt, attr): + """Indicate that the given attribute should be loaded using "lazy" + loading. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + .. seealso:: + + :paramref:`.relationship.lazy` + + """ + return loadopt.set_relationship_strategy(attr, {"lazy": "select"}) + +@lazyload._add_unbound_fn +def lazyload(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.lazyload, keys, False, {}) + +@lazyload._add_unbound_all_fn +def lazyload_all(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.lazyload, keys, True, {}) + +@loader_option() +def immediateload(loadopt, attr): + """Indicate that the given attribute should be loaded using + an immediate load with a per-attribute SELECT statement. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + .. seealso:: + + :ref:`loading_toplevel` + + :func:`.orm.joinedload` + + :func:`.orm.lazyload` + + :paramref:`.relationship.lazy` + + """ + loader = loadopt.set_relationship_strategy(attr, {"lazy": "immediate"}) + return loader + +@immediateload._add_unbound_fn +def immediateload(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.immediateload, keys, False, {}) + + +@loader_option() +def noload(loadopt, attr): + """Indicate that the given relationship attribute should remain unloaded. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + :func:`.orm.noload` applies to :func:`.relationship` attributes; for + column-based attributes, see :func:`.orm.defer`. + + """ + + return loadopt.set_relationship_strategy(attr, {"lazy": "noload"}) + +@noload._add_unbound_fn +def noload(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.noload, keys, False, {}) + +@loader_option() +def defaultload(loadopt, attr): + """Indicate an attribute should load using its default loader style. + + This method is used to link to other loader options, such as + to set the :func:`.orm.defer` option on a class that is linked to + a relationship of the parent class being loaded, :func:`.orm.defaultload` + can be used to navigate this path without changing the loading style + of the relationship:: + + session.query(MyClass).options(defaultload("someattr").defer("some_column")) + + .. seealso:: + + :func:`.orm.defer` + + :func:`.orm.undefer` + + """ + return loadopt.set_relationship_strategy( + attr, + None + ) + +@defaultload._add_unbound_fn +def defaultload(*keys): + return _UnboundLoad._from_keys(_UnboundLoad.defaultload, keys, False, {}) + +@loader_option() +def defer(loadopt, key): + """Indicate that the given column-oriented attribute should be deferred, e.g. + not loaded until accessed. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + e.g.:: + + from sqlalchemy.orm import defer + + session.query(MyClass).options( + defer("attribute_one"), + defer("attribute_two")) + + session.query(MyClass).options( + defer(MyClass.attribute_one), + defer(MyClass.attribute_two)) + + To specify a deferred load of an attribute on a related class, + the path can be specified one token at a time, specifying the loading + style for each link along the chain. To leave the loading style + for a link unchanged, use :func:`.orm.defaultload`:: + + session.query(MyClass).options(defaultload("someattr").defer("some_column")) + + A :class:`.Load` object that is present on a certain path can have + :meth:`.Load.defer` called multiple times, each will operate on the same + parent entity:: + + + session.query(MyClass).options( + defaultload("someattr"). + defer("some_column"). + defer("some_other_column"). + defer("another_column") + ) + + :param key: Attribute to be deferred. + + :param \*addl_attrs: Deprecated; this option supports the old 0.8 style + of specifying a path as a series of attributes, which is now superseded + by the method-chained style. + + .. seealso:: + + :ref:`deferred` + + :func:`.orm.undefer` + + """ + return loadopt.set_column_strategy( + (key, ), + {"deferred": True, "instrument": True} + ) + + +@defer._add_unbound_fn +def defer(key, *addl_attrs): + return _UnboundLoad._from_keys(_UnboundLoad.defer, (key, ) + addl_attrs, False, {}) + +@loader_option() +def undefer(loadopt, key): + """Indicate that the given column-oriented attribute should be undeferred, e.g. + specified within the SELECT statement of the entity as a whole. + + The column being undeferred is typically set up on the mapping as a + :func:`.deferred` attribute. + + This function is part of the :class:`.Load` interface and supports + both method-chained and standalone operation. + + Examples:: + + # undefer two columns + session.query(MyClass).options(undefer("col1"), undefer("col2")) + + # undefer all columns specific to a single class using Load + * + session.query(MyClass, MyOtherClass).options(Load(MyClass).undefer("*")) + + :param key: Attribute to be undeferred. + + :param \*addl_attrs: Deprecated; this option supports the old 0.8 style + of specifying a path as a series of attributes, which is now superseded + by the method-chained style. + + .. seealso:: + + :ref:`deferred` + + :func:`.orm.defer` + + :func:`.orm.undefer_group` + + """ + return loadopt.set_column_strategy( + (key, ), + {"deferred": False, "instrument": True} + ) + +@undefer._add_unbound_fn +def undefer(key, *addl_attrs): + return _UnboundLoad._from_keys(_UnboundLoad.undefer, (key, ) + addl_attrs, False, {}) + +@loader_option() +def undefer_group(loadopt, name): + """Indicate that columns within the given deferred group name should be undeferred. + + The columns being undeferred are set up on the mapping as + :func:`.deferred` attributes and include a "group" name. + + E.g:: + + session.query(MyClass).options(undefer_group("large_attrs")) + + To undefer a group of attributes on a related entity, the path can be + spelled out using relationship loader options, such as :func:`.orm.defaultload`:: + + session.query(MyClass).options(defaultload("someattr").undefer_group("large_attrs")) + + .. versionchanged:: 0.9.0 :func:`.orm.undefer_group` is now specific to a + particiular entity load path. + + .. seealso:: + + :ref:`deferred` + + :func:`.orm.defer` + + :func:`.orm.undefer` + + """ + return loadopt.set_column_strategy( + "*", + None, + {"undefer_group": name} + ) + +@undefer_group._add_unbound_fn +def undefer_group(name): + return _UnboundLoad().undefer_group(name) + diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py new file mode 100644 index 00000000..cf735fc5 --- /dev/null +++ b/lib/sqlalchemy/orm/sync.py @@ -0,0 +1,118 @@ +# orm/sync.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used for copying data +between instances based on join conditions. + +""" + +from . import exc, util as orm_util, attributes + + +def populate(source, source_mapper, dest, dest_mapper, + synchronize_pairs, uowcommit, flag_cascaded_pks): + source_dict = source.dict + dest_dict = dest.dict + + for l, r in synchronize_pairs: + try: + # inline of source_mapper._get_state_attr_by_column + prop = source_mapper._columntoproperty[l] + value = source.manager[prop.key].impl.get(source, source_dict, + attributes.PASSIVE_OFF) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) + + try: + # inline of dest_mapper._set_state_attr_by_column + prop = dest_mapper._columntoproperty[r] + dest.manager[prop.key].impl.set(dest, dest_dict, value, None) + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + + # technically the "r.primary_key" check isn't + # needed here, but we check for this condition to limit + # how often this logic is invoked for memory/performance + # reasons, since we only need this info for a primary key + # destination. + if flag_cascaded_pks and l.primary_key and \ + r.primary_key and \ + r.references(l): + uowcommit.attributes[("pk_cascaded", dest, r)] = True + + +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if r.primary_key: + raise AssertionError( + "Dependency rule tried to blank-out primary key " + "column '%s' on instance '%s'" % + (r, orm_util.state_str(dest)) + ) + try: + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) + except exc.UnmappedColumnError: + _raise_col_to_prop(True, None, l, dest_mapper, r) + + +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column( + source.obj(), l) + value = source_mapper._get_state_attr_by_column( + source, source.dict, l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue + + +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column( + source, source.dict, l) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + + dict_[r.key] = value + + +def source_modified(uowcommit, source, source_mapper, synchronize_pairs): + """return true if the source object has changes from an old to a + new value on the given synchronize pairs + + """ + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + history = uowcommit.get_attribute_history(source, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if bool(history.deleted): + return True + else: + return False + + +def _raise_col_to_prop(isdest, source_mapper, source_column, + dest_mapper, dest_column): + if isdest: + raise exc.UnmappedColumnError("Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, + dest_mapper)) + else: + raise exc.UnmappedColumnError("Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column)) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py new file mode 100644 index 00000000..2964705a --- /dev/null +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -0,0 +1,646 @@ +# orm/unitofwork.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The internals for the unit of work system. + +The session's flush() process passes objects to a contextual object +here, which assembles flush tasks based on mappers and their properties, +organizes them in order of dependency, and executes. + +""" + +from .. import util, event +from ..util import topological +from . import attributes, persistence, util as orm_util + + +def track_cascade_events(descriptor, prop): + """Establish event listeners on object attributes which handle + cascade-on-set/append. + + """ + key = prop.key + + def append(state, item, initiator): + # process "save_update" cascade rules for when + # an instance is appended to the list of another instance + + if item is None: + return + + sess = state.session + if sess: + if sess._warn_on_events: + sess._flush_warning("collection append") + + prop = state.manager.mapper._props[key] + item_state = attributes.instance_state(item) + if prop._cascade.save_update and \ + (prop.cascade_backrefs or key == initiator.key) and \ + not sess._contains_state(item_state): + sess._save_or_update_state(item_state) + return item + + def remove(state, item, initiator): + if item is None: + return + + sess = state.session + if sess: + + prop = state.manager.mapper._props[key] + + if sess._warn_on_events: + sess._flush_warning( + "collection remove" + if prop.uselist + else "related attribute delete") + + # expunge pending orphans + item_state = attributes.instance_state(item) + if prop._cascade.delete_orphan and \ + item_state in sess._new and \ + prop.mapper._is_orphan(item_state): + sess.expunge(item) + + def set_(state, newvalue, oldvalue, initiator): + # process "save_update" cascade rules for when an instance + # is attached to another instance + if oldvalue is newvalue: + return newvalue + + sess = state.session + if sess: + + if sess._warn_on_events: + sess._flush_warning("related attribute set") + + prop = state.manager.mapper._props[key] + if newvalue is not None: + newvalue_state = attributes.instance_state(newvalue) + if prop._cascade.save_update and \ + (prop.cascade_backrefs or key == initiator.key) and \ + not sess._contains_state(newvalue_state): + sess._save_or_update_state(newvalue_state) + + if oldvalue is not None and \ + oldvalue is not attributes.PASSIVE_NO_RESULT and \ + prop._cascade.delete_orphan: + # possible to reach here with attributes.NEVER_SET ? + oldvalue_state = attributes.instance_state(oldvalue) + + if oldvalue_state in sess._new and \ + prop.mapper._is_orphan(oldvalue_state): + sess.expunge(oldvalue) + return newvalue + + event.listen(descriptor, 'append', append, raw=True, retval=True) + event.listen(descriptor, 'remove', remove, raw=True, retval=True) + event.listen(descriptor, 'set', set_, raw=True, retval=True) + + +class UOWTransaction(object): + def __init__(self, session): + self.session = session + + # dictionary used by external actors to + # store arbitrary state information. + self.attributes = {} + + # dictionary of mappers to sets of + # DependencyProcessors, which are also + # set to be part of the sorted flush actions, + # which have that mapper as a parent. + self.deps = util.defaultdict(set) + + # dictionary of mappers to sets of InstanceState + # items pending for flush which have that mapper + # as a parent. + self.mappers = util.defaultdict(set) + + # a dictionary of Preprocess objects, which gather + # additional states impacted by the flush + # and determine if a flush action is needed + self.presort_actions = {} + + # dictionary of PostSortRec objects, each + # one issues work during the flush within + # a certain ordering. + self.postsort_actions = {} + + # a set of 2-tuples, each containing two + # PostSortRec objects where the second + # is dependent on the first being executed + # first + self.dependencies = set() + + # dictionary of InstanceState-> (isdelete, listonly) + # tuples, indicating if this state is to be deleted + # or insert/updated, or just refreshed + self.states = {} + + # tracks InstanceStates which will be receiving + # a "post update" call. Keys are mappers, + # values are a set of states and a set of the + # columns which should be included in the update. + self.post_update_states = util.defaultdict(lambda: (set(), set())) + + @property + def has_work(self): + return bool(self.states) + + def is_deleted(self, state): + """return true if the given state is marked as deleted + within this uowtransaction.""" + + return state in self.states and self.states[state][0] + + def memo(self, key, callable_): + if key in self.attributes: + return self.attributes[key] + else: + self.attributes[key] = ret = callable_() + return ret + + def remove_state_actions(self, state): + """remove pending actions for a state from the uowtransaction.""" + + isdelete = self.states[state][0] + + self.states[state] = (isdelete, True) + + def get_attribute_history(self, state, key, + passive=attributes.PASSIVE_NO_INITIALIZE): + """facade to attributes.get_state_history(), including + caching of results.""" + + hashkey = ("history", state, key) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + + if hashkey in self.attributes: + history, state_history, cached_passive = self.attributes[hashkey] + # if the cached lookup was "passive" and now + # we want non-passive, do a non-passive lookup and re-cache + + if not cached_passive & attributes.SQL_OK \ + and passive & attributes.SQL_OK: + impl = state.manager[key].impl + history = impl.get_history(state, state.dict, + attributes.PASSIVE_OFF | + attributes.LOAD_AGAINST_COMMITTED) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) + else: + impl = state.manager[key].impl + # TODO: store the history as (state, object) tuples + # so we don't have to keep converting here + history = impl.get_history(state, state.dict, passive | + attributes.LOAD_AGAINST_COMMITTED) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, + passive) + + return state_history + + def has_dep(self, processor): + return (processor, True) in self.presort_actions + + def register_preprocessor(self, processor, fromparent): + key = (processor, fromparent) + if key not in self.presort_actions: + self.presort_actions[key] = Preprocess(processor, fromparent) + + def register_object(self, state, isdelete=False, + listonly=False, cancel_delete=False, + operation=None, prop=None): + if not self.session._contains_state(state): + if not state.deleted and operation is not None: + util.warn("Object of type %s not in session, %s operation " + "along '%s' will not proceed" % + (orm_util.state_class_str(state), operation, prop)) + return False + + if state not in self.states: + mapper = state.manager.mapper + + if mapper not in self.mappers: + self._per_mapper_flush_actions(mapper) + + self.mappers[mapper].add(state) + self.states[state] = (isdelete, listonly) + else: + if not listonly and (isdelete or cancel_delete): + self.states[state] = (isdelete, False) + return True + + def issue_post_update(self, state, post_update_cols): + mapper = state.manager.mapper.base_mapper + states, cols = self.post_update_states[mapper] + states.add(state) + cols.update(post_update_cols) + + def _per_mapper_flush_actions(self, mapper): + saves = SaveUpdateAll(self, mapper.base_mapper) + deletes = DeleteAll(self, mapper.base_mapper) + self.dependencies.add((saves, deletes)) + + for dep in mapper._dependency_processors: + dep.per_property_preprocessors(self) + + for prop in mapper.relationships: + if prop.viewonly: + continue + dep = prop._dependency_processor + dep.per_property_preprocessors(self) + + @util.memoized_property + def _mapper_for_dep(self): + """return a dynamic mapping of (Mapper, DependencyProcessor) to + True or False, indicating if the DependencyProcessor operates + on objects of that Mapper. + + The result is stored in the dictionary persistently once + calculated. + + """ + return util.PopulateDict( + lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop + ) + + def filter_states_for_dep(self, dep, states): + """Filter the given list of InstanceStates to those relevant to the + given DependencyProcessor. + + """ + mapper_for_dep = self._mapper_for_dep + return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]] + + def states_for_mapper_hierarchy(self, mapper, isdelete, listonly): + checktup = (isdelete, listonly) + for mapper in mapper.base_mapper.self_and_descendants: + for state in self.mappers[mapper]: + if self.states[state] == checktup: + yield state + + def _generate_actions(self): + """Generate the full, unsorted collection of PostSortRecs as + well as dependency pairs for this UOWTransaction. + + """ + # execute presort_actions, until all states + # have been processed. a presort_action might + # add new states to the uow. + while True: + ret = False + for action in list(self.presort_actions.values()): + if action.execute(self): + ret = True + if not ret: + break + + # see if the graph of mapper dependencies has cycles. + self.cycles = cycles = topological.find_cycles( + self.dependencies, + list(self.postsort_actions.values())) + + if cycles: + # if yes, break the per-mapper actions into + # per-state actions + convert = dict( + (rec, set(rec.per_state_flush_actions(self))) + for rec in cycles + ) + + # rewrite the existing dependencies to point to + # the per-state actions for those per-mapper actions + # that were broken up. + for edge in list(self.dependencies): + if None in edge or \ + edge[0].disabled or edge[1].disabled or \ + cycles.issuperset(edge): + self.dependencies.remove(edge) + elif edge[0] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[0]]: + self.dependencies.add((dep, edge[1])) + elif edge[1] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[1]]: + self.dependencies.add((edge[0], dep)) + + return set([a for a in self.postsort_actions.values() + if not a.disabled + ] + ).difference(cycles) + + def execute(self): + postsort_actions = self._generate_actions() + + #sort = topological.sort(self.dependencies, postsort_actions) + #print "--------------" + #print "\ndependencies:", self.dependencies + #print "\ncycles:", self.cycles + #print "\nsort:", list(sort) + #print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions) + + # execute + if self.cycles: + for set_ in topological.sort_as_subsets( + self.dependencies, + postsort_actions): + while set_: + n = set_.pop() + n.execute_aggregate(self, set_) + else: + for rec in topological.sort( + self.dependencies, + postsort_actions): + rec.execute(self) + + def finalize_flush_changes(self): + """mark processed objects as clean / deleted after a successful + flush(). + + this method is called within the flush() method after the + execute() method has succeeded and the transaction has been committed. + + """ + states = set(self.states) + isdel = set( + s for (s, (isdelete, listonly)) in self.states.items() + if isdelete + ) + other = states.difference(isdel) + self.session._remove_newly_deleted(isdel) + self.session._register_newly_persistent(other) + + +class IterateMappersMixin(object): + def _mappers(self, uow): + if self.fromparent: + return iter( + m for m in + self.dependency_processor.parent.self_and_descendants + if uow._mapper_for_dep[(m, self.dependency_processor)] + ) + else: + return self.dependency_processor.mapper.self_and_descendants + + +class Preprocess(IterateMappersMixin): + def __init__(self, dependency_processor, fromparent): + self.dependency_processor = dependency_processor + self.fromparent = fromparent + self.processed = set() + self.setup_flush_actions = False + + def execute(self, uow): + delete_states = set() + save_states = set() + + for mapper in self._mappers(uow): + for state in uow.mappers[mapper].difference(self.processed): + (isdelete, listonly) = uow.states[state] + if not listonly: + if isdelete: + delete_states.add(state) + else: + save_states.add(state) + + if delete_states: + self.dependency_processor.presort_deletes(uow, delete_states) + self.processed.update(delete_states) + if save_states: + self.dependency_processor.presort_saves(uow, save_states) + self.processed.update(save_states) + + if (delete_states or save_states): + if not self.setup_flush_actions and ( + self.dependency_processor.\ + prop_has_changes(uow, delete_states, True) or + self.dependency_processor.\ + prop_has_changes(uow, save_states, False) + ): + self.dependency_processor.per_property_flush_actions(uow) + self.setup_flush_actions = True + return True + else: + return False + + +class PostSortRec(object): + disabled = False + + def __new__(cls, uow, *args): + key = (cls, ) + args + if key in uow.postsort_actions: + return uow.postsort_actions[key] + else: + uow.postsort_actions[key] = \ + ret = \ + object.__new__(cls) + return ret + + def execute_aggregate(self, uow, recs): + self.execute(uow) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + ",".join(str(x) for x in self.__dict__.values()) + ) + + +class ProcessAll(IterateMappersMixin, PostSortRec): + def __init__(self, uow, dependency_processor, delete, fromparent): + self.dependency_processor = dependency_processor + self.delete = delete + self.fromparent = fromparent + uow.deps[dependency_processor.parent.base_mapper].\ + add(dependency_processor) + + def execute(self, uow): + states = self._elements(uow) + if self.delete: + self.dependency_processor.process_deletes(uow, states) + else: + self.dependency_processor.process_saves(uow, states) + + def per_state_flush_actions(self, uow): + # this is handled by SaveUpdateAll and DeleteAll, + # since a ProcessAll should unconditionally be pulled + # into per-state if either the parent/child mappers + # are part of a cycle + return iter([]) + + def __repr__(self): + return "%s(%s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + self.delete + ) + + def _elements(self, uow): + for mapper in self._mappers(uow): + for state in uow.mappers[mapper]: + (isdelete, listonly) = uow.states[state] + if isdelete == self.delete and not listonly: + yield state + + +class IssuePostUpdate(PostSortRec): + def __init__(self, uow, mapper, isdelete): + self.mapper = mapper + self.isdelete = isdelete + + def execute(self, uow): + states, cols = uow.post_update_states[self.mapper] + states = [s for s in states if uow.states[s][0] == self.isdelete] + + persistence.post_update(self.mapper, states, uow, cols) + + +class SaveUpdateAll(PostSortRec): + def __init__(self, uow, mapper): + self.mapper = mapper + assert mapper is mapper.base_mapper + + def execute(self, uow): + persistence.save_obj(self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, False, False), + uow + ) + + def per_state_flush_actions(self, uow): + states = list(uow.states_for_mapper_hierarchy( + self.mapper, False, False)) + base_mapper = self.mapper.base_mapper + delete_all = DeleteAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = SaveUpdateState(uow, state, base_mapper) + uow.dependencies.add((action, delete_all)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, False) + + +class DeleteAll(PostSortRec): + def __init__(self, uow, mapper): + self.mapper = mapper + assert mapper is mapper.base_mapper + + def execute(self, uow): + persistence.delete_obj(self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, True, False), + uow + ) + + def per_state_flush_actions(self, uow): + states = list(uow.states_for_mapper_hierarchy( + self.mapper, True, False)) + base_mapper = self.mapper.base_mapper + save_all = SaveUpdateAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = DeleteState(uow, state, base_mapper) + uow.dependencies.add((save_all, action)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, True) + + +class ProcessState(PostSortRec): + def __init__(self, uow, dependency_processor, delete, state): + self.dependency_processor = dependency_processor + self.delete = delete + self.state = state + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + dependency_processor = self.dependency_processor + delete = self.delete + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.dependency_processor is dependency_processor and + r.delete is delete] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + if delete: + dependency_processor.process_deletes(uow, states) + else: + dependency_processor.process_saves(uow, states) + + def __repr__(self): + return "%s(%s, %s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + orm_util.state_str(self.state), + self.delete + ) + + +class SaveUpdateState(PostSortRec): + def __init__(self, uow, state, mapper): + self.state = state + self.mapper = mapper + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + persistence.save_obj(mapper, + [self.state] + + [r.state for r in our_recs], + uow) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state) + ) + + +class DeleteState(PostSortRec): + def __init__(self, uow, state, mapper): + self.state = state + self.mapper = mapper + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + persistence.delete_obj(mapper, + [s for s in states if uow.states[s][0]], + uow) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state) + ) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py new file mode 100644 index 00000000..dd85f2ef --- /dev/null +++ b/lib/sqlalchemy/orm/util.py @@ -0,0 +1,953 @@ +# orm/util.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +from .. import sql, util, event, exc as sa_exc, inspection +from ..sql import expression, util as sql_util, operators +from .interfaces import PropComparator, MapperProperty +from . import attributes +import re + +from .base import instance_str, state_str, state_class_str, attribute_str, \ + state_attribute_str, object_mapper, object_state, _none_set +from .base import class_mapper, _class_to_mapper +from .base import _InspectionAttr +from .path_registry import PathRegistry + +all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", + "expunge", "save-update", "refresh-expire", + "none")) + + +class CascadeOptions(frozenset): + """Keeps track of the options sent to relationship().cascade""" + + _add_w_all_cascades = all_cascades.difference([ + 'all', 'none', 'delete-orphan']) + _allowed_cascades = all_cascades + + def __new__(cls, arg): + values = set([ + c for c + in re.split('\s*,\s*', arg or "") + if c + ]) + + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" % + ", ".join([repr(x) for x in + sorted( + values.difference(cls._allowed_cascades) + )]) + ) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard('all') + + self = frozenset.__new__(CascadeOptions, values) + self.save_update = 'save-update' in values + self.delete = 'delete' in values + self.refresh_expire = 'refresh-expire' in values + self.merge = 'merge' in values + self.expunge = 'expunge' in values + self.delete_orphan = "delete-orphan" in values + + if self.delete_orphan and not self.delete: + util.warn("The 'delete-orphan' cascade " + "option requires 'delete'.") + return self + + def __repr__(self): + return "CascadeOptions(%r)" % ( + ",".join([x for x in sorted(self)]) + ) + + +def _validator_events(desc, key, validator, include_removes, include_backrefs): + """Runs a validation method on an attribute value to be set or appended.""" + + if not include_backrefs: + def detect_is_backref(state, initiator): + impl = state.manager[key].impl + return initiator.impl is not impl + + if include_removes: + def append(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value + + def remove(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + validator(state.obj(), key, value, True) + + else: + def append(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value + + event.listen(desc, 'append', append, raw=True, retval=True) + event.listen(desc, 'set', set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) + + +def polymorphic_union(table_map, typecolname, + aliasname='p_union', cast_nulls=True): + """Create a ``UNION`` statement used by a polymorphic mapper. + + See :ref:`concrete_inheritance` for an example of how + this is used. + + :param table_map: mapping of polymorphic identities to + :class:`.Table` objects. + :param typecolname: string name of a "discriminator" column, which will be + derived from the query, producing the polymorphic identity for + each row. If ``None``, no polymorphic discriminator is generated. + :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()` + construct generated. + :param cast_nulls: if True, non-existent columns, which are represented + as labeled NULLs, will be passed into CAST. This is a legacy behavior + that is problematic on some backends such as Oracle - in which case it + can be set to False. + + """ + + colnames = util.OrderedSet() + colnamemaps = {} + types = {} + for key in table_map: + table = table_map[key] + + # mysql doesnt like selecting from a select; + # make it an alias of the select + if isinstance(table, sql.Select): + table = table.alias() + table_map[key] = table + + m = {} + for c in table.c: + colnames.add(c.key) + m[c.key] = c + types[c.key] = c.type + colnamemaps[table] = m + + def col(name, table): + try: + return colnamemaps[table][name] + except KeyError: + if cast_nulls: + return sql.cast(sql.null(), types[name]).label(name) + else: + return sql.type_coerce(sql.null(), types[name]).label(name) + + result = [] + for type, table in table_map.items(): + if typecolname is not None: + result.append( + sql.select([col(name, table) for name in colnames] + + [sql.literal_column(sql_util._quote_ddl_expr(type)). + label(typecolname)], + from_obj=[table])) + else: + result.append(sql.select([col(name, table) for name in colnames], + from_obj=[table])) + return sql.union_all(*result).alias(aliasname) + + +def identity_key(*args, **kwargs): + """Generate "identity key" tuples, as are used as keys in the + :attr:`.Session.identity_map` dictionary. + + This function has several call styles: + + * ``identity_key(class, ident)`` + + This form receives a mapped class and a primary key scalar or + tuple as an argument. + + E.g.:: + + >>> identity_key(MyClass, (1, 2)) + (, (1, 2)) + + :param class: mapped class (must be a positional argument) + :param ident: primary key, may be a scalar or tuple argument. + + + * ``identity_key(instance=instance)`` + + This form will produce the identity key for a given instance. The + instance need not be persistent, only that its primary key attributes + are populated (else the key will contain ``None`` for those missing + values). + + E.g.:: + + >>> instance = MyClass(1, 2) + >>> identity_key(instance=instance) + (, (1, 2)) + + In this form, the given instance is ultimately run though + :meth:`.Mapper.identity_key_from_instance`, which will have the + effect of performing a database check for the corresponding row + if the object is expired. + + :param instance: object instance (must be given as a keyword arg) + + * ``identity_key(class, row=row)`` + + This form is similar to the class/tuple form, except is passed a + database result row as a :class:`.RowProxy` object. + + E.g.:: + + >>> row = engine.execute("select * from table where a=1 and b=2").first() + >>> identity_key(MyClass, row=row) + (, (1, 2)) + + :param class: mapped class (must be a positional argument) + :param row: :class:`.RowProxy` row returned by a :class:`.ResultProxy` + (must be given as a keyword arg) + + """ + if args: + if len(args) == 1: + class_ = args[0] + try: + row = kwargs.pop("row") + except KeyError: + ident = kwargs.pop("ident") + elif len(args) == 2: + class_, ident = args + elif len(args) == 3: + class_, ident = args + else: + raise sa_exc.ArgumentError("expected up to three " + "positional arguments, got %s" % len(args)) + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs)) + mapper = class_mapper(class_) + if "ident" in locals(): + return mapper.identity_key_from_primary_key(util.to_list(ident)) + return mapper.identity_key_from_row(row) + instance = kwargs.pop("instance") + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys)) + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) + + +class ORMAdapter(sql_util.ColumnAdapter): + """Extends ColumnAdapter to accept ORM entities. + + The selectable is extracted from the given entity, + and the AliasedClass if any is referenced. + + """ + def __init__(self, entity, equivalents=None, adapt_required=False, + chain_to=None): + info = inspection.inspect(entity) + + self.mapper = info.mapper + selectable = info.selectable + is_aliased_class = info.is_aliased_class + if is_aliased_class: + self.aliased_class = entity + else: + self.aliased_class = None + sql_util.ColumnAdapter.__init__(self, selectable, + equivalents, chain_to, + adapt_required=adapt_required) + + def replace(self, elem): + entity = elem._annotations.get('parentmapper', None) + if not entity or entity.isa(self.mapper): + return sql_util.ColumnAdapter.replace(self, elem) + else: + return None + +class AliasedClass(object): + """Represents an "aliased" form of a mapped class for usage with Query. + + The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias` + construct, this object mimics the mapped class using a + __getattr__ scheme and maintains a reference to a + real :class:`~sqlalchemy.sql.expression.Alias` object. + + Usage is via the :func:`.orm.aliased` function, or alternatively + via the :func:`.orm.with_polymorphic` function. + + Usage example:: + + # find all pairs of users with the same name + user_alias = aliased(User) + session.query(User, user_alias).\\ + join((user_alias, User.id > user_alias.id)).\\ + filter(User.name==user_alias.name) + + The resulting object is an instance of :class:`.AliasedClass`. + This object implements an attribute scheme which produces the + same attribute and method interface as the original mapped + class, allowing :class:`.AliasedClass` to be compatible + with any attribute technique which works on the original class, + including hybrid attributes (see :ref:`hybrids_toplevel`). + + The :class:`.AliasedClass` can be inspected for its underlying + :class:`.Mapper`, aliased selectable, and other information + using :func:`.inspect`:: + + from sqlalchemy import inspect + my_alias = aliased(MyClass) + insp = inspect(my_alias) + + The resulting inspection object is an instance of :class:`.AliasedInsp`. + + See :func:`.aliased` and :func:`.with_polymorphic` for construction + argument descriptions. + + """ + def __init__(self, cls, alias=None, + name=None, + flat=False, + adapt_on_names=False, + # TODO: None for default here? + with_polymorphic_mappers=(), + with_polymorphic_discriminator=None, + base_alias=None, + use_mapper_path=False): + mapper = _class_to_mapper(cls) + if alias is None: + alias = mapper._with_polymorphic_selectable.alias( + name=name, flat=flat) + self._aliased_insp = AliasedInsp( + self, + mapper, + alias, + name, + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers, + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on, + base_alias, + use_mapper_path, + adapt_on_names + ) + + self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__ + + def __getattr__(self, key): + try: + _aliased_insp = self.__dict__['_aliased_insp'] + except KeyError: + raise AttributeError() + else: + for base in _aliased_insp._target.__mro__: + try: + attr = object.__getattribute__(base, key) + except AttributeError: + continue + else: + break + else: + raise AttributeError(key) + + if isinstance(attr, PropComparator): + ret = attr.adapt_to_entity(_aliased_insp) + setattr(self, key, ret) + return ret + elif hasattr(attr, 'func_code'): + is_method = getattr(_aliased_insp._target, key, None) + if is_method and is_method.__self__ is not None: + return util.types.MethodType(attr.__func__, self, self) + else: + return None + elif hasattr(attr, '__get__'): + ret = attr.__get__(None, self) + if isinstance(ret, PropComparator): + return ret.adapt_to_entity(_aliased_insp) + else: + return ret + else: + return attr + + def __repr__(self): + return '' % ( + id(self), self._aliased_insp._target.__name__) + + +class AliasedInsp(_InspectionAttr): + """Provide an inspection interface for an + :class:`.AliasedClass` object. + + The :class:`.AliasedInsp` object is returned + given an :class:`.AliasedClass` using the + :func:`.inspect` function:: + + from sqlalchemy import inspect + from sqlalchemy.orm import aliased + + my_alias = aliased(MyMappedClass) + insp = inspect(my_alias) + + Attributes on :class:`.AliasedInsp` + include: + + * ``entity`` - the :class:`.AliasedClass` represented. + * ``mapper`` - the :class:`.Mapper` mapping the underlying class. + * ``selectable`` - the :class:`.Alias` construct which ultimately + represents an aliased :class:`.Table` or :class:`.Select` + construct. + * ``name`` - the name of the alias. Also is used as the attribute + name when returned in a result tuple from :class:`.Query`. + * ``with_polymorphic_mappers`` - collection of :class:`.Mapper` objects + indicating all those mappers expressed in the select construct + for the :class:`.AliasedClass`. + * ``polymorphic_on`` - an alternate column or SQL expression which + will be used as the "discriminator" for a polymorphic load. + + .. seealso:: + + :ref:`inspection_toplevel` + + """ + + def __init__(self, entity, mapper, selectable, name, + with_polymorphic_mappers, polymorphic_on, + _base_alias, _use_mapper_path, adapt_on_names): + self.entity = entity + self.mapper = mapper + self.selectable = selectable + self.name = name + self.with_polymorphic_mappers = with_polymorphic_mappers + self.polymorphic_on = polymorphic_on + self._base_alias = _base_alias or self + self._use_mapper_path = _use_mapper_path + + self._adapter = sql_util.ClauseAdapter(selectable, + equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names) + + self._adapt_on_names = adapt_on_names + self._target = mapper.class_ + + for poly in self.with_polymorphic_mappers: + if poly is not mapper: + setattr(self.entity, poly.class_.__name__, + AliasedClass(poly.class_, selectable, base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path)) + + is_aliased_class = True + "always returns True" + + @property + def class_(self): + """Return the mapped class ultimately represented by this + :class:`.AliasedInsp`.""" + return self.mapper.class_ + + @util.memoized_property + def _path_registry(self): + if self._use_mapper_path: + return self.mapper._path_registry + else: + return PathRegistry.per_mapper(self) + + def __getstate__(self): + return { + 'entity': self.entity, + 'mapper': self.mapper, + 'alias': self.selectable, + 'name': self.name, + 'adapt_on_names': self._adapt_on_names, + 'with_polymorphic_mappers': + self.with_polymorphic_mappers, + 'with_polymorphic_discriminator': + self.polymorphic_on, + 'base_alias': self._base_alias, + 'use_mapper_path': self._use_mapper_path + } + + def __setstate__(self, state): + self.__init__( + state['entity'], + state['mapper'], + state['alias'], + state['name'], + state['with_polymorphic_mappers'], + state['with_polymorphic_discriminator'], + state['base_alias'], + state['use_mapper_path'], + state['adapt_on_names'] + ) + + def _adapt_element(self, elem): + return self._adapter.traverse(elem).\ + _annotate({ + 'parententity': self.entity, + 'parentmapper': self.mapper} + ) + + def _entity_for_mapper(self, mapper): + self_poly = self.with_polymorphic_mappers + if mapper in self_poly: + return getattr(self.entity, mapper.class_.__name__)._aliased_insp + elif mapper.isa(self.mapper): + return self + else: + assert False, "mapper %s doesn't correspond to %s" % (mapper, self) + + def __repr__(self): + return '' % ( + id(self), self.class_.__name__) + + +inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) +inspection._inspects(AliasedInsp)(lambda target: target) + + +def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + + The :func:`.aliased` function is used to create an ad-hoc mapping + of a mapped class to a new selectable. By default, a selectable + is generated from the normally mapped selectable (typically a + :class:`.Table`) using the :meth:`.FromClause.alias` method. + However, :func:`.aliased` can also be used to link the class to + a new :func:`.select` statement. Also, the :func:`.with_polymorphic` + function is a variant of :func:`.aliased` that is intended to specify + a so-called "polymorphic selectable", that corresponds to the union + of several joined-inheritance subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`.FromClause` constructs, such as a :class:`.Table` or + :func:`.select` construct. In those cases, the :meth:`.FromClause.alias` + method is called on the object and the new :class:`.Alias` object + returned. The returned :class:`.Alias` is not ORM-mapped in this case. + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`.FromClause` element. + + :param alias: Optional selectable unit to map the element to. This should + normally be a :class:`.Alias` object corresponding to the :class:`.Table` + to which the class is mapped, or to a :func:`.select` construct that + is compatible with the mapping. By default, a simple anonymous + alias of the mapped table is generated. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`.Query` object. + + :param flat: Boolean, will be passed through to the :meth:`.FromClause.alias` + call so that aliases of :class:`.Join` objects don't include an enclosing + SELECT. This can lead to more efficient queries in many circumstances. + A JOIN against a nested JOIN will be rewritten as a JOIN against an aliased + SELECT subquery on backends that don't support this syntax. + + .. versionadded:: 0.9.0 + + .. seealso:: :meth:`.Join.alias` + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``fund.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + .. versionadded:: 0.7.3 + + + """ + if isinstance(element, expression.FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + return element.alias(name, flat=flat) + else: + return AliasedClass(element, alias=alias, flat=flat, + name=name, adapt_on_names=adapt_on_names) + + +def with_polymorphic(base, classes, selectable=False, + flat=False, + polymorphic_on=None, aliased=False, + innerjoin=False, _use_mapper_path=False): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + .. versionadded:: 0.8 + :func:`.orm.with_polymorphic` is in addition to the existing + :class:`.Query` method :meth:`.Query.with_polymorphic`, + which has the same purpose but is not as flexible in its usage. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + See the examples at :ref:`with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be wrapped in an + alias, that is ``(SELECT * FROM ) AS anon_1``. + This can be important when using the with_polymorphic() + to create the target of a JOIN on a backend that does not + support parenthesized joins, such as SQLite and older + versions of MySQL. + + :param flat: Boolean, will be passed through to the :meth:`.FromClause.alias` + call so that aliases of :class:`.Join` objects don't include an enclosing + SELECT. This can lead to more efficient queries in many circumstances. + A JOIN against a nested JOIN will be rewritten as a JOIN against an aliased + SELECT subquery on backends that don't support this syntax. + + Setting ``flat`` to ``True`` implies the ``aliased`` flag is + also ``True``. + + .. versionadded:: 0.9.0 + + .. seealso:: :meth:`.Join.alias` + + :param selectable: a table or select() statement that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + """ + primary_mapper = _class_to_mapper(base) + mappers, selectable = primary_mapper.\ + _with_polymorphic_args(classes, selectable, + innerjoin=innerjoin) + if aliased or flat: + selectable = selectable.alias(flat=flat) + return AliasedClass(base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path) + + +def _orm_annotate(element, exclude=None): + """Deep copy the given ClauseElement, annotating each element with the + "_orm_adapt" flag. + + Elements within the exclude collection will be cloned but not annotated. + + """ + return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude) + + +def _orm_deannotate(element): + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`.orm.foreign` and :func:`.orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate(element, + values=("_orm_adapt", "parententity") + ) + + +def _orm_full_deannotate(element): + return sql_util._deep_deannotate(element) + + +class _ORMJoin(expression.Join): + """Extend Join to support ORM constructs as input.""" + + __visit_name__ = expression.Join.__visit_name__ + + def __init__(self, left, right, onclause=None, isouter=False): + + left_info = inspection.inspect(left) + left_orm_info = getattr(left, '_joined_from_info', left_info) + + right_info = inspection.inspect(right) + adapt_to = right_info.selectable + + self._joined_from_info = right_info + + if isinstance(onclause, util.string_types): + onclause = getattr(left_orm_info.entity, onclause) + + if isinstance(onclause, attributes.QueryableAttribute): + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + elif isinstance(onclause, MapperProperty): + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + + if prop: + if sql_util.clause_is_present(on_selectable, left_info.selectable): + adapt_from = on_selectable + else: + adapt_from = left_info.selectable + + pj, sj, source, dest, \ + secondary, target_adapter = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + dest_polymorphic=True, + of_type=right_info.mapper) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + self._target_adapter = target_adapter + + expression.Join.__init__(self, left, right, onclause, isouter) + + def join(self, right, onclause=None, isouter=False, join_to_left=None): + return _ORMJoin(self, right, onclause, isouter) + + def outerjoin(self, right, onclause=None, join_to_left=None): + return _ORMJoin(self, right, onclause, True) + + +def join(left, right, onclause=None, isouter=False, join_to_left=None): + """Produce an inner join between left and right clauses. + + :func:`.orm.join` is an extension to the core join interface + provided by :func:`.sql.expression.join()`, where the + left and right selectables may be not only core selectable + objects such as :class:`.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression, or an attribute or string name + referencing a configured :func:`.relationship`. + + :func:`.orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`.Query.join` method, which features a + significant amount of automation beyond :func:`.orm.join` + by itself. Explicit usage of :func:`.orm.join` + with :class:`.Query` involves usage of the + :meth:`.Query.select_from` method, as in:: + + from sqlalchemy.orm import join + session.query(User).\\ + select_from(join(User, Address, User.addresses)).\\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + session.query(User).\\ + join(User.addresses).\\ + filter(Address.email_address=='foo@bar.com') + + See :meth:`.Query.join` for information on modern usage + of ORM level joins. + + .. versionchanged:: 0.8.1 - the ``join_to_left`` parameter + is no longer used, and is deprecated. + + """ + return _ORMJoin(left, right, onclause, isouter) + + +def outerjoin(left, right, onclause=None, join_to_left=None): + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`.orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True) + + +def with_parent(instance, prop): + """Create filtering criterion that relates this query's primary entity + to the given related instance, using established :func:`.relationship()` + configuration. + + The SQL rendered is the same as that rendered when a lazy loader + would fire off from the given parent on that attribute, meaning + that the appropriate state is taken from the parent object in + Python without the need to render joins to the parent table + in the rendered statement. + + .. versionchanged:: 0.6.4 + This method accepts parent instances in all + persistence states, including transient, persistent, and detached. + Only the requisite primary key/foreign key attributes need to + be populated. Previous versions didn't work with transient + instances. + + :param instance: + An instance which has some :func:`.relationship`. + + :param property: + String property name, or class-bound attribute, which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. + + """ + if isinstance(prop, util.string_types): + mapper = object_mapper(instance) + prop = getattr(mapper.class_, prop).property + elif isinstance(prop, attributes.QueryableAttribute): + prop = prop.property + + return prop.compare(operators.eq, + instance, + value_is_parent=True) + + + +def has_identity(object): + """Return True if the given object has a database + identity. + + This typically corresponds to the object being + in either the persistent or detached state. + + .. seealso:: + + :func:`.was_deleted` + + """ + state = attributes.instance_state(object) + return state.has_identity + +def was_deleted(object): + """Return True if the given object was deleted + within a session flush. + + .. versionadded:: 0.8.0 + + """ + + state = attributes.instance_state(object) + return state.deleted + + + + +def randomize_unitofwork(): + """Use random-ordering sets within the unit of work in order + to detect unit of work sorting issues. + + This is a utility function that can be used to help reproduce + inconsistent unit of work sorting issues. For example, + if two kinds of objects A and B are being inserted, and + B has a foreign key reference to A - the A must be inserted first. + However, if there is no relationship between A and B, the unit of work + won't know to perform this sorting, and an operation may or may not + fail, depending on how the ordering works out. Since Python sets + and dictionaries have non-deterministic ordering, such an issue may + occur on some runs and not on others, and in practice it tends to + have a great dependence on the state of the interpreter. This leads + to so-called "heisenbugs" where changing entirely irrelevant aspects + of the test program still cause the failure behavior to change. + + By calling ``randomize_unitofwork()`` when a script first runs, the + ordering of a key series of sets within the unit of work implementation + are randomized, so that the script can be minimized down to the fundamental + mapping and operation that's failing, while still reproducing the issue + on at least some runs. + + This utility is also available when running the test suite via the + ``--reversetop`` flag. + + .. versionadded:: 0.8.1 created a standalone version of the + ``--reversetop`` feature. + + """ + from sqlalchemy.orm import unitofwork, session, mapper, dependency + from sqlalchemy.util import topological + from sqlalchemy.testing.util import RandomSet + topological.set = unitofwork.set = session.set = mapper.set = \ + dependency.set = RandomSet + diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py new file mode 100644 index 00000000..79944354 --- /dev/null +++ b/lib/sqlalchemy/pool.py @@ -0,0 +1,1250 @@ +# sqlalchemy/pool.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +"""Connection pooling for DB-API connections. + +Provides a number of connection pool implementations for a variety of +usage scenarios and thread behavior requirements imposed by the +application, DB-API or database itself. + +Also provides a DB-API 2.0 connection proxying mechanism allowing +regular DB-API connect() methods to be transparently managed by a +SQLAlchemy connection pool. +""" + +import time +import traceback +import weakref + +from . import exc, log, event, interfaces, util +from .util import queue as sqla_queue +from .util import threading, memoized_property, \ + chop_traceback + +from collections import deque +proxies = {} + + +def manage(module, **params): + """Return a proxy for a DB-API module that automatically + pools connections. + + Given a DB-API 2.0 module and pool management parameters, returns + a proxy for the module that will automatically pool connections, + creating new connection pools for each distinct set of connection + arguments sent to the decorated module's connect() function. + + :param module: a DB-API 2.0 database module + + :param poolclass: the class used by the pool module to provide + pooling. Defaults to :class:`.QueuePool`. + + :param \*\*params: will be passed through to *poolclass* + + """ + try: + return proxies[module] + except KeyError: + return proxies.setdefault(module, _DBProxy(module, **params)) + + +def clear_managers(): + """Remove all current DB-API 2.0 managers. + + All pools and connections are disposed. + """ + + for manager in proxies.values(): + manager.close() + proxies.clear() + +reset_rollback = util.symbol('reset_rollback') +reset_commit = util.symbol('reset_commit') +reset_none = util.symbol('reset_none') + +class _ConnDialect(object): + """partial implementation of :class:`.Dialect` + which provides DBAPI connection methods. + + When a :class:`.Pool` is combined with an :class:`.Engine`, + the :class:`.Engine` replaces this with its own + :class:`.Dialect`. + + """ + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def do_close(self, dbapi_connection): + dbapi_connection.close() + +class Pool(log.Identified): + """Abstract base class for connection pools.""" + + _dialect = _ConnDialect() + + def __init__(self, + creator, recycle=-1, echo=None, + use_threadlocal=False, + logging_name=None, + reset_on_return=True, + listeners=None, + events=None, + _dispatch=None, + _dialect=None): + """ + Construct a Pool. + + :param creator: a callable function that returns a DB-API + connection object. The function will be called with + parameters. + + :param recycle: If set to non -1, number of seconds between + connection recycling, which means upon checkout, if this + timeout is surpassed the connection will be closed and + replaced with a newly opened connection. Defaults to -1. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param echo: If True, connections being pulled and retrieved + from the pool will be logged to the standard output, as well + as pool sizing information. Echoing can also be achieved by + enabling logging for the "sqlalchemy.pool" + namespace. Defaults to False. + + :param use_threadlocal: If set to True, repeated calls to + :meth:`connect` within the same application thread will be + guaranteed to return the same connection object, if one has + already been retrieved from the pool and has not been + returned yet. Offers a slight performance advantage at the + cost of individual transactions by default. The + :meth:`.Pool.unique_connection` method is provided to return + a consistenty unique connection to bypass this behavior + when the flag is set. + + .. warning:: The :paramref:`.Pool.use_threadlocal` flag + **does not affect the behavior** of :meth:`.Engine.connect`. + :meth:`.Engine.connect` makes use of the :meth:`.Pool.unique_connection` + method which **does not use thread local context**. + To produce a :class:`.Connection` which refers to the + :meth:`.Pool.connect` method, use + :meth:`.Engine.contextual_connect`. + + Note that other SQLAlchemy connectivity systems such as + :meth:`.Engine.execute` as well as the orm + :class:`.Session` make use of + :meth:`.Engine.contextual_connect` internally, so these functions + are compatible with the :paramref:`.Pool.use_threadlocal` setting. + + .. seealso:: + + :ref:`threadlocal_strategy` - contains detail on the + "threadlocal" engine strategy, which provides a more comprehensive + approach to "threadlocal" connectivity for the specific + use case of using :class:`.Engine` and :class:`.Connection` objects + directly. + + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool. + reset_on_return can have any of these values: + + * ``"rollback"`` - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * ``True`` - same as 'rollback', this is here for + backwards compatibility. + * ``"commit"`` - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * ``None`` - don't do anything on the connection. + This setting should only be made on a database + that has no transaction support at all, + namely MySQL MyISAM. By not doing anything, + performance can be improved. This + setting should **never be selected** for a + database that supports transactions, + as it will lead to deadlocks and stale + state. + * ``False`` - same as None, this is here for + backwards compatibility. + + .. versionchanged:: 0.7.6 + :paramref:`.Pool.reset_on_return` accepts ``"rollback"`` + and ``"commit"`` arguments. + + :param events: a list of 2-tuples, each of the form + ``(callable, target)`` which will be passed to :func:`.event.listen` + upon construction. Provided here so that event listeners + can be assigned via :func:`.create_engine` before dialect-level + listeners are applied. + + :param listeners: Deprecated. A list of + :class:`~sqlalchemy.interfaces.PoolListener`-like objects or + dictionaries of callables that receive events when DB-API + connections are created, checked out and checked in to the + pool. This has been superseded by + :func:`~sqlalchemy.event.listen`. + + """ + if logging_name: + self.logging_name = self._orig_logging_name = logging_name + else: + self._orig_logging_name = None + + log.instance_logger(self, echoflag=echo) + self._threadconns = threading.local() + self._creator = creator + self._recycle = recycle + self._invalidate_time = 0 + self._use_threadlocal = use_threadlocal + if reset_on_return in ('rollback', True, reset_rollback): + self._reset_on_return = reset_rollback + elif reset_on_return in (None, False, reset_none): + self._reset_on_return = reset_none + elif reset_on_return in ('commit', reset_commit): + self._reset_on_return = reset_commit + else: + raise exc.ArgumentError( + "Invalid value for 'reset_on_return': %r" + % reset_on_return) + + self.echo = echo + if _dispatch: + self.dispatch._update(_dispatch, only_propagate=False) + if _dialect: + self._dialect = _dialect + if events: + for fn, target in events: + event.listen(self, target, fn) + if listeners: + util.warn_deprecated( + "The 'listeners' argument to Pool (and " + "create_engine()) is deprecated. Use event.listen().") + for l in listeners: + self.add_listener(l) + + def _close_connection(self, connection): + self.logger.debug("Closing connection %r", connection) + try: + self._dialect.do_close(connection) + except (SystemExit, KeyboardInterrupt): + raise + except: + self.logger.error("Exception closing connection %r", + connection, exc_info=True) + + @util.deprecated( + 2.7, "Pool.add_listener is deprecated. Use event.listen()") + def add_listener(self, listener): + """Add a :class:`.PoolListener`-like object to this pool. + + ``listener`` may be an object that implements some or all of + PoolListener, or a dictionary of callables containing implementations + of some or all of the named methods in PoolListener. + + """ + interfaces.PoolListener._adapt_listener(self, listener) + + def unique_connection(self): + """Produce a DBAPI connection that is not referenced by any + thread-local context. + + This method is equivalent to :meth:`.Pool.connect` when the + :paramref:`.Pool.use_threadlocal` flag is not set to True. + When :paramref:`.Pool.use_threadlocal` is True, the :meth:`.Pool.unique_connection` + method provides a means of bypassing the threadlocal context. + + """ + return _ConnectionFairy._checkout(self) + + def _create_connection(self): + """Called by subclasses to create a new ConnectionRecord.""" + + return _ConnectionRecord(self) + + def _invalidate(self, connection, exception=None): + """Mark all connections established within the generation + of the given connection as invalidated. + + If this pool's last invalidate time is before when the given + connection was created, update the timestamp til now. Otherwise, + no action is performed. + + Connections with a start time prior to this pool's invalidation + time will be recycled upon next checkout. + """ + rec = getattr(connection, "_connection_record", None) + if not rec or self._invalidate_time < rec.starttime: + self._invalidate_time = time.time() + if getattr(connection, 'is_valid', False): + connection.invalidate(exception) + + + def recreate(self): + """Return a new :class:`.Pool`, of the same class as this one + and configured with identical creation arguments. + + This method is used in conjunection with :meth:`dispose` + to close out an entire :class:`.Pool` and create a new one in + its place. + + """ + + raise NotImplementedError() + + def dispose(self): + """Dispose of this pool. + + This method leaves the possibility of checked-out connections + remaining open, as it only affects connections that are + idle in the pool. + + See also the :meth:`Pool.recreate` method. + + """ + + raise NotImplementedError() + + def connect(self): + """Return a DBAPI connection from the pool. + + The connection is instrumented such that when its + ``close()`` method is called, the connection will be returned to + the pool. + + """ + if not self._use_threadlocal: + return _ConnectionFairy._checkout(self) + + try: + rec = self._threadconns.current() + except AttributeError: + pass + else: + if rec is not None: + return rec._checkout_existing() + + return _ConnectionFairy._checkout(self, self._threadconns) + + def _return_conn(self, record): + """Given a _ConnectionRecord, return it to the :class:`.Pool`. + + This method is called when an instrumented DBAPI connection + has its ``close()`` method called. + + """ + if self._use_threadlocal: + try: + del self._threadconns.current + except AttributeError: + pass + self._do_return_conn(record) + + def _do_get(self): + """Implementation for :meth:`get`, supplied by subclasses.""" + + raise NotImplementedError() + + def _do_return_conn(self, conn): + """Implementation for :meth:`return_conn`, supplied by subclasses.""" + + raise NotImplementedError() + + def status(self): + raise NotImplementedError() + + +class _ConnectionRecord(object): + """Internal object which maintains an individual DBAPI connection + referenced by a :class:`.Pool`. + + The :class:`._ConnectionRecord` object always exists for any particular + DBAPI connection whether or not that DBAPI connection has been + "checked out". This is in contrast to the :class:`._ConnectionFairy` + which is only a public facade to the DBAPI connection while it is checked + out. + + A :class:`._ConnectionRecord` may exist for a span longer than that + of a single DBAPI connection. For example, if the + :meth:`._ConnectionRecord.invalidate` + method is called, the DBAPI connection associated with this + :class:`._ConnectionRecord` + will be discarded, but the :class:`._ConnectionRecord` may be used again, + in which case a new DBAPI connection is produced when the :class:`.Pool` + next uses this record. + + The :class:`._ConnectionRecord` is delivered along with connection + pool events, including :meth:`.PoolEvents.connect` and + :meth:`.PoolEvents.checkout`, however :class:`._ConnectionRecord` still + remains an internal object whose API and internals may change. + + .. seealso:: + + :class:`._ConnectionFairy` + + """ + + def __init__(self, pool): + self.__pool = pool + self.connection = self.__connect() + self.finalize_callback = deque() + + pool.dispatch.first_connect.\ + for_modify(pool.dispatch).\ + exec_once(self.connection, self) + pool.dispatch.connect(self.connection, self) + + connection = None + """A reference to the actual DBAPI connection being tracked. + + May be ``None`` if this :class:`._ConnectionRecord` has been marked + as invalidated; a new DBAPI connection may replace it if the owning + pool calls upon this :class:`._ConnectionRecord` to reconnect. + + """ + + @util.memoized_property + def info(self): + """The ``.info`` dictionary associated with the DBAPI connection. + + This dictionary is shared among the :attr:`._ConnectionFairy.info` + and :attr:`.Connection.info` accessors. + + """ + return {} + + @classmethod + def checkout(cls, pool): + rec = pool._do_get() + try: + dbapi_connection = rec.get_connection() + except: + rec.checkin() + raise + fairy = _ConnectionFairy(dbapi_connection, rec) + rec.fairy_ref = weakref.ref( + fairy, + lambda ref: _finalize_fairy and \ + _finalize_fairy( + dbapi_connection, + rec, pool, ref, pool._echo) + ) + _refs.add(rec) + if pool._echo: + pool.logger.debug("Connection %r checked out from pool", + dbapi_connection) + return fairy + + def checkin(self): + self.fairy_ref = None + connection = self.connection + pool = self.__pool + while self.finalize_callback: + finalizer = self.finalize_callback.pop() + finalizer(connection) + if pool.dispatch.checkin: + pool.dispatch.checkin(connection, self) + pool._return_conn(self) + + + def close(self): + if self.connection is not None: + self.__close() + + def invalidate(self, e=None): + """Invalidate the DBAPI connection held by this :class:`._ConnectionRecord`. + + This method is called for all connection invalidations, including + when the :meth:`._ConnectionFairy.invalidate` or :meth:`.Connection.invalidate` + methods are called, as well as when any so-called "automatic invalidation" + condition occurs. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + self.__pool.dispatch.invalidate(self.connection, self, e) + if e is not None: + self.__pool.logger.info( + "Invalidate connection %r (reason: %s:%s)", + self.connection, e.__class__.__name__, e) + else: + self.__pool.logger.info( + "Invalidate connection %r", self.connection) + self.__close() + self.connection = None + + def get_connection(self): + recycle = False + if self.connection is None: + self.connection = self.__connect() + self.info.clear() + if self.__pool.dispatch.connect: + self.__pool.dispatch.connect(self.connection, self) + elif self.__pool._recycle > -1 and \ + time.time() - self.starttime > self.__pool._recycle: + self.__pool.logger.info( + "Connection %r exceeded timeout; recycling", + self.connection) + recycle = True + elif self.__pool._invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to pool invalidation; recycling", + self.connection + ) + recycle = True + + if recycle: + self.__close() + self.connection = self.__connect() + self.info.clear() + if self.__pool.dispatch.connect: + self.__pool.dispatch.connect(self.connection, self) + return self.connection + + def __close(self): + self.__pool._close_connection(self.connection) + + def __connect(self): + try: + self.starttime = time.time() + connection = self.__pool._creator() + self.__pool.logger.debug("Created new connection %r", connection) + return connection + except Exception as e: + self.__pool.logger.debug("Error on connect(): %s", e) + raise + + +def _finalize_fairy(connection, connection_record, pool, ref, echo, fairy=None): + """Cleanup for a :class:`._ConnectionFairy` whether or not it's already + been garbage collected. + + """ + _refs.discard(connection_record) + + if ref is not None and \ + connection_record.fairy_ref is not ref: + return + + if connection is not None: + if connection_record and echo: + pool.logger.debug("Connection %r being returned to pool", + connection) + + try: + fairy = fairy or _ConnectionFairy(connection, connection_record) + assert fairy.connection is connection + fairy._reset(pool, echo) + + # Immediately close detached instances + if not connection_record: + pool._close_connection(connection) + except Exception as e: + if connection_record: + connection_record.invalidate(e=e) + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise + + if connection_record: + connection_record.checkin() + + +_refs = set() + + +class _ConnectionFairy(object): + """Proxies a DBAPI connection and provides return-on-dereference + support. + + This is an internal object used by the :class:`.Pool` implementation + to provide context management to a DBAPI connection delivered by + that :class:`.Pool`. + + The name "fairy" is inspired by the fact that the :class:`._ConnectionFairy` + object's lifespan is transitory, as it lasts only for the length of a + specific DBAPI connection being checked out from the pool, and additionally + that as a transparent proxy, it is mostly invisible. + + .. seealso:: + + :class:`._ConnectionRecord` + + """ + + def __init__(self, dbapi_connection, connection_record): + self.connection = dbapi_connection + self._connection_record = connection_record + + connection = None + """A reference to the actual DBAPI connection being tracked.""" + + _connection_record = None + """A reference to the :class:`._ConnectionRecord` object associated + with the DBAPI connection. + + This is currently an internal accessor which is subject to change. + + """ + + _reset_agent = None + """Refer to an object with a ``.commit()`` and ``.rollback()`` method; + if non-None, the "reset-on-return" feature will call upon this object + rather than directly against the dialect-level do_rollback() and do_commit() + methods. + + In practice, a :class:`.Connection` assigns a :class:`.Transaction` object + to this variable when one is in scope so that the :class:`.Transaction` + takes the job of committing or rolling back on return if + :meth:`.Connection.close` is called while the :class:`.Transaction` + still exists. + + This is essentially an "event handler" of sorts but is simplified as an + instance variable both for performance/simplicity as well as that there + can only be one "reset agent" at a time. + """ + + @classmethod + def _checkout(cls, pool, threadconns=None, fairy=None): + if not fairy: + fairy = _ConnectionRecord.checkout(pool) + + fairy._pool = pool + fairy._counter = 0 + fairy._echo = pool._should_log_debug() + + if threadconns is not None: + threadconns.current = weakref.ref(fairy) + + if fairy.connection is None: + raise exc.InvalidRequestError("This connection is closed") + fairy._counter += 1 + + if not pool.dispatch.checkout or fairy._counter != 1: + return fairy + + # Pool listeners can trigger a reconnection on checkout + attempts = 2 + while attempts > 0: + try: + pool.dispatch.checkout(fairy.connection, + fairy._connection_record, + fairy) + return fairy + except exc.DisconnectionError as e: + pool.logger.info( + "Disconnection detected on checkout: %s", e) + fairy._connection_record.invalidate(e) + fairy.connection = fairy._connection_record.get_connection() + attempts -= 1 + + pool.logger.info("Reconnection attempts exhausted on checkout") + fairy.invalidate() + raise exc.InvalidRequestError("This connection is closed") + + def _checkout_existing(self): + return _ConnectionFairy._checkout(self._pool, fairy=self) + + def _checkin(self): + _finalize_fairy(self.connection, self._connection_record, + self._pool, None, self._echo, fairy=self) + self.connection = None + self._connection_record = None + + _close = _checkin + + def _reset(self, pool, echo): + if pool.dispatch.reset: + pool.dispatch.reset(self, self._connection_record) + if pool._reset_on_return is reset_rollback: + if echo: + pool.logger.debug("Connection %s rollback-on-return%s", + self.connection, + ", via agent" + if self._reset_agent else "") + if self._reset_agent: + self._reset_agent.rollback() + else: + pool._dialect.do_rollback(self) + elif pool._reset_on_return is reset_commit: + if echo: + pool.logger.debug("Connection %s commit-on-return%s", + self.connection, + ", via agent" + if self._reset_agent else "") + if self._reset_agent: + self._reset_agent.commit() + else: + pool._dialect.do_commit(self) + + @property + def _logger(self): + return self._pool.logger + + @property + def is_valid(self): + """Return True if this :class:`._ConnectionFairy` still refers + to an active DBAPI connection.""" + + return self.connection is not None + + @util.memoized_property + def info(self): + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ConnectionFairy`, allowing user-defined + data to be associated with the connection. + + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`._ConnectionFairy`. It is shared + with the :attr:`._ConnectionRecord.info` and :attr:`.Connection.info` + accessors. + + """ + return self._connection_record.info + + def invalidate(self, e=None): + """Mark this connection as invalidated. + + This method can be called directly, and is also called as a result + of the :meth:`.Connection.invalidate` method. When invoked, + the DBAPI connection is immediately closed and discarded from + further use by the pool. The invalidation mechanism proceeds + via the :meth:`._ConnectionRecord.invalidate` internal method. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + if self.connection is None: + util.warn("Can't invalidate an already-closed connection.") + return + if self._connection_record: + self._connection_record.invalidate(e=e) + self.connection = None + self._checkin() + + def cursor(self, *args, **kwargs): + """Return a new DBAPI cursor for the underlying connection. + + This method is a proxy for the ``connection.cursor()`` DBAPI + method. + + """ + return self.connection.cursor(*args, **kwargs) + + def __getattr__(self, key): + return getattr(self.connection, key) + + + def detach(self): + """Separate this connection from its Pool. + + This means that the connection will no longer be returned to the + pool when closed, and will instead be literally closed. The + containing ConnectionRecord is separated from the DB-API connection, + and will create a new connection when next used. + + Note that any overall connection limiting constraints imposed by a + Pool implementation may be violated after a detach, as the detached + connection is removed from the pool's knowledge and control. + """ + + if self._connection_record is not None: + _refs.remove(self._connection_record) + self._connection_record.fairy_ref = None + self._connection_record.connection = None + # TODO: should this be _return_conn? + self._pool._do_return_conn(self._connection_record) + self.info = self.info.copy() + self._connection_record = None + + def close(self): + self._counter -= 1 + if self._counter == 0: + self._checkin() + + + +class SingletonThreadPool(Pool): + """A Pool that maintains one connection per thread. + + Maintains one connection per each thread, never moving a connection to a + thread other than the one which it was created in. + + Options are the same as those of :class:`.Pool`, as well as: + + :param pool_size: The number of threads in which to maintain connections + at once. Defaults to five. + + :class:`.SingletonThreadPool` is used by the SQLite dialect + automatically when a memory-based database is used. + See :ref:`sqlite_toplevel`. + + """ + + def __init__(self, creator, pool_size=5, **kw): + kw['use_threadlocal'] = True + Pool.__init__(self, creator, **kw) + self._conn = threading.local() + self._all_conns = set() + self.size = pool_size + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + _dialect=self._dialect) + + def dispose(self): + """Dispose of this pool.""" + + for conn in self._all_conns: + try: + conn.close() + except (SystemExit, KeyboardInterrupt): + raise + except: + # pysqlite won't even let you close a conn from a thread + # that didn't create it + pass + + self._all_conns.clear() + + def _cleanup(self): + while len(self._all_conns) >= self.size: + c = self._all_conns.pop() + c.close() + + def status(self): + return "SingletonThreadPool id:%d size: %d" % \ + (id(self), len(self._all_conns)) + + def _do_return_conn(self, conn): + pass + + def _do_get(self): + try: + c = self._conn.current() + if c: + return c + except AttributeError: + pass + c = self._create_connection() + self._conn.current = weakref.ref(c) + if len(self._all_conns) >= self.size: + self._cleanup() + self._all_conns.add(c) + return c + + +class QueuePool(Pool): + """A :class:`.Pool` that imposes a limit on the number of open connections. + + :class:`.QueuePool` is the default pooling implementation used for + all :class:`.Engine` objects, unless the SQLite dialect is in use. + + """ + + def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30, + **kw): + """ + Construct a QueuePool. + + :param creator: a callable function that returns a DB-API + connection object, same as that of :paramref:`.Pool.creator`. + + :param pool_size: The size of the pool to be maintained, + defaults to 5. This is the largest number of connections that + will be kept persistently in the pool. Note that the pool + begins with no connections; once this number of connections + is requested, that number of connections will remain. + ``pool_size`` can be set to 0 to indicate no size limit; to + disable pooling, use a :class:`~sqlalchemy.pool.NullPool` + instead. + + :param max_overflow: The maximum overflow size of the + pool. When the number of checked-out connections reaches the + size set in pool_size, additional connections will be + returned up to this limit. When those additional connections + are returned to the pool, they are disconnected and + discarded. It follows then that the total number of + simultaneous connections the pool will allow is pool_size + + `max_overflow`, and the total number of "sleeping" + connections the pool will allow is pool_size. `max_overflow` + can be set to -1 to indicate no overflow limit; no limit + will be placed on the total number of concurrent + connections. Defaults to 10. + + :param timeout: The number of seconds to wait before giving up + on returning a connection. Defaults to 30. + + :param \**kw: Other keyword arguments including :paramref:`.Pool.recycle`, + :paramref:`.Pool.echo`, :paramref:`.Pool.reset_on_return` and others + are passed to the :class:`.Pool` constructor. + + """ + Pool.__init__(self, creator, **kw) + self._pool = sqla_queue.Queue(pool_size) + self._overflow = 0 - pool_size + self._max_overflow = max_overflow + self._timeout = timeout + self._overflow_lock = threading.Lock() + + def _do_return_conn(self, conn): + try: + self._pool.put(conn, False) + except sqla_queue.Full: + try: + conn.close() + finally: + self._dec_overflow() + + def _do_get(self): + use_overflow = self._max_overflow > -1 + + try: + wait = use_overflow and self._overflow >= self._max_overflow + return self._pool.get(wait, self._timeout) + except sqla_queue.Empty: + if use_overflow and self._overflow >= self._max_overflow: + if not wait: + return self._do_get() + else: + raise exc.TimeoutError( + "QueuePool limit of size %d overflow %d reached, " + "connection timed out, timeout %d" % + (self.size(), self.overflow(), self._timeout)) + + if self._inc_overflow(): + try: + return self._create_connection() + except: + self._dec_overflow() + raise + else: + return self._do_get() + + def _inc_overflow(self): + if self._max_overflow == -1: + self._overflow += 1 + return True + with self._overflow_lock: + if self._overflow < self._max_overflow: + self._overflow += 1 + return True + else: + return False + + def _dec_overflow(self): + if self._max_overflow == -1: + self._overflow -= 1 + return True + with self._overflow_lock: + self._overflow -= 1 + return True + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(self._creator, pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, + timeout=self._timeout, + recycle=self._recycle, echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + _dialect=self._dialect) + + def dispose(self): + while True: + try: + conn = self._pool.get(False) + conn.close() + except sqla_queue.Empty: + break + + self._overflow = 0 - self.size() + self.logger.info("Pool disposed. %s", self.status()) + + def status(self): + return "Pool size: %d Connections in pool: %d "\ + "Current Overflow: %d Current Checked out "\ + "connections: %d" % (self.size(), + self.checkedin(), + self.overflow(), + self.checkedout()) + + def size(self): + return self._pool.maxsize + + def checkedin(self): + return self._pool.qsize() + + def overflow(self): + return self._overflow + + def checkedout(self): + return self._pool.maxsize - self._pool.qsize() + self._overflow + + +class NullPool(Pool): + """A Pool which does not pool connections. + + Instead it literally opens and closes the underlying DB-API connection + per each connection open/close. + + Reconnect-related functions such as ``recycle`` and connection + invalidation are not supported by this Pool implementation, since + no connections are held persistently. + + .. versionchanged:: 0.7 + :class:`.NullPool` is used by the SQlite dialect automatically + when a file-based database is used. See :ref:`sqlite_toplevel`. + + """ + + def status(self): + return "NullPool" + + def _do_return_conn(self, conn): + conn.close() + + def _do_get(self): + return self._create_connection() + + def recreate(self): + self.logger.info("Pool recreating") + + return self.__class__(self._creator, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + _dialect=self._dialect) + + def dispose(self): + pass + + +class StaticPool(Pool): + """A Pool of exactly one connection, used for all requests. + + Reconnect-related functions such as ``recycle`` and connection + invalidation (which is also used to support auto-reconnect) are not + currently supported by this Pool implementation but may be implemented + in a future release. + + """ + + @memoized_property + def _conn(self): + return self._creator() + + @memoized_property + def connection(self): + return _ConnectionRecord(self) + + def status(self): + return "StaticPool" + + def dispose(self): + if '_conn' in self.__dict__: + self._conn.close() + self._conn = None + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(creator=self._creator, + recycle=self._recycle, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + _dialect=self._dialect) + + def _create_connection(self): + return self._conn + + def _do_return_conn(self, conn): + pass + + def _do_get(self): + return self.connection + + +class AssertionPool(Pool): + """A :class:`.Pool` that allows at most one checked out connection at + any given time. + + This will raise an exception if more than one connection is checked out + at a time. Useful for debugging code that is using more connections + than desired. + + .. versionchanged:: 0.7 + :class:`.AssertionPool` also logs a traceback of where + the original connection was checked out, and reports + this in the assertion error raised. + + """ + def __init__(self, *args, **kw): + self._conn = None + self._checked_out = False + self._store_traceback = kw.pop('store_traceback', True) + self._checkout_traceback = None + Pool.__init__(self, *args, **kw) + + def status(self): + return "AssertionPool" + + def _do_return_conn(self, conn): + if not self._checked_out: + raise AssertionError("connection is not checked out") + self._checked_out = False + assert conn is self._conn + + def dispose(self): + self._checked_out = False + if self._conn: + self._conn.close() + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(self._creator, echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + _dialect=self._dialect) + + def _do_get(self): + if self._checked_out: + if self._checkout_traceback: + suffix = ' at:\n%s' % ''.join( + chop_traceback(self._checkout_traceback)) + else: + suffix = '' + raise AssertionError("connection is already checked out" + suffix) + + if not self._conn: + self._conn = self._create_connection() + + self._checked_out = True + if self._store_traceback: + self._checkout_traceback = traceback.format_stack() + return self._conn + + +class _DBProxy(object): + """Layers connection pooling behavior on top of a standard DB-API module. + + Proxies a DB-API 2.0 connect() call to a connection pool keyed to the + specific connect parameters. Other functions and attributes are delegated + to the underlying DB-API module. + """ + + def __init__(self, module, poolclass=QueuePool, **kw): + """Initializes a new proxy. + + module + a DB-API 2.0 module + + poolclass + a Pool class, defaulting to QueuePool + + Other parameters are sent to the Pool object's constructor. + + """ + + self.module = module + self.kw = kw + self.poolclass = poolclass + self.pools = {} + self._create_pool_mutex = threading.Lock() + + def close(self): + for key in list(self.pools): + del self.pools[key] + + def __del__(self): + self.close() + + def __getattr__(self, key): + return getattr(self.module, key) + + def get_pool(self, *args, **kw): + key = self._serialize(*args, **kw) + try: + return self.pools[key] + except KeyError: + self._create_pool_mutex.acquire() + try: + if key not in self.pools: + kw.pop('sa_pool_key', None) + pool = self.poolclass(lambda: + self.module.connect(*args, **kw), **self.kw) + self.pools[key] = pool + return pool + else: + return self.pools[key] + finally: + self._create_pool_mutex.release() + + def connect(self, *args, **kw): + """Activate a connection to the database. + + Connect to the database using this DBProxy's module and the given + connect arguments. If the arguments match an existing pool, the + connection will be returned from the pool's current thread-local + connection instance, or if there is no thread-local connection + instance it will be checked out from the set of pooled connections. + + If the pool has no available connections and allows new connections + to be created, a new database connection will be made. + + """ + + return self.get_pool(*args, **kw).connect() + + def dispose(self, *args, **kw): + """Dispose the pool referenced by the given connect arguments.""" + + key = self._serialize(*args, **kw) + try: + del self.pools[key] + except KeyError: + pass + + def _serialize(self, *args, **kw): + if "sa_pool_key" in kw: + return kw['sa_pool_key'] + + return tuple( + list(args) + + [(k, kw[k]) for k in sorted(kw)] + ) diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py new file mode 100644 index 00000000..d0f52e42 --- /dev/null +++ b/lib/sqlalchemy/processors.py @@ -0,0 +1,152 @@ +# sqlalchemy/processors.py +# Copyright (C) 2010-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +import codecs +import re +import datetime +from . import util + + +def str_to_datetime_processor_factory(regexp, type_): + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + + def process(value): + if value is None: + return None + else: + try: + m = rmatch(value) + except TypeError: + raise ValueError("Couldn't parse %s string '%r' " + "- value is not a string." % + (type_.__name__, value)) + if m is None: + raise ValueError("Couldn't parse %s string: " + "'%s'" % (type_.__name__, value)) + if has_named_groups: + groups = m.groupdict(0) + return type_(**dict(list(zip(iter(groups.keys()), + list(map(int, iter(groups.values()))))))) + else: + return type_(*list(map(int, m.groups(0)))) + return process + + +def boolean_to_int(value): + if value is None: + return None + else: + return int(value) + + +def py_fallback(): + def to_unicode_processor_factory(encoding, errors=None): + decoder = codecs.getdecoder(encoding) + + def process(value): + if value is None: + return None + else: + # decoder returns a tuple: (value, len). Simply dropping the + # len part is safe: it is done that way in the normal + # 'xx'.decode(encoding) code path. + return decoder(value, errors)[0] + return process + + def to_conditional_unicode_processor_factory(encoding, errors=None): + decoder = codecs.getdecoder(encoding) + + def process(value): + if value is None: + return None + elif isinstance(value, util.text_type): + return value + else: + # decoder returns a tuple: (value, len). Simply dropping the + # len part is safe: it is done that way in the normal + # 'xx'.decode(encoding) code path. + return decoder(value, errors)[0] + return process + + def to_decimal_processor_factory(target_class, scale): + fstring = "%%.%df" % scale + + def process(value): + if value is None: + return None + else: + return target_class(fstring % value) + return process + + def to_float(value): + if value is None: + return None + else: + return float(value) + + def to_str(value): + if value is None: + return None + else: + return str(value) + + def int_to_boolean(value): + if value is None: + return None + else: + return value and True or False + + DATETIME_RE = re.compile( + "(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") + TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?") + DATE_RE = re.compile("(\d+)-(\d+)-(\d+)") + + str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, + datetime.datetime) + str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) + str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) + return locals() + +try: + from sqlalchemy.cprocessors import UnicodeResultProcessor, \ + DecimalResultProcessor, \ + to_float, to_str, int_to_boolean, \ + str_to_datetime, str_to_time, \ + str_to_date + + def to_unicode_processor_factory(encoding, errors=None): + if errors is not None: + return UnicodeResultProcessor(encoding, errors).process + else: + return UnicodeResultProcessor(encoding).process + + def to_conditional_unicode_processor_factory(encoding, errors=None): + if errors is not None: + return UnicodeResultProcessor(encoding, errors).conditional_process + else: + return UnicodeResultProcessor(encoding).conditional_process + + def to_decimal_processor_factory(target_class, scale): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process + +except ImportError: + globals().update(py_fallback()) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py new file mode 100644 index 00000000..8556272a --- /dev/null +++ b/lib/sqlalchemy/schema.py @@ -0,0 +1,61 @@ +# schema.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Compatiblity namespace for sqlalchemy.sql.schema and related. + +""" + +from .sql.base import ( + SchemaVisitor + ) + + +from .sql.schema import ( + CheckConstraint, + Column, + ColumnDefault, + Constraint, + DefaultClause, + DefaultGenerator, + FetchedValue, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + PassiveDefault, + PrimaryKeyConstraint, + SchemaItem, + Sequence, + Table, + ThreadLocalMetaData, + UniqueConstraint, + _get_table_key, + ColumnCollectionConstraint, + ) + + +from .sql.naming import conv + + +from .sql.ddl import ( + DDL, + CreateTable, + DropTable, + CreateSequence, + DropSequence, + CreateIndex, + DropIndex, + CreateSchema, + DropSchema, + _DropView, + CreateColumn, + AddConstraint, + DropConstraint, + DDLBase, + DDLElement, + _CreateDropBase, + _DDLCompiles +) diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..95dae5aa --- /dev/null +++ b/lib/sqlalchemy/sql/__init__.py @@ -0,0 +1,90 @@ +# sql/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .expression import ( + Alias, + ClauseElement, + ColumnCollection, + ColumnElement, + CompoundSelect, + Delete, + FromClause, + Insert, + Join, + Select, + Selectable, + TableClause, + Update, + alias, + and_, + asc, + between, + bindparam, + case, + cast, + collate, + column, + delete, + desc, + distinct, + except_, + except_all, + exists, + extract, + false, + False_, + func, + insert, + intersect, + intersect_all, + join, + label, + literal, + literal_column, + modifier, + not_, + null, + or_, + outerjoin, + outparam, + over, + select, + subquery, + table, + text, + true, + True_, + tuple_, + type_coerce, + union, + union_all, + update, + ) + +from .visitors import ClauseVisitor + +def __go(lcls): + global __all__ + from .. import util as _sa_util + + import inspect as _inspect + + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) + + from .annotation import _prepare_annotations, Annotated + from .elements import AnnotatedColumnElement, ClauseList + from .selectable import AnnotatedFromClause + _prepare_annotations(ColumnElement, AnnotatedColumnElement) + _prepare_annotations(FromClause, AnnotatedFromClause) + _prepare_annotations(ClauseList, Annotated) + + _sa_util.dependencies.resolve_all("sqlalchemy.sql") + + from . import naming + +__go(locals()) + diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py new file mode 100644 index 00000000..11b06667 --- /dev/null +++ b/lib/sqlalchemy/sql/annotation.py @@ -0,0 +1,191 @@ +# sql/annotation.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The :class:`.Annotated` class and related routines; creates hash-equivalent +copies of SQL constructs which contain context-specific markers and associations. + +""" + +from .. import util +from . import operators + +class Annotated(object): + """clones a ClauseElement and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + """ + + def __new__(cls, *args): + if not args: + # clone constructor + return object.__new__(cls) + else: + element, values = args + # pull appropriate subclass from registry of annotated + # classes + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return object.__new__(cls) + + def __init__(self, element, values): + self.__dict__ = element.__dict__.copy() + self.__element = element + self._annotations = values + + def _annotate(self, values): + _values = self._annotations.copy() + _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone._annotations = values + return clone + + def _deannotate(self, values=None, clone=True): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) + + def _compiler_dispatch(self, visitor, **kw): + return self.__element.__class__._compiler_dispatch(self, visitor, **kw) + + @property + def _constructor(self): + return self.__element._constructor + + def _clone(self): + clone = self.__element._clone() + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occurred + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return self.__class__(clone, self._annotations) + + def __hash__(self): + return hash(self.__element) + + def __eq__(self, other): + if isinstance(self.__element, operators.ColumnOperators): + return self.__element.__class__.__eq__(self, other) + else: + return hash(other) == hash(self) + + + +# hard-generate Annotated subclasses. this technique +# is used instead of on-the-fly types (i.e. type.__new__()) +# so that the resulting objects are pickleable. +annotated_classes = {} + + + +def _deep_annotate(element, annotations, exclude=None): + """Deep copy the given ClauseElement, annotating each element + with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + def clone(elem): + if exclude and \ + hasattr(elem, 'proxy_set') and \ + elem.proxy_set.intersection(exclude): + newelem = elem._clone() + elif annotations != elem._annotations: + newelem = elem._annotate(annotations) + else: + newelem = elem + newelem._copy_internals(clone=clone) + return newelem + + if element is not None: + element = clone(element) + return element + + +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" + + cloned = util.column_dict() + + def clone(elem): + # if a values dict is given, + # the elem must be cloned each time it appears, + # as there may be different annotations in source + # elements that are remaining. if totally + # removing all annotations, can assume the same + # slate... + if values or elem not in cloned: + newelem = elem._deannotate(values=values, clone=True) + newelem._copy_internals(clone=clone) + if not values: + cloned[elem] = newelem + return newelem + else: + return cloned[elem] + + if element is not None: + element = clone(element) + return element + + +def _shallow_annotate(element, annotations): + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. + + Basically used to apply a "dont traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element + +def _new_annotation_type(cls, base_cls): + if issubclass(cls, Annotated): + return cls + elif cls in annotated_classes: + return annotated_classes[cls] + + for super_ in cls.__mro__: + # check if an Annotated subclass more specific than + # the given base_cls is already registered, such + # as AnnotatedColumnElement. + if super_ in annotated_classes: + base_cls = annotated_classes[super_] + break + + annotated_classes[cls] = anno_cls = type( + "Annotated%s" % cls.__name__, + (base_cls, cls), {}) + globals()["Annotated%s" % cls.__name__] = anno_cls + return anno_cls + +def _prepare_annotations(target_hierarchy, base_cls): + stack = [target_hierarchy] + while stack: + cls = stack.pop() + stack.extend(cls.__subclasses__()) + + _new_annotation_type(cls, base_cls) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py new file mode 100644 index 00000000..379f61ed --- /dev/null +++ b/lib/sqlalchemy/sql/base.py @@ -0,0 +1,621 @@ +# sql/base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Foundational utilities common to many sql modules. + +""" + + +from .. import util, exc +import itertools +from .visitors import ClauseVisitor +import re +import collections + +PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT') +NO_ARG = util.symbol('NO_ARG') + +class Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def _clone(self): + return self + + + +def _from_objects(*elements): + return itertools.chain(*[element._from_objects for element in elements]) + +@util.decorator +def _generative(fn, *args, **kw): + """Mark a method as generative.""" + + self = args[0]._generate() + fn(self, *args[1:], **kw) + return self + + +class _DialectArgView(collections.MutableMapping): + """A dictionary view of dialect-level arguments in the form + _. + + """ + def __init__(self, obj): + self.obj = obj + + def _key(self, key): + try: + dialect, value_key = key.split("_", 1) + except ValueError: + raise KeyError(key) + else: + return dialect, value_key + + def __getitem__(self, key): + dialect, value_key = self._key(key) + + try: + opt = self.obj.dialect_options[dialect] + except exc.NoSuchModuleError: + raise KeyError(key) + else: + return opt[value_key] + + def __setitem__(self, key, value): + try: + dialect, value_key = self._key(key) + except KeyError: + raise exc.ArgumentError( + "Keys must be of the form _") + else: + self.obj.dialect_options[dialect][value_key] = value + + def __delitem__(self, key): + dialect, value_key = self._key(key) + del self.obj.dialect_options[dialect][value_key] + + def __len__(self): + return sum(len(args._non_defaults) for args in + self.obj.dialect_options.values()) + + def __iter__(self): + return ( + "%s_%s" % (dialect_name, value_name) + for dialect_name in self.obj.dialect_options + for value_name in self.obj.dialect_options[dialect_name]._non_defaults + ) + +class _DialectArgDict(collections.MutableMapping): + """A dictionary view of dialect-level arguments for a specific + dialect. + + Maintains a separate collection of user-specified arguments + and dialect-specified default arguments. + + """ + def __init__(self): + self._non_defaults = {} + self._defaults = {} + + def __len__(self): + return len(set(self._non_defaults).union(self._defaults)) + + def __iter__(self): + return iter(set(self._non_defaults).union(self._defaults)) + + def __getitem__(self, key): + if key in self._non_defaults: + return self._non_defaults[key] + else: + return self._defaults[key] + + def __setitem__(self, key, value): + self._non_defaults[key] = value + + def __delitem__(self, key): + del self._non_defaults[key] + + +class DialectKWArgs(object): + """Establish the ability for a class to have dialect-specific arguments + with defaults and constructor validation. + + The :class:`.DialectKWArgs` interacts with the + :attr:`.DefaultDialect.construct_arguments` present on a dialect. + + .. seealso:: + + :attr:`.DefaultDialect.construct_arguments` + + """ + + @classmethod + def argument_for(cls, dialect_name, argument_name, default): + """Add a new kind of dialect-specific keyword argument for this class. + + E.g.:: + + Index.argument_for("mydialect", "length", None) + + some_index = Index('a', 'b', mydialect_length=5) + + The :meth:`.DialectKWArgs.argument_for` method is a per-argument + way adding extra arguments to the :attr:`.DefaultDialect.construct_arguments` + dictionary. This dictionary provides a list of argument names accepted by + various schema-level constructs on behalf of a dialect. + + New dialects should typically specify this dictionary all at once as a data + member of the dialect class. The use case for ad-hoc addition of + argument names is typically for end-user code that is also using + a custom compilation scheme which consumes the additional arguments. + + :param dialect_name: name of a dialect. The dialect must be locatable, + else a :class:`.NoSuchModuleError` is raised. The dialect must + also include an existing :attr:`.DefaultDialect.construct_arguments` collection, + indicating that it participates in the keyword-argument validation and + default system, else :class:`.ArgumentError` is raised. + If the dialect does not include this collection, then any keyword argument + can be specified on behalf of this dialect already. All dialects + packaged within SQLAlchemy include this collection, however for third + party dialects, support may vary. + + :param argument_name: name of the parameter. + + :param default: default value of the parameter. + + .. versionadded:: 0.9.4 + + """ + + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + if construct_arg_dictionary is None: + raise exc.ArgumentError("Dialect '%s' does have keyword-argument " + "validation and defaults enabled configured" % + dialect_name) + construct_arg_dictionary[cls][argument_name] = default + + @util.memoized_property + def dialect_kwargs(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + The arguments are present here in their original ``_`` + format. Only arguments that were actually passed are included; + unlike the :attr:`.DialectKWArgs.dialect_options` collection, which + contains all options known by this dialect including defaults. + + The collection is also writable; keys are accepted of the + form ``_`` where the value will be assembled + into the list of options. + + .. versionadded:: 0.9.2 + + .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs` + collection is now writable. + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_options` - nested dictionary form + + """ + return _DialectArgView(self) + + @property + def kwargs(self): + """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" + return self.dialect_kwargs + + @util.dependencies("sqlalchemy.dialects") + def _kw_reg_for_dialect(dialects, dialect_name): + dialect_cls = dialects.registry.load(dialect_name) + if dialect_cls.construct_arguments is None: + return None + return dict(dialect_cls.construct_arguments) + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) + + def _kw_reg_for_dialect_cls(self, dialect_name): + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + d = _DialectArgDict() + + if construct_arg_dictionary is None: + d._defaults.update({"*": None}) + else: + for cls in reversed(self.__class__.__mro__): + if cls in construct_arg_dictionary: + d._defaults.update(construct_arg_dictionary[cls]) + return d + + @util.memoized_property + def dialect_options(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + This is a two-level nested registry, keyed to ```` + and ````. For example, the ``postgresql_where`` argument + would be locatable as:: + + arg = my_object.dialect_options['postgresql']['where'] + + .. versionadded:: 0.9.2 + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form + + """ + + return util.PopulateDict( + util.portable_instancemethod(self._kw_reg_for_dialect_cls) + ) + + def _validate_dialect_kwargs(self, kwargs): + # validate remaining kwargs that they all specify DB prefixes + + if not kwargs: + return + + for k in kwargs: + m = re.match('^(.+?)_(.+)$', k) + if not m: + raise TypeError("Additional arguments should be " + "named _, got '%s'" % k) + dialect_name, arg_name = m.group(1, 2) + + try: + construct_arg_dictionary = self.dialect_options[dialect_name] + except exc.NoSuchModuleError: + util.warn( + "Can't validate argument %r; can't " + "locate any SQLAlchemy dialect named %r" % + (k, dialect_name)) + self.dialect_options[dialect_name] = d = _DialectArgDict() + d._defaults.update({"*": None}) + d._non_defaults[arg_name] = kwargs[k] + else: + if "*" not in construct_arg_dictionary and \ + arg_name not in construct_arg_dictionary: + raise exc.ArgumentError( + "Argument %r is not accepted by " + "dialect %r on behalf of %r" % ( + k, + dialect_name, self.__class__ + )) + else: + construct_arg_dictionary[arg_name] = kwargs[k] + + +class Generative(object): + """Allow a ClauseElement to generate itself via the + @_generative decorator. + + """ + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +class Executable(Generative): + """Mark a ClauseElement as supporting execution. + + :class:`.Executable` is a superclass for all "statement" types + of objects, including :func:`select`, :func:`delete`, :func:`update`, + :func:`insert`, :func:`text`. + + """ + + supports_execution = True + _execution_options = util.immutabledict() + _bind = None + + @_generative + def execution_options(self, **kw): + """ Set non-SQL options for the statement which take effect during + execution. + + Execution options can be set on a per-statement or + per :class:`.Connection` basis. Additionally, the + :class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide + access to execution options which they in turn configure upon + connections. + + The :meth:`execution_options` method is generative. A new + instance of this statement is returned that contains the options:: + + statement = select([table.c.x, table.c.y]) + statement = statement.execution_options(autocommit=True) + + Note that only a subset of possible execution options can be applied + to a statement - these include "autocommit" and "stream_results", + but not "isolation_level" or "compiled_cache". + See :meth:`.Connection.execution_options` for a full list of + possible options. + + .. seealso:: + + :meth:`.Connection.execution_options()` + + :meth:`.Query.execution_options()` + + """ + if 'isolation_level' in kw: + raise exc.ArgumentError( + "'isolation_level' execution option may only be specified " + "on Connection.execution_options(), or " + "per-engine using the isolation_level " + "argument to create_engine()." + ) + if 'compiled_cache' in kw: + raise exc.ArgumentError( + "'compiled_cache' execution option may only be specified " + "on Connection.execution_options(), not per statement." + ) + self._execution_options = self._execution_options.union(kw) + + def execute(self, *multiparams, **params): + """Compile and execute this :class:`.Executable`.""" + e = self.bind + if e is None: + label = getattr(self, 'description', self.__class__.__name__) + msg = ('This %s is not directly bound to a Connection or Engine.' + 'Use the .execute() method of a Connection or Engine ' + 'to execute this construct.' % label) + raise exc.UnboundExecutionError(msg) + return e._execute_clauseelement(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Compile and execute this :class:`.Executable`, returning the + result's scalar representation. + + """ + return self.execute(*multiparams, **params).scalar() + + @property + def bind(self): + """Returns the :class:`.Engine` or :class:`.Connection` to + which this :class:`.Executable` is bound, or None if none found. + + This is a traversal which checks locally, then + checks among the "from" clauses of associated objects + until a bound engine or connection is found. + + """ + if self._bind is not None: + return self._bind + + for f in _from_objects(self): + if f is self: + continue + engine = f.bind + if engine is not None: + return engine + else: + return None + + +class SchemaEventTarget(object): + """Base class for elements that are the targets of :class:`.DDLEvents` + events. + + This includes :class:`.SchemaItem` as well as :class:`.SchemaType`. + + """ + + def _set_parent(self, parent): + """Associate with this SchemaEvent's parent object.""" + + raise NotImplementedError() + + def _set_parent_with_dispatch(self, parent): + self.dispatch.before_parent_attach(self, parent) + self._set_parent(parent) + self.dispatch.after_parent_attach(self, parent) + +class SchemaVisitor(ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {'schema_visitor': True} + +class ColumnCollection(util.OrderedProperties): + """An ordered dictionary that stores a list of ColumnElement + instances. + + Overrides the ``__eq__()`` method to produce SQL clauses between + sets of correlated columns. + + """ + + def __init__(self): + super(ColumnCollection, self).__init__() + self.__dict__['_all_col_set'] = util.column_set() + self.__dict__['_all_columns'] = [] + + def __str__(self): + return repr([str(c) for c in self]) + + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. + + e.g.:: + + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + + """ + remove_col = None + if column.name in self and column.key != column.name: + other = self[column.name] + if other.name == other.key: + remove_col = other + self._all_col_set.remove(other) + del self._data[other.key] + + if column.key in self._data: + remove_col = self._data[column.key] + self._all_col_set.remove(remove_col) + + self._all_col_set.add(column) + self._data[column.key] = column + if remove_col is not None: + self._all_columns[:] = [column if c is remove_col + else c for c in self._all_columns] + else: + self._all_columns.append(column) + + + def add(self, column): + """Add a column to this collection. + + The key attribute of the column will be used as the hash key + for this dictionary. + + """ + self[column.key] = column + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, object): + raise NotImplementedError() + + def __setitem__(self, key, value): + if key in self: + + # this warning is primarily to catch select() statements + # which have conflicting column names in their exported + # columns collection + + existing = self[key] + if not existing.shares_lineage(value): + util.warn('Column %r on table %r being replaced by ' + '%r, which has the same key. Consider ' + 'use_labels for select() statements.' % (key, + getattr(existing, 'table', None), value)) + + # pop out memoized proxy_set as this + # operation may very well be occurring + # in a _make_proxy operation + util.memoized_property.reset(value, "proxy_set") + + self._all_col_set.add(value) + self._all_columns.append(value) + self._data[key] = value + + def clear(self): + raise NotImplementedError() + + def remove(self, column): + del self._data[column.key] + self._all_col_set.remove(column) + self._all_columns[:] = [c for c in self._all_columns if c is not column] + + def update(self, iter): + cols = list(iter) + self._all_columns.extend(c for label, c in cols if c not in self._all_col_set) + self._all_col_set.update(c for label, c in cols) + self._data.update((label, c) for label, c in cols) + + def extend(self, iter): + cols = list(iter) + self._all_columns.extend(c for c in cols if c not in self._all_col_set) + self._all_col_set.update(cols) + self._data.update((c.key, c) for c in cols) + + __hash__ = None + + @util.dependencies("sqlalchemy.sql.elements") + def __eq__(self, elements, other): + l = [] + for c in getattr(other, "_all_columns", other): + for local in self._all_columns: + if c.shares_lineage(local): + l.append(c == local) + return elements.and_(*l) + + def __contains__(self, other): + if not isinstance(other, util.string_types): + raise exc.ArgumentError("__contains__ requires a string argument") + return util.OrderedProperties.__contains__(self, other) + + def __getstate__(self): + return {'_data': self.__dict__['_data'], + '_all_columns': self.__dict__['_all_columns']} + + def __setstate__(self, state): + self.__dict__['_data'] = state['_data'] + self.__dict__['_all_columns'] = state['_all_columns'] + self.__dict__['_all_col_set'] = util.column_set(state['_all_columns']) + + def contains_column(self, col): + # this has to be done via set() membership + return col in self._all_col_set + + def as_immutable(self): + return ImmutableColumnCollection(self._data, self._all_col_set, self._all_columns) + + +class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): + def __init__(self, data, colset, all_columns): + util.ImmutableProperties.__init__(self, data) + self.__dict__['_all_col_set'] = colset + self.__dict__['_all_columns'] = all_columns + + extend = remove = util.ImmutableProperties._immutable + + +class ColumnSet(util.ordered_column_set): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __add__(self, other): + return list(self) + list(other) + + @util.dependencies("sqlalchemy.sql.elements") + def __eq__(self, elements, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c == local) + return elements.and_(*l) + + def __hash__(self): + return hash(tuple(x for x in self)) + +def _bind_or_error(schemaitem, msg=None): + bind = schemaitem.bind + if not bind: + name = schemaitem.__class__.__name__ + label = getattr(schemaitem, 'fullname', + getattr(schemaitem, 'name', None)) + if label: + item = '%s object %r' % (name, label) + else: + item = '%s object' % name + if msg is None: + msg = "%s is not bound to an Engine or Connection. "\ + "Execution can not proceed without a database to execute "\ + "against." % item + raise exc.UnboundExecutionError(msg) + return bind diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py new file mode 100644 index 00000000..5165ee78 --- /dev/null +++ b/lib/sqlalchemy/sql/compiler.py @@ -0,0 +1,2957 @@ +# sql/compiler.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base SQL and DDL compiler implementations. + +Classes provided include: + +:class:`.compiler.SQLCompiler` - renders SQL +strings + +:class:`.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:doc:`/ext/compiler`. + +""" + +import re +from . import schema, sqltypes, operators, functions, \ + util as sql_util, visitors, elements, selectable, base +from .. import util, exc +import decimal +import itertools +import operator + +RESERVED_WORDS = set([ + 'all', 'analyse', 'analyze', 'and', 'any', 'array', + 'as', 'asc', 'asymmetric', 'authorization', 'between', + 'binary', 'both', 'case', 'cast', 'check', 'collate', + 'column', 'constraint', 'create', 'cross', 'current_date', + 'current_role', 'current_time', 'current_timestamp', + 'current_user', 'default', 'deferrable', 'desc', + 'distinct', 'do', 'else', 'end', 'except', 'false', + 'for', 'foreign', 'freeze', 'from', 'full', 'grant', + 'group', 'having', 'ilike', 'in', 'initially', 'inner', + 'intersect', 'into', 'is', 'isnull', 'join', 'leading', + 'left', 'like', 'limit', 'localtime', 'localtimestamp', + 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', + 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', + 'placing', 'primary', 'references', 'right', 'select', + 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', + 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', + 'using', 'verbose', 'when', 'where']) + +LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) +ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in range(0, 10)]).union(['$']) + +BIND_PARAMS = re.compile(r'(? ', + operators.ge: ' >= ', + operators.eq: ' = ', + operators.concat_op: ' || ', + operators.between_op: ' BETWEEN ', + operators.match_op: ' MATCH ', + operators.in_op: ' IN ', + operators.notin_op: ' NOT IN ', + operators.comma_op: ', ', + operators.from_: ' FROM ', + operators.as_: ' AS ', + operators.is_: ' IS ', + operators.isnot: ' IS NOT ', + operators.collate: ' COLLATE ', + + # unary + operators.exists: 'EXISTS ', + operators.distinct_op: 'DISTINCT ', + operators.inv: 'NOT ', + + # modifiers + operators.desc_op: ' DESC', + operators.asc_op: ' ASC', + operators.nullsfirst_op: ' NULLS FIRST', + operators.nullslast_op: ' NULLS LAST', + +} + +FUNCTIONS = { + functions.coalesce: 'coalesce%(expr)s', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', + functions.current_timestamp: 'CURRENT_TIMESTAMP', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', + functions.localtimestamp: 'LOCALTIMESTAMP', + functions.random: 'random%(expr)s', + functions.sysdate: 'sysdate', + functions.session_user: 'SESSION_USER', + functions.user: 'USER' +} + +EXTRACT_MAP = { + 'month': 'month', + 'day': 'day', + 'year': 'year', + 'second': 'second', + 'hour': 'hour', + 'doy': 'doy', + 'minute': 'minute', + 'quarter': 'quarter', + 'dow': 'dow', + 'week': 'week', + 'epoch': 'epoch', + 'milliseconds': 'milliseconds', + 'microseconds': 'microseconds', + 'timezone_hour': 'timezone_hour', + 'timezone_minute': 'timezone_minute' +} + +COMPOUND_KEYWORDS = { + selectable.CompoundSelect.UNION: 'UNION', + selectable.CompoundSelect.UNION_ALL: 'UNION ALL', + selectable.CompoundSelect.EXCEPT: 'EXCEPT', + selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', + selectable.CompoundSelect.INTERSECT: 'INTERSECT', + selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' +} + +class Compiled(object): + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, bind=None, + compile_kwargs=util.immutabledict()): + """Construct a new ``Compiled`` object. + + :param dialect: ``Dialect`` to compile against. + + :param statement: ``ClauseElement`` to be compiled. + + :param bind: Optional Engine or Connection to compile this + statement against. + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + .. versionadded:: 0.8 + + """ + + self.dialect = dialect + self.bind = bind + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + self.string = self.process(self.statement, **compile_kwargs) + + @util.deprecated("0.7", ":class:`.Compiled` objects now compile " + "within the constructor.") + def compile(self): + """Produce the internal string representation of this element. + """ + pass + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_compiled(self, multiparams, params) + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or '' + + def construct_params(self, params=None): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exc.UnboundExecutionError( + "This Compiled object is not bound to any Engine " + "or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's + scalar value.""" + + return self.execute(*multiparams, **params).scalar() + + +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + + +class _CompileLabel(visitors.Visitable): + """lightweight label object which acts as an expression.Label.""" + + __visit_name__ = 'label' + __slots__ = 'element', 'name' + + def __init__(self, col, name, alt_names=()): + self.element = col + self.name = name + self._alt_names = (col,) + alt_names + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + +class SQLCompiler(Compiled): + """Default implementation of Compiled. + + Compiles ClauseElements into SQL strings. Uses a similar visit + paradigm as visitors.ClauseVisitor but implements its own traversal. + + """ + + extract_map = EXTRACT_MAP + + compound_keywords = COMPOUND_KEYWORDS + + isdelete = isinsert = isupdate = False + """class-level defaults which can be set at the instance + level to define if this Compiled instance represents + INSERT/UPDATE/DELETE + """ + + returning = None + """holds the "returning" collection of columns if + the statement is CRUD and defines returning columns + either implicitly or explicitly + """ + + returning_precedes_values = False + """set to True classwide to generate RETURNING + clauses before the VALUES or WHERE clause (i.e. MSSQL) + """ + + render_table_with_column_in_update_from = False + """set to True classwide to indicate the SET clause + in a multi-table UPDATE statement should qualify + columns with the table name (i.e. MySQL only) + """ + + ansi_bind_rules = False + """SQL 92 doesn't allow bind parameters to be used + in the columns clause of a SELECT, nor does it allow + ambiguous expressions like "? = ?". A compiler + subclass can set this flag to False if the target + driver/DB enforces this + """ + + def __init__(self, dialect, statement, column_keys=None, + inline=False, **kwargs): + """Construct a new ``DefaultCompiler`` object. + + dialect + Dialect to be used + + statement + ClauseElement to be compiled + + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. + + """ + self.column_keys = column_keys + + # compile INSERT/UPDATE defaults/sequences inlined (no pre- + # execute) + self.inline = inline or getattr(statement, 'inline', False) + + # a dictionary of bind parameter keys to BindParameter + # instances. + self.binds = {} + + # a dictionary of BindParameter instances to "compiled" names + # that are actually present in the generated SQL + self.bind_names = util.column_dict() + + # stack which keeps track of nested SELECT statements + self.stack = [] + + # relates label names in the final SQL to a tuple of local + # column/label name, ColumnElement object (if any) and + # TypeEngine. ResultProxy uses this for type processing and + # column targeting + self.result_map = {} + + # true if the paramstyle is positional + self.positional = dialect.positional + if self.positional: + self.positiontup = [] + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + self.ctes = None + + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = dialect.identifier_preparer + self.label_length = dialect.label_length \ + or dialect.max_identifier_length + + # a map which tracks "anonymous" identifiers that are created on + # the fly here + self.anon_map = util.PopulateDict(self._process_anon) + + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length + self.truncated_names = {} + Compiled.__init__(self, dialect, statement, **kwargs) + + if self.positional and dialect.paramstyle == 'numeric': + self._apply_numbered_params() + + @util.memoized_instancemethod + def _init_cte_state(self): + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_by_name = {} + self.ctes_recursive = False + if self.positional: + self.cte_positional = [] + + def _apply_numbered_params(self): + poscount = itertools.count(1) + self.string = re.sub( + r'\[_POSITION\]', + lambda m: str(util.next(poscount)), + self.string) + + @util.memoized_property + def _bind_processors(self): + return dict( + (key, value) for key, value in + ((self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect)) + for bindparam in self.bind_names) + if value is not None + ) + + def is_subquery(self): + return len(self.stack) > 1 + + @property + def sql_compiler(self): + return self + + def construct_params(self, params=None, _group_number=None, _check=True): + """return a dictionary of bind parameter keys and values""" + + if params: + pd = {} + for bindparam, name in self.bind_names.items(): + if bindparam.key in params: + pd[name] = params[bindparam.key] + elif name in params: + pd[name] = params[name] + elif _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) + else: + pd[name] = bindparam.effective_value + return pd + else: + pd = {} + for bindparam in self.bind_names: + if _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) + pd[self.bind_names[bindparam]] = bindparam.effective_value + return pd + + @property + def params(self): + """Return the bind param dictionary embedded into this + compiled object, for those values that are present.""" + return self.construct_params(_check=False) + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + + """ + return "" + + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_label(self, label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw): + # only render labels within the columns clause + # or ORDER BY clause of a select. dialect-specific compilers + # can modify this behavior. + render_label_with_as = within_columns_clause and not within_label_clause + render_label_only = render_label_as_label is label + + if render_label_only or render_label_with_as: + if isinstance(label.name, elements._truncated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name + + if render_label_with_as: + if add_to_result_map is not None: + add_to_result_map( + labelname, + label.name, + (label, labelname, ) + label._alt_names, + label.type + ) + + return label.element._compiler_dispatch(self, + within_columns_clause=True, + within_label_clause=True, + **kw) + \ + OPERATORS[operators.as_] + \ + self.preparer.format_label(label, labelname) + elif render_label_only: + return labelname + else: + return label.element._compiler_dispatch(self, + within_columns_clause=False, + **kw) + + def visit_column(self, column, add_to_result_map=None, + include_table=True, **kwargs): + name = orig_name = column.name + if name is None: + raise exc.CompileError("Cannot compile Column object until " + "its 'name' is assigned.") + + is_literal = column.is_literal + if not is_literal and isinstance(name, elements._truncated_label): + name = self._truncated_identifier("colident", name) + + if add_to_result_map is not None: + add_to_result_map( + name, + orig_name, + (column, name, column.key), + column.type + ) + + if is_literal: + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(name) + + table = column.table + if table is None or not include_table or not table.named_with_column: + return name + else: + if table.schema: + schema_prefix = self.preparer.quote_schema(table.schema) + '.' + else: + schema_prefix = '' + tablename = table.name + if isinstance(tablename, elements._truncated_label): + tablename = self._truncated_identifier("alias", tablename) + + return schema_prefix + \ + self.preparer.quote(tablename) + \ + "." + name + + def escape_literal_column(self, text): + """provide escaping for the literal_column() construct.""" + + # TODO: some dialects might need different behavior here + return text.replace('%', '%%') + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kwargs): + return self.dialect.type_compiler.process(typeclause.type) + + def post_process_text(self, text): + return text + + def visit_textclause(self, textclause, **kw): + def do_bindparam(m): + name = m.group(1) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) + else: + return self.bindparam_string(name, **kw) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + BIND_PARAMS.sub(do_bindparam, + self.post_process_text(textclause.text)) + ) + + def visit_text_as_from(self, taf, iswrapper=False, + compound_index=0, force_result_map=False, + asfrom=False, + parens=True, **kw): + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = force_result_map or ( + compound_index == 0 and ( + toplevel or \ + entry['iswrapper'] + ) + ) + + if populate_result_map: + for c in taf.column_args: + self.process(c, within_columns_clause=True, + add_to_result_map=self._add_to_result_map) + + text = self.process(taf.element, **kw) + if asfrom and parens: + text = "(%s)" % text + return text + + + def visit_null(self, expr, **kw): + return 'NULL' + + def visit_true(self, expr, **kw): + if self.dialect.supports_native_boolean: + return 'true' + else: + return "1" + + def visit_false(self, expr, **kw): + if self.dialect.supports_native_boolean: + return 'false' + else: + return "0" + + def visit_clauselist(self, clauselist, order_by_select=None, **kw): + if order_by_select is not None: + return self._order_by_clauselist( + clauselist, order_by_select, **kw) + + sep = clauselist.operator + if sep is None: + sep = " " + else: + sep = OPERATORS[clauselist.operator] + return sep.join( + s for s in + ( + c._compiler_dispatch(self, **kw) + for c in clauselist.clauses) + if s) + + def _order_by_clauselist(self, clauselist, order_by_select, **kw): + # look through raw columns collection for labels. + # note that its OK we aren't expanding tables and other selectables + # here; we can only add a label in the ORDER BY for an individual + # label expression in the columns clause. + + raw_col = set(l._order_by_label_element.name + for l in order_by_select._raw_columns + if l._order_by_label_element is not None) + + return ", ".join( + s for s in + ( + c._compiler_dispatch(self, + render_label_as_label= + c._order_by_label_element if + c._order_by_label_element is not None and + c._order_by_label_element.name in raw_col + else None, + **kw) + for c in clauselist.clauses) + if s) + + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value is not None: + x += clause.value._compiler_dispatch(self, **kwargs) + " " + for cond, result in clause.whens: + x += "WHEN " + cond._compiler_dispatch( + self, **kwargs + ) + " THEN " + result._compiler_dispatch( + self, **kwargs) + " " + if clause.else_ is not None: + x += "ELSE " + clause.else_._compiler_dispatch( + self, **kwargs + ) + " " + x += "END" + return x + + def visit_cast(self, cast, **kwargs): + return "CAST(%s AS %s)" % \ + (cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs)) + + def visit_over(self, over, **kwargs): + return "%s OVER (%s)" % ( + over.func._compiler_dispatch(self, **kwargs), + ' '.join( + '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ('PARTITION', over.partition_by), + ('ORDER', over.order_by) + ) + if clause is not None and len(clause) + ) + ) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % (field, + extract.expr._compiler_dispatch(self, **kwargs)) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + if add_to_result_map is not None: + add_to_result_map( + func.name, func.name, (), func.type + ) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) + else: + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(list(func.packagenames) + [name]) % \ + {'expr': self.function_argspec(func, **kwargs)} + + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." % + self.dialect.name + ) + + def function_argspec(self, func, **kwargs): + return func.clause_expr._compiler_dispatch(self, **kwargs) + + + def visit_compound_select(self, cs, asfrom=False, + parens=True, compound_index=0, **kwargs): + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + self.stack.append( + { + 'correlate_froms': entry['correlate_froms'], + 'iswrapper': toplevel, + 'asfrom_froms': entry['asfrom_froms'] + }) + + keyword = self.compound_keywords.get(cs.keyword) + + text = (" " + keyword + " ").join( + (c._compiler_dispatch(self, + asfrom=asfrom, parens=False, + compound_index=i, **kwargs) + for i, c in enumerate(cs.selects)) + ) + + group_by = cs._group_by_clause._compiler_dispatch( + self, asfrom=asfrom, **kwargs) + if group_by: + text += " GROUP BY " + group_by + + text += self.order_by_clause(cs, **kwargs) + text += (cs._limit is not None or cs._offset is not None) and \ + self.limit_clause(cs) or "" + + if self.ctes and \ + compound_index == 0 and toplevel: + text = self._render_cte_clause() + text + + self.stack.pop(-1) + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def visit_unary(self, unary, **kw): + if unary.operator: + if unary.modifier: + raise exc.CompileError( + "Unary expression does not support operator " + "and modifier simultaneously") + disp = getattr(self, "visit_%s_unary_operator" % + unary.operator.__name__, None) + if disp: + return disp(unary, unary.operator, **kw) + else: + return self._generate_generic_unary_operator(unary, + OPERATORS[unary.operator], **kw) + elif unary.modifier: + disp = getattr(self, "visit_%s_unary_modifier" % + unary.modifier.__name__, None) + if disp: + return disp(unary, unary.modifier, **kw) + else: + return self._generate_generic_unary_modifier(unary, + OPERATORS[unary.modifier], **kw) + else: + raise exc.CompileError( + "Unary expression has no operator or modifier") + + def visit_istrue_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_isfalse_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + + def visit_binary(self, binary, **kw): + # don't allow "? = ?" to render + if self.ansi_bind_rules and \ + isinstance(binary.left, elements.BindParameter) and \ + isinstance(binary.right, elements.BindParameter): + kw['literal_binds'] = True + + operator = binary.operator + disp = getattr(self, "visit_%s_binary" % operator.__name__, None) + if disp: + return disp(binary, operator, **kw) + else: + try: + opstring = OPERATORS[operator] + except KeyError: + raise exc.UnsupportedCompilationError(self, operator) + else: + return self._generate_generic_binary(binary, opstring, **kw) + + def visit_custom_op_binary(self, element, operator, **kw): + return self._generate_generic_binary(element, + " " + operator.opstring + " ", **kw) + + def visit_custom_op_unary_operator(self, element, operator, **kw): + return self._generate_generic_unary_operator(element, + operator.opstring + " ", **kw) + + def visit_custom_op_unary_modifier(self, element, operator, **kw): + return self._generate_generic_unary_modifier(element, + " " + operator.opstring, **kw) + + def _generate_generic_binary(self, binary, opstring, **kw): + return binary.left._compiler_dispatch(self, **kw) + \ + opstring + \ + binary.right._compiler_dispatch(self, **kw) + + def _generate_generic_unary_operator(self, unary, opstring, **kw): + return opstring + unary.element._compiler_dispatch(self, **kw) + + def _generate_generic_unary_modifier(self, unary, opstring, **kw): + return unary.element._compiler_dispatch(self, **kw) + opstring + + @util.memoized_property + def _like_percent_literal(self): + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) + + def visit_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right).__add__(percent) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_notcontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right).__add__(percent) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__radd__( + binary.right + ) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_notstartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__radd__( + binary.right + ) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_notendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + # TODO: use ternary here, not "and"/ "or" + return '%s LIKE %s' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notlike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_bindparam(self, bindparam, within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs): + if not skip_bind_expression and bindparam.type._has_bind_expression: + bind_expression = bindparam.type.bind_expression(bindparam) + return self.process(bind_expression, + skip_bind_expression=True) + + if literal_binds or \ + (within_columns_clause and \ + self.ansi_bind_rules): + if bindparam.value is None and bindparam.callable is None: + raise exc.CompileError("Bind parameter '%s' without a " + "renderable value not allowed here." + % bindparam.key) + return self.render_literal_bindparam(bindparam, + within_columns_clause=True, **kwargs) + + name = self._truncate_bindparam(bindparam) + + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam: + if (existing.unique or bindparam.unique) and \ + not existing.proxy_set.intersection( + bindparam.proxy_set): + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % + bindparam.key + ) + elif existing._is_crud or bindparam._is_crud: + raise exc.CompileError( + "bindparam() name '%s' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using bindparam() " + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) + ) + + self.binds[bindparam.key] = self.binds[name] = bindparam + + return self.bindparam_string(name, **kwargs) + + def render_literal_bindparam(self, bindparam, **kw): + value = bindparam.effective_value + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind parameters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) + else: + raise NotImplementedError( + "Don't know how to literal-quote value %r" % value) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + if isinstance(bind_name, elements._truncated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.truncated_names: + return self.truncated_names[(ident_class, name)] + + anonname = name.apply_map(self.anon_map) + + if len(anonname) > self.label_length: + counter = self.truncated_names.get(ident_class, 1) + truncname = anonname[0:max(self.label_length - 6, 0)] + \ + "_" + hex(counter)[2:] + self.truncated_names[ident_class] = counter + 1 + else: + truncname = anonname + self.truncated_names[(ident_class, name)] = truncname + return truncname + + def _anonymize(self, name): + return name % self.anon_map + + def _process_anon(self, key): + (ident, derived) = key.split(' ', 1) + anonymous_counter = self.anon_map.get(derived, 1) + self.anon_map[derived] = anonymous_counter + 1 + return derived + "_" + str(anonymous_counter) + + def bindparam_string(self, name, positional_names=None, **kw): + if self.positional: + if positional_names is not None: + positional_names.append(name) + else: + self.positiontup.append(name) + return self.bindtemplate % {'name': name} + + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, + **kwargs): + self._init_cte_state() + if self.positional: + kwargs['positional_names'] = self.cte_positional + + if isinstance(cte.name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + + if cte_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_name] + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte in existing_cte._restates or cte is existing_cte: + return self.preparer.format_alias(cte, cte_name) + elif existing_cte in cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self.ctes[existing_cte] + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % + cte_name) + + self.ctes_by_name[cte_name] = cte + + if cte._cte_alias is not None: + orig_cte = cte._cte_alias + if orig_cte not in self.ctes: + self.visit_cte(orig_cte) + cte_alias_name = cte._cte_alias.name + if isinstance(cte_alias_name, elements._truncated_label): + cte_alias_name = self._truncated_identifier("alias", cte_alias_name) + else: + orig_cte = cte + cte_alias_name = None + if not cte_alias_name and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, selectable.Select): + col_source = cte.original + elif isinstance(cte.original, selectable.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c for c in + util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join( + self.preparer.format_column(ident) + for ident in recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + + if asfrom: + if cte_alias_name: + text = self.preparer.format_alias(cte, cte_alias_name) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + + def visit_alias(self, alias, asfrom=False, ashint=False, + iscrud=False, + fromhints=None, **kwargs): + if asfrom or ashint: + if isinstance(alias.name, elements._truncated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + ret = alias.original._compiler_dispatch(self, + asfrom=True, **kwargs) + \ + " AS " + \ + self.preparer.format_alias(alias, alias_name) + + if fromhints and alias in fromhints: + ret = self.format_from_hint_text(ret, alias, + fromhints[alias], iscrud) + + return ret + else: + return alias.original._compiler_dispatch(self, **kwargs) + + def _add_to_result_map(self, keyname, name, objects, type_): + if not self.dialect.case_sensitive: + keyname = keyname.lower() + + if keyname in self.result_map: + # conflicting keyname, just double up the list + # of objects. this will cause an "ambiguous name" + # error if an attempt is made by the result set to + # access. + e_name, e_obj, e_type = self.result_map[keyname] + self.result_map[keyname] = e_name, e_obj + objects, e_type + else: + self.result_map[keyname] = name, objects, type_ + + def _label_select_column(self, select, column, + populate_result_map, + asfrom, column_clause_args, + name=None, + within_columns_clause=True): + """produce labeled columns present in a select().""" + + if column.type._has_column_expression and \ + populate_result_map: + col_expr = column.type.column_expression(column) + add_to_result_map = lambda keyname, name, objects, type_: \ + self._add_to_result_map( + keyname, name, + objects + (column,), type_) + else: + col_expr = column + if populate_result_map: + add_to_result_map = self._add_to_result_map + else: + add_to_result_map = None + + if not within_columns_clause: + result_expr = col_expr + elif isinstance(column, elements.Label): + if col_expr is not column: + result_expr = _CompileLabel( + col_expr, + column.name, + alt_names=(column.element,) + ) + else: + result_expr = col_expr + + elif select is not None and name: + result_expr = _CompileLabel( + col_expr, + name, + alt_names=(column._key_label,) + ) + + elif \ + asfrom and \ + isinstance(column, elements.ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, selectable.Select): + result_expr = _CompileLabel(col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,)) + elif not isinstance(column, + (elements.UnaryExpression, elements.TextClause)) \ + and (not hasattr(column, 'name') or \ + isinstance(column, functions.Function)): + result_expr = _CompileLabel(col_expr, column.anon_label) + elif col_expr is not column: + # TODO: are we sure "column" has a .name and .key here ? + # assert isinstance(column, elements.ColumnClause) + result_expr = _CompileLabel(col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,)) + else: + result_expr = col_expr + + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map + ) + return result_expr._compiler_dispatch( + self, + **column_clause_args + ) + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext + + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + + def get_crud_hint_text(self, table, text): + return None + + def _transform_select_for_nested_joins(self, select): + """Rewrite any "a JOIN (b JOIN c)" expression as + "a JOIN (select * from b JOIN c) AS anon", to support + databases that can't parse a parenthesized join correctly + (i.e. sqlite the main one). + + """ + cloned = {} + column_translate = [{}] + + + def visit(element, **kw): + if element in column_translate[-1]: + return column_translate[-1][element] + + elif element in cloned: + return cloned[element] + + newelem = cloned[element] = element._clone() + + if newelem.is_selectable and newelem._is_join and \ + isinstance(newelem.right, selectable.FromGrouping): + + newelem._reset_exported() + newelem.left = visit(newelem.left, **kw) + + right = visit(newelem.right, **kw) + + selectable_ = selectable.Select( + [right.element], + use_labels=True).alias() + for c in selectable_.c: + c._key_label = c.key + c._label = c.name + + translate_dict = dict( + zip(newelem.right.element.c, selectable_.c) + ) + + # translating from both the old and the new + # because different select() structures will lead us + # to traverse differently + translate_dict[right.element.left] = selectable_ + translate_dict[right.element.right] = selectable_ + translate_dict[newelem.right.element.left] = selectable_ + translate_dict[newelem.right.element.right] = selectable_ + + # propagate translations that we've gained + # from nested visit(newelem.right) outwards + # to the enclosing select here. this happens + # only when we have more than one level of right + # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))" + for k, v in list(column_translate[-1].items()): + if v in translate_dict: + # remarkably, no current ORM tests (May 2013) + # hit this condition, only test_join_rewriting + # does. + column_translate[-1][k] = translate_dict[v] + + column_translate[-1].update(translate_dict) + + newelem.right = selectable_ + + newelem.onclause = visit(newelem.onclause, **kw) + + elif newelem.is_selectable and newelem._is_from_container: + # if we hit an Alias or CompoundSelect, put a marker in the + # stack. + kw['transform_clue'] = 'select_container' + newelem._copy_internals(clone=visit, **kw) + elif newelem.is_selectable and newelem._is_select: + barrier_select = kw.get('transform_clue', None) == 'select_container' + # if we're still descended from an Alias/CompoundSelect, we're + # in a FROM clause, so start with a new translate collection + if barrier_select: + column_translate.append({}) + kw['transform_clue'] = 'inside_select' + newelem._copy_internals(clone=visit, **kw) + if barrier_select: + del column_translate[-1] + else: + newelem._copy_internals(clone=visit, **kw) + + return newelem + + return visit(select) + + def _transform_result_map_for_nested_joins(self, select, transformed_select): + inner_col = dict((c._key_label, c) for + c in transformed_select.inner_columns) + + d = dict( + (inner_col[c._key_label], c) + for c in select.inner_columns + ) + for key, (name, objs, typ) in list(self.result_map.items()): + objs = tuple([d.get(col, col) for col in objs]) + self.result_map[key] = (name, objs, typ) + + + _default_stack_entry = util.immutabledict([ + ('iswrapper', False), + ('correlate_froms', frozenset()), + ('asfrom_froms', frozenset()) + ]) + + def _display_froms_for_select(self, select, asfrom): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms=\ + correlate_froms.difference(asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + return froms + + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, fromhints=None, + compound_index=0, + force_result_map=False, + positional_names=None, + nested_join_translation=False, + **kwargs): + + needs_nested_translation = \ + select.use_labels and \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins + + if needs_nested_translation: + transformed_select = self._transform_select_for_nested_joins(select) + text = self.visit_select( + transformed_select, asfrom=asfrom, parens=parens, + iswrapper=iswrapper, fromhints=fromhints, + compound_index=compound_index, + force_result_map=force_result_map, + positional_names=positional_names, + nested_join_translation=True, **kwargs + ) + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + + populate_result_map = force_result_map or ( + compound_index == 0 and ( + toplevel or \ + entry['iswrapper'] + ) + ) + + if needs_nested_translation: + if populate_result_map: + self._transform_result_map_for_nested_joins( + select, transformed_select) + return text + + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms= + correlate_froms.difference(asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + + new_correlate_froms = set(selectable._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry = { + 'asfrom_froms': new_correlate_froms, + 'iswrapper': iswrapper, + 'correlate_froms': all_correlate_froms + } + self.stack.append(new_entry) + + column_clause_args = kwargs.copy() + column_clause_args.update({ + 'positional_names': positional_names, + 'within_label_clause': False, + 'within_columns_clause': False + }) + + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c for c in [ + self._label_select_column(select, + column, + populate_result_map, asfrom, + column_clause_args, + name=name) + for name, column in select._columns_plus_names + ] + if c is not None + ] + + text = "SELECT " # we're off to a good start ! + + if select._hints: + byfrom = dict([ + (from_, hinttext % { + 'name':from_._compiler_dispatch( + self, ashint=True) + }) + for (from_, dialect), hinttext in + select._hints.items() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + if hint_text: + text += hint_text + " " + + if select._prefixes: + text += self._generate_prefixes(select, select._prefixes, **kwargs) + + text += self.get_select_precolumns(select) + text += ', '.join(inner_columns) + + if froms: + text += " \nFROM " + + if select._hints: + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, fromhints=byfrom, + **kwargs) + for f in froms]) + else: + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, **kwargs) + for f in froms]) + else: + text += self.default_from() + + if select._whereclause is not None: + t = select._whereclause._compiler_dispatch(self, **kwargs) + if t: + text += " \nWHERE " + t + + if select._group_by_clause.clauses: + group_by = select._group_by_clause._compiler_dispatch( + self, **kwargs) + if group_by: + text += " GROUP BY " + group_by + + if select._having is not None: + t = select._having._compiler_dispatch(self, **kwargs) + if t: + text += " \nHAVING " + t + + if select._order_by_clause.clauses: + if self.dialect.supports_simple_order_by_label: + order_by_select = select + else: + order_by_select = None + + text += self.order_by_clause(select, + order_by_select=order_by_select, **kwargs) + + if select._limit is not None or select._offset is not None: + text += self.limit_clause(select) + + if select._for_update_arg is not None: + text += self.for_update_clause(select) + + if self.ctes and \ + compound_index == 0 and toplevel: + text = self._render_cte_clause() + text + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def _generate_prefixes(self, stmt, prefixes, **kw): + clause = " ".join( + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name is None or + dialect_name == self.dialect.name + ) + if clause: + clause += " " + return clause + + def _render_cte_clause(self): + if self.positional: + self.positiontup = self.cte_positional + self.positiontup + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + return cte_text + + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list. + + """ + return select._distinct and "DISTINCT " or "" + + def order_by_clause(self, select, **kw): + order_by = select._order_by_clause._compiler_dispatch(self, **kw) + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select): + return " FOR UPDATE" + + def returning_clause(self, stmt, returning_cols): + raise exc.CompileError( + "RETURNING is not supported by this " + "dialect's statement compiler.") + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += "\n LIMIT " + self.process(elements.literal(select._limit)) + if select._offset is not None: + if select._limit is None: + text += "\n LIMIT -1" + text += " OFFSET " + self.process(elements.literal(select._offset)) + return text + + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + fromhints=None, **kwargs): + if asfrom or ashint: + if getattr(table, "schema", None): + ret = self.preparer.quote_schema(table.schema) + \ + "." + self.preparer.quote(table.name) + else: + ret = self.preparer.quote(table.name) + if fromhints and table in fromhints: + ret = self.format_from_hint_text(ret, table, + fromhints[table], iscrud) + return ret + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return ( + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) + ) + + def visit_insert(self, insert_stmt, **kw): + self.isinsert = True + colparams = self._get_colparams(insert_stmt, **kw) + + if not colparams and \ + not self.dialect.supports_default_values and \ + not self.dialect.supports_empty_insert: + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % + self.dialect.name) + + if insert_stmt._has_multi_parameters: + if not self.dialect.supports_multivalues_insert: + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % + self.dialect.name) + colparams_single = colparams[0] + else: + colparams_single = colparams + + + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT " + + if insert_stmt._prefixes: + text += self._generate_prefixes(insert_stmt, + insert_stmt._prefixes, **kw) + + text += "INTO " + table_text = preparer.format_table(insert_stmt.table) + + if insert_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + insert_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if insert_stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, + insert_stmt.table, + dialect_hints[insert_stmt.table], + True + ) + + text += table_text + + if colparams_single or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams_single]) + + if self.returning or insert_stmt._returning: + self.returning = self.returning or insert_stmt._returning + returning_clause = self.returning_clause( + insert_stmt, self.returning) + + if self.returning_precedes_values: + text += " " + returning_clause + + if insert_stmt.select is not None: + text += " %s" % self.process(insert_stmt.select, **kw) + elif not colparams and supports_default_values: + text += " DEFAULT VALUES" + elif insert_stmt._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" % ( + ', '.join(c[1] for c in colparam_set) + ) + for colparam_set in colparams + ) + ) + else: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if self.returning and not self.returning_precedes_values: + text += " " + returning_clause + + return text + + def update_limit_clause(self, update_stmt): + """Provide a hook for MySQL to add LIMIT to the UPDATE""" + return None + + def update_tables_clause(self, update_stmt, from_table, + extra_froms, **kw): + """Provide a hook to override the initial table clause + in an UPDATE statement. + + MySQL overrides this. + + """ + return from_table._compiler_dispatch(self, asfrom=True, + iscrud=True, **kw) + + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): + """Provide a hook to override the generation of an + UPDATE..FROM clause. + + MySQL and MSSQL override this. + + """ + return "FROM " + ', '.join( + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in extra_froms) + + def visit_update(self, update_stmt, **kw): + self.stack.append( + {'correlate_froms': set([update_stmt.table]), + "iswrapper": False, + "asfrom_froms": set([update_stmt.table])}) + + self.isupdate = True + + extra_froms = update_stmt._extra_froms + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes(update_stmt, + update_stmt._prefixes, **kw) + + table_text = self.update_tables_clause(update_stmt, update_stmt.table, + extra_froms, **kw) + + colparams = self._get_colparams(update_stmt, **kw) + + if update_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + update_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if update_stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, + update_stmt.table, + dialect_hints[update_stmt.table], + True + ) + else: + dialect_hints = None + + text += table_text + + text += ' SET ' + include_table = extra_froms and \ + self.render_table_with_column_in_update_from + text += ', '.join( + c[0]._compiler_dispatch(self, + include_table=include_table) + + '=' + c[1] for c in colparams + ) + + if self.returning or update_stmt._returning: + if not self.returning: + self.returning = update_stmt._returning + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, self.returning) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + extra_froms, + dialect_hints, **kw) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._whereclause is not None: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if self.returning and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, self.returning) + + self.stack.pop(-1) + + return text + + def _create_crud_bind_param(self, col, value, required=False, name=None): + if name is None: + name = col.key + bindparam = elements.BindParameter(name, value, + type_=col.type, required=required) + bindparam._is_crud = True + return bindparam._compiler_dispatch(self) + + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the ResultProxy's prefetch_cols() and postfetch_cols() + collections. + + """ + + self.postfetch = [] + self.prefetch = [] + self.returning = [] + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if self.column_keys is None and stmt.parameters is None: + return [ + (c, self._create_crud_bind_param(c, + None, required=True)) + for c in stmt.table.columns + ] + + if stmt._has_multi_parameters: + stmt_parameters = stmt.parameters[0] + else: + stmt_parameters = stmt.parameters + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + + # if we have statement parameters - set defaults in the + # compiled params + if self.column_keys is None: + parameters = {} + else: + parameters = dict((_column_as_key(key), REQUIRED) + for key in self.column_keys + if not stmt_parameters or + key not in stmt_parameters) + + # create a list of column assignment clauses as tuples + values = [] + + if stmt_parameters is not None: + for k, v in stmt_parameters.items(): + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + if elements._is_literal(v): + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) + else: + v = self.process(v.self_group(), **kw) + + values.append((k, v)) + + need_pks = self.isinsert and \ + not self.inline and \ + not stmt._returning + + implicit_returning = need_pks and \ + self.dialect.implicit_returning and \ + stmt.table.implicit_returning + + if self.isinsert: + implicit_return_defaults = implicit_returning and stmt._return_defaults + elif self.isupdate: + implicit_return_defaults = self.dialect.implicit_returning and \ + stmt.table.implicit_returning and \ + stmt._return_defaults + + if implicit_return_defaults: + if stmt._return_defaults is True: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults) + + postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid + + check_columns = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if self.isupdate and stmt._extra_froms and stmt_parameters: + normalized_params = dict( + (elements._clause_element_as_expr(c), param) + for c, param in stmt_parameters.items() + ) + affected_tables = set() + for t in stmt._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + if elements._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) + else: + self.postfetch.append(c) + value = self.process(value.self_group(), **kw) + values.append((c, value)) + # determine tables which are actually + # to be updated - process onupdate and + # server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) + ) + self.postfetch.append(c) + else: + values.append( + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) + ) + self.prefetch.append(c) + elif c.server_onupdate is not None: + self.postfetch.append(c) + + if self.isinsert and stmt.select_names: + # for an insert from select, we can only use names that + # are given, so only select for those names. + cols = (stmt.table.c[_column_as_key(name)] + for name in stmt.select_names) + else: + # iterate through all table columns to maintain + # ordering, even for those cols that aren't included + cols = stmt.table.columns + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) + if elements._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is REQUIRED, + name=_col_bind_name(c) + if not stmt._has_multi_parameters + else "%s_0" % _col_bind_name(c) + ) + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + + if c.primary_key and implicit_returning: + self.returning.append(c) + value = self.process(value.self_group(), **kw) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + value = self.process(value.self_group(), **kw) + else: + self.postfetch.append(c) + value = self.process(value.self_group(), **kw) + values.append((c, value)) + + elif self.isinsert: + if c.primary_key and \ + need_pks and \ + ( + implicit_returning or + not postfetch_lastrowid or + c is not stmt.table._autoincrement_column + ): + + if implicit_returning: + if c.default is not None: + if c.default.is_sequence: + if self.dialect.supports_sequences and \ + (not c.default.optional or \ + not self.dialect.sequences_optional): + proc = self.process(c.default, **kw) + values.append((c, proc)) + self.returning.append(c) + elif c.default.is_clause_element: + values.append( + (c, + self.process(c.default.arg.self_group(), **kw)) + ) + self.returning.append(c) + else: + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if ( + c.default is not None and + ( + not c.default.is_sequence or + self.dialect.supports_sequences + ) + ) or \ + c is stmt.table._autoincrement_column and ( + self.dialect.supports_sequences or + self.dialect.preexecute_autoincrement_sequences + ): + + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + + self.prefetch.append(c) + + elif c.default is not None: + if c.default.is_sequence: + if self.dialect.supports_sequences and \ + (not c.default.optional or \ + not self.dialect.sequences_optional): + proc = self.process(c.default, **kw) + values.append((c, proc)) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: + self.postfetch.append(c) + elif c.default.is_clause_element: + values.append( + (c, self.process(c.default.arg.self_group(), **kw)) + ) + + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: + # dont add primary key column to postfetch + self.postfetch.append(c) + else: + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) + elif c.server_default is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: + self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + + elif self.isupdate: + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, self.process(c.onupdate.arg.self_group(), **kw)) + ) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) + else: + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) + elif c.server_onupdate is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + + if parameters and stmt_parameters: + check = set(parameters).intersection( + _column_as_key(k) for k in stmt.parameters + ).difference(check_columns) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" % + (", ".join("%s" % c for c in check)) + ) + + if stmt._has_multi_parameters: + values_0 = values + values = [values] + + values.extend( + [ + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) + ) + + return values + + def visit_delete(self, delete_stmt, **kw): + self.stack.append({'correlate_froms': set([delete_stmt.table]), + "iswrapper": False, + "asfrom_froms": set([delete_stmt.table])}) + self.isdelete = True + + text = "DELETE " + + if delete_stmt._prefixes: + text += self._generate_prefixes(delete_stmt, + delete_stmt._prefixes, **kw) + + text += "FROM " + table_text = delete_stmt.table._compiler_dispatch(self, + asfrom=True, iscrud=True) + + if delete_stmt._hints: + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + delete_stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if delete_stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, + delete_stmt.table, + dialect_hints[delete_stmt.table], + True + ) + + else: + dialect_hints = None + + text += table_text + + if delete_stmt._returning: + self.returning = delete_stmt._returning + if self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) + + if delete_stmt._whereclause is not None: + text += " WHERE " + text += delete_stmt._whereclause._compiler_dispatch(self) + + if self.returning and not self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) + + +class DDLCompiler(Compiled): + + @util.memoized_property + def sql_compiler(self): + return self.dialect.statement_compiler(self.dialect, None) + + @util.memoized_property + def type_compiler(self): + return self.dialect.type_compiler + + @property + def preparer(self): + return self.dialect.identifier_preparer + + def construct_params(self, params=None): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.target, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.target) + if len(path) == 1: + table, sch = path[0], '' + else: + table, sch = path[-1], path[0] + + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.target)) + + return self.sql_compiler.post_process_text(ddl.statement % context) + + def visit_create_schema(self, create): + schema = self.preparer.format_schema(create.element) + return "CREATE SCHEMA " + schema + + def visit_drop_schema(self, drop): + schema = self.preparer.format_schema(drop.element) + text = "DROP SCHEMA " + schema + if drop.cascade: + text += " CASCADE" + return text + + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ + ['TABLE', + preparer.format_table(table), + "("]) + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for create_column in create.columns: + column = create_column.element + try: + processed = self.process(create_column, + first_pk=column.primary_key + and not first_pk) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed + if column.primary_key: + first_pk = True + except exc.CompileError as ce: + util.raise_from_cause( + exc.CompileError(util.u("(in table '%s', column '%s'): %s") % ( + table.description, + column.name, + ce.args[0] + ))) + + const = self.create_table_constraints(table) + if const: + text += ", \n\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_create_column(self, create, first_pk=False): + column = create.element + + if column.system: + return None + + text = self.get_column_specification( + column, + first_pk=first_pk + ) + const = " ".join(self.process(constraint) \ + for constraint in column.constraints) + if const: + text += " " + const + + return text + + def create_table_constraints(self, table): + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + constraints = [] + if table.primary_key: + constraints.append(table.primary_key) + + constraints.extend([c for c in table._sorted_constraints + if c is not table.primary_key]) + + return ", \n\t".join(p for p in + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) + + def visit_drop_table(self, drop): + return "\nDROP TABLE " + self.preparer.format_table(drop.element) + + def visit_drop_view(self, drop): + return "\nDROP VIEW " + self.preparer.format_table(drop.element) + + + def _verify_index_table(self, index): + if index.table is None: + raise exc.CompileError("Index '%s' is not associated " + "with any table." % index.name) + + + def visit_create_index(self, create, include_schema=False, + include_table_schema=True): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % ( + self._prepared_index_name(index, + include_schema=include_schema), + preparer.format_table(index.table, + use_schema=include_table_schema), + ', '.join( + self.sql_compiler.process(expr, + include_table=False, literal_binds=True) for + expr in index.expressions) + ) + return text + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + self._prepared_index_name(index, + include_schema=True) + + def _prepared_index_name(self, index, include_schema=False): + if include_schema and index.table is not None and index.table.schema: + schema = index.table.schema + schema_name = self.preparer.quote_schema(schema) + else: + schema_name = None + + ident = index.name + if isinstance(ident, elements._truncated_label): + max_ = self.dialect.max_index_name_length or \ + self.dialect.max_identifier_length + if len(ident) > max_: + ident = ident[0:max_ - 8] + \ + "_" + util.md5_hex(ident)[-4:] + else: + self.dialect.validate_identifier(ident) + + index_name = self.preparer.quote(ident) + + if schema_name: + index_name = schema_name + "." + index_name + return index_name + + def visit_add_constraint(self, create): + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element) + ) + + def visit_create_sequence(self, create): + text = "CREATE SEQUENCE %s" % \ + self.preparer.format_sequence(create.element) + if create.element.increment is not None: + text += " INCREMENT BY %d" % create.element.increment + if create.element.start is not None: + text += " START WITH %d" % create.element.start + return text + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % \ + self.preparer.format_sequence(drop.element) + + def visit_drop_constraint(self, drop): + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element), + drop.cascade and " CASCADE" or "" + ) + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def post_create_table(self, table): + return '' + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + if isinstance(column.server_default.arg, util.string_types): + return "'%s'" % column.server_default.arg + else: + return self.sql_compiler.process(column.server_default.arg) + else: + return None + + def visit_check_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, + include_table=False, + literal_binds=True) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_column_check_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return '' + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint(self, constraint): + preparer = self.dialect.identifier_preparer + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) + remote_table = list(constraint._elements.values())[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ', '.join(preparer.quote(f.parent.name) + for f in constraint._elements.values()), + self.define_constraint_remote_table( + constraint, remote_table, preparer), + ', '.join(preparer.quote(f.column.name) + for f in constraint._elements.values()) + ) + text += self.define_constraint_match(constraint) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + + def visit_unique_constraint(self, constraint): + if len(constraint) == 0: + return '' + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "UNIQUE (%s)" % ( + ', '.join(self.preparer.quote(c.name) + for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + if constraint.onupdate is not None: + text += " ON UPDATE %s" % constraint.onupdate + return text + + def define_constraint_deferrability(self, constraint): + text = "" + if constraint.deferrable is not None: + if constraint.deferrable: + text += " DEFERRABLE" + else: + text += " NOT DEFERRABLE" + if constraint.initially is not None: + text += " INITIALLY %s" % constraint.initially + return text + + def define_constraint_match(self, constraint): + text = "" + if constraint.match is not None: + text += " MATCH %s" % constraint.match + return text + + +class GenericTypeCompiler(TypeCompiler): + + def visit_FLOAT(self, type_): + return "FLOAT" + + def visit_REAL(self, type_): + return "REAL" + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + elif type_.scale is None: + return "NUMERIC(%(precision)s)" % \ + {'precision': type_.precision} + else: + return "NUMERIC(%(precision)s, %(scale)s)" % \ + {'precision': type_.precision, + 'scale': type_.scale} + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return "DECIMAL" + elif type_.scale is None: + return "DECIMAL(%(precision)s)" % \ + {'precision': type_.precision} + else: + return "DECIMAL(%(precision)s, %(scale)s)" % \ + {'precision': type_.precision, + 'scale': type_.scale} + + def visit_INTEGER(self, type_): + return "INTEGER" + + def visit_SMALLINT(self, type_): + return "SMALLINT" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_CLOB(self, type_): + return "CLOB" + + def visit_NCLOB(self, type_): + return "NCLOB" + + def _render_string_type(self, type_, name): + + text = name + if type_.length: + text += "(%d)" % type_.length + if type_.collation: + text += ' COLLATE "%s"' % type_.collation + return text + + def visit_CHAR(self, type_): + return self._render_string_type(type_, "CHAR") + + def visit_NCHAR(self, type_): + return self._render_string_type(type_, "NCHAR") + + def visit_VARCHAR(self, type_): + return self._render_string_type(type_, "VARCHAR") + + def visit_NVARCHAR(self, type_): + return self._render_string_type(type_, "NVARCHAR") + + def visit_TEXT(self, type_): + return self._render_string_type(type_, "TEXT") + + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BINARY(self, type_): + return "BINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_VARBINARY(self, type_): + return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_large_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + + def visit_time(self, type_): + return self.visit_TIME(type_) + + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + + def visit_date(self, type_): + return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) + + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + + def visit_real(self, type_): + return self.visit_REAL(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + + def visit_unicode(self, type_): + return self.visit_VARCHAR(type_) + + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_unicode_text(self, type_): + return self.visit_TEXT(type_) + + def visit_enum(self, type_): + return self.visit_VARCHAR(type_) + + def visit_null(self, type_): + raise exc.CompileError("Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_) + + def visit_type_decorator(self, type_): + return self.process(type_.type_engine(self.dialect)) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + + +class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + def __init__(self, dialect, initial_quote='"', + final_quote=None, escape_quote='"', omit_schema=False): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to + `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.escape_quote = escape_quote + self.escape_to_quote = self.escape_quote * 2 + self.omit_schema = omit_schema + self._strings = {} + + def _escape_identifier(self, value): + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + return value.replace(self.escape_quote, self.escape_to_quote) + + def _unescape_identifier(self, value): + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace(self.escape_to_quote, self.escape_quote) + + def quote_identifier(self, value): + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return self.initial_quote + \ + self._escape_identifier(value) + \ + self.final_quote + + def _requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value)) + + def quote_schema(self, schema, force=None): + """Conditionally quote a schema. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + the 'force' flag should be considered deprecated. + + """ + return self.quote(schema, force) + + def quote(self, ident, force=None): + """Conditionally quote an identifier. + + the 'force' flag should be considered deprecated. + """ + + force = getattr(ident, "quote", None) + + if force is None: + if ident in self._strings: + return self._strings[ident] + else: + if self._requires_quotes(ident): + self._strings[ident] = self.quote_identifier(ident) + else: + self._strings[ident] = ident + return self._strings[ident] + elif force: + return self.quote_identifier(ident) + else: + return ident + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence.name) + if not self.omit_schema and use_schema and sequence.schema is not None: + name = self.quote_schema(sequence.schema) + "." + name + return name + + def format_label(self, label, name=None): + return self.quote(name or label.name) + + def format_alias(self, alias, name=None): + return self.quote(name or alias.name) + + def format_savepoint(self, savepoint, name=None): + return self.quote(name or savepoint.ident) + + def format_constraint(self, constraint): + return self.quote(constraint.name) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + result = self.quote(name) + if not self.omit_schema and use_schema \ + and getattr(table, "schema", None): + result = self.quote_schema(table.schema) + "." + result + return result + + def format_schema(self, name, quote=None): + """Prepare a quoted schema name.""" + + return self.quote(name, quote) + + def format_column(self, column, use_table=False, + name=None, table_name=None): + """Prepare a quoted column name.""" + + if name is None: + name = column.name + if not getattr(column, 'is_literal', False): + if use_table: + return self.format_table( + column.table, use_schema=False, + name=table_name) + "." + self.quote(name) + else: + return self.quote(name) + else: + # literal textual elements get stuck into ColumnClause a lot, + # which shouldn't get quoted + + if use_table: + return self.format_table(column.table, + use_schema=False, name=table_name) + '.' + name + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + if not self.omit_schema and use_schema and \ + getattr(table, 'schema', None): + return (self.quote_schema(table.schema), + self.format_table(table, use_schema=False)) + else: + return (self.format_table(table, use_schema=False), ) + + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + {'initial': initial, + 'final': final, + 'escaped': escaped_final}) + return r + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + r = self._r_identifiers + return [self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)]] diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py new file mode 100644 index 00000000..bda87650 --- /dev/null +++ b/lib/sqlalchemy/sql/ddl.py @@ -0,0 +1,864 @@ +# sql/ddl.py +# Copyright (C) 2009-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +Provides the hierarchy of DDL-defining schema items as well as routines +to invoke them for a create/drop call. + +""" + +from .. import util +from .elements import ClauseElement +from .visitors import traverse +from .base import Executable, _generative, SchemaVisitor, _bind_or_error +from ..util import topological +from .. import event +from .. import exc + +class _DDLCompiles(ClauseElement): + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + + +class DDLElement(Executable, _DDLCompiles): + """Base class for DDL expression constructs. + + This class is the base for the general purpose :class:`.DDL` class, + as well as the various create/drop clause constructs such as + :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`, + etc. + + :class:`.DDLElement` integrates closely with SQLAlchemy events, + introduced in :ref:`event_toplevel`. An instance of one is + itself an event receiving callable:: + + event.listen( + users, + 'after_create', + AddConstraint(constraint).execute_if(dialect='postgresql') + ) + + .. seealso:: + + :class:`.DDL` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + :ref:`schema_ddl_sequences` + + """ + + _execution_options = Executable.\ + _execution_options.union({'autocommit': True}) + + target = None + on = None + dialect = None + callable_ = None + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_ddl(self, multiparams, params) + + def execute(self, bind=None, target=None): + """Execute this DDL immediately. + + Executes the DDL statement in isolation using the supplied + :class:`.Connectable` or + :class:`.Connectable` assigned to the ``.bind`` + property, if not supplied. If the DDL has a conditional ``on`` + criteria, it will be invoked with None as the event. + + :param bind: + Optional, an ``Engine`` or ``Connection``. If not supplied, a valid + :class:`.Connectable` must be present in the + ``.bind`` property. + + :param target: + Optional, defaults to None. The target SchemaItem for the + execute call. Will be passed to the ``on`` callable if any, + and may also provide string expansion data for the + statement. See ``execute_at`` for more information. + + """ + + if bind is None: + bind = _bind_or_error(self) + + if self._should_execute(target, bind): + return bind.execute(self.against(target)) + else: + bind.engine.logger.info( + "DDL execution skipped, criteria not met.") + + @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as " + ":meth:`.DDLElement.execute_if`.") + def execute_at(self, event_name, target): + """Link execution of this DDL to the DDL lifecycle of a SchemaItem. + + Links this ``DDLElement`` to a ``Table`` or ``MetaData`` instance, + executing it when that schema item is created or dropped. The DDL + statement will be executed using the same Connection and transactional + context as the Table create/drop itself. The ``.bind`` property of + this statement is ignored. + + :param event: + One of the events defined in the schema item's ``.ddl_events``; + e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop' + + :param target: + The Table or MetaData instance for which this DDLElement will + be associated with. + + A DDLElement instance can be linked to any number of schema items. + + ``execute_at`` builds on the ``append_ddl_listener`` interface of + :class:`.MetaData` and :class:`.Table` objects. + + Caveat: Creating or dropping a Table in isolation will also trigger + any DDL set to ``execute_at`` that Table's MetaData. This may change + in a future release. + + """ + + def call_event(target, connection, **kw): + if self._should_execute_deprecated(event_name, + target, connection, **kw): + return connection.execute(self.against(target)) + + event.listen(target, "" + event_name.replace('-', '_'), call_event) + + @_generative + def against(self, target): + """Return a copy of this DDL against a specific schema item.""" + + self.target = target + + @_generative + def execute_if(self, dialect=None, callable_=None, state=None): + """Return a callable that will execute this + DDLElement conditionally. + + Used to provide a wrapper for event listening:: + + event.listen( + metadata, + 'before_create', + DDL("my_ddl").execute_if(dialect='postgresql') + ) + + :param dialect: May be a string, tuple or a callable + predicate. If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something').execute_if(dialect='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something').execute_if(dialect=('postgresql', 'mysql')) + + :param callable_: A callable, which will be invoked with + four positional arguments as well as optional keyword + arguments: + + :ddl: + This DDL element. + + :target: + The :class:`.Table` or :class:`.MetaData` object which is the + target of this event. May be None if the DDL is executed + explicitly. + + :bind: + The :class:`.Connection` being used for DDL execution + + :tables: + Optional keyword argument - a list of Table objects which are to + be created/ dropped within a MetaData.create_all() or drop_all() + method call. + + :state: + Optional keyword argument - will be the ``state`` argument + passed to this function. + + :checkfirst: + Keyword argument, will be True if the 'checkfirst' flag was + set during the call to ``create()``, ``create_all()``, + ``drop()``, ``drop_all()``. + + If the callable returns a true value, the DDL statement will be + executed. + + :param state: any value which will be passed to the callable\_ + as the ``state`` keyword argument. + + .. seealso:: + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + self.dialect = dialect + self.callable_ = callable_ + self.state = state + + def _should_execute(self, target, bind, **kw): + if self.on is not None and \ + not self._should_execute_deprecated(None, target, bind, **kw): + return False + + if isinstance(self.dialect, util.string_types): + if self.dialect != bind.engine.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if bind.engine.name not in self.dialect: + return False + if self.callable_ is not None and \ + not self.callable_(self, target, bind, state=self.state, **kw): + return False + + return True + + def _should_execute_deprecated(self, event, target, bind, **kw): + if self.on is None: + return True + elif isinstance(self.on, util.string_types): + return self.on == bind.engine.name + elif isinstance(self.on, (tuple, list, set)): + return bind.engine.name in self.on + else: + return self.on(self, event, target, bind, **kw) + + def __call__(self, target, bind, **kw): + """Execute the DDL as a ddl_listener.""" + + if self._should_execute(target, bind, **kw): + return bind.execute(self.against(target)) + + def _check_ddl_on(self, on): + if (on is not None and + (not isinstance(on, util.string_types + (tuple, list, set)) and + not util.callable(on))): + raise exc.ArgumentError( + "Expected the name of a database dialect, a tuple " + "of names, or a callable for " + "'on' criteria, got type '%s'." % type(on).__name__) + + def bind(self): + if self._bind: + return self._bind + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +class DDL(DDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects + function as DDL event listeners, and can be subscribed to those events + listed in :class:`.DDLEvents`, using either :class:`.Table` or + :class:`.MetaData` objects as targets. Basic templating support allows + a single DDL instance to handle repetitive tasks for multiple tables. + + Examples:: + + from sqlalchemy import event, DDL + + tbl = Table('users', metadata, Column('uid', Integer)) + event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger')) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE') + event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb')) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + + When operating on Table events, the following ``statement`` + string substitions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's "context", if any, will be combined with the standard + substutions noted above. Keys present in the context will override + the standard substitutions. + + """ + + __visit_name__ = "ddl" + + def __init__(self, statement, on=None, context=None, bind=None): + """Create a DDL statement. + + :param statement: + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator. See the + ``context`` argument and the ``execute_at`` method. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + :param on: + .. deprecated:: 0.7 + See :meth:`.DDLElement.execute_if`. + + Optional filtering criteria. May be a string, tuple or a callable + predicate. If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something', on='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something', on=('postgresql', 'mysql')) + + If a callable, it will be invoked with four positional arguments + as well as optional keyword arguments: + + :ddl: + This DDL element. + + :event: + The name of the event that has triggered this DDL, such as + 'after-create' Will be None if the DDL is executed explicitly. + + :target: + The ``Table`` or ``MetaData`` object which is the target of + this event. May be None if the DDL is executed explicitly. + + :connection: + The ``Connection`` being used for DDL execution + + :tables: + Optional keyword argument - a list of Table objects which are to + be created/ dropped within a MetaData.create_all() or drop_all() + method call. + + + If the callable returns a true value, the DDL statement will be + executed. + + :param context: + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + :param bind: + Optional. A :class:`.Connectable`, used by + default when ``execute()`` is invoked without a bind argument. + + + .. seealso:: + + :class:`.DDLEvents` + + :mod:`sqlalchemy.event` + + """ + + if not isinstance(statement, util.string_types): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" % + statement) + + self.statement = statement + self.context = context or {} + + self._check_ddl_on(on) + self.on = on + self._bind = bind + + def __repr__(self): + return '<%s@%s; %s>' % ( + type(self).__name__, id(self), + ', '.join([repr(self.statement)] + + ['%s=%r' % (key, getattr(self, key)) + for key in ('on', 'context') + if getattr(self, key)])) + + + +class _CreateDropBase(DDLElement): + """Base class for DDL constucts that represent CREATE and DROP or + equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + def __init__(self, element, on=None, bind=None): + self.element = element + self._check_ddl_on(on) + self.on = on + self.bind = bind + + def _create_rule_disable(self, compiler): + """Allow disable of _create_rule using a callable. + + Pass to _create_rule using + util.portable_instancemethod(self._create_rule_disable) + to retain serializability. + + """ + return False + + +class CreateSchema(_CreateDropBase): + """Represent a CREATE SCHEMA statement. + + .. versionadded:: 0.7.4 + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "create_schema" + + def __init__(self, name, quote=None, **kw): + """Create a new :class:`.CreateSchema` construct.""" + + self.quote = quote + super(CreateSchema, self).__init__(name, **kw) + + +class DropSchema(_CreateDropBase): + """Represent a DROP SCHEMA statement. + + The argument here is the string name of the schema. + + .. versionadded:: 0.7.4 + + """ + + __visit_name__ = "drop_schema" + + def __init__(self, name, quote=None, cascade=False, **kw): + """Create a new :class:`.DropSchema` construct.""" + + self.quote = quote + self.cascade = cascade + super(DropSchema, self).__init__(name, **kw) + + +class CreateTable(_CreateDropBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + + def __init__(self, element, on=None, bind=None): + """Create a :class:`.CreateTable` construct. + + :param element: a :class:`.Table` that's the subject + of the CREATE + :param on: See the description for 'on' in :class:`.DDL`. + :param bind: See the description for 'bind' in :class:`.DDL`. + + """ + super(CreateTable, self).__init__(element, on=on, bind=bind) + self.columns = [CreateColumn(column) + for column in element.columns + ] + + +class _DropView(_CreateDropBase): + """Semi-public 'DROP VIEW' construct. + + Used by the test suite for dialect-agnostic drops of views. + This object will eventually be part of a public "view" API. + + """ + __visit_name__ = "drop_view" + + +class CreateColumn(_DDLCompiles): + """Represent a :class:`.Column` as rendered in a CREATE TABLE statement, + via the :class:`.CreateTable` construct. + + This is provided to support custom column DDL within the generation + of CREATE TABLE statements, by using the + compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel` + to extend :class:`.CreateColumn`. + + Typical integration is to examine the incoming :class:`.Column` + object, and to redirect compilation if a particular flag or condition + is found:: + + from sqlalchemy import schema + from sqlalchemy.ext.compiler import compiles + + @compiles(schema.CreateColumn) + def compile(element, compiler, **kw): + column = element.element + + if "special" not in column.info: + return compiler.visit_create_column(element, **kw) + + text = "%s SPECIAL DIRECTIVE %s" % ( + column.name, + compiler.type_compiler.process(column.type) + ) + default = compiler.get_column_default_string(column) + if default is not None: + text += " DEFAULT " + default + + if not column.nullable: + text += " NOT NULL" + + if column.constraints: + text += " ".join( + compiler.process(const) + for const in column.constraints) + return text + + The above construct can be applied to a :class:`.Table` as follows:: + + from sqlalchemy import Table, Metadata, Column, Integer, String + from sqlalchemy import schema + + metadata = MetaData() + + table = Table('mytable', MetaData(), + Column('x', Integer, info={"special":True}, primary_key=True), + Column('y', String(50)), + Column('z', String(20), info={"special":True}) + ) + + metadata.create_all(conn) + + Above, the directives we've added to the :attr:`.Column.info` collection + will be detected by our custom compilation scheme:: + + CREATE TABLE mytable ( + x SPECIAL DIRECTIVE INTEGER NOT NULL, + y VARCHAR(50), + z SPECIAL DIRECTIVE VARCHAR(20), + PRIMARY KEY (x) + ) + + The :class:`.CreateColumn` construct can also be used to skip certain + columns when producing a ``CREATE TABLE``. This is accomplished by + creating a compilation rule that conditionally returns ``None``. + This is essentially how to produce the same effect as using the + ``system=True`` argument on :class:`.Column`, which marks a column + as an implicitly-present "system" column. + + For example, suppose we wish to produce a :class:`.Table` which skips + rendering of the Postgresql ``xmin`` column against the Postgresql backend, + but on other backends does render it, in anticipation of a triggered rule. + A conditional compilation rule could skip this name only on Postgresql:: + + from sqlalchemy.schema import CreateColumn + + @compiles(CreateColumn, "postgresql") + def skip_xmin(element, compiler, **kw): + if element.element.name == 'xmin': + return None + else: + return compiler.visit_create_column(element, **kw) + + + my_table = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('xmin', Integer) + ) + + Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE`` + which only includes the ``id`` column in the string; the ``xmin`` column + will be omitted, but only against the Postgresql backend. + + .. versionadded:: 0.8.3 The :class:`.CreateColumn` construct supports + skipping of columns by returning ``None`` from a custom compilation rule. + + .. versionadded:: 0.8 The :class:`.CreateColumn` construct was added + to support custom column creation styles. + + """ + __visit_name__ = 'create_column' + + def __init__(self, element): + self.element = element + + +class DropTable(_CreateDropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + + +class CreateSequence(_CreateDropBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + + +class DropSequence(_CreateDropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + + +class CreateIndex(_CreateDropBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + + +class DropIndex(_CreateDropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + + +class AddConstraint(_CreateDropBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + + def __init__(self, element, *args, **kw): + super(AddConstraint, self).__init__(element, *args, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable) + + +class DropConstraint(_CreateDropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, **kw): + self.cascade = cascade + super(DropConstraint, self).__init__(element, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable) + + +class DDLBase(SchemaVisitor): + def __init__(self, connection): + self.connection = connection + + +class SchemaGenerator(DDLBase): + + def __init__(self, dialect, connection, checkfirst=False, + tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def _can_create_table(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or \ + not self.dialect.has_table(self.connection, + table.name, schema=table.schema) + + def _can_create_sequence(self, sequence): + return self.dialect.supports_sequences and \ + ( + (not self.dialect.sequences_optional or + not sequence.optional) and + ( + not self.checkfirst or + not self.dialect.has_sequence( + self.connection, + sequence.name, + schema=sequence.schema) + ) + ) + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + collection = [t for t in sort_tables(tables) + if self._can_create_table(t)] + seq_coll = [s for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s)] + + metadata.dispatch.before_create(metadata, self.connection, + tables=collection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + for seq in seq_coll: + self.traverse_single(seq, create_ok=True) + + for table in collection: + self.traverse_single(table, create_ok=True) + + metadata.dispatch.after_create(metadata, self.connection, + tables=collection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + def visit_table(self, table, create_ok=False): + if not create_ok and not self._can_create_table(table): + return + + table.dispatch.before_create(table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(CreateTable(table)) + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + table.dispatch.after_create(table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + def visit_sequence(self, sequence, create_ok=False): + if not create_ok and not self._can_create_sequence(sequence): + return + self.connection.execute(CreateSequence(sequence)) + + def visit_index(self, index): + self.connection.execute(CreateIndex(index)) + + +class SchemaDropper(DDLBase): + + def __init__(self, dialect, connection, checkfirst=False, + tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + collection = [ + t + for t in reversed(sort_tables(tables)) + if self._can_drop_table(t) + ] + + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_drop_sequence(s) + ] + + metadata.dispatch.before_drop( + metadata, self.connection, tables=collection, + checkfirst=self.checkfirst, _ddl_runner=self) + + for table in collection: + self.traverse_single(table, drop_ok=True) + + for seq in seq_coll: + self.traverse_single(seq, drop_ok=True) + + metadata.dispatch.after_drop( + metadata, self.connection, tables=collection, + checkfirst=self.checkfirst, _ddl_runner=self) + + def _can_drop_table(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, + table.name, schema=table.schema) + + def _can_drop_sequence(self, sequence): + return self.dialect.supports_sequences and \ + ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + self.dialect.has_sequence( + self.connection, + sequence.name, + schema=sequence.schema)) + ) + + def visit_index(self, index): + self.connection.execute(DropIndex(index)) + + def visit_table(self, table, drop_ok=False): + if not drop_ok and not self._can_drop_table(table): + return + + table.dispatch.before_drop(table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(DropTable(table)) + + table.dispatch.after_drop(table, self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self) + + def visit_sequence(self, sequence, drop_ok=False): + if not drop_ok and not self._can_drop_sequence(sequence): + return + self.connection.execute(DropSequence(sequence)) + +def sort_tables(tables, skip_fn=None, extra_dependencies=None): + """sort a collection of Table objects in order of + their foreign-key dependency.""" + + tables = list(tables) + tuples = [] + if extra_dependencies is not None: + tuples.extend(extra_dependencies) + + def visit_foreign_key(fkey): + if fkey.use_alter: + return + elif skip_fn and skip_fn(fkey): + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + if parent_table is not child_table: + tuples.append((parent_table, child_table)) + + for table in tables: + traverse(table, + {'schema_visitor': True}, + {'foreign_key': visit_foreign_key}) + + tuples.extend( + [parent, table] for parent in table._extra_dependencies + ) + + return list(topological.sort(tuples, tables)) + diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py new file mode 100644 index 00000000..6d595450 --- /dev/null +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -0,0 +1,282 @@ +# sql/default_comparator.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Default implementation of SQL comparison operations. +""" + +from .. import exc, util +from . import operators +from . import type_api +from .elements import BindParameter, True_, False_, BinaryExpression, \ + Null, _const_expr, _clause_element_as_expr, \ + ClauseList, ColumnElement, TextClause, UnaryExpression, \ + collate, _is_literal, _literal_as_text, ClauseElement +from .selectable import SelectBase, Alias, Selectable, ScalarSelect + +class _DefaultColumnComparator(operators.ColumnOperators): + """Defines comparison and math operations. + + See :class:`.ColumnOperators` and :class:`.Operators` for descriptions + of all operations. + + """ + + @util.memoized_property + def type(self): + return self.expr.type + + def operate(self, op, *other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, *(other + o[1:]), **kwargs) + + def reverse_operate(self, op, other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs) + + def _adapt_expression(self, op, other_comparator): + """evaluate the return type of , + and apply any adaptations to the given operator. + + This method determines the type of a resulting binary expression + given two source types and an operator. For example, two + :class:`.Column` objects, both of the type :class:`.Integer`, will + produce a :class:`.BinaryExpression` that also has the type + :class:`.Integer` when compared via the addition (``+``) operator. + However, using the addition operator with an :class:`.Integer` + and a :class:`.Date` object will produce a :class:`.Date`, assuming + "days delta" behavior by the database (in reality, most databases + other than Postgresql don't accept this particular operation). + + The method returns a tuple of the form , . + The resulting operator and type will be those applied to the + resulting :class:`.BinaryExpression` as the final operator and the + right-hand side of the expression. + + Note that only a subset of operators make usage of + :meth:`._adapt_expression`, + including math operators and user-defined operators, but not + boolean comparison or special SQL keywords like MATCH or BETWEEN. + + """ + return op, other_comparator.type + + def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, + _python_is_types=(util.NoneType, bool), + **kwargs): + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and \ + isinstance(obj, (bool, True_, False_)): + return BinaryExpression(expr, + _literal_as_text(obj), + op, + type_=type_api.BOOLEANTYPE, + negate=negate, modifiers=kwargs) + else: + # all other None/True/False uses IS, IS NOT + if op in (operators.eq, operators.is_): + return BinaryExpression(expr, _const_expr(obj), + operators.is_, + negate=operators.isnot) + elif op in (operators.ne, operators.isnot): + return BinaryExpression(expr, _const_expr(obj), + operators.isnot, + negate=operators.is_) + else: + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'isnot()' operators can " + "be used with None/True/False") + else: + obj = self._check_literal(expr, op, obj) + + if reverse: + return BinaryExpression(obj, + expr, + op, + type_=type_api.BOOLEANTYPE, + negate=negate, modifiers=kwargs) + else: + return BinaryExpression(expr, + obj, + op, + type_=type_api.BOOLEANTYPE, + negate=negate, modifiers=kwargs) + + def _binary_operate(self, expr, op, obj, reverse=False, result_type=None, + **kw): + obj = self._check_literal(expr, op, obj) + + if reverse: + left, right = obj, expr + else: + left, right = expr, obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator) + + return BinaryExpression(left, right, op, type_=result_type) + + def _scalar(self, expr, op, fn, **kw): + return fn(expr) + + def _in_impl(self, expr, op, seq_or_selectable, negate_op, **kw): + seq_or_selectable = _clause_element_as_expr(seq_or_selectable) + + if isinstance(seq_or_selectable, ScalarSelect): + return self._boolean_compare(expr, op, seq_or_selectable, + negate=negate_op) + elif isinstance(seq_or_selectable, SelectBase): + + # TODO: if we ever want to support (x, y, z) IN (select x, + # y, z from table), we would need a multi-column version of + # as_scalar() to produce a multi- column selectable that + # does not export itself as a FROM clause + + return self._boolean_compare( + expr, op, seq_or_selectable.as_scalar(), + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, (Selectable, TextClause)): + return self._boolean_compare(expr, op, seq_or_selectable, + negate=negate_op, **kw) + elif isinstance(seq_or_selectable, ClauseElement): + raise exc.InvalidRequestError('in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % seq_or_selectable) + + # Handle non selectable arguments as sequences + args = [] + for o in seq_or_selectable: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + raise exc.InvalidRequestError('in_() accepts' + ' either a list of expressions ' + 'or a selectable: %r' % o) + elif o is None: + o = Null() + else: + o = expr._bind_param(op, o) + args.append(o) + if len(args) == 0: + + # Special case handling for empty IN's, behave like + # comparison against zero row selectable. We use != to + # build the contradiction as it handles NULL values + # appropriately, i.e. "not (x IN ())" should not return NULL + # values for x. + + util.warn('The IN-predicate on "%s" was invoked with an ' + 'empty sequence. This results in a ' + 'contradiction, which nonetheless can be ' + 'expensive to evaluate. Consider alternative ' + 'strategies for improved performance.' % expr) + if op is operators.in_op: + return expr != expr + else: + return expr == expr + + return self._boolean_compare(expr, op, + ClauseList(*args).self_group(against=op), + negate=negate_op) + + def _unsupported_impl(self, expr, op, *arg, **kw): + raise NotImplementedError("Operator '%s' is not supported on " + "this expression" % op.__name__) + + def _neg_impl(self, expr, op, **kw): + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg) + + def _match_impl(self, expr, op, other, **kw): + """See :meth:`.ColumnOperators.match`.""" + return self._boolean_compare(expr, operators.match_op, + self._check_literal(expr, operators.match_op, + other)) + + def _distinct_impl(self, expr, op, **kw): + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression(expr, operator=operators.distinct_op, + type_=expr.type) + + def _between_impl(self, expr, op, cleft, cright, **kw): + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ClauseList( + self._check_literal(expr, operators.and_, cleft), + self._check_literal(expr, operators.and_, cright), + operator=operators.and_, + group=False, group_contents=False), + operators.between_op) + + def _collate_impl(self, expr, op, other, **kw): + return collate(expr, other) + + # a mapping of operators with the method they use, along with + # their negated operator for comparison operators + operators = { + "add": (_binary_operate,), + "mul": (_binary_operate,), + "sub": (_binary_operate,), + "div": (_binary_operate,), + "mod": (_binary_operate,), + "truediv": (_binary_operate,), + "custom_op": (_binary_operate,), + "concat_op": (_binary_operate,), + "lt": (_boolean_compare, operators.ge), + "le": (_boolean_compare, operators.gt), + "ne": (_boolean_compare, operators.eq), + "gt": (_boolean_compare, operators.le), + "ge": (_boolean_compare, operators.lt), + "eq": (_boolean_compare, operators.ne), + "like_op": (_boolean_compare, operators.notlike_op), + "ilike_op": (_boolean_compare, operators.notilike_op), + "notlike_op": (_boolean_compare, operators.like_op), + "notilike_op": (_boolean_compare, operators.ilike_op), + "contains_op": (_boolean_compare, operators.notcontains_op), + "startswith_op": (_boolean_compare, operators.notstartswith_op), + "endswith_op": (_boolean_compare, operators.notendswith_op), + "desc_op": (_scalar, UnaryExpression._create_desc), + "asc_op": (_scalar, UnaryExpression._create_asc), + "nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst), + "nullslast_op": (_scalar, UnaryExpression._create_nullslast), + "in_op": (_in_impl, operators.notin_op), + "notin_op": (_in_impl, operators.in_op), + "is_": (_boolean_compare, operators.is_), + "isnot": (_boolean_compare, operators.isnot), + "collate": (_collate_impl,), + "match_op": (_match_impl,), + "distinct_op": (_distinct_impl,), + "between_op": (_between_impl, ), + "neg": (_neg_impl,), + "getitem": (_unsupported_impl,), + "lshift": (_unsupported_impl,), + "rshift": (_unsupported_impl,), + } + + def _check_literal(self, expr, operator, other): + if isinstance(other, (ColumnElement, TextClause)): + if isinstance(other, BindParameter) and \ + other.type._isnull: + other = other._clone() + other.type = expr.type + return other + elif hasattr(other, '__clause_element__'): + other = other.__clause_element__() + elif isinstance(other, type_api.TypeEngine.Comparator): + other = other.expr + + if isinstance(other, (SelectBase, Alias)): + return other.as_scalar() + elif not isinstance(other, (ColumnElement, TextClause)): + return expr._bind_param(operator, other) + else: + return other + diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py new file mode 100644 index 00000000..098f2d58 --- /dev/null +++ b/lib/sqlalchemy/sql/dml.py @@ -0,0 +1,770 @@ +# sql/dml.py +# Copyright (C) 2009-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. + +""" + +from .base import Executable, _generative, _from_objects, DialectKWArgs +from .elements import ClauseElement, _literal_as_text, Null, and_, _clone +from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes +from .. import util +from .. import exc + +class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. + + """ + + __visit_name__ = 'update_base' + + _execution_options = \ + Executable._execution_options.union({'autocommit': True}) + _hints = util.immutabledict() + _prefixes = () + + def _process_colparams(self, parameters): + def process_single(p): + if isinstance(p, (list, tuple)): + return dict( + (c.key, pval) + for c, pval in zip(self.table.c, p) + ) + else: + return p + + if isinstance(parameters, (list, tuple)) and \ + parameters and \ + isinstance(parameters[0], (list, tuple, dict)): + + if not self._supports_multi_parameters: + raise exc.InvalidRequestError( + "This construct does not support " + "multiple parameter sets.") + + return [process_single(p) for p in parameters], True + else: + return process_single(parameters), False + + def params(self, *arg, **kw): + """Set the parameters for the statement. + + This method raises ``NotImplementedError`` on the base class, + and is overridden by :class:`.ValuesBase` to provide the + SET/VALUES clause of UPDATE and INSERT. + + """ + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters).") + + def bind(self): + """Return a 'bind' linked to this :class:`.UpdateBase` + or a :class:`.Table` associated with it. + + """ + return self._bind or self.table.bind + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + @_generative + def returning(self, *cols): + """Add a :term:`RETURNING` or equivalent clause to this statement. + + e.g.:: + + stmt = table.update().\\ + where(table.c.data == 'value').\\ + values(status='X').\\ + returning(table.c.server_flag, table.c.updated_timestamp) + + for server_flag, updated_timestamp in connection.execute(stmt): + print(server_flag, updated_timestamp) + + The given collection of column expressions should be derived from + the table that is + the target of the INSERT, UPDATE, or DELETE. While :class:`.Column` + objects are typical, the elements can also be expressions:: + + stmt = table.insert().returning( + (table.c.first_name + " " + table.c.last_name).label('fullname') + ) + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned + are made available via the result set and can be iterated + using :meth:`.ResultProxy.fetchone` and similar. For DBAPIs which do not + natively support returning values (i.e. cx_oracle), + SQLAlchemy will approximate this behavior at the result level + so that a reasonable amount of behavioral neutrality is + provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + .. seealso:: + + :meth:`.ValuesBase.return_defaults` - an alternative method tailored + towards efficient fetching of server-side defaults and triggers + for single-row INSERTs or UPDATEs. + + + """ + self._returning = cols + + + @_generative + def with_hint(self, text, selectable=None, dialect_name="*"): + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use + :meth:`.UpdateBase.prefix_with`. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`.Table` that is the subject of this + statement, or optionally to that of the given + :class:`.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + .. versionadded:: 0.7.6 + + :param text: Text of the hint. + :param selectable: optional :class:`.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + + self._hints = self._hints.union( + {(selectable, dialect_name): text}) + + +class ValuesBase(UpdateBase): + """Supplies support for :meth:`.ValuesBase.values` to + INSERT and UPDATE constructs.""" + + __visit_name__ = 'values_base' + + _supports_multi_parameters = False + _has_multi_parameters = False + select = None + + def __init__(self, table, values, prefixes): + self.table = _interpret_as_from(table) + self.parameters, self._has_multi_parameters = \ + self._process_colparams(values) + if prefixes: + self._setup_prefixes(prefixes) + + @_generative + def values(self, *args, **kwargs): + """specify a fixed VALUES clause for an INSERT statement, or the SET + clause for an UPDATE. + + Note that the :class:`.Insert` and :class:`.Update` constructs support + per-execution time formatting of the VALUES and/or SET clauses, + based on the arguments passed to :meth:`.Connection.execute`. However, + the :meth:`.ValuesBase.values` method can be used to "fix" a particular + set of parameters into the statement. + + Multiple calls to :meth:`.ValuesBase.values` will produce a new + construct, each one with the parameter list modified to include + the new parameters sent. In the typical case of a single + dictionary of parameters, the newly passed keys will replace + the same keys in the previous construct. In the case of a list-based + "multiple values" construct, each new list of values is extended + onto the existing list of values. + + :param \**kwargs: key value pairs representing the string key + of a :class:`.Column` mapped to the value to be rendered into the + VALUES or SET clause:: + + users.insert().values(name="some name") + + users.update().where(users.c.id==5).values(name="some name") + + :param \*args: Alternatively, a dictionary, tuple or list + of dictionaries or tuples can be passed as a single positional + argument in order to form the VALUES or + SET clause of the statement. The single dictionary form + works the same as the kwargs form:: + + users.insert().values({"name": "some name"}) + + If a tuple is passed, the tuple should contain the same number + of columns as the target :class:`.Table`:: + + users.insert().values((5, "some name")) + + The :class:`.Insert` construct also supports multiply-rendered VALUES + construct, for those backends which support this SQL syntax + (SQLite, Postgresql, MySQL). This mode is indicated by passing a list + of one or more dictionaries/tuples:: + + users.insert().values([ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ]) + + In the case of an :class:`.Update` + construct, only the single dictionary/tuple form is accepted, + else an exception is raised. It is also an exception case to + attempt to mix the single-/multiple- value styles together, + either through multiple :meth:`.ValuesBase.values` calls + or by sending a list + kwargs at the same time. + + .. note:: + + Passing a multiple values list is *not* the same + as passing a multiple values list to the :meth:`.Connection.execute` + method. Passing a list of parameter sets to :meth:`.ValuesBase.values` + produces a construct of this form:: + + INSERT INTO table (col1, col2, col3) VALUES + (col1_0, col2_0, col3_0), + (col1_1, col2_1, col3_1), + ... + + whereas a multiple list passed to :meth:`.Connection.execute` + has the effect of using the DBAPI + `executemany() `_ + method, which provides a high-performance system of invoking + a single-row INSERT statement many times against a series + of parameter sets. The "executemany" style is supported by + all database backends, as it does not depend on a special SQL + syntax. + + .. versionadded:: 0.8 + Support for multiple-VALUES INSERT statements. + + + .. seealso:: + + :ref:`inserts_and_updates` - SQL Expression + Language Tutorial + + :func:`~.expression.insert` - produce an ``INSERT`` statement + + :func:`~.expression.update` - produce an ``UPDATE`` statement + + """ + if self.select is not None: + raise exc.InvalidRequestError( + "This construct already inserts from a SELECT") + if self._has_multi_parameters and kwargs: + raise exc.InvalidRequestError( + "This construct already has multiple parameter sets.") + + if args: + if len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally.") + v = args[0] + else: + v = {} + + if self.parameters is None: + self.parameters, self._has_multi_parameters = \ + self._process_colparams(v) + else: + if self._has_multi_parameters: + self.parameters = list(self.parameters) + p, self._has_multi_parameters = self._process_colparams(v) + if not self._has_multi_parameters: + raise exc.ArgumentError( + "Can't mix single-values and multiple values " + "formats in one statement") + + self.parameters.extend(p) + else: + self.parameters = self.parameters.copy() + p, self._has_multi_parameters = self._process_colparams(v) + if self._has_multi_parameters: + raise exc.ArgumentError( + "Can't mix single-values and multiple values " + "formats in one statement") + self.parameters.update(p) + + if kwargs: + if self._has_multi_parameters: + raise exc.ArgumentError( + "Can't pass kwargs and multiple parameter sets " + "simultaenously") + else: + self.parameters.update(kwargs) + + @_generative + def return_defaults(self, *cols): + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against a backend that supports RETURNING, all column + values generated by SQL expression or server-side-default will be added + to any existing RETURNING clause, provided that + :meth:`.UpdateBase.returning` is not used simultaneously. The column values + will then be available on the result using the + :attr:`.ResultProxy.returned_defaults` accessor as a + dictionary, referring to values keyed to the :class:`.Column` object + as well as its ``.key``. + + This method differs from :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.ValuesBase.return_defaults` is only intended for use with + an INSERT or an UPDATE statement that matches exactly one row. + While the RETURNING construct in the general sense supports multiple + rows for a multi-row UPDATE or DELETE statement, or for special + cases of INSERT that return multiple rows (e.g. INSERT from SELECT, + multi-valued VALUES clause), :meth:`.ValuesBase.return_defaults` + is intended only + for an "ORM-style" single-row INSERT/UPDATE statement. The row + returned by the statement is also consumed implcitly when + :meth:`.ValuesBase.return_defaults` is used. By contrast, + :meth:`.UpdateBase.returning` leaves the RETURNING result-set intact + with a collection of any number of rows. + + 2. It is compatible with the existing logic to fetch auto-generated + primary key values, also known as "implicit returning". Backends that + support RETURNING will automatically make use of RETURNING in order + to fetch the value of newly generated primary keys; while the + :meth:`.UpdateBase.returning` method circumvents this behavior, + :meth:`.ValuesBase.return_defaults` leaves it intact. + + 3. It can be called against any backend. Backends that don't support + RETURNING will skip the usage of the feature, rather than raising + an exception. The return value of :attr:`.ResultProxy.returned_defaults` + will be ``None`` + + :meth:`.ValuesBase.return_defaults` is used by the ORM to provide + an efficient implementation for the ``eager_defaults`` feature of + :func:`.mapper`. + + :param cols: optional list of column key names or :class:`.Column` + objects. If omitted, all column expressions evaulated on the server + are added to the returning list. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`.ResultProxy.returned_defaults` + + """ + self._return_defaults = cols or True + + +class Insert(ValuesBase): + """Represent an INSERT construct. + + The :class:`.Insert` object is created using the + :func:`~.expression.insert()` function. + + .. seealso:: + + :ref:`coretutorial_insert_expressions` + + """ + __visit_name__ = 'insert' + + _supports_multi_parameters = True + + def __init__(self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + **dialect_kw): + """Construct an :class:`.Insert` object. + + Similar functionality is available via the + :meth:`~.TableClause.insert` method on + :class:`~.schema.Table`. + + :param table: :class:`.TableClause` which is the subject of the insert. + + :param values: collection of values to be inserted; see + :meth:`.Insert.values` for a description of allowed formats here. + Can be omitted entirely; a :class:`.Insert` construct will also + dynamically render the VALUES clause at execution time based on + the parameters passed to :meth:`.Connection.execute`. + + :param inline: if True, SQL defaults will be compiled 'inline' into the + statement and not pre-executed. + + If both `values` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within `values` on a per-key basis. + + The keys within `values` can be either :class:`~sqlalchemy.schema.Column` + objects or their string identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + + .. seealso:: + + :ref:`coretutorial_insert_expressions` - SQL Expression Tutorial + + :ref:`inserts_and_updates` - SQL Expression Tutorial + + """ + ValuesBase.__init__(self, table, values, prefixes) + self._bind = bind + self.select = self.select_names = None + self.inline = inline + self._returning = returning + self._validate_dialect_kwargs(dialect_kw) + self._return_defaults = return_defaults + + def get_children(self, **kwargs): + if self.select is not None: + return self.select, + else: + return () + + @_generative + def from_select(self, names, select): + """Return a new :class:`.Insert` construct which represents + an ``INSERT...FROM SELECT`` statement. + + e.g.:: + + sel = select([table1.c.a, table1.c.b]).where(table1.c.c > 5) + ins = table2.insert().from_select(['a', 'b'], sel) + + :param names: a sequence of string column names or :class:`.Column` + objects representing the target columns. + :param select: a :func:`.select` construct, :class:`.FromClause` + or other construct which resolves into a :class:`.FromClause`, + such as an ORM :class:`.Query` object, etc. The order of + columns returned from this FROM clause should correspond to the + order of columns sent as the ``names`` parameter; while this + is not checked before passing along to the database, the database + would normally raise an exception if these column lists don't + correspond. + + .. note:: + + Depending on backend, it may be necessary for the :class:`.Insert` + statement to be constructed using the ``inline=True`` flag; this + flag will prevent the implicit usage of ``RETURNING`` when the + ``INSERT`` statement is rendered, which isn't supported on a backend + such as Oracle in conjunction with an ``INSERT..SELECT`` combination:: + + sel = select([table1.c.a, table1.c.b]).where(table1.c.c > 5) + ins = table2.insert(inline=True).from_select(['a', 'b'], sel) + + .. note:: + + A SELECT..INSERT construct in SQL has no VALUES clause. Therefore + :class:`.Column` objects which utilize Python-side defaults + (e.g. as described at :ref:`metadata_defaults_toplevel`) + will **not** take effect when using :meth:`.Insert.from_select`. + + .. versionadded:: 0.8.3 + + """ + if self.parameters: + raise exc.InvalidRequestError( + "This construct already inserts value expressions") + + self.parameters, self._has_multi_parameters = \ + self._process_colparams(dict((n, Null()) for n in names)) + + self.select_names = names + self.select = _interpret_as_select(select) + + def _copy_internals(self, clone=_clone, **kw): + # TODO: coverage + self.parameters = self.parameters.copy() + if self.select is not None: + self.select = _clone(self.select) + + +class Update(ValuesBase): + """Represent an Update construct. + + The :class:`.Update` object is created using the :func:`update()` function. + + """ + __visit_name__ = 'update' + + def __init__(self, + table, + whereclause=None, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + **dialect_kw): + """Construct an :class:`.Update` object. + + E.g.:: + + from sqlalchemy import update + + stmt = update(users).where(users.c.id==5).\\ + values(name='user #5') + + Similar functionality is available via the + :meth:`~.TableClause.update` method on + :class:`.Table`:: + + stmt = users.update().\\ + where(users.c.id==5).\\ + values(name='user #5') + + :param table: A :class:`.Table` object representing the database + table to be updated. + + :param whereclause: Optional SQL expression describing the ``WHERE`` + condition of the ``UPDATE`` statement. Modern applications + may prefer to use the generative :meth:`~Update.where()` + method to specify the ``WHERE`` clause. + + The WHERE clause can refer to multiple tables. + For databases which support this, an ``UPDATE FROM`` clause will + be generated, or on MySQL, a multi-table update. The statement + will fail on databases that don't have support for multi-table + update statements. A SQL-standard method of referring to + additional tables in the WHERE clause is to use a correlated + subquery:: + + users.update().values(name='ed').where( + users.c.name==select([addresses.c.email_address]).\\ + where(addresses.c.user_id==users.c.id).\\ + as_scalar() + ) + + .. versionchanged:: 0.7.4 + The WHERE clause can refer to multiple tables. + + :param values: + Optional dictionary which specifies the ``SET`` conditions of the + ``UPDATE``. If left as ``None``, the ``SET`` + conditions are determined from those parameters passed to the + statement during the execution and/or compilation of the + statement. When compiled standalone without any parameters, + the ``SET`` clause generates for all columns. + + Modern applications may prefer to use the generative + :meth:`.Update.values` method to set the values of the + UPDATE statement. + + :param inline: + if True, SQL defaults present on :class:`.Column` objects via + the ``default`` keyword will be compiled 'inline' into the statement + and not pre-executed. This means that their values will not + be available in the dictionary returned from + :meth:`.ResultProxy.last_updated_params`. + + If both ``values`` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within ``values`` on a per-key basis. + + The keys within ``values`` can be either :class:`.Column` + objects or their string identifiers (specifically the "key" of the + :class:`.Column`, normally but not necessarily equivalent to + its "name"). Normally, the + :class:`.Column` objects used here are expected to be + part of the target :class:`.Table` that is the table + to be updated. However when using MySQL, a multiple-table + UPDATE statement can refer to columns from any of + the tables referred to in the WHERE clause. + + The values referred to in ``values`` are typically: + + * a literal data value (i.e. string, number, etc.) + * a SQL expression, such as a related :class:`.Column`, + a scalar-returning :func:`.select` construct, + etc. + + When combining :func:`.select` constructs within the values + clause of an :func:`.update` construct, + the subquery represented by the :func:`.select` should be + *correlated* to the parent table, that is, providing criterion + which links the table inside the subquery to the outer table + being updated:: + + users.update().values( + name=select([addresses.c.email_address]).\\ + where(addresses.c.user_id==users.c.id).\\ + as_scalar() + ) + + .. seealso:: + + :ref:`inserts_and_updates` - SQL Expression + Language Tutorial + + + """ + ValuesBase.__init__(self, table, values, prefixes) + self._bind = bind + self._returning = returning + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + self.inline = inline + self._validate_dialect_kwargs(dialect_kw) + self._return_defaults = return_defaults + + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + def _copy_internals(self, clone=_clone, **kw): + # TODO: coverage + self._whereclause = clone(self._whereclause, **kw) + self.parameters = self.parameters.copy() + + @_generative + def where(self, whereclause): + """return a new update() construct with the given expression added to + its WHERE clause, joined to the existing clause via AND, if any. + + """ + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, + _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + @property + def _extra_froms(self): + # TODO: this could be made memoized + # if the memoization is reset on each generative call. + froms = [] + seen = set([self.table]) + + if self._whereclause is not None: + for item in _from_objects(self._whereclause): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + return froms + + +class Delete(UpdateBase): + """Represent a DELETE construct. + + The :class:`.Delete` object is created using the :func:`delete()` function. + + """ + + __visit_name__ = 'delete' + + def __init__(self, + table, + whereclause=None, + bind=None, + returning=None, + prefixes=None, + **dialect_kw): + """Construct :class:`.Delete` object. + + Similar functionality is available via the + :meth:`~.TableClause.delete` method on + :class:`~.schema.Table`. + + :param table: The table to be updated. + + :param whereclause: A :class:`.ClauseElement` describing the ``WHERE`` + condition of the ``UPDATE`` statement. Note that the + :meth:`~Delete.where()` generative method may be used instead. + + .. seealso:: + + :ref:`deletes` - SQL Expression Tutorial + + """ + self._bind = bind + self.table = _interpret_as_from(table) + self._returning = returning + + if prefixes: + self._setup_prefixes(prefixes) + + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + + self._validate_dialect_kwargs(dialect_kw) + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + @_generative + def where(self, whereclause): + """Add the given WHERE clause to a newly returned delete construct.""" + + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, + _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + def _copy_internals(self, clone=_clone, **kw): + # TODO: coverage + self._whereclause = clone(self._whereclause, **kw) + diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py new file mode 100644 index 00000000..c230fb0d --- /dev/null +++ b/lib/sqlalchemy/sql/elements.py @@ -0,0 +1,3451 @@ +# sql/elements.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Core SQL expression elements, including :class:`.ClauseElement`, +:class:`.ColumnElement`, and derived classes. + +""" + +from __future__ import unicode_literals + +from .. import util, exc, inspection +from . import type_api +from . import operators +from .visitors import Visitable, cloned_traverse, traverse +from .annotation import Annotated +import itertools +from .base import Executable, PARSE_AUTOCOMMIT, Immutable, NO_ARG +from .base import _generative, Generative + +import re +import operator + +def _clone(element, **kw): + return element._clone() + +def collate(expression, collation): + """Return the clause ``expression COLLATE collation``. + + e.g.:: + + collate(mycolumn, 'utf8_bin') + + produces:: + + mycolumn COLLATE utf8_bin + + """ + + expr = _literal_as_binds(expression) + return BinaryExpression( + expr, + _literal_as_text(collation), + operators.collate, type_=expr.type) + +def between(expr, lower_bound, upper_bound): + """Produce a ``BETWEEN`` predicate clause. + + E.g.:: + + from sqlalchemy import between + stmt = select([users_table]).where(between(users_table.c.id, 5, 7)) + + Would produce SQL resembling:: + + SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 + + The :func:`.between` function is a standalone version of the + :meth:`.ColumnElement.between` method available on all + SQL expressions, as in:: + + stmt = select([users_table]).where(users_table.c.id.between(5, 7)) + + All arguments passed to :func:`.between`, including the left side + column expression, are coerced from Python scalar values if a + the value is not a :class:`.ColumnElement` subclass. For example, + three fixed values can be compared as in:: + + print(between(5, 3, 7)) + + Which would produce:: + + :param_1 BETWEEN :param_2 AND :param_3 + + :param expr: a column expression, typically a :class:`.ColumnElement` + instance or alternatively a Python scalar expression to be coerced + into a column expression, serving as the left side of the ``BETWEEN`` + expression. + + :param lower_bound: a column or Python scalar expression serving as the lower + bound of the right side of the ``BETWEEN`` expression. + + :param upper_bound: a column or Python scalar expression serving as the + upper bound of the right side of the ``BETWEEN`` expression. + + .. seealso:: + + :meth:`.ColumnElement.between` + + """ + expr = _literal_as_binds(expr) + return expr.between(lower_bound, upper_bound) + +def literal(value, type_=None): + """Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- :class:`.ClauseElement` + objects (such as strings, ints, dates, etc.) are used in a comparison + operation with a :class:`.ColumnElement` + subclass, such as a :class:`~sqlalchemy.schema.Column` object. + Use this function to force the + generation of a literal clause, which will be created as a + :class:`BindParameter` with a bound value. + + :param value: the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type argument. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which + will provide bind-parameter translation for this literal. + + """ + return BindParameter(None, value, type_=type_, unique=True) + + + +def type_coerce(expression, type_): + """Associate a SQL expression with a particular type, without rendering + ``CAST``. + + E.g.:: + + from sqlalchemy import type_coerce + + stmt = select([type_coerce(log_table.date_string, StringDateTime())]) + + The above construct will produce SQL that is usually otherwise unaffected + by the :func:`.type_coerce` call:: + + SELECT date_string FROM log + + However, when result rows are fetched, the ``StringDateTime`` type + will be applied to result rows on behalf of the ``date_string`` column. + + A type that features bound-value handling will also have that behavior + take effect when literal values or :func:`.bindparam` constructs are + passed to :func:`.type_coerce` as targets. + For example, if a type implements the :meth:`.TypeEngine.bind_expression` + method or :meth:`.TypeEngine.bind_processor` method or equivalent, + these functions will take effect at statement compliation/execution time + when a literal value is passed, as in:: + + # bound-value handling of MyStringType will be applied to the + # literal value "some string" + stmt = select([type_coerce("some string", MyStringType)]) + + :func:`.type_coerce` is similar to the :func:`.cast` function, + except that it does not render the ``CAST`` expression in the resulting + statement. + + :param expression: A SQL expression, such as a :class:`.ColumnElement` expression + or a Python string which will be coerced into a bound literal value. + + :param type_: A :class:`.TypeEngine` class or instance indicating + the type to which the the expression is coerced. + + .. seealso:: + + :func:`.cast` + + """ + type_ = type_api.to_instance(type_) + + if hasattr(expression, '__clause_element__'): + return type_coerce(expression.__clause_element__(), type_) + elif isinstance(expression, BindParameter): + bp = expression._clone() + bp.type = type_ + return bp + elif not isinstance(expression, Visitable): + if expression is None: + return Null() + else: + return literal(expression, type_=type_) + else: + return Label(None, expression, type_=type_) + + + + + +def outparam(key, type_=None): + """Create an 'OUT' parameter for usage in functions (stored procedures), + for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + :class:`~sqlalchemy.engine.ResultProxy` object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + + """ + return BindParameter( + key, None, type_=type_, unique=False, isoutparam=True) + + + + +def not_(clause): + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + :class:`.ColumnElement` subclasses to produce the + same result. + + """ + return operators.inv(_literal_as_binds(clause)) + + + +@inspection._self_inspects +class ClauseElement(Visitable): + """Base class for elements of a programmatically constructed SQL + expression. + + """ + __visit_name__ = 'clause' + + _annotations = {} + supports_execution = False + _from_objects = [] + bind = None + _is_clone_of = None + is_selectable = False + is_clause_element = True + + _order_by_label_element = None + + def _clone(self): + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + ClauseElement._cloned_set._reset(c) + ColumnElement.comparator._reset(c) + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + c._is_clone_of = self + + return c + + @property + def _constructor(self): + """return the 'constructor' for this ClauseElement. + + This is for the purposes for creating a new object of + this type. Usually, its just the element's __class__. + However, the "Annotated" version of the object overrides + to return the class of its proxied element. + + """ + return self.__class__ + + @util.memoized_property + def _cloned_set(self): + """Return the set consisting all cloned ancestors of this + ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ + s = util.column_set() + f = self + while f is not None: + s.add(f) + f = f._is_clone_of + return s + + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_is_clone_of', None) + return d + + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated(self, values) + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_clauseelement(self, multiparams, params) + + def unique_params(self, *optionaldict, **kwargs): + """Return a copy with :func:`bindparam()` elements replaced. + + Same functionality as ``params()``, except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + + """ + return self._params(True, optionaldict, kwargs) + + def params(self, *optionaldict, **kwargs): + """Return a copy with :func:`bindparam()` elements replaced. + + Returns a copy of this ClauseElement with :func:`bindparam()` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print clause.compile().params + {'foo':None} + >>> print clause.params({'foo':7}).compile().params + {'foo':7} + + """ + return self._params(False, optionaldict, kwargs) + + def _params(self, unique, optionaldict, kwargs): + if len(optionaldict) == 1: + kwargs.update(optionaldict[0]) + elif len(optionaldict) > 1: + raise exc.ArgumentError( + "params() takes zero or one positional dictionary argument") + + def visit_bindparam(bind): + if bind.key in kwargs: + bind.value = kwargs[bind.key] + bind.required = False + if unique: + bind._convert_to_unique() + return cloned_traverse(self, {}, {'bindparam': visit_bindparam}) + + def compare(self, other, **kw): + """Compare this ClauseElement to the given ClauseElement. + + Subclasses should override the default behavior, which is a + straight identity comparison. + + \**kw are arguments consumed by subclass compare() methods and + may be used to modify the criteria for comparison. + (see :class:`.ColumnElement`) + + """ + return self is other + + def _copy_internals(self, clone=_clone, **kw): + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + + """ + pass + + def get_children(self, **kwargs): + """Return immediate child elements of this :class:`.ClauseElement`. + + This is used for visit traversal. + + \**kwargs may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + return [] + + def self_group(self, against=None): + """Apply a 'grouping' to this :class:`.ClauseElement`. + + This method is overridden by subclasses to return a + "grouping" construct, i.e. parenthesis. In particular + it's used by "binary" expressions to provide a grouping + around themselves when placed into a larger expression, + as well as by :func:`.select` constructs when placed into + the FROM clause of another :func:`.select`. (Note that + subqueries should be normally created using the + :meth:`.Select.alias` method, as many platforms require + nested SELECT statements to be named). + + As expressions are composed together, the application of + :meth:`self_group` is automatic - end-user code should never + need to use this method directly. Note that SQLAlchemy's + clause constructs take operator precedence into account - + so parenthesis might not be needed, for example, in + an expression like ``x OR (y AND z)`` - AND takes precedence + over OR. + + The base :meth:`self_group` method of :class:`.ClauseElement` + just returns self. + """ + return self + + @util.dependencies("sqlalchemy.engine.default") + def compile(self, default, bind=None, dialect=None, **kw): + """Compile this SQL expression. + + The return value is a :class:`~.Compiled` object. + Calling ``str()`` or ``unicode()`` on the returned value will yield a + string representation of the result. The + :class:`~.Compiled` object also can return a + dictionary of bind parameter names and values + using the ``params`` accessor. + + :param bind: An ``Engine`` or ``Connection`` from which a + ``Compiled`` will be acquired. This argument takes precedence over + this :class:`.ClauseElement`'s bound engine, if any. + + :param column_keys: Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause of the + compiled statement. If ``None``, all columns from the target table + object are rendered. + + :param dialect: A ``Dialect`` instance from which a ``Compiled`` + will be acquired. This argument takes precedence over the `bind` + argument as well as this :class:`.ClauseElement`'s bound engine, if + any. + + :param inline: Used for INSERT statements, for a dialect which does + not support inline retrieval of newly generated primary key + columns, will force the expression used to create the new primary + key value to be rendered inline within the INSERT statement's + VALUES clause. This typically refers to Sequence execution but may + also refer to any server-side default generation function + associated with a primary key `Column`. + + """ + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + bind = self.bind + else: + dialect = default.DefaultDialect() + return self._compiler(dialect, bind=bind, **kw) + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + + def __str__(self): + if util.py3k: + return str(self.compile()) + else: + return unicode(self.compile()).encode('ascii', 'backslashreplace') + + def __and__(self, other): + return and_(self, other) + + def __or__(self, other): + return or_(self, other) + + def __invert__(self): + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return self._negate() + + def __bool__(self): + raise TypeError("Boolean value of this clause is not defined") + + __nonzero__ = __bool__ + + def _negate(self): + return UnaryExpression( + self.self_group(against=operators.inv), + operator=operators.inv, + negate=None) + + def __repr__(self): + friendly = getattr(self, 'description', None) + if friendly is None: + return object.__repr__(self) + else: + return '<%s.%s at 0x%x; %s>' % ( + self.__module__, self.__class__.__name__, id(self), friendly) + + + +class ColumnElement(ClauseElement, operators.ColumnOperators): + """Represent a column-oriented SQL expression suitable for usage in the + "columns" clause, WHERE clause etc. of a statement. + + While the most familiar kind of :class:`.ColumnElement` is the + :class:`.Column` object, :class:`.ColumnElement` serves as the basis + for any unit that may be present in a SQL expression, including + the expressions themselves, SQL functions, bound parameters, + literal expressions, keywords such as ``NULL``, etc. + :class:`.ColumnElement` is the ultimate base class for all such elements. + + A wide variety of SQLAlchemy Core functions work at the SQL expression level, + and are intended to accept instances of :class:`.ColumnElement` as arguments. + These functions will typically document that they accept a "SQL expression" + as an argument. What this means in terms of SQLAlchemy usually refers + to an input which is either already in the form of a :class:`.ColumnElement` + object, or a value which can be **coerced** into one. The coercion + rules followed by most, but not all, SQLAlchemy Core functions with regards + to SQL expressions are as follows: + + * a literal Python value, such as a string, integer or floating + point value, boolean, datetime, ``Decimal`` object, or virtually + any other Python object, will be coerced into a "literal bound value". + This generally means that a :func:`.bindparam` will be produced + featuring the given value embedded into the construct; the resulting + :class:`.BindParameter` object is an instance of :class:`.ColumnElement`. + The Python value will ultimately be sent to the DBAPI at execution time as a + paramterized argument to the ``execute()`` or ``executemany()`` methods, + after SQLAlchemy type-specific converters (e.g. those provided by + any associated :class:`.TypeEngine` objects) are applied to the value. + + * any special object value, typically ORM-level constructs, which feature + a method called ``__clause_element__()``. The Core expression system + looks for this method when an object of otherwise unknown type is passed + to a function that is looking to coerce the argument into a :class:`.ColumnElement` + expression. The ``__clause_element__()`` method, if present, should + return a :class:`.ColumnElement` instance. The primary use of + ``__clause_element__()`` within SQLAlchemy is that of class-bound attributes + on ORM-mapped classes; a ``User`` class which contains a mapped attribute + named ``.name`` will have a method ``User.name.__clause_element__()`` + which when invoked returns the :class:`.Column` called ``name`` associated + with the mapped table. + + * The Python ``None`` value is typically interpreted as ``NULL``, which + in SQLAlchemy Core produces an instance of :func:`.null`. + + A :class:`.ColumnElement` provides the ability to generate new + :class:`.ColumnElement` + objects using Python expressions. This means that Python operators + such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations, + and allow the instantiation of further :class:`.ColumnElement` instances + which are composed from other, more fundamental :class:`.ColumnElement` + objects. For example, two :class:`.ColumnClause` objects can be added + together with the addition operator ``+`` to produce + a :class:`.BinaryExpression`. + Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses + of :class:`.ColumnElement`:: + + >>> from sqlalchemy.sql import column + >>> column('a') + column('b') + + >>> print column('a') + column('b') + a + b + + .. seealso:: + + :class:`.Column` + + :func:`.expression.column` + + """ + + __visit_name__ = 'column' + primary_key = False + foreign_keys = [] + _label = None + _key_label = key = None + _alt_names = () + + def self_group(self, against=None): + if against in (operators.and_, operators.or_, operators._asbool) and \ + self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + return AsBoolean(self, operators.istrue, operators.isfalse) + else: + return self + + def _negate(self): + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + return AsBoolean(self, operators.isfalse, operators.istrue) + else: + return super(ColumnElement, self)._negate() + + @util.memoized_property + def type(self): + return type_api.NULLTYPE + + @util.memoized_property + def comparator(self): + return self.type.comparator_factory(self) + + def __getattr__(self, key): + try: + return getattr(self.comparator, key) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + key) + ) + + def operate(self, op, *other, **kwargs): + return op(self.comparator, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) + + def _bind_param(self, operator, obj): + return BindParameter(None, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + @property + def expression(self): + """Return a column expression. + + Part of the inspection interface; returns self. + + """ + return self + + @property + def _select_iterable(self): + return (self, ) + + @util.memoized_property + def base_columns(self): + return util.column_set(c for c in self.proxy_set + if not hasattr(c, '_proxies')) + + @util.memoized_property + def proxy_set(self): + s = util.column_set([self]) + if hasattr(self, '_proxies'): + for c in self._proxies: + s.update(c.proxy_set) + return s + + def shares_lineage(self, othercolumn): + """Return True if the given :class:`.ColumnElement` + has a common ancestor to this :class:`.ColumnElement`.""" + + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) + + def _compare_name_for_result(self, other): + """Return True if the given column element compares to this one + when targeting within a result row.""" + + return hasattr(other, 'name') and hasattr(self, 'name') and \ + other.name == self.name + + def _make_proxy(self, selectable, name=None, name_is_truncatable=False, **kw): + """Create a new :class:`.ColumnElement` representing this + :class:`.ColumnElement` as it appears in the select list of a + descending selectable. + + """ + if name is None: + name = self.anon_label + if self.key: + key = self.key + else: + try: + key = str(self) + except exc.UnsupportedCompilationError: + key = self.anon_label + + else: + key = name + co = ColumnClause( + _as_truncated(name) if name_is_truncatable else name, + type_=getattr(self, 'type', None), + _selectable=selectable + ) + co._proxies = [self] + if selectable._is_clone_of is not None: + co._is_clone_of = \ + selectable._is_clone_of.columns.get(key) + selectable._columns[key] = co + return co + + def compare(self, other, use_proxies=False, equivalents=None, **kw): + """Compare this ColumnElement to another. + + Special arguments understood: + + :param use_proxies: when True, consider two columns that + share a common base column as equivalent (i.e. shares_lineage()) + + :param equivalents: a dictionary of columns as keys mapped to sets + of columns. If the given "other" column is present in this + dictionary, if any of the columns in the corresponding set() pass the + comparison test, the result is True. This is used to expand the + comparison to other columns that may be known to be equivalent to + this one via foreign key or other criterion. + + """ + to_compare = (other, ) + if equivalents and other in equivalents: + to_compare = equivalents[other].union(to_compare) + + for oth in to_compare: + if use_proxies and self.shares_lineage(oth): + return True + elif hash(oth) == hash(self): + return True + else: + return False + + def label(self, name): + """Produce a column label, i.e. `` AS ``. + + This is a shortcut to the :func:`~.expression.label` function. + + if 'name' is None, an anonymous label name will be generated. + + """ + return Label(name, self, self.type) + + @util.memoized_property + def anon_label(self): + """provides a constant 'anonymous label' for this ColumnElement. + + This is a label() expression which will be named at compile time. + The same label() is returned each time anon_label is called so + that expressions can reference anon_label multiple times, producing + the same label name at compile time. + + the compiler uses this function automatically at compile time + for expressions that are known to be 'unnamed' like binary + expressions and function calls. + + """ + return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self, + 'name', 'anon'))) + + +class BindParameter(ColumnElement): + """Represent a "bound expression". + + :class:`.BindParameter` is invoked explicitly using the + :func:`.bindparam` function, as in:: + + from sqlalchemy import bindparam + + stmt = select([users_table]).\\ + where(users_table.c.name == bindparam('username')) + + Detailed discussion of how :class:`.BindParameter` is used is + at :func:`.bindparam`. + + .. seealso:: + + :func:`.bindparam` + + """ + + __visit_name__ = 'bindparam' + + _is_crud = False + + def __init__(self, key, value=NO_ARG, type_=None, + unique=False, required=NO_ARG, + quote=None, callable_=None, + isoutparam=False, + _compared_to_operator=None, + _compared_to_type=None): + """Produce a "bound expression". + + The return value is an instance of :class:`.BindParameter`; this + is a :class:`.ColumnElement` subclass which represents a so-called + "placeholder" value in a SQL expression, the value of which is supplied + at the point at which the statement in executed against a database + connection. + + In SQLAlchemy, the :func:`.bindparam` construct has + the ability to carry along the actual value that will be ultimately + used at expression time. In this way, it serves not just as + a "placeholder" for eventual population, but also as a means of + representing so-called "unsafe" values which should not be rendered + directly in a SQL statement, but rather should be passed along + to the :term:`DBAPI` as values which need to be correctly escaped + and potentially handled for type-safety. + + When using :func:`.bindparam` explicitly, the use case is typically + one of traditional deferment of parameters; the :func:`.bindparam` + construct accepts a name which can then be referred to at execution + time:: + + from sqlalchemy import bindparam + + stmt = select([users_table]).\\ + where(users_table.c.name == bindparam('username')) + + The above statement, when rendered, will produce SQL similar to:: + + SELECT id, name FROM user WHERE name = :username + + In order to populate the value of ``:username`` above, the value + would typically be applied at execution time to a method + like :meth:`.Connection.execute`:: + + result = connection.execute(stmt, username='wendy') + + Explicit use of :func:`.bindparam` is also common when producing + UPDATE or DELETE statements that are to be invoked multiple times, + where the WHERE criterion of the statement is to change on each + invocation, such as:: + + stmt = users_table.update().\\ + where(user_table.c.name == bindparam('username')).\\ + values(fullname=bindparam('fullname')) + + connection.execute(stmt, [ + {"username": "wendy", "fullname": "Wendy Smith"}, + {"username": "jack", "fullname": "Jack Jones"}, + ]) + + SQLAlchemy's Core expression system makes wide use of :func:`.bindparam` + in an implicit sense. It is typical that Python literal values passed to + virtually all SQL expression functions are coerced into fixed + :func:`.bindparam` constructs. For example, given a comparison operation + such as:: + + expr = users_table.c.name == 'Wendy' + + The above expression will produce a :class:`.BinaryExpression` + contruct, where the left side is the :class:`.Column` object + representing the ``name`` column, and the right side is a :class:`.BindParameter` + representing the literal value:: + + print(repr(expr.right)) + BindParameter('%(4327771088 name)s', 'Wendy', type_=String()) + + The expression above will render SQL such as:: + + user.name = :name_1 + + Where the ``:name_1`` parameter name is an anonymous name. The + actual string ``Wendy`` is not in the rendered string, but is carried + along where it is later used within statement execution. If we + invoke a statement like the following:: + + stmt = select([users_table]).where(users_table.c.name == 'Wendy') + result = connection.execute(stmt) + + We would see SQL logging output as:: + + SELECT "user".id, "user".name + FROM "user" + WHERE "user".name = %(name_1)s + {'name_1': 'Wendy'} + + Above, we see that ``Wendy`` is passed as a parameter to the database, + while the placeholder ``:name_1`` is rendered in the appropriate form + for the target database, in this case the Postgresql database. + + Similarly, :func:`.bindparam` is invoked automatically + when working with :term:`CRUD` statements as far as the "VALUES" + portion is concerned. The :func:`.insert` construct produces an + ``INSERT`` expression which will, at statement execution time, generate + bound placeholders based on the arguments passed, as in:: + + stmt = users_table.insert() + result = connection.execute(stmt, name='Wendy') + + The above will produce SQL output as:: + + INSERT INTO "user" (name) VALUES (%(name)s) + {'name': 'Wendy'} + + The :class:`.Insert` construct, at compilation/execution time, + rendered a single :func:`.bindparam` mirroring the column + name ``name`` as a result of the single ``name`` parameter + we passed to the :meth:`.Connection.execute` method. + + :param key: + the key (e.g. the name) for this bind param. + Will be used in the generated + SQL statement for dialects that use named parameters. This + value may be modified when part of a compilation operation, + if other :class:`BindParameter` objects exist with the same + key, or if its length is too long and truncation is + required. + + :param value: + Initial value for this bind param. Will be used at statement + execution time as the value for this parameter passed to the + DBAPI, if no other value is indicated to the statement execution + method for this particular parameter name. Defaults to ``None``. + + :param callable\_: + A callable function that takes the place of "value". The function + will be called at statement execution time to determine the + ultimate value. Used for scenarios where the actual bind + value cannot be determined at the point at which the clause + construct is created, but embedded bind values are still desirable. + + :param type\_: + A :class:`.TypeEngine` class or instance representing an optional + datatype for this :func:`.bindparam`. If not passed, a type + may be determined automatically for the bind, based on the given + value; for example, trivial Python types such as ``str``, + ``int``, ``bool`` + may result in the :class:`.String`, :class:`.Integer` or + :class:`.Boolean` types being autoamtically selected. + + The type of a :func:`.bindparam` is significant especially in that + the type will apply pre-processing to the value before it is + passed to the database. For example, a :func:`.bindparam` which + refers to a datetime value, and is specified as holding the + :class:`.DateTime` type, may apply conversion needed to the + value (such as stringification on SQLite) before passing the value + to the database. + + :param unique: + if True, the key name of this :class:`.BindParameter` will be + modified if another :class:`.BindParameter` of the same name + already has been located within the containing + expression. This flag is used generally by the internals + when producing so-called "anonymous" bound expressions, it + isn't generally applicable to explicitly-named :func:`.bindparam` + constructs. + + :param required: + If ``True``, a value is required at execution time. If not passed, + it defaults to ``True`` if neither :paramref:`.bindparam.value` + or :paramref:`.bindparam.callable` were passed. If either of these + parameters are present, then :paramref:`.bindparam.required` defaults + to ``False``. + + .. versionchanged:: 0.8 If the ``required`` flag is not specified, + it will be set automatically to ``True`` or ``False`` depending + on whether or not the ``value`` or ``callable`` parameters + were specified. + + :param quote: + True if this parameter name requires quoting and is not + currently known as a SQLAlchemy reserved word; this currently + only applies to the Oracle backend, where bound names must + sometimes be quoted. + + :param isoutparam: + if True, the parameter should be treated like a stored procedure + "OUT" parameter. This applies to backends such as Oracle which + support OUT parameters. + + .. seealso:: + + :ref:`coretutorial_bind_param` + + :ref:`coretutorial_insert_expressions` + + :func:`.outparam` + + """ + if isinstance(key, ColumnClause): + type_ = key.type + key = key.name + if required is NO_ARG: + required = (value is NO_ARG and callable_ is None) + if value is NO_ARG: + value = None + + if quote is not None: + key = quoted_name(key, quote) + + if unique: + self.key = _anonymous_label('%%(%d %s)s' % (id(self), key + or 'param')) + else: + self.key = key or _anonymous_label('%%(%d param)s' + % id(self)) + + # identifying key that won't change across + # clones, used to identify the bind's logical + # identity + self._identifying_key = self.key + + # key that was passed in the first place, used to + # generate new keys + self._orig_key = key or 'param' + + self.unique = unique + self.value = value + self.callable = callable_ + self.isoutparam = isoutparam + self.required = required + if type_ is None: + if _compared_to_type is not None: + self.type = \ + _compared_to_type.coerce_compared_value( + _compared_to_operator, value) + else: + self.type = type_api._type_map.get(type(value), + type_api.NULLTYPE) + elif isinstance(type_, type): + self.type = type_() + else: + self.type = type_ + + def _with_value(self, value): + """Return a copy of this :class:`.BindParameter` with the given value set.""" + cloned = self._clone() + cloned.value = value + cloned.callable = None + cloned.required = False + if cloned.type is type_api.NULLTYPE: + cloned.type = type_api._type_map.get(type(value), + type_api.NULLTYPE) + return cloned + + @property + def effective_value(self): + """Return the value of this bound parameter, + taking into account if the ``callable`` parameter + was set. + + The ``callable`` value will be evaluated + and returned if present, else ``value``. + + """ + if self.callable: + return self.callable() + else: + return self.value + + def _clone(self): + c = ClauseElement._clone(self) + if self.unique: + c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key + or 'param')) + return c + + def _convert_to_unique(self): + if not self.unique: + self.unique = True + self.key = _anonymous_label('%%(%d %s)s' % (id(self), + self._orig_key or 'param')) + + def compare(self, other, **kw): + """Compare this :class:`BindParameter` to the given + clause.""" + + return isinstance(other, BindParameter) \ + and self.type._compare_type_affinity(other.type) \ + and self.value == other.value + + def __getstate__(self): + """execute a deferred value for serialization purposes.""" + + d = self.__dict__.copy() + v = self.value + if self.callable: + v = self.callable() + d['callable'] = None + d['value'] = v + return d + + def __repr__(self): + return 'BindParameter(%r, %r, type_=%r)' % (self.key, + self.value, self.type) + + +class TypeClause(ClauseElement): + """Handle a type keyword in a SQL statement. + + Used by the ``Case`` statement. + + """ + + __visit_name__ = 'typeclause' + + def __init__(self, type): + self.type = type + + +class TextClause(Executable, ClauseElement): + """Represent a literal SQL text fragment. + + E.g.:: + + from sqlalchemy import text + + t = text("SELECT * FROM users") + result = connection.execute(t) + + + The :class:`.Text` construct is produced using the :func:`.text` + function; see that function for full documentation. + + .. seealso:: + + :func:`.text` + + """ + + __visit_name__ = 'textclause' + + _bind_params_regex = re.compile(r'(?`` + to specify bind parameters; they will be compiled to their + engine-specific format. + + :param autocommit: + Deprecated. Use .execution_options(autocommit=) + to set the autocommit option. + + :param bind: + an optional connection or engine to be used for this text query. + + :param bindparams: + Deprecated. A list of :func:`.bindparam` instances used to + provide information about parameters embedded in the statement. + This argument now invokes the :meth:`.TextClause.bindparams` + method on the construct before returning it. E.g.:: + + stmt = text("SELECT * FROM table WHERE id=:id", + bindparams=[bindparam('id', value=5, type_=Integer)]) + + Is equivalent to:: + + stmt = text("SELECT * FROM table WHERE id=:id").\\ + bindparams(bindparam('id', value=5, type_=Integer)) + + .. deprecated:: 0.9.0 the :meth:`.TextClause.bindparams` method + supersedes the ``bindparams`` argument to :func:`.text`. + + :param typemap: + Deprecated. A dictionary mapping the names of columns + represented in the columns clause of a ``SELECT`` statement + to type objects, + which will be used to perform post-processing on columns within + the result set. This parameter now invokes the :meth:`.TextClause.columns` + method, which returns a :class:`.TextAsFrom` construct that gains + a ``.c`` collection and can be embedded in other expressions. E.g.:: + + stmt = text("SELECT * FROM table", + typemap={'id': Integer, 'name': String}, + ) + + Is equivalent to:: + + stmt = text("SELECT * FROM table").columns(id=Integer, name=String) + + Or alternatively:: + + from sqlalchemy.sql import column + stmt = text("SELECT * FROM table").columns( + column('id', Integer), + column('name', String) + ) + + .. deprecated:: 0.9.0 the :meth:`.TextClause.columns` method + supersedes the ``typemap`` argument to :func:`.text`. + + """ + stmt = TextClause(text, bind=bind) + if bindparams: + stmt = stmt.bindparams(*bindparams) + if typemap: + stmt = stmt.columns(**typemap) + if autocommit is not None: + util.warn_deprecated('autocommit on text() is deprecated. ' + 'Use .execution_options(autocommit=True)') + stmt = stmt.execution_options(autocommit=autocommit) + + return stmt + + @_generative + def bindparams(self, *binds, **names_to_values): + """Establish the values and/or types of bound parameters within + this :class:`.TextClause` construct. + + Given a text construct such as:: + + from sqlalchemy import text + stmt = text("SELECT id, name FROM user WHERE name=:name " + "AND timestamp=:timestamp") + + the :meth:`.TextClause.bindparams` method can be used to establish + the initial value of ``:name`` and ``:timestamp``, + using simple keyword arguments:: + + stmt = stmt.bindparams(name='jack', + timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)) + + Where above, new :class:`.BindParameter` objects + will be generated with the names ``name`` and ``timestamp``, and + values of ``jack`` and ``datetime.datetime(2012, 10, 8, 15, 12, 5)``, + respectively. The types will be + inferred from the values given, in this case :class:`.String` and + :class:`.DateTime`. + + When specific typing behavior is needed, the positional ``*binds`` + argument can be used in which to specify :func:`.bindparam` constructs + directly. These constructs must include at least the ``key`` argument, + then an optional value and type:: + + from sqlalchemy import bindparam + stmt = stmt.bindparams( + bindparam('name', value='jack', type_=String), + bindparam('timestamp', type_=DateTime) + ) + + Above, we specified the type of :class:`.DateTime` for the ``timestamp`` + bind, and the type of :class:`.String` for the ``name`` bind. In + the case of ``name`` we also set the default value of ``"jack"``. + + Additional bound parameters can be supplied at statement execution + time, e.g.:: + + result = connection.execute(stmt, + timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)) + + The :meth:`.TextClause.bindparams` method can be called repeatedly, where + it will re-use existing :class:`.BindParameter` objects to add new information. + For example, we can call :meth:`.TextClause.bindparams` first with + typing information, and a second time with value information, and it + will be combined:: + + stmt = text("SELECT id, name FROM user WHERE name=:name " + "AND timestamp=:timestamp") + stmt = stmt.bindparams( + bindparam('name', type_=String), + bindparam('timestamp', type_=DateTime) + ) + stmt = stmt.bindparams( + name='jack', + timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5) + ) + + + .. versionadded:: 0.9.0 The :meth:`.TextClause.bindparams` method supersedes + the argument ``bindparams`` passed to :func:`~.expression.text`. + + + """ + self._bindparams = new_params = self._bindparams.copy() + + for bind in binds: + try: + existing = new_params[bind.key] + except KeyError: + raise exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % bind.key) + else: + new_params[existing.key] = bind + + for key, value in names_to_values.items(): + try: + existing = new_params[key] + except KeyError: + raise exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % key) + else: + new_params[key] = existing._with_value(value) + + + + @util.dependencies('sqlalchemy.sql.selectable') + def columns(self, selectable, *cols, **types): + """Turn this :class:`.TextClause` object into a :class:`.TextAsFrom` + object that can be embedded into another statement. + + This function essentially bridges the gap between an entirely + textual SELECT statement and the SQL expression language concept + of a "selectable":: + + from sqlalchemy.sql import column, text + + stmt = text("SELECT id, name FROM some_table") + stmt = stmt.columns(column('id'), column('name')).alias('st') + + stmt = select([mytable]).\\ + select_from( + mytable.join(stmt, mytable.c.name == stmt.c.name) + ).where(stmt.c.id > 5) + + Above, we used untyped :func:`.column` elements. These can also have + types specified, which will impact how the column behaves in expressions + as well as determining result set behavior:: + + stmt = text("SELECT id, name, timestamp FROM some_table") + stmt = stmt.columns( + column('id', Integer), + column('name', Unicode), + column('timestamp', DateTime) + ) + + for id, name, timestamp in connection.execute(stmt): + print(id, name, timestamp) + + Keyword arguments allow just the names and types of columns to be specified, + where the :func:`.column` elements will be generated automatically:: + + stmt = text("SELECT id, name, timestamp FROM some_table") + stmt = stmt.columns( + id=Integer, + name=Unicode, + timestamp=DateTime + ) + + for id, name, timestamp in connection.execute(stmt): + print(id, name, timestamp) + + The :meth:`.TextClause.columns` method provides a direct + route to calling :meth:`.FromClause.alias` as well as :meth:`.SelectBase.cte` + against a textual SELECT statement:: + + stmt = stmt.columns(id=Integer, name=String).cte('st') + + stmt = select([sometable]).where(sometable.c.id == stmt.c.id) + + .. versionadded:: 0.9.0 :func:`.text` can now be converted into a fully + featured "selectable" construct using the :meth:`.TextClause.columns` + method. This method supersedes the ``typemap`` argument to + :func:`.text`. + + """ + + input_cols = [ + ColumnClause(col.key, types.pop(col.key)) + if col.key in types + else col + for col in cols + ] + [ColumnClause(key, type_) for key, type_ in types.items()] + return selectable.TextAsFrom(self, input_cols) + + @property + def type(self): + return type_api.NULLTYPE + + @property + def comparator(self): + return self.type.comparator_factory(self) + + def self_group(self, against=None): + if against is operators.in_op: + return Grouping(self) + else: + return self + + def _copy_internals(self, clone=_clone, **kw): + self._bindparams = dict((b.key, clone(b, **kw)) + for b in self._bindparams.values()) + + def get_children(self, **kwargs): + return list(self._bindparams.values()) + + +class Null(ColumnElement): + """Represent the NULL keyword in a SQL statement. + + :class:`.Null` is accessed as a constant via the + :func:`.null` function. + + """ + + __visit_name__ = 'null' + + @util.memoized_property + def type(self): + return type_api.NULLTYPE + + @classmethod + def _singleton(cls): + """Return a constant :class:`.Null` construct.""" + + return NULL + + def compare(self, other): + return isinstance(other, Null) + + +class False_(ColumnElement): + """Represent the ``false`` keyword, or equivalent, in a SQL statement. + + :class:`.False_` is accessed as a constant via the + :func:`.false` function. + + """ + + __visit_name__ = 'false' + + @util.memoized_property + def type(self): + return type_api.BOOLEANTYPE + + def _negate(self): + return TRUE + + @classmethod + def _singleton(cls): + """Return a constant :class:`.False_` construct. + + E.g.:: + + >>> from sqlalchemy import false + >>> print select([t.c.x]).where(false()) + SELECT x FROM t WHERE false + + A backend which does not support true/false constants will render as + an expression against 1 or 0:: + + >>> print select([t.c.x]).where(false()) + SELECT x FROM t WHERE 0 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction:: + + >>> print select([t.c.x]).where(or_(t.c.x > 5, true())) + SELECT x FROM t WHERE true + + >>> print select([t.c.x]).where(and_(t.c.x > 5, false())) + SELECT x FROM t WHERE false + + .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature + better integrated behavior within conjunctions and on dialects + that don't support true/false constants. + + .. seealso:: + + :func:`.true` + + """ + + return FALSE + + def compare(self, other): + return isinstance(other, False_) + +class True_(ColumnElement): + """Represent the ``true`` keyword, or equivalent, in a SQL statement. + + :class:`.True_` is accessed as a constant via the + :func:`.true` function. + + """ + + __visit_name__ = 'true' + + @util.memoized_property + def type(self): + return type_api.BOOLEANTYPE + + def _negate(self): + return FALSE + + @classmethod + def _ifnone(cls, other): + if other is None: + return cls._singleton() + else: + return other + + @classmethod + def _singleton(cls): + """Return a constant :class:`.True_` construct. + + E.g.:: + + >>> from sqlalchemy import true + >>> print select([t.c.x]).where(true()) + SELECT x FROM t WHERE true + + A backend which does not support true/false constants will render as + an expression against 1 or 0:: + + >>> print select([t.c.x]).where(true()) + SELECT x FROM t WHERE 1 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction:: + + >>> print select([t.c.x]).where(or_(t.c.x > 5, true())) + SELECT x FROM t WHERE true + + >>> print select([t.c.x]).where(and_(t.c.x > 5, false())) + SELECT x FROM t WHERE false + + .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature + better integrated behavior within conjunctions and on dialects + that don't support true/false constants. + + .. seealso:: + + :func:`.false` + + """ + + return TRUE + + def compare(self, other): + return isinstance(other, True_) + +NULL = Null() +FALSE = False_() +TRUE = True_() + +class ClauseList(ClauseElement): + """Describe a list of clauses, separated by an operator. + + By default, is comma-separated, such as a column listing. + + """ + __visit_name__ = 'clauselist' + + def __init__(self, *clauses, **kwargs): + self.operator = kwargs.pop('operator', operators.comma_op) + self.group = kwargs.pop('group', True) + self.group_contents = kwargs.pop('group_contents', True) + if self.group_contents: + self.clauses = [ + _literal_as_text(clause).self_group(against=self.operator) + for clause in clauses] + else: + self.clauses = [ + _literal_as_text(clause) + for clause in clauses] + + def __iter__(self): + return iter(self.clauses) + + def __len__(self): + return len(self.clauses) + + @property + def _select_iterable(self): + return iter(self) + + def append(self, clause): + if self.group_contents: + self.clauses.append(_literal_as_text(clause).\ + self_group(against=self.operator)) + else: + self.clauses.append(_literal_as_text(clause)) + + def _copy_internals(self, clone=_clone, **kw): + self.clauses = [clone(clause, **kw) for clause in self.clauses] + + def get_children(self, **kwargs): + return self.clauses + + @property + def _from_objects(self): + return list(itertools.chain(*[c._from_objects for c in self.clauses])) + + def self_group(self, against=None): + if self.group and operators.is_precedent(self.operator, against): + return Grouping(self) + else: + return self + + def compare(self, other, **kw): + """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`, + including a comparison of all the clause items. + + """ + if not isinstance(other, ClauseList) and len(self.clauses) == 1: + return self.clauses[0].compare(other, **kw) + elif isinstance(other, ClauseList) and \ + len(self.clauses) == len(other.clauses): + for i in range(0, len(self.clauses)): + if not self.clauses[i].compare(other.clauses[i], **kw): + return False + else: + return self.operator == other.operator + else: + return False + + + +class BooleanClauseList(ClauseList, ColumnElement): + __visit_name__ = 'clauselist' + + def __init__(self, *arg, **kw): + raise NotImplementedError( + "BooleanClauseList has a private constructor") + + @classmethod + def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): + convert_clauses = [] + + clauses = util.coerce_generator_arg(clauses) + for clause in clauses: + clause = _literal_as_text(clause) + + if isinstance(clause, continue_on): + continue + elif isinstance(clause, skip_on): + return clause.self_group(against=operators._asbool) + + convert_clauses.append(clause) + + if len(convert_clauses) == 1: + return convert_clauses[0].self_group(against=operators._asbool) + elif not convert_clauses and clauses: + return clauses[0].self_group(against=operators._asbool) + + convert_clauses = [c.self_group(against=operator) + for c in convert_clauses] + + self = cls.__new__(cls) + self.clauses = convert_clauses + self.group = True + self.operator = operator + self.group_contents = True + self.type = type_api.BOOLEANTYPE + return self + + @classmethod + def and_(cls, *clauses): + """Produce a conjunction of expressions joined by ``AND``. + + E.g.:: + + from sqlalchemy import and_ + + stmt = select([users_table]).where( + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) + + The :func:`.and_` conjunction is also available using the + Python ``&`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select([users_table]).where( + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) + + The :func:`.and_` operation is also implicit in some cases; + the :meth:`.Select.where` method for example can be invoked multiple + times against a statement, which will have the effect of each + clause being combined using :func:`.and_`:: + + stmt = select([users_table]).\\ + where(users_table.c.name == 'wendy').\\ + where(users_table.c.enrolled == True) + + .. seealso:: + + :func:`.or_` + + """ + return cls._construct(operators.and_, True_, False_, *clauses) + + @classmethod + def or_(cls, *clauses): + """Produce a conjunction of expressions joined by ``OR``. + + E.g.:: + + from sqlalchemy import or_ + + stmt = select([users_table]).where( + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) + + The :func:`.or_` conjunction is also available using the + Python ``|`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select([users_table]).where( + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) + + .. seealso:: + + :func:`.and_` + + """ + return cls._construct(operators.or_, False_, True_, *clauses) + + @property + def _select_iterable(self): + return (self, ) + + def self_group(self, against=None): + if not self.clauses: + return self + else: + return super(BooleanClauseList, self).self_group(against=against) + + def _negate(self): + return ClauseList._negate(self) + + +and_ = BooleanClauseList.and_ +or_ = BooleanClauseList.or_ + +class Tuple(ClauseList, ColumnElement): + """Represent a SQL tuple.""" + + def __init__(self, *clauses, **kw): + """Return a :class:`.Tuple`. + + Main usage is to produce a composite IN construct:: + + from sqlalchemy import tuple_ + + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) + + .. warning:: + + The composite IN construct is not supported by all backends, + and is currently known to work on Postgresql and MySQL, + but not SQLite. Unsupported backends will raise + a subclass of :class:`~sqlalchemy.exc.DBAPIError` when such + an expression is invoked. + + """ + + clauses = [_literal_as_binds(c) for c in clauses] + self._type_tuple = [arg.type for arg in clauses] + self.type = kw.pop('type_', self._type_tuple[0] + if self._type_tuple else type_api.NULLTYPE) + + super(Tuple, self).__init__(*clauses, **kw) + + @property + def _select_iterable(self): + return (self, ) + + def _bind_param(self, operator, obj): + return Tuple(*[ + BindParameter(None, o, _compared_to_operator=operator, + _compared_to_type=type_, unique=True) + for o, type_ in zip(obj, self._type_tuple) + ]).self_group() + + +class Case(ColumnElement): + """Represent a ``CASE`` expression. + + :class:`.Case` is produced using the :func:`.case` factory function, + as in:: + + from sqlalchemy import case + + stmt = select([users_table]).\\ + where( + case( + [ + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') + ], + else_='E' + ) + ) + + Details on :class:`.Case` usage is at :func:`.case`. + + .. seealso:: + + :func:`.case` + + """ + + __visit_name__ = 'case' + + def __init__(self, whens, value=None, else_=None): + """Produce a ``CASE`` expression. + + The ``CASE`` construct in SQL is a conditional object that + acts somewhat analogously to an "if/then" construct in other + languages. It returns an instance of :class:`.Case`. + + :func:`.case` in its usual form is passed a list of "when" + contructs, that is, a list of conditions and results as tuples:: + + from sqlalchemy import case + + stmt = select([users_table]).\\ + where( + case( + [ + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') + ], + else_='E' + ) + ) + + The above statement will produce SQL resembling:: + + SELECT id, name FROM user + WHERE CASE + WHEN (name = :name_1) THEN :param_1 + WHEN (name = :name_2) THEN :param_2 + ELSE :param_3 + END + + When simple equality expressions of several values against a single + parent column are needed, :func:`.case` also has a "shorthand" format + used via the + :paramref:`.case.value` parameter, which is passed a column + expression to be compared. In this form, the :paramref:`.case.whens` + parameter is passed as a dictionary containing expressions to be compared + against keyed to result expressions. The statement below is equivalent + to the preceding statement:: + + stmt = select([users_table]).\\ + where( + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name, + else_='E' + ) + ) + + The values which are accepted as result values in + :paramref:`.case.whens` as well as with :paramref:`.case.else_` are + coerced from Python literals into :func:`.bindparam` constructs. + SQL expressions, e.g. :class:`.ColumnElement` constructs, are accepted + as well. To coerce a literal string expression into a constant + expression rendered inline, use the :func:`.literal_column` construct, + as in:: + + from sqlalchemy import case, literal_column + + case( + [ + ( + orderline.c.qty > 100, + literal_column("'greaterthan100'") + ), + ( + orderline.c.qty > 10, + literal_column("'greaterthan10'") + ) + ], + else_=literal_column("'lessthan10'") + ) + + The above will render the given constants without using bound + parameters for the result values (but still for the comparison + values), as in:: + + CASE + WHEN (orderline.qty > :qty_1) THEN 'greaterthan100' + WHEN (orderline.qty > :qty_2) THEN 'greaterthan10' + ELSE 'lessthan10' + END + + :param whens: The criteria to be compared against, :paramref:`.case.whens` + accepts two different forms, based on whether or not :paramref:`.case.value` + is used. + + In the first form, it accepts a list of 2-tuples; each 2-tuple consists + of ``(, )``, where the SQL expression is a + boolean expression and "value" is a resulting value, e.g.:: + + case([ + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') + ]) + + In the second form, it accepts a Python dictionary of comparison values + mapped to a resulting value; this form requires :paramref:`.case.value` + to be present, and values will be compared using the ``==`` operator, + e.g.:: + + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name + ) + + :param value: An optional SQL expression which will be used as a + fixed "comparison point" for candidate values within a dictionary + passed to :paramref:`.case.whens`. + + :param else\_: An optional SQL expression which will be the evaluated + result of the ``CASE`` construct if all expressions within + :paramref:`.case.whens` evaluate to false. When omitted, most + databases will produce a result of NULL if none of the "when" + expressions evaulate to true. + + + """ + + try: + whens = util.dictlike_iteritems(whens) + except TypeError: + pass + + if value is not None: + whenlist = [ + (_literal_as_binds(c).self_group(), + _literal_as_binds(r)) for (c, r) in whens + ] + else: + whenlist = [ + (_no_literals(c).self_group(), + _literal_as_binds(r)) for (c, r) in whens + ] + + if whenlist: + type_ = list(whenlist[-1])[-1].type + else: + type_ = None + + if value is None: + self.value = None + else: + self.value = _literal_as_binds(value) + + self.type = type_ + self.whens = whenlist + if else_ is not None: + self.else_ = _literal_as_binds(else_) + else: + self.else_ = None + + def _copy_internals(self, clone=_clone, **kw): + if self.value is not None: + self.value = clone(self.value, **kw) + self.whens = [(clone(x, **kw), clone(y, **kw)) + for x, y in self.whens] + if self.else_ is not None: + self.else_ = clone(self.else_, **kw) + + def get_children(self, **kwargs): + if self.value is not None: + yield self.value + for x, y in self.whens: + yield x + yield y + if self.else_ is not None: + yield self.else_ + + @property + def _from_objects(self): + return list(itertools.chain(*[x._from_objects for x in + self.get_children()])) + + +def literal_column(text, type_=None): + """Return a textual column expression, as would be in the columns + clause of a ``SELECT`` statement. + + The object returned supports further expressions in the same way as any + other column object, including comparison, math and string operations. + The type\_ parameter is important to determine proper expression behavior + (such as, '+' means string concatenation or numerical addition based on + the type). + + :param text: the text of the expression; can be any SQL expression. + Quoting rules will not be applied. To specify a column-name expression + which should be subject to quoting rules, use the :func:`column` + function. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` + object which will + provide result-set translation and additional expression semantics for + this column. If left as None the type will be NullType. + + """ + return ColumnClause(text, type_=type_, is_literal=True) + + + +class Cast(ColumnElement): + """Represent a ``CAST`` expression. + + :class:`.Cast` is produced using the :func:`.cast` factory function, + as in:: + + from sqlalchemy import cast, Numeric + + stmt = select([ + cast(product_table.c.unit_price, Numeric(10, 4)) + ]) + + Details on :class:`.Cast` usage is at :func:`.cast`. + + .. seealso:: + + :func:`.cast` + + """ + + __visit_name__ = 'cast' + + def __init__(self, expression, type_): + """Produce a ``CAST`` expression. + + :func:`.cast` returns an instance of :class:`.Cast`. + + E.g.:: + + from sqlalchemy import cast, Numeric + + stmt = select([ + cast(product_table.c.unit_price, Numeric(10, 4)) + ]) + + The above statement will produce SQL resembling:: + + SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product + + The :func:`.cast` function performs two distinct functions when + used. The first is that it renders the ``CAST`` expression within + the resulting SQL string. The second is that it associates the given + type (e.g. :class:`.TypeEngine` class or instance) with the column + expression on the Python side, which means the expression will take + on the expression operator behavior associated with that type, + as well as the bound-value handling and result-row-handling behavior + of the type. + + .. versionchanged:: 0.9.0 :func:`.cast` now applies the given type + to the expression such that it takes effect on the bound-value, + e.g. the Python-to-database direction, in addition to the + result handling, e.g. database-to-Python, direction. + + An alternative to :func:`.cast` is the :func:`.type_coerce` function. + This function performs the second task of associating an expression + with a specific type, but does not render the ``CAST`` expression + in SQL. + + :param expression: A SQL expression, such as a :class:`.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type_: A :class:`.TypeEngine` class or instance indicating + the type to which the ``CAST`` should apply. + + .. seealso:: + + :func:`.type_coerce` - Python-side type coercion without emitting + CAST. + + """ + self.type = type_api.to_instance(type_) + self.clause = _literal_as_binds(expression, type_=self.type) + self.typeclause = TypeClause(self.type) + + def _copy_internals(self, clone=_clone, **kw): + self.clause = clone(self.clause, **kw) + self.typeclause = clone(self.typeclause, **kw) + + def get_children(self, **kwargs): + return self.clause, self.typeclause + + @property + def _from_objects(self): + return self.clause._from_objects + + +class Extract(ColumnElement): + """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``.""" + + __visit_name__ = 'extract' + + def __init__(self, field, expr, **kwargs): + """Return a :class:`.Extract` construct. + + This is typically available as :func:`.extract` + as well as ``func.extract`` from the + :data:`.func` namespace. + + """ + self.type = type_api.INTEGERTYPE + self.field = field + self.expr = _literal_as_binds(expr, None) + + def _copy_internals(self, clone=_clone, **kw): + self.expr = clone(self.expr, **kw) + + def get_children(self, **kwargs): + return self.expr, + + @property + def _from_objects(self): + return self.expr._from_objects + + +class UnaryExpression(ColumnElement): + """Define a 'unary' expression. + + A unary expression has a single column expression + and an operator. The operator can be placed on the left + (where it is called the 'operator') or right (where it is called the + 'modifier') of the column expression. + + :class:`.UnaryExpression` is the basis for several unary operators + including those used by :func:`.desc`, :func:`.asc`, :func:`.distinct`, + :func:`.nullsfirst` and :func:`.nullslast`. + + """ + __visit_name__ = 'unary' + + def __init__(self, element, operator=None, modifier=None, + type_=None, negate=None): + self.operator = operator + self.modifier = modifier + self.element = element.self_group(against=self.operator or self.modifier) + self.type = type_api.to_instance(type_) + self.negate = negate + + @classmethod + def _create_nullsfirst(cls, column): + """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. + + :func:`.nullsfirst` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nullsfirst + + stmt = select([users_table]).\\ + order_by(nullsfirst(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS FIRST + + Like :func:`.asc` and :func:`.desc`, :func:`.nullsfirst` is typically + invoked from the column expression itself using :meth:`.ColumnElement.nullsfirst`, + rather than as its standalone function version, as in:: + + stmt = select([users_table]).\\ + order_by(users_table.c.name.desc().nullsfirst()) + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nullslast` + + :meth:`.Select.order_by` + + """ + return UnaryExpression( + _literal_as_text(column), modifier=operators.nullsfirst_op) + + + @classmethod + def _create_nullslast(cls, column): + """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. + + :func:`.nullslast` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nullslast + + stmt = select([users_table]).\\ + order_by(nullslast(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS LAST + + Like :func:`.asc` and :func:`.desc`, :func:`.nullslast` is typically + invoked from the column expression itself using :meth:`.ColumnElement.nullslast`, + rather than as its standalone function version, as in:: + + stmt = select([users_table]).\\ + order_by(users_table.c.name.desc().nullslast()) + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nullsfirst` + + :meth:`.Select.order_by` + + """ + return UnaryExpression( + _literal_as_text(column), modifier=operators.nullslast_op) + + + @classmethod + def _create_desc(cls, column): + """Produce a descending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import desc + + stmt = select([users_table]).order_by(desc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name DESC + + The :func:`.desc` function is a standalone version of the + :meth:`.ColumnElement.desc` method available on all SQL expressions, + e.g.:: + + + stmt = select([users_table]).order_by(users_table.c.name.desc()) + + :param column: A :class:`.ColumnElement` (e.g. scalar SQL expression) + with which to apply the :func:`.desc` operation. + + .. seealso:: + + :func:`.asc` + + :func:`.nullsfirst` + + :func:`.nullslast` + + :meth:`.Select.order_by` + + """ + return UnaryExpression( + _literal_as_text(column), modifier=operators.desc_op) + + @classmethod + def _create_asc(cls, column): + """Produce an ascending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import asc + stmt = select([users_table]).order_by(asc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name ASC + + The :func:`.asc` function is a standalone version of the + :meth:`.ColumnElement.asc` method available on all SQL expressions, + e.g.:: + + + stmt = select([users_table]).order_by(users_table.c.name.asc()) + + :param column: A :class:`.ColumnElement` (e.g. scalar SQL expression) + with which to apply the :func:`.asc` operation. + + .. seealso:: + + :func:`.desc` + + :func:`.nullsfirst` + + :func:`.nullslast` + + :meth:`.Select.order_by` + + """ + return UnaryExpression( + _literal_as_text(column), modifier=operators.asc_op) + + @classmethod + def _create_distinct(cls, expr): + """Produce an column-expression-level unary ``DISTINCT`` clause. + + This applies the ``DISTINCT`` keyword to an individual column + expression, and is typically contained within an aggregate function, + as in:: + + from sqlalchemy import distinct, func + stmt = select([func.count(distinct(users_table.c.name))]) + + The above would produce an expression resembling:: + + SELECT COUNT(DISTINCT name) FROM user + + The :func:`.distinct` function is also available as a column-level + method, e.g. :meth:`.ColumnElement.distinct`, as in:: + + stmt = select([func.count(users_table.c.name.distinct())]) + + The :func:`.distinct` operator is different from the + :meth:`.Select.distinct` method of :class:`.Select`, + which produces a ``SELECT`` statement + with ``DISTINCT`` applied to the result set as a whole, + e.g. a ``SELECT DISTINCT`` expression. See that method for further + information. + + .. seealso:: + + :meth:`.ColumnElement.distinct` + + :meth:`.Select.distinct` + + :data:`.func` + + """ + expr = _literal_as_binds(expr) + return UnaryExpression(expr, + operator=operators.distinct_op, type_=expr.type) + + @util.memoized_property + def _order_by_label_element(self): + if self.modifier in (operators.desc_op, operators.asc_op): + return self.element._order_by_label_element + else: + return None + + @property + def _from_objects(self): + return self.element._from_objects + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + + def get_children(self, **kwargs): + return self.element, + + def compare(self, other, **kw): + """Compare this :class:`UnaryExpression` against the given + :class:`.ClauseElement`.""" + + return ( + isinstance(other, UnaryExpression) and + self.operator == other.operator and + self.modifier == other.modifier and + self.element.compare(other.element, **kw) + ) + + def _negate(self): + if self.negate is not None: + return UnaryExpression( + self.element, + operator=self.negate, + negate=self.operator, + modifier=self.modifier, + type_=self.type) + else: + return ClauseElement._negate(self) + + def self_group(self, against=None): + if self.operator and operators.is_precedent(self.operator, against): + return Grouping(self) + else: + return self + + +class AsBoolean(UnaryExpression): + + def __init__(self, element, operator, negate): + self.element = element + self.type = type_api.BOOLEANTYPE + self.operator = operator + self.negate = negate + self.modifier = None + + def self_group(self, against=None): + return self + + def _negate(self): + return self.element._negate() + + +class BinaryExpression(ColumnElement): + """Represent an expression that is ``LEFT RIGHT``. + + A :class:`.BinaryExpression` is generated automatically + whenever two column expressions are used in a Python binary expresion:: + + >>> from sqlalchemy.sql import column + >>> column('a') + column('b') + + >>> print column('a') + column('b') + a + b + + """ + + __visit_name__ = 'binary' + + def __init__(self, left, right, operator, type_=None, + negate=None, modifiers=None): + # allow compatibility with libraries that + # refer to BinaryExpression directly and pass strings + if isinstance(operator, util.string_types): + operator = operators.custom_op(operator) + self._orig = (left, right) + self.left = left.self_group(against=operator) + self.right = right.self_group(against=operator) + self.operator = operator + self.type = type_api.to_instance(type_) + self.negate = negate + + if modifiers is None: + self.modifiers = {} + else: + self.modifiers = modifiers + + def __bool__(self): + if self.operator in (operator.eq, operator.ne): + return self.operator(hash(self._orig[0]), hash(self._orig[1])) + else: + raise TypeError("Boolean value of this clause is not defined") + + __nonzero__ = __bool__ + + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + @property + def _from_objects(self): + return self.left._from_objects + self.right._from_objects + + def _copy_internals(self, clone=_clone, **kw): + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) + + def get_children(self, **kwargs): + return self.left, self.right + + def compare(self, other, **kw): + """Compare this :class:`BinaryExpression` against the + given :class:`BinaryExpression`.""" + + return ( + isinstance(other, BinaryExpression) and + self.operator == other.operator and + ( + self.left.compare(other.left, **kw) and + self.right.compare(other.right, **kw) or + ( + operators.is_commutative(self.operator) and + self.left.compare(other.right, **kw) and + self.right.compare(other.left, **kw) + ) + ) + ) + + def self_group(self, against=None): + if operators.is_precedent(self.operator, against): + return Grouping(self) + else: + return self + + def _negate(self): + if self.negate is not None: + return BinaryExpression( + self.left, + self.right, + self.negate, + negate=self.operator, + type_=type_api.BOOLEANTYPE, + modifiers=self.modifiers) + else: + return super(BinaryExpression, self)._negate() + + + + +class Grouping(ColumnElement): + """Represent a grouping within a column expression""" + + __visit_name__ = 'grouping' + + def __init__(self, element): + self.element = element + self.type = getattr(element, 'type', type_api.NULLTYPE) + + def self_group(self, against=None): + return self + + @property + def _label(self): + return getattr(self.element, '_label', None) or self.anon_label + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + + def get_children(self, **kwargs): + return self.element, + + @property + def _from_objects(self): + return self.element._from_objects + + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {'element': self.element, 'type': self.type} + + def __setstate__(self, state): + self.element = state['element'] + self.type = state['type'] + + def compare(self, other, **kw): + return isinstance(other, Grouping) and \ + self.element.compare(other.element) + + +class Over(ColumnElement): + """Represent an OVER clause. + + This is a special operator against a so-called + "window" function, as well as any aggregate function, + which produces results relative to the result set + itself. It's supported only by certain database + backends. + + """ + __visit_name__ = 'over' + + order_by = None + partition_by = None + + def __init__(self, func, partition_by=None, order_by=None): + """Produce an :class:`.Over` object against a function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + E.g.:: + + from sqlalchemy import over + over(func.row_number(), order_by='x') + + Would produce "ROW_NUMBER() OVER(ORDER BY x)". + + :param func: a :class:`.FunctionElement` construct, typically + generated by :data:`~.expression.func`. + :param partition_by: a column element or string, or a list + of such, that will be used as the PARTITION BY clause + of the OVER construct. + :param order_by: a column element or string, or a list + of such, that will be used as the ORDER BY clause + of the OVER construct. + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.over` method. + + .. versionadded:: 0.7 + + """ + self.func = func + if order_by is not None: + self.order_by = ClauseList(*util.to_list(order_by)) + if partition_by is not None: + self.partition_by = ClauseList(*util.to_list(partition_by)) + + @util.memoized_property + def type(self): + return self.func.type + + def get_children(self, **kwargs): + return [c for c in + (self.func, self.partition_by, self.order_by) + if c is not None] + + def _copy_internals(self, clone=_clone, **kw): + self.func = clone(self.func, **kw) + if self.partition_by is not None: + self.partition_by = clone(self.partition_by, **kw) + if self.order_by is not None: + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return list(itertools.chain( + *[c._from_objects for c in + (self.func, self.partition_by, self.order_by) + if c is not None] + )) + + +class Label(ColumnElement): + """Represents a column label (AS). + + Represent a label, as typically applied to any column-level + element using the ``AS`` sql keyword. + + """ + + __visit_name__ = 'label' + + def __init__(self, name, element, type_=None): + """Return a :class:`Label` object for the + given :class:`.ColumnElement`. + + A label changes the name of an element in the columns clause of a + ``SELECT`` statement, typically via the ``AS`` SQL keyword. + + This functionality is more conveniently available via the + :meth:`.ColumnElement.label` method on :class:`.ColumnElement`. + + :param name: label name + + :param obj: a :class:`.ColumnElement`. + + """ + while isinstance(element, Label): + element = element.element + if name: + self.name = name + else: + self.name = _anonymous_label('%%(%d %s)s' % (id(self), + getattr(element, 'name', 'anon'))) + self.key = self._label = self._key_label = self.name + self._element = element + self._type = type_ + self._proxies = [element] + + def __reduce__(self): + return self.__class__, (self.name, self._element, self._type) + + @util.memoized_property + def _order_by_label_element(self): + return self + + @util.memoized_property + def type(self): + return type_api.to_instance( + self._type or getattr(self._element, 'type', None) + ) + + @util.memoized_property + def element(self): + return self._element.self_group(against=operators.as_) + + def self_group(self, against=None): + sub_element = self._element.self_group(against=against) + if sub_element is not self._element: + return Label(self.name, + sub_element, + type_=self._type) + else: + return self + + @property + def primary_key(self): + return self.element.primary_key + + @property + def foreign_keys(self): + return self.element.foreign_keys + + def get_children(self, **kwargs): + return self.element, + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + + @property + def _from_objects(self): + return self.element._from_objects + + def _make_proxy(self, selectable, name=None, **kw): + e = self.element._make_proxy(selectable, + name=name if name else self.name) + e._proxies.append(self) + if self._type is not None: + e.type = self._type + return e + + +class ColumnClause(Immutable, ColumnElement): + """Represents a column expression from any textual string. + + The :class:`.ColumnClause`, a lightweight analogue to the + :class:`.Column` class, is typically invoked using the + :func:`.column` function, as in:: + + from sqlalchemy.sql import column + + id, name = column("id"), column("name") + stmt = select([id, name]).select_from("user") + + The above statement would produce SQL like:: + + SELECT id, name FROM user + + :class:`.ColumnClause` is the immediate superclass of the schema-specific + :class:`.Column` object. While the :class:`.Column` class has all the + same capabilities as :class:`.ColumnClause`, the :class:`.ColumnClause` + class is usable by itself in those cases where behavioral requirements + are limited to simple SQL expression generation. The object has none of the + associations with schema-level metadata or with execution-time behavior + that :class:`.Column` does, so in that sense is a "lightweight" version + of :class:`.Column`. + + Full details on :class:`.ColumnClause` usage is at :func:`.column`. + + .. seealso:: + + :func:`.column` + + :class:`.Column` + + """ + __visit_name__ = 'column' + + onupdate = default = server_default = server_onupdate = None + + _memoized_property = util.group_expirable_memoized_property() + + def __init__(self, text, type_=None, is_literal=False, _selectable=None): + """Produce a :class:`.ColumnClause` object. + + The :class:`.ColumnClause` is a lightweight analogue to the + :class:`.Column` class. The :func:`.column` function can + be invoked with just a name alone, as in:: + + from sqlalchemy.sql import column + + id, name = column("id"), column("name") + stmt = select([id, name]).select_from("user") + + The above statement would produce SQL like:: + + SELECT id, name FROM user + + Once constructed, :func:`.column` may be used like any other SQL expression + element such as within :func:`.select` constructs:: + + from sqlalchemy.sql import column + + id, name = column("id"), column("name") + stmt = select([id, name]).select_from("user") + + The text handled by :func:`.column` is assumed to be handled + like the name of a database column; if the string contains mixed case, + special characters, or matches a known reserved word on the target + backend, the column expression will render using the quoting + behavior determined by the backend. To produce a textual SQL + expression that is rendered exactly without any quoting, + use :func:`.literal_column` instead, or pass ``True`` as the + value of :paramref:`.column.is_literal`. Additionally, full SQL + statements are best handled using the :func:`.text` construct. + + :func:`.column` can be used in a table-like + fashion by combining it with the :func:`.table` function + (which is the lightweight analogue to :class:`.Table`) to produce + a working table construct with minimal boilerplate:: + + from sqlalchemy.sql import table, column + + user = table("user", + column("id"), + column("name"), + column("description"), + ) + + stmt = select([user.c.description]).where(user.c.name == 'wendy') + + A :func:`.column` / :func:`.table` construct like that illustrated + above can be created in an + ad-hoc fashion and is not associated with any :class:`.schema.MetaData`, + DDL, or events, unlike its :class:`.Table` counterpart. + + :param text: the text of the element. + + :param type: :class:`.types.TypeEngine` object which can associate + this :class:`.ColumnClause` with a type. + + :param is_literal: if True, the :class:`.ColumnClause` is assumed to + be an exact expression that will be delivered to the output with no + quoting rules applied regardless of case sensitive settings. the + :func:`.literal_column()` function essentially invokes :func:`.column` + while passing ``is_literal=True``. + + .. seealso:: + + :class:`.Column` + + :func:`.literal_column` + + :func:`.text` + + :ref:`metadata_toplevel` + + """ + + self.key = self.name = text + self.table = _selectable + self.type = type_api.to_instance(type_) + self.is_literal = is_literal + + def _compare_name_for_result(self, other): + if self.is_literal or \ + self.table is None or self.table._textual or \ + not hasattr(other, 'proxy_set') or ( + isinstance(other, ColumnClause) and + (other.is_literal or + other.table is None or + other.table._textual) + ): + return (hasattr(other, 'name') and self.name == other.name) or \ + (hasattr(other, '_label') and self._label == other._label) + else: + return other.proxy_set.intersection(self.proxy_set) + + def _get_table(self): + return self.__dict__['table'] + + def _set_table(self, table): + self._memoized_property.expire_instance(self) + self.__dict__['table'] = table + table = property(_get_table, _set_table) + + @_memoized_property + def _from_objects(self): + t = self.table + if t is not None: + return [t] + else: + return [] + + @util.memoized_property + def description(self): + if util.py3k: + return self.name + else: + return self.name.encode('ascii', 'backslashreplace') + + @_memoized_property + def _key_label(self): + if self.key != self.name: + return self._gen_label(self.key) + else: + return self._label + + @_memoized_property + def _label(self): + return self._gen_label(self.name) + + def _gen_label(self, name): + t = self.table + + if self.is_literal: + return None + + elif t is not None and t.named_with_column: + if getattr(t, 'schema', None): + label = t.schema.replace('.', '_') + "_" + \ + t.name + "_" + name + else: + label = t.name + "_" + name + + # propagate name quoting rules for labels. + if getattr(name, "quote", None) is not None: + if isinstance(label, quoted_name): + label.quote = name.quote + else: + label = quoted_name(label, name.quote) + elif getattr(t.name, "quote", None) is not None: + # can't get this situation to occur, so let's + # assert false on it for now + assert not isinstance(label, quoted_name) + label = quoted_name(label, t.name.quote) + + # ensure the label name doesn't conflict with that + # of an existing column + if label in t.c: + _label = label + counter = 1 + while _label in t.c: + _label = label + "_" + str(counter) + counter += 1 + label = _label + + return _as_truncated(label) + + else: + return name + + def _bind_param(self, operator, obj): + return BindParameter(self.name, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + unique=True) + + def _make_proxy(self, selectable, name=None, attach=True, + name_is_truncatable=False, **kw): + # propagate the "is_literal" flag only if we are keeping our name, + # otherwise its considered to be a label + is_literal = self.is_literal and (name is None or name == self.name) + c = self._constructor( + _as_truncated(name or self.name) if \ + name_is_truncatable else \ + (name or self.name), + type_=self.type, + _selectable=selectable, + is_literal=is_literal + ) + if name is None: + c.key = self.key + c._proxies = [self] + if selectable._is_clone_of is not None: + c._is_clone_of = \ + selectable._is_clone_of.columns.get(c.key) + + if attach: + selectable._columns[c.key] = c + return c + + +class _IdentifiedClause(Executable, ClauseElement): + + __visit_name__ = 'identified' + _execution_options = \ + Executable._execution_options.union({'autocommit': False}) + + def __init__(self, ident): + self.ident = ident + + +class SavepointClause(_IdentifiedClause): + __visit_name__ = 'savepoint' + + +class RollbackToSavepointClause(_IdentifiedClause): + __visit_name__ = 'rollback_to_savepoint' + + +class ReleaseSavepointClause(_IdentifiedClause): + __visit_name__ = 'release_savepoint' + + +class quoted_name(util.text_type): + """Represent a SQL identifier combined with quoting preferences. + + :class:`.quoted_name` is a Python unicode/str subclass which + represents a particular identifier name along with a + ``quote`` flag. This ``quote`` flag, when set to + ``True`` or ``False``, overrides automatic quoting behavior + for this identifier in order to either unconditionally quote + or to not quote the name. If left at its default of ``None``, + quoting behavior is applied to the identifier on a per-backend basis + based on an examination of the token itself. + + A :class:`.quoted_name` object with ``quote=True`` is also + prevented from being modified in the case of a so-called + "name normalize" option. Certain database backends, such as + Oracle, Firebird, and DB2 "normalize" case-insensitive names + as uppercase. The SQLAlchemy dialects for these backends + convert from SQLAlchemy's lower-case-means-insensitive convention + to the upper-case-means-insensitive conventions of those backends. + The ``quote=True`` flag here will prevent this conversion from occurring + to support an identifier that's quoted as all lower case against + such a backend. + + The :class:`.quoted_name` object is normally created automatically + when specifying the name for key schema constructs such as :class:`.Table`, + :class:`.Column`, and others. The class can also be passed explicitly + as the name to any function that receives a name which can be quoted. + Such as to use the :meth:`.Engine.has_table` method with an unconditionally + quoted name:: + + from sqlaclchemy import create_engine + from sqlalchemy.sql.elements import quoted_name + + engine = create_engine("oracle+cx_oracle://some_dsn") + engine.has_table(quoted_name("some_table", True)) + + The above logic will run the "has table" logic against the Oracle backend, + passing the name exactly as ``"some_table"`` without converting to + upper case. + + .. versionadded:: 0.9.0 + + """ + + def __new__(cls, value, quote): + if value is None: + return None + # experimental - don't bother with quoted_name + # if quote flag is None. doesn't seem to make any dent + # in performance however + # elif not sprcls and quote is None: + # return value + elif isinstance(value, cls) and ( + quote is None or value.quote == quote + ): + return value + self = super(quoted_name, cls).__new__(cls, value) + self.quote = quote + return self + + def __reduce__(self): + return quoted_name, (util.text_type(self), self.quote) + + @util.memoized_instancemethod + def lower(self): + if self.quote: + return self + else: + return util.text_type(self).lower() + + @util.memoized_instancemethod + def upper(self): + if self.quote: + return self + else: + return util.text_type(self).upper() + + def __repr__(self): + backslashed = self.encode('ascii', 'backslashreplace') + if not util.py2k: + backslashed = backslashed.decode('ascii') + return "'%s'" % backslashed + +class _truncated_label(quoted_name): + """A unicode subclass used to identify symbolic " + "names that may require truncation.""" + + def __new__(cls, value, quote=None): + quote = getattr(value, "quote", quote) + #return super(_truncated_label, cls).__new__(cls, value, quote, True) + return super(_truncated_label, cls).__new__(cls, value, quote) + + def __reduce__(self): + return self.__class__, (util.text_type(self), self.quote) + + def apply_map(self, map_): + return self + +# for backwards compatibility in case +# someone is re-implementing the +# _truncated_identifier() sequence in a custom +# compiler +_generated_label = _truncated_label + + +class _anonymous_label(_truncated_label): + """A unicode subclass used to identify anonymously + generated names.""" + + def __add__(self, other): + return _anonymous_label( + quoted_name( + util.text_type.__add__(self, util.text_type(other)), + self.quote) + ) + + def __radd__(self, other): + return _anonymous_label( + quoted_name( + util.text_type.__add__(util.text_type(other), self), + self.quote) + ) + + def apply_map(self, map_): + if self.quote is not None: + # preserve quoting only if necessary + return quoted_name(self % map_, self.quote) + else: + # else skip the constructor call + return self % map_ + + +def _as_truncated(value): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(value, _truncated_label): + return value + else: + return _truncated_label(value) + + +def _string_or_unprintable(element): + if isinstance(element, util.string_types): + return element + else: + try: + return str(element) + except: + return "unprintable element %r" % element + + +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set(elem for elem in a + if all_overlap.intersection(elem._cloned_set)) + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set(elem for elem in a + if not all_overlap.intersection(elem._cloned_set)) + + +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + + +def _is_column(col): + """True if ``col`` is an instance of :class:`.ColumnElement`.""" + + return isinstance(col, ColumnElement) + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.column_set() + traverse(clause, {}, {'column': cols.add}) + return cols + + +# there is some inconsistency here between the usage of +# inspect() vs. checking for Visitable and __clause_element__. +# Ideally all functions here would derive from inspect(), +# however the inspect() versions add significant callcount +# overhead for critical functions like _interpret_as_column_or_from(). +# Generally, the column-based functions are more performance critical +# and are fine just checking for __clause_element__(). it's only +# _interpret_as_from() where we'd like to be able to receive ORM entities +# that have no defined namespace, hence inspect() is needed there. + + +def _column_as_key(element): + if isinstance(element, util.string_types): + return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + try: + return element.key + except AttributeError: + return None + + +def _clause_element_as_expr(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + else: + return element + + +def _literal_as_text(element): + if isinstance(element, Visitable): + return element + elif hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif isinstance(element, util.string_types): + return TextClause(util.text_type(element)) + elif isinstance(element, (util.NoneType, bool)): + return _const_expr(element) + else: + raise exc.ArgumentError( + "SQL expression object or string expected." + ) + + +def _no_literals(element): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' " + "function to indicate a SQL expression " + "literal, or 'literal()' to indicate a " + "bound value." % element) + else: + return element + + +def _is_literal(element): + return not isinstance(element, Visitable) and \ + not hasattr(element, '__clause_element__') + + +def _only_column_elements_or_none(element, name): + if element is None: + return None + else: + return _only_column_elements(element, name) + + +def _only_column_elements(element, name): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, ColumnElement): + raise exc.ArgumentError( + "Column-based expression object expected for argument " + "'%s'; got: '%s', type %s" % (name, element, type(element))) + return element + +def _literal_as_binds(element, name=None, type_=None): + if hasattr(element, '__clause_element__'): + return element.__clause_element__() + elif not isinstance(element, Visitable): + if element is None: + return Null() + else: + return BindParameter(name, element, type_=type_, unique=True) + else: + return element + + +def _interpret_as_column_or_from(element): + if isinstance(element, Visitable): + return element + elif hasattr(element, '__clause_element__'): + return element.__clause_element__() + + insp = inspection.inspect(element, raiseerr=False) + if insp is None: + if isinstance(element, (util.NoneType, bool)): + return _const_expr(element) + elif hasattr(insp, "selectable"): + return insp.selectable + + return ColumnClause(str(element), is_literal=True) + + +def _const_expr(element): + if isinstance(element, (Null, False_, True_)): + return element + elif element is None: + return Null() + elif element is False: + return False_() + elif element is True: + return True_() + else: + raise exc.ArgumentError( + "Expected None, False, or True" + ) + + +def _type_from_args(args): + for a in args: + if not a.type._isnull: + return a.type + else: + return type_api.NULLTYPE + + +def _corresponding_column_or_error(fromclause, column, + require_embedded=False): + c = fromclause.corresponding_column(column, + require_embedded=require_embedded) + if c is None: + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % + (column, + getattr(column, 'table', None), + fromclause.description) + ) + return c + + +class AnnotatedColumnElement(Annotated): + def __init__(self, element, values): + Annotated.__init__(self, element, values) + ColumnElement.comparator._reset(self) + for attr in ('name', 'key', 'table'): + if self.__dict__.get(attr, False) is None: + self.__dict__.pop(attr) + + def _with_annotations(self, values): + clone = super(AnnotatedColumnElement, self)._with_annotations(values) + ColumnElement.comparator._reset(clone) + return clone + + @util.memoized_property + def name(self): + """pull 'name' from parent, if not present""" + return self._Annotated__element.name + + @util.memoized_property + def table(self): + """pull 'table' from parent, if not present""" + return self._Annotated__element.table + + @util.memoized_property + def key(self): + """pull 'key' from parent, if not present""" + return self._Annotated__element.key + + @util.memoized_property + def info(self): + return self._Annotated__element.info + diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py new file mode 100644 index 00000000..c99665b4 --- /dev/null +++ b/lib/sqlalchemy/sql/expression.py @@ -0,0 +1,127 @@ +# sql/expression.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines the public namespace for SQL expression constructs. + +Prior to version 0.9, this module contained all of "elements", "dml", +"default_comparator" and "selectable". The module was broken up +and most "factory" functions were moved to be grouped with their associated +class. + +""" + +__all__ = [ + 'Alias', 'ClauseElement', 'ColumnCollection', 'ColumnElement', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Select', + 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', + 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', + 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', + 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label', + 'literal', 'literal_column', 'not_', 'null', 'nullsfirst', 'nullslast', + 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery', + 'table', 'text', + 'tuple_', 'type_coerce', 'union', 'union_all', 'update'] + + +from .visitors import Visitable +from .functions import func, modifier, FunctionElement +from ..util.langhelpers import public_factory +from .elements import ClauseElement, ColumnElement,\ + BindParameter, UnaryExpression, BooleanClauseList, \ + Label, Cast, Case, ColumnClause, TextClause, Over, Null, \ + True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \ + Grouping, not_, \ + collate, literal_column, between,\ + literal, outparam, type_coerce, ClauseList + +from .elements import SavepointClause, RollbackToSavepointClause, \ + ReleaseSavepointClause + +from .base import ColumnCollection, Generative, Executable, \ + PARSE_AUTOCOMMIT + +from .selectable import Alias, Join, Select, Selectable, TableClause, \ + CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ + alias, GenerativeSelect, \ + subquery, HasPrefixes, Exists, ScalarSelect, TextAsFrom + + +from .dml import Insert, Update, Delete, UpdateBase, ValuesBase + +# factory functions - these pull class-bound constructors and classmethods +# from SQL elements and selectables into public functions. This allows +# the functions to be available in the sqlalchemy.sql.* namespace and +# to be auto-cross-documenting from the function to the class itself. + +and_ = public_factory(BooleanClauseList.and_, ".expression.and_") +or_ = public_factory(BooleanClauseList.or_, ".expression.or_") +bindparam = public_factory(BindParameter, ".expression.bindparam") +select = public_factory(Select, ".expression.select") +text = public_factory(TextClause._create_text, ".expression.text") +table = public_factory(TableClause, ".expression.table") +column = public_factory(ColumnClause, ".expression.column") +over = public_factory(Over, ".expression.over") +label = public_factory(Label, ".expression.label") +case = public_factory(Case, ".expression.case") +cast = public_factory(Cast, ".expression.cast") +extract = public_factory(Extract, ".expression.extract") +tuple_ = public_factory(Tuple, ".expression.tuple_") +except_ = public_factory(CompoundSelect._create_except, ".expression.except_") +except_all = public_factory(CompoundSelect._create_except_all, ".expression.except_all") +intersect = public_factory(CompoundSelect._create_intersect, ".expression.intersect") +intersect_all = public_factory(CompoundSelect._create_intersect_all, ".expression.intersect_all") +union = public_factory(CompoundSelect._create_union, ".expression.union") +union_all = public_factory(CompoundSelect._create_union_all, ".expression.union_all") +exists = public_factory(Exists, ".expression.exists") +nullsfirst = public_factory(UnaryExpression._create_nullsfirst, ".expression.nullsfirst") +nullslast = public_factory(UnaryExpression._create_nullslast, ".expression.nullslast") +asc = public_factory(UnaryExpression._create_asc, ".expression.asc") +desc = public_factory(UnaryExpression._create_desc, ".expression.desc") +distinct = public_factory(UnaryExpression._create_distinct, ".expression.distinct") +true = public_factory(True_._singleton, ".expression.true") +false = public_factory(False_._singleton, ".expression.false") +null = public_factory(Null._singleton, ".expression.null") +join = public_factory(Join._create_join, ".expression.join") +outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin") +insert = public_factory(Insert, ".expression.insert") +update = public_factory(Update, ".expression.update") +delete = public_factory(Delete, ".expression.delete") + + +# internal functions still being called from tests and the ORM, +# these might be better off in some other namespace +from .base import _from_objects +from .elements import _literal_as_text, _clause_element_as_expr,\ + _is_column, _labeled, _only_column_elements, _string_or_unprintable, \ + _truncated_label, _clone, _cloned_difference, _cloned_intersection,\ + _column_as_key, _literal_as_binds, _select_iterables, \ + _corresponding_column_or_error +from .selectable import _interpret_as_from + + + +# old names for compatibility +_Executable = Executable +_BindParamClause = BindParameter +_Label = Label +_SelectBase = SelectBase +_BinaryExpression = BinaryExpression +_Cast = Cast +_Null = Null +_False = False_ +_True = True_ +_TextClause = TextClause +_UnaryExpression = UnaryExpression +_Case = Case +_Tuple = Tuple +_Over = Over +_Generative = Generative +_TypeClause = TypeClause +_Extract = Extract +_Exists = Exists +_Grouping = Grouping +_FromGrouping = FromGrouping +_ScalarSelect = ScalarSelect diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py new file mode 100644 index 00000000..a9b88b13 --- /dev/null +++ b/lib/sqlalchemy/sql/functions.py @@ -0,0 +1,534 @@ +# sql/functions.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL function API, factories, and built-in functions. + +""" +from . import sqltypes, schema +from .base import Executable +from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ + literal_column, _type_from_args, ColumnElement, _clone,\ + Over, BindParameter +from .selectable import FromClause, Select + +from . import operators +from .visitors import VisitableType +from .. import util +from . import annotation + +_registry = util.defaultdict(dict) + + +def register_function(identifier, fn, package="_default"): + """Associate a callable with a particular func. name. + + This is normally called by _GenericMeta, but is also + available by itself so that a non-Function construct + can be associated with the :data:`.func` accessor (i.e. + CAST, EXTRACT). + + """ + reg = _registry[package] + reg[identifier] = fn + + +class FunctionElement(Executable, ColumnElement, FromClause): + """Base for SQL function-oriented constructs. + + .. seealso:: + + :class:`.Function` - named SQL function. + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + + """ + + packagenames = () + + def __init__(self, *clauses, **kwargs): + """Construct a :class:`.FunctionElement`. + """ + args = [_literal_as_binds(c, self.name) for c in clauses] + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args).\ + self_group() + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_function(self, multiparams, params) + + @property + def columns(self): + """Fulfill the 'columns' contract of :class:`.ColumnElement`. + + Returns a single-element list consisting of this object. + + """ + return [self] + + @util.memoized_property + def clauses(self): + """Return the underlying :class:`.ClauseList` which contains + the arguments for this :class:`.FunctionElement`. + + """ + return self.clause_expr.element + + def over(self, partition_by=None, order_by=None): + """Produce an OVER clause against this function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + The expression:: + + func.row_number().over(order_by='x') + + is shorthand for:: + + from sqlalchemy import over + over(func.row_number(), order_by='x') + + See :func:`~.expression.over` for a full description. + + .. versionadded:: 0.7 + + """ + return Over(self, partition_by=partition_by, order_by=order_by) + + @property + def _from_objects(self): + return self.clauses._from_objects + + def get_children(self, **kwargs): + return self.clause_expr, + + def _copy_internals(self, clone=_clone, **kw): + self.clause_expr = clone(self.clause_expr, **kw) + self._reset_exported() + FunctionElement.clauses._reset(self) + + def select(self): + """Produce a :func:`~.expression.select` construct + against this :class:`.FunctionElement`. + + This is shorthand for:: + + s = select([function_element]) + + """ + s = Select([self]) + if self._execution_options: + s = s.execution_options(**self._execution_options) + return s + + def scalar(self): + """Execute this :class:`.FunctionElement` against an embedded + 'bind' and return a scalar value. + + This first calls :meth:`~.FunctionElement.select` to + produce a SELECT construct. + + Note that :class:`.FunctionElement` can be passed to + the :meth:`.Connectable.scalar` method of :class:`.Connection` + or :class:`.Engine`. + + """ + return self.select().execute().scalar() + + def execute(self): + """Execute this :class:`.FunctionElement` against an embedded + 'bind'. + + This first calls :meth:`~.FunctionElement.select` to + produce a SELECT construct. + + Note that :class:`.FunctionElement` can be passed to + the :meth:`.Connectable.execute` method of :class:`.Connection` + or :class:`.Engine`. + + """ + return self.select().execute() + + def _bind_param(self, operator, obj): + return BindParameter(None, obj, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + + +class _FunctionGenerator(object): + """Generate :class:`.Function` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts + + def __getattr__(self, name): + # passthru __ attributes; fixes pydoc + if name.startswith('__'): + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(name) + + elif name.endswith('_'): + name = name[0:-1] + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f + + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + + tokens = len(self.__names) + + if tokens == 2: + package, fname = self.__names + elif tokens == 1: + package, fname = "_default", self.__names[0] + else: + package = None + + if package is not None: + func = _registry[package].get(fname) + if func is not None: + return func(*c, **o) + + return Function(self.__names[-1], + packagenames=self.__names[0:-1], *c, **o) + + +func = _FunctionGenerator() +"""Generate SQL function expressions. + + :data:`.func` is a special object instance which generates SQL + functions based on name-based attributes, e.g.:: + + >>> print func.count(1) + count(:param_1) + + The element is a column-oriented SQL element like any other, and is + used in that way:: + + >>> print select([func.count(table.c.id)]) + SELECT count(sometable.id) FROM sometable + + Any name can be given to :data:`.func`. If the function name is unknown to + SQLAlchemy, it will be rendered exactly as is. For common SQL functions + which SQLAlchemy is aware of, the name may be interpreted as a *generic + function* which will be compiled appropriately to the target database:: + + >>> print func.current_timestamp() + CURRENT_TIMESTAMP + + To call functions which are present in dot-separated packages, + specify them in the same manner:: + + >>> print func.stats.yield_curve(5, 10) + stats.yield_curve(:yield_curve_1, :yield_curve_2) + + SQLAlchemy can be made aware of the return type of functions to enable + type-specific lexical and result-based behavior. For example, to ensure + that a string-based function returns a Unicode value and is similarly + treated as a string in expressions, specify + :class:`~sqlalchemy.types.Unicode` as the type: + + >>> print func.my_string(u'hi', type_=Unicode) + ' ' + \ + ... func.my_string(u'there', type_=Unicode) + my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3) + + The object returned by a :data:`.func` call is usually an instance of + :class:`.Function`. + This object meets the "column" interface, including comparison and labeling + functions. The object can also be passed the :meth:`~.Connectable.execute` + method of a :class:`.Connection` or :class:`.Engine`, where it will be + wrapped inside of a SELECT statement first:: + + print connection.execute(func.current_timestamp()).scalar() + + In a few exception cases, the :data:`.func` accessor + will redirect a name to a built-in expression such as :func:`.cast` + or :func:`.extract`, as these names have well-known meaning + but are not exactly the same as "functions" from a SQLAlchemy + perspective. + + .. versionadded:: 0.8 :data:`.func` can return non-function expression + constructs for common quasi-functional names like :func:`.cast` + and :func:`.extract`. + + Functions which are interpreted as "generic" functions know how to + calculate their return type automatically. For a listing of known generic + functions, see :ref:`generic_functions`. + +""" + +modifier = _FunctionGenerator(group=False) + +class Function(FunctionElement): + """Describe a named SQL function. + + See the superclass :class:`.FunctionElement` for a description + of public methods. + + .. seealso:: + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + + """ + + __visit_name__ = 'function' + + def __init__(self, name, *clauses, **kw): + """Construct a :class:`.Function`. + + The :data:`.func` construct is normally used to construct + new :class:`.Function` instances. + + """ + self.packagenames = kw.pop('packagenames', None) or [] + self.name = name + self._bind = kw.get('bind', None) + self.type = sqltypes.to_instance(kw.get('type_', None)) + + FunctionElement.__init__(self, *clauses, **kw) + + def _bind_param(self, operator, obj): + return BindParameter(self.name, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + unique=True) + +class _GenericMeta(VisitableType): + def __init__(cls, clsname, bases, clsdict): + if annotation.Annotated not in cls.__mro__: + cls.name = name = clsdict.get('name', clsname) + cls.identifier = identifier = clsdict.get('identifier', name) + package = clsdict.pop('package', '_default') + # legacy + if '__return_type__' in clsdict: + cls.type = clsdict['__return_type__'] + register_function(identifier, cls, package) + super(_GenericMeta, cls).__init__(clsname, bases, clsdict) + + +class GenericFunction(util.with_metaclass(_GenericMeta, Function)): + """Define a 'generic' function. + + A generic function is a pre-established :class:`.Function` + class that is instantiated automatically when called + by name from the :data:`.func` attribute. Note that + calling any name from :data:`.func` has the effect that + a new :class:`.Function` instance is created automatically, + given that name. The primary use case for defining + a :class:`.GenericFunction` class is so that a function + of a particular name may be given a fixed return type. + It can also include custom argument parsing schemes as well + as additional methods. + + Subclasses of :class:`.GenericFunction` are automatically + registered under the name of the class. For + example, a user-defined function ``as_utc()`` would + be available immediately:: + + from sqlalchemy.sql.functions import GenericFunction + from sqlalchemy.types import DateTime + + class as_utc(GenericFunction): + type = DateTime + + print select([func.as_utc()]) + + User-defined generic functions can be organized into + packages by specifying the "package" attribute when defining + :class:`.GenericFunction`. Third party libraries + containing many functions may want to use this in order + to avoid name conflicts with other systems. For example, + if our ``as_utc()`` function were part of a package + "time":: + + class as_utc(GenericFunction): + type = DateTime + package = "time" + + The above function would be available from :data:`.func` + using the package name ``time``:: + + print select([func.time.as_utc()]) + + A final option is to allow the function to be accessed + from one name in :data:`.func` but to render as a different name. + The ``identifier`` attribute will override the name used to + access the function as loaded from :data:`.func`, but will retain + the usage of ``name`` as the rendered name:: + + class GeoBuffer(GenericFunction): + type = Geometry + package = "geo" + name = "ST_Buffer" + identifier = "buffer" + + The above function will render as follows:: + + >>> print func.geo.buffer() + ST_Buffer() + + .. versionadded:: 0.8 :class:`.GenericFunction` now supports + automatic registration of new functions as well as package + and custom naming support. + + .. versionchanged:: 0.8 The attribute name ``type`` is used + to specify the function's return type at the class level. + Previously, the name ``__return_type__`` was used. This + name is still recognized for backwards-compatibility. + + """ + + coerce_arguments = True + + def __init__(self, *args, **kwargs): + parsed_args = kwargs.pop('_parsed_args', None) + if parsed_args is None: + parsed_args = [_literal_as_binds(c) for c in args] + self.packagenames = [] + self._bind = kwargs.get('bind', None) + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *parsed_args).self_group() + self.type = sqltypes.to_instance( + kwargs.pop("type_", None) or getattr(self, 'type', None)) + +register_function("cast", Cast) +register_function("extract", Extract) + + +class next_value(GenericFunction): + """Represent the 'next value', given a :class:`.Sequence` + as it's single argument. + + Compiles into the appropriate function on each backend, + or will raise NotImplementedError if used on a backend + that does not provide support for sequences. + + """ + type = sqltypes.Integer() + name = "next_value" + + def __init__(self, seq, **kw): + assert isinstance(seq, schema.Sequence), \ + "next_value() accepts a Sequence object as input." + self._bind = kw.get('bind', None) + self.sequence = seq + + @property + def _from_objects(self): + return [] + + +class AnsiFunction(GenericFunction): + def __init__(self, **kwargs): + GenericFunction.__init__(self, **kwargs) + + +class ReturnTypeFromArgs(GenericFunction): + """Define a function whose return type is the same as its arguments.""" + + def __init__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + kwargs.setdefault('type_', _type_from_args(args)) + kwargs['_parsed_args'] = args + GenericFunction.__init__(self, *args, **kwargs) + + +class coalesce(ReturnTypeFromArgs): + pass + + +class max(ReturnTypeFromArgs): + pass + + +class min(ReturnTypeFromArgs): + pass + + +class sum(ReturnTypeFromArgs): + pass + + +class now(GenericFunction): + type = sqltypes.DateTime + + +class concat(GenericFunction): + type = sqltypes.String + + +class char_length(GenericFunction): + type = sqltypes.Integer + + def __init__(self, arg, **kwargs): + GenericFunction.__init__(self, arg, **kwargs) + + +class random(GenericFunction): + pass + + +class count(GenericFunction): + """The ANSI COUNT aggregate function. With no arguments, + emits COUNT \*. + + """ + type = sqltypes.Integer + + def __init__(self, expression=None, **kwargs): + if expression is None: + expression = literal_column('*') + GenericFunction.__init__(self, expression, **kwargs) + + +class current_date(AnsiFunction): + type = sqltypes.Date + + +class current_time(AnsiFunction): + type = sqltypes.Time + + +class current_timestamp(AnsiFunction): + type = sqltypes.DateTime + + +class current_user(AnsiFunction): + type = sqltypes.String + + +class localtime(AnsiFunction): + type = sqltypes.DateTime + + +class localtimestamp(AnsiFunction): + type = sqltypes.DateTime + + +class session_user(AnsiFunction): + type = sqltypes.String + + +class sysdate(AnsiFunction): + type = sqltypes.DateTime + + +class user(AnsiFunction): + type = sqltypes.String diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py new file mode 100644 index 00000000..1c5fae19 --- /dev/null +++ b/lib/sqlalchemy/sql/naming.py @@ -0,0 +1,165 @@ +# sqlalchemy/naming.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Establish constraint and index naming conventions. + + +""" + +from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \ + UniqueConstraint, CheckConstraint, Index, Table, Column +from .. import event, events +from .. import exc +from .elements import _truncated_label +import re + +class conv(_truncated_label): + """Mark a string indicating that a name has already been converted + by a naming convention. + + This is a string subclass that indicates a name that should not be + subject to any further naming conventions. + + E.g. when we create a :class:`.Constraint` using a naming convention + as follows:: + + m = MetaData(naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + t = Table('t', m, Column('x', Integer), + CheckConstraint('x > 5', name='x5')) + + The name of the above constraint will be rendered as ``"ck_t_x5"``. That is, + the existing name ``x5`` is used in the naming convention as the ``constraint_name`` + token. + + In some situations, such as in migration scripts, we may be rendering + the above :class:`.CheckConstraint` with a name that's already been + converted. In order to make sure the name isn't double-modified, the + new name is applied using the :func:`.schema.conv` marker. We can + use this explicitly as follows:: + + + m = MetaData(naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + t = Table('t', m, Column('x', Integer), + CheckConstraint('x > 5', name=conv('ck_t_x5'))) + + Where above, the :func:`.schema.conv` marker indicates that the constraint + name here is final, and the name will render as ``"ck_t_x5"`` and not + ``"ck_t_ck_t_x5"`` + + .. versionadded:: 0.9.4 + + .. seealso:: + + :ref:`constraint_naming_conventions` + + """ + +class ConventionDict(object): + def __init__(self, const, table, convention): + self.const = const + self._is_fk = isinstance(const, ForeignKeyConstraint) + self.table = table + self.convention = convention + self._const_name = const.name + + def _key_table_name(self): + return self.table.name + + def _column_X(self, idx): + if self._is_fk: + fk = self.const.elements[idx] + return fk.parent + else: + return list(self.const.columns)[idx] + + def _key_constraint_name(self): + if not self._const_name: + raise exc.InvalidRequestError( + "Naming convention including " + "%(constraint_name)s token requires that " + "constraint is explicitly named." + ) + if not isinstance(self._const_name, conv): + self.const.name = None + return self._const_name + + def _key_column_X_name(self, idx): + return self._column_X(idx).name + + def _key_column_X_label(self, idx): + return self._column_X(idx)._label + + def _key_referred_table_name(self): + fk = self.const.elements[0] + refs = fk.target_fullname.split(".") + if len(refs) == 3: + refschema, reftable, refcol = refs + else: + reftable, refcol = refs + return reftable + + def _key_referred_column_X_name(self, idx): + fk = self.const.elements[idx] + refs = fk.target_fullname.split(".") + if len(refs) == 3: + refschema, reftable, refcol = refs + else: + reftable, refcol = refs + return refcol + + def __getitem__(self, key): + if key in self.convention: + return self.convention[key](self.const, self.table) + elif hasattr(self, '_key_%s' % key): + return getattr(self, '_key_%s' % key)() + else: + col_template = re.match(r".*_?column_(\d+)_.+", key) + if col_template: + idx = col_template.group(1) + attr = "_key_" + key.replace(idx, "X") + idx = int(idx) + if hasattr(self, attr): + return getattr(self, attr)(idx) + raise KeyError(key) + +_prefix_dict = { + Index: "ix", + PrimaryKeyConstraint: "pk", + CheckConstraint: "ck", + UniqueConstraint: "uq", + ForeignKeyConstraint: "fk" +} + +def _get_convention(dict_, key): + + for super_ in key.__mro__: + if super_ in _prefix_dict and _prefix_dict[super_] in dict_: + return dict_[_prefix_dict[super_]] + elif super_ in dict_: + return dict_[super_] + else: + return None + + +@event.listens_for(Constraint, "after_parent_attach") +@event.listens_for(Index, "after_parent_attach") +def _constraint_name(const, table): + if isinstance(table, Column): + # for column-attached constraint, set another event + # to link the column attached to the table as this constraint + # associated with the table. + event.listen(table, "after_parent_attach", + lambda col, table: _constraint_name(const, table) + ) + elif isinstance(table, Table): + metadata = table.metadata + convention = _get_convention(metadata.naming_convention, type(const)) + if convention is not None: + newname = conv( + convention % ConventionDict(const, table, metadata.naming_convention) + ) + if const.name is None: + const.name = newname diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py new file mode 100644 index 00000000..91301c78 --- /dev/null +++ b/lib/sqlalchemy/sql/operators.py @@ -0,0 +1,867 @@ +# sql/operators.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines operators used in SQL expressions.""" + +from .. import util + + +from operator import ( + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg, + getitem, lshift, rshift + ) + +if util.py2k: + from operator import div +else: + div = truediv + + + +class Operators(object): + """Base of comparison and logical operators. + + Implements base methods :meth:`~sqlalchemy.sql.operators.Operators.operate` and + :meth:`~sqlalchemy.sql.operators.Operators.reverse_operate`, as well as + :meth:`~sqlalchemy.sql.operators.Operators.__and__`, + :meth:`~sqlalchemy.sql.operators.Operators.__or__`, + :meth:`~sqlalchemy.sql.operators.Operators.__invert__`. + + Usually is used via its most common subclass + :class:`.ColumnOperators`. + + """ + def __and__(self, other): + """Implement the ``&`` operator. + + When used with SQL expressions, results in an + AND operation, equivalent to + :func:`~.expression.and_`, that is:: + + a & b + + is equivalent to:: + + from sqlalchemy import and_ + and_(a, b) + + Care should be taken when using ``&`` regarding + operator precedence; the ``&`` operator has the highest precedence. + The operands should be enclosed in parenthesis if they contain + further sub expressions:: + + (a == 2) & (b == 4) + + """ + return self.operate(and_, other) + + def __or__(self, other): + """Implement the ``|`` operator. + + When used with SQL expressions, results in an + OR operation, equivalent to + :func:`~.expression.or_`, that is:: + + a | b + + is equivalent to:: + + from sqlalchemy import or_ + or_(a, b) + + Care should be taken when using ``|`` regarding + operator precedence; the ``|`` operator has the highest precedence. + The operands should be enclosed in parenthesis if they contain + further sub expressions:: + + (a == 2) | (b == 4) + + """ + return self.operate(or_, other) + + def __invert__(self): + """Implement the ``~`` operator. + + When used with SQL expressions, results in a + NOT operation, equivalent to + :func:`~.expression.not_`, that is:: + + ~a + + is equivalent to:: + + from sqlalchemy import not_ + not_(a) + + """ + return self.operate(inv) + + def op(self, opstring, precedence=0, is_comparison=False): + """produce a generic operator function. + + e.g.:: + + somecolumn.op("*")(5) + + produces:: + + somecolumn * 5 + + This function can also be used to make bitwise operators explicit. For + example:: + + somecolumn.op('&')(0xff) + + is a bitwise AND of the value in ``somecolumn``. + + :param operator: a string which will be output as the infix operator + between this element and the expression passed to the + generated function. + + :param precedence: precedence to apply to the operator, when + parenthesizing expressions. A lower number will cause the expression + to be parenthesized when applied against another operator with + higher precedence. The default value of ``0`` is lower than all + operators except for the comma (``,``) and ``AS`` operators. + A value of 100 will be higher or equal to all operators, and -100 + will be lower than or equal to all operators. + + .. versionadded:: 0.8 - added the 'precedence' argument. + + :param is_comparison: if True, the operator will be considered as a + "comparison" operator, that is which evaulates to a boolean true/false + value, like ``==``, ``>``, etc. This flag should be set so that + ORM relationships can establish that the operator is a comparison + operator when used in a custom join condition. + + .. versionadded:: 0.9.2 - added the :paramref:`.Operators.op.is_comparison` + flag. + + .. seealso:: + + :ref:`types_operators` + + :ref:`relationship_custom_operator` + + """ + operator = custom_op(opstring, precedence, is_comparison) + + def against(other): + return operator(self, other) + return against + + def operate(self, op, *other, **kwargs): + """Operate on an argument. + + This is the lowest level of operation, raises + :class:`NotImplementedError` by default. + + Overriding this on a subclass can allow common + behavior to be applied to all operations. + For example, overriding :class:`.ColumnOperators` + to apply ``func.lower()`` to the left and right + side:: + + class MyComparator(ColumnOperators): + def operate(self, op, other): + return op(func.lower(self), func.lower(other)) + + :param op: Operator callable. + :param \*other: the 'other' side of the operation. Will + be a single scalar for most operations. + :param \**kwargs: modifiers. These may be passed by special + operators such as :meth:`ColumnOperators.contains`. + + + """ + raise NotImplementedError(str(op)) + + def reverse_operate(self, op, other, **kwargs): + """Reverse operate on an argument. + + Usage is the same as :meth:`operate`. + + """ + raise NotImplementedError(str(op)) + + +class custom_op(object): + """Represent a 'custom' operator. + + :class:`.custom_op` is normally instantitated when the + :meth:`.ColumnOperators.op` method is used to create a + custom operator callable. The class can also be used directly + when programmatically constructing expressions. E.g. + to represent the "factorial" operation:: + + from sqlalchemy.sql import UnaryExpression + from sqlalchemy.sql import operators + from sqlalchemy import Numeric + + unary = UnaryExpression(table.c.somecolumn, + modifier=operators.custom_op("!"), + type_=Numeric) + + """ + __name__ = 'custom_op' + + def __init__(self, opstring, precedence=0, is_comparison=False): + self.opstring = opstring + self.precedence = precedence + self.is_comparison = is_comparison + + def __eq__(self, other): + return isinstance(other, custom_op) and \ + other.opstring == self.opstring + + def __hash__(self): + return id(self) + + def __call__(self, left, right, **kw): + return left.operate(self, right, **kw) + + +class ColumnOperators(Operators): + """Defines boolean, comparison, and other operators for + :class:`.ColumnElement` expressions. + + By default, all methods call down to + :meth:`.operate` or :meth:`.reverse_operate`, + passing in the appropriate operator function from the + Python builtin ``operator`` module or + a SQLAlchemy-specific operator function from + :mod:`sqlalchemy.expression.operators`. For example + the ``__eq__`` function:: + + def __eq__(self, other): + return self.operate(operators.eq, other) + + Where ``operators.eq`` is essentially:: + + def eq(a, b): + return a == b + + The core column expression unit :class:`.ColumnElement` + overrides :meth:`.Operators.operate` and others + to return further :class:`.ColumnElement` constructs, + so that the ``==`` operation above is replaced by a clause + construct. + + See also: + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + :class:`.ColumnOperators` + + :class:`.PropComparator` + + """ + + timetuple = None + """Hack, allows datetime objects to be compared on the LHS.""" + + def __lt__(self, other): + """Implement the ``<`` operator. + + In a column context, produces the clause ``a < b``. + + """ + return self.operate(lt, other) + + def __le__(self, other): + """Implement the ``<=`` operator. + + In a column context, produces the clause ``a <= b``. + + """ + return self.operate(le, other) + + __hash__ = Operators.__hash__ + + def __eq__(self, other): + """Implement the ``==`` operator. + + In a column context, produces the clause ``a = b``. + If the target is ``None``, produces ``a IS NULL``. + + """ + return self.operate(eq, other) + + def __ne__(self, other): + """Implement the ``!=`` operator. + + In a column context, produces the clause ``a != b``. + If the target is ``None``, produces ``a IS NOT NULL``. + + """ + return self.operate(ne, other) + + def __gt__(self, other): + """Implement the ``>`` operator. + + In a column context, produces the clause ``a > b``. + + """ + return self.operate(gt, other) + + def __ge__(self, other): + """Implement the ``>=`` operator. + + In a column context, produces the clause ``a >= b``. + + """ + return self.operate(ge, other) + + def __neg__(self): + """Implement the ``-`` operator. + + In a column context, produces the clause ``-a``. + + """ + return self.operate(neg) + + def __getitem__(self, index): + """Implement the [] operator. + + This can be used by some database-specific types + such as Postgresql ARRAY and HSTORE. + + """ + return self.operate(getitem, index) + + def __lshift__(self, other): + """implement the << operator. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + << as an extension point. + """ + return self.operate(lshift, other) + + def __rshift__(self, other): + """implement the >> operator. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + >> as an extension point. + """ + return self.operate(rshift, other) + + def concat(self, other): + """Implement the 'concat' operator. + + In a column context, produces the clause ``a || b``, + or uses the ``concat()`` operator on MySQL. + + """ + return self.operate(concat_op, other) + + def like(self, other, escape=None): + """Implement the ``like`` operator. + + In a column context, produces the clause ``a LIKE other``. + + E.g.:: + + select([sometable]).where(sometable.c.column.like("%foobar%")) + + :param other: expression to be compared + :param escape: optional escape character, renders the ``ESCAPE`` + keyword, e.g.:: + + somecolumn.like("foo/%bar", escape="/") + + .. seealso:: + + :meth:`.ColumnOperators.ilike` + + """ + return self.operate(like_op, other, escape=escape) + + def ilike(self, other, escape=None): + """Implement the ``ilike`` operator. + + In a column context, produces the clause ``a ILIKE other``. + + E.g.:: + + select([sometable]).where(sometable.c.column.ilike("%foobar%")) + + :param other: expression to be compared + :param escape: optional escape character, renders the ``ESCAPE`` + keyword, e.g.:: + + somecolumn.ilike("foo/%bar", escape="/") + + .. seealso:: + + :meth:`.ColumnOperators.like` + + """ + return self.operate(ilike_op, other, escape=escape) + + def in_(self, other): + """Implement the ``in`` operator. + + In a column context, produces the clause ``a IN other``. + "other" may be a tuple/list of column expressions, + or a :func:`~.expression.select` construct. + + """ + return self.operate(in_op, other) + + def notin_(self, other): + """implement the ``NOT IN`` operator. + + This is equivalent to using negation with :meth:`.ColumnOperators.in_`, + i.e. ``~x.in_(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.in_` + + """ + return self.operate(notin_op, other) + + def notlike(self, other, escape=None): + """implement the ``NOT LIKE`` operator. + + This is equivalent to using negation with + :meth:`.ColumnOperators.like`, i.e. ``~x.like(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.like` + + """ + return self.operate(notlike_op, other, escape=escape) + + def notilike(self, other, escape=None): + """implement the ``NOT ILIKE`` operator. + + This is equivalent to using negation with + :meth:`.ColumnOperators.ilike`, i.e. ``~x.ilike(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.ilike` + + """ + return self.operate(notilike_op, other, escape=escape) + + def is_(self, other): + """Implement the ``IS`` operator. + + Normally, ``IS`` is generated automatically when comparing to a + value of ``None``, which resolves to ``NULL``. However, explicit + usage of ``IS`` may be desirable if comparing to boolean values + on certain platforms. + + .. versionadded:: 0.7.9 + + .. seealso:: :meth:`.ColumnOperators.isnot` + + """ + return self.operate(is_, other) + + def isnot(self, other): + """Implement the ``IS NOT`` operator. + + Normally, ``IS NOT`` is generated automatically when comparing to a + value of ``None``, which resolves to ``NULL``. However, explicit + usage of ``IS NOT`` may be desirable if comparing to boolean values + on certain platforms. + + .. versionadded:: 0.7.9 + + .. seealso:: :meth:`.ColumnOperators.is_` + + """ + return self.operate(isnot, other) + + def startswith(self, other, **kwargs): + """Implement the ``startwith`` operator. + + In a column context, produces the clause ``LIKE '%'`` + + """ + return self.operate(startswith_op, other, **kwargs) + + def endswith(self, other, **kwargs): + """Implement the 'endswith' operator. + + In a column context, produces the clause ``LIKE '%'`` + + """ + return self.operate(endswith_op, other, **kwargs) + + def contains(self, other, **kwargs): + """Implement the 'contains' operator. + + In a column context, produces the clause ``LIKE '%%'`` + + """ + return self.operate(contains_op, other, **kwargs) + + def match(self, other, **kwargs): + """Implements the 'match' operator. + + In a column context, this produces a MATCH clause, i.e. + ``MATCH ''``. The allowed contents of ``other`` + are database backend specific. + + """ + return self.operate(match_op, other, **kwargs) + + def desc(self): + """Produce a :func:`~.expression.desc` clause against the + parent object.""" + return self.operate(desc_op) + + def asc(self): + """Produce a :func:`~.expression.asc` clause against the + parent object.""" + return self.operate(asc_op) + + def nullsfirst(self): + """Produce a :func:`~.expression.nullsfirst` clause against the + parent object.""" + return self.operate(nullsfirst_op) + + def nullslast(self): + """Produce a :func:`~.expression.nullslast` clause against the + parent object.""" + return self.operate(nullslast_op) + + def collate(self, collation): + """Produce a :func:`~.expression.collate` clause against + the parent object, given the collation string.""" + return self.operate(collate, collation) + + def __radd__(self, other): + """Implement the ``+`` operator in reverse. + + See :meth:`.ColumnOperators.__add__`. + + """ + return self.reverse_operate(add, other) + + def __rsub__(self, other): + """Implement the ``-`` operator in reverse. + + See :meth:`.ColumnOperators.__sub__`. + + """ + return self.reverse_operate(sub, other) + + def __rmul__(self, other): + """Implement the ``*`` operator in reverse. + + See :meth:`.ColumnOperators.__mul__`. + + """ + return self.reverse_operate(mul, other) + + def __rdiv__(self, other): + """Implement the ``/`` operator in reverse. + + See :meth:`.ColumnOperators.__div__`. + + """ + return self.reverse_operate(div, other) + + def between(self, cleft, cright): + """Produce a :func:`~.expression.between` clause against + the parent object, given the lower and upper range.""" + return self.operate(between_op, cleft, cright) + + def distinct(self): + """Produce a :func:`~.expression.distinct` clause against the + parent object. + + """ + return self.operate(distinct_op) + + def __add__(self, other): + """Implement the ``+`` operator. + + In a column context, produces the clause ``a + b`` + if the parent object has non-string affinity. + If the parent object has a string affinity, + produces the concatenation operator, ``a || b`` - + see :meth:`.ColumnOperators.concat`. + + """ + return self.operate(add, other) + + def __sub__(self, other): + """Implement the ``-`` operator. + + In a column context, produces the clause ``a - b``. + + """ + return self.operate(sub, other) + + def __mul__(self, other): + """Implement the ``*`` operator. + + In a column context, produces the clause ``a * b``. + + """ + return self.operate(mul, other) + + def __div__(self, other): + """Implement the ``/`` operator. + + In a column context, produces the clause ``a / b``. + + """ + return self.operate(div, other) + + def __mod__(self, other): + """Implement the ``%`` operator. + + In a column context, produces the clause ``a % b``. + + """ + return self.operate(mod, other) + + def __truediv__(self, other): + """Implement the ``//`` operator. + + In a column context, produces the clause ``a / b``. + + """ + return self.operate(truediv, other) + + def __rtruediv__(self, other): + """Implement the ``//`` operator in reverse. + + See :meth:`.ColumnOperators.__truediv__`. + + """ + return self.reverse_operate(truediv, other) + + +def from_(): + raise NotImplementedError() + + +def as_(): + raise NotImplementedError() + + +def exists(): + raise NotImplementedError() + + +def istrue(a): + raise NotImplementedError() + +def isfalse(a): + raise NotImplementedError() + +def is_(a, b): + return a.is_(b) + + +def isnot(a, b): + return a.isnot(b) + + +def collate(a, b): + return a.collate(b) + + +def op(a, opstring, b): + return a.op(opstring)(b) + + +def like_op(a, b, escape=None): + return a.like(b, escape=escape) + + +def notlike_op(a, b, escape=None): + return a.notlike(b, escape=escape) + + +def ilike_op(a, b, escape=None): + return a.ilike(b, escape=escape) + + +def notilike_op(a, b, escape=None): + return a.notilike(b, escape=escape) + + +def between_op(a, b, c): + return a.between(b, c) + + +def in_op(a, b): + return a.in_(b) + + +def notin_op(a, b): + return a.notin_(b) + + +def distinct_op(a): + return a.distinct() + + +def startswith_op(a, b, escape=None): + return a.startswith(b, escape=escape) + + +def notstartswith_op(a, b, escape=None): + return ~a.startswith(b, escape=escape) + + +def endswith_op(a, b, escape=None): + return a.endswith(b, escape=escape) + + +def notendswith_op(a, b, escape=None): + return ~a.endswith(b, escape=escape) + + +def contains_op(a, b, escape=None): + return a.contains(b, escape=escape) + + +def notcontains_op(a, b, escape=None): + return ~a.contains(b, escape=escape) + + +def match_op(a, b): + return a.match(b) + + +def comma_op(a, b): + raise NotImplementedError() + + +def concat_op(a, b): + return a.concat(b) + + +def desc_op(a): + return a.desc() + + +def asc_op(a): + return a.asc() + + +def nullsfirst_op(a): + return a.nullsfirst() + + +def nullslast_op(a): + return a.nullslast() + + +_commutative = set([eq, ne, add, mul]) + +_comparison = set([eq, ne, lt, gt, ge, le, between_op]) + + +def is_comparison(op): + return op in _comparison or \ + isinstance(op, custom_op) and op.is_comparison + + +def is_commutative(op): + return op in _commutative + + +def is_ordering_modifier(op): + return op in (asc_op, desc_op, + nullsfirst_op, nullslast_op) + +_associative = _commutative.union([concat_op, and_, or_]) + +_natural_self_precedent = _associative.union([getitem]) +"""Operators where if we have (a op b) op c, we don't want to +parenthesize (a op b). + +""" + +_asbool = util.symbol('_asbool', canonical=-10) +_smallest = util.symbol('_smallest', canonical=-100) +_largest = util.symbol('_largest', canonical=100) + +_PRECEDENCE = { + from_: 15, + getitem: 15, + mul: 8, + truediv: 8, + div: 8, + mod: 8, + neg: 8, + add: 7, + sub: 7, + + concat_op: 6, + match_op: 6, + + ilike_op: 6, + notilike_op: 6, + like_op: 6, + notlike_op: 6, + in_op: 6, + notin_op: 6, + + is_: 6, + isnot: 6, + + eq: 5, + ne: 5, + gt: 5, + lt: 5, + ge: 5, + le: 5, + + between_op: 5, + distinct_op: 5, + inv: 5, + istrue: 5, + isfalse: 5, + and_: 3, + or_: 2, + comma_op: -1, + + desc_op: 3, + asc_op: 3, + collate: 4, + + as_: -1, + exists: 0, + _asbool: -10, + _smallest: _smallest, + _largest: _largest +} + + +def is_precedent(operator, against): + if operator is against and operator in _natural_self_precedent: + return False + else: + return (_PRECEDENCE.get(operator, + getattr(operator, 'precedence', _smallest)) <= + _PRECEDENCE.get(against, + getattr(against, 'precedence', _largest))) diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py new file mode 100644 index 00000000..e29fe456 --- /dev/null +++ b/lib/sqlalchemy/sql/schema.py @@ -0,0 +1,3386 @@ +# sql/schema.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The schema module provides the building blocks for database metadata. + +Each element within this module describes a database entity which can be +created and dropped, or is otherwise part of such an entity. Examples include +tables, columns, sequences, and indexes. + +All entities are subclasses of :class:`~sqlalchemy.schema.SchemaItem`, and as +defined in this module they are intended to be agnostic of any vendor-specific +constructs. + +A collection of entities are grouped into a unit called +:class:`~sqlalchemy.schema.MetaData`. MetaData serves as a logical grouping of +schema elements, and can also be associated with an actual database connection +such that operations involving the contained elements can contact the database +as needed. + +Two of the elements here also build upon their "syntactic" counterparts, which +are defined in :class:`~sqlalchemy.sql.expression.`, specifically +:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.Column`. +Since these objects are part of the SQL expression language, they are usable +as components in SQL expressions. + +""" +from __future__ import absolute_import + +import inspect +from .. import exc, util, event, inspection +from .base import SchemaEventTarget, DialectKWArgs +from . import visitors +from . import type_api +from .base import _bind_or_error, ColumnCollection +from .elements import ClauseElement, ColumnClause, _truncated_label, \ + _as_truncated, TextClause, _literal_as_text,\ + ColumnElement, _find_columns, quoted_name +from .selectable import TableClause +import collections +import sqlalchemy +from . import ddl +import types + +RETAIN_SCHEMA = util.symbol('retain_schema') + + +def _get_table_key(name, schema): + if schema is None: + return name + else: + return schema + "." + name + + + +@inspection._self_inspects +class SchemaItem(SchemaEventTarget, visitors.Visitable): + """Base class for items that define a database schema.""" + + __visit_name__ = 'schema_item' + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_default(self, multiparams, params) + + def _init_items(self, *args): + """Initialize the list of child items for this SchemaItem.""" + + for item in args: + if item is not None: + item._set_parent_with_dispatch(self) + + def get_children(self, **kwargs): + """used to allow SchemaVisitor access""" + return [] + + def __repr__(self): + return util.generic_repr(self) + + @property + @util.deprecated('0.9', 'Use ``.name.quote``') + def quote(self): + """Return the value of the ``quote`` flag passed + to this schema object, for those schema items which + have a ``name`` field. + + """ + + return self.name.quote + + @util.memoized_property + def info(self): + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.SchemaItem`. + + The dictionary is automatically generated when first accessed. + It can also be specified in the constructor of some objects, + such as :class:`.Table` and :class:`.Column`. + + """ + return {} + + def _schema_item_copy(self, schema_item): + if 'info' in self.__dict__: + schema_item.info = self.info.copy() + schema_item.dispatch._update(self.dispatch) + return schema_item + + +class Table(DialectKWArgs, SchemaItem, TableClause): + """Represent a table in a database. + + e.g.:: + + mytable = Table("mytable", metadata, + Column('mytable_id', Integer, primary_key=True), + Column('value', String(50)) + ) + + The :class:`.Table` object constructs a unique instance of itself based + on its name and optional schema name within the given + :class:`.MetaData` object. Calling the :class:`.Table` + constructor with the same name and same :class:`.MetaData` argument + a second time will return the *same* :class:`.Table` object - in this way + the :class:`.Table` constructor acts as a registry function. + + .. seealso:: + + :ref:`metadata_describing` - Introduction to database metadata + + Constructor arguments are as follows: + + :param name: The name of this table as represented in the database. + + The table name, along with the value of the ``schema`` parameter, + forms a key which uniquely identifies this :class:`.Table` within + the owning :class:`.MetaData` collection. + Additional calls to :class:`.Table` with the same name, metadata, + and schema name will return the same :class:`.Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word or contain special characters. + A name with any number of upper case characters is considered + to be case sensitive, and will be sent as quoted. + + To enable unconditional quoting for the table name, specify the flag + ``quote=True`` to the constructor, or use the :class:`.quoted_name` + construct to specify the name. + + :param metadata: a :class:`.MetaData` object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`.Connectable`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`.Column` objects contained within this + table. Similar to the style of a CREATE TABLE statement, other + :class:`.SchemaItem` constructs may be added here, including + :class:`.PrimaryKeyConstraint`, and :class:`.ForeignKeyConstraint`. + + :param autoload: Defaults to False: the Columns for this table should + be reflected from the database. Usually there will be no Column + objects in the constructor if this property is set. + + :param autoload_replace: If ``True``, when using ``autoload=True`` + and ``extend_existing=True``, + replace ``Column`` objects already present in the ``Table`` that's + in the ``MetaData`` registry with + what's reflected. Otherwise, all existing columns will be + excluded from the reflection process. Note that this does + not impact ``Column`` objects specified in the same call to ``Table`` + which includes ``autoload``, those always take precedence. + Defaults to ``True``. + + .. versionadded:: 0.7.5 + + :param autoload_with: If autoload==True, this is an optional Engine + or Connection instance to be used for the table reflection. If + ``None``, the underlying MetaData's bound connectable will be used. + + :param extend_existing: When ``True``, indicates that if this + :class:`.Table` is already present in the given :class:`.MetaData`, + apply further arguments within the constructor to the existing + :class:`.Table`. + + If ``extend_existing`` or ``keep_existing`` are not set, an error is + raised if additional table modifiers are specified when + the given :class:`.Table` is already present in the :class:`.MetaData`. + + .. versionchanged:: 0.7.4 + ``extend_existing`` will work in conjunction + with ``autoload=True`` to run a new reflection operation against + the database; new :class:`.Column` objects will be produced + from database metadata to replace those existing with the same + name, and additional :class:`.Column` objects not present + in the :class:`.Table` will be added. + + As is always the case with ``autoload=True``, :class:`.Column` + objects can be specified in the same :class:`.Table` constructor, + which will take precedence. I.e.:: + + Table("mytable", metadata, + Column('y', Integer), + extend_existing=True, + autoload=True, + autoload_with=engine + ) + + The above will overwrite all columns within ``mytable`` which + are present in the database, except for ``y`` which will be used as is + from the above definition. If the ``autoload_replace`` flag + is set to False, no existing columns will be replaced. + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + create_engine() also provides an implicit_returning flag. + + :param include_columns: A list of strings indicating a subset of + columns to be loaded via the ``autoload`` operation; table columns who + aren't present in this list will not be represented on the resulting + ``Table`` object. Defaults to ``None`` which indicates all columns + should be reflected. + + :param info: Optional data dictionary which will be populated into the + :attr:`.SchemaItem.info` attribute of this object. + + :param keep_existing: When ``True``, indicates that if this Table + is already present in the given :class:`.MetaData`, ignore + further arguments within the constructor to the existing + :class:`.Table`, and return the :class:`.Table` object as + originally created. This is to allow a function that wishes + to define a new :class:`.Table` on first call, but on + subsequent calls will return the same :class:`.Table`, + without any of the declarations (particularly constraints) + being applied a second time. Also see extend_existing. + + If extend_existing or keep_existing are not set, an error is + raised if additional table modifiers are specified when + the given :class:`.Table` is already present in the :class:`.MetaData`. + + :param listeners: A list of tuples of the form ``(, )`` + which will be passed to :func:`.event.listen` upon construction. + This alternate hook to :func:`.event.listen` allows the establishment + of a listener function specific to this :class:`.Table` before + the "autoload" process begins. Particularly useful for + the :meth:`.DDLEvents.column_reflect` event:: + + def listen_for_reflect(table, column_info): + "handle the column reflection event" + # ... + + t = Table( + 'sometable', + autoload=True, + listeners=[ + ('column_reflect', listen_for_reflect) + ]) + + :param mustexist: When ``True``, indicates that this Table must already + be present in the given :class:`.MetaData` collection, else + an exception is raised. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The schema name for this table, which is required if + the table resides in a schema other than the default selected schema + for the engine's database connection. Defaults to ``None``. + + The quoting rules for the schema name are the same as those for the + ``name`` parameter, in that quoting is applied for reserved words or + case-sensitive names; to enable unconditional quoting for the + schema name, specify the flag + ``quote_schema=True`` to the constructor, or use the :class:`.quoted_name` + construct to specify the name. + + :param useexisting: Deprecated. Use extend_existing. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + + __visit_name__ = 'table' + + def __new__(cls, *args, **kw): + if not args: + # python3k pickle seems to call this + return object.__new__(cls) + + try: + name, metadata, args = args[0], args[1], args[2:] + except IndexError: + raise TypeError("Table() takes at least two arguments") + + schema = kw.get('schema', None) + if schema is None: + schema = metadata.schema + keep_existing = kw.pop('keep_existing', False) + extend_existing = kw.pop('extend_existing', False) + if 'useexisting' in kw: + msg = "useexisting is deprecated. Use extend_existing." + util.warn_deprecated(msg) + if extend_existing: + msg = "useexisting is synonymous with extend_existing." + raise exc.ArgumentError(msg) + extend_existing = kw.pop('useexisting', False) + + if keep_existing and extend_existing: + msg = "keep_existing and extend_existing are mutually exclusive." + raise exc.ArgumentError(msg) + + mustexist = kw.pop('mustexist', False) + key = _get_table_key(name, schema) + if key in metadata.tables: + if not keep_existing and not extend_existing and bool(args): + raise exc.InvalidRequestError( + "Table '%s' is already defined for this MetaData " + "instance. Specify 'extend_existing=True' " + "to redefine " + "options and columns on an " + "existing Table object." % key) + table = metadata.tables[key] + if extend_existing: + table._init_existing(*args, **kw) + return table + else: + if mustexist: + raise exc.InvalidRequestError( + "Table '%s' not defined" % (key)) + table = object.__new__(cls) + table.dispatch.before_parent_attach(table, metadata) + metadata._add_table(name, schema, table) + try: + table._init(name, metadata, *args, **kw) + table.dispatch.after_parent_attach(table, metadata) + return table + except: + metadata._remove_table(name, schema) + raise + + + @property + @util.deprecated('0.9', 'Use ``table.schema.quote``') + def quote_schema(self): + """Return the value of the ``quote_schema`` flag passed + to this :class:`.Table`. + """ + + return self.schema.quote + + def __init__(self, *args, **kw): + """Constructor for :class:`~.schema.Table`. + + This method is a no-op. See the top-level + documentation for :class:`~.schema.Table` + for constructor arguments. + + """ + # __init__ is overridden to prevent __new__ from + # calling the superclass constructor. + + def _init(self, name, metadata, *args, **kwargs): + super(Table, self).__init__(quoted_name(name, kwargs.pop('quote', None))) + self.metadata = metadata + + self.schema = kwargs.pop('schema', None) + if self.schema is None: + self.schema = metadata.schema + else: + quote_schema = kwargs.pop('quote_schema', None) + self.schema = quoted_name(self.schema, quote_schema) + + self.indexes = set() + self.constraints = set() + self._columns = ColumnCollection() + PrimaryKeyConstraint()._set_parent_with_dispatch(self) + self.foreign_keys = set() + self._extra_dependencies = set() + if self.schema is not None: + self.fullname = "%s.%s" % (self.schema, self.name) + else: + self.fullname = self.name + + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + # this argument is only used with _init_existing() + kwargs.pop('autoload_replace', True) + include_columns = kwargs.pop('include_columns', None) + + self.implicit_returning = kwargs.pop('implicit_returning', True) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + if 'listeners' in kwargs: + listeners = kwargs.pop('listeners') + for evt, fn in listeners: + event.listen(self, evt, fn) + + self._prefixes = kwargs.pop('prefixes', []) + + self._extra_kwargs(**kwargs) + + # load column definitions from the database if 'autoload' is defined + # we do it after the table is in the singleton dictionary to support + # circular foreign keys + if autoload: + self._autoload(metadata, autoload_with, include_columns) + + # initialize all the column, etc. objects. done after reflection to + # allow user-overrides + self._init_items(*args) + + def _autoload(self, metadata, autoload_with, include_columns, + exclude_columns=()): + + if autoload_with: + autoload_with.run_callable( + autoload_with.dialect.reflecttable, + self, include_columns, exclude_columns + ) + else: + bind = _bind_or_error(metadata, + msg="No engine is bound to this Table's MetaData. " + "Pass an engine to the Table via " + "autoload_with=, " + "or associate the MetaData with an engine via " + "metadata.bind=") + bind.run_callable( + bind.dialect.reflecttable, + self, include_columns, exclude_columns + ) + + @property + def _sorted_constraints(self): + """Return the set of constraints as a list, sorted by creation + order. + + """ + return sorted(self.constraints, key=lambda c: c._creation_order) + + def _init_existing(self, *args, **kwargs): + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + autoload_replace = kwargs.pop('autoload_replace', True) + schema = kwargs.pop('schema', None) + if schema and schema != self.schema: + raise exc.ArgumentError( + "Can't change schema of existing table from '%s' to '%s'", + (self.schema, schema)) + + include_columns = kwargs.pop('include_columns', None) + + if include_columns is not None: + for c in self.c: + if c.name not in include_columns: + self._columns.remove(c) + + for key in ('quote', 'quote_schema'): + if key in kwargs: + raise exc.ArgumentError( + "Can't redefine 'quote' or 'quote_schema' arguments") + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + if autoload: + if not autoload_replace: + exclude_columns = [c.name for c in self.c] + else: + exclude_columns = () + self._autoload( + self.metadata, autoload_with, include_columns, exclude_columns) + + self._extra_kwargs(**kwargs) + self._init_items(*args) + + def _extra_kwargs(self, **kwargs): + self._validate_dialect_kwargs(kwargs) + + def _init_collections(self): + pass + + @util.memoized_property + def _autoincrement_column(self): + for col in self.primary_key: + if col.autoincrement and \ + col.type._type_affinity is not None and \ + issubclass(col.type._type_affinity, type_api.INTEGERTYPE._type_affinity) and \ + (not col.foreign_keys or col.autoincrement == 'ignore_fk') and \ + isinstance(col.default, (type(None), Sequence)) and \ + (col.server_default is None or col.server_default.reflected): + return col + + @property + def key(self): + """Return the 'key' for this :class:`.Table`. + + This value is used as the dictionary key within the + :attr:`.MetaData.tables` collection. It is typically the same + as that of :attr:`.Table.name` for a table with no :attr:`.Table.schema` + set; otherwise it is typically of the form ``schemaname.tablename``. + + """ + return _get_table_key(self.name, self.schema) + + def __repr__(self): + return "Table(%s)" % ', '.join( + [repr(self.name)] + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) + + def __str__(self): + return _get_table_key(self.description, self.schema) + + @property + def bind(self): + """Return the connectable associated with this Table.""" + + return self.metadata and self.metadata.bind or None + + def add_is_dependent_on(self, table): + """Add a 'dependency' for this Table. + + This is another Table object which must be created + first before this one can, or dropped after this one. + + Usually, dependencies between tables are determined via + ForeignKey objects. However, for other situations that + create dependencies outside of foreign keys (rules, inheriting), + this method can manually establish such a link. + + """ + self._extra_dependencies.add(table) + + def append_column(self, column): + """Append a :class:`~.schema.Column` to this :class:`~.schema.Table`. + + The "key" of the newly added :class:`~.schema.Column`, i.e. the + value of its ``.key`` attribute, will then be available + in the ``.c`` collection of this :class:`~.schema.Table`, and the + column definition will be included in any CREATE TABLE, SELECT, + UPDATE, etc. statements generated from this :class:`~.schema.Table` + construct. + + Note that this does **not** change the definition of the table + as it exists within any underlying database, assuming that + table has already been created in the database. Relational + databases support the addition of columns to existing tables + using the SQL ALTER command, which would need to be + emitted for an already-existing table that doesn't contain + the newly added column. + + """ + + column._set_parent_with_dispatch(self) + + def append_constraint(self, constraint): + """Append a :class:`~.schema.Constraint` to this + :class:`~.schema.Table`. + + This has the effect of the constraint being included in any + future CREATE TABLE statement, assuming specific DDL creation + events have not been associated with the given + :class:`~.schema.Constraint` object. + + Note that this does **not** produce the constraint within the + relational database automatically, for a table that already exists + in the database. To add a constraint to an + existing relational database table, the SQL ALTER command must + be used. SQLAlchemy also provides the + :class:`.AddConstraint` construct which can produce this SQL when + invoked as an executable clause. + + """ + + constraint._set_parent_with_dispatch(self) + + def append_ddl_listener(self, event_name, listener): + """Append a DDL event listener to this ``Table``. + + .. deprecated:: 0.7 + See :class:`.DDLEvents`. + + """ + + def adapt_listener(target, connection, **kw): + listener(event_name, target, connection) + + event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + + def _set_parent(self, metadata): + metadata._add_table(self.name, self.schema, self) + self.metadata = metadata + + def get_children(self, column_collections=True, + schema_visitor=False, **kw): + if not schema_visitor: + return TableClause.get_children( + self, column_collections=column_collections, **kw) + else: + if column_collections: + return list(self.columns) + else: + return [] + + def exists(self, bind=None): + """Return True if this table exists.""" + + if bind is None: + bind = _bind_or_error(self) + + return bind.run_callable(bind.dialect.has_table, + self.name, schema=self.schema) + + def create(self, bind=None, checkfirst=False): + """Issue a ``CREATE`` statement for this + :class:`.Table`, using the given :class:`.Connectable` + for connectivity. + + .. seealso:: + + :meth:`.MetaData.create_all`. + + """ + + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaGenerator, + self, + checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=False): + """Issue a ``DROP`` statement for this + :class:`.Table`, using the given :class:`.Connectable` + for connectivity. + + .. seealso:: + + :meth:`.MetaData.drop_all`. + + """ + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaDropper, + self, + checkfirst=checkfirst) + + def tometadata(self, metadata, schema=RETAIN_SCHEMA, referred_schema_fn=None): + """Return a copy of this :class:`.Table` associated with a different + :class:`.MetaData`. + + E.g.:: + + m1 = MetaData() + + user = Table('user', m1, Column('id', Integer, priamry_key=True)) + + m2 = MetaData() + user_copy = user.tometadata(m2) + + :param metadata: Target :class:`.MetaData` object, into which the + new :class:`.Table` object will be created. + + :param schema: optional string name indicating the target schema. + Defaults to the special symbol :attr:`.RETAIN_SCHEMA` which indicates + that no change to the schema name should be made in the new + :class:`.Table`. If set to a string name, the new :class:`.Table` + will have this new name as the ``.schema``. If set to ``None``, the + schema will be set to that of the schema set on the target + :class:`.MetaData`, which is typically ``None`` as well, unless + set explicitly:: + + m2 = MetaData(schema='newschema') + + # user_copy_one will have "newschema" as the schema name + user_copy_one = user.tometadata(m2, schema=None) + + m3 = MetaData() # schema defaults to None + + # user_copy_two will have None as the schema name + user_copy_two = user.tometadata(m3, schema=None) + + :param referred_schema_fn: optional callable which can be supplied + in order to provide for the schema name that should be assigned + to the referenced table of a :class:`.ForeignKeyConstraint`. + The callable accepts this parent :class:`.Table`, the + target schema that we are changing to, the :class:`.ForeignKeyConstraint` + object, and the existing "target schema" of that constraint. The + function should return the string schema name that should be applied. + E.g.:: + + def referred_schema_fn(table, to_schema, + constraint, referred_schema): + if referred_schema == 'base_tables': + return referred_schema + else: + return to_schema + + new_table = table.tometadata(m2, schema="alt_schema", + referred_schema_fn=referred_schema_fn) + + .. versionadded:: 0.9.2 + + """ + + if schema is RETAIN_SCHEMA: + schema = self.schema + elif schema is None: + schema = metadata.schema + key = _get_table_key(self.name, schema) + if key in metadata.tables: + util.warn("Table '%s' already exists within the given " + "MetaData - not copying." % self.description) + return metadata.tables[key] + + args = [] + for c in self.columns: + args.append(c.copy(schema=schema)) + table = Table( + self.name, metadata, schema=schema, + *args, **self.kwargs + ) + for c in self.constraints: + if isinstance(c, ForeignKeyConstraint): + referred_schema = c._referred_schema + if referred_schema_fn: + fk_constraint_schema = referred_schema_fn(self, schema, c, referred_schema) + else: + fk_constraint_schema = schema if referred_schema == self.schema else None + table.append_constraint(c.copy(schema=fk_constraint_schema, target_table=table)) + + else: + table.append_constraint(c.copy(schema=schema, target_table=table)) + for index in self.indexes: + # skip indexes that would be generated + # by the 'index' flag on Column + if len(index.columns) == 1 and \ + list(index.columns)[0].index: + continue + Index(index.name, + unique=index.unique, + *[table.c[col] for col in index.columns.keys()], + **index.kwargs) + return self._schema_item_copy(table) + + +class Column(SchemaItem, ColumnClause): + """Represents a column in a database table.""" + + __visit_name__ = 'column' + + def __init__(self, *args, **kwargs): + """ + Construct a new ``Column`` object. + + :param name: The name of this column as represented in the database. + This argument may be the first positional argument, or specified + via keyword. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word. Names with any number of upper + case characters will be quoted and sent exactly. Note that this + behavior applies even for databases which standardize upper + case names as case insensitive such as Oracle. + + The name field may be omitted at construction time and applied + later, at any time before the Column is associated with a + :class:`.Table`. This is to support convenient + usage within the :mod:`~sqlalchemy.ext.declarative` extension. + + :param type\_: The column's type, indicated using an instance which + subclasses :class:`~sqlalchemy.types.TypeEngine`. If no arguments + are required for the type, the class of the type can be sent + as well, e.g.:: + + # use a type with arguments + Column('data', String(50)) + + # use no arguments + Column('level', Integer) + + The ``type`` argument may be the second positional argument + or specified by keyword. + + If the ``type`` is ``None`` or is omitted, it will first default to the special + type :class:`.NullType`. If and when this :class:`.Column` is + made to refer to another column using :class:`.ForeignKey` + and/or :class:`.ForeignKeyConstraint`, the type of the remote-referenced + column will be copied to this column as well, at the moment that + the foreign key is resolved against that remote :class:`.Column` + object. + + .. versionchanged:: 0.9.0 + Support for propagation of type to a :class:`.Column` from its + :class:`.ForeignKey` object has been improved and should be + more reliable and timely. + + :param \*args: Additional positional arguments include various + :class:`.SchemaItem` derived constructs which will be applied + as options to the column. These include instances of + :class:`.Constraint`, :class:`.ForeignKey`, :class:`.ColumnDefault`, + and :class:`.Sequence`. In some cases an equivalent keyword + argument is available such as ``server_default``, ``default`` + and ``unique``. + + :param autoincrement: This flag may be set to ``False`` to + indicate an integer primary key column that should not be + considered to be the "autoincrement" column, that is + the integer primary key column which generates values + implicitly upon INSERT and whose value is usually returned + via the DBAPI cursor.lastrowid attribute. It defaults + to ``True`` to satisfy the common use case of a table + with a single integer primary key column. If the table + has a composite primary key consisting of more than one + integer column, set this flag to True only on the + column that should be considered "autoincrement". + + The setting *only* has an effect for columns which are: + + * Integer derived (i.e. INT, SMALLINT, BIGINT). + + * Part of the primary key + + * Are not referenced by any foreign keys, unless + the value is specified as ``'ignore_fk'`` + + .. versionadded:: 0.7.4 + + * have no server side or client side defaults (with the exception + of Postgresql SERIAL). + + The setting has these two effects on columns that meet the + above criteria: + + * DDL issued for the column will include database-specific + keywords intended to signify this column as an + "autoincrement" column, such as AUTO INCREMENT on MySQL, + SERIAL on Postgresql, and IDENTITY on MS-SQL. It does + *not* issue AUTOINCREMENT for SQLite since this is a + special SQLite flag that is not required for autoincrementing + behavior. See the SQLite dialect documentation for + information on SQLite's AUTOINCREMENT. + + * The column will be considered to be available as + cursor.lastrowid or equivalent, for those dialects which + "post fetch" newly inserted identifiers after a row has + been inserted (SQLite, MySQL, MS-SQL). It does not have + any effect in this regard for databases that use sequences + to generate primary key identifiers (i.e. Firebird, Postgresql, + Oracle). + + .. versionchanged:: 0.7.4 + ``autoincrement`` accepts a special value ``'ignore_fk'`` + to indicate that autoincrementing status regardless of foreign + key references. This applies to certain composite foreign key + setups, such as the one demonstrated in the ORM documentation + at :ref:`post_update`. + + :param default: A scalar, Python callable, or + :class:`.ColumnElement` expression representing the + *default value* for this column, which will be invoked upon insert + if this column is otherwise not specified in the VALUES clause of + the insert. This is a shortcut to using :class:`.ColumnDefault` as + a positional argument; see that class for full detail on the + structure of the argument. + + Contrast this argument to ``server_default`` which creates a + default generator on the database side. + + :param doc: optional String that can be used by the ORM or similar + to document attributes. This attribute does not render SQL + comments (a future attribute 'comment' will achieve that). + + :param key: An optional string identifier which will identify this + ``Column`` object on the :class:`.Table`. When a key is provided, + this is the only identifier referencing the ``Column`` within the + application, including ORM attribute mapping; the ``name`` field + is used only when rendering SQL. + + :param index: When ``True``, indicates that the column is indexed. + This is a shortcut for using a :class:`.Index` construct on the + table. To specify indexes with explicit names or indexes that + contain multiple columns, use the :class:`.Index` construct + instead. + + :param info: Optional data dictionary which will be populated into the + :attr:`.SchemaItem.info` attribute of this object. + + :param nullable: If set to the default of ``True``, indicates the + column will be rendered as allowing NULL, else it's rendered as + NOT NULL. This parameter is only used when issuing CREATE TABLE + statements. + + :param onupdate: A scalar, Python callable, or + :class:`~sqlalchemy.sql.expression.ClauseElement` representing a + default value to be applied to the column within UPDATE + statements, which wil be invoked upon update if this column is not + present in the SET clause of the update. This is a shortcut to + using :class:`.ColumnDefault` as a positional argument with + ``for_update=True``. + + :param primary_key: If ``True``, marks this column as a primary key + column. Multiple columns can have this flag set to specify + composite primary keys. As an alternative, the primary key of a + :class:`.Table` can be specified via an explicit + :class:`.PrimaryKeyConstraint` object. + + :param server_default: A :class:`.FetchedValue` instance, str, Unicode + or :func:`~sqlalchemy.sql.expression.text` construct representing + the DDL DEFAULT value for the column. + + String types will be emitted as-is, surrounded by single quotes:: + + Column('x', Text, server_default="val") + + x TEXT DEFAULT 'val' + + A :func:`~sqlalchemy.sql.expression.text` expression will be + rendered as-is, without quotes:: + + Column('y', DateTime, server_default=text('NOW()')) + + y DATETIME DEFAULT NOW() + + Strings and text() will be converted into a :class:`.DefaultClause` + object upon initialization. + + Use :class:`.FetchedValue` to indicate that an already-existing + column will generate a default value on the database side which + will be available to SQLAlchemy for post-fetch after inserts. This + construct does not specify any DDL and the implementation is left + to the database, such as via a trigger. + + :param server_onupdate: A :class:`.FetchedValue` instance + representing a database-side default generation function. This + indicates to SQLAlchemy that a newly generated value will be + available after updates. This construct does not specify any DDL + and the implementation is left to the database, such as via a + trigger. + + :param quote: Force quoting of this column's name on or off, + corresponding to ``True`` or ``False``. When left at its default + of ``None``, the column identifier will be quoted according to + whether the name is case sensitive (identifiers with at least one + upper case character are treated as case sensitive), or if it's a + reserved word. This flag is only needed to force quoting of a + reserved word which is not known by the SQLAlchemy dialect. + + :param unique: When ``True``, indicates that this column contains a + unique constraint, or if ``index`` is ``True`` as well, indicates + that the :class:`.Index` should be created with the unique flag. + To specify multiple columns in the constraint/index or to specify + an explicit name, use the :class:`.UniqueConstraint` or + :class:`.Index` constructs explicitly. + + :param system: When ``True``, indicates this is a "system" column, + that is a column which is automatically made available by the + database, and should not be included in the columns list for a + ``CREATE TABLE`` statement. + + For more elaborate scenarios where columns should be conditionally + rendered differently on different backends, consider custom + compilation rules for :class:`.CreateColumn`. + + ..versionadded:: 0.8.3 Added the ``system=True`` parameter to + :class:`.Column`. + + """ + + name = kwargs.pop('name', None) + type_ = kwargs.pop('type_', None) + args = list(args) + if args: + if isinstance(args[0], util.string_types): + if name is not None: + raise exc.ArgumentError( + "May not pass name positionally and as a keyword.") + name = args.pop(0) + if args: + coltype = args[0] + + if hasattr(coltype, "_sqla_type"): + if type_ is not None: + raise exc.ArgumentError( + "May not pass type_ positionally and as a keyword.") + type_ = args.pop(0) + + if name is not None: + name = quoted_name(name, kwargs.pop('quote', None)) + elif "quote" in kwargs: + raise exc.ArgumentError("Explicit 'name' is required when " + "sending 'quote' argument") + + super(Column, self).__init__(name, type_) + self.key = kwargs.pop('key', name) + self.primary_key = kwargs.pop('primary_key', False) + self.nullable = kwargs.pop('nullable', not self.primary_key) + self.default = kwargs.pop('default', None) + self.server_default = kwargs.pop('server_default', None) + self.server_onupdate = kwargs.pop('server_onupdate', None) + + # these default to None because .index and .unique is *not* + # an informational flag about Column - there can still be an + # Index or UniqueConstraint referring to this Column. + self.index = kwargs.pop('index', None) + self.unique = kwargs.pop('unique', None) + + self.system = kwargs.pop('system', False) + self.doc = kwargs.pop('doc', None) + self.onupdate = kwargs.pop('onupdate', None) + self.autoincrement = kwargs.pop('autoincrement', True) + self.constraints = set() + self.foreign_keys = set() + + # check if this Column is proxying another column + if '_proxies' in kwargs: + self._proxies = kwargs.pop('_proxies') + # otherwise, add DDL-related events + elif isinstance(self.type, SchemaEventTarget): + self.type._set_parent_with_dispatch(self) + + if self.default is not None: + if isinstance(self.default, (ColumnDefault, Sequence)): + args.append(self.default) + else: + if getattr(self.type, '_warn_on_bytestring', False): + if isinstance(self.default, util.binary_type): + util.warn("Unicode column received non-unicode " + "default value.") + args.append(ColumnDefault(self.default)) + + if self.server_default is not None: + if isinstance(self.server_default, FetchedValue): + args.append(self.server_default._as_for_update(False)) + else: + args.append(DefaultClause(self.server_default)) + + if self.onupdate is not None: + if isinstance(self.onupdate, (ColumnDefault, Sequence)): + args.append(self.onupdate) + else: + args.append(ColumnDefault(self.onupdate, for_update=True)) + + if self.server_onupdate is not None: + if isinstance(self.server_onupdate, FetchedValue): + args.append(self.server_onupdate._as_for_update(True)) + else: + args.append(DefaultClause(self.server_onupdate, + for_update=True)) + self._init_items(*args) + + util.set_creation_order(self) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + if kwargs: + raise exc.ArgumentError( + "Unknown arguments passed to Column: " + repr(list(kwargs))) + +# @property +# def quote(self): +# return getattr(self.name, "quote", None) + + def __str__(self): + if self.name is None: + return "(no name)" + elif self.table is not None: + if self.table.named_with_column: + return (self.table.description + "." + self.description) + else: + return self.description + else: + return self.description + + def references(self, column): + """Return True if this Column references the given column via foreign + key.""" + + for fk in self.foreign_keys: + if fk.column.proxy_set.intersection(column.proxy_set): + return True + else: + return False + + def append_foreign_key(self, fk): + fk._set_parent_with_dispatch(self) + + def __repr__(self): + kwarg = [] + if self.key != self.name: + kwarg.append('key') + if self.primary_key: + kwarg.append('primary_key') + if not self.nullable: + kwarg.append('nullable') + if self.onupdate: + kwarg.append('onupdate') + if self.default: + kwarg.append('default') + if self.server_default: + kwarg.append('server_default') + return "Column(%s)" % ', '.join( + [repr(self.name)] + [repr(self.type)] + + [repr(x) for x in self.foreign_keys if x is not None] + + [repr(x) for x in self.constraints] + + [(self.table is not None and "table=<%s>" % + self.table.description or "table=None")] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) + + def _set_parent(self, table): + if not self.name: + raise exc.ArgumentError( + "Column must be constructed with a non-blank name or " + "assign a non-blank .name before adding to a Table.") + if self.key is None: + self.key = self.name + + existing = getattr(self, 'table', None) + if existing is not None and existing is not table: + raise exc.ArgumentError( + "Column object already assigned to Table '%s'" % + existing.description) + + if self.key in table._columns: + col = table._columns.get(self.key) + if col is not self: + for fk in col.foreign_keys: + table.foreign_keys.remove(fk) + if fk.constraint in table.constraints: + # this might have been removed + # already, if it's a composite constraint + # and more than one col being replaced + table.constraints.remove(fk.constraint) + + table._columns.replace(self) + + if self.primary_key: + table.primary_key._replace(self) + Table._autoincrement_column._reset(table) + elif self.key in table.primary_key: + raise exc.ArgumentError( + "Trying to redefine primary-key column '%s' as a " + "non-primary-key column on table '%s'" % ( + self.key, table.fullname)) + self.table = table + + if self.index: + if isinstance(self.index, util.string_types): + raise exc.ArgumentError( + "The 'index' keyword argument on Column is boolean only. " + "To create indexes with a specific name, create an " + "explicit Index object external to the Table.") + Index(None, self, unique=bool(self.unique)) + elif self.unique: + if isinstance(self.unique, util.string_types): + raise exc.ArgumentError( + "The 'unique' keyword argument on Column is boolean " + "only. To create unique constraints or indexes with a " + "specific name, append an explicit UniqueConstraint to " + "the Table's list of elements, or create an explicit " + "Index object external to the Table.") + table.append_constraint(UniqueConstraint(self.key)) + + fk_key = (table.key, self.key) + if fk_key in self.table.metadata._fk_memos: + for fk in self.table.metadata._fk_memos[fk_key]: + fk._set_remote_table(table) + + def _on_table_attach(self, fn): + if self.table is not None: + fn(self, self.table) + event.listen(self, 'after_parent_attach', fn) + + def copy(self, **kw): + """Create a copy of this ``Column``, unitialized. + + This is used in ``Table.tometadata``. + + """ + + # Constraint objects plus non-constraint-bound ForeignKey objects + args = \ + [c.copy(**kw) for c in self.constraints] + \ + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + + type_ = self.type + if isinstance(type_, SchemaEventTarget): + type_ = type_.copy(**kw) + + c = self._constructor( + name=self.name, + type_=type_, + key=self.key, + primary_key=self.primary_key, + nullable=self.nullable, + unique=self.unique, + system=self.system, + #quote=self.quote, + index=self.index, + autoincrement=self.autoincrement, + default=self.default, + server_default=self.server_default, + onupdate=self.onupdate, + server_onupdate=self.server_onupdate, + doc=self.doc, + *args + ) + return self._schema_item_copy(c) + + def _make_proxy(self, selectable, name=None, key=None, + name_is_truncatable=False, **kw): + """Create a *proxy* for this column. + + This is a copy of this ``Column`` referenced by a different parent + (such as an alias or select statement). The column should + be used only in select scenarios, as its full DDL/default + information is not transferred. + + """ + fk = [ForeignKey(f.column, _constraint=f.constraint) + for f in self.foreign_keys] + if name is None and self.name is None: + raise exc.InvalidRequestError("Cannot initialize a sub-selectable" + " with this Column object until it's 'name' has " + "been assigned.") + try: + c = self._constructor( + _as_truncated(name or self.name) if \ + name_is_truncatable else (name or self.name), + self.type, + key=key if key else name if name else self.key, + primary_key=self.primary_key, + nullable=self.nullable, + _proxies=[self], *fk) + except TypeError: + util.raise_from_cause( + TypeError( + "Could not create a copy of this %r object. " + "Ensure the class includes a _constructor() " + "attribute or method which accepts the " + "standard Column constructor arguments, or " + "references the Column class itself." % self.__class__) + ) + + c.table = selectable + selectable._columns.add(c) + if selectable._is_clone_of is not None: + c._is_clone_of = selectable._is_clone_of.columns[c.key] + if self.primary_key: + selectable.primary_key.add(c) + c.dispatch.after_parent_attach(c, selectable) + return c + + def get_children(self, schema_visitor=False, **kwargs): + if schema_visitor: + return [x for x in (self.default, self.onupdate) + if x is not None] + \ + list(self.foreign_keys) + list(self.constraints) + else: + return ColumnClause.get_children(self, **kwargs) + + +class ForeignKey(DialectKWArgs, SchemaItem): + """Defines a dependency between two columns. + + ``ForeignKey`` is specified as an argument to a :class:`.Column` object, + e.g.:: + + t = Table("remote_table", metadata, + Column("remote_id", ForeignKey("main_table.id")) + ) + + Note that ``ForeignKey`` is only a marker object that defines + a dependency between two columns. The actual constraint + is in all cases represented by the :class:`.ForeignKeyConstraint` + object. This object will be generated automatically when + a ``ForeignKey`` is associated with a :class:`.Column` which + in turn is associated with a :class:`.Table`. Conversely, + when :class:`.ForeignKeyConstraint` is applied to a :class:`.Table`, + ``ForeignKey`` markers are automatically generated to be + present on each associated :class:`.Column`, which are also + associated with the constraint object. + + Note that you cannot define a "composite" foreign key constraint, + that is a constraint between a grouping of multiple parent/child + columns, using ``ForeignKey`` objects. To define this grouping, + the :class:`.ForeignKeyConstraint` object must be used, and applied + to the :class:`.Table`. The associated ``ForeignKey`` objects + are created automatically. + + The ``ForeignKey`` objects associated with an individual + :class:`.Column` object are available in the `foreign_keys` collection + of that column. + + Further examples of foreign key configuration are in + :ref:`metadata_foreignkeys`. + + """ + + __visit_name__ = 'foreign_key' + + def __init__(self, column, _constraint=None, use_alter=False, name=None, + onupdate=None, ondelete=None, deferrable=None, + initially=None, link_to_name=False, match=None, + **dialect_kw): + """ + Construct a column-level FOREIGN KEY. + + The :class:`.ForeignKey` object when constructed generates a + :class:`.ForeignKeyConstraint` which is associated with the parent + :class:`.Table` object's collection of constraints. + + :param column: A single target column for the key relationship. A + :class:`.Column` object or a column name as a string: + ``tablename.columnkey`` or ``schema.tablename.columnkey``. + ``columnkey`` is the ``key`` which has been assigned to the column + (defaults to the column name itself), unless ``link_to_name`` is + ``True`` in which case the rendered name of the column is used. + + .. versionadded:: 0.7.4 + Note that if the schema name is not included, and the + underlying :class:`.MetaData` has a "schema", that value will + be used. + + :param name: Optional string. An in-database name for the key if + `constraint` is not provided. + + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + + :param initially: Optional string. If set, emit INITIALLY when + issuing DDL for this constraint. + + :param link_to_name: if True, the string name given in ``column`` is + the rendered name of the referenced column, not its locally + assigned ``key``. + + :param use_alter: passed to the underlying + :class:`.ForeignKeyConstraint` to indicate the constraint should be + generated/dropped externally from the CREATE TABLE/ DROP TABLE + statement. See that classes' constructor for details. + + :param match: Optional string. If set, emit MATCH when issuing + DDL for this constraint. Typical values include SIMPLE, PARTIAL + and FULL. + + :param \**dialect_kw: Additional keyword arguments are dialect specific, + and passed in the form ``_``. The arguments + are ultimately handled by a corresponding :class:`.ForeignKeyConstraint`. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + .. versionadded:: 0.9.2 + + """ + + self._colspec = column + if isinstance(self._colspec, util.string_types): + self._table_column = None + else: + if hasattr(self._colspec, '__clause_element__'): + self._table_column = self._colspec.__clause_element__() + else: + self._table_column = self._colspec + + if not isinstance(self._table_column, ColumnClause): + raise exc.ArgumentError( + "String, Column, or Column-bound argument " + "expected, got %r" % self._table_column) + elif not isinstance(self._table_column.table, (util.NoneType, TableClause)): + raise exc.ArgumentError( + "ForeignKey received Column not bound " + "to a Table, got: %r" % self._table_column.table + ) + + # the linked ForeignKeyConstraint. + # ForeignKey will create this when parent Column + # is attached to a Table, *or* ForeignKeyConstraint + # object passes itself in when creating ForeignKey + # markers. + self.constraint = _constraint + self.parent = None + self.use_alter = use_alter + self.name = name + self.onupdate = onupdate + self.ondelete = ondelete + self.deferrable = deferrable + self.initially = initially + self.link_to_name = link_to_name + self.match = match + self._unvalidated_dialect_kw = dialect_kw + + def __repr__(self): + return "ForeignKey(%r)" % self._get_colspec() + + def copy(self, schema=None): + """Produce a copy of this :class:`.ForeignKey` object. + + The new :class:`.ForeignKey` will not be bound + to any :class:`.Column`. + + This method is usually used by the internal + copy procedures of :class:`.Column`, :class:`.Table`, + and :class:`.MetaData`. + + :param schema: The returned :class:`.ForeignKey` will + reference the original table and column name, qualified + by the given string schema name. + + """ + + fk = ForeignKey( + self._get_colspec(schema=schema), + use_alter=self.use_alter, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + deferrable=self.deferrable, + initially=self.initially, + link_to_name=self.link_to_name, + match=self.match, + **self._unvalidated_dialect_kw + ) + return self._schema_item_copy(fk) + + + def _get_colspec(self, schema=None): + """Return a string based 'column specification' for this + :class:`.ForeignKey`. + + This is usually the equivalent of the string-based "tablename.colname" + argument first passed to the object's constructor. + + """ + if schema: + _schema, tname, colname = self._column_tokens + return "%s.%s.%s" % (schema, tname, colname) + elif self._table_column is not None: + return "%s.%s" % ( + self._table_column.table.fullname, self._table_column.key) + else: + return self._colspec + + @property + def _referred_schema(self): + return self._column_tokens[0] + + + def _table_key(self): + if self._table_column is not None: + if self._table_column.table is None: + return None + else: + return self._table_column.table.key + else: + schema, tname, colname = self._column_tokens + return _get_table_key(tname, schema) + + + + target_fullname = property(_get_colspec) + + def references(self, table): + """Return True if the given :class:`.Table` is referenced by this + :class:`.ForeignKey`.""" + + return table.corresponding_column(self.column) is not None + + def get_referent(self, table): + """Return the :class:`.Column` in the given :class:`.Table` + referenced by this :class:`.ForeignKey`. + + Returns None if this :class:`.ForeignKey` does not reference the given + :class:`.Table`. + + """ + + return table.corresponding_column(self.column) + + @util.memoized_property + def _column_tokens(self): + """parse a string-based _colspec into its component parts.""" + + m = self._get_colspec().split('.') + if m is None: + raise exc.ArgumentError( + "Invalid foreign key column specification: %s" % + self._colspec) + if (len(m) == 1): + tname = m.pop() + colname = None + else: + colname = m.pop() + tname = m.pop() + + # A FK between column 'bar' and table 'foo' can be + # specified as 'foo', 'foo.bar', 'dbo.foo.bar', + # 'otherdb.dbo.foo.bar'. Once we have the column name and + # the table name, treat everything else as the schema + # name. Some databases (e.g. Sybase) support + # inter-database foreign keys. See tickets#1341 and -- + # indirectly related -- Ticket #594. This assumes that '.' + # will never appear *within* any component of the FK. + + if (len(m) > 0): + schema = '.'.join(m) + else: + schema = None + return schema, tname, colname + + def _resolve_col_tokens(self): + if self.parent is None: + raise exc.InvalidRequestError( + "this ForeignKey object does not yet have a " + "parent Column associated with it.") + + elif self.parent.table is None: + raise exc.InvalidRequestError( + "this ForeignKey's parent column is not yet associated " + "with a Table.") + + parenttable = self.parent.table + + # assertion, can be commented out. + # basically Column._make_proxy() sends the actual + # target Column to the ForeignKey object, so the + # string resolution here is never called. + for c in self.parent.base_columns: + if isinstance(c, Column): + assert c.table is parenttable + break + else: + assert False + ###################### + + schema, tname, colname = self._column_tokens + + if schema is None and parenttable.metadata.schema is not None: + schema = parenttable.metadata.schema + + tablekey = _get_table_key(tname, schema) + return parenttable, tablekey, colname + + + def _link_to_col_by_colstring(self, parenttable, table, colname): + if not hasattr(self.constraint, '_referred_table'): + self.constraint._referred_table = table + else: + assert self.constraint._referred_table is table + + _column = None + if colname is None: + # colname is None in the case that ForeignKey argument + # was specified as table name only, in which case we + # match the column name to the same column on the + # parent. + key = self.parent + _column = table.c.get(self.parent.key, None) + elif self.link_to_name: + key = colname + for c in table.c: + if c.name == colname: + _column = c + else: + key = colname + _column = table.c.get(colname, None) + + if _column is None: + raise exc.NoReferencedColumnError( + "Could not initialize target column for ForeignKey '%s' on table '%s': " + "table '%s' has no column named '%s'" % ( + self._colspec, parenttable.name, table.name, key), + table.name, key) + + self._set_target_column(_column) + + def _set_target_column(self, column): + # propagate TypeEngine to parent if it didn't have one + if self.parent.type._isnull: + self.parent.type = column.type + + # super-edgy case, if other FKs point to our column, + # they'd get the type propagated out also. + if isinstance(self.parent.table, Table): + fk_key = (self.parent.table.key, self.parent.key) + if fk_key in self.parent.table.metadata._fk_memos: + for fk in self.parent.table.metadata._fk_memos[fk_key]: + if fk.parent.type._isnull: + fk.parent.type = column.type + + self.column = column + + @util.memoized_property + def column(self): + """Return the target :class:`.Column` referenced by this + :class:`.ForeignKey`. + + If no target column has been established, an exception + is raised. + + .. versionchanged:: 0.9.0 + Foreign key target column resolution now occurs as soon as both + the ForeignKey object and the remote Column to which it refers + are both associated with the same MetaData object. + + """ + + if isinstance(self._colspec, util.string_types): + + parenttable, tablekey, colname = self._resolve_col_tokens() + + if tablekey not in parenttable.metadata: + raise exc.NoReferencedTableError( + "Foreign key associated with column '%s' could not find " + "table '%s' with which to generate a " + "foreign key to target column '%s'" % + (self.parent, tablekey, colname), + tablekey) + elif parenttable.key not in parenttable.metadata: + raise exc.InvalidRequestError( + "Table %s is no longer associated with its " + "parent MetaData" % parenttable) + else: + raise exc.NoReferencedColumnError( + "Could not initialize target column for " + "ForeignKey '%s' on table '%s': " + "table '%s' has no column named '%s'" % ( + self._colspec, parenttable.name, tablekey, colname), + tablekey, colname) + elif hasattr(self._colspec, '__clause_element__'): + _column = self._colspec.__clause_element__() + return _column + else: + _column = self._colspec + return _column + + def _set_parent(self, column): + if self.parent is not None and self.parent is not column: + raise exc.InvalidRequestError( + "This ForeignKey already has a parent !") + self.parent = column + self.parent.foreign_keys.add(self) + self.parent._on_table_attach(self._set_table) + + def _set_remote_table(self, table): + parenttable, tablekey, colname = self._resolve_col_tokens() + self._link_to_col_by_colstring(parenttable, table, colname) + self.constraint._validate_dest_table(table) + + def _remove_from_metadata(self, metadata): + parenttable, table_key, colname = self._resolve_col_tokens() + fk_key = (table_key, colname) + + if self in metadata._fk_memos[fk_key]: + # TODO: no test coverage for self not in memos + metadata._fk_memos[fk_key].remove(self) + + def _set_table(self, column, table): + # standalone ForeignKey - create ForeignKeyConstraint + # on the hosting Table when attached to the Table. + if self.constraint is None and isinstance(table, Table): + self.constraint = ForeignKeyConstraint( + [], [], use_alter=self.use_alter, name=self.name, + onupdate=self.onupdate, ondelete=self.ondelete, + deferrable=self.deferrable, initially=self.initially, + match=self.match, + **self._unvalidated_dialect_kw + ) + self.constraint._elements[self.parent] = self + self.constraint._set_parent_with_dispatch(table) + table.foreign_keys.add(self) + + # set up remote ".column" attribute, or a note to pick it + # up when the other Table/Column shows up + if isinstance(self._colspec, util.string_types): + parenttable, table_key, colname = self._resolve_col_tokens() + fk_key = (table_key, colname) + if table_key in parenttable.metadata.tables: + table = parenttable.metadata.tables[table_key] + try: + self._link_to_col_by_colstring(parenttable, table, colname) + except exc.NoReferencedColumnError: + # this is OK, we'll try later + pass + parenttable.metadata._fk_memos[fk_key].append(self) + elif hasattr(self._colspec, '__clause_element__'): + _column = self._colspec.__clause_element__() + self._set_target_column(_column) + else: + _column = self._colspec + self._set_target_column(_column) + + + +class _NotAColumnExpr(object): + def _not_a_column_expr(self): + raise exc.InvalidRequestError( + "This %s cannot be used directly " + "as a column expression." % self.__class__.__name__) + + __clause_element__ = self_group = lambda self: self._not_a_column_expr() + _from_objects = property(lambda self: self._not_a_column_expr()) + + +class DefaultGenerator(_NotAColumnExpr, SchemaItem): + """Base class for column *default* values.""" + + __visit_name__ = 'default_generator' + + is_sequence = False + is_server_default = False + column = None + + def __init__(self, for_update=False): + self.for_update = for_update + + def _set_parent(self, column): + self.column = column + if self.for_update: + self.column.onupdate = self + else: + self.column.default = self + + def execute(self, bind=None, **kwargs): + if bind is None: + bind = _bind_or_error(self) + return bind._execute_default(self, **kwargs) + + @property + def bind(self): + """Return the connectable associated with this default.""" + if getattr(self, 'column', None) is not None: + return self.column.table.bind + else: + return None + + +class ColumnDefault(DefaultGenerator): + """A plain default value on a column. + + This could correspond to a constant, a callable function, + or a SQL clause. + + :class:`.ColumnDefault` is generated automatically + whenever the ``default``, ``onupdate`` arguments of + :class:`.Column` are used. A :class:`.ColumnDefault` + can be passed positionally as well. + + For example, the following:: + + Column('foo', Integer, default=50) + + Is equivalent to:: + + Column('foo', Integer, ColumnDefault(50)) + + + """ + + def __init__(self, arg, **kwargs): + """"Construct a new :class:`.ColumnDefault`. + + + :param arg: argument representing the default value. + May be one of the following: + + * a plain non-callable Python value, such as a + string, integer, boolean, or other simple type. + The default value will be used as is each time. + * a SQL expression, that is one which derives from + :class:`.ColumnElement`. The SQL expression will + be rendered into the INSERT or UPDATE statement, + or in the case of a primary key column when + RETURNING is not used may be + pre-executed before an INSERT within a SELECT. + * A Python callable. The function will be invoked for each + new row subject to an INSERT or UPDATE. + The callable must accept exactly + zero or one positional arguments. The one-argument form + will receive an instance of the :class:`.ExecutionContext`, + which provides contextual information as to the current + :class:`.Connection` in use as well as the current + statement and parameters. + + """ + super(ColumnDefault, self).__init__(**kwargs) + if isinstance(arg, FetchedValue): + raise exc.ArgumentError( + "ColumnDefault may not be a server-side default type.") + if util.callable(arg): + arg = self._maybe_wrap_callable(arg) + self.arg = arg + + @util.memoized_property + def is_callable(self): + return util.callable(self.arg) + + @util.memoized_property + def is_clause_element(self): + return isinstance(self.arg, ClauseElement) + + @util.memoized_property + def is_scalar(self): + return not self.is_callable and \ + not self.is_clause_element and \ + not self.is_sequence + + def _maybe_wrap_callable(self, fn): + """Wrap callables that don't accept a context. + + This is to allow easy compatiblity with default callables + that aren't specific to accepting of a context. + + """ + try: + argspec = util.get_callable_argspec(fn, no_self=True) + except TypeError: + return lambda ctx: fn() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + positionals = len(argspec[0]) - defaulted + + if positionals == 0: + return lambda ctx: fn() + elif positionals == 1: + return fn + else: + raise exc.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + + def _visit_name(self): + if self.for_update: + return "column_onupdate" + else: + return "column_default" + __visit_name__ = property(_visit_name) + + def __repr__(self): + return "ColumnDefault(%r)" % self.arg + + +class Sequence(DefaultGenerator): + """Represents a named database sequence. + + The :class:`.Sequence` object represents the name and configurational + parameters of a database sequence. It also represents + a construct that can be "executed" by a SQLAlchemy :class:`.Engine` + or :class:`.Connection`, rendering the appropriate "next value" function + for the target database and returning a result. + + The :class:`.Sequence` is typically associated with a primary key column:: + + some_table = Table('some_table', metadata, + Column('id', Integer, Sequence('some_table_seq'), primary_key=True) + ) + + When CREATE TABLE is emitted for the above :class:`.Table`, if the + target platform supports sequences, a CREATE SEQUENCE statement will + be emitted as well. For platforms that don't support sequences, + the :class:`.Sequence` construct is ignored. + + .. seealso:: + + :class:`.CreateSequence` + + :class:`.DropSequence` + + """ + + __visit_name__ = 'sequence' + + is_sequence = True + + def __init__(self, name, start=None, increment=None, schema=None, + optional=False, quote=None, metadata=None, + quote_schema=None, + for_update=False): + """Construct a :class:`.Sequence` object. + + :param name: The name of the sequence. + :param start: the starting index of the sequence. This value is + used when the CREATE SEQUENCE command is emitted to the database + as the value of the "START WITH" clause. If ``None``, the + clause is omitted, which on most platforms indicates a starting + value of 1. + :param increment: the increment value of the sequence. This + value is used when the CREATE SEQUENCE command is emitted to + the database as the value of the "INCREMENT BY" clause. If ``None``, + the clause is omitted, which on most platforms indicates an + increment of 1. + :param schema: Optional schema name for the sequence, if located + in a schema other than the default. + :param optional: boolean value, when ``True``, indicates that this + :class:`.Sequence` object only needs to be explicitly generated + on backends that don't provide another way to generate primary + key identifiers. Currently, it essentially means, "don't create + this sequence on the Postgresql backend, where the SERIAL keyword + creates a sequence for us automatically". + :param quote: boolean value, when ``True`` or ``False``, explicitly + forces quoting of the schema name on or off. When left at its + default of ``None``, normal quoting rules based on casing and reserved + words take place. + :param quote_schema: set the quoting preferences for the ``schema`` + name. + :param metadata: optional :class:`.MetaData` object which will be + associated with this :class:`.Sequence`. A :class:`.Sequence` + that is associated with a :class:`.MetaData` gains access to the + ``bind`` of that :class:`.MetaData`, meaning the + :meth:`.Sequence.create` and :meth:`.Sequence.drop` methods will + make usage of that engine automatically. + + .. versionchanged:: 0.7 + Additionally, the appropriate CREATE SEQUENCE/ + DROP SEQUENCE DDL commands will be emitted corresponding to this + :class:`.Sequence` when :meth:`.MetaData.create_all` and + :meth:`.MetaData.drop_all` are invoked. + + Note that when a :class:`.Sequence` is applied to a :class:`.Column`, + the :class:`.Sequence` is automatically associated with the + :class:`.MetaData` object of that column's parent :class:`.Table`, + when that association is made. The :class:`.Sequence` will then + be subject to automatic CREATE SEQUENCE/DROP SEQUENCE corresponding + to when the :class:`.Table` object itself is created or dropped, + rather than that of the :class:`.MetaData` object overall. + :param for_update: Indicates this :class:`.Sequence`, when associated + with a :class:`.Column`, should be invoked for UPDATE statements + on that column's table, rather than for INSERT statements, when + no value is otherwise present for that column in the statement. + + """ + super(Sequence, self).__init__(for_update=for_update) + self.name = quoted_name(name, quote) + self.start = start + self.increment = increment + self.optional = optional + if metadata is not None and schema is None and metadata.schema: + self.schema = schema = metadata.schema + else: + self.schema = quoted_name(schema, quote_schema) + self.metadata = metadata + self._key = _get_table_key(name, schema) + if metadata: + self._set_metadata(metadata) + + @util.memoized_property + def is_callable(self): + return False + + @util.memoized_property + def is_clause_element(self): + return False + + @util.dependencies("sqlalchemy.sql.functions.func") + def next_value(self, func): + """Return a :class:`.next_value` function element + which will render the appropriate increment function + for this :class:`.Sequence` within any SQL expression. + + """ + return func.next_value(self, bind=self.bind) + + def _set_parent(self, column): + super(Sequence, self)._set_parent(column) + column._on_table_attach(self._set_table) + + def _set_table(self, column, table): + self._set_metadata(table.metadata) + + def _set_metadata(self, metadata): + self.metadata = metadata + self.metadata._sequences[self._key] = self + + @property + def bind(self): + if self.metadata: + return self.metadata.bind + else: + return None + + def create(self, bind=None, checkfirst=True): + """Creates this sequence in the database.""" + + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaGenerator, + self, + checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Drops this sequence from the database.""" + + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaDropper, + self, + checkfirst=checkfirst) + + def _not_a_column_expr(self): + raise exc.InvalidRequestError( + "This %s cannot be used directly " + "as a column expression. Use func.next_value(sequence) " + "to produce a 'next value' function that's usable " + "as a column element." + % self.__class__.__name__) + + +@inspection._self_inspects +class FetchedValue(_NotAColumnExpr, SchemaEventTarget): + """A marker for a transparent database-side default. + + Use :class:`.FetchedValue` when the database is configured + to provide some automatic default for a column. + + E.g.:: + + Column('foo', Integer, FetchedValue()) + + Would indicate that some trigger or default generator + will create a new value for the ``foo`` column during an + INSERT. + + .. seealso:: + + :ref:`triggered_columns` + + """ + is_server_default = True + reflected = False + has_argument = False + + def __init__(self, for_update=False): + self.for_update = for_update + + def _as_for_update(self, for_update): + if for_update == self.for_update: + return self + else: + return self._clone(for_update) + + def _clone(self, for_update): + n = self.__class__.__new__(self.__class__) + n.__dict__.update(self.__dict__) + n.__dict__.pop('column', None) + n.for_update = for_update + return n + + def _set_parent(self, column): + self.column = column + if self.for_update: + self.column.server_onupdate = self + else: + self.column.server_default = self + + def __repr__(self): + return util.generic_repr(self) + + +class DefaultClause(FetchedValue): + """A DDL-specified DEFAULT column value. + + :class:`.DefaultClause` is a :class:`.FetchedValue` + that also generates a "DEFAULT" clause when + "CREATE TABLE" is emitted. + + :class:`.DefaultClause` is generated automatically + whenever the ``server_default``, ``server_onupdate`` arguments of + :class:`.Column` are used. A :class:`.DefaultClause` + can be passed positionally as well. + + For example, the following:: + + Column('foo', Integer, server_default="50") + + Is equivalent to:: + + Column('foo', Integer, DefaultClause("50")) + + """ + + has_argument = True + + def __init__(self, arg, for_update=False, _reflected=False): + util.assert_arg_type(arg, (util.string_types[0], + ClauseElement, + TextClause), 'arg') + super(DefaultClause, self).__init__(for_update) + self.arg = arg + self.reflected = _reflected + + def __repr__(self): + return "DefaultClause(%r, for_update=%r)" % \ + (self.arg, self.for_update) + + +class PassiveDefault(DefaultClause): + """A DDL-specified DEFAULT column value. + + .. deprecated:: 0.6 + :class:`.PassiveDefault` is deprecated. + Use :class:`.DefaultClause`. + """ + @util.deprecated("0.6", + ":class:`.PassiveDefault` is deprecated. " + "Use :class:`.DefaultClause`.", + False) + def __init__(self, *arg, **kw): + DefaultClause.__init__(self, *arg, **kw) + + +class Constraint(DialectKWArgs, SchemaItem): + """A table-level SQL constraint.""" + + __visit_name__ = 'constraint' + + def __init__(self, name=None, deferrable=None, initially=None, + _create_rule=None, + **dialect_kw): + """Create a SQL constraint. + + :param name: + Optional, the in-database name of this ``Constraint``. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param _create_rule: + a callable which is passed the DDLCompiler object during + compilation. Returns True or False to signal inline generation of + this Constraint. + + The AddConstraint and DropConstraint DDL constructs provide + DDLElement's more comprehensive "conditional DDL" approach that is + passed a database connection when DDL is being issued. _create_rule + is instead called during any CREATE TABLE compilation, where there + may not be any transaction/connection in progress. However, it + allows conditional compilation of the constraint even for backends + which do not support addition of constraints through ALTER TABLE, + which currently includes SQLite. + + _create_rule is used by some types to create constraints. + Currently, its call signature is subject to change at any time. + + :param \**dialect_kw: Additional keyword arguments are dialect specific, + and passed in the form ``_``. See the + documentation regarding an individual dialect at :ref:`dialect_toplevel` + for detail on documented arguments. + + """ + + self.name = name + self.deferrable = deferrable + self.initially = initially + self._create_rule = _create_rule + util.set_creation_order(self) + self._validate_dialect_kwargs(dialect_kw) + + @property + def table(self): + try: + if isinstance(self.parent, Table): + return self.parent + except AttributeError: + pass + raise exc.InvalidRequestError( + "This constraint is not bound to a table. Did you " + "mean to call table.append_constraint(constraint) ?") + + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + + def copy(self, **kw): + raise NotImplementedError() + + +def _to_schema_column(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, Column): + raise exc.ArgumentError("schema.Column object expected") + return element + + +def _to_schema_column_or_string(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, util.string_types + (ColumnElement, )): + msg = "Element %r is not a string name or column element" + raise exc.ArgumentError(msg % element) + return element + + +class ColumnCollectionMixin(object): + def __init__(self, *columns): + self.columns = ColumnCollection() + self._pending_colargs = [_to_schema_column_or_string(c) + for c in columns] + if self._pending_colargs and \ + isinstance(self._pending_colargs[0], Column) and \ + isinstance(self._pending_colargs[0].table, Table): + self._set_parent_with_dispatch(self._pending_colargs[0].table) + + def _set_parent(self, table): + for col in self._pending_colargs: + if isinstance(col, util.string_types): + col = table.c[col] + self.columns.add(col) + + +class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): + """A constraint that proxies a ColumnCollection.""" + + def __init__(self, *columns, **kw): + """ + :param \*columns: + A sequence of column names or Column objects. + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param \**kw: other keyword arguments including dialect-specific + arguments are propagated to the :class:`.Constraint` superclass. + + """ + Constraint.__init__(self, **kw) + ColumnCollectionMixin.__init__(self, *columns) + + def _set_parent(self, table): + Constraint._set_parent(self, table) + ColumnCollectionMixin._set_parent(self, table) + + def __contains__(self, x): + return x in self.columns + + def copy(self, **kw): + c = self.__class__(name=self.name, deferrable=self.deferrable, + initially=self.initially, *self.columns.keys()) + return self._schema_item_copy(c) + + def contains_column(self, col): + return self.columns.contains_column(col) + + def __iter__(self): + # inlining of + # return iter(self.columns) + # ColumnCollection->OrderedProperties->OrderedDict + ordered_dict = self.columns._data + return (ordered_dict[key] for key in ordered_dict._list) + + def __len__(self): + return len(self.columns._data) + + +class CheckConstraint(Constraint): + """A table- or column-level CHECK constraint. + + Can be included in the definition of a Table or Column. + """ + + def __init__(self, sqltext, name=None, deferrable=None, + initially=None, table=None, _create_rule=None, + _autoattach=True): + """Construct a CHECK constraint. + + :param sqltext: + A string containing the constraint definition, which will be used + verbatim, or a SQL expression construct. If given as a string, + the object is converted to a :class:`.Text` object. If the textual + string includes a colon character, escape this using a backslash:: + + CheckConstraint(r"foo ~ E'a(?\:b|c)d") + + :param name: + Optional, the in-database name of the constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + """ + + super(CheckConstraint, self).\ + __init__(name, deferrable, initially, _create_rule) + self.sqltext = _literal_as_text(sqltext) + if table is not None: + self._set_parent_with_dispatch(table) + elif _autoattach: + cols = _find_columns(self.sqltext) + tables = set([c.table for c in cols + if isinstance(c.table, Table)]) + if len(tables) == 1: + self._set_parent_with_dispatch( + tables.pop()) + + def __visit_name__(self): + if isinstance(self.parent, Table): + return "check_constraint" + else: + return "column_check_constraint" + __visit_name__ = property(__visit_name__) + + def copy(self, target_table=None, **kw): + if target_table is not None: + def replace(col): + if self.table.c.contains_column(col): + return target_table.c[col.key] + else: + return None + sqltext = visitors.replacement_traverse(self.sqltext, {}, replace) + else: + sqltext = self.sqltext + c = CheckConstraint(sqltext, + name=self.name, + initially=self.initially, + deferrable=self.deferrable, + _create_rule=self._create_rule, + table=target_table, + _autoattach=False) + return self._schema_item_copy(c) + + +class ForeignKeyConstraint(Constraint): + """A table-level FOREIGN KEY constraint. + + Defines a single column or composite FOREIGN KEY ... REFERENCES + constraint. For a no-frills, single column foreign key, adding a + :class:`.ForeignKey` to the definition of a :class:`.Column` is a shorthand + equivalent for an unnamed, single column :class:`.ForeignKeyConstraint`. + + Examples of foreign key configuration are in :ref:`metadata_foreignkeys`. + + """ + __visit_name__ = 'foreign_key_constraint' + + def __init__(self, columns, refcolumns, name=None, onupdate=None, + ondelete=None, deferrable=None, initially=None, use_alter=False, + link_to_name=False, match=None, table=None, **dialect_kw): + """Construct a composite-capable FOREIGN KEY. + + :param columns: A sequence of local column names. The named columns + must be defined and present in the parent Table. The names should + match the ``key`` given to each column (defaults to the name) unless + ``link_to_name`` is True. + + :param refcolumns: A sequence of foreign column names or Column + objects. The columns must all be located within the same Table. + + :param name: Optional, the in-database name of the key. + + :param onupdate: Optional string. If set, emit ON UPDATE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param ondelete: Optional string. If set, emit ON DELETE when + issuing DDL for this constraint. Typical values include CASCADE, + DELETE and RESTRICT. + + :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT + DEFERRABLE when issuing DDL for this constraint. + + :param initially: Optional string. If set, emit INITIALLY when + issuing DDL for this constraint. + + :param link_to_name: if True, the string name given in ``column`` is + the rendered name of the referenced column, not its locally assigned + ``key``. + + :param use_alter: If True, do not emit the DDL for this constraint as + part of the CREATE TABLE definition. Instead, generate it via an + ALTER TABLE statement issued after the full collection of tables + have been created, and drop it via an ALTER TABLE statement before + the full collection of tables are dropped. This is shorthand for the + usage of :class:`.AddConstraint` and :class:`.DropConstraint` applied + as "after-create" and "before-drop" events on the MetaData object. + This is normally used to generate/drop constraints on objects that + are mutually dependent on each other. + + :param match: Optional string. If set, emit MATCH when issuing + DDL for this constraint. Typical values include SIMPLE, PARTIAL + and FULL. + + :param \**dialect_kw: Additional keyword arguments are dialect specific, + and passed in the form ``_``. See the + documentation regarding an individual dialect at :ref:`dialect_toplevel` + for detail on documented arguments. + + .. versionadded:: 0.9.2 + + """ + super(ForeignKeyConstraint, self).\ + __init__(name, deferrable, initially, **dialect_kw) + + self.onupdate = onupdate + self.ondelete = ondelete + self.link_to_name = link_to_name + if self.name is None and use_alter: + raise exc.ArgumentError("Alterable Constraint requires a name") + self.use_alter = use_alter + self.match = match + + self._elements = util.OrderedDict() + + # standalone ForeignKeyConstraint - create + # associated ForeignKey objects which will be applied to hosted + # Column objects (in col.foreign_keys), either now or when attached + # to the Table for string-specified names + for col, refcol in zip(columns, refcolumns): + self._elements[col] = ForeignKey( + refcol, + _constraint=self, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter, + link_to_name=self.link_to_name, + match=self.match, + deferrable=self.deferrable, + initially=self.initially, + **self.dialect_kwargs + ) + + if table is not None: + self._set_parent_with_dispatch(table) + elif columns and \ + isinstance(columns[0], Column) and \ + columns[0].table is not None: + self._set_parent_with_dispatch(columns[0].table) + + @property + def _referred_schema(self): + for elem in self._elements.values(): + return elem._referred_schema + else: + return None + + def _validate_dest_table(self, table): + table_keys = set([elem._table_key() for elem in self._elements.values()]) + if None not in table_keys and len(table_keys) > 1: + elem0, elem1 = sorted(table_keys)[0:2] + raise exc.ArgumentError( + 'ForeignKeyConstraint on %s(%s) refers to ' + 'multiple remote tables: %s and %s' % ( + table.fullname, + self._col_description, + elem0, + elem1 + )) + + @property + def _col_description(self): + return ", ".join(self._elements) + + @property + def columns(self): + return list(self._elements) + + @property + def elements(self): + return list(self._elements.values()) + + def _set_parent(self, table): + super(ForeignKeyConstraint, self)._set_parent(table) + + self._validate_dest_table(table) + + for col, fk in self._elements.items(): + # string-specified column names now get + # resolved to Column objects + if isinstance(col, util.string_types): + try: + col = table.c[col] + except KeyError: + raise exc.ArgumentError( + "Can't create ForeignKeyConstraint " + "on table '%s': no column " + "named '%s' is present." % (table.description, col)) + + if not hasattr(fk, 'parent') or \ + fk.parent is not col: + fk._set_parent_with_dispatch(col) + + if self.use_alter: + def supports_alter(ddl, event, schema_item, bind, **kw): + return table in set(kw['tables']) and \ + bind.dialect.supports_alter + + event.listen(table.metadata, "after_create", + ddl.AddConstraint(self, on=supports_alter)) + event.listen(table.metadata, "before_drop", + ddl.DropConstraint(self, on=supports_alter)) + + def copy(self, schema=None, **kw): + fkc = ForeignKeyConstraint( + [x.parent.key for x in self._elements.values()], + [x._get_colspec(schema=schema) for x in self._elements.values()], + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter, + deferrable=self.deferrable, + initially=self.initially, + link_to_name=self.link_to_name, + match=self.match + ) + for self_fk, other_fk in zip( + self._elements.values(), + fkc._elements.values()): + self_fk._schema_item_copy(other_fk) + return self._schema_item_copy(fkc) + + +class PrimaryKeyConstraint(ColumnCollectionConstraint): + """A table-level PRIMARY KEY constraint. + + The :class:`.PrimaryKeyConstraint` object is present automatically + on any :class:`.Table` object; it is assigned a set of + :class:`.Column` objects corresponding to those marked with + the :paramref:`.Column.primary_key` flag:: + + >>> my_table = Table('mytable', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('version_id', Integer, primary_key=True), + ... Column('data', String(50)) + ... ) + >>> my_table.primary_key + PrimaryKeyConstraint( + Column('id', Integer(), table=, primary_key=True, nullable=False), + Column('version_id', Integer(), table=, primary_key=True, nullable=False) + ) + + The primary key of a :class:`.Table` can also be specified by using + a :class:`.PrimaryKeyConstraint` object explicitly; in this mode of usage, + the "name" of the constraint can also be specified, as well as other + options which may be recognized by dialects:: + + my_table = Table('mytable', metadata, + Column('id', Integer), + Column('version_id', Integer), + Column('data', String(50)), + PrimaryKeyConstraint('id', 'version_id', name='mytable_pk') + ) + + The two styles of column-specification should generally not be mixed. + An warning is emitted if the columns present in the + :class:`.PrimaryKeyConstraint` + don't match the columns that were marked as ``primary_key=True``, if both + are present; in this case, the columns are taken strictly from the + :class:`.PrimaryKeyConstraint` declaration, and those columns otherwise marked + as ``primary_key=True`` are ignored. This behavior is intended to be + backwards compatible with previous behavior. + + .. versionchanged:: 0.9.2 Using a mixture of columns within a + :class:`.PrimaryKeyConstraint` in addition to columns marked as + ``primary_key=True`` now emits a warning if the lists don't match. + The ultimate behavior of ignoring those columns marked with the flag + only is currently maintained for backwards compatibility; this warning + may raise an exception in a future release. + + For the use case where specific options are to be specified on the + :class:`.PrimaryKeyConstraint`, but the usual style of using ``primary_key=True`` + flags is still desirable, an empty :class:`.PrimaryKeyConstraint` may be + specified, which will take on the primary key column collection from + the :class:`.Table` based on the flags:: + + my_table = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('version_id', Integer, primary_key=True), + Column('data', String(50)), + PrimaryKeyConstraint(name='mytable_pk', mssql_clustered=True) + ) + + .. versionadded:: 0.9.2 an empty :class:`.PrimaryKeyConstraint` may now + be specified for the purposes of establishing keyword arguments with the + constraint, independently of the specification of "primary key" columns + within the :class:`.Table` itself; columns marked as ``primary_key=True`` + will be gathered into the empty constraint's column collection. + + """ + + __visit_name__ = 'primary_key_constraint' + + def _set_parent(self, table): + super(PrimaryKeyConstraint, self)._set_parent(table) + + if table.primary_key is not self: + table.constraints.discard(table.primary_key) + table.primary_key = self + table.constraints.add(self) + + table_pks = [c for c in table.c if c.primary_key] + if self.columns and table_pks and \ + set(table_pks) != set(self.columns.values()): + util.warn( + "Table '%s' specifies columns %s as primary_key=True, " + "not matching locally specified columns %s; setting the " + "current primary key columns to %s. This warning " + "may become an exception in a future release" % + ( + table.name, + ", ".join("'%s'" % c.name for c in table_pks), + ", ".join("'%s'" % c.name for c in self.columns), + ", ".join("'%s'" % c.name for c in self.columns) + ) + ) + table_pks[:] = [] + + for c in self.columns: + c.primary_key = True + self.columns.extend(table_pks) + + def _reload(self, columns): + """repopulate this :class:`.PrimaryKeyConstraint` given + a set of columns. + + Existing columns in the table that are marked as primary_key=True + are maintained. + + Also fires a new event. + + This is basically like putting a whole new + :class:`.PrimaryKeyConstraint` object on the parent + :class:`.Table` object without actually replacing the object. + + The ordering of the given list of columns is also maintained; these + columns will be appended to the list of columns after any which + are already present. + + """ + + # set the primary key flag on new columns. + # note any existing PK cols on the table also have their + # flag still set. + for col in columns: + col.primary_key = True + + self.columns.extend(columns) + + self._set_parent_with_dispatch(self.table) + + def _replace(self, col): + self.columns.replace(col) + + +class UniqueConstraint(ColumnCollectionConstraint): + """A table-level UNIQUE constraint. + + Defines a single column or composite UNIQUE constraint. For a no-frills, + single column constraint, adding ``unique=True`` to the ``Column`` + definition is a shorthand equivalent for an unnamed, single column + UniqueConstraint. + """ + + __visit_name__ = 'unique_constraint' + + +class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): + """A table-level INDEX. + + Defines a composite (one or more column) INDEX. + + E.g.:: + + sometable = Table("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)) + ) + + Index("some_index", sometable.c.name) + + For a no-frills, single column index, adding + :class:`.Column` also supports ``index=True``:: + + sometable = Table("sometable", metadata, + Column("name", String(50), index=True) + ) + + For a composite index, multiple columns can be specified:: + + Index("some_index", sometable.c.name, sometable.c.address) + + Functional indexes are supported as well, keeping in mind that at least + one :class:`.Column` must be present:: + + Index("some_index", func.lower(sometable.c.name)) + + .. versionadded:: 0.8 support for functional and expression-based indexes. + + .. seealso:: + + :ref:`schema_indexes` - General information on :class:`.Index`. + + :ref:`postgresql_indexes` - PostgreSQL-specific options available for the + :class:`.Index` construct. + + :ref:`mysql_indexes` - MySQL-specific options available for the + :class:`.Index` construct. + + :ref:`mssql_indexes` - MSSQL-specific options available for the + :class:`.Index` construct. + + """ + + __visit_name__ = 'index' + + def __init__(self, name, *expressions, **kw): + """Construct an index object. + + :param name: + The name of the index + + :param \*expressions: + Column expressions to include in the index. The expressions + are normally instances of :class:`.Column`, but may also + be arbitrary SQL expressions which ultmately refer to a + :class:`.Column`. + + :param unique=False: + Keyword only argument; if True, create a unique index. + + :param quote=None: + Keyword only argument; whether to apply quoting to the name of + the index. Works in the same manner as that of + :paramref:`.Column.quote`. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ + self.table = None + + columns = [] + for expr in expressions: + if not isinstance(expr, ClauseElement): + columns.append(expr) + else: + cols = [] + visitors.traverse(expr, {}, {'column': cols.append}) + if cols: + columns.append(cols[0]) + else: + columns.append(expr) + + self.expressions = expressions + self.name = quoted_name(name, kw.pop("quote", None)) + self.unique = kw.pop('unique', False) + self._validate_dialect_kwargs(kw) + + # will call _set_parent() if table-bound column + # objects are present + ColumnCollectionMixin.__init__(self, *columns) + + + + def _set_parent(self, table): + ColumnCollectionMixin._set_parent(self, table) + + if self.table is not None and table is not self.table: + raise exc.ArgumentError( + "Index '%s' is against table '%s', and " + "cannot be associated with table '%s'." % ( + self.name, + self.table.description, + table.description + ) + ) + self.table = table + for c in self.columns: + if c.table != self.table: + raise exc.ArgumentError( + "Column '%s' is not part of table '%s'." % + (c, self.table.description) + ) + table.indexes.add(self) + + self.expressions = [ + expr if isinstance(expr, ClauseElement) + else colexpr + for expr, colexpr in zip(self.expressions, self.columns) + ] + + @property + def bind(self): + """Return the connectable associated with this Index.""" + + return self.table.bind + + def create(self, bind=None): + """Issue a ``CREATE`` statement for this + :class:`.Index`, using the given :class:`.Connectable` + for connectivity. + + .. seealso:: + + :meth:`.MetaData.create_all`. + + """ + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaGenerator, self) + return self + + def drop(self, bind=None): + """Issue a ``DROP`` statement for this + :class:`.Index`, using the given :class:`.Connectable` + for connectivity. + + .. seealso:: + + :meth:`.MetaData.drop_all`. + + """ + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaDropper, self) + + def __repr__(self): + return 'Index(%s)' % ( + ", ".join( + [repr(self.name)] + + [repr(c) for c in self.columns] + + (self.unique and ["unique=True"] or []) + )) + + +DEFAULT_NAMING_CONVENTION = util.immutabledict({ + "ix": 'ix_%(column_0_label)s' +}) + + +class MetaData(SchemaItem): + """A collection of :class:`.Table` objects and their associated schema + constructs. + + Holds a collection of :class:`.Table` objects as well as + an optional binding to an :class:`.Engine` or + :class:`.Connection`. If bound, the :class:`.Table` objects + in the collection and their columns may participate in implicit SQL + execution. + + The :class:`.Table` objects themselves are stored in the + :attr:`.MetaData.tables` dictionary. + + :class:`.MetaData` is a thread-safe object for read operations. Construction + of new tables within a single :class:`.MetaData` object, either explicitly + or via reflection, may not be completely thread-safe. + + .. seealso:: + + :ref:`metadata_describing` - Introduction to database metadata + + """ + + __visit_name__ = 'metadata' + + def __init__(self, bind=None, reflect=False, schema=None, + quote_schema=None, + naming_convention=DEFAULT_NAMING_CONVENTION + ): + """Create a new MetaData object. + + :param bind: + An Engine or Connection to bind to. May also be a string or URL + instance, these are passed to create_engine() and this MetaData will + be bound to the resulting engine. + + :param reflect: + Optional, automatically load all tables from the bound database. + Defaults to False. ``bind`` is required when this option is set. + + .. deprecated:: 0.8 + Please use the :meth:`.MetaData.reflect` method. + + :param schema: + The default schema to use for the :class:`.Table`, + :class:`.Sequence`, and other objects associated with this + :class:`.MetaData`. Defaults to ``None``. + + :param quote_schema: + Sets the ``quote_schema`` flag for those :class:`.Table`, + :class:`.Sequence`, and other objects which make usage of the + local ``schema`` name. + + :param naming_convention: a dictionary referring to values which + will establish default naming conventions for :class:`.Constraint` + and :class:`.Index` objects, for those objects which are not given + a name explicitly. + + The keys of this dictionary may be: + + * a constraint or Index class, e.g. the :class:`.UniqueConstraint`, + :class:`.ForeignKeyConstraint` class, the :class:`.Index` class + + * a string mnemonic for one of the known constraint classes; + ``"fk"``, ``"pk"``, ``"ix"``, ``"ck"``, ``"uq"`` for foreign key, + primary key, index, check, and unique constraint, respectively. + + * the string name of a user-defined "token" that can be used + to define new naming tokens. + + The values associated with each "constraint class" or "constraint + mnemonic" key are string naming templates, such as + ``"uq_%(table_name)s_%(column_0_name)s"``, + which decribe how the name should be composed. The values associated + with user-defined "token" keys should be callables of the form + ``fn(constraint, table)``, which accepts the constraint/index + object and :class:`.Table` as arguments, returning a string + result. + + The built-in names are as follows, some of which may only be + available for certain types of constraint: + + * ``%(table_name)s`` - the name of the :class:`.Table` object + associated with the constraint. + + * ``%(referred_table_name)s`` - the name of the :class:`.Table` + object associated with the referencing target of a + :class:`.ForeignKeyConstraint`. + + * ``%(column_0_name)s`` - the name of the :class:`.Column` at + index position "0" within the constraint. + + * ``%(column_0_label)s`` - the label of the :class:`.Column` at + index position "0", e.g. :attr:`.Column.label` + + * ``%(column_0_key)s`` - the key of the :class:`.Column` at + index position "0", e.g. :attr:`.Column.key` + + * ``%(referred_column_0_name)s`` - the name of a :class:`.Column` + at index position "0" referenced by a :class:`.ForeignKeyConstraint`. + + * ``%(constraint_name)s`` - a special key that refers to the existing + name given to the constraint. When this key is present, the + :class:`.Constraint` object's existing name will be replaced with + one that is composed from template string that uses this token. + When this token is present, it is required that the :class:`.Constraint` + is given an expicit name ahead of time. + + * user-defined: any additional token may be implemented by passing + it along with a ``fn(constraint, table)`` callable to the + naming_convention dictionary. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :ref:`constraint_naming_conventions` - for detailed usage + examples. + + """ + self.tables = util.immutabledict() + self.schema = quoted_name(schema, quote_schema) + self.naming_convention = naming_convention + self._schemas = set() + self._sequences = {} + self._fk_memos = collections.defaultdict(list) + + self.bind = bind + if reflect: + util.warn_deprecated("reflect=True is deprecate; please " + "use the reflect() method.") + if not bind: + raise exc.ArgumentError( + "A bind must be supplied in conjunction " + "with reflect=True") + self.reflect() + + tables = None + """A dictionary of :class:`.Table` objects keyed to their name or "table key". + + The exact key is that determined by the :attr:`.Table.key` attribute; + for a table with no :attr:`.Table.schema` attribute, this is the same + as :attr:`.Table.name`. For a table with a schema, it is typically of the + form ``schemaname.tablename``. + + .. seealso:: + + :attr:`.MetaData.sorted_tables` + + """ + + def __repr__(self): + return 'MetaData(bind=%r)' % self.bind + + def __contains__(self, table_or_key): + if not isinstance(table_or_key, util.string_types): + table_or_key = table_or_key.key + return table_or_key in self.tables + + def _add_table(self, name, schema, table): + key = _get_table_key(name, schema) + dict.__setitem__(self.tables, key, table) + if schema: + self._schemas.add(schema) + + + + def _remove_table(self, name, schema): + key = _get_table_key(name, schema) + removed = dict.pop(self.tables, key, None) + if removed is not None: + for fk in removed.foreign_keys: + fk._remove_from_metadata(self) + if self._schemas: + self._schemas = set([t.schema + for t in self.tables.values() + if t.schema is not None]) + + + def __getstate__(self): + return {'tables': self.tables, + 'schema': self.schema, + 'schemas': self._schemas, + 'sequences': self._sequences, + 'fk_memos': self._fk_memos} + + def __setstate__(self, state): + self.tables = state['tables'] + self.schema = state['schema'] + self._bind = None + self._sequences = state['sequences'] + self._schemas = state['schemas'] + self._fk_memos = state['fk_memos'] + + def is_bound(self): + """True if this MetaData is bound to an Engine or Connection.""" + + return self._bind is not None + + def bind(self): + """An :class:`.Engine` or :class:`.Connection` to which this + :class:`.MetaData` is bound. + + Typically, a :class:`.Engine` is assigned to this attribute + so that "implicit execution" may be used, or alternatively + as a means of providing engine binding information to an + ORM :class:`.Session` object:: + + engine = create_engine("someurl://") + metadata.bind = engine + + .. seealso:: + + :ref:`dbengine_implicit` - background on "bound metadata" + + """ + return self._bind + + @util.dependencies("sqlalchemy.engine.url") + def _bind_to(self, url, bind): + """Bind this MetaData to an Engine, Connection, string or URL.""" + + if isinstance(bind, util.string_types + (url.URL, )): + self._bind = sqlalchemy.create_engine(bind) + else: + self._bind = bind + bind = property(bind, _bind_to) + + def clear(self): + """Clear all Table objects from this MetaData.""" + + dict.clear(self.tables) + self._schemas.clear() + self._fk_memos.clear() + + def remove(self, table): + """Remove the given Table object from this MetaData.""" + + self._remove_table(table.name, table.schema) + + @property + def sorted_tables(self): + """Returns a list of :class:`.Table` objects sorted in order of + foreign key dependency. + + The sorting will place :class:`.Table` objects that have dependencies + first, before the dependencies themselves, representing the + order in which they can be created. To get the order in which + the tables would be dropped, use the ``reversed()`` Python built-in. + + .. seealso:: + + :attr:`.MetaData.tables` + + :meth:`.Inspector.get_table_names` + + """ + return ddl.sort_tables(self.tables.values()) + + def reflect(self, bind=None, schema=None, views=False, only=None, + extend_existing=False, + autoload_replace=True, + **dialect_kwargs): + """Load all available table definitions from the database. + + Automatically creates ``Table`` entries in this ``MetaData`` for any + table available in the database but not yet present in the + ``MetaData``. May be called multiple times to pick up tables recently + added to the database, however no special action is taken if a table + in this ``MetaData`` no longer exists in the database. + + :param bind: + A :class:`.Connectable` used to access the database; if None, uses + the existing bind on this ``MetaData``, if any. + + :param schema: + Optional, query and reflect tables from an alterate schema. + If None, the schema associated with this :class:`.MetaData` + is used, if any. + + :param views: + If True, also reflect views. + + :param only: + Optional. Load only a sub-set of available named tables. May be + specified as a sequence of names or a callable. + + If a sequence of names is provided, only those tables will be + reflected. An error is raised if a table is requested but not + available. Named tables already present in this ``MetaData`` are + ignored. + + If a callable is provided, it will be used as a boolean predicate to + filter the list of potential table names. The callable is called + with a table name and this ``MetaData`` instance as positional + arguments and should return a true value for any table to reflect. + + :param extend_existing: Passed along to each :class:`.Table` as + :paramref:`.Table.extend_existing`. + + .. versionadded:: 0.9.1 + + :param autoload_replace: Passed along to each :class:`.Table` as + :paramref:`.Table.autoload_replace`. + + .. versionadded:: 0.9.1 + + :param \**dialect_kwargs: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + .. versionadded:: 0.9.2 - Added :paramref:`.MetaData.reflect.**dialect_kwargs` + to support dialect-level reflection options for all :class:`.Table` + objects reflected. + + """ + if bind is None: + bind = _bind_or_error(self) + + with bind.connect() as conn: + + reflect_opts = { + 'autoload': True, + 'autoload_with': conn, + 'extend_existing': extend_existing, + 'autoload_replace': autoload_replace + } + + reflect_opts.update(dialect_kwargs) + + if schema is None: + schema = self.schema + + if schema is not None: + reflect_opts['schema'] = schema + + available = util.OrderedSet(bind.engine.table_names(schema, + connection=conn)) + if views: + available.update( + bind.dialect.get_view_names(conn, schema) + ) + + if schema is not None: + available_w_schema = util.OrderedSet(["%s.%s" % (schema, name) + for name in available]) + else: + available_w_schema = available + + current = set(self.tables) + + if only is None: + load = [name for name, schname in + zip(available, available_w_schema) + if extend_existing or schname not in current] + elif util.callable(only): + load = [name for name, schname in + zip(available, available_w_schema) + if (extend_existing or schname not in current) + and only(name, self)] + else: + missing = [name for name in only if name not in available] + if missing: + s = schema and (" schema '%s'" % schema) or '' + raise exc.InvalidRequestError( + 'Could not reflect: requested table(s) not available ' + 'in %s%s: (%s)' % + (bind.engine.url, s, ', '.join(missing))) + load = [name for name in only if extend_existing or + name not in current] + + for name in load: + Table(name, self, **reflect_opts) + + def append_ddl_listener(self, event_name, listener): + """Append a DDL event listener to this ``MetaData``. + + .. deprecated:: 0.7 + See :class:`.DDLEvents`. + + """ + def adapt_listener(target, connection, **kw): + tables = kw['tables'] + listener(event, target, connection, tables=tables) + + event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + + def create_all(self, bind=None, tables=None, checkfirst=True): + """Create all tables stored in this metadata. + + Conditional by default, will not attempt to recreate tables already + present in the target database. + + :param bind: + A :class:`.Connectable` used to access the + database; if None, uses the existing bind on this ``MetaData``, if + any. + + :param tables: + Optional list of ``Table`` objects, which is a subset of the total + tables in the ``MetaData`` (others are ignored). + + :param checkfirst: + Defaults to True, don't issue CREATEs for tables already present + in the target database. + + """ + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaGenerator, + self, + checkfirst=checkfirst, + tables=tables) + + def drop_all(self, bind=None, tables=None, checkfirst=True): + """Drop all tables stored in this metadata. + + Conditional by default, will not attempt to drop tables not present in + the target database. + + :param bind: + A :class:`.Connectable` used to access the + database; if None, uses the existing bind on this ``MetaData``, if + any. + + :param tables: + Optional list of ``Table`` objects, which is a subset of the + total tables in the ``MetaData`` (others are ignored). + + :param checkfirst: + Defaults to True, only issue DROPs for tables confirmed to be + present in the target database. + + """ + if bind is None: + bind = _bind_or_error(self) + bind._run_visitor(ddl.SchemaDropper, + self, + checkfirst=checkfirst, + tables=tables) + + +class ThreadLocalMetaData(MetaData): + """A MetaData variant that presents a different ``bind`` in every thread. + + Makes the ``bind`` property of the MetaData a thread-local value, allowing + this collection of tables to be bound to different ``Engine`` + implementations or connections in each thread. + + The ThreadLocalMetaData starts off bound to None in each thread. Binds + must be made explicitly by assigning to the ``bind`` property or using + ``connect()``. You can also re-bind dynamically multiple times per + thread, just like a regular ``MetaData``. + + """ + + __visit_name__ = 'metadata' + + def __init__(self): + """Construct a ThreadLocalMetaData.""" + + self.context = util.threading.local() + self.__engines = {} + super(ThreadLocalMetaData, self).__init__() + + def bind(self): + """The bound Engine or Connection for this thread. + + This property may be assigned an Engine or Connection, or assigned a + string or URL to automatically create a basic Engine for this bind + with ``create_engine()``.""" + + return getattr(self.context, '_engine', None) + + @util.dependencies("sqlalchemy.engine.url") + def _bind_to(self, url, bind): + """Bind to a Connectable in the caller's thread.""" + + if isinstance(bind, util.string_types + (url.URL, )): + try: + self.context._engine = self.__engines[bind] + except KeyError: + e = sqlalchemy.create_engine(bind) + self.__engines[bind] = e + self.context._engine = e + else: + # TODO: this is squirrely. we shouldnt have to hold onto engines + # in a case like this + if bind not in self.__engines: + self.__engines[bind] = bind + self.context._engine = bind + + bind = property(bind, _bind_to) + + def is_bound(self): + """True if there is a bind for this thread.""" + return (hasattr(self.context, '_engine') and + self.context._engine is not None) + + def dispose(self): + """Dispose all bound engines, in all thread contexts.""" + + for e in self.__engines.values(): + if hasattr(e, 'dispose'): + e.dispose() + + + diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py new file mode 100644 index 00000000..f64a70ec --- /dev/null +++ b/lib/sqlalchemy/sql/selectable.py @@ -0,0 +1,3115 @@ +# sql/selectable.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""The :class:`.FromClause` class of SQL expression elements, representing +SQL tables and derived rowsets. + +""" + +from .elements import ClauseElement, TextClause, ClauseList, \ + and_, Grouping, UnaryExpression, literal_column +from .elements import _clone, \ + _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\ + _select_iterables, _anonymous_label, _clause_element_as_expr,\ + _cloned_intersection, _cloned_difference, True_, _only_column_elements +from .base import Immutable, Executable, _generative, \ + ColumnCollection, ColumnSet, _from_objects, Generative +from . import type_api +from .. import inspection +from .. import util +from .. import exc +from operator import attrgetter +from . import operators +import operator +import collections +from .annotation import Annotated +import itertools + +def _interpret_as_from(element): + insp = inspection.inspect(element, raiseerr=False) + if insp is None: + if isinstance(element, util.string_types): + return TextClause(util.text_type(element)) + elif hasattr(insp, "selectable"): + return insp.selectable + raise exc.ArgumentError("FROM expression expected") + +def _interpret_as_select(element): + element = _interpret_as_from(element) + if isinstance(element, Alias): + element = element.original + if not isinstance(element, Select): + element = element.select() + return element + +def subquery(alias, *args, **kwargs): + """Return an :class:`.Alias` object derived + from a :class:`.Select`. + + name + alias name + + \*args, \**kwargs + + all other arguments are delivered to the + :func:`select` function. + + """ + return Select(*args, **kwargs).alias(alias) + + + +def alias(selectable, name=None, flat=False): + """Return an :class:`.Alias` object. + + An :class:`.Alias` represents any :class:`.FromClause` + with an alternate name assigned within SQL, typically using the ``AS`` + clause when generated, e.g. ``SELECT * FROM table AS aliasname``. + + Similar functionality is available via the + :meth:`~.FromClause.alias` method + available on all :class:`.FromClause` subclasses. + + When an :class:`.Alias` is created from a :class:`.Table` object, + this has the effect of the table being rendered + as ``tablename AS aliasname`` in a SELECT statement. + + For :func:`.select` objects, the effect is that of creating a named + subquery, i.e. ``(select ...) AS aliasname``. + + The ``name`` parameter is optional, and provides the name + to use in the rendered SQL. If blank, an "anonymous" name + will be deterministically generated at compile time. + Deterministic means the name is guaranteed to be unique against + other constructs used in the same statement, and will also be the + same name for each successive compilation of the same statement + object. + + :param selectable: any :class:`.FromClause` subclass, + such as a table, select statement, etc. + + :param name: string name to be assigned as the alias. + If ``None``, a name will be deterministically generated + at compile time. + + :param flat: Will be passed through to if the given selectable + is an instance of :class:`.Join` - see :meth:`.Join.alias` + for details. + + .. versionadded:: 0.9.0 + + """ + return selectable.alias(name=name, flat=flat) + + +class Selectable(ClauseElement): + """mark a class as being selectable""" + __visit_name__ = 'selectable' + + is_selectable = True + + @property + def selectable(self): + return self + + +class FromClause(Selectable): + """Represent an element that can be used within the ``FROM`` + clause of a ``SELECT`` statement. + + The most common forms of :class:`.FromClause` are the + :class:`.Table` and the :func:`.select` constructs. Key + features common to all :class:`.FromClause` objects include: + + * a :attr:`.c` collection, which provides per-name access to a collection + of :class:`.ColumnElement` objects. + * a :attr:`.primary_key` attribute, which is a collection of all those + :class:`.ColumnElement` objects that indicate the ``primary_key`` flag. + * Methods to generate various derivations of a "from" clause, including + :meth:`.FromClause.alias`, :meth:`.FromClause.join`, + :meth:`.FromClause.select`. + + + """ + __visit_name__ = 'fromclause' + named_with_column = False + _hide_froms = [] + + _is_join = False + _is_select = False + _is_from_container = False + + _textual = False + """a marker that allows us to easily distinguish a :class:`.TextAsFrom` + or similar object from other kinds of :class:`.FromClause` objects.""" + + schema = None + """Define the 'schema' attribute for this :class:`.FromClause`. + + This is typically ``None`` for most objects except that of :class:`.Table`, + where it is taken as the value of the :paramref:`.Table.schema` argument. + + """ + + _memoized_property = util.group_expirable_memoized_property(["_columns"]) + + @util.dependencies("sqlalchemy.sql.functions") + def count(self, functions, whereclause=None, **params): + """return a SELECT COUNT generated against this + :class:`.FromClause`.""" + + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return Select( + [functions.func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) + + def select(self, whereclause=None, **params): + """return a SELECT of this :class:`.FromClause`. + + .. seealso:: + + :func:`~.sql.expression.select` - general purpose + method which allows for arbitrary column lists. + + """ + + return Select([self], whereclause, **params) + + def join(self, right, onclause=None, isouter=False): + """Return a :class:`.Join` from this :class:`.FromClause` + to another :class:`FromClause`. + + E.g.:: + + from sqlalchemy import join + + j = user_table.join(address_table, + user_table.c.id == address_table.c.user_id) + stmt = select([user_table]).select_from(j) + + would emit SQL along the lines of:: + + SELECT user.id, user.name FROM user + JOIN address ON user.id = address.user_id + + :param right: the right side of the join; this is any :class:`.FromClause` + object such as a :class:`.Table` object, and may also be a selectable-compatible + object such as an ORM-mapped class. + + :param onclause: a SQL expression representing the ON clause of the + join. If left at ``None``, :meth:`.FromClause.join` will attempt to + join the two tables based on a foreign key relationship. + + :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN. + + .. seealso:: + + :func:`.join` - standalone function + + :class:`.Join` - the type of object produced + + """ + + return Join(self, right, onclause, isouter) + + def outerjoin(self, right, onclause=None): + """Return a :class:`.Join` from this :class:`.FromClause` + to another :class:`FromClause`, with the "isouter" flag set to + True. + + E.g.:: + + from sqlalchemy import outerjoin + + j = user_table.outerjoin(address_table, + user_table.c.id == address_table.c.user_id) + + The above is equivalent to:: + + j = user_table.join(address_table, + user_table.c.id == address_table.c.user_id, isouter=True) + + :param right: the right side of the join; this is any :class:`.FromClause` + object such as a :class:`.Table` object, and may also be a selectable-compatible + object such as an ORM-mapped class. + + :param onclause: a SQL expression representing the ON clause of the + join. If left at ``None``, :meth:`.FromClause.join` will attempt to + join the two tables based on a foreign key relationship. + + .. seealso:: + + :meth:`.FromClause.join` + + :class:`.Join` + + """ + + return Join(self, right, onclause, True) + + def alias(self, name=None, flat=False): + """return an alias of this :class:`.FromClause`. + + This is shorthand for calling:: + + from sqlalchemy import alias + a = alias(self, name=name) + + See :func:`~.expression.alias` for details. + + """ + + return Alias(self, name) + + def is_derived_from(self, fromclause): + """Return True if this FromClause is 'derived' from the given + FromClause. + + An example would be an Alias of a Table is derived from that Table. + + """ + # this is essentially an "identity" check in the base class. + # Other constructs override this to traverse through + # contained elements. + return fromclause in self._cloned_set + + def _is_lexical_equivalent(self, other): + """Return True if this FromClause and the other represent + the same lexical identity. + + This tests if either one is a copy of the other, or + if they are the same via annotation identity. + + """ + return self._cloned_set.intersection(other._cloned_set) + + @util.dependencies("sqlalchemy.sql.util") + def replace_selectable(self, sqlutil, old, alias): + """replace all occurrences of FromClause 'old' with the given Alias + object, returning a copy of this :class:`.FromClause`. + + """ + + return sqlutil.ClauseAdapter(alias).traverse(self) + + def correspond_on_equivalents(self, column, equivalents): + """Return corresponding_column for the given column, or if None + search for a match in the given dictionary. + + """ + col = self.corresponding_column(column, require_embedded=True) + if col is None and col in equivalents: + for equiv in equivalents[col]: + nc = self.corresponding_column(equiv, require_embedded=True) + if nc: + return nc + return col + + def corresponding_column(self, column, require_embedded=False): + """Given a :class:`.ColumnElement`, return the exported + :class:`.ColumnElement` object from this :class:`.Selectable` + which corresponds to that original + :class:`~sqlalchemy.schema.Column` via a common ancestor + column. + + :param column: the target :class:`.ColumnElement` to be matched + + :param require_embedded: only return corresponding columns for + the given :class:`.ColumnElement`, if the given :class:`.ColumnElement` + is actually present within a sub-element + of this :class:`.FromClause`. Normally the column will match if + it merely shares a common ancestor with one of the exported + columns of this :class:`.FromClause`. + + """ + + def embedded(expanded_proxy_set, target_set): + for t in target_set.difference(expanded_proxy_set): + if not set(_expand_cloned([t]) + ).intersection(expanded_proxy_set): + return False + return True + + # don't dig around if the column is locally present + if self.c.contains_column(column): + return column + col, intersect = None, None + target_set = column.proxy_set + cols = self.c._all_columns + for c in cols: + expanded_proxy_set = set(_expand_cloned(c.proxy_set)) + i = target_set.intersection(expanded_proxy_set) + if i and (not require_embedded + or embedded(expanded_proxy_set, target_set)): + if col is None: + + # no corresponding column yet, pick this one. + + col, intersect = c, i + elif len(i) > len(intersect): + + # 'c' has a larger field of correspondence than + # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x + # matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + + col, intersect = c, i + elif i == intersect: + + # they have the same field of correspondence. see + # which proxy_set has fewer columns in it, which + # indicates a closer relationship with the root + # column. Also take into account the "weight" + # attribute which CompoundSelect() uses to give + # higher precedence to columns based on vertical + # position in the compound statement, and discard + # columns that have no reference to the target + # column (also occurs with CompoundSelect) + + col_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) for sc in + col.proxy_set if sc.shares_lineage(column)]) + c_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) for sc in + c.proxy_set if sc.shares_lineage(column)]) + if c_distance < col_distance: + col, intersect = c, i + return col + + @property + def description(self): + """a brief description of this FromClause. + + Used primarily for error message formatting. + + """ + return getattr(self, 'name', self.__class__.__name__ + " object") + + def _reset_exported(self): + """delete memoized collections when a FromClause is cloned.""" + + self._memoized_property.expire_instance(self) + + @_memoized_property + def columns(self): + """A named-based collection of :class:`.ColumnElement` objects + maintained by this :class:`.FromClause`. + + The :attr:`.columns`, or :attr:`.c` collection, is the gateway + to the construction of SQL expressions using table-bound or + other selectable-bound columns:: + + select([mytable]).where(mytable.c.somecolumn == 5) + + """ + + if '_columns' not in self.__dict__: + self._init_collections() + self._populate_column_collection() + return self._columns.as_immutable() + + @_memoized_property + def primary_key(self): + """Return the collection of Column objects which comprise the + primary key of this FromClause.""" + + self._init_collections() + self._populate_column_collection() + return self.primary_key + + @_memoized_property + def foreign_keys(self): + """Return the collection of ForeignKey objects which this + FromClause references.""" + + self._init_collections() + self._populate_column_collection() + return self.foreign_keys + + c = property(attrgetter('columns'), + doc="An alias for the :attr:`.columns` attribute.") + _select_iterable = property(attrgetter('columns')) + + def _init_collections(self): + assert '_columns' not in self.__dict__ + assert 'primary_key' not in self.__dict__ + assert 'foreign_keys' not in self.__dict__ + + self._columns = ColumnCollection() + self.primary_key = ColumnSet() + self.foreign_keys = set() + + @property + def _cols_populated(self): + return '_columns' in self.__dict__ + + def _populate_column_collection(self): + """Called on subclasses to establish the .c collection. + + Each implementation has a different way of establishing + this collection. + + """ + + def _refresh_for_new_column(self, column): + """Given a column added to the .c collection of an underlying + selectable, produce the local version of that column, assuming this + selectable ultimately should proxy this column. + + this is used to "ping" a derived selectable to add a new column + to its .c. collection when a Column has been added to one of the + Table objects it ultimtely derives from. + + If the given selectable hasn't populated it's .c. collection yet, + it should at least pass on the message to the contained selectables, + but it will return None. + + This method is currently used by Declarative to allow Table + columns to be added to a partially constructed inheritance + mapping that may have already produced joins. The method + isn't public right now, as the full span of implications + and/or caveats aren't yet clear. + + It's also possible that this functionality could be invoked by + default via an event, which would require that + selectables maintain a weak referencing collection of all + derivations. + + """ + if not self._cols_populated: + return None + elif column.key in self.columns and self.columns[column.key] is column: + return column + else: + return None + + +class Join(FromClause): + """represent a ``JOIN`` construct between two :class:`.FromClause` + elements. + + The public constructor function for :class:`.Join` is the module-level + :func:`.join()` function, as well as the :meth:`.FromClause.join` method + of any :class:`.FromClause` (e.g. such as :class:`.Table`). + + .. seealso:: + + :func:`.join` + + :meth:`.FromClause.join` + + """ + __visit_name__ = 'join' + + _is_join = True + + def __init__(self, left, right, onclause=None, isouter=False): + """Construct a new :class:`.Join`. + + The usual entrypoint here is the :func:`~.expression.join` + function or the :meth:`.FromClause.join` method of any + :class:`.FromClause` object. + + """ + self.left = _interpret_as_from(left) + self.right = _interpret_as_from(right).self_group() + + if onclause is None: + self.onclause = self._match_primaries(self.left, self.right) + else: + self.onclause = onclause + + self.isouter = isouter + + @classmethod + def _create_outerjoin(cls, left, right, onclause=None): + """Return an ``OUTER JOIN`` clause element. + + The returned object is an instance of :class:`.Join`. + + Similar functionality is also available via the + :meth:`~.FromClause.outerjoin()` method on any + :class:`.FromClause`. + + :param left: The left side of the join. + + :param right: The right side of the join. + + :param onclause: Optional criterion for the ``ON`` clause, is + derived from foreign key relationships established between + left and right otherwise. + + To chain joins together, use the :meth:`.FromClause.join` or + :meth:`.FromClause.outerjoin` methods on the resulting + :class:`.Join` object. + + """ + return cls(left, right, onclause, isouter=True) + + + @classmethod + def _create_join(cls, left, right, onclause=None, isouter=False): + """Produce a :class:`.Join` object, given two :class:`.FromClause` + expressions. + + E.g.:: + + j = join(user_table, address_table, user_table.c.id == address_table.c.user_id) + stmt = select([user_table]).select_from(j) + + would emit SQL along the lines of:: + + SELECT user.id, user.name FROM user + JOIN address ON user.id = address.user_id + + Similar functionality is available given any :class:`.FromClause` object + (e.g. such as a :class:`.Table`) using the :meth:`.FromClause.join` + method. + + :param left: The left side of the join. + + :param right: the right side of the join; this is any :class:`.FromClause` + object such as a :class:`.Table` object, and may also be a selectable-compatible + object such as an ORM-mapped class. + + :param onclause: a SQL expression representing the ON clause of the + join. If left at ``None``, :meth:`.FromClause.join` will attempt to + join the two tables based on a foreign key relationship. + + :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN. + + .. seealso:: + + :meth:`.FromClause.join` - method form, based on a given left side + + :class:`.Join` - the type of object produced + + """ + + return cls(left, right, onclause, isouter) + + + @property + def description(self): + return "Join object on %s(%d) and %s(%d)" % ( + self.left.description, + id(self.left), + self.right.description, + id(self.right)) + + def is_derived_from(self, fromclause): + return fromclause is self or \ + self.left.is_derived_from(fromclause) or \ + self.right.is_derived_from(fromclause) + + def self_group(self, against=None): + return FromGrouping(self) + + @util.dependencies("sqlalchemy.sql.util") + def _populate_column_collection(self, sqlutil): + columns = [c for c in self.left.columns] + \ + [c for c in self.right.columns] + + self.primary_key.extend(sqlutil.reduce_columns( + (c for c in columns if c.primary_key), self.onclause)) + self._columns.update((col._label, col) for col in columns) + self.foreign_keys.update(itertools.chain( + *[col.foreign_keys for col in columns])) + + def _refresh_for_new_column(self, column): + col = self.left._refresh_for_new_column(column) + if col is None: + col = self.right._refresh_for_new_column(column) + if col is not None: + if self._cols_populated: + self._columns[col._label] = col + self.foreign_keys.add(col) + if col.primary_key: + self.primary_key.add(col) + return col + return None + + def _copy_internals(self, clone=_clone, **kw): + self._reset_exported() + self.left = clone(self.left, **kw) + self.right = clone(self.right, **kw) + self.onclause = clone(self.onclause, **kw) + + def get_children(self, **kwargs): + return self.left, self.right, self.onclause + + def _match_primaries(self, left, right): + if isinstance(left, Join): + left_right = left.right + else: + left_right = None + return self._join_condition(left, right, a_subset=left_right) + + @classmethod + def _join_condition(cls, a, b, ignore_nonexistent_tables=False, + a_subset=None, + consider_as_foreign_keys=None): + """create a join condition between two tables or selectables. + + e.g.:: + + join_condition(tablea, tableb) + + would produce an expression along the lines of:: + + tablea.c.id==tableb.c.tablea_id + + The join is determined based on the foreign key relationships + between the two selectables. If there are multiple ways + to join, or no way to join, an error is raised. + + :param ignore_nonexistent_tables: Deprecated - this + flag is no longer used. Only resolution errors regarding + the two given tables are propagated. + + :param a_subset: An optional expression that is a sub-component + of ``a``. An attempt will be made to join to just this sub-component + first before looking at the full ``a`` construct, and if found + will be successful even if there are other ways to join to ``a``. + This allows the "right side" of a join to be passed thereby + providing a "natural join". + + """ + constraints = collections.defaultdict(list) + + for left in (a_subset, a): + if left is None: + continue + for fk in sorted( + b.foreign_keys, + key=lambda fk: fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue + try: + col = fk.get_referent(left) + except exc.NoReferenceError as nrte: + if nrte.table_name == left.name: + raise + else: + continue + + if col is not None: + constraints[fk.constraint].append((col, fk.parent)) + if left is not b: + for fk in sorted( + left.foreign_keys, + key=lambda fk: fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue + try: + col = fk.get_referent(b) + except exc.NoReferenceError as nrte: + if nrte.table_name == b.name: + raise + else: + # this is totally covered. can't get + # coverage to mark it. + continue + + if col is not None: + constraints[fk.constraint].append((col, fk.parent)) + if constraints: + break + + if len(constraints) > 1: + # more than one constraint matched. narrow down the list + # to include just those FKCs that match exactly to + # "consider_as_foreign_keys". + if consider_as_foreign_keys: + for const in list(constraints): + if set(f.parent for f in const.elements) != set(consider_as_foreign_keys): + del constraints[const] + + # if still multiple constraints, but + # they all refer to the exact same end result, use it. + if len(constraints) > 1: + dedupe = set(tuple(crit) for crit in constraints.values()) + if len(dedupe) == 1: + key = list(constraints)[0] + constraints = {key: constraints[key]} + + if len(constraints) != 1: + raise exc.AmbiguousForeignKeysError( + "Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (a.description, b.description)) + + if len(constraints) == 0: + if isinstance(b, FromGrouping): + hint = " Perhaps you meant to convert the right side to a "\ + "subquery using alias()?" + else: + hint = "" + raise exc.NoForeignKeysError( + "Can't find any foreign key relationships " + "between '%s' and '%s'.%s" % (a.description, b.description, hint)) + + crit = [(x == y) for x, y in list(constraints.values())[0]] + if len(crit) == 1: + return (crit[0]) + else: + return and_(*crit) + + + def select(self, whereclause=None, **kwargs): + """Create a :class:`.Select` from this :class:`.Join`. + + The equivalent long-hand form, given a :class:`.Join` object + ``j``, is:: + + from sqlalchemy import select + j = select([j.left, j.right], **kw).\\ + where(whereclause).\\ + select_from(j) + + :param whereclause: the WHERE criterion that will be sent to + the :func:`select()` function + + :param \**kwargs: all other kwargs are sent to the + underlying :func:`select()` function. + + """ + collist = [self.left, self.right] + + return Select(collist, whereclause, from_obj=[self], **kwargs) + + @property + def bind(self): + return self.left.bind or self.right.bind + + @util.dependencies("sqlalchemy.sql.util") + def alias(self, sqlutil, name=None, flat=False): + """return an alias of this :class:`.Join`. + + The default behavior here is to first produce a SELECT + construct from this :class:`.Join`, then to produce a + :class:`.Alias` from that. So given a join of the form:: + + j = table_a.join(table_b, table_a.c.id == table_b.c.a_id) + + The JOIN by itself would look like:: + + table_a JOIN table_b ON table_a.id = table_b.a_id + + Whereas the alias of the above, ``j.alias()``, would in a + SELECT context look like:: + + (SELECT table_a.id AS table_a_id, table_b.id AS table_b_id, + table_b.a_id AS table_b_a_id + FROM table_a + JOIN table_b ON table_a.id = table_b.a_id) AS anon_1 + + The equivalent long-hand form, given a :class:`.Join` object + ``j``, is:: + + from sqlalchemy import select, alias + j = alias( + select([j.left, j.right]).\\ + select_from(j).\\ + with_labels(True).\\ + correlate(False), + name=name + ) + + The selectable produced by :meth:`.Join.alias` features the same + columns as that of the two individual selectables presented under + a single name - the individual columns are "auto-labeled", meaning + the ``.c.`` collection of the resulting :class:`.Alias` represents + the names of the individual columns using a ``_`` + scheme:: + + j.c.table_a_id + j.c.table_b_a_id + + :meth:`.Join.alias` also features an alternate + option for aliasing joins which produces no enclosing SELECT and + does not normally apply labels to the column names. The + ``flat=True`` option will call :meth:`.FromClause.alias` + against the left and right sides individually. + Using this option, no new ``SELECT`` is produced; + we instead, from a construct as below:: + + j = table_a.join(table_b, table_a.c.id == table_b.c.a_id) + j = j.alias(flat=True) + + we get a result like this:: + + table_a AS table_a_1 JOIN table_b AS table_b_1 ON + table_a_1.id = table_b_1.a_id + + The ``flat=True`` argument is also propagated to the contained + selectables, so that a composite join such as:: + + j = table_a.join( + table_b.join(table_c, + table_b.c.id == table_c.c.b_id), + table_b.c.a_id == table_a.c.id + ).alias(flat=True) + + Will produce an expression like:: + + table_a AS table_a_1 JOIN ( + table_b AS table_b_1 JOIN table_c AS table_c_1 + ON table_b_1.id = table_c_1.b_id + ) ON table_a_1.id = table_b_1.a_id + + The standalone :func:`~.expression.alias` function as well as the + base :meth:`.FromClause.alias` method also support the ``flat=True`` + argument as a no-op, so that the argument can be passed to the + ``alias()`` method of any selectable. + + .. versionadded:: 0.9.0 Added the ``flat=True`` option to create + "aliases" of joins without enclosing inside of a SELECT + subquery. + + :param name: name given to the alias. + + :param flat: if True, produce an alias of the left and right + sides of this :class:`.Join` and return the join of those + two selectables. This produces join expression that does not + include an enclosing SELECT. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :func:`~.expression.alias` + + """ + if flat: + assert name is None, "Can't send name argument with flat" + left_a, right_a = self.left.alias(flat=True), \ + self.right.alias(flat=True) + adapter = sqlutil.ClauseAdapter(left_a).\ + chain(sqlutil.ClauseAdapter(right_a)) + + return left_a.join(right_a, + adapter.traverse(self.onclause), isouter=self.isouter) + else: + return self.select(use_labels=True, correlate=False).alias(name) + + @property + def _hide_froms(self): + return itertools.chain(*[_from_objects(x.left, x.right) + for x in self._cloned_set]) + + @property + def _from_objects(self): + return [self] + \ + self.onclause._from_objects + \ + self.left._from_objects + \ + self.right._from_objects + + +class Alias(FromClause): + """Represents an table or selectable alias (AS). + + Represents an alias, as typically applied to any table or + sub-select within a SQL statement using the ``AS`` keyword (or + without the keyword on certain databases such as Oracle). + + This object is constructed from the :func:`~.expression.alias` module level + function as well as the :meth:`.FromClause.alias` method available on all + :class:`.FromClause` subclasses. + + """ + + __visit_name__ = 'alias' + named_with_column = True + + _is_from_container = True + + def __init__(self, selectable, name=None): + baseselectable = selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.element + self.original = baseselectable + self.supports_execution = baseselectable.supports_execution + if self.supports_execution: + self._execution_options = baseselectable._execution_options + self.element = selectable + if name is None: + if self.original.named_with_column: + name = getattr(self.original, 'name', None) + name = _anonymous_label('%%(%d %s)s' % (id(self), name + or 'anon')) + self.name = name + + + @property + def description(self): + if util.py3k: + return self.name + else: + return self.name.encode('ascii', 'backslashreplace') + + def as_scalar(self): + try: + return self.element.as_scalar() + except AttributeError: + raise AttributeError("Element %s does not support " + "'as_scalar()'" % self.element) + + def is_derived_from(self, fromclause): + if fromclause in self._cloned_set: + return True + return self.element.is_derived_from(fromclause) + + def _populate_column_collection(self): + for col in self.element.columns._all_columns: + col._make_proxy(self) + + def _refresh_for_new_column(self, column): + col = self.element._refresh_for_new_column(column) + if col is not None: + if not self._cols_populated: + return None + else: + return col._make_proxy(self) + else: + return None + + def _copy_internals(self, clone=_clone, **kw): + # don't apply anything to an aliased Table + # for now. May want to drive this from + # the given **kw. + if isinstance(self.element, TableClause): + return + self._reset_exported() + self.element = clone(self.element, **kw) + baseselectable = self.element + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.element + self.original = baseselectable + + def get_children(self, column_collections=True, **kw): + if column_collections: + for c in self.c: + yield c + yield self.element + + @property + def _from_objects(self): + return [self] + + @property + def bind(self): + return self.element.bind + + +class CTE(Alias): + """Represent a Common Table Expression. + + The :class:`.CTE` object is obtained using the + :meth:`.SelectBase.cte` method from any selectable. + See that method for complete examples. + + .. versionadded:: 0.7.6 + + """ + __visit_name__ = 'cte' + + def __init__(self, selectable, + name=None, + recursive=False, + _cte_alias=None, + _restates=frozenset()): + self.recursive = recursive + self._cte_alias = _cte_alias + self._restates = _restates + super(CTE, self).__init__(selectable, name=name) + + def alias(self, name=None, flat=False): + return CTE( + self.original, + name=name, + recursive=self.recursive, + _cte_alias=self, + ) + + def union(self, other): + return CTE( + self.original.union(other), + name=self.name, + recursive=self.recursive, + _restates=self._restates.union([self]) + ) + + def union_all(self, other): + return CTE( + self.original.union_all(other), + name=self.name, + recursive=self.recursive, + _restates=self._restates.union([self]) + ) + + + + +class FromGrouping(FromClause): + """Represent a grouping of a FROM clause""" + __visit_name__ = 'grouping' + + def __init__(self, element): + self.element = element + + def _init_collections(self): + pass + + @property + def columns(self): + return self.element.columns + + @property + def primary_key(self): + return self.element.primary_key + + @property + def foreign_keys(self): + return self.element.foreign_keys + + def is_derived_from(self, element): + return self.element.is_derived_from(element) + + def alias(self, **kw): + return FromGrouping(self.element.alias(**kw)) + + @property + def _hide_froms(self): + return self.element._hide_froms + + def get_children(self, **kwargs): + return self.element, + + def _copy_internals(self, clone=_clone, **kw): + self.element = clone(self.element, **kw) + + @property + def _from_objects(self): + return self.element._from_objects + + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {'element': self.element} + + def __setstate__(self, state): + self.element = state['element'] + +class TableClause(Immutable, FromClause): + """Represents a minimal "table" construct. + + This is a lightweight table object that has only a name and a + collection of columns, which are typically produced + by the :func:`.expression.column` function:: + + from sqlalchemy.sql import table, column + + user = table("user", + column("id"), + column("name"), + column("description"), + ) + + The :class:`.TableClause` construct serves as the base for + the more commonly used :class:`~.schema.Table` object, providing + the usual set of :class:`~.expression.FromClause` services including + the ``.c.`` collection and statement generation methods. + + It does **not** provide all the additional schema-level services + of :class:`~.schema.Table`, including constraints, references to other + tables, or support for :class:`.MetaData`-level services. It's useful + on its own as an ad-hoc construct used to generate quick SQL + statements when a more fully fledged :class:`~.schema.Table` + is not on hand. + + """ + + __visit_name__ = 'table' + + named_with_column = True + + implicit_returning = False + """:class:`.TableClause` doesn't support having a primary key or column + -level defaults, so implicit returning doesn't apply.""" + + _autoincrement_column = None + """No PK or default support so no autoincrement column.""" + + def __init__(self, name, *columns): + """Produce a new :class:`.TableClause`. + + The object returned is an instance of :class:`.TableClause`, which + represents the "syntactical" portion of the schema-level + :class:`~.schema.Table` object. + It may be used to construct lightweight table constructs. + + Note that the :func:`.expression.table` function is not part of + the ``sqlalchemy`` namespace. It must be imported from the + ``sql`` package:: + + from sqlalchemy.sql import table, column + + :param name: Name of the table. + + :param columns: A collection of :func:`.expression.column` constructs. + + """ + + super(TableClause, self).__init__() + self.name = self.fullname = name + self._columns = ColumnCollection() + self.primary_key = ColumnSet() + self.foreign_keys = set() + for c in columns: + self.append_column(c) + + def _init_collections(self): + pass + + @util.memoized_property + def description(self): + if util.py3k: + return self.name + else: + return self.name.encode('ascii', 'backslashreplace') + + def append_column(self, c): + self._columns[c.key] = c + c.table = self + + def get_children(self, column_collections=True, **kwargs): + if column_collections: + return [c for c in self.c] + else: + return [] + + @util.dependencies("sqlalchemy.sql.functions") + def count(self, functions, whereclause=None, **params): + """return a SELECT COUNT generated against this + :class:`.TableClause`.""" + + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return Select( + [functions.func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) + + @util.dependencies("sqlalchemy.sql.dml") + def insert(self, dml, values=None, inline=False, **kwargs): + """Generate an :func:`.insert` construct against this + :class:`.TableClause`. + + E.g.:: + + table.insert().values(name='foo') + + See :func:`.insert` for argument and usage information. + + """ + + return dml.Insert(self, values=values, inline=inline, **kwargs) + + @util.dependencies("sqlalchemy.sql.dml") + def update(self, dml, whereclause=None, values=None, inline=False, **kwargs): + """Generate an :func:`.update` construct against this + :class:`.TableClause`. + + E.g.:: + + table.update().where(table.c.id==7).values(name='foo') + + See :func:`.update` for argument and usage information. + + """ + + return dml.Update(self, whereclause=whereclause, + values=values, inline=inline, **kwargs) + + @util.dependencies("sqlalchemy.sql.dml") + def delete(self, dml, whereclause=None, **kwargs): + """Generate a :func:`.delete` construct against this + :class:`.TableClause`. + + E.g.:: + + table.delete().where(table.c.id==7) + + See :func:`.delete` for argument and usage information. + + """ + + return dml.Delete(self, whereclause, **kwargs) + + @property + def _from_objects(self): + return [self] + + +class ForUpdateArg(ClauseElement): + + @classmethod + def parse_legacy_select(self, arg): + """Parse the for_update arugment of :func:`.select`. + + :param mode: Defines the lockmode to use. + + ``None`` - translates to no lockmode + + ``'update'`` - translates to ``FOR UPDATE`` + (standard SQL, supported by most dialects) + + ``'nowait'`` - translates to ``FOR UPDATE NOWAIT`` + (supported by Oracle, PostgreSQL 8.1 upwards) + + ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL), + and ``FOR SHARE`` (for PostgreSQL) + + ``'read_nowait'`` - translates to ``FOR SHARE NOWAIT`` + (supported by PostgreSQL). ``FOR SHARE`` and + ``FOR SHARE NOWAIT`` (PostgreSQL). + + """ + if arg in (None, False): + return None + + nowait = read = False + if arg == 'nowait': + nowait = True + elif arg == 'read': + read = True + elif arg == 'read_nowait': + read = nowait = True + elif arg is not True: + raise exc.ArgumentError("Unknown for_update argument: %r" % arg) + + return ForUpdateArg(read=read, nowait=nowait) + + @property + def legacy_for_update_value(self): + if self.read and not self.nowait: + return "read" + elif self.read and self.nowait: + return "read_nowait" + elif self.nowait: + return "nowait" + else: + return True + + def _copy_internals(self, clone=_clone, **kw): + if self.of is not None: + self.of = [clone(col, **kw) for col in self.of] + + def __init__(self, nowait=False, read=False, of=None): + """Represents arguments specified to :meth:`.Select.for_update`. + + .. versionadded:: 0.9.0 + """ + + self.nowait = nowait + self.read = read + if of is not None: + self.of = [_interpret_as_column_or_from(elem) + for elem in util.to_list(of)] + else: + self.of = None + + +class SelectBase(Executable, FromClause): + """Base class for SELECT statements. + + + This includes :class:`.Select`, :class:`.CompoundSelect` and + :class:`.TextAsFrom`. + + + """ + + def as_scalar(self): + """return a 'scalar' representation of this selectable, which can be + used as a column expression. + + Typically, a select statement which has only one column in its columns + clause is eligible to be used as a scalar expression. + + The returned object is an instance of + :class:`ScalarSelect`. + + """ + return ScalarSelect(self) + + + def label(self, name): + """return a 'scalar' representation of this selectable, embedded as a + subquery with a label. + + .. seealso:: + + :meth:`~.SelectBase.as_scalar`. + + """ + return self.as_scalar().label(name) + + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + .. versionadded:: 0.7.6 + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples illustrate two examples from + Postgresql's documentation at + http://www.postgresql.org/docs/8.4/static/queries-with.html. + + Example 1, non recursive:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ]).\ + select_from(included_parts.join(parts, + included_parts.c.part==parts.c.part)).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + + .. seealso:: + + :meth:`.orm.query.Query.cte` - ORM version of :meth:`.SelectBase.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + + @_generative + @util.deprecated('0.6', + message="``autocommit()`` is deprecated. Use " + ":meth:`.Executable.execution_options` with the " + "'autocommit' flag.") + def autocommit(self): + """return a new selectable with the 'autocommit' flag set to + True. + """ + + self._execution_options = \ + self._execution_options.union({'autocommit': True}) + + def _generate(self): + """Override the default _generate() method to also clear out + exported collections.""" + + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + s._reset_exported() + return s + + @property + def _from_objects(self): + return [self] + +class GenerativeSelect(SelectBase): + """Base class for SELECT statements where additional elements can be + added. + + This serves as the base for :class:`.Select` and :class:`.CompoundSelect` + where elements such as ORDER BY, GROUP BY can be added and column rendering + can be controlled. Compare to :class:`.TextAsFrom`, which, while it + subclasses :class:`.SelectBase` and is also a SELECT construct, represents + a fixed textual string which cannot be altered at this level, only + wrapped as a subquery. + + .. versionadded:: 0.9.0 :class:`.GenerativeSelect` was added to + provide functionality specific to :class:`.Select` and :class:`.CompoundSelect` + while allowing :class:`.SelectBase` to be used for other SELECT-like + objects, e.g. :class:`.TextAsFrom`. + + """ + _order_by_clause = ClauseList() + _group_by_clause = ClauseList() + _limit = None + _offset = None + _for_update_arg = None + + def __init__(self, + use_labels=False, + for_update=False, + limit=None, + offset=None, + order_by=None, + group_by=None, + bind=None, + autocommit=None): + self.use_labels = use_labels + + if for_update is not False: + self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update) + + if autocommit is not None: + util.warn_deprecated('autocommit on select() is ' + 'deprecated. Use .execution_options(a' + 'utocommit=True)') + self._execution_options = \ + self._execution_options.union( + {'autocommit': autocommit}) + if limit is not None: + self._limit = util.asint(limit) + if offset is not None: + self._offset = util.asint(offset) + self._bind = bind + + if order_by is not None: + self._order_by_clause = ClauseList(*util.to_list(order_by)) + if group_by is not None: + self._group_by_clause = ClauseList(*util.to_list(group_by)) + + @property + def for_update(self): + """Provide legacy dialect support for the ``for_update`` attribute. + """ + if self._for_update_arg is not None: + return self._for_update_arg.legacy_for_update_value + else: + return None + + @for_update.setter + def for_update(self, value): + self._for_update_arg = ForUpdateArg.parse_legacy_select(value) + + @_generative + def with_for_update(self, nowait=False, read=False, of=None): + """Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`. + + E.g.:: + + stmt = select([table]).with_for_update(nowait=True) + + On a database like Postgresql or Oracle, the above would render a + statement like:: + + SELECT table.a, table.b FROM table FOR UPDATE NOWAIT + + on other backends, the ``nowait`` option is ignored and instead + would produce:: + + SELECT table.a, table.b FROM table FOR UPDATE + + When called with no arguments, the statement will render with + the suffix ``FOR UPDATE``. Additional arguments can then be + provided which allow for common database-specific + variants. + + :param nowait: boolean; will render ``FOR UPDATE NOWAIT`` on Oracle and + Postgresql dialects. + + :param read: boolean; will render ``LOCK IN SHARE MODE`` on MySQL, + ``FOR SHARE`` on Postgresql. On Postgresql, when combined with + ``nowait``, will render ``FOR SHARE NOWAIT``. + + :param of: SQL expression or list of SQL expression elements + (typically :class:`.Column` objects or a compatible expression) which + will render into a ``FOR UPDATE OF`` clause; supported by PostgreSQL + and Oracle. May render as a table or as a column depending on + backend. + + .. versionadded:: 0.9.0 + + """ + self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of) + + @_generative + def apply_labels(self): + """return a new selectable with the 'use_labels' flag set to True. + + This will result in column expressions being generated using labels + against their table name, such as "SELECT somecolumn AS + tablename_somecolumn". This allows selectables which contain multiple + FROM clauses to produce a unique set of column names regardless of + name conflicts among the individual FROM clauses. + + """ + self.use_labels = True + + @_generative + def limit(self, limit): + """return a new selectable with the given LIMIT criterion + applied.""" + + self._limit = util.asint(limit) + + @_generative + def offset(self, offset): + """return a new selectable with the given OFFSET criterion + applied.""" + + self._offset = util.asint(offset) + + @_generative + def order_by(self, *clauses): + """return a new selectable with the given list of ORDER BY + criterion applied. + + The criterion will be appended to any pre-existing ORDER BY + criterion. + + """ + + self.append_order_by(*clauses) + + @_generative + def group_by(self, *clauses): + """return a new selectable with the given list of GROUP BY + criterion applied. + + The criterion will be appended to any pre-existing GROUP BY + criterion. + + """ + + self.append_group_by(*clauses) + + def append_order_by(self, *clauses): + """Append the given ORDER BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing ORDER BY criterion. + + This is an **in-place** mutation method; the + :meth:`~.GenerativeSelect.order_by` method is preferred, as it provides standard + :term:`method chaining`. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._order_by_clause = ClauseList() + else: + if getattr(self, '_order_by_clause', None) is not None: + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + + def append_group_by(self, *clauses): + """Append the given GROUP BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing GROUP BY criterion. + + This is an **in-place** mutation method; the + :meth:`~.GenerativeSelect.group_by` method is preferred, as it provides standard + :term:`method chaining`. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None) is not None: + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + + +class CompoundSelect(GenerativeSelect): + """Forms the basis of ``UNION``, ``UNION ALL``, and other + SELECT-based set operations. + + + .. seealso:: + + :func:`.union` + + :func:`.union_all` + + :func:`.intersect` + + :func:`.intersect_all` + + :func:`.except` + + :func:`.except_all` + + """ + + __visit_name__ = 'compound_select' + + UNION = util.symbol('UNION') + UNION_ALL = util.symbol('UNION ALL') + EXCEPT = util.symbol('EXCEPT') + EXCEPT_ALL = util.symbol('EXCEPT ALL') + INTERSECT = util.symbol('INTERSECT') + INTERSECT_ALL = util.symbol('INTERSECT ALL') + + _is_from_container = True + + def __init__(self, keyword, *selects, **kwargs): + self._auto_correlate = kwargs.pop('correlate', False) + self.keyword = keyword + self.selects = [] + + numcols = None + + # some DBs do not like ORDER BY in the inner queries of a UNION, etc. + for n, s in enumerate(selects): + s = _clause_element_as_expr(s) + + if not numcols: + numcols = len(s.c._all_columns) + elif len(s.c._all_columns) != numcols: + raise exc.ArgumentError('All selectables passed to ' + 'CompoundSelect must have identical numbers of ' + 'columns; select #%d has %d columns, select ' + '#%d has %d' % (1, len(self.selects[0].c._all_columns), n + + 1, len(s.c._all_columns))) + + self.selects.append(s.self_group(self)) + + GenerativeSelect.__init__(self, **kwargs) + + @classmethod + def _create_union(cls, *selects, **kwargs): + """Return a ``UNION`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + A similar :func:`union()` method is available on all + :class:`.FromClause` subclasses. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + + @classmethod + def _create_union_all(cls, *selects, **kwargs): + """Return a ``UNION ALL`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + A similar :func:`union_all()` method is available on all + :class:`.FromClause` subclasses. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs) + + + @classmethod + def _create_except(cls, *selects, **kwargs): + """Return an ``EXCEPT`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs) + + + @classmethod + def _create_except_all(cls, *selects, **kwargs): + """Return an ``EXCEPT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs) + + + @classmethod + def _create_intersect(cls, *selects, **kwargs): + """Return an ``INTERSECT`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs) + + + @classmethod + def _create_intersect_all(cls, *selects, **kwargs): + """Return an ``INTERSECT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`.CompoundSelect`. + + \*selects + a list of :class:`.Select` instances. + + \**kwargs + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) + + + def _scalar_type(self): + return self.selects[0]._scalar_type() + + def self_group(self, against=None): + return FromGrouping(self) + + def is_derived_from(self, fromclause): + for s in self.selects: + if s.is_derived_from(fromclause): + return True + return False + + def _populate_column_collection(self): + for cols in zip(*[s.c._all_columns for s in self.selects]): + + # this is a slightly hacky thing - the union exports a + # column that resembles just that of the *first* selectable. + # to get at a "composite" column, particularly foreign keys, + # you have to dig through the proxies collection which we + # generate below. We may want to improve upon this, such as + # perhaps _make_proxy can accept a list of other columns + # that are "shared" - schema.column can then copy all the + # ForeignKeys in. this would allow the union() to have all + # those fks too. + + proxy = cols[0]._make_proxy(self, + name=cols[0]._label if self.use_labels else None, + key=cols[0]._key_label if self.use_labels else None) + + # hand-construct the "_proxies" collection to include all + # derived columns place a 'weight' annotation corresponding + # to how low in the list of select()s the column occurs, so + # that the corresponding_column() operation can resolve + # conflicts + + proxy._proxies = [c._annotate({'weight': i + 1}) for (i, + c) in enumerate(cols)] + + def _refresh_for_new_column(self, column): + for s in self.selects: + s._refresh_for_new_column(column) + + if not self._cols_populated: + return None + + raise NotImplementedError("CompoundSelect constructs don't support " + "addition of columns to underlying selectables") + + def _copy_internals(self, clone=_clone, **kw): + self._reset_exported() + self.selects = [clone(s, **kw) for s in self.selects] + if hasattr(self, '_col_map'): + del self._col_map + for attr in ('_order_by_clause', '_group_by_clause', '_for_update_arg'): + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr), **kw)) + + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.c) or []) \ + + [self._order_by_clause, self._group_by_clause] \ + + list(self.selects) + + def bind(self): + if self._bind: + return self._bind + for s in self.selects: + e = s.bind + if e: + return e + else: + return None + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + +class HasPrefixes(object): + _prefixes = () + + @_generative + def prefix_with(self, *expr, **kw): + """Add one or more expressions following the statement keyword, i.e. + SELECT, INSERT, UPDATE, or DELETE. Generative. + + This is used to support backend-specific prefix keywords such as those + provided by MySQL. + + E.g.:: + + stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") + + Multiple prefixes can be specified by multiple calls + to :meth:`.prefix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the INSERT, UPDATE, or DELETE + keyword. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this prefix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_prefixes(expr, dialect) + + def _setup_prefixes(self, prefixes, dialect=None): + self._prefixes = self._prefixes + tuple( + [(_literal_as_text(p), dialect) for p in prefixes]) + + + +class Select(HasPrefixes, GenerativeSelect): + """Represents a ``SELECT`` statement. + + """ + + __visit_name__ = 'select' + + _prefixes = () + _hints = util.immutabledict() + _distinct = False + _from_cloned = None + _correlate = () + _correlate_except = None + _memoized_property = SelectBase._memoized_property + _is_select = True + + def __init__(self, + columns=None, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + **kwargs): + """Construct a new :class:`.Select`. + + Similar functionality is also available via the :meth:`.FromClause.select` + method on any :class:`.FromClause`. + + All arguments which accept :class:`.ClauseElement` arguments also accept + string arguments, which will be converted as appropriate into + either :func:`text()` or :func:`literal_column()` constructs. + + .. seealso:: + + :ref:`coretutorial_selecting` - Core Tutorial description of + :func:`.select`. + + :param columns: + A list of :class:`.ClauseElement` objects, typically + :class:`.ColumnElement` objects or subclasses, which will form the + columns clause of the resulting statement. For all members which are + instances of :class:`.Selectable`, the individual :class:`.ColumnElement` + members of the :class:`.Selectable` will be added individually to the + columns clause. For example, specifying a + :class:`~sqlalchemy.schema.Table` instance will result in all the + contained :class:`~sqlalchemy.schema.Column` objects within to be added + to the columns clause. + + This argument is not present on the form of :func:`select()` + available on :class:`~sqlalchemy.schema.Table`. + + :param whereclause: + A :class:`.ClauseElement` expression which will be used to form the + ``WHERE`` clause. + + :param from_obj: + A list of :class:`.ClauseElement` objects which will be added to the + ``FROM`` clause of the resulting statement. Note that "from" objects are + automatically located within the columns and whereclause ClauseElements. + Use this parameter to explicitly specify "from" objects which are not + automatically locatable. This could include + :class:`~sqlalchemy.schema.Table` objects that aren't otherwise present, + or :class:`.Join` objects whose presence will supercede that of the + :class:`~sqlalchemy.schema.Table` objects already located in the other + clauses. + + :param autocommit: + Deprecated. Use .execution_options(autocommit=) + to set the autocommit option. + + :param bind=None: + an :class:`~.Engine` or :class:`~.Connection` instance + to which the + resulting :class:`.Select` object will be bound. The :class:`.Select` + object will otherwise automatically bind to whatever + :class:`~.base.Connectable` instances can be located within its contained + :class:`.ClauseElement` members. + + :param correlate=True: + indicates that this :class:`.Select` object should have its + contained :class:`.FromClause` elements "correlated" to an enclosing + :class:`.Select` object. This means that any :class:`.ClauseElement` + instance within the "froms" collection of this :class:`.Select` + which is also present in the "froms" collection of an + enclosing select will not be rendered in the ``FROM`` clause + of this select statement. + + :param distinct=False: + when ``True``, applies a ``DISTINCT`` qualifier to the columns + clause of the resulting statement. + + The boolean argument may also be a column expression or list + of column expressions - this is a special calling form which + is understood by the Postgresql dialect to render the + ``DISTINCT ON ()`` syntax. + + ``distinct`` is also available via the :meth:`~.Select.distinct` + generative method. + + :param for_update=False: + when ``True``, applies ``FOR UPDATE`` to the end of the + resulting statement. + + .. deprecated:: 0.9.0 - use :meth:`.GenerativeSelect.with_for_update` + to specify the structure of the ``FOR UPDATE`` clause. + + ``for_update`` accepts various string values interpreted by + specific backends, including: + + * ``"read"`` - on MySQL, translates to ``LOCK IN SHARE MODE``; + on Postgresql, translates to ``FOR SHARE``. + * ``"nowait"`` - on Postgresql and Oracle, translates to + ``FOR UPDATE NOWAIT``. + * ``"read_nowait"`` - on Postgresql, translates to + ``FOR SHARE NOWAIT``. + + .. seealso:: + + :meth:`.GenerativeSelect.with_for_update` - improved API for + specifying the ``FOR UPDATE`` clause. + + :param group_by: + a list of :class:`.ClauseElement` objects which will comprise the + ``GROUP BY`` clause of the resulting select. + + :param having: + a :class:`.ClauseElement` that will comprise the ``HAVING`` clause + of the resulting select when ``GROUP BY`` is used. + + :param limit=None: + a numerical value which usually compiles to a ``LIMIT`` + expression in the resulting select. Databases that don't + support ``LIMIT`` will attempt to provide similar + functionality. + + :param offset=None: + a numeric value which usually compiles to an ``OFFSET`` + expression in the resulting select. Databases that don't + support ``OFFSET`` will attempt to provide similar + functionality. + + :param order_by: + a scalar or list of :class:`.ClauseElement` objects which will + comprise the ``ORDER BY`` clause of the resulting select. + + :param use_labels=False: + when ``True``, the statement will be generated using labels + for each column in the columns clause, which qualify each + column with its parent table's (or aliases) name so that name + conflicts between columns in different tables don't occur. + The format of the label is _. The "c" + collection of the resulting :class:`.Select` object will use these + names as well for targeting column members. + + use_labels is also available via the :meth:`~.GenerativeSelect.apply_labels` + generative method. + + """ + self._auto_correlate = correlate + if distinct is not False: + if distinct is True: + self._distinct = True + else: + self._distinct = [ + _literal_as_text(e) + for e in util.to_list(distinct) + ] + + if from_obj is not None: + self._from_obj = util.OrderedSet( + _interpret_as_from(f) + for f in util.to_list(from_obj)) + else: + self._from_obj = util.OrderedSet() + + try: + cols_present = bool(columns) + except TypeError: + raise exc.ArgumentError("columns argument to select() must " + "be a Python list or other iterable") + + if cols_present: + self._raw_columns = [] + for c in columns: + c = _interpret_as_column_or_from(c) + if isinstance(c, ScalarSelect): + c = c.self_group(against=operators.comma_op) + self._raw_columns.append(c) + else: + self._raw_columns = [] + + if whereclause is not None: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + + if having is not None: + self._having = _literal_as_text(having) + else: + self._having = None + + if prefixes: + self._setup_prefixes(prefixes) + + GenerativeSelect.__init__(self, **kwargs) + + @property + def _froms(self): + # would love to cache this, + # but there's just enough edge cases, particularly now that + # declarative encourages construction of SQL expressions + # without tables present, to just regen this each time. + froms = [] + seen = set() + translate = self._from_cloned + + def add(items): + for item in items: + if item is self: + raise exc.InvalidRequestError( + "select() construct refers to itself as a FROM") + if translate and item in translate: + item = translate[item] + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + add(_from_objects(*self._raw_columns)) + if self._whereclause is not None: + add(_from_objects(self._whereclause)) + add(self._from_obj) + + return froms + + def _get_display_froms(self, explicit_correlate_froms=None, + implicit_correlate_froms=None): + """Return the full list of 'from' clauses to be displayed. + + Takes into account a set of existing froms which may be + rendered in the FROM clause of enclosing selects; this Select + may want to leave those absent if it is automatically + correlating. + + """ + froms = self._froms + + toremove = set(itertools.chain(*[ + _expand_cloned(f._hide_froms) + for f in froms])) + if toremove: + # if we're maintaining clones of froms, + # add the copies out to the toremove list. only include + # clones that are lexical equivalents. + if self._from_cloned: + toremove.update( + self._from_cloned[f] for f in + toremove.intersection(self._from_cloned) + if self._from_cloned[f]._is_lexical_equivalent(f) + ) + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + + if self._correlate: + to_correlate = self._correlate + if to_correlate: + froms = [ + f for f in froms if f not in + _cloned_intersection( + _cloned_intersection(froms, explicit_correlate_froms or ()), + to_correlate + ) + ] + + if self._correlate_except is not None: + + froms = [ + f for f in froms if f not in + _cloned_difference( + _cloned_intersection(froms, explicit_correlate_froms or ()), + self._correlate_except + ) + ] + + if self._auto_correlate and \ + implicit_correlate_froms and \ + len(froms) > 1: + + froms = [ + f for f in froms if f not in + _cloned_intersection(froms, implicit_correlate_froms) + ] + + if not len(froms): + raise exc.InvalidRequestError("Select statement '%s" + "' returned no FROM clauses due to " + "auto-correlation; specify " + "correlate() to control " + "correlation manually." % self) + + return froms + + def _scalar_type(self): + elem = self._raw_columns[0] + cols = list(elem._select_iterable) + return cols[0].type + + @property + def froms(self): + """Return the displayed list of FromClause elements.""" + + return self._get_display_froms() + + @_generative + def with_hint(self, selectable, text, dialect_name='*'): + """Add an indexing hint for the given selectable to this + :class:`.Select`. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the given :class:`.Table` or :class:`.Alias` passed as the + ``selectable`` argument. The dialect implementation + typically uses Python string substitution syntax + with the token ``%(name)s`` to render the name of + the table or alias. E.g. when using Oracle, the + following:: + + select([mytable]).\\ + with_hint(mytable, "+ index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add hints for both Oracle + and Sybase simultaneously:: + + select([mytable]).\\ + with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + """ + self._hints = self._hints.union( + {(selectable, dialect_name): text}) + + @property + def type(self): + raise exc.InvalidRequestError("Select objects don't have a type. " + "Call as_scalar() on this Select object " + "to return a 'scalar' version of this Select.") + + @_memoized_property.method + def locate_all_froms(self): + """return a Set of all FromClause elements referenced by this Select. + + This set is a superset of that returned by the ``froms`` property, + which is specifically for those FromClause elements that would + actually be rendered. + + """ + froms = self._froms + return froms + list(_from_objects(*froms)) + + @property + def inner_columns(self): + """an iterator of all ColumnElement expressions which would + be rendered into the columns clause of the resulting SELECT statement. + + """ + return _select_iterables(self._raw_columns) + + def is_derived_from(self, fromclause): + if self in fromclause._cloned_set: + return True + + for f in self.locate_all_froms(): + if f.is_derived_from(fromclause): + return True + return False + + def _copy_internals(self, clone=_clone, **kw): + + # Select() object has been cloned and probably adapted by the + # given clone function. Apply the cloning function to internal + # objects + + # 1. keep a dictionary of the froms we've cloned, and what + # they've become. This is consulted later when we derive + # additional froms from "whereclause" and the columns clause, + # which may still reference the uncloned parent table. + # as of 0.7.4 we also put the current version of _froms, which + # gets cleared on each generation. previously we were "baking" + # _froms into self._from_obj. + self._from_cloned = from_cloned = dict((f, clone(f, **kw)) + for f in self._from_obj.union(self._froms)) + + # 3. update persistent _from_obj with the cloned versions. + self._from_obj = util.OrderedSet(from_cloned[f] for f in + self._from_obj) + + # the _correlate collection is done separately, what can happen + # here is the same item is _correlate as in _from_obj but the + # _correlate version has an annotation on it - (specifically + # RelationshipProperty.Comparator._criterion_exists() does + # this). Also keep _correlate liberally open with it's previous + # contents, as this set is used for matching, not rendering. + self._correlate = set(clone(f) for f in + self._correlate).union(self._correlate) + + # 4. clone other things. The difficulty here is that Column + # objects are not actually cloned, and refer to their original + # .table, resulting in the wrong "from" parent after a clone + # operation. Hence _from_cloned and _from_obj supercede what is + # present here. + self._raw_columns = [clone(c, **kw) for c in self._raw_columns] + for attr in '_whereclause', '_having', '_order_by_clause', \ + '_group_by_clause', '_for_update_arg': + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr), **kw)) + + # erase exported column list, _froms collection, + # etc. + self._reset_exported() + + def get_children(self, column_collections=True, **kwargs): + """return child elements as per the ClauseElement specification.""" + + return (column_collections and list(self.columns) or []) + \ + self._raw_columns + list(self._froms) + \ + [x for x in + (self._whereclause, self._having, + self._order_by_clause, self._group_by_clause) + if x is not None] + + @_generative + def column(self, column): + """return a new select() construct with the given column expression + added to its columns clause. + + """ + self.append_column(column) + + @util.dependencies("sqlalchemy.sql.util") + def reduce_columns(self, sqlutil, only_synonyms=True): + """Return a new :func`.select` construct with redundantly + named, equivalently-valued columns removed from the columns clause. + + "Redundant" here means two columns where one refers to the + other either based on foreign key, or via a simple equality + comparison in the WHERE clause of the statement. The primary purpose + of this method is to automatically construct a select statement + with all uniquely-named columns, without the need to use + table-qualified labels as :meth:`.apply_labels` does. + + When columns are omitted based on foreign key, the referred-to + column is the one that's kept. When columns are omitted based on + WHERE eqivalence, the first column in the columns clause is the + one that's kept. + + :param only_synonyms: when True, limit the removal of columns + to those which have the same name as the equivalent. Otherwise, + all columns that are equivalent to another are removed. + + .. versionadded:: 0.8 + + """ + return self.with_only_columns( + sqlutil.reduce_columns( + self.inner_columns, + only_synonyms=only_synonyms, + *(self._whereclause, ) + tuple(self._from_obj) + ) + ) + + @_generative + def with_only_columns(self, columns): + """Return a new :func:`.select` construct with its columns + clause replaced with the given columns. + + .. versionchanged:: 0.7.3 + Due to a bug fix, this method has a slight + behavioral change as of version 0.7.3. + Prior to version 0.7.3, the FROM clause of + a :func:`.select` was calculated upfront and as new columns + were added; in 0.7.3 and later it's calculated + at compile time, fixing an issue regarding late binding + of columns to parent tables. This changes the behavior of + :meth:`.Select.with_only_columns` in that FROM clauses no + longer represented in the new list are dropped, + but this behavior is more consistent in + that the FROM clauses are consistently derived from the + current columns clause. The original intent of this method + is to allow trimming of the existing columns list to be fewer + columns than originally present; the use case of replacing + the columns list with an entirely different one hadn't + been anticipated until 0.7.3 was released; the usage + guidelines below illustrate how this should be done. + + This method is exactly equivalent to as if the original + :func:`.select` had been called with the given columns + clause. I.e. a statement:: + + s = select([table1.c.a, table1.c.b]) + s = s.with_only_columns([table1.c.b]) + + should be exactly equivalent to:: + + s = select([table1.c.b]) + + This means that FROM clauses which are only derived + from the column list will be discarded if the new column + list no longer contains that FROM:: + + >>> table1 = table('t1', column('a'), column('b')) + >>> table2 = table('t2', column('a'), column('b')) + >>> s1 = select([table1.c.a, table2.c.b]) + >>> print s1 + SELECT t1.a, t2.b FROM t1, t2 + >>> s2 = s1.with_only_columns([table2.c.b]) + >>> print s2 + SELECT t2.b FROM t1 + + The preferred way to maintain a specific FROM clause + in the construct, assuming it won't be represented anywhere + else (i.e. not in the WHERE clause, etc.) is to set it using + :meth:`.Select.select_from`:: + + >>> s1 = select([table1.c.a, table2.c.b]).\\ + ... select_from(table1.join(table2, + ... table1.c.a==table2.c.a)) + >>> s2 = s1.with_only_columns([table2.c.b]) + >>> print s2 + SELECT t2.b FROM t1 JOIN t2 ON t1.a=t2.a + + Care should also be taken to use the correct + set of column objects passed to :meth:`.Select.with_only_columns`. + Since the method is essentially equivalent to calling the + :func:`.select` construct in the first place with the given + columns, the columns passed to :meth:`.Select.with_only_columns` + should usually be a subset of those which were passed + to the :func:`.select` construct, not those which are available + from the ``.c`` collection of that :func:`.select`. That + is:: + + s = select([table1.c.a, table1.c.b]).select_from(table1) + s = s.with_only_columns([table1.c.b]) + + and **not**:: + + # usually incorrect + s = s.with_only_columns([s.c.b]) + + The latter would produce the SQL:: + + SELECT b + FROM (SELECT t1.a AS a, t1.b AS b + FROM t1), t1 + + Since the :func:`.select` construct is essentially being + asked to select both from ``table1`` as well as itself. + + """ + self._reset_exported() + rc = [] + for c in columns: + c = _interpret_as_column_or_from(c) + if isinstance(c, ScalarSelect): + c = c.self_group(against=operators.comma_op) + rc.append(c) + self._raw_columns = rc + + @_generative + def where(self, whereclause): + """return a new select() construct with the given expression added to + its WHERE clause, joined to the existing clause via AND, if any. + + """ + + self.append_whereclause(whereclause) + + @_generative + def having(self, having): + """return a new select() construct with the given expression added to + its HAVING clause, joined to the existing clause via AND, if any. + + """ + self.append_having(having) + + @_generative + def distinct(self, *expr): + """Return a new select() construct which will apply DISTINCT to its + columns clause. + + :param \*expr: optional column expressions. When present, + the Postgresql dialect will render a ``DISTINCT ON (>)`` + construct. + + """ + if expr: + expr = [_literal_as_text(e) for e in expr] + if isinstance(self._distinct, list): + self._distinct = self._distinct + expr + else: + self._distinct = expr + else: + self._distinct = True + + @_generative + def select_from(self, fromclause): + """return a new :func:`.select` construct with the + given FROM expression + merged into its list of FROM objects. + + E.g.:: + + table1 = table('t1', column('a')) + table2 = table('t2', column('b')) + s = select([table1.c.a]).\\ + select_from( + table1.join(table2, table1.c.a==table2.c.b) + ) + + The "from" list is a unique set on the identity of each element, + so adding an already present :class:`.Table` or other selectable + will have no effect. Passing a :class:`.Join` that refers + to an already present :class:`.Table` or other selectable will have + the effect of concealing the presence of that selectable as + an individual element in the rendered FROM list, instead + rendering it into a JOIN clause. + + While the typical purpose of :meth:`.Select.select_from` is to + replace the default, derived FROM clause with a join, it can + also be called with individual table elements, multiple times + if desired, in the case that the FROM clause cannot be fully + derived from the columns clause:: + + select([func.count('*')]).select_from(table1) + + """ + self.append_from(fromclause) + + @_generative + def correlate(self, *fromclauses): + """return a new :class:`.Select` which will correlate the given FROM + clauses to that of an enclosing :class:`.Select`. + + Calling this method turns off the :class:`.Select` object's + default behavior of "auto-correlation". Normally, FROM elements + which appear in a :class:`.Select` that encloses this one via + its :term:`WHERE clause`, ORDER BY, HAVING or + :term:`columns clause` will be omitted from this :class:`.Select` + object's :term:`FROM clause`. + Setting an explicit correlation collection using the + :meth:`.Select.correlate` method provides a fixed list of FROM objects + that can potentially take place in this process. + + When :meth:`.Select.correlate` is used to apply specific FROM clauses + for correlation, the FROM elements become candidates for + correlation regardless of how deeply nested this :class:`.Select` + object is, relative to an enclosing :class:`.Select` which refers to + the same FROM object. This is in contrast to the behavior of + "auto-correlation" which only correlates to an immediate enclosing + :class:`.Select`. Multi-level correlation ensures that the link + between enclosed and enclosing :class:`.Select` is always via + at least one WHERE/ORDER BY/HAVING/columns clause in order for + correlation to take place. + + If ``None`` is passed, the :class:`.Select` object will correlate + none of its FROM entries, and all will render unconditionally + in the local FROM clause. + + :param \*fromclauses: a list of one or more :class:`.FromClause` + constructs, or other compatible constructs (i.e. ORM-mapped + classes) to become part of the correlate collection. + + .. versionchanged:: 0.8.0 ORM-mapped classes are accepted by + :meth:`.Select.correlate`. + + .. versionchanged:: 0.8.0 The :meth:`.Select.correlate` method no + longer unconditionally removes entries from the FROM clause; instead, + the candidate FROM entries must also be matched by a FROM entry + located in an enclosing :class:`.Select`, which ultimately encloses + this one as present in the WHERE clause, ORDER BY clause, HAVING + clause, or columns clause of an enclosing :meth:`.Select`. + + .. versionchanged:: 0.8.2 explicit correlation takes place + via any level of nesting of :class:`.Select` objects; in previous + 0.8 versions, correlation would only occur relative to the immediate + enclosing :class:`.Select` construct. + + .. seealso:: + + :meth:`.Select.correlate_except` + + :ref:`correlated_subqueries` + + """ + self._auto_correlate = False + if fromclauses and fromclauses[0] is None: + self._correlate = () + else: + self._correlate = set(self._correlate).union( + _interpret_as_from(f) for f in fromclauses) + + @_generative + def correlate_except(self, *fromclauses): + """return a new :class:`.Select` which will omit the given FROM + clauses from the auto-correlation process. + + Calling :meth:`.Select.correlate_except` turns off the + :class:`.Select` object's default behavior of + "auto-correlation" for the given FROM elements. An element + specified here will unconditionally appear in the FROM list, while + all other FROM elements remain subject to normal auto-correlation + behaviors. + + .. versionchanged:: 0.8.2 The :meth:`.Select.correlate_except` + method was improved to fully prevent FROM clauses specified here + from being omitted from the immediate FROM clause of this + :class:`.Select`. + + If ``None`` is passed, the :class:`.Select` object will correlate + all of its FROM entries. + + .. versionchanged:: 0.8.2 calling ``correlate_except(None)`` will + correctly auto-correlate all FROM clauses. + + :param \*fromclauses: a list of one or more :class:`.FromClause` + constructs, or other compatible constructs (i.e. ORM-mapped + classes) to become part of the correlate-exception collection. + + .. seealso:: + + :meth:`.Select.correlate` + + :ref:`correlated_subqueries` + + """ + + self._auto_correlate = False + if fromclauses and fromclauses[0] is None: + self._correlate_except = () + else: + self._correlate_except = set(self._correlate_except or ()).union( + _interpret_as_from(f) for f in fromclauses) + + def append_correlation(self, fromclause): + """append the given correlation expression to this select() + construct. + + This is an **in-place** mutation method; the + :meth:`~.Select.correlate` method is preferred, as it provides standard + :term:`method chaining`. + + """ + + self._auto_correlate = False + self._correlate = set(self._correlate).union( + _interpret_as_from(f) for f in fromclause) + + def append_column(self, column): + """append the given column expression to the columns clause of this + select() construct. + + This is an **in-place** mutation method; the + :meth:`~.Select.column` method is preferred, as it provides standard + :term:`method chaining`. + + """ + self._reset_exported() + column = _interpret_as_column_or_from(column) + + if isinstance(column, ScalarSelect): + column = column.self_group(against=operators.comma_op) + + self._raw_columns = self._raw_columns + [column] + + def append_prefix(self, clause): + """append the given columns clause prefix expression to this select() + construct. + + This is an **in-place** mutation method; the + :meth:`~.Select.prefix_with` method is preferred, as it provides standard + :term:`method chaining`. + + """ + clause = _literal_as_text(clause) + self._prefixes = self._prefixes + (clause,) + + def append_whereclause(self, whereclause): + """append the given expression to this select() construct's WHERE + criterion. + + The expression will be joined to existing WHERE criterion via AND. + + This is an **in-place** mutation method; the + :meth:`~.Select.where` method is preferred, as it provides standard + :term:`method chaining`. + + """ + + self._reset_exported() + self._whereclause = and_(True_._ifnone(self._whereclause), whereclause) + + def append_having(self, having): + """append the given expression to this select() construct's HAVING + criterion. + + The expression will be joined to existing HAVING criterion via AND. + + This is an **in-place** mutation method; the + :meth:`~.Select.having` method is preferred, as it provides standard + :term:`method chaining`. + + """ + self._reset_exported() + self._having = and_(True_._ifnone(self._having), having) + + def append_from(self, fromclause): + """append the given FromClause expression to this select() construct's + FROM clause. + + This is an **in-place** mutation method; the + :meth:`~.Select.select_from` method is preferred, as it provides standard + :term:`method chaining`. + + """ + self._reset_exported() + fromclause = _interpret_as_from(fromclause) + self._from_obj = self._from_obj.union([fromclause]) + + + @_memoized_property + def _columns_plus_names(self): + if self.use_labels: + names = set() + def name_for_col(c): + if c._label is None: + return (None, c) + name = c._label + if name in names: + name = c.anon_label + else: + names.add(name) + return name, c + + return [ + name_for_col(c) + for c in util.unique_list(_select_iterables(self._raw_columns)) + ] + else: + return [ + (None, c) + for c in util.unique_list(_select_iterables(self._raw_columns)) + ] + + def _populate_column_collection(self): + for name, c in self._columns_plus_names: + if not hasattr(c, '_make_proxy'): + continue + if name is None: + key = None + elif self.use_labels: + key = c._key_label + if key is not None and key in self.c: + key = c.anon_label + else: + key = None + + c._make_proxy(self, key=key, + name=name, + name_is_truncatable=True) + + def _refresh_for_new_column(self, column): + for fromclause in self._froms: + col = fromclause._refresh_for_new_column(column) + if col is not None: + if col in self.inner_columns and self._cols_populated: + our_label = col._key_label if self.use_labels else col.key + if our_label not in self.c: + return col._make_proxy(self, + name=col._label if self.use_labels else None, + key=col._key_label if self.use_labels else None, + name_is_truncatable=True) + return None + return None + + def self_group(self, against=None): + """return a 'grouping' construct as per the ClauseElement + specification. + + This produces an element that can be embedded in an expression. Note + that this method is called automatically as needed when constructing + expressions and should not require explicit use. + + """ + if isinstance(against, CompoundSelect): + return self + return FromGrouping(self) + + def union(self, other, **kwargs): + """return a SQL UNION of this select() construct against the given + selectable.""" + + return CompoundSelect._create_union(self, other, **kwargs) + + def union_all(self, other, **kwargs): + """return a SQL UNION ALL of this select() construct against the given + selectable. + + """ + return CompoundSelect._create_union_all(self, other, **kwargs) + + def except_(self, other, **kwargs): + """return a SQL EXCEPT of this select() construct against the given + selectable.""" + + return CompoundSelect._create_except(self, other, **kwargs) + + def except_all(self, other, **kwargs): + """return a SQL EXCEPT ALL of this select() construct against the + given selectable. + + """ + return CompoundSelect._create_except_all(self, other, **kwargs) + + def intersect(self, other, **kwargs): + """return a SQL INTERSECT of this select() construct against the given + selectable. + + """ + return CompoundSelect._create_intersect(self, other, **kwargs) + + def intersect_all(self, other, **kwargs): + """return a SQL INTERSECT ALL of this select() construct against the + given selectable. + + """ + return CompoundSelect._create_intersect_all(self, other, **kwargs) + + def bind(self): + if self._bind: + return self._bind + froms = self._froms + if not froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(froms)[0].bind + if e: + self._bind = e + return e + + return None + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + +class ScalarSelect(Generative, Grouping): + _from_objects = [] + + def __init__(self, element): + self.element = element + self.type = element._scalar_type() + + @property + def columns(self): + raise exc.InvalidRequestError('Scalar Select expression has no ' + 'columns; use this object directly within a ' + 'column-level expression.') + c = columns + + @_generative + def where(self, crit): + """Apply a WHERE clause to the SELECT statement referred to + by this :class:`.ScalarSelect`. + + """ + self.element = self.element.where(crit) + + def self_group(self, **kwargs): + return self + + +class Exists(UnaryExpression): + """Represent an ``EXISTS`` clause. + + """ + __visit_name__ = UnaryExpression.__visit_name__ + _from_objects = [] + + + def __init__(self, *args, **kwargs): + """Construct a new :class:`.Exists` against an existing + :class:`.Select` object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([table.c.col1]).where(table.c.col2==5) + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where(criterion) + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where(table.c.col2==5) + + """ + if args and isinstance(args[0], (SelectBase, ScalarSelect)): + s = args[0] + else: + if not args: + args = ([literal_column('*')],) + s = Select(*args, **kwargs).as_scalar().self_group() + + UnaryExpression.__init__(self, s, operator=operators.exists, + type_=type_api.BOOLEANTYPE) + + def select(self, whereclause=None, **params): + return Select([self], whereclause, **params) + + def correlate(self, *fromclause): + e = self._clone() + e.element = self.element.correlate(*fromclause).self_group() + return e + + def correlate_except(self, *fromclause): + e = self._clone() + e.element = self.element.correlate_except(*fromclause).self_group() + return e + + def select_from(self, clause): + """return a new :class:`.Exists` construct, applying the given + expression to the :meth:`.Select.select_from` method of the select + statement contained. + + """ + e = self._clone() + e.element = self.element.select_from(clause).self_group() + return e + + def where(self, clause): + """return a new exists() construct with the given expression added to + its WHERE clause, joined to the existing clause via AND, if any. + + """ + e = self._clone() + e.element = self.element.where(clause).self_group() + return e + + +class TextAsFrom(SelectBase): + """Wrap a :class:`.TextClause` construct within a :class:`.SelectBase` + interface. + + This allows the :class:`.TextClause` object to gain a ``.c`` collection and + other FROM-like capabilities such as :meth:`.FromClause.alias`, + :meth:`.SelectBase.cte`, etc. + + The :class:`.TextAsFrom` construct is produced via the + :meth:`.TextClause.columns` method - see that method for details. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :func:`.text` + + :meth:`.TextClause.columns` + + """ + __visit_name__ = "text_as_from" + + _textual = True + + def __init__(self, text, columns): + self.element = text + self.column_args = columns + + @property + def _bind(self): + return self.element._bind + + @_generative + def bindparams(self, *binds, **bind_as_values): + self.element = self.element.bindparams(*binds, **bind_as_values) + + def _populate_column_collection(self): + for c in self.column_args: + c._make_proxy(self) + + def _copy_internals(self, clone=_clone, **kw): + self._reset_exported() + self.element = clone(self.element, **kw) + + def _scalar_type(self): + return self.column_args[0].type + +class AnnotatedFromClause(Annotated): + def __init__(self, element, values): + # force FromClause to generate their internal + # collections into __dict__ + element.c + Annotated.__init__(self, element, values) + + diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py new file mode 100644 index 00000000..f3468ebc --- /dev/null +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -0,0 +1,1639 @@ +# sql/sqltypes.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL specific types. + +""" + +import datetime as dt +import codecs + +from .type_api import TypeEngine, TypeDecorator, to_instance +from .elements import quoted_name, type_coerce +from .default_comparator import _DefaultColumnComparator +from .. import exc, util, processors +from .base import _bind_or_error, SchemaEventTarget +from . import operators +from .. import event +from ..util import pickle +import decimal + +if util.jython: + import array + +class _DateAffinity(object): + """Mixin date/time specific expression adaptations. + + Rules are implemented within Date,Time,Interval,DateTime, Numeric, + Integer. Based on http://www.postgresql.org/docs/current/static + /functions-datetime.html. + + """ + + @property + def _expression_adaptations(self): + raise NotImplementedError() + + class Comparator(TypeEngine.Comparator): + _blank_dict = util.immutabledict() + + def _adapt_expression(self, op, other_comparator): + othertype = other_comparator.type._type_affinity + return op, \ + to_instance(self.type._expression_adaptations.get(op, self._blank_dict).\ + get(othertype, NULLTYPE)) + comparator_factory = Comparator + +class Concatenable(object): + """A mixin that marks a type as supporting 'concatenation', + typically strings.""" + + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if op is operators.add and isinstance(other_comparator, + (Concatenable.Comparator, NullType.Comparator)): + return operators.concat_op, self.expr.type + else: + return op, self.expr.type + + comparator_factory = Comparator + + +class String(Concatenable, TypeEngine): + """The base for all string and character types. + + In SQL, corresponds to VARCHAR. Can also take Python unicode objects + and encode to the database's encoding in bind params (and the reverse for + result sets.) + + The `length` field is usually required when the `String` type is + used within a CREATE TABLE statement, as VARCHAR requires a length + on most databases. + + """ + + __visit_name__ = 'string' + + def __init__(self, length=None, collation=None, + convert_unicode=False, + unicode_error=None, + _warn_on_bytestring=False + ): + """ + Create a string-holding type. + + :param length: optional, a length for the column for use in + DDL and CAST expressions. May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + ``length`` for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued if a ``VARCHAR`` + with no length is included. Whether the value is + interpreted as bytes or characters is database specific. + + :param collation: Optional, a column-level collation for + use in DDL and CAST expressions. Renders using the + COLLATE keyword supported by SQLite, MySQL, and Postgresql. + E.g.:: + + >>> from sqlalchemy import cast, select, String + >>> print select([cast('some string', String(collation='utf8'))]) + SELECT CAST(:param_1 AS VARCHAR COLLATE utf8) AS anon_1 + + .. versionadded:: 0.8 Added support for COLLATE to all + string types. + + :param convert_unicode: When set to ``True``, the + :class:`.String` type will assume that + input is to be passed as Python ``unicode`` objects, + and results returned as Python ``unicode`` objects. + If the DBAPI in use does not support Python unicode + (which is fewer and fewer these days), SQLAlchemy + will encode/decode the value, using the + value of the ``encoding`` parameter passed to + :func:`.create_engine` as the encoding. + + When using a DBAPI that natively supports Python + unicode objects, this flag generally does not + need to be set. For columns that are explicitly + intended to store non-ASCII data, the :class:`.Unicode` + or :class:`.UnicodeText` + types should be used regardless, which feature + the same behavior of ``convert_unicode`` but + also indicate an underlying column type that + directly supports unicode, such as ``NVARCHAR``. + + For the extremely rare case that Python ``unicode`` + is to be encoded/decoded by SQLAlchemy on a backend + that does natively support Python ``unicode``, + the value ``force`` can be passed here which will + cause SQLAlchemy's encode/decode services to be + used unconditionally. + + :param unicode_error: Optional, a method to use to handle Unicode + conversion errors. Behaves like the ``errors`` keyword argument to + the standard library's ``string.decode()`` functions. This flag + requires that ``convert_unicode`` is set to ``force`` - otherwise, + SQLAlchemy is not guaranteed to handle the task of unicode + conversion. Note that this flag adds significant performance + overhead to row-fetching operations for backends that already + return unicode objects natively (which most DBAPIs do). This + flag should only be used as a last resort for reading + strings from a column with varied or corrupted encodings. + + """ + if unicode_error is not None and convert_unicode != 'force': + raise exc.ArgumentError("convert_unicode must be 'force' " + "when unicode_error is set.") + + self.length = length + self.collation = collation + self.convert_unicode = convert_unicode + self.unicode_error = unicode_error + self._warn_on_bytestring = _warn_on_bytestring + + def literal_processor(self, dialect): + def process(value): + value = value.replace("'", "''") + return "'%s'" % value + return process + + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + if dialect.supports_unicode_binds and \ + self.convert_unicode != 'force': + if self._warn_on_bytestring: + def process(value): + if isinstance(value, util.binary_type): + util.warn("Unicode type received non-unicode bind " + "param value.") + return value + return process + else: + return None + else: + encoder = codecs.getencoder(dialect.encoding) + warn_on_bytestring = self._warn_on_bytestring + + def process(value): + if isinstance(value, util.text_type): + return encoder(value, self.unicode_error)[0] + elif warn_on_bytestring and value is not None: + util.warn("Unicode type received non-unicode bind " + "param value") + return value + return process + else: + return None + + def result_processor(self, dialect, coltype): + wants_unicode = self.convert_unicode or dialect.convert_unicode + needs_convert = wants_unicode and \ + (dialect.returns_unicode_strings is not True or + self.convert_unicode in ('force', 'force_nocheck')) + needs_isinstance = ( + needs_convert and + dialect.returns_unicode_strings and + self.convert_unicode != 'force_nocheck' + ) + if needs_convert: + to_unicode = processors.to_unicode_processor_factory( + dialect.encoding, self.unicode_error) + + if needs_isinstance: + return processors.to_conditional_unicode_processor_factory( + dialect.encoding, self.unicode_error) + else: + return processors.to_unicode_processor_factory( + dialect.encoding, self.unicode_error) + else: + return None + + @property + def python_type(self): + if self.convert_unicode: + return util.text_type + else: + return str + + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class Text(String): + """A variably sized string type. + + In SQL, usually corresponds to CLOB or TEXT. Can also take Python + unicode objects and encode to the database's encoding in bind + params (and the reverse for result sets.) In general, TEXT objects + do not have a length; while some databases will accept a length + argument here, it will be rejected by others. + + """ + __visit_name__ = 'text' + + +class Unicode(String): + """A variable length Unicode string type. + + The :class:`.Unicode` type is a :class:`.String` subclass + that assumes input and output as Python ``unicode`` data, + and in that regard is equivalent to the usage of the + ``convert_unicode`` flag with the :class:`.String` type. + However, unlike plain :class:`.String`, it also implies an + underlying column type that is explicitly supporting of non-ASCII + data, such as ``NVARCHAR`` on Oracle and SQL Server. + This can impact the output of ``CREATE TABLE`` statements + and ``CAST`` functions at the dialect level, and can + also affect the handling of bound parameters in some + specific DBAPI scenarios. + + The encoding used by the :class:`.Unicode` type is usually + determined by the DBAPI itself; most modern DBAPIs + feature support for Python ``unicode`` objects as bound + values and result set values, and the encoding should + be configured as detailed in the notes for the target + DBAPI in the :ref:`dialect_toplevel` section. + + For those DBAPIs which do not support, or are not configured + to accommodate Python ``unicode`` objects + directly, SQLAlchemy does the encoding and decoding + outside of the DBAPI. The encoding in this scenario + is determined by the ``encoding`` flag passed to + :func:`.create_engine`. + + When using the :class:`.Unicode` type, it is only appropriate + to pass Python ``unicode`` objects, and not plain ``str``. + If a plain ``str`` is passed under Python 2, a warning + is emitted. If you notice your application emitting these warnings but + you're not sure of the source of them, the Python + ``warnings`` filter, documented at + http://docs.python.org/library/warnings.html, + can be used to turn these warnings into exceptions + which will illustrate a stack trace:: + + import warnings + warnings.simplefilter('error') + + For an application that wishes to pass plain bytestrings + and Python ``unicode`` objects to the ``Unicode`` type + equally, the bytestrings must first be decoded into + unicode. The recipe at :ref:`coerce_to_unicode` illustrates + how this is done. + + See also: + + :class:`.UnicodeText` - unlengthed textual counterpart + to :class:`.Unicode`. + + """ + + __visit_name__ = 'unicode' + + def __init__(self, length=None, **kwargs): + """ + Create a :class:`.Unicode` object. + + Parameters are the same as that of :class:`.String`, + with the exception that ``convert_unicode`` + defaults to ``True``. + + """ + kwargs.setdefault('convert_unicode', True) + kwargs.setdefault('_warn_on_bytestring', True) + super(Unicode, self).__init__(length=length, **kwargs) + + +class UnicodeText(Text): + """An unbounded-length Unicode string type. + + See :class:`.Unicode` for details on the unicode + behavior of this object. + + Like :class:`.Unicode`, usage the :class:`.UnicodeText` type implies a + unicode-capable type being used on the backend, such as + ``NCLOB``, ``NTEXT``. + + """ + + __visit_name__ = 'unicode_text' + + def __init__(self, length=None, **kwargs): + """ + Create a Unicode-converting Text type. + + Parameters are the same as that of :class:`.Text`, + with the exception that ``convert_unicode`` + defaults to ``True``. + + """ + kwargs.setdefault('convert_unicode', True) + kwargs.setdefault('_warn_on_bytestring', True) + super(UnicodeText, self).__init__(length=length, **kwargs) + + +class Integer(_DateAffinity, TypeEngine): + """A type for ``int`` integers.""" + + __visit_name__ = 'integer' + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + @property + def python_type(self): + return int + + def literal_processor(self, dialect): + def process(value): + return str(value) + return process + + @util.memoized_property + def _expression_adaptations(self): + # TODO: need a dictionary object that will + # handle operators generically here, this is incomplete + return { + operators.add: { + Date: Date, + Integer: self.__class__, + Numeric: Numeric, + }, + operators.mul: { + Interval: Interval, + Integer: self.__class__, + Numeric: Numeric, + }, + operators.div: { + Integer: self.__class__, + Numeric: Numeric, + }, + operators.truediv: { + Integer: self.__class__, + Numeric: Numeric, + }, + operators.sub: { + Integer: self.__class__, + Numeric: Numeric, + }, + } + + + +class SmallInteger(Integer): + """A type for smaller ``int`` integers. + + Typically generates a ``SMALLINT`` in DDL, and otherwise acts like + a normal :class:`.Integer` on the Python side. + + """ + + __visit_name__ = 'small_integer' + + +class BigInteger(Integer): + """A type for bigger ``int`` integers. + + Typically generates a ``BIGINT`` in DDL, and otherwise acts like + a normal :class:`.Integer` on the Python side. + + """ + + __visit_name__ = 'big_integer' + + + +class Numeric(_DateAffinity, TypeEngine): + """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``. + + This type returns Python ``decimal.Decimal`` objects by default, unless the + :paramref:`.Numeric.asdecimal` flag is set to False, in which case they + are coerced to Python ``float`` objects. + + .. note:: + + The :class:`.Numeric` type is designed to receive data from a database + type that is explicitly known to be a decimal type + (e.g. ``DECIMAL``, ``NUMERIC``, others) and not a floating point + type (e.g. ``FLOAT``, ``REAL``, others). + If the database column on the server is in fact a floating-point type + type, such as ``FLOAT`` or ``REAL``, use the :class:`.Float` + type or a subclass, otherwise numeric coercion between ``float``/``Decimal`` + may or may not function as expected. + + .. note:: + + The Python ``decimal.Decimal`` class is generally slow + performing; cPython 3.3 has now switched to use the `cdecimal + `_ library natively. For + older Python versions, the ``cdecimal`` library can be patched + into any application where it will replace the ``decimal`` + library fully, however this needs to be applied globally and + before any other modules have been imported, as follows:: + + import sys + import cdecimal + sys.modules["decimal"] = cdecimal + + Note that the ``cdecimal`` and ``decimal`` libraries are **not + compatible with each other**, so patching ``cdecimal`` at the + global level is the only way it can be used effectively with + various DBAPIs that hardcode to import the ``decimal`` library. + + """ + + __visit_name__ = 'numeric' + + _default_decimal_return_scale = 10 + + def __init__(self, precision=None, scale=None, + decimal_return_scale=None, asdecimal=True): + """ + Construct a Numeric. + + :param precision: the numeric precision for use in DDL ``CREATE + TABLE``. + + :param scale: the numeric scale for use in DDL ``CREATE TABLE``. + + :param asdecimal: default True. Return whether or not + values should be sent as Python Decimal objects, or + as floats. Different DBAPIs send one or the other based on + datatypes - the Numeric type will ensure that return values + are one or the other across DBAPIs consistently. + + :param decimal_return_scale: Default scale to use when converting + from floats to Python decimals. Floating point values will typically + be much longer due to decimal inaccuracy, and most floating point + database types don't have a notion of "scale", so by default the + float type looks for the first ten decimal places when converting. + Specfiying this value will override that length. Types which + do include an explicit ".scale" value, such as the base :class:`.Numeric` + as well as the MySQL float types, will use the value of ".scale" + as the default for decimal_return_scale, if not otherwise specified. + + .. versionadded:: 0.9.0 + + When using the ``Numeric`` type, care should be taken to ensure + that the asdecimal setting is apppropriate for the DBAPI in use - + when Numeric applies a conversion from Decimal->float or float-> + Decimal, this conversion incurs an additional performance overhead + for all result columns received. + + DBAPIs that return Decimal natively (e.g. psycopg2) will have + better accuracy and higher performance with a setting of ``True``, + as the native translation to Decimal reduces the amount of floating- + point issues at play, and the Numeric type itself doesn't need + to apply any further conversions. However, another DBAPI which + returns floats natively *will* incur an additional conversion + overhead, and is still subject to floating point data loss - in + which case ``asdecimal=False`` will at least remove the extra + conversion overhead. + + """ + self.precision = precision + self.scale = scale + self.decimal_return_scale = decimal_return_scale + self.asdecimal = asdecimal + + @property + def _effective_decimal_return_scale(self): + if self.decimal_return_scale is not None: + return self.decimal_return_scale + elif getattr(self, "scale", None) is not None: + return self.scale + else: + return self._default_decimal_return_scale + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def literal_processor(self, dialect): + def process(value): + return str(value) + return process + + @property + def python_type(self): + if self.asdecimal: + return decimal.Decimal + else: + return float + + def bind_processor(self, dialect): + if dialect.supports_native_decimal: + return None + else: + return processors.to_float + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if dialect.supports_native_decimal: + # we're a "numeric", DBAPI will give us Decimal directly + return None + else: + util.warn('Dialect %s+%s does *not* support Decimal ' + 'objects natively, and SQLAlchemy must ' + 'convert from floating point - rounding ' + 'errors and other issues may occur. Please ' + 'consider storing Decimal numbers as strings ' + 'or integers on this platform for lossless ' + 'storage.' % (dialect.name, dialect.driver)) + + # we're a "numeric", DBAPI returns floats, convert. + return processors.to_decimal_processor_factory( + decimal.Decimal, + self.scale if self.scale is not None + else self._default_decimal_return_scale) + else: + if dialect.supports_native_decimal: + return processors.to_float + else: + return None + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.mul: { + Interval: Interval, + Numeric: self.__class__, + Integer: self.__class__, + }, + operators.div: { + Numeric: self.__class__, + Integer: self.__class__, + }, + operators.truediv: { + Numeric: self.__class__, + Integer: self.__class__, + }, + operators.add: { + Numeric: self.__class__, + Integer: self.__class__, + }, + operators.sub: { + Numeric: self.__class__, + Integer: self.__class__, + } + } + + +class Float(Numeric): + """Type representing floating point types, such as ``FLOAT`` or ``REAL``. + + This type returns Python ``float`` objects by default, unless the + :paramref:`.Float.asdecimal` flag is set to True, in which case they + are coerced to ``decimal.Decimal`` objects. + + .. note:: + + The :class:`.Float` type is designed to receive data from a database + type that is explicitly known to be a floating point type + (e.g. ``FLOAT``, ``REAL``, others) + and not a decimal type (e.g. ``DECIMAL``, ``NUMERIC``, others). + If the database column on the server is in fact a Numeric + type, such as ``DECIMAL`` or ``NUMERIC``, use the :class:`.Numeric` + type or a subclass, otherwise numeric coercion between ``float``/``Decimal`` + may or may not function as expected. + + """ + + __visit_name__ = 'float' + + scale = None + + def __init__(self, precision=None, asdecimal=False, + decimal_return_scale=None, **kwargs): + """ + Construct a Float. + + :param precision: the numeric precision for use in DDL ``CREATE + TABLE``. + + :param asdecimal: the same flag as that of :class:`.Numeric`, but + defaults to ``False``. Note that setting this flag to ``True`` + results in floating point conversion. + + :param decimal_return_scale: Default scale to use when converting + from floats to Python decimals. Floating point values will typically + be much longer due to decimal inaccuracy, and most floating point + database types don't have a notion of "scale", so by default the + float type looks for the first ten decimal places when converting. + Specfiying this value will override that length. Note that the + MySQL float types, which do include "scale", will use "scale" + as the default for decimal_return_scale, if not otherwise specified. + + .. versionadded:: 0.9.0 + + :param \**kwargs: deprecated. Additional arguments here are ignored + by the default :class:`.Float` type. For database specific + floats that support additional arguments, see that dialect's + documentation for details, such as + :class:`sqlalchemy.dialects.mysql.FLOAT`. + + """ + self.precision = precision + self.asdecimal = asdecimal + self.decimal_return_scale = decimal_return_scale + if kwargs: + util.warn_deprecated("Additional keyword arguments " + "passed to Float ignored.") + + def result_processor(self, dialect, coltype): + if self.asdecimal: + return processors.to_decimal_processor_factory( + decimal.Decimal, + self._effective_decimal_return_scale) + else: + return None + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.mul: { + Interval: Interval, + Numeric: self.__class__, + }, + operators.div: { + Numeric: self.__class__, + }, + operators.truediv: { + Numeric: self.__class__, + }, + operators.add: { + Numeric: self.__class__, + }, + operators.sub: { + Numeric: self.__class__, + } + } + + +class DateTime(_DateAffinity, TypeEngine): + """A type for ``datetime.datetime()`` objects. + + Date and time types return objects from the Python ``datetime`` + module. Most DBAPIs have built in support for the datetime + module, with the noted exception of SQLite. In the case of + SQLite, date and time types are stored as strings which are then + converted back to datetime objects when rows are returned. + + """ + + __visit_name__ = 'datetime' + + def __init__(self, timezone=False): + """Construct a new :class:`.DateTime`. + + :param timezone: boolean. If True, and supported by the + backend, will produce 'TIMESTAMP WITH TIMEZONE'. For backends + that don't support timezone aware timestamps, has no + effect. + + """ + self.timezone = timezone + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @property + def python_type(self): + return dt.datetime + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add: { + Interval: self.__class__, + }, + operators.sub: { + Interval: self.__class__, + DateTime: Interval, + }, + } + + +class Date(_DateAffinity, TypeEngine): + """A type for ``datetime.date()`` objects.""" + + __visit_name__ = 'date' + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @property + def python_type(self): + return dt.date + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add: { + Integer: self.__class__, + Interval: DateTime, + Time: DateTime, + }, + operators.sub: { + # date - integer = date + Integer: self.__class__, + + # date - date = integer. + Date: Integer, + + Interval: DateTime, + + # date - datetime = interval, + # this one is not in the PG docs + # but works + DateTime: Interval, + }, + } + + +class Time(_DateAffinity, TypeEngine): + """A type for ``datetime.time()`` objects.""" + + __visit_name__ = 'time' + + def __init__(self, timezone=False): + self.timezone = timezone + + def get_dbapi_type(self, dbapi): + return dbapi.DATETIME + + @property + def python_type(self): + return dt.time + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add: { + Date: DateTime, + Interval: self.__class__ + }, + operators.sub: { + Time: Interval, + Interval: self.__class__, + }, + } + + +class _Binary(TypeEngine): + """Define base behavior for binary types.""" + + def __init__(self, length=None): + self.length = length + + def literal_processor(self, dialect): + def process(value): + value = value.decode(dialect.encoding).replace("'", "''") + return "'%s'" % value + return process + + @property + def python_type(self): + return util.binary_type + + # Python 3 - sqlite3 doesn't need the `Binary` conversion + # here, though pg8000 does to indicate "bytea" + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + return None + return process + + # Python 3 has native bytes() type + # both sqlite3 and pg8000 seem to return it, + # psycopg2 as of 2.5 returns 'memoryview' + if util.py2k: + def result_processor(self, dialect, coltype): + if util.jython: + def process(value): + if value is not None: + if isinstance(value, array.array): + return value.tostring() + return str(value) + else: + return None + else: + process = processors.to_str + return process + else: + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + value = bytes(value) + return value + return process + + def coerce_compared_value(self, op, value): + """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" + + if isinstance(value, util.string_types): + return self + else: + return super(_Binary, self).coerce_compared_value(op, value) + + def get_dbapi_type(self, dbapi): + return dbapi.BINARY + + +class LargeBinary(_Binary): + """A type for large binary byte data. + + The Binary type generates BLOB or BYTEA when tables are created, + and also converts incoming values using the ``Binary`` callable + provided by each DB-API. + + """ + + __visit_name__ = 'large_binary' + + def __init__(self, length=None): + """ + Construct a LargeBinary type. + + :param length: optional, a length for the column for use in + DDL statements, for those BLOB types that accept a length + (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY + type - use the BINARY/VARBINARY types specifically for those. + May be safely omitted if no ``CREATE + TABLE`` will be issued. Certain databases may require a + *length* for use in DDL, and will raise an exception when + the ``CREATE TABLE`` DDL is issued. + + """ + _Binary.__init__(self, length=length) + + +class Binary(LargeBinary): + """Deprecated. Renamed to LargeBinary.""" + + def __init__(self, *arg, **kw): + util.warn_deprecated('The Binary type has been renamed to ' + 'LargeBinary.') + LargeBinary.__init__(self, *arg, **kw) + + + +class SchemaType(SchemaEventTarget): + """Mark a type as possibly requiring schema-level DDL for usage. + + Supports types that must be explicitly created/dropped (i.e. PG ENUM type) + as well as types that are complimented by table or schema level + constraints, triggers, and other rules. + + :class:`.SchemaType` classes can also be targets for the + :meth:`.DDLEvents.before_parent_attach` and + :meth:`.DDLEvents.after_parent_attach` events, where the events fire off + surrounding the association of the type object with a parent + :class:`.Column`. + + .. seealso:: + + :class:`.Enum` + + :class:`.Boolean` + + + """ + + def __init__(self, name=None, schema=None, metadata=None, + inherit_schema=False, quote=None): + if name is not None: + self.name = quoted_name(name, quote) + else: + self.name = None + self.schema = schema + self.metadata = metadata + self.inherit_schema = inherit_schema + if self.metadata: + event.listen( + self.metadata, + "before_create", + util.portable_instancemethod(self._on_metadata_create) + ) + event.listen( + self.metadata, + "after_drop", + util.portable_instancemethod(self._on_metadata_drop) + ) + + def _set_parent(self, column): + column._on_table_attach(util.portable_instancemethod(self._set_table)) + + def _set_table(self, column, table): + if self.inherit_schema: + self.schema = table.schema + + event.listen( + table, + "before_create", + util.portable_instancemethod( + self._on_table_create) + ) + event.listen( + table, + "after_drop", + util.portable_instancemethod(self._on_table_drop) + ) + if self.metadata is None: + # TODO: what's the difference between self.metadata + # and table.metadata here ? + event.listen( + table.metadata, + "before_create", + util.portable_instancemethod(self._on_metadata_create) + ) + event.listen( + table.metadata, + "after_drop", + util.portable_instancemethod(self._on_metadata_drop) + ) + + def copy(self, **kw): + return self.adapt(self.__class__) + + def adapt(self, impltype, **kw): + schema = kw.pop('schema', self.schema) + metadata = kw.pop('metadata', self.metadata) + return impltype(name=self.name, + schema=schema, + metadata=metadata, + inherit_schema=self.inherit_schema, + **kw + ) + + @property + def bind(self): + return self.metadata and self.metadata.bind or None + + def create(self, bind=None, checkfirst=False): + """Issue CREATE ddl for this type, if applicable.""" + + if bind is None: + bind = _bind_or_error(self) + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t.create(bind=bind, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=False): + """Issue DROP ddl for this type, if applicable.""" + + if bind is None: + bind = _bind_or_error(self) + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t.drop(bind=bind, checkfirst=checkfirst) + + def _on_table_create(self, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t._on_table_create(target, bind, **kw) + + def _on_table_drop(self, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t._on_table_drop(target, bind, **kw) + + def _on_metadata_create(self, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t._on_metadata_create(target, bind, **kw) + + def _on_metadata_drop(self, target, bind, **kw): + t = self.dialect_impl(bind.dialect) + if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + t._on_metadata_drop(target, bind, **kw) + +class Enum(String, SchemaType): + """Generic Enum Type. + + The Enum type provides a set of possible string values which the + column is constrained towards. + + By default, uses the backend's native ENUM type if available, + else uses VARCHAR + a CHECK constraint. + + .. seealso:: + + :class:`~.postgresql.ENUM` - PostgreSQL-specific type, + which has additional functionality. + + """ + + __visit_name__ = 'enum' + + def __init__(self, *enums, **kw): + """Construct an enum. + + Keyword arguments which don't apply to a specific backend are ignored + by that backend. + + :param \*enums: string or unicode enumeration labels. If unicode + labels are present, the `convert_unicode` flag is auto-enabled. + + :param convert_unicode: Enable unicode-aware bind parameter and + result-set processing for this Enum's data. This is set + automatically based on the presence of unicode label strings. + + :param metadata: Associate this type directly with a ``MetaData`` + object. For types that exist on the target database as an + independent schema construct (Postgresql), this type will be + created and dropped within ``create_all()`` and ``drop_all()`` + operations. If the type is not associated with any ``MetaData`` + object, it will associate itself with each ``Table`` in which it is + used, and will be created when any of those individual tables are + created, after a check is performed for it's existence. The type is + only dropped when ``drop_all()`` is called for that ``Table`` + object's metadata, however. + + :param name: The name of this type. This is required for Postgresql + and any future supported database which requires an explicitly + named type, or an explicitly named constraint in order to generate + the type and/or a table that uses it. + + :param native_enum: Use the database's native ENUM type when + available. Defaults to True. When False, uses VARCHAR + check + constraint for all backends. + + :param schema: Schema name of this type. For types that exist on the + target database as an independent schema construct (Postgresql), + this parameter specifies the named schema in which the type is + present. + + .. note:: + + The ``schema`` of the :class:`.Enum` type does not + by default make use of the ``schema`` established on the + owning :class:`.Table`. If this behavior is desired, + set the ``inherit_schema`` flag to ``True``. + + :param quote: Set explicit quoting preferences for the type's name. + + :param inherit_schema: When ``True``, the "schema" from the owning + :class:`.Table` will be copied to the "schema" attribute of this + :class:`.Enum`, replacing whatever value was passed for the + ``schema`` attribute. This also takes effect when using the + :meth:`.Table.tometadata` operation. + + .. versionadded:: 0.8 + + """ + self.enums = enums + self.native_enum = kw.pop('native_enum', True) + convert_unicode = kw.pop('convert_unicode', None) + if convert_unicode is None: + for e in enums: + if isinstance(e, util.text_type): + convert_unicode = True + break + else: + convert_unicode = False + + if self.enums: + length = max(len(x) for x in self.enums) + else: + length = 0 + String.__init__(self, + length=length, + convert_unicode=convert_unicode, + ) + SchemaType.__init__(self, **kw) + + def __repr__(self): + return util.generic_repr(self, + to_inspect=[Enum, SchemaType], + ) + + def _should_create_constraint(self, compiler): + return not self.native_enum or \ + not compiler.dialect.supports_native_enum + + @util.dependencies("sqlalchemy.sql.schema") + def _set_table(self, schema, column, table): + if self.native_enum: + SchemaType._set_table(self, column, table) + + e = schema.CheckConstraint( + type_coerce(column, self).in_(self.enums), + name=self.name, + _create_rule=util.portable_instancemethod( + self._should_create_constraint) + ) + assert e.table is table + + def adapt(self, impltype, **kw): + schema = kw.pop('schema', self.schema) + metadata = kw.pop('metadata', self.metadata) + if issubclass(impltype, Enum): + return impltype(name=self.name, + schema=schema, + metadata=metadata, + convert_unicode=self.convert_unicode, + native_enum=self.native_enum, + inherit_schema=self.inherit_schema, + *self.enums, + **kw + ) + else: + return super(Enum, self).adapt(impltype, **kw) + + +class PickleType(TypeDecorator): + """Holds Python objects, which are serialized using pickle. + + PickleType builds upon the Binary type to apply Python's + ``pickle.dumps()`` to incoming objects, and ``pickle.loads()`` on + the way out, allowing any pickleable Python object to be stored as + a serialized binary field. + + To allow ORM change events to propagate for elements associated + with :class:`.PickleType`, see :ref:`mutable_toplevel`. + + """ + + impl = LargeBinary + + def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, + pickler=None, comparator=None): + """ + Construct a PickleType. + + :param protocol: defaults to ``pickle.HIGHEST_PROTOCOL``. + + :param pickler: defaults to cPickle.pickle or pickle.pickle if + cPickle is not available. May be any object with + pickle-compatible ``dumps` and ``loads`` methods. + + :param comparator: a 2-arg callable predicate used + to compare values of this type. If left as ``None``, + the Python "equals" operator is used to compare values. + + """ + self.protocol = protocol + self.pickler = pickler or pickle + self.comparator = comparator + super(PickleType, self).__init__() + + def __reduce__(self): + return PickleType, (self.protocol, + None, + self.comparator) + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + dumps = self.pickler.dumps + protocol = self.protocol + if impl_processor: + def process(value): + if value is not None: + value = dumps(value, protocol) + return impl_processor(value) + else: + def process(value): + if value is not None: + value = dumps(value, protocol) + return value + return process + + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) + loads = self.pickler.loads + if impl_processor: + def process(value): + value = impl_processor(value) + if value is None: + return None + return loads(value) + else: + def process(value): + if value is None: + return None + return loads(value) + return process + + def compare_values(self, x, y): + if self.comparator: + return self.comparator(x, y) + else: + return x == y + + +class Boolean(TypeEngine, SchemaType): + """A bool datatype. + + Boolean typically uses BOOLEAN or SMALLINT on the DDL side, and on + the Python side deals in ``True`` or ``False``. + + """ + + __visit_name__ = 'boolean' + + def __init__(self, create_constraint=True, name=None): + """Construct a Boolean. + + :param create_constraint: defaults to True. If the boolean + is generated as an int/smallint, also create a CHECK constraint + on the table that ensures 1 or 0 as a value. + + :param name: if a CHECK constraint is generated, specify + the name of the constraint. + + """ + self.create_constraint = create_constraint + self.name = name + + def _should_create_constraint(self, compiler): + return not compiler.dialect.supports_native_boolean + + @util.dependencies("sqlalchemy.sql.schema") + def _set_table(self, schema, column, table): + if not self.create_constraint: + return + + e = schema.CheckConstraint( + type_coerce(column, self).in_([0, 1]), + name=self.name, + _create_rule=util.portable_instancemethod( + self._should_create_constraint) + ) + assert e.table is table + + @property + def python_type(self): + return bool + + def literal_processor(self, dialect): + if dialect.supports_native_boolean: + def process(value): + return "true" if value else "false" + else: + def process(value): + return str(1 if value else 0) + return process + + def bind_processor(self, dialect): + if dialect.supports_native_boolean: + return None + else: + return processors.boolean_to_int + + def result_processor(self, dialect, coltype): + if dialect.supports_native_boolean: + return None + else: + return processors.int_to_boolean + + +class Interval(_DateAffinity, TypeDecorator): + """A type for ``datetime.timedelta()`` objects. + + The Interval type deals with ``datetime.timedelta`` objects. In + PostgreSQL, the native ``INTERVAL`` type is used; for others, the + value is stored as a date which is relative to the "epoch" + (Jan. 1, 1970). + + Note that the ``Interval`` type does not currently provide date arithmetic + operations on platforms which do not support interval types natively. Such + operations usually require transformation of both sides of the expression + (such as, conversion of both sides into integer epoch values first) which + currently is a manual procedure (such as via + :attr:`~sqlalchemy.sql.expression.func`). + + """ + + impl = DateTime + epoch = dt.datetime.utcfromtimestamp(0) + + def __init__(self, native=True, + second_precision=None, + day_precision=None): + """Construct an Interval object. + + :param native: when True, use the actual + INTERVAL type provided by the database, if + supported (currently Postgresql, Oracle). + Otherwise, represent the interval data as + an epoch value regardless. + + :param second_precision: For native interval types + which support a "fractional seconds precision" parameter, + i.e. Oracle and Postgresql + + :param day_precision: for native interval types which + support a "day precision" parameter, i.e. Oracle. + + """ + super(Interval, self).__init__() + self.native = native + self.second_precision = second_precision + self.day_precision = day_precision + + def adapt(self, cls, **kw): + if self.native and hasattr(cls, '_adapt_from_generic_interval'): + return cls._adapt_from_generic_interval(self, **kw) + else: + return self.__class__( + native=self.native, + second_precision=self.second_precision, + day_precision=self.day_precision, + **kw) + + @property + def python_type(self): + return dt.timedelta + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + epoch = self.epoch + if impl_processor: + def process(value): + if value is not None: + value = epoch + value + return impl_processor(value) + else: + def process(value): + if value is not None: + value = epoch + value + return value + return process + + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) + epoch = self.epoch + if impl_processor: + def process(value): + value = impl_processor(value) + if value is None: + return None + return value - epoch + else: + def process(value): + if value is None: + return None + return value - epoch + return process + + @util.memoized_property + def _expression_adaptations(self): + return { + operators.add: { + Date: DateTime, + Interval: self.__class__, + DateTime: DateTime, + Time: Time, + }, + operators.sub: { + Interval: self.__class__ + }, + operators.mul: { + Numeric: self.__class__ + }, + operators.truediv: { + Numeric: self.__class__ + }, + operators.div: { + Numeric: self.__class__ + } + } + + @property + def _type_affinity(self): + return Interval + + def coerce_compared_value(self, op, value): + """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" + + return self.impl.coerce_compared_value(op, value) + + +class REAL(Float): + """The SQL REAL type.""" + + __visit_name__ = 'REAL' + + +class FLOAT(Float): + """The SQL FLOAT type.""" + + __visit_name__ = 'FLOAT' + + +class NUMERIC(Numeric): + """The SQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + +class DECIMAL(Numeric): + """The SQL DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + +class INTEGER(Integer): + """The SQL INT or INTEGER type.""" + + __visit_name__ = 'INTEGER' +INT = INTEGER + + +class SMALLINT(SmallInteger): + """The SQL SMALLINT type.""" + + __visit_name__ = 'SMALLINT' + + +class BIGINT(BigInteger): + """The SQL BIGINT type.""" + + __visit_name__ = 'BIGINT' + + +class TIMESTAMP(DateTime): + """The SQL TIMESTAMP type.""" + + __visit_name__ = 'TIMESTAMP' + + def get_dbapi_type(self, dbapi): + return dbapi.TIMESTAMP + + +class DATETIME(DateTime): + """The SQL DATETIME type.""" + + __visit_name__ = 'DATETIME' + + +class DATE(Date): + """The SQL DATE type.""" + + __visit_name__ = 'DATE' + + +class TIME(Time): + """The SQL TIME type.""" + + __visit_name__ = 'TIME' + + +class TEXT(Text): + """The SQL TEXT type.""" + + __visit_name__ = 'TEXT' + + +class CLOB(Text): + """The CLOB type. + + This type is found in Oracle and Informix. + """ + + __visit_name__ = 'CLOB' + + +class VARCHAR(String): + """The SQL VARCHAR type.""" + + __visit_name__ = 'VARCHAR' + + +class NVARCHAR(Unicode): + """The SQL NVARCHAR type.""" + + __visit_name__ = 'NVARCHAR' + + +class CHAR(String): + """The SQL CHAR type.""" + + __visit_name__ = 'CHAR' + + +class NCHAR(Unicode): + """The SQL NCHAR type.""" + + __visit_name__ = 'NCHAR' + + +class BLOB(LargeBinary): + """The SQL BLOB type.""" + + __visit_name__ = 'BLOB' + + +class BINARY(_Binary): + """The SQL BINARY type.""" + + __visit_name__ = 'BINARY' + + +class VARBINARY(_Binary): + """The SQL VARBINARY type.""" + + __visit_name__ = 'VARBINARY' + + +class BOOLEAN(Boolean): + """The SQL BOOLEAN type.""" + + __visit_name__ = 'BOOLEAN' + +class NullType(TypeEngine): + """An unknown type. + + :class:`.NullType` is used as a default type for those cases where + a type cannot be determined, including: + + * During table reflection, when the type of a column is not recognized + by the :class:`.Dialect` + * When constructing SQL expressions using plain Python objects of + unknown types (e.g. ``somecolumn == my_special_object``) + * When a new :class:`.Column` is created, and the given type is passed + as ``None`` or is not passed at all. + + The :class:`.NullType` can be used within SQL expression invocation + without issue, it just has no behavior either at the expression construction + level or at the bind-parameter/result processing level. :class:`.NullType` + will result in a :exc:`.CompileError` if the compiler is asked to render + the type itself, such as if it is used in a :func:`.cast` operation + or within a schema creation operation such as that invoked by + :meth:`.MetaData.create_all` or the :class:`.CreateTable` construct. + + """ + __visit_name__ = 'null' + + _isnull = True + + def literal_processor(self, dialect): + def process(value): + return "NULL" + return process + + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if isinstance(other_comparator, NullType.Comparator) or \ + not operators.is_commutative(op): + return op, self.expr.type + else: + return other_comparator._adapt_expression(op, self) + comparator_factory = Comparator + + +NULLTYPE = NullType() +BOOLEANTYPE = Boolean() +STRINGTYPE = String() +INTEGERTYPE = Integer() + +_type_map = { + int: Integer(), + float: Numeric(), + bool: BOOLEANTYPE, + decimal.Decimal: Numeric(), + dt.date: Date(), + dt.datetime: DateTime(), + dt.time: Time(), + dt.timedelta: Interval(), + util.NoneType: NULLTYPE +} + +if util.py3k: + _type_map[bytes] = LargeBinary() + _type_map[str] = Unicode() +else: + _type_map[unicode] = Unicode() + _type_map[str] = String() + + +# back-assign to type_api +from . import type_api +type_api.BOOLEANTYPE = BOOLEANTYPE +type_api.STRINGTYPE = STRINGTYPE +type_api.INTEGERTYPE = INTEGERTYPE +type_api.NULLTYPE = NULLTYPE +type_api._type_map = _type_map + +# this one, there's all kinds of ways to play it, but at the EOD +# there's just a giant dependency cycle between the typing system and +# the expression element system, as you might expect. We can use +# importlaters or whatnot, but the typing system just necessarily has +# to have some kind of connection like this. right now we're injecting the +# _DefaultColumnComparator implementation into the TypeEngine.Comparator interface. +# Alternatively TypeEngine.Comparator could have an "impl" injected, though +# just injecting the base is simpler, error free, and more performant. +class Comparator(_DefaultColumnComparator): + BOOLEANTYPE = BOOLEANTYPE + +TypeEngine.Comparator.__bases__ = (Comparator, ) + TypeEngine.Comparator.__bases__ + diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py new file mode 100644 index 00000000..1f534c55 --- /dev/null +++ b/lib/sqlalchemy/sql/type_api.py @@ -0,0 +1,1064 @@ +# sql/types_api.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base types API. + +""" + + +from .. import exc, util +from . import operators +from .visitors import Visitable + +# these are back-assigned by sqltypes. +BOOLEANTYPE = None +INTEGERTYPE = None +NULLTYPE = None +STRINGTYPE = None + +class TypeEngine(Visitable): + """The ultimate base class for all SQL datatypes. + + Common subclasses of :class:`.TypeEngine` include + :class:`.String`, :class:`.Integer`, and :class:`.Boolean`. + + For an overview of the SQLAlchemy typing system, see :ref:`types_toplevel`. + + .. seealso:: + + :ref:`types_toplevel` + + """ + + _sqla_type = True + _isnull = False + + class Comparator(operators.ColumnOperators): + """Base class for custom comparison operations defined at the + type level. See :attr:`.TypeEngine.comparator_factory`. + + + """ + + def __init__(self, expr): + self.expr = expr + + def __reduce__(self): + return _reconstitute_comparator, (self.expr, ) + + + hashable = True + """Flag, if False, means values from this type aren't hashable. + + Used by the ORM when uniquing result lists. + + """ + + comparator_factory = Comparator + """A :class:`.TypeEngine.Comparator` class which will apply + to operations performed by owning :class:`.ColumnElement` objects. + + The :attr:`.comparator_factory` attribute is a hook consulted by + the core expression system when column and SQL expression operations + are performed. When a :class:`.TypeEngine.Comparator` class is + associated with this attribute, it allows custom re-definition of + all existing operators, as well as definition of new operators. + Existing operators include those provided by Python operator overloading + such as :meth:`.operators.ColumnOperators.__add__` and + :meth:`.operators.ColumnOperators.__eq__`, + those provided as standard + attributes of :class:`.operators.ColumnOperators` such as + :meth:`.operators.ColumnOperators.like` + and :meth:`.operators.ColumnOperators.in_`. + + Rudimentary usage of this hook is allowed through simple subclassing + of existing types, or alternatively by using :class:`.TypeDecorator`. + See the documentation section :ref:`types_operators` for examples. + + .. versionadded:: 0.8 The expression system was enhanced to support + customization of operators on a per-type level. + + """ + + def copy_value(self, value): + return value + + def literal_processor(self, dialect): + """Return a conversion function for processing literal values that are + to be rendered directly without using binds. + + This function is used when the compiler makes use of the + "literal_binds" flag, typically used in DDL generation as well + as in certain scenarios where backends don't accept bound parameters. + + .. versionadded:: 0.9.0 + + """ + return None + + def bind_processor(self, dialect): + """Return a conversion function for processing bind values. + + Returns a callable which will receive a bind parameter value + as the sole positional argument and will return a value to + send to the DB-API. + + If processing is not necessary, the method should return ``None``. + + :param dialect: Dialect instance in use. + + """ + return None + + def result_processor(self, dialect, coltype): + """Return a conversion function for processing result row values. + + Returns a callable which will receive a result row column + value as the sole positional argument and will return a value + to return to the user. + + If processing is not necessary, the method should return ``None``. + + :param dialect: Dialect instance in use. + + :param coltype: DBAPI coltype argument received in cursor.description. + + """ + return None + + def column_expression(self, colexpr): + """Given a SELECT column expression, return a wrapping SQL expression. + + This is typically a SQL function that wraps a column expression + as rendered in the columns clause of a SELECT statement. + It is used for special data types that require + columns to be wrapped in some special database function in order + to coerce the value before being sent back to the application. + It is the SQL analogue of the :meth:`.TypeEngine.result_processor` + method. + + The method is evaluated at statement compile time, as opposed + to statement construction time. + + See also: + + :ref:`types_sql_value_processing` + + """ + + return None + + @util.memoized_property + def _has_column_expression(self): + """memoized boolean, check if column_expression is implemented. + + Allows the method to be skipped for the vast majority of expression + types that don't use this feature. + + """ + + return self.__class__.column_expression.__code__ \ + is not TypeEngine.column_expression.__code__ + + def bind_expression(self, bindvalue): + """"Given a bind value (i.e. a :class:`.BindParameter` instance), + return a SQL expression in its place. + + This is typically a SQL function that wraps the existing bound + parameter within the statement. It is used for special data types + that require literals being wrapped in some special database function + in order to coerce an application-level value into a database-specific + format. It is the SQL analogue of the + :meth:`.TypeEngine.bind_processor` method. + + The method is evaluated at statement compile time, as opposed + to statement construction time. + + Note that this method, when implemented, should always return + the exact same structure, without any conditional logic, as it + may be used in an executemany() call against an arbitrary number + of bound parameter sets. + + See also: + + :ref:`types_sql_value_processing` + + """ + return None + + @util.memoized_property + def _has_bind_expression(self): + """memoized boolean, check if bind_expression is implemented. + + Allows the method to be skipped for the vast majority of expression + types that don't use this feature. + + """ + + return self.__class__.bind_expression.__code__ \ + is not TypeEngine.bind_expression.__code__ + + def compare_values(self, x, y): + """Compare two values for equality.""" + + return x == y + + def get_dbapi_type(self, dbapi): + """Return the corresponding type object from the underlying DB-API, if + any. + + This can be useful for calling ``setinputsizes()``, for example. + + """ + return None + + @property + def python_type(self): + """Return the Python type object expected to be returned + by instances of this type, if known. + + Basically, for those types which enforce a return type, + or are known across the board to do such for all common + DBAPIs (like ``int`` for example), will return that type. + + If a return type is not defined, raises + ``NotImplementedError``. + + Note that any type also accommodates NULL in SQL which + means you can also get back ``None`` from any type + in practice. + + """ + raise NotImplementedError() + + def with_variant(self, type_, dialect_name): + """Produce a new type object that will utilize the given + type when applied to the dialect of the given name. + + e.g.:: + + from sqlalchemy.types import String + from sqlalchemy.dialects import mysql + + s = String() + + s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql') + + The construction of :meth:`.TypeEngine.with_variant` is always + from the "fallback" type to that which is dialect specific. + The returned type is an instance of :class:`.Variant`, which + itself provides a :meth:`~sqlalchemy.types.Variant.with_variant` + that can be called repeatedly. + + :param type_: a :class:`.TypeEngine` that will be selected + as a variant from the originating type, when a dialect + of the given name is in use. + :param dialect_name: base name of the dialect which uses + this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + + .. versionadded:: 0.7.2 + + """ + return Variant(self, {dialect_name: type_}) + + + @util.memoized_property + def _type_affinity(self): + """Return a rudimental 'affinity' value expressing the general class + of type.""" + + typ = None + for t in self.__class__.__mro__: + if t in (TypeEngine, UserDefinedType): + return typ + elif issubclass(t, (TypeEngine, UserDefinedType)): + typ = t + else: + return self.__class__ + + def dialect_impl(self, dialect): + """Return a dialect-specific implementation for this + :class:`.TypeEngine`. + + """ + try: + return dialect._type_memos[self]['impl'] + except KeyError: + return self._dialect_info(dialect)['impl'] + + + def _cached_literal_processor(self, dialect): + """Return a dialect-specific literal processor for this type.""" + try: + return dialect._type_memos[self]['literal'] + except KeyError: + d = self._dialect_info(dialect) + d['literal'] = lp = d['impl'].literal_processor(dialect) + return lp + + def _cached_bind_processor(self, dialect): + """Return a dialect-specific bind processor for this type.""" + + try: + return dialect._type_memos[self]['bind'] + except KeyError: + d = self._dialect_info(dialect) + d['bind'] = bp = d['impl'].bind_processor(dialect) + return bp + + def _cached_result_processor(self, dialect, coltype): + """Return a dialect-specific result processor for this type.""" + + try: + return dialect._type_memos[self][coltype] + except KeyError: + d = self._dialect_info(dialect) + # key assumption: DBAPI type codes are + # constants. Else this dictionary would + # grow unbounded. + d[coltype] = rp = d['impl'].result_processor(dialect, coltype) + return rp + + def _dialect_info(self, dialect): + """Return a dialect-specific registry which + caches a dialect-specific implementation, bind processing + function, and one or more result processing functions.""" + + if self in dialect._type_memos: + return dialect._type_memos[self] + else: + impl = self._gen_dialect_impl(dialect) + if impl is self: + impl = self.adapt(type(self)) + # this can't be self, else we create a cycle + assert impl is not self + dialect._type_memos[self] = d = {'impl': impl} + return d + + def _gen_dialect_impl(self, dialect): + return dialect.type_descriptor(self) + + def adapt(self, cls, **kw): + """Produce an "adapted" form of this type, given an "impl" class + to work with. + + This method is used internally to associate generic + types with "implementation" types that are specific to a particular + dialect. + """ + return util.constructor_copy(self, cls, **kw) + + + def coerce_compared_value(self, op, value): + """Suggest a type for a 'coerced' Python value in an expression. + + Given an operator and value, gives the type a chance + to return a type which the value should be coerced into. + + The default behavior here is conservative; if the right-hand + side is already coerced into a SQL type based on its + Python type, it is usually left alone. + + End-user functionality extension here should generally be via + :class:`.TypeDecorator`, which provides more liberal behavior in that + it defaults to coercing the other side of the expression into this + type, thus applying special Python conversions above and beyond those + needed by the DBAPI to both ides. It also provides the public method + :meth:`.TypeDecorator.coerce_compared_value` which is intended for + end-user customization of this behavior. + + """ + _coerced_type = _type_map.get(type(value), NULLTYPE) + if _coerced_type is NULLTYPE or _coerced_type._type_affinity \ + is self._type_affinity: + return self + else: + return _coerced_type + + def _compare_type_affinity(self, other): + return self._type_affinity is other._type_affinity + + def compile(self, dialect=None): + """Produce a string-compiled form of this :class:`.TypeEngine`. + + When called with no arguments, uses a "default" dialect + to produce a string result. + + :param dialect: a :class:`.Dialect` instance. + + """ + # arg, return value is inconsistent with + # ClauseElement.compile()....this is a mistake. + + if not dialect: + dialect = self._default_dialect() + + return dialect.type_compiler.process(self) + + @util.dependencies("sqlalchemy.engine.default") + def _default_dialect(self, default): + if self.__class__.__module__.startswith("sqlalchemy.dialects"): + tokens = self.__class__.__module__.split(".")[0:3] + mod = ".".join(tokens) + return getattr(__import__(mod).dialects, tokens[-1]).dialect() + else: + return default.DefaultDialect() + + def __str__(self): + if util.py2k: + return unicode(self.compile()).\ + encode('ascii', 'backslashreplace') + else: + return str(self.compile()) + + def __repr__(self): + return util.generic_repr(self) + +class UserDefinedType(TypeEngine): + """Base for user defined types. + + This should be the base of new types. Note that + for most cases, :class:`.TypeDecorator` is probably + more appropriate:: + + import sqlalchemy.types as types + + class MyType(types.UserDefinedType): + def __init__(self, precision = 8): + self.precision = precision + + def get_col_spec(self): + return "MYTYPE(%s)" % self.precision + + def bind_processor(self, dialect): + def process(value): + return value + return process + + def result_processor(self, dialect, coltype): + def process(value): + return value + return process + + Once the type is made, it's immediately usable:: + + table = Table('foo', meta, + Column('id', Integer, primary_key=True), + Column('data', MyType(16)) + ) + + """ + __visit_name__ = "user_defined" + + + class Comparator(TypeEngine.Comparator): + def _adapt_expression(self, op, other_comparator): + if hasattr(self.type, 'adapt_operator'): + util.warn_deprecated( + "UserDefinedType.adapt_operator is deprecated. Create " + "a UserDefinedType.Comparator subclass instead which " + "generates the desired expression constructs, given a " + "particular operator." + ) + return self.type.adapt_operator(op), self.type + else: + return op, self.type + + comparator_factory = Comparator + + def coerce_compared_value(self, op, value): + """Suggest a type for a 'coerced' Python value in an expression. + + Default behavior for :class:`.UserDefinedType` is the + same as that of :class:`.TypeDecorator`; by default it returns + ``self``, assuming the compared value should be coerced into + the same type as this one. See + :meth:`.TypeDecorator.coerce_compared_value` for more detail. + + .. versionchanged:: 0.8 :meth:`.UserDefinedType.coerce_compared_value` + now returns ``self`` by default, rather than falling onto the + more fundamental behavior of + :meth:`.TypeEngine.coerce_compared_value`. + + """ + + return self + + +class TypeDecorator(TypeEngine): + """Allows the creation of types which add additional functionality + to an existing type. + + This method is preferred to direct subclassing of SQLAlchemy's + built-in types as it ensures that all required functionality of + the underlying type is kept in place. + + Typical usage:: + + import sqlalchemy.types as types + + class MyType(types.TypeDecorator): + '''Prefixes Unicode values with "PREFIX:" on the way in and + strips it off on the way out. + ''' + + impl = types.Unicode + + def process_bind_param(self, value, dialect): + return "PREFIX:" + value + + def process_result_value(self, value, dialect): + return value[7:] + + def copy(self): + return MyType(self.impl.length) + + The class-level "impl" attribute is required, and can reference any + TypeEngine class. Alternatively, the load_dialect_impl() method + can be used to provide different type classes based on the dialect + given; in this case, the "impl" variable can reference + ``TypeEngine`` as a placeholder. + + Types that receive a Python type that isn't similar to the ultimate type + used may want to define the :meth:`TypeDecorator.coerce_compared_value` + method. This is used to give the expression system a hint when coercing + Python objects into bind parameters within expressions. Consider this + expression:: + + mytable.c.somecol + datetime.date(2009, 5, 15) + + Above, if "somecol" is an ``Integer`` variant, it makes sense that + we're doing date arithmetic, where above is usually interpreted + by databases as adding a number of days to the given date. + The expression system does the right thing by not attempting to + coerce the "date()" value into an integer-oriented bind parameter. + + However, in the case of ``TypeDecorator``, we are usually changing an + incoming Python type to something new - ``TypeDecorator`` by default will + "coerce" the non-typed side to be the same type as itself. Such as below, + we define an "epoch" type that stores a date value as an integer:: + + class MyEpochType(types.TypeDecorator): + impl = types.Integer + + epoch = datetime.date(1970, 1, 1) + + def process_bind_param(self, value, dialect): + return (value - self.epoch).days + + def process_result_value(self, value, dialect): + return self.epoch + timedelta(days=value) + + Our expression of ``somecol + date`` with the above type will coerce the + "date" on the right side to also be treated as ``MyEpochType``. + + This behavior can be overridden via the + :meth:`~TypeDecorator.coerce_compared_value` method, which returns a type + that should be used for the value of the expression. Below we set it such + that an integer value will be treated as an ``Integer``, and any other + value is assumed to be a date and will be treated as a ``MyEpochType``:: + + def coerce_compared_value(self, op, value): + if isinstance(value, int): + return Integer() + else: + return self + + """ + + __visit_name__ = "type_decorator" + + def __init__(self, *args, **kwargs): + """Construct a :class:`.TypeDecorator`. + + Arguments sent here are passed to the constructor + of the class assigned to the ``impl`` class level attribute, + assuming the ``impl`` is a callable, and the resulting + object is assigned to the ``self.impl`` instance attribute + (thus overriding the class attribute of the same name). + + If the class level ``impl`` is not a callable (the unusual case), + it will be assigned to the same instance attribute 'as-is', + ignoring those arguments passed to the constructor. + + Subclasses can override this to customize the generation + of ``self.impl`` entirely. + + """ + + if not hasattr(self.__class__, 'impl'): + raise AssertionError("TypeDecorator implementations " + "require a class-level variable " + "'impl' which refers to the class of " + "type being decorated") + self.impl = to_instance(self.__class__.impl, *args, **kwargs) + + coerce_to_is_types = (util.NoneType, ) + """Specify those Python types which should be coerced at the expression + level to "IS " when compared using ``==`` (and same for + ``IS NOT`` in conjunction with ``!=``. + + For most SQLAlchemy types, this includes ``NoneType``, as well as ``bool``. + + :class:`.TypeDecorator` modifies this list to only include ``NoneType``, + as typedecorator implementations that deal with boolean types are common. + + Custom :class:`.TypeDecorator` classes can override this attribute to + return an empty tuple, in which case no values will be coerced to + constants. + + ..versionadded:: 0.8.2 + Added :attr:`.TypeDecorator.coerce_to_is_types` to allow for easier + control of ``__eq__()`` ``__ne__()`` operations. + + """ + + class Comparator(TypeEngine.Comparator): + + def operate(self, op, *other, **kwargs): + kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + return super(TypeDecorator.Comparator, self).operate( + op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + return super(TypeDecorator.Comparator, self).reverse_operate( + op, other, **kwargs) + + @property + def comparator_factory(self): + return type("TDComparator", + (TypeDecorator.Comparator, self.impl.comparator_factory), + {}) + + def _gen_dialect_impl(self, dialect): + """ + #todo + """ + adapted = dialect.type_descriptor(self) + if adapted is not self: + return adapted + + # otherwise adapt the impl type, link + # to a copy of this TypeDecorator and return + # that. + typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect) + tt = self.copy() + if not isinstance(tt, self.__class__): + raise AssertionError('Type object %s does not properly ' + 'implement the copy() method, it must ' + 'return an object of type %s' % (self, + self.__class__)) + tt.impl = typedesc + return tt + + @property + def _type_affinity(self): + """ + #todo + """ + return self.impl._type_affinity + + def type_engine(self, dialect): + """Return a dialect-specific :class:`.TypeEngine` instance + for this :class:`.TypeDecorator`. + + In most cases this returns a dialect-adapted form of + the :class:`.TypeEngine` type represented by ``self.impl``. + Makes usage of :meth:`dialect_impl` but also traverses + into wrapped :class:`.TypeDecorator` instances. + Behavior can be customized here by overriding + :meth:`load_dialect_impl`. + + """ + adapted = dialect.type_descriptor(self) + if type(adapted) is not type(self): + return adapted + elif isinstance(self.impl, TypeDecorator): + return self.impl.type_engine(dialect) + else: + return self.load_dialect_impl(dialect) + + def load_dialect_impl(self, dialect): + """Return a :class:`.TypeEngine` object corresponding to a dialect. + + This is an end-user override hook that can be used to provide + differing types depending on the given dialect. It is used + by the :class:`.TypeDecorator` implementation of :meth:`type_engine` + to help determine what type should ultimately be returned + for a given :class:`.TypeDecorator`. + + By default returns ``self.impl``. + + """ + return self.impl + + def __getattr__(self, key): + """Proxy all other undefined accessors to the underlying + implementation.""" + return getattr(self.impl, key) + + def process_literal_param(self, value, dialect): + """Receive a literal parameter value to be rendered inline within + a statement. + + This method is used when the compiler renders a + literal value without using binds, typically within DDL + such as in the "server default" of a column or an expression + within a CHECK constraint. + + The returned string will be rendered into the output string. + + .. versionadded:: 0.9.0 + + """ + raise NotImplementedError() + + def process_bind_param(self, value, dialect): + """Receive a bound parameter value to be converted. + + Subclasses override this method to return the + value that should be passed along to the underlying + :class:`.TypeEngine` object, and from there to the + DBAPI ``execute()`` method. + + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + This operation should be designed with the reverse operation + in mind, which would be the process_result_value method of + this class. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. + :param dialect: the :class:`.Dialect` in use. + + """ + + raise NotImplementedError() + + def process_result_value(self, value, dialect): + """Receive a result-row column value to be converted. + + Subclasses should implement this method to operate on data + fetched from the database. + + Subclasses override this method to return the + value that should be passed back to the application, + given a value that is already processed by + the underlying :class:`.TypeEngine` object, originally + from the DBAPI cursor method ``fetchone()`` or similar. + + The operation could be anything desired to perform custom + behavior, such as transforming or serializing data. + This could also be used as a hook for validating logic. + + :param value: Data to operate upon, of any type expected by + this method in the subclass. Can be ``None``. + :param dialect: the :class:`.Dialect` in use. + + This operation should be designed to be reversible by + the "process_bind_param" method of this class. + + """ + + raise NotImplementedError() + + @util.memoized_property + def _has_bind_processor(self): + """memoized boolean, check if process_bind_param is implemented. + + Allows the base process_bind_param to raise + NotImplementedError without needing to test an expensive + exception throw. + + """ + + return self.__class__.process_bind_param.__code__ \ + is not TypeDecorator.process_bind_param.__code__ + + @util.memoized_property + def _has_literal_processor(self): + """memoized boolean, check if process_literal_param is implemented. + + + """ + + return self.__class__.process_literal_param.__code__ \ + is not TypeDecorator.process_literal_param.__code__ + + def literal_processor(self, dialect): + """Provide a literal processing function for the given + :class:`.Dialect`. + + Subclasses here will typically override :meth:`.TypeDecorator.process_literal_param` + instead of this method directly. + + By default, this method makes use of :meth:`.TypeDecorator.process_bind_param` + if that method is implemented, where :meth:`.TypeDecorator.process_literal_param` + is not. The rationale here is that :class:`.TypeDecorator` typically deals + with Python conversions of data that are above the layer of database + presentation. With the value converted by :meth:`.TypeDecorator.process_bind_param`, + the underlying type will then handle whether it needs to be presented to the + DBAPI as a bound parameter or to the database as an inline SQL value. + + .. versionadded:: 0.9.0 + + """ + if self._has_literal_processor: + process_param = self.process_literal_param + elif self._has_bind_processor: + # the bind processor should normally be OK + # for TypeDecorator since it isn't doing DB-level + # handling, the handling here won't be different for bound vs. + # literals. + process_param = self.process_bind_param + else: + process_param = None + + if process_param: + impl_processor = self.impl.literal_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(process_param(value, dialect)) + else: + def process(value): + return process_param(value, dialect) + + return process + else: + return self.impl.literal_processor(dialect) + + def bind_processor(self, dialect): + """Provide a bound value processing function for the + given :class:`.Dialect`. + + This is the method that fulfills the :class:`.TypeEngine` + contract for bound value conversion. :class:`.TypeDecorator` + will wrap a user-defined implementation of + :meth:`process_bind_param` here. + + User-defined code can override this method directly, + though its likely best to use :meth:`process_bind_param` so that + the processing provided by ``self.impl`` is maintained. + + :param dialect: Dialect instance in use. + + This method is the reverse counterpart to the + :meth:`result_processor` method of this class. + + """ + if self._has_bind_processor: + process_param = self.process_bind_param + impl_processor = self.impl.bind_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(process_param(value, dialect)) + + else: + def process(value): + return process_param(value, dialect) + + return process + else: + return self.impl.bind_processor(dialect) + + @util.memoized_property + def _has_result_processor(self): + """memoized boolean, check if process_result_value is implemented. + + Allows the base process_result_value to raise + NotImplementedError without needing to test an expensive + exception throw. + + """ + return self.__class__.process_result_value.__code__ \ + is not TypeDecorator.process_result_value.__code__ + + def result_processor(self, dialect, coltype): + """Provide a result value processing function for the given + :class:`.Dialect`. + + This is the method that fulfills the :class:`.TypeEngine` + contract for result value conversion. :class:`.TypeDecorator` + will wrap a user-defined implementation of + :meth:`process_result_value` here. + + User-defined code can override this method directly, + though its likely best to use :meth:`process_result_value` so that + the processing provided by ``self.impl`` is maintained. + + :param dialect: Dialect instance in use. + :param coltype: An SQLAlchemy data type + + This method is the reverse counterpart to the + :meth:`bind_processor` method of this class. + + """ + if self._has_result_processor: + process_value = self.process_result_value + impl_processor = self.impl.result_processor(dialect, + coltype) + if impl_processor: + def process(value): + return process_value(impl_processor(value), dialect) + + else: + def process(value): + return process_value(value, dialect) + + return process + else: + return self.impl.result_processor(dialect, coltype) + + def coerce_compared_value(self, op, value): + """Suggest a type for a 'coerced' Python value in an expression. + + By default, returns self. This method is called by + the expression system when an object using this type is + on the left or right side of an expression against a plain Python + object which does not yet have a SQLAlchemy type assigned:: + + expr = table.c.somecolumn + 35 + + Where above, if ``somecolumn`` uses this type, this method will + be called with the value ``operator.add`` + and ``35``. The return value is whatever SQLAlchemy type should + be used for ``35`` for this particular operation. + + """ + return self + + def copy(self): + """Produce a copy of this :class:`.TypeDecorator` instance. + + This is a shallow copy and is provided to fulfill part of + the :class:`.TypeEngine` contract. It usually does not + need to be overridden unless the user-defined :class:`.TypeDecorator` + has local state that should be deep-copied. + + """ + + instance = self.__class__.__new__(self.__class__) + instance.__dict__.update(self.__dict__) + return instance + + def get_dbapi_type(self, dbapi): + """Return the DBAPI type object represented by this + :class:`.TypeDecorator`. + + By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the + underlying "impl". + """ + return self.impl.get_dbapi_type(dbapi) + + def compare_values(self, x, y): + """Given two values, compare them for equality. + + By default this calls upon :meth:`.TypeEngine.compare_values` + of the underlying "impl", which in turn usually + uses the Python equals operator ``==``. + + This function is used by the ORM to compare + an original-loaded value with an intercepted + "changed" value, to determine if a net change + has occurred. + + """ + return self.impl.compare_values(x, y) + + def __repr__(self): + return util.generic_repr(self, to_inspect=self.impl) + + +class Variant(TypeDecorator): + """A wrapping type that selects among a variety of + implementations based on dialect in use. + + The :class:`.Variant` type is typically constructed + using the :meth:`.TypeEngine.with_variant` method. + + .. versionadded:: 0.7.2 + + .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use. + + """ + + def __init__(self, base, mapping): + """Construct a new :class:`.Variant`. + + :param base: the base 'fallback' type + :param mapping: dictionary of string dialect names to + :class:`.TypeEngine` instances. + + """ + self.impl = base + self.mapping = mapping + + def load_dialect_impl(self, dialect): + if dialect.name in self.mapping: + return self.mapping[dialect.name] + else: + return self.impl + + def with_variant(self, type_, dialect_name): + """Return a new :class:`.Variant` which adds the given + type + dialect name to the mapping, in addition to the + mapping present in this :class:`.Variant`. + + :param type_: a :class:`.TypeEngine` that will be selected + as a variant from the originating type, when a dialect + of the given name is in use. + :param dialect_name: base name of the dialect which uses + this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + + """ + + if dialect_name in self.mapping: + raise exc.ArgumentError( + "Dialect '%s' is already present in " + "the mapping for this Variant" % dialect_name) + mapping = self.mapping.copy() + mapping[dialect_name] = type_ + return Variant(self.impl, mapping) + +def _reconstitute_comparator(expression): + return expression.comparator + + +def to_instance(typeobj, *arg, **kw): + if typeobj is None: + return NULLTYPE + + if util.callable(typeobj): + return typeobj(*arg, **kw) + else: + return typeobj + + +def adapt_type(typeobj, colspecs): + if isinstance(typeobj, type): + typeobj = typeobj() + for t in typeobj.__class__.__mro__[0:-1]: + try: + impltype = colspecs[t] + break + except KeyError: + pass + else: + # couldnt adapt - so just return the type itself + # (it may be a user-defined type) + return typeobj + # if we adapted the given generic type to a database-specific type, + # but it turns out the originally given "generic" type + # is actually a subclass of our resulting type, then we were already + # given a more specific type than that required; so use that. + if (issubclass(typeobj.__class__, impltype)): + return typeobj + return typeobj.adapt(impltype) + + diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py new file mode 100644 index 00000000..50ce30aa --- /dev/null +++ b/lib/sqlalchemy/sql/util.py @@ -0,0 +1,601 @@ +# sql/util.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""High level utilities which build upon other modules here. + +""" + +from .. import exc, util +from .base import _from_objects, ColumnSet +from . import operators, visitors +from itertools import chain +from collections import deque + +from .elements import BindParameter, ColumnClause, ColumnElement, \ + Null, UnaryExpression, literal_column, Label +from .selectable import ScalarSelect, Join, FromClause, FromGrouping +from .schema import Column + +join_condition = util.langhelpers.public_factory( + Join._join_condition, + ".sql.util.join_condition") + +# names that are still being imported from the outside +from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate +from .elements import _find_columns +from .ddl import sort_tables + + +def find_join_source(clauses, join_to): + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns + None, None if no match is found. + + e.g.:: + + clause1 = table1.join(table2) + clause2 = table4.join(table5) + + join_to = table2.join(table3) + + find_join_source([clause1, clause2], join_to) == clause1 + + """ + + selectables = list(_from_objects(join_to)) + for i, f in enumerate(clauses): + for s in selectables: + if f.is_derived_from(s): + return i, f + else: + return None, None + + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a q + a e + a f + b q + b e + b f + j r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + + def visit(element): + if isinstance(element, ScalarSelect): + # we dont want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + + +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, + include_selects=False, include_crud=False): + """locate Table objects within the given expression.""" + + tables = [] + _visitors = {} + + if include_selects: + _visitors['select'] = _visitors['compound_select'] = tables.append + + if include_joins: + _visitors['join'] = tables.append + + if include_aliases: + _visitors['alias'] = tables.append + + if include_crud: + _visitors['insert'] = _visitors['update'] = \ + _visitors['delete'] = lambda ent: tables.append(ent.table) + + if check_columns: + def visit_column(column): + tables.append(column.table) + _visitors['column'] = visit_column + + _visitors['table'] = tables.append + + visitors.traverse(clause, {'column_collections': False}, _visitors) + return tables + + + +def unwrap_order_by(clause): + """Break up an 'order by' expression into individual column-expressions, + without DESC/ASC/NULLS FIRST/NULLS LAST""" + + cols = util.column_set() + stack = deque([clause]) + while stack: + t = stack.popleft() + if isinstance(t, ColumnElement) and \ + ( + not isinstance(t, UnaryExpression) or \ + not operators.is_ordering_modifier(t.modifier) + ): + cols.add(t) + else: + for c in t.get_children(): + stack.append(c) + return cols + + +def clause_is_present(clause, search): + """Given a target clause and a second to search within, return True + if the target is plainly present in the search without any + subqueries or aliases involved. + + Basically descends through Joins. + + """ + + for elem in surface_selectables(search): + if clause == elem: # use == here so that Annotated's compare + return True + else: + return False + +def surface_selectables(clause): + stack = [clause] + while stack: + elem = stack.pop() + yield elem + if isinstance(elem, Join): + stack.extend((elem.left, elem.right)) + elif isinstance(elem, FromGrouping): + stack.append(elem.element) + +def selectables_overlap(left, right): + """Return True if left/right have some overlapping selectable""" + + return bool( + set(surface_selectables(left)).intersection( + surface_selectables(right) + ) + ) + +def bind_values(clause): + """Return an ordered list of "bound" values in the given clause. + + E.g.:: + + >>> expr = and_( + ... table.c.foo==5, table.c.foo==7 + ... ) + >>> bind_values(expr) + [5, 7] + """ + + v = [] + + def visit_bindparam(bind): + v.append(bind.effective_value) + + visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + return v + + +def _quote_ddl_expr(element): + if isinstance(element, util.string_types): + element = element.replace("'", "''") + return "'%s'" % element + else: + return repr(element) + + +class _repr_params(object): + """A string view of bound parameters, truncating + display to the given number of 'multi' parameter sets. + + """ + def __init__(self, params, batches): + self.params = params + self.batches = batches + + def __repr__(self): + if isinstance(self.params, (list, tuple)) and \ + len(self.params) > self.batches and \ + isinstance(self.params[0], (list, dict, tuple)): + msg = " ... displaying %i of %i total bound parameter sets ... " + return ' '.join(( + repr(self.params[:self.batches - 2])[0:-1], + msg % (self.batches, len(self.params)), + repr(self.params[-2:])[1:] + )) + else: + return repr(self.params) + + + + +def adapt_criterion_to_null(crit, nulls): + """given criterion containing bind params, convert selected elements + to IS NULL. + + """ + + def visit_binary(binary): + if isinstance(binary.left, BindParameter) \ + and binary.left._identifying_key in nulls: + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = Null() + binary.operator = operators.is_ + binary.negate = operators.isnot + elif isinstance(binary.right, BindParameter) \ + and binary.right._identifying_key in nulls: + binary.right = Null() + binary.operator = operators.is_ + binary.negate = operators.isnot + + return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) + + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, Join) and right is not stop_on: + right = right._clone() + right._reset_exported() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright is not None: + prevright.left = right + if ret is None: + ret = right + + return ret + + +def reduce_columns(columns, *clauses, **kw): + """given a list of columns, return a 'reduced' set based on natural + equivalents. + + the set is reduced to the smallest list of columns which have no natural + equivalent present in the list. A "natural equivalent" means that two + columns will ultimately represent the same value because they are related + by a foreign key. + + \*clauses is an optional list of join clauses which will be traversed + to further identify columns that are "equivalent". + + \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys + whose tables are not yet configured, or columns that aren't yet present. + + This function is primarily used to determine the most minimal "primary key" + from a selectable, by reducing the set of primary key columns present + in the the selectable to just those that are not repeated. + + """ + ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) + only_synonyms = kw.pop('only_synonyms', False) + + columns = util.ordered_column_set(columns) + + omit = util.column_set() + for col in columns: + for fk in chain(*[c.foreign_keys for c in col.proxy_set]): + for c in columns: + if c is col: + continue + try: + fk_col = fk.column + except exc.NoReferencedColumnError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest + if ignore_nonexistent_tables: + continue + else: + raise + except exc.NoReferencedTableError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest + if ignore_nonexistent_tables: + continue + else: + raise + if fk_col.shares_lineage(c) and \ + (not only_synonyms or \ + c.name == col.name): + omit.add(col) + break + + if clauses: + def visit_binary(binary): + if binary.operator == operators.eq: + cols = util.column_set(chain(*[c.proxy_set + for c in columns.difference(omit)])) + if binary.left in cols and binary.right in cols: + for c in reversed(columns): + if c.shares_lineage(binary.right) and \ + (not only_synonyms or \ + c.name == binary.left.name): + omit.add(c) + break + for clause in clauses: + if clause is not None: + visitors.traverse(clause, {}, {'binary': visit_binary}) + + return ColumnSet(columns.difference(omit)) + + +def criterion_as_pairs(expression, consider_as_foreign_keys=None, + consider_as_referenced_keys=None, any_operator=False): + """traverse an expression and locate binary criterion pairs.""" + + if consider_as_foreign_keys and consider_as_referenced_keys: + raise exc.ArgumentError("Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'") + + def col_is(a, b): + #return a is b + return a.compare(b) + + def visit_binary(binary): + if not any_operator and binary.operator is not operators.eq: + return + if not isinstance(binary.left, ColumnElement) or \ + not isinstance(binary.right, ColumnElement): + return + + if consider_as_foreign_keys: + if binary.left in consider_as_foreign_keys and \ + (col_is(binary.right, binary.left) or + binary.right not in consider_as_foreign_keys): + pairs.append((binary.right, binary.left)) + elif binary.right in consider_as_foreign_keys and \ + (col_is(binary.left, binary.right) or + binary.left not in consider_as_foreign_keys): + pairs.append((binary.left, binary.right)) + elif consider_as_referenced_keys: + if binary.left in consider_as_referenced_keys and \ + (col_is(binary.right, binary.left) or + binary.right not in consider_as_referenced_keys): + pairs.append((binary.left, binary.right)) + elif binary.right in consider_as_referenced_keys and \ + (col_is(binary.left, binary.right) or + binary.left not in consider_as_referenced_keys): + pairs.append((binary.right, binary.left)) + else: + if isinstance(binary.left, Column) and \ + isinstance(binary.right, Column): + if binary.left.references(binary.right): + pairs.append((binary.right, binary.left)) + elif binary.right.references(binary.left): + pairs.append((binary.left, binary.right)) + pairs = [] + visitors.traverse(expression, {}, {'binary': visit_binary}) + return pairs + + + +class AliasedRow(object): + """Wrap a RowProxy with a translation map. + + This object allows a set of keys to be translated + to those present in a RowProxy. + + """ + def __init__(self, row, map): + # AliasedRow objects don't nest, so un-nest + # if another AliasedRow was passed + if isinstance(row, AliasedRow): + self.row = row.row + else: + self.row = row + self.map = map + + def __contains__(self, key): + return self.map[key] in self.row + + def has_key(self, key): + return key in self + + def __getitem__(self, key): + return self.row[self.map[key]] + + def keys(self): + return self.row.keys() + + +class ClauseAdapter(visitors.ReplacingCloningVisitor): + """Clones and modifies clauses based on column correspondence. + + E.g.:: + + table1 = Table('sometable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + table2 = Table('someothertable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + + condition = table1.c.col1 == table2.c.col1 + + make an alias of table1:: + + s = table1.alias('foo') + + calling ``ClauseAdapter(s).traverse(condition)`` converts + condition to read:: + + s.c.col1 == table2.c.col1 + + """ + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False): + self.__traverse_options__ = {'stop_on': [selectable]} + self.selectable = selectable + if include: + assert not include_fn + self.include_fn = lambda e: e in include + else: + self.include_fn = include_fn + if exclude: + assert not exclude_fn + self.exclude_fn = lambda e: e in exclude + else: + self.exclude_fn = exclude_fn + self.equivalents = util.column_dict(equivalents or {}) + self.adapt_on_names = adapt_on_names + + def _corresponding_column(self, col, require_embedded, + _seen=util.EMPTY_SET): + newcol = self.selectable.corresponding_column( + col, + require_embedded=require_embedded) + if newcol is None and col in self.equivalents and col not in _seen: + for equiv in self.equivalents[col]: + newcol = self._corresponding_column(equiv, + require_embedded=require_embedded, + _seen=_seen.union([col])) + if newcol is not None: + return newcol + if self.adapt_on_names and newcol is None: + newcol = self.selectable.c.get(col.name) + return newcol + + magic_flag = False + def replace(self, col): + if not self.magic_flag and isinstance(col, FromClause) and \ + self.selectable.is_derived_from(col): + return self.selectable + elif not isinstance(col, ColumnElement): + return None + elif self.include_fn and not self.include_fn(col): + return None + elif self.exclude_fn and self.exclude_fn(col): + return None + else: + return self._corresponding_column(col, True) + + +class ColumnAdapter(ClauseAdapter): + """Extends ClauseAdapter with extra utility functions. + + Provides the ability to "wrap" this ClauseAdapter + around another, a columns dictionary which returns + adapted elements given an original, and an + adapted_row() factory. + + """ + def __init__(self, selectable, equivalents=None, + chain_to=None, include=None, + exclude=None, adapt_required=False): + ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + if chain_to: + self.chain(chain_to) + self.columns = util.populate_column_dict(self._locate_col) + self.adapt_required = adapt_required + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__ = self.__dict__.copy() + ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) + ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) + ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.columns = util.populate_column_dict(ac._locate_col) + return ac + + adapt_clause = ClauseAdapter.traverse + adapt_list = ClauseAdapter.copy_and_process + + def _wrap(self, local, wrapped): + def locate(col): + col = local(col) + return wrapped(col) + return locate + + def _locate_col(self, col): + c = self._corresponding_column(col, True) + if c is None: + c = self.adapt_clause(col) + + # anonymize labels in case they have a hardcoded name + if isinstance(c, Label): + c = c.label(None) + + # adapt_required used by eager loading to indicate that + # we don't trust a result row column that is not translated. + # this is to prevent a column from being interpreted as that + # of the child row in a self-referential scenario, see + # inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency + if self.adapt_required and c is col: + return None + + return c + + def adapted_row(self, row): + return AliasedRow(row, self.columns) + + def __getstate__(self): + d = self.__dict__.copy() + del d['columns'] + return d + + def __setstate__(self, state): + self.__dict__.update(state) + self.columns = util.PopulateDict(self._locate_col) + diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py new file mode 100644 index 00000000..d9ad04fc --- /dev/null +++ b/lib/sqlalchemy/sql/visitors.py @@ -0,0 +1,314 @@ +# sql/visitors.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Visitor/traversal interface and library functions. + +SQLAlchemy schema and expression constructs rely on a Python-centric +version of the classic "visitor" pattern as the primary way in which +they apply functionality. The most common use of this pattern +is statement compilation, where individual expression classes match +up to rendering methods that produce a string result. Beyond this, +the visitor system is also used to inspect expressions for various +information and patterns, as well as for usage in +some kinds of expression transformation. Other kinds of transformation +use a non-visitor traversal system. + +For many examples of how the visit system is used, see the +sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules. +For an introduction to clause adaption, see +http://techspot.zzzeek.org/2008/01/23/expression-transformations/ + +""" + +from collections import deque +from .. import util +import operator +from .. import exc + +__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', + 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', + 'iterate_depthfirst', 'traverse_using', 'traverse', + 'traverse_depthfirst', + 'cloned_traverse', 'replacement_traverse'] + + +class VisitableType(type): + """Metaclass which assigns a `_compiler_dispatch` method to classes + having a `__visit_name__` attribute. + + The _compiler_dispatch attribute becomes an instance method which + looks approximately like the following:: + + def _compiler_dispatch (self, visitor, **kw): + '''Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params.''' + visit_attr = 'visit_%s' % self.__visit_name__ + return getattr(visitor, visit_attr)(self, **kw) + + Classes having no __visit_name__ attribute will remain unaffected. + """ + def __init__(cls, clsname, bases, clsdict): + if clsname != 'Visitable' and \ + hasattr(cls, '__visit_name__'): + _generate_dispatch(cls) + + super(VisitableType, cls).__init__(clsname, bases, clsdict) + + +def _generate_dispatch(cls): + """Return an optimized visit dispatch function for the cls + for use by the compiler. + """ + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. + getter = operator.attrgetter("visit_%s" % visit_name) + + def _compiler_dispatch(self, visitor, **kw): + try: + meth = getter(visitor) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. + def _compiler_dispatch(self, visitor, **kw): + visit_attr = 'visit_%s' % self.__visit_name__ + try: + meth = getattr(visitor, visit_attr) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + _compiler_dispatch.__doc__ = \ + """Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params. + """ + cls._compiler_dispatch = _compiler_dispatch + + +class Visitable(util.with_metaclass(VisitableType, object)): + """Base class for visitable objects, applies the + ``VisitableType`` metaclass. + + """ + + +class ClauseVisitor(object): + """Base class for visitor objects which can traverse using + the traverse() function. + + """ + + __traverse_options__ = {} + + def traverse_single(self, obj, **kw): + for v in self._visitor_iterator: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kw) + + def iterate(self, obj): + """traverse the given expression structure, returning an iterator + of all elements. + + """ + return iterate(obj, self.__traverse_options__) + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + return traverse(obj, self.__traverse_options__, self._visitor_dict) + + @util.memoized_property + def _visitor_dict(self): + visitors = {} + + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + return visitors + + @property + def _visitor_iterator(self): + """iterate through this visitor and each 'chained' visitor.""" + + v = self + while v: + yield v + v = getattr(v, '_next', None) + + def chain(self, visitor): + """'chain' an additional ClauseVisitor onto this ClauseVisitor. + + the chained visitor will receive all visit events after this one. + + """ + tail = list(self._visitor_iterator)[-1] + tail._next = visitor + return self + + +class CloningVisitor(ClauseVisitor): + """Base class for visitor objects which can traverse using + the cloned_traverse() function. + + """ + + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return + the new list. + + """ + return [self.traverse(x) for x in list_] + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + return cloned_traverse( + obj, self.__traverse_options__, self._visitor_dict) + + +class ReplacingCloningVisitor(CloningVisitor): + """Base class for visitor objects which can traverse using + the replacement_traverse() function. + + """ + + def replace(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + def replace(elem): + for v in self._visitor_iterator: + e = v.replace(elem) + if e is not None: + return e + return replacement_traverse(obj, self.__traverse_options__, replace) + + +def iterate(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be breadth-first. + + """ + stack = deque([obj]) + while stack: + t = stack.popleft() + yield t + for c in t.get_children(**opts): + stack.append(c) + + +def iterate_depthfirst(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be depth-first. + + """ + stack = deque([obj]) + traversal = deque() + while stack: + t = stack.pop() + traversal.appendleft(t) + for c in t.get_children(**opts): + stack.append(c) + return iter(traversal) + + +def traverse_using(iterator, obj, visitors): + """visit the given expression structure using the given iterator of + objects. + + """ + for target in iterator: + meth = visitors.get(target.__visit_name__, None) + if meth: + meth(target) + return obj + + +def traverse(obj, opts, visitors): + """traverse and visit the given expression structure using the default + iterator. + + """ + return traverse_using(iterate(obj, opts), obj, visitors) + + +def traverse_depthfirst(obj, opts, visitors): + """traverse and visit the given expression structure using the + depth-first iterator. + + """ + return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + + +def cloned_traverse(obj, opts, visitors): + """clone the given expression structure, allowing + modifications by visitors.""" + + cloned = {} + stop_on = set(opts.get('stop_on', [])) + + def clone(elem): + if elem in stop_on: + return elem + else: + if id(elem) not in cloned: + cloned[id(elem)] = newelem = elem._clone() + newelem._copy_internals(clone=clone) + meth = visitors.get(newelem.__visit_name__, None) + if meth: + meth(newelem) + return cloned[id(elem)] + + if obj is not None: + obj = clone(obj) + return obj + + +def replacement_traverse(obj, opts, replace): + """clone the given expression structure, allowing element + replacement by a given replacement function.""" + + cloned = {} + stop_on = set([id(x) for x in opts.get('stop_on', [])]) + + def clone(elem, **kw): + if id(elem) in stop_on or \ + 'no_replacement_traverse' in elem._annotations: + return elem + else: + newelem = replace(elem) + if newelem is not None: + stop_on.add(id(newelem)) + return newelem + else: + if elem not in cloned: + cloned[elem] = newelem = elem._clone() + newelem._copy_internals(clone=clone, **kw) + return cloned[elem] + + if obj is not None: + obj = clone(obj, **opts) + return obj diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py new file mode 100644 index 00000000..95490643 --- /dev/null +++ b/lib/sqlalchemy/testing/__init__.py @@ -0,0 +1,32 @@ +# testing/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +from .warnings import testing_warn, assert_warnings, resetwarnings + +from . import config + +from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ + fails_on, fails_on_everything_except, skip, only_on, exclude, \ + against as _against, _server_version, only_if + + +def against(*queries): + return _against(config._current, *queries) + +from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ + eq_, ne_, is_, is_not_, startswith_, assert_raises, \ + assert_raises_message, AssertsCompiledSQL, ComparesTables, \ + AssertsExecutionResults, expect_deprecated + +from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict + +crashes = skip + +from .config import db +from .config import requirements as requires + +from . import mock \ No newline at end of file diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py new file mode 100644 index 00000000..edc9df04 --- /dev/null +++ b/lib/sqlalchemy/testing/assertions.py @@ -0,0 +1,439 @@ +# testing/assertions.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +from . import util as testutil +from sqlalchemy import pool, orm, util +from sqlalchemy.engine import default, create_engine, url +from sqlalchemy import exc as sa_exc +from sqlalchemy.util import decorator +from sqlalchemy import types as sqltypes, schema +import warnings +import re +from .warnings import resetwarnings +from .exclusions import db_spec, _is_excluded +from . import assertsql +from . import config +import itertools +from .util import fail +import contextlib + + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + + @decorator + def decorate(fn, *args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) + else: + filters.extend(dict(action='ignore', + message=message, + category=sa_exc.SAWarning) + for message in messages) + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return decorate + + +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + spec = db_spec(db) + + @decorator + def decorate(fn, *args, **kw): + if isinstance(db, util.string_types): + if not spec(config._current): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_deprecated(*messages): + return fn(*args, **kw) + return decorate + +@contextlib.contextmanager +def expect_deprecated(*messages): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [(m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) + + for f in filters: + warnings.filterwarnings(**f) + try: + yield + finally: + resetwarnings() + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + _assert_no_stray_pool_connections() + +_STRAY_CONNECTION_FAILURES = 0 +def _assert_no_stray_pool_connections(): + global _STRAY_CONNECTION_FAILURES + + # lazy gc on cPython means "do nothing." pool connections + # shouldn't be in cycles, should go away. + testutil.lazy_gc() + + # however, once in awhile, on an EC2 machine usually, + # there's a ref in there. usually just one. + if pool._refs: + + # OK, let's be somewhat forgiving. Increment a counter, + # we'll allow a couple of these at most. + _STRAY_CONNECTION_FAILURES += 1 + + print("Encountered a stray connection in test cleanup: %s" + % str(pool._refs)) + # then do a real GC sweep. We shouldn't even be here + # so a single sweep should really be doing it, otherwise + # there's probably a real unreachable cycle somewhere. + testutil.gc_collect() + + # if we've already had two of these occurrences, or + # after a hard gc sweep we still have pool._refs?! + # now we have to raise. + if _STRAY_CONNECTION_FAILURES >= 2 or pool._refs: + err = str(pool._refs) + + # but clean out the pool refs collection directly, + # reset the counter, + # so the error doesn't at least keep happening. + pool._refs.clear() + _STRAY_CONNECTION_FAILURES = 0 + assert False, "Stray conections in cleanup: %s" % err + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, fragment) + + +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + success = False + except except_cls: + success = True + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls as e: + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) + print(util.text_type(e).encode('utf-8')) + + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, + checkparams=None, dialect=None, + checkpositional=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False): + if use_default_dialect: + dialect = default.DefaultDialect() + elif allow_dialect_select: + dialect = None + else: + if dialect is None: + dialect = getattr(self, '__dialect__', None) + + if dialect is None: + dialect = config.db.dialect + elif dialect == 'default': + dialect = default.DefaultDialect() + elif isinstance(dialect, util.string_types): + dialect = url.URL(dialect).get_dialect()() + + + kw = {} + compile_kwargs = {} + + if params is not None: + kw['column_keys'] = list(params) + + if literal_binds: + compile_kwargs['literal_binds'] = True + + if isinstance(clause, orm.Query): + context = clause._compile_context() + context.statement.use_labels = True + clause = context.statement + + if compile_kwargs: + kw['compile_kwargs'] = compile_kwargs + + c = clause.compile(dialect=dialect, **kw) + + param_str = repr(getattr(c, 'params', {})) + + if util.py3k: + param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + print(("\nSQL String:\n" + util.text_type(c) + param_str).encode('utf-8')) + else: + print("\nSQL String:\n" + util.text_type(c).encode('utf-8') + param_str) + + + cc = re.sub(r'[\n\t]', '', util.text_type(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + + +class ComparesTables(object): + + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" + assert type(reflected_c.type) is type(c.type), \ + msg % (reflected_c.type, c.type) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_( + set([f.column.name for f in c.foreign_keys]), + set([f.column.name for f in reflected_c.foreign_keys]) + ) + if c.server_default: + assert isinstance(reflected_c.server_default, + schema.FetchedValue) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity(c2.type),\ + "On column %r, type '%s' doesn't correspond to type '%s'" % \ + (c1.name, c1.type, c2.type) + + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print(repr(result)) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.items(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([immutabledict(e) for e in expected]) + + for wrong in util.itertools_filterfalse(lambda o: type(o) == cls, found): + fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + + def _compare_item(obj, spec): + for key, value in spec.items(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) + try: + callable_() + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.dialect.supports_sequences: + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.items() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count)) + + @contextlib.contextmanager + def assert_execution(self, *rules): + assertsql.asserter.add_rules(rules) + try: + yield + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_statement_count(self, count): + return self.assert_execution(assertsql.CountStatements(count)) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py new file mode 100644 index 00000000..3e0d4c9d --- /dev/null +++ b/lib/sqlalchemy/testing/assertsql.py @@ -0,0 +1,333 @@ +# testing/assertsql.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from ..engine.default import DefaultDialect +from .. import util +import re + + +class AssertRule(object): + + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has + definitely failed. + + """ + + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if + failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not + been consumed or has failed. + + """ + + if self._result is None: + assert False, 'Rule has not been consumed' + return self.is_consumed() + + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + + +class ExactSQL(SQLMatchRule): + + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests are migrated, as + # ExactSQL should really be *exact* SQL + + sql = _process_assertion_statement(self.sql, context) + equivalent = _received_statement == sql + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params \ + == context.compiled_parameters + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for exact statement %r exact params %r, '\ + 'received %r with params %r' % (sql, params, + _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + + # do a positive compare only + + for param, received in zip(params, _received_parameters): + for k, v in param.items(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for regex %r partial params %r, received %r '\ + 'with params %r' % (self.orig_regex, params, + _received_statement, + _received_parameters) + + +class CompiledSQL(SQLMatchRule): + + def __init__(self, statement, params=None): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + from sqlalchemy.schema import _DDLCompiles + _received_parameters = list(context.compiled_parameters) + + # recompile from the context, using the default dialect + + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect()) + else: + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect(), + column_keys=context.compiled.column_keys) + _received_statement = re.sub(r'[\n\t]', '', str(compiled)) + equivalent = self.statement == _received_statement + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + else: + params = list(params) + all_params = list(params) + all_received = list(_received_parameters) + while params: + param = dict(params.pop(0)) + for k, v in context.compiled.params.items(): + param.setdefault(k, v) + if param not in _received_parameters: + equivalent = False + break + else: + _received_parameters.remove(param) + if _received_parameters: + equivalent = False + else: + params = {} + all_params = {} + all_received = [] + self._result = equivalent + if not self._result: + print('Testing for compiled statement %r partial params '\ + '%r, received %r with params %r' % (self.statement, + all_params, _received_statement, all_received)) + self._errmsg = \ + 'Testing for compiled statement %r partial params %r, '\ + 'received %r with params %r' % (self.statement, + all_params, _received_statement, all_received) + + + # print self._errmsg + +class CountStatements(AssertRule): + + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, \ + 'desired statement count %d does not match %d' \ + % (self.count, self._statement_count) + return True + + +class AllOf(AssertRule): + + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, + executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, + executemany) + + def is_consumed(self): + if not self.rules: + return True + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + assert False, 'No assertion rules were satisfied for statement' + + def consume_final(self): + return len(self.rules) == 0 + + +def _process_engine_statement(query, context): + if util.jython: + + # oracle+zxjdbc passes a PyStatement when returning into + + query = str(query) + if context.engine.name == 'mssql' \ + and query.endswith('; select scope_identity()'): + query = query[:-25] + query = re.sub(r'\n', '', query) + return query + + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle == 'pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle == 'qmark': + repl = "?" + elif paramstyle == 'format': + repl = r"%s" + elif paramstyle == 'numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + + +class SQLAssert(object): + + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, \ + 'All statements are complete, but pending '\ + 'assertion rules remain' + + def clear_rules(self): + del self.rules + + def execute(self, conn, clauseelement, multiparams, params, result): + if self.rules is not None: + if not self.rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + def cursor_execute(self, conn, cursor, statement, parameters, + context, executemany): + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, + executemany) + +asserter = SQLAssert() diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py new file mode 100644 index 00000000..20af3dd2 --- /dev/null +++ b/lib/sqlalchemy/testing/config.py @@ -0,0 +1,77 @@ +# testing/config.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import collections + +requirements = None +db = None +db_url = None +db_opts = None +file_config = None + +_current = None + +class Config(object): + def __init__(self, db, db_opts, options, file_config): + self.db = db + self.db_opts = db_opts + self.options = options + self.file_config = file_config + + _stack = collections.deque() + _configs = {} + + @classmethod + def register(cls, db, db_opts, options, file_config, namespace): + """add a config as one of the global configs. + + If there are no configs set up yet, this config also + gets set as the "_current". + """ + cfg = Config(db, db_opts, options, file_config) + + global _current + if not _current: + cls.set_as_current(cfg, namespace) + cls._configs[cfg.db.name] = cfg + cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg + cls._configs[cfg.db] = cfg + + @classmethod + def set_as_current(cls, config, namespace): + global db, _current, db_url + _current = config + db_url = config.db.url + namespace.db = db = config.db + + @classmethod + def push_engine(cls, db, namespace): + assert _current, "Can't push without a default Config set up" + cls.push( + Config(db, _current.db_opts, _current.options, _current.file_config), + namespace + ) + + @classmethod + def push(cls, config, namespace): + cls._stack.append(_current) + cls.set_as_current(config, namespace) + + @classmethod + def reset(cls, namespace): + if cls._stack: + cls.set_as_current(cls._stack[0], namespace) + cls._stack.clear() + + @classmethod + def all_configs(cls): + for cfg in set(cls._configs.values()): + yield cfg + + @classmethod + def all_dbs(cls): + for cfg in cls.all_configs(): + yield cfg.db diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py new file mode 100644 index 00000000..d27be3cd --- /dev/null +++ b/lib/sqlalchemy/testing/engines.py @@ -0,0 +1,448 @@ +# testing/engines.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import types +import weakref +from collections import deque +from . import config +from .util import decorator +from .. import event, pool +import re +import warnings +from .. import util + +class ConnectionKiller(object): + + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = weakref.WeakKeyDictionary() + self.conns = set() + + def add_engine(self, engine): + self.testing_engines[engine] = True + + def connect(self, dbapi_conn, con_record): + self.conns.add((dbapi_conn, con_record)) + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def invalidate(self, dbapi_con, con_record, exception): + self.conns.discard((dbapi_con, con_record)) + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception as e: + warnings.warn( + "testing_reaper couldn't " + "rollback/close connection: %s" % e) + + def rollback_all(self): + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def close_all(self): + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self._safe(rec._close) + + def _after_test_ctx(self): + # this can cause a deadlock with pg8000 - pg8000 acquires + # prepared statment lock inside of rollback() - if async gc + # is collecting in finalize_fairy, deadlock. + # not sure if this should be if pypy/jython only. + # note that firebird/fdb definitely needs this though + for conn, rec in list(self.conns): + self._safe(conn.rollback) + + def _stop_test_ctx(self): + if config.options.low_connections: + self._stop_test_ctx_minimal() + else: + self._stop_test_ctx_aggressive() + + def _stop_test_ctx_minimal(self): + self.close_all() + + self.conns = set() + + for rec in list(self.testing_engines): + if rec is not config.db: + rec.dispose() + + def _stop_test_ctx_aggressive(self): + self.close_all() + for conn, rec in list(self.conns): + self._safe(conn.close) + rec.connection = None + + self.conns = set() + for rec in list(self.testing_engines): + rec.dispose() + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + + +def drop_all_tables(metadata, bind): + testing_reaper.close_all() + if hasattr(bind, 'close'): + bind.close() + metadata.drop_all(bind) + + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.close_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + + +def all_dialects(exclude=None): + import sqlalchemy.databases as d + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__( + 'sqlalchemy.databases.%s' % name).databases, name) + yield mod.dialect() + + +class ReconnectFixture(object): + + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception as e: + warnings.warn( + "ReconnectFixture couldn't " + "close connection: %s" % e) + + def shutdown(self): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + + +def reconnecting_engine(url=None, options=None): + url = url or config.db.url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options['module'] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + + def dispose(): + engine.dialect.dbapi.shutdown() + _dispose() + + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.dispose = dispose + return engine + + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from .assertsql import asserter + + if not options: + use_reaper = True + else: + use_reaper = options.pop('use_reaper', True) + + url = url or config.db.url + if options is None: + options = config.db_opts + + engine = create_engine(url, **options) + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + event.listen(engine, 'after_execute', asserter.execute) + event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) + if use_reaper: + event.listen(engine.pool, 'connect', testing_reaper.connect) + event.listen(engine.pool, 'checkout', testing_reaper.checkout) + event.listen(engine.pool, 'invalidate', testing_reaper.invalidate) + testing_reaper.add_engine(engine) + + return engine + + + + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + + def executor(sql, *a, **kw): + buffer.append(sql) + + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + + def print_sql(): + d = engine.dialect + return "\n".join( + str(s.compile(dialect=d)) + for s in engine.mock + ) + + engine = create_engine(dialect_name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + + +class DBAPIProxyCursor(object): + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + def __init__(self, engine, conn): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor() + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __getattr__(self, key): + return getattr(self.cursor, key) + + +class DBAPIProxyConnection(object): + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + def __init__(self, engine, cursor_cls): + self.conn = self._sqla_unwrap = engine.pool._creator() + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self): + return self.cursor_cls(self.engine, self.conn) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) + + +def proxying_engine(conn_cls=DBAPIProxyConnection, + cursor_cls=DBAPIProxyCursor): + """Produce an engine that provides proxy hooks for + common methods. + + """ + def mock_conn(): + return conn_cls(config.db, cursor_cls) + return testing_engine(options={'creator': mock_conn}) + + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + + if util.py2k: + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]).\ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + else: + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]).\ + union([type(t) if not isinstance(t, type) + else t for t in __builtins__.values()]).\ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', )]) + + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + @property + def _sqla_unwrap(self): + return self._subject + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + @property + def _sqla_unwrap(self): + return None + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py new file mode 100644 index 00000000..9309abfd --- /dev/null +++ b/lib/sqlalchemy/testing/entities.py @@ -0,0 +1,99 @@ +# testing/entities.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc + +_repr_stack = set() + + +class BasicEntity(object): + + def __init__(self, **kw): + for key, value in kw.items(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ', '.join(["%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith('_')])) + finally: + _repr_stack.remove(id(self)) + +_recursion_stack = set() + + +class ComparableEntity(BasicEntity): + + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity thats not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in list(a.__dict__): + if attr.startswith('_'): + continue + value = getattr(a, attr) + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): + if hasattr(value, '__getitem__') and not hasattr(value, 'keys'): + if list(value) != list(battr): + return False + else: + if set(value) != set(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py new file mode 100644 index 00000000..00ca2842 --- /dev/null +++ b/lib/sqlalchemy/testing/exclusions.py @@ -0,0 +1,363 @@ +# testing/exclusions.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +import operator +from .plugin.plugin_base import SkipTest +from ..util import decorator +from . import config +from .. import util +import contextlib +import inspect + +class skip_if(object): + def __init__(self, predicate, reason=None): + self.predicate = _as_predicate(predicate) + self.reason = reason + + _fails_on = None + + def __add__(self, other): + def decorate(fn): + return other(self(fn)) + return decorate + + @property + def enabled(self): + return self.enabled_for_config(config._current) + + def enabled_for_config(self, config): + return not self.predicate(config) + + @contextlib.contextmanager + def fail_if(self, name='block'): + try: + yield + except Exception as ex: + if self.predicate(config._current): + print(("%s failed as expected (%s): %s " % ( + name, self.predicate, str(ex)))) + else: + raise + else: + if self.predicate(config._current): + raise AssertionError( + "Unexpected success for '%s' (%s)" % + (name, self.predicate)) + + def __call__(self, fn): + @decorator + def decorate(fn, *args, **kw): + if self.predicate(config._current): + if self.reason: + msg = "'%s' : %s" % ( + fn.__name__, + self.reason + ) + else: + msg = "'%s': %s" % ( + fn.__name__, self.predicate + ) + raise SkipTest(msg) + else: + if self._fails_on: + with self._fails_on.fail_if(name=fn.__name__): + return fn(*args, **kw) + else: + return fn(*args, **kw) + return decorate(fn) + + def fails_on(self, other, reason=None): + self._fails_on = skip_if(other, reason) + return self + + def fails_on_everything_except(self, *dbs): + self._fails_on = skip_if(fails_on_everything_except(*dbs)) + return self + +class fails_if(skip_if): + def __call__(self, fn): + @decorator + def decorate(fn, *args, **kw): + with self.fail_if(name=fn.__name__): + return fn(*args, **kw) + return decorate(fn) + + +def only_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return skip_if(NotPredicate(predicate), reason) + + +def succeeds_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return fails_if(NotPredicate(predicate), reason) + + +class Predicate(object): + @classmethod + def as_predicate(cls, predicate): + if isinstance(predicate, skip_if): + return NotPredicate(predicate.predicate) + elif isinstance(predicate, Predicate): + return predicate + elif isinstance(predicate, list): + return OrPredicate([cls.as_predicate(pred) for pred in predicate]) + elif isinstance(predicate, tuple): + return SpecPredicate(*predicate) + elif isinstance(predicate, util.string_types): + tokens = predicate.split(" ", 2) + op = spec = None + db = tokens.pop(0) + if tokens: + op = tokens.pop(0) + if tokens: + spec = tuple(int(d) for d in tokens.pop(0).split(".")) + return SpecPredicate(db, op, spec) + elif util.callable(predicate): + return LambdaPredicate(predicate) + else: + assert False, "unknown predicate type: %s" % predicate + + +class BooleanPredicate(Predicate): + def __init__(self, value, description=None): + self.value = value + self.description = description or "boolean %s" % value + + def __call__(self, config): + return self.value + + def _as_string(self, negate=False): + if negate: + return "not " + self.description + else: + return self.description + + def __str__(self): + return self._as_string() + + +class SpecPredicate(Predicate): + def __init__(self, db, op=None, spec=None, description=None): + self.db = db + self.op = op + self.spec = spec + self.description = description + + _ops = { + '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + + def __call__(self, config): + engine = config.db + + if "+" in self.db: + dialect, driver = self.db.split('+') + else: + dialect, driver = self.db, None + + if dialect and engine.name != dialect: + return False + if driver is not None and engine.driver != driver: + return False + + if self.op is not None: + assert driver is None, "DBAPI version specs not supported yet" + + version = _server_version(engine) + oper = hasattr(self.op, '__call__') and self.op \ + or self._ops[self.op] + return oper(version, self.spec) + else: + return True + + def _as_string(self, negate=False): + if self.description is not None: + return self.description + elif self.op is None: + if negate: + return "not %s" % self.db + else: + return "%s" % self.db + else: + if negate: + return "not %s %s %s" % ( + self.db, + self.op, + self.spec + ) + else: + return "%s %s %s" % ( + self.db, + self.op, + self.spec + ) + + def __str__(self): + return self._as_string() + + +class LambdaPredicate(Predicate): + def __init__(self, lambda_, description=None, args=None, kw=None): + spec = inspect.getargspec(lambda_) + if not spec[0]: + self.lambda_ = lambda db: lambda_() + else: + self.lambda_ = lambda_ + self.args = args or () + self.kw = kw or {} + if description: + self.description = description + elif lambda_.__doc__: + self.description = lambda_.__doc__ + else: + self.description = "custom function" + + def __call__(self, config): + return self.lambda_(config) + + def _as_string(self, negate=False): + if negate: + return "not " + self.description + else: + return self.description + + def __str__(self): + return self._as_string() + + +class NotPredicate(Predicate): + def __init__(self, predicate): + self.predicate = predicate + + def __call__(self, config): + return not self.predicate(config) + + def __str__(self): + return self.predicate._as_string(True) + + +class OrPredicate(Predicate): + def __init__(self, predicates, description=None): + self.predicates = predicates + self.description = description + + def __call__(self, config): + for pred in self.predicates: + if pred(config): + self._str = pred + return True + return False + + _str = None + + def _eval_str(self, negate=False): + if self._str is None: + if negate: + conjunction = " and " + else: + conjunction = " or " + return conjunction.join(p._as_string(negate=negate) + for p in self.predicates) + else: + return self._str._as_string(negate=negate) + + def _negation_str(self): + if self.description is not None: + return "Not " + (self.description % {"spec": self._str}) + else: + return self._eval_str(negate=True) + + def _as_string(self, negate=False): + if negate: + return self._negation_str() + else: + if self.description is not None: + return self.description % {"spec": self._str} + else: + return self._eval_str() + + def __str__(self): + return self._as_string() + +_as_predicate = Predicate.as_predicate + + +def _is_excluded(db, op, spec): + return SpecPredicate(db, op, spec)(config._current) + + +def _server_version(engine): + """Return a server_version_info tuple.""" + + # force metadata to be retrieved + conn = engine.connect() + version = getattr(engine.dialect, 'server_version_info', ()) + conn.close() + return version + + +def db_spec(*dbs): + return OrPredicate( + [Predicate.as_predicate(db) for db in dbs] + ) + + +def open(): + return skip_if(BooleanPredicate(False, "mark as execute")) + + +def closed(): + return skip_if(BooleanPredicate(True, "marked as skip")) + +def fails(): + return fails_if(BooleanPredicate(True, "expected to fail")) + +@decorator +def future(fn, *arg): + return fails_if(LambdaPredicate(fn), "Future feature") + + +def fails_on(db, reason=None): + return fails_if(SpecPredicate(db), reason) + + +def fails_on_everything_except(*dbs): + return succeeds_if( + OrPredicate([ + SpecPredicate(db) for db in dbs + ]) + ) + + +def skip(db, reason=None): + return skip_if(SpecPredicate(db), reason) + + +def only_on(dbs, reason=None): + return only_if( + OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + ) + + +def exclude(db, op, spec, reason=None): + return skip_if(SpecPredicate(db, op, spec), reason) + + +def against(config, *queries): + assert queries, "no queries sent!" + return OrPredicate([ + Predicate.as_predicate(query) + for query in queries + ])(config) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py new file mode 100644 index 00000000..7941bf0f --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures.py @@ -0,0 +1,380 @@ +# testing/fixtures.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import config +from . import assertions, schema +from .util import adict +from .. import util +from .engines import drop_all_tables +from .entities import BasicEntity, ComparableEntity +import sys +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta + +# whether or not we use unittest changes things dramatically, +# as far as how py.test collection works. + +class TestBase(object): + # A sequence of database names to always run, regardless of the + # constraints below. + __whitelist__ = () + + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + def assert_(self, val, msg=None): + assert val, msg + + # apparently a handful of tests are doing this....OK + def setup(self): + if hasattr(self, "setUp"): + self.setUp() + + def teardown(self): + if hasattr(self, "tearDown"): + self.tearDown() + +class TablesTest(TestBase): + + # 'once', None + run_setup_bind = 'once' + + # 'once', 'each', None + run_define_tables = 'once' + + # 'once', 'each', None + run_create_tables = 'once' + + # 'once', 'each', None + run_inserts = 'each' + + # 'each', None + run_deletes = 'each' + + # 'once', None + run_dispose_bind = None + + bind = None + metadata = None + tables = None + other = None + + @classmethod + def setup_class(cls): + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + @classmethod + def _init_class(cls): + if cls.run_define_tables == 'each': + if cls.run_create_tables == 'once': + cls.run_create_tables = 'each' + assert cls.run_inserts in ('each', None) + + if cls.other is None: + cls.other = adict() + + if cls.tables is None: + cls.tables = adict() + + if cls.bind is None: + setattr(cls, 'bind', cls.setup_bind()) + + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) + + if cls.metadata.bind is None: + cls.metadata.bind = cls.bind + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + if cls.run_create_tables == 'once': + cls.metadata.create_all(cls.bind) + cls.tables.update(cls.metadata.tables) + + def _setup_each_tables(self): + if self.run_define_tables == 'each': + self.tables.clear() + if self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.clear() + self.define_tables(self.metadata) + if self.run_create_tables == 'each': + self.metadata.create_all(self.bind) + self.tables.update(self.metadata.tables) + elif self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == 'each': + self._load_fixtures() + self.insert_data() + + def _teardown_each_tables(self): + # no need to run deletes if tables are recreated on setup + if self.run_define_tables != 'each' and self.run_deletes == 'each': + for table in reversed(self.metadata.sorted_tables): + try: + table.delete().execute().close() + except sa.exc.DBAPIError as ex: + util.print_( + ("Error emptying table %s: %r" % (table, ex)), + file=sys.stderr) + + def setup(self): + self._setup_each_tables() + self._setup_each_inserts() + + def teardown(self): + self._teardown_each_tables() + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables(cls.metadata, cls.bind) + + if cls.run_dispose_bind == 'once': + cls.dispose_bind(cls.bind) + + cls.metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def teardown_class(cls): + cls._teardown_once_metadata_bind() + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, 'dispose'): + bind.dispose() + elif hasattr(bind, 'close'): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements, with_sequences=None): + self.assert_sql(self.bind, + callable_, statements, with_sequences) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().items(): + if len(data) < 2: + continue + if isinstance(table, util.string_types): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table in cls.metadata.sorted_tables: + if table not in headers: + continue + cls.bind.execute( + table.insert(), + [dict(zip(headers[table], column_values)) + for column_values in rows[table]]) + +from sqlalchemy import event +class RemovesEvents(object): + @util.memoized_property + def _event_fns(self): + return set() + + def event_listen(self, target, name, fn): + self._event_fns.add((target, name, fn)) + event.listen(target, name, fn) + + def teardown(self): + for key in self._event_fns: + event.remove(*key) + super_ = super(RemovesEvents, self) + if hasattr(super_, "teardown"): + super_.teardown() + + + +class _ORMTest(object): + + @classmethod + def teardown_class(cls): + sa.orm.session.Session.close_all() + sa.orm.clear_mappers() + + +class ORMTest(_ORMTest, TestBase): + pass + + +class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = 'once' + + # 'once', 'each', None + run_setup_mappers = 'each' + + classes = None + + @classmethod + def setup_class(cls): + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + @classmethod + def teardown_class(cls): + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + def setup(self): + self._setup_each_tables() + self._setup_each_mappers() + self._setup_each_inserts() + + def teardown(self): + sa.orm.session.Session.close_all() + self._teardown_each_mappers() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + _ORMTest.teardown_class() + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == 'once': + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == 'once': + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers == 'each': + self._with_register_classes(self.setup_mappers) + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + + class FindFixture(type): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return type.__init__(cls, classname, bases, dict_) + + class _Base(util.with_metaclass(FindFixture, object)): + pass + + class Basic(BasicEntity, _Base): + pass + + class Comparable(ComparableEntity, _Base): + pass + + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != 'once': + sa.orm.clear_mappers() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = 'once' + run_setup_mappers = 'once' + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + + class FindFixtureDeclarative(DeclarativeMeta): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return DeclarativeMeta.__init__( + cls, classname, bases, dict_) + + class DeclarativeBasic(object): + __table_cls__ = schema.Table + + _DeclBase = declarative_base(metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic) + cls.DeclarativeBasic = _DeclBase + fn() + + if cls.metadata.tables and cls.run_create_tables: + cls.metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py new file mode 100644 index 00000000..18ba053e --- /dev/null +++ b/lib/sqlalchemy/testing/mock.py @@ -0,0 +1,21 @@ +# testing/mock.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Import stub for mock library. +""" +from __future__ import absolute_import +from ..util import py33 + +if py33: + from unittest.mock import MagicMock, Mock, call, patch +else: + try: + from mock import MagicMock, Mock, call, patch + except ImportError: + raise ImportError( + "SQLAlchemy's test suite requires the " + "'mock' library as of 0.8.2.") + diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py new file mode 100644 index 00000000..9a41034b --- /dev/null +++ b/lib/sqlalchemy/testing/pickleable.py @@ -0,0 +1,142 @@ +# testing/pickleable.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Classes used in pickling tests, need to be at the module level for +unpickling. +""" + +from . import fixtures + + +class User(fixtures.ComparableEntity): + pass + + +class Order(fixtures.ComparableEntity): + pass + + +class Dingaling(fixtures.ComparableEntity): + pass + + +class EmailUser(User): + pass + + +class Address(fixtures.ComparableEntity): + pass + + +# TODO: these are kind of arbitrary.... +class Child1(fixtures.ComparableEntity): + pass + + +class Child2(fixtures.ComparableEntity): + pass + + +class Parent(fixtures.ComparableEntity): + pass + + +class Screen(object): + + def __init__(self, obj, parent=None): + self.obj = obj + self.parent = parent + + +class Foo(object): + + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + + __hash__ = object.__hash__ + + def __eq__(self, other): + return other.data == self.data and \ + other.stuff == self.stuff and \ + other.moredata == self.moredata + + +class Bar(object): + + def __init__(self, x, y): + self.x = x + self.y = y + + __hash__ = object.__hash__ + + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class OldSchool: + + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + + +class OldSchoolWithoutCompare: + + def __init__(self, x, y): + self.x = x + self.y = y + + +class BarWithoutCompare(object): + + def __init__(self, x, y): + self.x = x + self.y = y + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py new file mode 100644 index 00000000..18a1178a --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -0,0 +1,89 @@ +# plugin/noseplugin.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Enhance nose with extra options and behaviors for running SQLAlchemy tests. + +Must be run via ./sqla_nose.py so that it is imported in the expected +way (e.g. as a package-less import). + +""" + +import os + +from nose.plugins import Plugin +fixtures = None + +# no package imports yet! this prevents us from tripping coverage +# too soon. +import imp +path = os.path.join(os.path.dirname(__file__), "plugin_base.py") +plugin_base = imp.load_source("plugin_base", path) + + +class NoseSQLAlchemy(Plugin): + enabled = True + + name = 'sqla_testing' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + + def make_option(name, **kw): + callback_ = kw.pop("callback", None) + if callback_: + def wrap_(option, opt_str, value, parser): + callback_(opt_str, value, parser) + kw["callback"] = wrap_ + opt(name, **kw) + + plugin_base.setup_options(make_option) + plugin_base.read_config() + + def configure(self, options, conf): + super(NoseSQLAlchemy, self).configure(options, conf) + plugin_base.pre_begin(options) + + plugin_base.set_coverage_flag(options.enable_plugin_coverage) + + global fixtures + from sqlalchemy.testing import fixtures + + def begin(self): + plugin_base.post_begin() + + def describeTest(self, test): + return "" + + def wantFunction(self, fn): + if fn.__module__ is None: + return False + if fn.__module__.startswith('sqlalchemy.testing'): + return False + + def wantClass(self, cls): + return plugin_base.want_class(cls) + + def beforeTest(self, test): + plugin_base.before_test(test, + test.test.cls.__module__, + test.test.cls, test.test.method.__name__) + + def afterTest(self, test): + plugin_base.after_test(test) + + def startContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + plugin_base.start_test_class(ctx) + + def stopContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + plugin_base.stop_test_class(ctx) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py new file mode 100644 index 00000000..061848e2 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -0,0 +1,455 @@ +# plugin/plugin_base.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Testing extensions. + +this module is designed to work as a testing-framework-agnostic library, +so that we can continue to support nose and also begin adding new functionality +via py.test. + +""" + +from __future__ import absolute_import +try: + # unitttest has a SkipTest also but pytest doesn't + # honor it unless nose is imported too... + from nose import SkipTest +except ImportError: + from _pytest.runner import Skipped as SkipTest + +import sys +import re + +py3k = sys.version_info >= (3, 0) + +if py3k: + import configparser +else: + import ConfigParser as configparser + + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +assertions = None +requirements = None +config = None +testing = None +util = None +file_config = None + + +logging = None +db_opts = {} +options = None + +def setup_options(make_option): + make_option("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for (multiple OK)") + make_option("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for (multiple OK)") + make_option("--db", action="append", type="string", dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.") + make_option('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + make_option("--dburi", action="append", type="string", dest="dburi", + help="Database uri. Multiple OK, first one is run by default.") + make_option("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first") + make_option("--backend-only", action="store_true", dest="backend_only", + help="Run only tests marked with __backend__") + make_option("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + make_option("--low-connections", action="store_true", dest="low_connections", + help="Use a low number of distinct connections - i.e. for Oracle TNS" + ) + make_option("--reversetop", action="store_true", dest="reversetop", default=False, + help="Use a random-ordering set implementation in the ORM (helps " + "reveal dependency issues)") + make_option("--requirements", action="callback", type="string", + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg") + make_option("--with-cdecimal", action="store_true", dest="cdecimal", default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' for all tests") + make_option("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + make_option("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + make_option("--tableopts", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + make_option("--write-profiles", action="store_true", dest="write_profiles", default=False, + help="Write/update profiling data.") + +def read_config(): + global file_config + file_config = configparser.ConfigParser() + file_config.read(['setup.cfg', 'test.cfg']) + +def pre_begin(opt): + """things to set up early, before coverage might be setup.""" + global options + options = opt + for fn in pre_configure: + fn(options, file_config) + +def set_coverage_flag(value): + options.has_coverage = value + +def post_begin(): + """things to set up later, once we know coverage is running.""" + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + # late imports, has to happen after config as well + # as nose plugins like coverage + global util, fixtures, engines, exclusions, \ + assertions, warnings, profiling,\ + config, testing + from sqlalchemy import testing + from sqlalchemy.testing import fixtures, engines, exclusions, \ + assertions, warnings, profiling, config + from sqlalchemy import util + + +def _log(opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print("Available --db options (use --dburi to override)") + for macro in sorted(file_config.options('db')): + print("%20s\t%s" % (macro, file_config.get('db', macro))) + sys.exit(0) + + +def _server_side_cursors(opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _requirements_opt(opt_str, value, parser): + _setup_requirements(value) + + +pre_configure = [] +post_configure = [] + + +def pre(fn): + pre_configure.append(fn) + return fn + + +def post(fn): + post_configure.append(fn) + return fn + + +@pre +def _setup_options(opt, file_config): + global options + options = opt + + +@pre +def _monkeypatch_cdecimal(options, file_config): + if options.cdecimal: + import cdecimal + sys.modules['decimal'] = cdecimal + + +@post +def _engine_uri(options, file_config): + from sqlalchemy.testing import engines, config + from sqlalchemy import testing + + if options.dburi: + db_urls = list(options.dburi) + else: + db_urls = [] + + if options.db: + for db_token in options.db: + for db in re.split(r'[,\s]+', db_token): + if db not in file_config.options('db'): + raise RuntimeError( + "Unknown URI specifier '%s'. Specify --dbs for known uris." + % db) + else: + db_urls.append(file_config.get('db', db)) + + if not db_urls: + db_urls.append(file_config.get('db', 'default')) + + for db_url in db_urls: + eng = engines.testing_engine(db_url, db_opts) + eng.connect().close() + config.Config.register(eng, db_opts, options, file_config, testing) + + config.db_opts = db_opts + + +@post +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool + +@post +def _requirements(options, file_config): + + requirement_cls = file_config.get('sqla_testing', "requirement_cls") + _setup_requirements(requirement_cls) + +def _setup_requirements(argument): + from sqlalchemy.testing import config + from sqlalchemy import testing + + if config.requirements is not None: + return + + modname, clsname = argument.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + + config.requirements = testing.requires = req_cls() + +@post +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import config + from sqlalchemy import schema, inspect + + if options.dropfirst: + for cfg in config.Config.all_configs(): + e = cfg.db + inspector = inspect(e) + try: + view_names = inspector.get_view_names() + except NotImplementedError: + pass + else: + for vname in view_names: + e.execute(schema._DropView(schema.Table(vname, schema.MetaData()))) + + if config.requirements.schemas.enabled_for_config(cfg): + try: + view_names = inspector.get_view_names(schema="test_schema") + except NotImplementedError: + pass + else: + for vname in view_names: + e.execute(schema._DropView( + schema.Table(vname, + schema.MetaData(), schema="test_schema"))) + + for tname in reversed(inspector.get_table_names(order_by="foreign_key")): + e.execute(schema.DropTable(schema.Table(tname, schema.MetaData()))) + + if config.requirements.schemas.enabled_for_config(cfg): + for tname in reversed(inspector.get_table_names( + order_by="foreign_key", schema="test_schema")): + e.execute(schema.DropTable( + schema.Table(tname, schema.MetaData(), schema="test_schema"))) + + +@post +def _set_table_options(options, file_config): + from sqlalchemy.testing import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine + + +@post +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm.util import randomize_unitofwork + randomize_unitofwork() + + +@post +def _post_setup_options(opt, file_config): + from sqlalchemy.testing import config + config.options = options + config.file_config = file_config + + +@post +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get('sqla_testing', 'profile_file')) + + +def want_class(cls): + if not issubclass(cls, fixtures.TestBase): + return False + elif cls.__name__.startswith('_'): + return False + elif config.options.backend_only and not getattr(cls, '__backend__', False): + return False + else: + return True + +def generate_sub_tests(cls, module): + if getattr(cls, '__backend__', False): + for cfg in config.Config.all_configs(): + name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver) + subcls = type( + name, + (cls, ), + { + "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)), + "__backend__": False} + ) + setattr(module, name, subcls) + yield subcls + else: + yield cls + + +def start_test_class(cls): + _do_skips(cls) + _setup_engine(cls) + +def stop_test_class(cls): + engines.testing_reaper._stop_test_ctx() + if not options.low_connections: + assertions.global_cleanup_assertions() + _restore_engine() + +def _restore_engine(): + config._current.reset(testing) + +def _setup_engine(cls): + if getattr(cls, '__engine_options__', None): + eng = engines.testing_engine(options=cls.__engine_options__) + config._current.push_engine(eng, testing) + +def before_test(test, test_module_name, test_class, test_name): + + # like a nose id, e.g.: + # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" + name = test_class.__name__ + + suffix = "_%s_%s" % (config.db.name, config.db.driver) + if name.endswith(suffix): + name = name[0:-(len(suffix))] + + id_ = "%s.%s.%s" % (test_module_name, name, test_name) + + warnings.resetwarnings() + profiling._current_test = id_ + +def after_test(test): + engines.testing_reaper._after_test_ctx() + warnings.resetwarnings() + +def _do_skips(cls): + all_configs = set(config.Config.all_configs()) + reasons = [] + + if hasattr(cls, '__requires__'): + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__requires__: + check = getattr(requirements, requirement) + + if check.predicate(config_obj): + all_configs.remove(config_obj) + if check.reason: + reasons.append(check.reason) + break + + if hasattr(cls, '__prefer_requires__'): + non_preferred = set() + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__prefer_requires__: + check = getattr(requirements, requirement) + + if check.predicate(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + for config_obj in list(all_configs): + if spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, '__only_on__', None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + for config_obj in list(all_configs): + if not spec(config_obj): + all_configs.remove(config_obj) + + + if getattr(cls, '__skip_if__', False): + for c in getattr(cls, '__skip_if__'): + if c(): + raise SkipTest("'%s' skipped by %s" % ( + cls.__name__, c.__name__) + ) + + for db_spec, op, spec in getattr(cls, '__excluded_on__', ()): + for config_obj in list(all_configs): + if exclusions.skip_if( + exclusions.SpecPredicate(db_spec, op, spec) + ).predicate(config_obj): + all_configs.remove(config_obj) + + + if not all_configs: + raise SkipTest( + "'%s' unsupported on DB implementation %s%s" % ( + cls.__name__, + ", ".join("'%s' = %s" % ( + config_obj.db.name, + config_obj.db.dialect.server_version_info) + for config_obj in config.Config.all_configs() + ), + ", ".join(reasons) + ) + ) + elif hasattr(cls, '__prefer_backends__'): + non_preferred = set() + spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) + for config_obj in all_configs: + if not spec(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if config._current not in all_configs: + _setup_config(all_configs.pop(), cls) + +def _setup_config(config_obj, ctx): + config._current.push(config_obj, testing) + diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py new file mode 100644 index 00000000..74d5cc08 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -0,0 +1,125 @@ +import pytest +import argparse +import inspect +from . import plugin_base +import collections + +def pytest_addoption(parser): + group = parser.getgroup("sqlalchemy") + + def make_option(name, **kw): + callback_ = kw.pop("callback", None) + if callback_: + class CallableAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + callback_(option_string, values, parser) + kw["action"] = CallableAction + + group.addoption(name, **kw) + + plugin_base.setup_options(make_option) + plugin_base.read_config() + +def pytest_configure(config): + plugin_base.pre_begin(config.option) + + plugin_base.set_coverage_flag(bool(getattr(config.option, "cov_source", False))) + + plugin_base.post_begin() + + +def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and + # expand them out into per-database test cases. + + # this is much easier to do within pytest_pycollect_makeitem, however + # pytest is iterating through cls.__dict__ as makeitem is + # called which causes a "dictionary changed size" error on py3k. + # I'd submit a pullreq for them to turn it into a list first, but + # it's to suit the rather odd use case here which is that we are adding + # new classes to a module on the fly. + + rebuilt_items = collections.defaultdict(list) + test_classes = set(item.parent for item in items) + for test_class in test_classes: + for sub_cls in plugin_base.generate_sub_tests(test_class.cls, test_class.parent.module): + if sub_cls is not test_class.cls: + list_ = rebuilt_items[test_class.cls] + + for inst in pytest.Class(sub_cls.__name__, + parent=test_class.parent.parent).collect(): + list_.extend(inst.collect()) + + newitems = [] + for item in items: + if item.parent.cls in rebuilt_items: + newitems.extend(rebuilt_items[item.parent.cls]) + rebuilt_items[item.parent.cls][:] = [] + else: + newitems.append(item) + + # seems like the functions attached to a test class aren't sorted already? + # is that true and why's that? (when using unittest, they're sorted) + items[:] = sorted(newitems, key=lambda item: ( + item.parent.parent.parent.name, + item.parent.parent.name, + item.name + ) + ) + + + +def pytest_pycollect_makeitem(collector, name, obj): + + if inspect.isclass(obj) and plugin_base.want_class(obj): + return pytest.Class(name, parent=collector) + elif inspect.isfunction(obj) and \ + name.startswith("test_") and \ + isinstance(collector, pytest.Instance): + return pytest.Function(name, parent=collector) + else: + return [] + +_current_class = None + +def pytest_runtest_setup(item): + # here we seem to get called only based on what we collected + # in pytest_collection_modifyitems. So to do class-based stuff + # we have to tear that out. + global _current_class + + if not isinstance(item, pytest.Function): + return + + # ... so we're doing a little dance here to figure it out... + if item.parent.parent is not _current_class: + + class_setup(item.parent.parent) + _current_class = item.parent.parent + + # this is needed for the class-level, to ensure that the + # teardown runs after the class is completed with its own + # class-level teardown... + item.parent.parent.addfinalizer(lambda: class_teardown(item.parent.parent)) + + test_setup(item) + +def pytest_runtest_teardown(item): + # ...but this works better as the hook here rather than + # using a finalizer, as the finalizer seems to get in the way + # of the test reporting failures correctly (you get a bunch of + # py.test assertion stuff instead) + test_teardown(item) + +def test_setup(item): + plugin_base.before_test(item, + item.parent.module.__name__, item.parent.cls, item.name) + +def test_teardown(item): + plugin_base.after_test(item) + +def class_setup(item): + plugin_base.start_test_class(item.cls) + +def class_teardown(item): + plugin_base.stop_test_class(item.cls) diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py new file mode 100644 index 00000000..2f92527e --- /dev/null +++ b/lib/sqlalchemy/testing/profiling.py @@ -0,0 +1,309 @@ +# testing/profiling.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import os +import sys +from .util import gc_collect, decorator +from . import config +from .plugin.plugin_base import SkipTest +import pstats +import time +import collections +from .. import util + +try: + import cProfile +except ImportError: + cProfile = None +from ..util import jython, pypy, win32, update_wrapper + +_current_test = None + + +def profiled(target=None, **target_opts): + """Function profiling. + + @profiled() + or + @profiled(report=True, sort=('calls',), limit=20) + + Outputs profiling info for a decorated function. + + """ + + profile_config = {'targets': set(), + 'report': True, + 'print_callers': False, + 'print_callees': False, + 'graphic': False, + 'sort': ('time', 'calls'), + 'limit': None} + if target is None: + target = 'anonymous_target' + + @decorator + def decorate(fn, *args, **kw): + elapsed, load_stats, result = _profile( + fn, *args, **kw) + + graphic = target_opts.get('graphic', profile_config['graphic']) + if graphic: + os.system("runsnake %s" % filename) + else: + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print(("Profile report for target '%s'" % ( + target, ) + )) + + stats = load_stats() + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + + print_callers = target_opts.get( + 'print_callers', profile_config['print_callers']) + if print_callers: + stats.print_callers() + + print_callees = target_opts.get( + 'print_callees', profile_config['print_callees']) + if print_callees: + stats.print_callees() + + return result + return decorate + + +class ProfileStatsFile(object): + """"Store per-platform/fn profiling results in a file. + + We're still targeting Py2.5, 2.4 on 0.7 with no dependencies, + so no json lib :( need to roll something silly + + """ + def __init__(self, filename): + self.write = ( + config.options is not None and + config.options.write_profiles + ) + self.fname = os.path.abspath(filename) + self.short_fname = os.path.split(self.fname)[-1] + self.data = collections.defaultdict( + lambda: collections.defaultdict(dict)) + self._read() + if self.write: + # rewrite for the case where features changed, + # etc. + self._write() + + @property + def platform_key(self): + + dbapi_key = config.db.name + "_" + config.db.driver + + # keep it at 2.7, 3.1, 3.2, etc. for now. + py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + + platform_tokens = [py_version] + platform_tokens.append(dbapi_key) + if jython: + platform_tokens.append("jython") + if pypy: + platform_tokens.append("pypy") + if win32: + platform_tokens.append("win") + _has_cext = config.requirements._has_cextensions() + platform_tokens.append(_has_cext and "cextensions" or "nocextensions") + return "_".join(platform_tokens) + + def has_stats(self): + test_key = _current_test + return ( + test_key in self.data and + self.platform_key in self.data[test_key] + ) + + def result(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + + if 'counts' not in per_platform: + per_platform['counts'] = counts = [] + else: + counts = per_platform['counts'] + + if 'current_count' not in per_platform: + per_platform['current_count'] = current_count = 0 + else: + current_count = per_platform['current_count'] + + has_count = len(counts) > current_count + + if not has_count: + counts.append(callcount) + if self.write: + self._write() + result = None + else: + result = per_platform['lineno'], counts[current_count] + per_platform['current_count'] += 1 + return result + + def replace(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + counts = per_platform['counts'] + counts[-1] = callcount + if self.write: + self._write() + + def _header(self): + return \ + "# %s\n"\ + "# This file is written out on a per-environment basis.\n"\ + "# For each test in aaa_profiling, the corresponding function and \n"\ + "# environment is located within this file. If it doesn't exist,\n"\ + "# the test is skipped.\n"\ + "# If a callcount does exist, it is compared to what we received. \n"\ + "# assertions are raised if the counts do not match.\n"\ + "# \n"\ + "# To add a new callcount test, apply the function_call_count \n"\ + "# decorator and re-run the tests using the --write-profiles \n"\ + "# option - this file will be rewritten including the new count.\n"\ + "# \n"\ + "" % (self.fname) + + def _read(self): + try: + profile_f = open(self.fname) + except IOError: + return + for lineno, line in enumerate(profile_f): + line = line.strip() + if not line or line.startswith("#"): + continue + + test_key, platform_key, counts = line.split() + per_fn = self.data[test_key] + per_platform = per_fn[platform_key] + c = [int(count) for count in counts.split(",")] + per_platform['counts'] = c + per_platform['lineno'] = lineno + 1 + per_platform['current_count'] = 0 + profile_f.close() + + def _write(self): + print(("Writing profile file %s" % self.fname)) + profile_f = open(self.fname, "w") + profile_f.write(self._header()) + for test_key in sorted(self.data): + + per_fn = self.data[test_key] + profile_f.write("\n# TEST: %s\n\n" % test_key) + for platform_key in sorted(per_fn): + per_platform = per_fn[platform_key] + c = ",".join(str(count) for count in per_platform['counts']) + profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) + profile_f.close() + + + +def function_call_count(variance=0.05): + """Assert a target for a test case's function call count. + + The main purpose of this assertion is to detect changes in + callcounts for various functions - the actual number is not as important. + Callcounts are stored in a file keyed to Python version and OS platform + information. This file is generated automatically for new tests, + and versioned so that unexpected changes in callcounts will be detected. + + """ + + def decorate(fn): + def wrap(*args, **kw): + + if cProfile is None: + raise SkipTest("cProfile is not installed") + + if not _profile_stats.has_stats() and not _profile_stats.write: + # run the function anyway, to support dependent tests + # (not a great idea but we have these in test_zoomark) + fn(*args, **kw) + raise SkipTest("No profiling stats available on this " + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname) + + gc_collect() + + timespent, load_stats, fn_result = _profile( + fn, *args, **kw + ) + stats = load_stats() + callcount = stats.total_calls + + expected = _profile_stats.result(callcount) + if expected is None: + expected_count = None + else: + line_no, expected_count = expected + + print(("Pstats calls: %d Expected %s" % ( + callcount, + expected_count + ) + )) + stats.print_stats() + #stats.print_callers() + + if expected_count: + deviance = int(callcount * variance) + failed = abs(callcount - expected_count) > deviance + + if failed: + if _profile_stats.write: + _profile_stats.replace(callcount) + else: + raise AssertionError( + "Adjusted function call count %s not within %s%% " + "of expected %s. Rerun with --write-profiles to " + "regenerate this callcount." + % ( + callcount, (variance * 100), + expected_count)) + return fn_result + return update_wrapper(wrap, fn) + return decorate + + +def _profile(fn, *args, **kw): + filename = "%s.prof" % fn.__name__ + + def load_stats(): + st = pstats.Stats(filename) + os.unlink(filename) + return st + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py new file mode 100644 index 00000000..07b5697e --- /dev/null +++ b/lib/sqlalchemy/testing/requirements.py @@ -0,0 +1,631 @@ +# testing/requirements.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +External dialect test suites should subclass SuiteRequirements +to provide specific inclusion/exlusions. + +""" + +from . import exclusions + + +class Requirements(object): + pass + +class SuiteRequirements(Requirements): + + @property + def create_table(self): + """target platform can emit basic CreateTable DDL.""" + + return exclusions.open() + + @property + def drop_table(self): + """target platform can emit basic DropTable DDL.""" + + return exclusions.open() + + @property + def foreign_keys(self): + """Target database must support foreign keys.""" + + return exclusions.open() + + @property + def on_update_cascade(self): + """"target database must support ON UPDATE..CASCADE behavior in + foreign keys.""" + + return exclusions.open() + + @property + def non_updating_cascade(self): + """target database must *not* support ON UPDATE..CASCADE behavior in + foreign keys.""" + return exclusions.closed() + + @property + def deferrable_fks(self): + return exclusions.closed() + + @property + def on_update_or_deferrable_fks(self): + # TODO: exclusions should be composable, + # somehow only_if([x, y]) isn't working here, negation/conjunctions + # getting confused. + return exclusions.only_if( + lambda: self.on_update_cascade.enabled or self.deferrable_fks.enabled + ) + + + @property + def self_referential_foreign_keys(self): + """Target database must support self-referential foreign keys.""" + + return exclusions.open() + + @property + def foreign_key_ddl(self): + """Target database must support the DDL phrases for FOREIGN KEY.""" + + return exclusions.open() + + @property + def named_constraints(self): + """target database must support names for constraints.""" + + return exclusions.open() + + @property + def subqueries(self): + """Target database must support subqueries.""" + + return exclusions.open() + + @property + def offset(self): + """target database can render OFFSET, or an equivalent, in a SELECT.""" + + return exclusions.open() + + @property + def boolean_col_expressions(self): + """Target database must support boolean expressions as columns""" + + return exclusions.closed() + + @property + def nullsordering(self): + """Target backends that support nulls ordering.""" + + return exclusions.closed() + + @property + def standalone_binds(self): + """target database/driver supports bound parameters as column expressions + without being in the context of a typed column. + + """ + return exclusions.closed() + + @property + def intersect(self): + """Target database must support INTERSECT or equivalent.""" + return exclusions.closed() + + @property + def except_(self): + """Target database must support EXCEPT or equivalent (i.e. MINUS).""" + return exclusions.closed() + + @property + def window_functions(self): + """Target database must support window functions.""" + return exclusions.closed() + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + + return exclusions.open() + + @property + def fetch_rows_post_commit(self): + """target platform will allow cursor.fetchone() to proceed after a + COMMIT. + + Typically this refers to an INSERT statement with RETURNING which + is invoked within "autocommit". If the row can be returned + after the autocommit, then this rule can be open. + + """ + + return exclusions.open() + + + @property + def empty_inserts(self): + """target platform supports INSERT with no values, i.e. + INSERT DEFAULT VALUES or equivalent.""" + + return exclusions.only_if( + lambda config: config.db.dialect.supports_empty_insert or \ + config.db.dialect.supports_default_values, + "empty inserts not supported" + ) + + @property + def insert_from_select(self): + """target platform supports INSERT from a SELECT.""" + + return exclusions.open() + + @property + def returning(self): + """target platform supports RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.implicit_returning, + "'returning' not supported by database" + ) + + @property + def duplicate_names_in_cursor_description(self): + """target platform supports a SELECT statement that has + the same name repeated more than once in the columns list.""" + + return exclusions.open() + + @property + def denormalized_names(self): + """Target database must have 'denormalized', i.e. + UPPERCASE as case insensitive names.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.requires_name_normalize, + "Backend does not require denormalized names." + ) + + @property + def multivalues_inserts(self): + """target database must support multiple VALUES clauses in an + INSERT statement.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_multivalues_insert, + "Backend does not support multirow inserts." + ) + + + @property + def implements_get_lastrowid(self): + """"target dialect implements the executioncontext.get_lastrowid() + method without reliance on RETURNING. + + """ + return exclusions.open() + + @property + def emulated_lastrowid(self): + """"target dialect retrieves cursor.lastrowid, or fetches + from a database-side function after an insert() construct executes, + within the get_lastrowid() method. + + Only dialects that "pre-execute", or need RETURNING to get last + inserted id, would return closed/fail/skip for this. + + """ + return exclusions.closed() + + @property + def dbapi_lastrowid(self): + """"target platform includes a 'lastrowid' accessor on the DBAPI + cursor object. + + """ + return exclusions.closed() + + @property + def views(self): + """Target database must support VIEWs.""" + + return exclusions.closed() + + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return exclusions.closed() + + @property + def sequences(self): + """Target database must support SEQUENCEs.""" + + return exclusions.only_if([ + lambda config: config.db.dialect.supports_sequences + ], "no sequence support") + + @property + def sequences_optional(self): + """Target database supports sequences, but also optionally + as a means of generating new PK values.""" + + return exclusions.only_if([ + lambda config: config.db.dialect.supports_sequences and \ + config.db.dialect.sequences_optional + ], "no sequence support, or sequences not optional") + + + + + + @property + def reflects_pk_names(self): + return exclusions.closed() + + @property + def table_reflection(self): + return exclusions.open() + + @property + def view_column_reflection(self): + """target database must support retrieval of the columns in a view, + similarly to how a table is inspected. + + This does not include the full CREATE VIEW definition. + + """ + return self.views + + @property + def view_reflection(self): + """target database must support inspection of the full CREATE VIEW definition. + """ + return self.views + + @property + def schema_reflection(self): + return self.schemas + + @property + def primary_key_constraint_reflection(self): + return exclusions.open() + + @property + def foreign_key_constraint_reflection(self): + return exclusions.open() + + @property + def index_reflection(self): + return exclusions.open() + + @property + def unique_constraint_reflection(self): + """target dialect supports reflection of unique constraints""" + return exclusions.open() + + @property + def unbounded_varchar(self): + """Target database must support VARCHAR with no length""" + + return exclusions.open() + + @property + def unicode_data(self): + """Target database/dialect must support Python unicode objects with + non-ASCII characters represented, delivered as bound parameters + as well as in result rows. + + """ + return exclusions.open() + + @property + def unicode_ddl(self): + """Target driver must support some degree of non-ascii symbol names.""" + return exclusions.closed() + + @property + def datetime_literals(self): + """target dialect supports rendering of a date, time, or datetime as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + + return exclusions.closed() + + @property + def datetime(self): + """target dialect supports representation of Python + datetime.datetime() objects.""" + + return exclusions.open() + + @property + def datetime_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects.""" + + return exclusions.open() + + @property + def datetime_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def date(self): + """target dialect supports representation of Python + datetime.date() objects.""" + + return exclusions.open() + + @property + def date_coerces_from_datetime(self): + """target dialect accepts a datetime object as the target + of a date column.""" + + return exclusions.open() + + @property + def date_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def time(self): + """target dialect supports representation of Python + datetime.time() objects.""" + + return exclusions.open() + + @property + def time_microseconds(self): + """target dialect supports representation of Python + datetime.time() with microsecond objects.""" + + return exclusions.open() + + @property + def binary_comparisons(self): + """target database/driver can allow BLOB/BINARY fields to be compared + against a bound parameter value. + """ + + return exclusions.open() + + @property + def binary_literals(self): + """target backend supports simple binary literals, e.g. an + expression like:: + + SELECT CAST('foo' AS BINARY) + + Where ``BINARY`` is the type emitted from :class:`.LargeBinary`, + e.g. it could be ``BLOB`` or similar. + + Basically fails on Oracle. + + """ + + return exclusions.open() + + @property + def precision_numerics_general(self): + """target backend has general support for moderately high-precision + numerics.""" + return exclusions.open() + + @property + def precision_numerics_enotation_small(self): + """target backend supports Decimal() objects using E notation + to represent very small values.""" + return exclusions.closed() + + @property + def precision_numerics_enotation_large(self): + """target backend supports Decimal() objects using E notation + to represent very large values.""" + return exclusions.closed() + + @property + def precision_numerics_many_significant_digits(self): + """target backend supports values with many digits on both sides, + such as 319438950232418390.273596, 87673.594069654243 + + """ + return exclusions.closed() + + @property + def precision_numerics_retains_significant_digits(self): + """A precision numeric type will return empty significant digits, + i.e. a value such as 10.000 will come back in Decimal form with + the .000 maintained.""" + + return exclusions.closed() + + @property + def precision_generic_float_type(self): + """target backend will return native floating point numbers with at + least seven decimal places when using the generic Float type. + + """ + return exclusions.open() + + @property + def floats_to_four_decimals(self): + """target backend can return a floating-point number with four + significant digits (such as 15.7563) accurately + (i.e. without FP inaccuracies, such as 15.75629997253418). + + """ + return exclusions.open() + + @property + def fetch_null_from_numeric(self): + """target backend doesn't crash when you try to select a NUMERIC + value that has a value of NULL. + + Added to support Pyodbc bug #351. + """ + + return exclusions.open() + + @property + def text_type(self): + """Target database must support an unbounded Text() " + "type such as TEXT or CLOB""" + + return exclusions.open() + + @property + def empty_strings_varchar(self): + """target database can persist/return an empty string with a + varchar. + + """ + return exclusions.open() + + @property + def empty_strings_text(self): + """target database can persist/return an empty string with an + unbounded text.""" + + return exclusions.open() + + @property + def selectone(self): + """target driver must support the literal statement 'select 1'""" + return exclusions.open() + + @property + def savepoints(self): + """Target database must support savepoints.""" + + return exclusions.closed() + + @property + def two_phase_transactions(self): + """Target database must support two-phase transactions.""" + + return exclusions.closed() + + + @property + def update_from(self): + """Target must support UPDATE..FROM syntax""" + return exclusions.closed() + + @property + def update_where_target_in_subquery(self): + """Target must support UPDATE where the same table is present in a + subquery in the WHERE clause. + + This is an ANSI-standard syntax that apparently MySQL can't handle, + such as: + + UPDATE documents SET flag=1 WHERE documents.title IN + (SELECT max(documents.title) AS title + FROM documents GROUP BY documents.user_id + ) + """ + return exclusions.open() + + @property + def mod_operator_as_percent_sign(self): + """target database must use a plain percent '%' as the 'modulus' + operator.""" + return exclusions.closed() + + @property + def percent_schema_names(self): + """target backend supports weird identifiers with percent signs + in them, e.g. 'some % column'. + + this is a very weird use case but often has problems because of + DBAPIs that use python formatting. It's not a critical use + case either. + + """ + return exclusions.closed() + + @property + def order_by_label_with_expression(self): + """target backend supports ORDER BY a column label within an + expression. + + Basically this:: + + select data as foo from test order by foo || 'bar' + + Lots of databases including Postgresql don't support this, + so this is off by default. + + """ + return exclusions.closed() + + @property + def unicode_connections(self): + """Target driver must support non-ASCII characters being passed at all.""" + return exclusions.open() + + @property + def skip_mysql_on_windows(self): + """Catchall for a large variety of MySQL on Windows failures""" + return exclusions.open() + + @property + def ad_hoc_engines(self): + """Test environment must allow ad-hoc engine/connection creation. + + DBs that scale poorly for many connections, even when closed, i.e. + Oracle, may use the "--low-connections" option which flags this requirement + as not present. + + """ + return exclusions.skip_if(lambda config: config.options.low_connections) + + def _has_mysql_on_windows(self, config): + return False + + def _has_mysql_fully_case_sensitive(self, config): + return False + + @property + def sqlite(self): + return exclusions.skip_if(lambda: not self._has_sqlite()) + + @property + def cextensions(self): + return exclusions.skip_if( + lambda: not self._has_cextensions(), "C extensions not installed" + ) + + def _has_sqlite(self): + from sqlalchemy import create_engine + try: + create_engine('sqlite://') + return True + except ImportError: + return False + + def _has_cextensions(self): + try: + from sqlalchemy import cresultproxy, cprocessors + return True + except ImportError: + return False diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py new file mode 100644 index 00000000..19aba53d --- /dev/null +++ b/lib/sqlalchemy/testing/runner.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# testing/runner.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +Nose test runner module. + +This script is a front-end to "nosetests" which +installs SQLAlchemy's testing plugin into the local environment. + +The script is intended to be used by third-party dialects and extensions +that run within SQLAlchemy's testing framework. The runner can +be invoked via:: + + python -m sqlalchemy.testing.runner + +The script is then essentially the same as the "nosetests" script, including +all of the usual Nose options. The test environment requires that a +setup.cfg is locally present including various required options. + +Note that when using this runner, Nose's "coverage" plugin will not be +able to provide coverage for SQLAlchemy itself, since SQLAlchemy is +imported into sys.modules before coverage is started. The special +script sqla_nose.py is provided as a top-level script which loads the +plugin in a special (somewhat hacky) way so that coverage against +SQLAlchemy itself is possible. + +""" + +from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy + +import nose + + +def main(): + nose.main(addplugins=[NoseSQLAlchemy()]) + +def setup_py_test(): + """Runner to use for the 'test_suite' entry of your setup.py. + + Prevents any name clash shenanigans from the command line + argument "test" that the "setup.py test" command sends + to nose. + + """ + nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner']) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py new file mode 100644 index 00000000..4766af18 --- /dev/null +++ b/lib/sqlalchemy/testing/schema.py @@ -0,0 +1,100 @@ +# testing/schema.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import exclusions +from .. import schema, event +from . import config + +__all__ = 'Table', 'Column', + +table_options = {} + + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in list(kw) + if k.startswith('test_')]) + + kw.update(table_options) + + if exclusions.against(config._current, 'mysql'): + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + else: + kw['mysql_engine'] = 'MyISAM' + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if exclusions.against(config._current, 'firebird'): + table_name = args[0] + unpack = (config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.foreign_keys] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in list(kw) + if k.startswith('test_')]) + + if config.requirements.foreign_key_ddl.predicate(config): + args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] + + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False): + + # allow any test suite to pick up on this + col.info['test_needs_autoincrement'] = True + + # hardcoded rule for firebird, oracle; this should + # be moved out + if exclusions.against(config._current, 'firebird', 'oracle'): + def add_seq(c, tbl): + c._init_items( + schema.Sequence(_truncate_name( + config.db.dialect, tbl.name + '_' + c.name + '_seq'), + optional=True) + ) + event.listen(col, 'after_parent_attach', add_seq, propagate=True) + return col + + + + + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + \ + "_" + hex(hash(name) % 64)[2:] + else: + return name diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 00000000..780aa40a --- /dev/null +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -0,0 +1,9 @@ + +from sqlalchemy.testing.suite.test_ddl import * +from sqlalchemy.testing.suite.test_insert import * +from sqlalchemy.testing.suite.test_sequence import * +from sqlalchemy.testing.suite.test_select import * +from sqlalchemy.testing.suite.test_results import * +from sqlalchemy.testing.suite.test_update_delete import * +from sqlalchemy.testing.suite.test_reflection import * +from sqlalchemy.testing.suite.test_types import * diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 00000000..2dca1443 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,63 @@ + + +from .. import fixtures, config, util +from ..config import requirements +from ..assertions import eq_ + +from sqlalchemy import Table, Column, Integer, String + + +class TableDDLTest(fixtures.TestBase): + __backend__ = True + + def _simple_fixture(self): + return Table('test_table', self.metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(50)) + ) + + def _underscore_fixture(self): + return Table('_test_table', self.metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('_data', String(50)) + ) + + def _simple_roundtrip(self, table): + with config.db.begin() as conn: + conn.execute(table.insert().values((1, 'some data'))) + result = conn.execute(table.select()) + eq_( + result.first(), + (1, 'some data') + ) + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + self._simple_roundtrip(table) + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + table.drop( + config.db, checkfirst=False + ) + + @requirements.create_table + @util.provide_metadata + def test_underscore_names(self): + table = self._underscore_fixture() + table.create( + config.db, checkfirst=False + ) + self._simple_roundtrip(table) + +__all__ = ('TableDDLTest', ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py new file mode 100644 index 00000000..3444e15c --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -0,0 +1,230 @@ +from .. import fixtures, config +from ..config import requirements +from .. import exclusions +from ..assertions import eq_ +from .. import engines + +from sqlalchemy import Integer, String, select, util + +from ..schema import Table, Column + + +class LastrowidTest(fixtures.TablesTest): + run_deletes = 'each' + + __backend__ = True + + __requires__ = 'implements_get_lastrowid', 'autoincrement_insert' + + __engine_options__ = {"implicit_returning": False} + + @classmethod + def define_tables(cls, metadata): + Table('autoinc_pk', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)) + ) + + Table('manual_pk', metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(50)) + ) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + (config.db.dialect.default_sequence_base, "some data") + ) + + def test_autoincrement_on_insert(self): + + config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + self._assert_round_trip(self.tables.autoinc_pk, config.db) + + def test_last_inserted_id(self): + + r = config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) + eq_( + r.inserted_primary_key, + [pk] + ) + + # failed on pypy1.9 but seems to be OK on pypy 2.1 + #@exclusions.fails_if(lambda: util.pypy, "lastrowid not maintained after " + # "connection close") + @requirements.dbapi_lastrowid + def test_native_lastrowid_autoinc(self): + r = config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + lastrowid = r.lastrowid + pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) + eq_( + lastrowid, pk + ) + + +class InsertBehaviorTest(fixtures.TablesTest): + run_deletes = 'each' + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('autoinc_pk', metadata, + Column('id', Integer, primary_key=True, \ + test_needs_autoincrement=True), + Column('data', String(50)) + ) + Table('manual_pk', metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(50)) + ) + + def test_autoclose_on_insert(self): + if requirements.returning.enabled: + engine = engines.testing_engine( + options={'implicit_returning': False}) + else: + engine = config.db + + r = engine.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + assert r.closed + assert r.is_insert + assert not r.returns_rows + + @requirements.returning + def test_autoclose_on_insert_implicit_returning(self): + r = config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + assert r.closed + assert r.is_insert + assert not r.returns_rows + + @requirements.empty_inserts + def test_empty_insert(self): + r = config.db.execute( + self.tables.autoinc_pk.insert(), + ) + assert r.closed + + r = config.db.execute( + self.tables.autoinc_pk.select().\ + where(self.tables.autoinc_pk.c.id != None) + ) + + assert len(r.fetchall()) + + @requirements.insert_from_select + def test_insert_from_select(self): + table = self.tables.manual_pk + config.db.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ] + ) + + + config.db.execute( + table.insert(inline=True). + from_select( + ("id", "data",), select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"])) + ), + ) + + eq_( + config.db.execute( + select([table.c.data]).order_by(table.c.data) + ).fetchall(), + [("data1", ), ("data2", ), ("data2", ), + ("data3", ), ("data3", )] + ) + +class ReturningTest(fixtures.TablesTest): + run_create_tables = 'each' + __requires__ = 'returning', 'autoincrement_insert' + __backend__ = True + + __engine_options__ = {"implicit_returning": True} + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + (config.db.dialect.default_sequence_base, "some data") + ) + + @classmethod + def define_tables(cls, metadata): + Table('autoinc_pk', metadata, + Column('id', Integer, primary_key=True, \ + test_needs_autoincrement=True), + Column('data', String(50)) + ) + + @requirements.fetch_rows_post_commit + def test_explicit_returning_pk_autocommit(self): + engine = config.db + table = self.tables.autoinc_pk + r = engine.execute( + table.insert().returning( + table.c.id), + data="some data" + ) + pk = r.first()[0] + fetched_pk = config.db.scalar(select([table.c.id])) + eq_(fetched_pk, pk) + + def test_explicit_returning_pk_no_autocommit(self): + engine = config.db + table = self.tables.autoinc_pk + with engine.begin() as conn: + r = conn.execute( + table.insert().returning( + table.c.id), + data="some data" + ) + pk = r.first()[0] + fetched_pk = config.db.scalar(select([table.c.id])) + eq_(fetched_pk, pk) + + def test_autoincrement_on_insert_implcit_returning(self): + + config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + self._assert_round_trip(self.tables.autoinc_pk, config.db) + + def test_last_inserted_id_implicit_returning(self): + + r = config.db.execute( + self.tables.autoinc_pk.insert(), + data="some data" + ) + pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) + eq_( + r.inserted_primary_key, + [pk] + ) + + +__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest') diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 00000000..762c9955 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -0,0 +1,540 @@ + + +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc +from sqlalchemy import types as sql_types +from sqlalchemy import inspect +from sqlalchemy import MetaData, Integer, String +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.testing import engines, fixtures +from sqlalchemy.testing.schema import Table, Column +from sqlalchemy.testing import eq_, assert_raises_message +from sqlalchemy import testing +from .. import config +import operator +from sqlalchemy.schema import DDL, Index +from sqlalchemy import event + +metadata, users = None, None + + +class HasTableTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('test_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def test_has_table(self): + with config.db.begin() as conn: + assert config.db.dialect.has_table(conn, "test_table") + assert not config.db.dialect.has_table(conn, "nonexistent_table") + + + + +class ComponentReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + if testing.requires.schemas.enabled: + cls.define_reflected_tables(metadata, "test_schema") + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + users = Table('users', metadata, + Column('user_id', sa.INT, primary_key=True), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('parent_user_id', sa.Integer, + sa.ForeignKey('%susers.user_id' % schema_prefix)), + schema=schema, + test_needs_fk=True, + ) + else: + users = Table('users', metadata, + Column('user_id', sa.INT, primary_key=True), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table("dingalings", metadata, + Column('dingaling_id', sa.Integer, primary_key=True), + Column('address_id', sa.Integer, + sa.ForeignKey('%semail_addresses.address_id' % + schema_prefix)), + Column('data', sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table('email_addresses', metadata, + Column('address_id', sa.Integer), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + cls.define_index(metadata, users) + if testing.requires.view_column_reflection.enabled: + cls.define_views(metadata, schema) + + @classmethod + def define_index(cls, metadata, users): + Index("users_t_idx", users.c.test1, users.c.test2) + Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) + + @classmethod + def define_views(cls, metadata, schema): + for table_name in ('users', 'email_addresses'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + view_name, fullname) + + event.listen( + metadata, + "after_create", + DDL(query) + ) + event.listen( + metadata, + "before_drop", + DDL("DROP VIEW %s" % view_name) + ) + + @testing.requires.schema_reflection + def test_get_schema_names(self): + insp = inspect(testing.db) + + self.assert_('test_schema' in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_dialect_initialize(self): + engine = engines.testing_engine() + assert not hasattr(engine.dialect, 'default_schema_name') + inspect(engine) + assert hasattr(engine.dialect, 'default_schema_name') + + @testing.requires.schema_reflection + def test_get_default_schema_name(self): + insp = inspect(testing.db) + eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) + + @testing.provide_metadata + def _test_get_table_names(self, schema=None, table_type='table', + order_by=None): + meta = self.metadata + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + insp = inspect(meta.bind) + if table_type == 'view': + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ['email_addresses_v', 'users_v'] + eq_(sorted(table_names), answer) + else: + table_names = insp.get_table_names(schema, + order_by=order_by) + if order_by == 'foreign_key': + answer = ['users', 'email_addresses', 'dingalings'] + eq_(table_names, answer) + else: + answer = ['dingalings', 'email_addresses', 'users'] + eq_(sorted(table_names), answer) + + @testing.requires.table_reflection + def test_get_table_names(self): + self._test_get_table_names() + + @testing.requires.table_reflection + @testing.requires.foreign_key_constraint_reflection + def test_get_table_names_fks(self): + self._test_get_table_names(order_by='foreign_key') + + @testing.requires.table_reflection + @testing.requires.schemas + def test_get_table_names_with_schema(self): + self._test_get_table_names('test_schema') + + @testing.requires.view_column_reflection + def test_get_view_names(self): + self._test_get_table_names(table_type='view') + + @testing.requires.view_column_reflection + @testing.requires.schemas + def test_get_view_names_with_schema(self): + self._test_get_table_names('test_schema', table_type='view') + + @testing.requires.table_reflection + @testing.requires.view_column_reflection + def test_get_tables_and_views(self): + self._test_get_table_names() + self._test_get_table_names(table_type='view') + + def _test_get_columns(self, schema=None, table_type='table'): + meta = MetaData(testing.db) + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + table_names = ['users', 'email_addresses'] + if table_type == 'view': + table_names = ['users_v', 'email_addresses_v'] + insp = inspect(meta.bind) + for table_name, table in zip(table_names, (users, + addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]['name']) + ctype = cols[i]['type'].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against('oracle') and ctype_def \ + in (sql_types.Date, sql_types.DateTime): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + self.assert_(len(set(ctype.__mro__). + intersection(ctype_def.__mro__).intersection([ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ])) > 0, '%s(%s), %s(%s)' % (col.name, + col.type, cols[i]['name'], ctype)) + + if not col.primary_key: + assert cols[i]['default'] is None + + @testing.requires.table_reflection + def test_get_columns(self): + self._test_get_columns() + + @testing.provide_metadata + def _type_round_trip(self, *types): + t = Table('t', self.metadata, + *[ + Column('t%d' % i, type_) + for i, type_ in enumerate(types) + ] + ) + t.create() + + return [ + c['type'] for c in + inspect(self.metadata.bind).get_columns('t') + ] + + @testing.requires.table_reflection + def test_numeric_reflection(self): + for typ in self._type_round_trip( + sql_types.Numeric(18, 5), + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) + + @testing.requires.table_reflection + def test_varchar_reflection(self): + typ = self._type_round_trip(sql_types.String(52))[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) + + @testing.requires.table_reflection + @testing.provide_metadata + def test_nullable_reflection(self): + t = Table('t', self.metadata, + Column('a', Integer, nullable=True), + Column('b', Integer, nullable=False)) + t.create() + eq_( + dict( + (col['name'], col['nullable']) + for col in inspect(self.metadata.bind).get_columns('t') + ), + {"a": True, "b": False} + ) + + + @testing.requires.table_reflection + @testing.requires.schemas + def test_get_columns_with_schema(self): + self._test_get_columns(schema='test_schema') + + @testing.requires.view_column_reflection + def test_get_view_columns(self): + self._test_get_columns(table_type='view') + + @testing.requires.view_column_reflection + @testing.requires.schemas + def test_get_view_columns_with_schema(self): + self._test_get_columns(schema='test_schema', table_type='view') + + @testing.provide_metadata + def _test_get_pk_constraint(self, schema=None): + meta = self.metadata + users, addresses = self.tables.users, self.tables.email_addresses + insp = inspect(meta.bind) + + users_cons = insp.get_pk_constraint(users.name, schema=schema) + users_pkeys = users_cons['constrained_columns'] + eq_(users_pkeys, ['user_id']) + + addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) + addr_pkeys = addr_cons['constrained_columns'] + eq_(addr_pkeys, ['address_id']) + + with testing.requires.reflects_pk_names.fail_if(): + eq_(addr_cons['name'], 'email_ad_pk') + + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self): + self._test_get_pk_constraint() + + @testing.requires.table_reflection + @testing.requires.primary_key_constraint_reflection + @testing.requires.schemas + def test_get_pk_constraint_with_schema(self): + self._test_get_pk_constraint(schema='test_schema') + + @testing.requires.table_reflection + @testing.provide_metadata + def test_deprecated_get_primary_keys(self): + meta = self.metadata + users = self.tables.users + insp = Inspector(meta.bind) + assert_raises_message( + sa_exc.SADeprecationWarning, + "Call to deprecated method get_primary_keys." + " Use get_pk_constraint instead.", + insp.get_primary_keys, users.name + ) + + @testing.provide_metadata + def _test_get_foreign_keys(self, schema=None): + meta = self.metadata + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + insp = inspect(meta.bind) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, + schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + self.assert_(fkey1['name'] is not None) + + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + if testing.requires.self_referential_foreign_keys.enabled: + eq_(fkey1['constrained_columns'], ['parent_user_id']) + + #addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, + schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + self.assert_(fkey1['name'] is not None) + + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['remote_user_id']) + + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self): + self._test_get_foreign_keys() + + @testing.requires.foreign_key_constraint_reflection + @testing.requires.schemas + def test_get_foreign_keys_with_schema(self): + self._test_get_foreign_keys(schema='test_schema') + + @testing.provide_metadata + def _test_get_indexes(self, schema=None): + meta = self.metadata + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = inspect(meta.bind) + indexes = insp.get_indexes('users', schema=schema) + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}, + {'unique': False, + 'column_names': ['user_id', 'test2', 'test1'], + 'name': 'users_all_idx'} + ] + index_names = [d['name'] for d in indexes] + for e_index in expected_indexes: + assert e_index['name'] in index_names + index = indexes[index_names.index(e_index['name'])] + for key in e_index: + eq_(e_index[key], index[key]) + + @testing.requires.index_reflection + def test_get_indexes(self): + self._test_get_indexes() + + @testing.requires.index_reflection + @testing.requires.schemas + def test_get_indexes_with_schema(self): + self._test_get_indexes(schema='test_schema') + + + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self): + self._test_get_unique_constraints() + + @testing.requires.unique_constraint_reflection + @testing.requires.schemas + def test_get_unique_constraints_with_schema(self): + self._test_get_unique_constraints(schema='test_schema') + + @testing.provide_metadata + def _test_get_unique_constraints(self, schema=None): + uniques = sorted( + [ + {'name': 'unique_a', 'column_names': ['a']}, + {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, + {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']}, + {'name': 'unique_asc_key', 'column_names': ['asc', 'key']}, + ], + key=operator.itemgetter('name') + ) + orig_meta = self.metadata + table = Table( + 'testtbl', orig_meta, + Column('a', sa.String(20)), + Column('b', sa.String(30)), + Column('c', sa.Integer), + # reserved identifiers + Column('asc', sa.String(30)), + Column('key', sa.String(30)), + schema=schema + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc['column_names'], name=uc['name']) + ) + orig_meta.create_all() + + inspector = inspect(orig_meta.bind) + reflected = sorted( + inspector.get_unique_constraints('testtbl', schema=schema), + key=operator.itemgetter('name') + ) + + for orig, refl in zip(uniques, reflected): + eq_(orig, refl) + + + @testing.provide_metadata + def _test_get_view_definition(self, schema=None): + meta = self.metadata + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + view_name1 = 'users_v' + view_name2 = 'email_addresses_v' + insp = inspect(meta.bind) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + @testing.requires.view_reflection + def test_get_view_definition(self): + self._test_get_view_definition() + + @testing.requires.view_reflection + @testing.requires.schemas + def test_get_view_definition_with_schema(self): + self._test_get_view_definition(schema='test_schema') + + @testing.only_on("postgresql", "PG specific feature") + @testing.provide_metadata + def _test_get_table_oid(self, table_name, schema=None): + meta = self.metadata + users, addresses, dingalings = self.tables.users, \ + self.tables.email_addresses, self.tables.dingalings + insp = inspect(meta.bind) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + def test_get_table_oid(self): + self._test_get_table_oid('users') + + @testing.requires.schemas + def test_get_table_oid_with_schema(self): + self._test_get_table_oid('users', schema='test_schema') + + @testing.requires.table_reflection + @testing.provide_metadata + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + meta = self.metadata + insp = inspect(meta.bind) + + for tname, cname in [ + ('users', 'user_id'), + ('email_addresses', 'address_id'), + ('dingalings', 'dingaling_id'), + ]: + cols = insp.get_columns(tname) + id_ = dict((c['name'], c) for c in cols)[cname] + assert id_.get('autoincrement', True) + + + +__all__ = ('ComponentReflectionTest', 'HasTableTest') diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py new file mode 100644 index 00000000..2fdab4d1 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -0,0 +1,220 @@ +from .. import fixtures, config +from ..config import requirements +from .. import exclusions +from ..assertions import eq_ +from .. import engines + +from sqlalchemy import Integer, String, select, util, sql, DateTime +import datetime +from ..schema import Table, Column + + +class RowFetchTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('plain_pk', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + Table('has_dates', metadata, + Column('id', Integer, primary_key=True), + Column('today', DateTime) + ) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ] + ) + + config.db.execute( + cls.tables.has_dates.insert(), + [ + {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)} + ] + ) + + def test_via_string(self): + row = config.db.execute( + self.tables.plain_pk.select().\ + order_by(self.tables.plain_pk.c.id) + ).first() + + eq_( + row['id'], 1 + ) + eq_( + row['data'], "d1" + ) + + def test_via_int(self): + row = config.db.execute( + self.tables.plain_pk.select().\ + order_by(self.tables.plain_pk.c.id) + ).first() + + eq_( + row[0], 1 + ) + eq_( + row[1], "d1" + ) + + def test_via_col_object(self): + row = config.db.execute( + self.tables.plain_pk.select().\ + order_by(self.tables.plain_pk.c.id) + ).first() + + eq_( + row[self.tables.plain_pk.c.id], 1 + ) + eq_( + row[self.tables.plain_pk.c.data], "d1" + ) + + @requirements.duplicate_names_in_cursor_description + def test_row_with_dupe_names(self): + result = config.db.execute( + select([self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label('data')]).\ + order_by(self.tables.plain_pk.c.id) + ) + row = result.first() + eq_(result.keys(), ['data', 'data']) + eq_(row, ('d1', 'd1')) + + + def test_row_w_scalar_select(self): + """test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + + """ + datetable = self.tables.has_dates + s = select([datetable.alias('x').c.today]).as_scalar() + s2 = select([datetable.c.id, s.label('somelabel')]) + row = config.db.execute(s2).first() + + eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0)) + + +class PercentSchemaNamesTest(fixtures.TablesTest): + """tests using percent signs, spaces in table and column names. + + This is a very fringe use case, doesn't work for MySQL + or Postgresql. the requirement, "percent_schema_names", + is marked "skip" by default. + + """ + + __requires__ = ('percent_schema_names', ) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.percent_table = Table('percent%table', metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table('percent%table', + sql.column("percent%"), + sql.column("spaces % more spaces"), + ) + + def test_single_roundtrip(self): + percent_table = self.tables.percent_table + for params in [ + {'percent%': 5, 'spaces % more spaces': 12}, + {'percent%': 7, 'spaces % more spaces': 11}, + {'percent%': 9, 'spaces % more spaces': 10}, + {'percent%': 11, 'spaces % more spaces': 9} + ]: + config.db.execute(percent_table.insert(), params) + self._assert_table() + + def test_executemany_roundtrip(self): + percent_table = self.tables.percent_table + config.db.execute( + percent_table.insert(), + {'percent%': 5, 'spaces % more spaces': 12} + ) + config.db.execute( + percent_table.insert(), + [{'percent%': 7, 'spaces % more spaces': 11}, + {'percent%': 9, 'spaces % more spaces': 10}, + {'percent%': 11, 'spaces % more spaces': 9}] + ) + self._assert_table() + + def _assert_table(self): + percent_table = self.tables.percent_table + lightweight_percent_table = self.tables.lightweight_percent_table + + for table in ( + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias()): + eq_( + list( + config.db.execute( + table.select().order_by(table.c['percent%']) + ) + ), + [ + (5, 12), + (7, 11), + (9, 10), + (11, 9) + ] + ) + + eq_( + list( + config.db.execute( + table.select(). + where(table.c['spaces % more spaces'].in_([9, 10])). + order_by(table.c['percent%']), + ) + ), + [ + (9, 10), + (11, 9) + ] + ) + + row = config.db.execute(table.select().\ + order_by(table.c['percent%'])).first() + eq_(row['percent%'], 5) + eq_(row['spaces % more spaces'], 12) + + eq_(row[table.c['percent%']], 5) + eq_(row[table.c['spaces % more spaces']], 12) + + config.db.execute( + percent_table.update().values( + {percent_table.c['spaces % more spaces']: 15} + ) + ) + + eq_( + list( + config.db.execute( + percent_table.\ + select().\ + order_by(percent_table.c['percent%']) + ) + ), + [(5, 15), (7, 15), (9, 15), (11, 15)] + ) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 00000000..2ccff61e --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,86 @@ +from .. import fixtures, config +from ..assertions import eq_ + +from sqlalchemy import util +from sqlalchemy import Integer, String, select, func + +from ..schema import Table, Column + + +class OrderByLabelTest(fixtures.TablesTest): + """Test the dialect sends appropriate ORDER BY expressions when + labels are used. + + This essentially exercises the "supports_simple_order_by_label" + setting. + + """ + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer), + Column('q', String(50)), + Column('p', String(50)) + ) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ] + ) + + def _assert_result(self, select, result): + eq_( + config.db.execute(select).fetchall(), + result + ) + + def test_plain(self): + table = self.tables.some_table + lx = table.c.x.label('lx') + self._assert_result( + select([lx]).order_by(lx), + [(1, ), (2, ), (3, )] + ) + + def test_composed_int(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label('lx') + self._assert_result( + select([lx]).order_by(lx), + [(3, ), (5, ), (7, )] + ) + + def test_composed_multiple(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label('lx') + ly = (func.lower(table.c.q) + table.c.p).label('ly') + self._assert_result( + select([lx, ly]).order_by(lx, ly.desc()), + [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))] + ) + + def test_plain_desc(self): + table = self.tables.some_table + lx = table.c.x.label('lx') + self._assert_result( + select([lx]).order_by(lx.desc()), + [(3, ), (2, ), (1, )] + ) + + def test_composed_int_desc(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label('lx') + self._assert_result( + select([lx]).order_by(lx.desc()), + [(7, ), (5, ), (3, )] + ) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py new file mode 100644 index 00000000..6bc2822f --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -0,0 +1,128 @@ +from .. import fixtures, config +from ..config import requirements +from ..assertions import eq_ +from ... import testing + +from ... import Integer, String, Sequence, schema + +from ..schema import Table, Column + +class SequenceTest(fixtures.TablesTest): + __requires__ = ('sequences',) + __backend__ = True + + run_create_tables = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('seq_pk', metadata, + Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), + Column('data', String(50)) + ) + + Table('seq_opt_pk', metadata, + Column('id', Integer, Sequence('tab_id_seq', optional=True), + primary_key=True), + Column('data', String(50)) + ) + + def test_insert_roundtrip(self): + config.db.execute( + self.tables.seq_pk.insert(), + data="some data" + ) + self._assert_round_trip(self.tables.seq_pk, config.db) + + def test_insert_lastrowid(self): + r = config.db.execute( + self.tables.seq_pk.insert(), + data="some data" + ) + eq_( + r.inserted_primary_key, + [1] + ) + + def test_nextval_direct(self): + r = config.db.execute( + self.tables.seq_pk.c.id.default + ) + eq_( + r, 1 + ) + + @requirements.sequences_optional + def test_optional_seq(self): + r = config.db.execute( + self.tables.seq_opt_pk.insert(), + data="some data" + ) + eq_( + r.inserted_primary_key, + [1] + ) + + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + (1, "some data") + ) + + +class HasSequenceTest(fixtures.TestBase): + __requires__ = 'sequences', + __backend__ = True + + def test_has_sequence(self): + s1 = Sequence('user_id_seq') + testing.db.execute(schema.CreateSequence(s1)) + try: + eq_(testing.db.dialect.has_sequence(testing.db, + 'user_id_seq'), True) + finally: + testing.db.execute(schema.DropSequence(s1)) + + @testing.requires.schemas + def test_has_sequence_schema(self): + s1 = Sequence('user_id_seq', schema="test_schema") + testing.db.execute(schema.CreateSequence(s1)) + try: + eq_(testing.db.dialect.has_sequence(testing.db, + 'user_id_seq', schema="test_schema"), True) + finally: + testing.db.execute(schema.DropSequence(s1)) + + def test_has_sequence_neg(self): + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), + False) + + @testing.requires.schemas + def test_has_sequence_schemas_neg(self): + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', + schema="test_schema"), + False) + + @testing.requires.schemas + def test_has_sequence_default_not_in_remote(self): + s1 = Sequence('user_id_seq') + testing.db.execute(schema.CreateSequence(s1)) + try: + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', + schema="test_schema"), + False) + finally: + testing.db.execute(schema.DropSequence(s1)) + + @testing.requires.schemas + def test_has_sequence_remote_not_in_default(self): + s1 = Sequence('user_id_seq', schema="test_schema") + testing.db.execute(schema.CreateSequence(s1)) + try: + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), + False) + finally: + testing.db.execute(schema.DropSequence(s1)) + + diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py new file mode 100644 index 00000000..5d8005f4 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -0,0 +1,598 @@ +# coding: utf-8 + +from .. import fixtures, config +from ..assertions import eq_ +from ..config import requirements +from sqlalchemy import Integer, Unicode, UnicodeText, select +from sqlalchemy import Date, DateTime, Time, MetaData, String, \ + Text, Numeric, Float, literal, Boolean +from ..schema import Table, Column +from ... import testing +import decimal +import datetime +from ...util import u +from ... import util + + +class _LiteralRoundTripFixture(object): + @testing.provide_metadata + def _literal_round_trip(self, type_, input_, output, filter_=None): + """test literal rendering """ + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as it's + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + t = Table('t', self.metadata, Column('x', type_)) + t.create() + + for value in input_: + ins = t.insert().values(x=literal(value)).compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True) + ) + testing.db.execute(ins) + + for row in t.select().execute(): + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + +class _UnicodeFixture(_LiteralRoundTripFixture): + __requires__ = 'unicode_data', + + data = u("Alors vous imaginez ma surprise, au lever du jour, "\ + "quand une drôle de petite voix m’a réveillé. Elle "\ + "disait: « S’il vous plaît… dessine-moi un mouton! »") + + @classmethod + def define_tables(cls, metadata): + Table('unicode_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('unicode_data', cls.datatype), + ) + + def test_round_trip(self): + unicode_table = self.tables.unicode_table + + config.db.execute( + unicode_table.insert(), + { + 'unicode_data': self.data, + } + ) + + row = config.db.execute( + select([ + unicode_table.c.unicode_data, + ]) + ).first() + + eq_( + row, + (self.data, ) + ) + assert isinstance(row[0], util.text_type) + + def test_round_trip_executemany(self): + unicode_table = self.tables.unicode_table + + config.db.execute( + unicode_table.insert(), + [ + { + 'unicode_data': self.data, + } + for i in range(3) + ] + ) + + rows = config.db.execute( + select([ + unicode_table.c.unicode_data, + ]) + ).fetchall() + eq_( + rows, + [(self.data, ) for i in range(3)] + ) + for row in rows: + assert isinstance(row[0], util.text_type) + + def _test_empty_strings(self): + unicode_table = self.tables.unicode_table + + config.db.execute( + unicode_table.insert(), + {"unicode_data": u('')} + ) + row = config.db.execute( + select([unicode_table.c.unicode_data]) + ).first() + eq_(row, (u(''),)) + + def test_literal(self): + self._literal_round_trip(self.datatype, [self.data], [self.data]) + + +class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = 'unicode_data', + __backend__ = True + + datatype = Unicode(255) + + @requirements.empty_strings_varchar + def test_empty_strings_varchar(self): + self._test_empty_strings() + + +class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = 'unicode_data', 'text_type' + __backend__ = True + + datatype = UnicodeText() + + @requirements.empty_strings_text + def test_empty_strings_text(self): + self._test_empty_strings() + +class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = 'text_type', + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('text_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('text_data', Text), + ) + + def test_text_roundtrip(self): + text_table = self.tables.text_table + + config.db.execute( + text_table.insert(), + {"text_data": 'some text'} + ) + row = config.db.execute( + select([text_table.c.text_data]) + ).first() + eq_(row, ('some text',)) + + def test_text_empty_strings(self): + text_table = self.tables.text_table + + config.db.execute( + text_table.insert(), + {"text_data": ''} + ) + row = config.db.execute( + select([text_table.c.text_data]) + ).first() + eq_(row, ('',)) + + def test_literal(self): + self._literal_round_trip(Text, ["some text"], ["some text"]) + + def test_literal_quoting(self): + data = '''some 'text' hey "hi there" that's text''' + self._literal_round_trip(Text, [data], [data]) + + def test_literal_backslashes(self): + data = r'backslash one \ backslash two \\ end' + self._literal_round_trip(Text, [data], [data]) + +class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @requirements.unbounded_varchar + def test_nolength_string(self): + metadata = MetaData() + foo = Table('foo', metadata, + Column('one', String) + ) + + foo.create(config.db) + foo.drop(config.db) + + def test_literal(self): + self._literal_round_trip(String(40), ["some text"], ["some text"]) + + def test_literal_quoting(self): + data = '''some 'text' hey "hi there" that's text''' + self._literal_round_trip(String(40), [data], [data]) + + def test_literal_backslashes(self): + data = r'backslash one \ backslash two \\ end' + self._literal_round_trip(Text, [data], [data]) + + +class _DateFixture(_LiteralRoundTripFixture): + compare = None + + @classmethod + def define_tables(cls, metadata): + Table('date_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('date_data', cls.datatype), + ) + + def test_round_trip(self): + date_table = self.tables.date_table + + config.db.execute( + date_table.insert(), + {'date_data': self.data} + ) + + row = config.db.execute( + select([ + date_table.c.date_data, + ]) + ).first() + + compare = self.compare or self.data + eq_(row, + (compare, )) + assert isinstance(row[0], type(compare)) + + def test_null(self): + date_table = self.tables.date_table + + config.db.execute( + date_table.insert(), + {'date_data': None} + ) + + row = config.db.execute( + select([ + date_table.c.date_data, + ]) + ).first() + eq_(row, (None,)) + + @testing.requires.datetime_literals + def test_literal(self): + compare = self.compare or self.data + self._literal_round_trip(self.datatype, [self.data], [compare]) + + + +class DateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'datetime', + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + + +class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'datetime_microseconds', + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + +class TimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'time', + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18) + + +class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'time_microseconds', + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18, 396) + + +class DateTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'date', + __backend__ = True + datatype = Date + data = datetime.date(2012, 10, 15) + + +class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'date', 'date_coerces_from_datetime' + __backend__ = True + datatype = Date + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + compare = datetime.date(2012, 10, 15) + + +class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'datetime_historic', + __backend__ = True + datatype = DateTime + data = datetime.datetime(1850, 11, 10, 11, 52, 35) + + +class DateHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = 'date_historic', + __backend__ = True + datatype = Date + data = datetime.date(1727, 4, 1) + + +class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + def test_literal(self): + self._literal_round_trip(Integer, [5], [5]) + +class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + @testing.provide_metadata + def _do_test(self, type_, input_, output, filter_=None, check_scale=False): + metadata = self.metadata + t = Table('t', metadata, Column('x', type_)) + t.create() + t.insert().execute([{'x':x} for x in input_]) + + result = set([row[0] for row in t.select().execute()]) + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_( + [str(x) for x in result], + [str(x) for x in output], + ) + + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric(self): + self._literal_round_trip( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric_asfloat(self): + self._literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self): + self._literal_round_trip( + Float(4), + [15.7563, decimal.Decimal("15.7563")], + [15.7563,], + filter_=lambda n: n is not None and round(n, 5) or None + ) + + + @testing.requires.precision_generic_float_type + def test_float_custom_scale(self): + self._do_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [15.7563827, decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827"),], + check_scale=True + ) + + def test_numeric_as_decimal(self): + self._do_test( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self): + self._do_test( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_decimal(self): + self._do_test( + Numeric(precision=8, scale=4), + [None], + [None], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_float(self): + self._do_test( + Numeric(precision=8, scale=4, asdecimal=False), + [None], + [None], + ) + + @testing.requires.floats_to_four_decimals + def test_float_as_decimal(self): + self._do_test( + Float(precision=8, asdecimal=True), + [15.7563, decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], + ) + + + def test_float_as_float(self): + self._do_test( + Float(precision=8), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None + ) + + + @testing.requires.precision_numerics_general + def test_precision_decimal(self): + numbers = set([ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ]) + + self._do_test( + Numeric(precision=18, scale=12), + numbers, + numbers, + ) + + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + numbers = set([ + decimal.Decimal('1E-2'), + decimal.Decimal('1E-3'), + decimal.Decimal('1E-4'), + decimal.Decimal('1E-5'), + decimal.Decimal('1E-6'), + decimal.Decimal('1E-7'), + decimal.Decimal('1E-8'), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ]) + self._do_test( + Numeric(precision=18, scale=14), + numbers, + numbers + ) + + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self): + """test exceedingly large decimals. + + """ + + numbers = set([ + decimal.Decimal('4E+8'), + decimal.Decimal("5748E+15"), + decimal.Decimal('1.521E+15'), + decimal.Decimal('00000000000000.1E+12'), + ]) + self._do_test( + Numeric(precision=25, scale=2), + numbers, + numbers + ) + + @testing.requires.precision_numerics_many_significant_digits + def test_many_significant_digits(self): + numbers = set([ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ]) + self._do_test( + Numeric(precision=38, scale=12), + numbers, + numbers + ) + + @testing.requires.precision_numerics_retains_significant_digits + def test_numeric_no_decimal(self): + numbers = set([ + decimal.Decimal("1.000") + ]) + self._do_test( + Numeric(precision=5, scale=3), + numbers, + numbers, + check_scale=True + ) + + +class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('boolean_table', metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('value', Boolean), + Column('unconstrained_value', Boolean(create_constraint=False)), + ) + + def test_render_literal_bool(self): + self._literal_round_trip( + Boolean(), + [True, False], + [True, False] + ) + + def test_round_trip(self): + boolean_table = self.tables.boolean_table + + config.db.execute( + boolean_table.insert(), + { + 'id': 1, + 'value': True, + 'unconstrained_value': False + } + ) + + row = config.db.execute( + select([ + boolean_table.c.value, + boolean_table.c.unconstrained_value + ]) + ).first() + + eq_( + row, + (True, False) + ) + assert isinstance(row[0], bool) + + def test_null(self): + boolean_table = self.tables.boolean_table + + config.db.execute( + boolean_table.insert(), + { + 'id': 1, + 'value': None, + 'unconstrained_value': None + } + ) + + row = config.db.execute( + select([ + boolean_table.c.value, + boolean_table.c.unconstrained_value + ]) + ).first() + + eq_( + row, + (None, None) + ) + + + + +__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', + 'DateTest', 'DateTimeTest', 'TextTest', + 'NumericTest', 'IntegerTest', + 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', + 'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest', + 'DateHistoricTest', 'StringTest', 'BooleanTest') diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py new file mode 100644 index 00000000..88dc9535 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -0,0 +1,63 @@ +from .. import fixtures, config +from ..assertions import eq_ + +from sqlalchemy import Integer, String +from ..schema import Table, Column + + +class SimpleUpdateDeleteTest(fixtures.TablesTest): + run_deletes = 'each' + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table('plain_pk', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.plain_pk.insert(), + [ + {"id":1, "data":"d1"}, + {"id":2, "data":"d2"}, + {"id":3, "data":"d3"}, + ] + ) + + def test_update(self): + t = self.tables.plain_pk + r = config.db.execute( + t.update().where(t.c.id == 2), + data="d2_new" + ) + assert not r.is_insert + assert not r.returns_rows + + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [ + (1, "d1"), + (2, "d2_new"), + (3, "d3") + ] + ) + + def test_delete(self): + t = self.tables.plain_pk + r = config.db.execute( + t.delete().where(t.c.id == 2) + ) + assert not r.is_insert + assert not r.returns_rows + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [ + (1, "d1"), + (3, "d3") + ] + ) + +__all__ = ('SimpleUpdateDeleteTest', ) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py new file mode 100644 index 00000000..bde11a35 --- /dev/null +++ b/lib/sqlalchemy/testing/util.py @@ -0,0 +1,205 @@ +# testing/util.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from ..util import jython, pypy, defaultdict, decorator, py2k +import decimal +import gc +import time +import random +import sys +import types + +if jython: + def jython_gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + gc_collect = lazy_gc = jython_gc_collect +elif pypy: + def pypy_gc_collect(*args): + gc.collect() + gc.collect() + gc_collect = lazy_gc = pypy_gc_collect +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + + def lazy_gc(): + pass + + +def picklers(): + picklers = set() + if py2k: + try: + import cPickle + picklers.add(cPickle) + except ImportError: + pass + + import pickle + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle_ in picklers: + for protocol in -1, 0, 1, 2: + yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec) + ).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec) + + +class RandomSet(set): + def __iter__(self): + l = list(set.__iter__(self)) + random.shuffle(l) + return iter(l) + + def pop(self): + index = random.randint(0, len(self) - 1) + item = list(set.__iter__(self))[index] + self.remove(item) + return item + + def union(self, other): + return RandomSet(set.union(self, other)) + + def difference(self, other): + return RandomSet(set.difference(self, other)) + + def intersection(self, other): + return RandomSet(set.intersection(self, other)) + + def copy(self): + return RandomSet(self) + + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + This function should be phased out as much as possible + in favor of @decorator. Tests that "generate" many named tests + should be modernized. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType(fn.__code__, fn.__globals__, name, + fn.__defaults__, fn.__closure__) + return fn + + +def run_as_contextmanager(ctx, fn, *arg, **kw): + """Run the given function under the given contextmanager, + simulating the behavior of 'with' to support older + Python versions. + + """ + + obj = ctx.__enter__() + try: + result = fn(obj, *arg, **kw) + ctx.__exit__(None, None, None) + return result + except: + exc_info = sys.exc_info() + raise_ = ctx.__exit__(*exc_info) + if raise_ is None: + raise + else: + return raise_ + + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + +def fail(msg): + assert False, msg + + +@decorator +def provide_metadata(fn, *args, **kw): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from . import config + from sqlalchemy import schema + + metadata = schema.MetaData(config.db) + self = args[0] + prev_meta = getattr(self, 'metadata', None) + self.metadata = metadata + try: + return fn(*args, **kw) + finally: + metadata.drop_all() + self.metadata = prev_meta + + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def get_all(self, *keys): + return tuple([self[key] for key in keys]) diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py new file mode 100644 index 00000000..849b1b5b --- /dev/null +++ b/lib/sqlalchemy/testing/warnings.py @@ -0,0 +1,56 @@ +# testing/warnings.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import warnings +from .. import exc as sa_exc +from .. import util +import re + +def testing_warn(msg, stacklevel=3): + """Replaces sqlalchemy.util.warn during tests.""" + + filename = "sqlalchemy.testing.warnings" + lineno = 1 + if isinstance(msg, util.string_types): + warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno) + else: + warnings.warn_explicit(msg, filename, lineno) + + +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + util.warn = util.langhelpers.warn = testing_warn + + warnings.filterwarnings('ignore', + category=sa_exc.SAPendingDeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) + + +def assert_warnings(fn, warnings, regex=False): + """Assert that each of the given warnings are emitted by fn.""" + + from .assertions import eq_, emits_warning + + canary = [] + orig_warn = util.warn + + def capture_warnings(*args, **kw): + orig_warn(*args, **kw) + popwarn = warnings.pop(0) + canary.append(popwarn) + if regex: + assert re.match(popwarn, args[0]) + else: + eq_(args[0], popwarn) + util.warn = util.langhelpers.warn = capture_warnings + + result = emits_warning()(fn)() + assert canary, "No warning was emitted" + return result diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py new file mode 100644 index 00000000..3994bd4a --- /dev/null +++ b/lib/sqlalchemy/types.py @@ -0,0 +1,77 @@ +# types.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Compatiblity namespace for sqlalchemy.sql.types. + +""" + +__all__ = ['TypeEngine', 'TypeDecorator', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text', + 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME', + 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', + 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', + 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', + 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', + 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum'] + +from .sql.type_api import ( + adapt_type, + TypeEngine, + TypeDecorator, + Variant, + to_instance, + UserDefinedType +) +from .sql.sqltypes import ( + BIGINT, + BINARY, + BLOB, + BOOLEAN, + BigInteger, + Binary, + _Binary, + Boolean, + CHAR, + CLOB, + Concatenable, + DATE, + DATETIME, + DECIMAL, + Date, + DateTime, + Enum, + FLOAT, + Float, + INT, + INTEGER, + Integer, + Interval, + LargeBinary, + NCHAR, + NVARCHAR, + NullType, + NULLTYPE, + NUMERIC, + Numeric, + PickleType, + REAL, + SchemaType, + SMALLINT, + SmallInteger, + String, + STRINGTYPE, + TEXT, + TIME, + TIMESTAMP, + Text, + Time, + Unicode, + UnicodeText, + VARBINARY, + VARCHAR, + _type_map + ) + diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py new file mode 100644 index 00000000..eba64ed1 --- /dev/null +++ b/lib/sqlalchemy/util/__init__.py @@ -0,0 +1,45 @@ +# util/__init__.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .compat import callable, cmp, reduce, \ + threading, py3k, py33, py2k, jython, pypy, cpython, win32, \ + pickle, dottedgetter, parse_qsl, namedtuple, next, reraise, \ + raise_from_cause, text_type, string_types, int_types, binary_type, \ + quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\ + unquote_plus, unquote, b64decode, b64encode, byte_buffer, itertools_filter,\ + iterbytes, StringIO, inspect_getargspec + +from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \ + Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ + OrderedSet, IdentitySet, OrderedIdentitySet, column_set, \ + column_dict, ordered_column_set, populate_column_dict, unique_list, \ + UniqueAppender, PopulateDict, EMPTY_SET, to_list, to_set, \ + to_column_set, update_copy, flatten_iterator, \ + LRUCache, ScopedRegistry, ThreadLocalRegistry, WeakSequence, \ + coerce_generator_arg + +from .langhelpers import iterate_attributes, class_hierarchy, \ + portable_instancemethod, unbound_method_to_callable, \ + getargspec_init, format_argspec_init, format_argspec_plus, \ + get_func_kwargs, get_cls_kwargs, decorator, as_interface, \ + memoized_property, memoized_instancemethod, md5_hex, \ + group_expirable_memoized_property, dependencies, decode_slice, \ + monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\ + duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ + classproperty, set_creation_order, warn_exception, warn, NoneType,\ + constructor_copy, methods_equivalent, chop_traceback, asint,\ + generic_repr, counter, PluginLoader, hybridmethod, safe_reraise,\ + get_callable_argspec, only_once + +from .deprecations import warn_deprecated, warn_pending_deprecation, \ + deprecated, pending_deprecation, inject_docstring_text + +# things that used to be not always available, +# but are now as of current support Python versions +from collections import defaultdict +from functools import partial +from functools import update_wrapper +from contextlib import contextmanager diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py new file mode 100644 index 00000000..c0a24ba4 --- /dev/null +++ b/lib/sqlalchemy/util/_collections.py @@ -0,0 +1,957 @@ +# util/_collections.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Collection classes and helpers.""" + +from __future__ import absolute_import +import weakref +import operator +from .compat import threading, itertools_filterfalse +from . import py2k +import types + +EMPTY_SET = frozenset() + + +class KeyedTuple(tuple): + """``tuple`` subclass that adds labeled names. + + E.g.:: + + >>> k = KeyedTuple([1, 2, 3], labels=["one", "two", "three"]) + >>> k.one + 1 + >>> k.two + 2 + + Result rows returned by :class:`.Query` that contain multiple + ORM entities and/or column expressions make use of this + class to return rows. + + The :class:`.KeyedTuple` exhibits similar behavior to the + ``collections.namedtuple()`` construct provided in the Python + standard library, however is architected very differently. + Unlike ``collections.namedtuple()``, :class:`.KeyedTuple` is + does not rely on creation of custom subtypes in order to represent + a new series of keys, instead each :class:`.KeyedTuple` instance + receives its list of keys in place. The subtype approach + of ``collections.namedtuple()`` introduces significant complexity + and performance overhead, which is not necessary for the + :class:`.Query` object's use case. + + .. versionchanged:: 0.8 + Compatibility methods with ``collections.namedtuple()`` have been + added including :attr:`.KeyedTuple._fields` and + :meth:`.KeyedTuple._asdict`. + + .. seealso:: + + :ref:`ormtutorial_querying` + + """ + + def __new__(cls, vals, labels=None): + t = tuple.__new__(cls, vals) + t._labels = [] + if labels: + t.__dict__.update(zip(labels, vals)) + t._labels = labels + return t + + def keys(self): + """Return a list of string key names for this :class:`.KeyedTuple`. + + .. seealso:: + + :attr:`.KeyedTuple._fields` + + """ + + return [l for l in self._labels if l is not None] + + @property + def _fields(self): + """Return a tuple of string key names for this :class:`.KeyedTuple`. + + This method provides compatibility with ``collections.namedtuple()``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.KeyedTuple.keys` + + """ + return tuple(self.keys()) + + def _asdict(self): + """Return the contents of this :class:`.KeyedTuple` as a dictionary. + + This method provides compatibility with ``collections.namedtuple()``, + with the exception that the dictionary returned is **not** ordered. + + .. versionadded:: 0.8 + + """ + return dict((key, self.__dict__[key]) for key in self.keys()) + + +class ImmutableContainer(object): + def _immutable(self, *arg, **kw): + raise TypeError("%s object is immutable" % self.__class__.__name__) + + __delitem__ = __setitem__ = __setattr__ = _immutable + + +class immutabledict(ImmutableContainer, dict): + + clear = pop = popitem = setdefault = \ + update = ImmutableContainer._immutable + + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new + + def __init__(self, *args): + pass + + def __reduce__(self): + return immutabledict, (dict(self), ) + + def union(self, d): + if not self: + return immutabledict(d) + else: + d2 = immutabledict(self) + dict.update(d2, d) + return d2 + + def __repr__(self): + return "immutabledict(%s)" % dict.__repr__(self) + + +class Properties(object): + """Provide a __getattr__/__setattr__ interface over a dict.""" + + def __init__(self, data): + self.__dict__['_data'] = data + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(list(self._data.values())) + + def __add__(self, other): + return list(self) + list(other) + + def __setitem__(self, key, object): + self._data[key] = object + + def __getitem__(self, key): + return self._data[key] + + def __delitem__(self, key): + del self._data[key] + + def __setattr__(self, key, object): + self._data[key] = object + + def __getstate__(self): + return {'_data': self.__dict__['_data']} + + def __setstate__(self, state): + self.__dict__['_data'] = state['_data'] + + def __getattr__(self, key): + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + return key in self._data + + def as_immutable(self): + """Return an immutable proxy for this :class:`.Properties`.""" + + return ImmutableProperties(self._data) + + def update(self, value): + self._data.update(value) + + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def keys(self): + return list(self._data) + + def values(self): + return list(self._data.values()) + + def items(self): + return list(self._data.items()) + + def has_key(self, key): + return key in self._data + + def clear(self): + self._data.clear() + + +class OrderedProperties(Properties): + """Provide a __getattr__/__setattr__ interface with an OrderedDict + as backing store.""" + def __init__(self): + Properties.__init__(self, OrderedDict()) + + +class ImmutableProperties(ImmutableContainer, Properties): + """Provide immutable dict/object attribute to an underlying dictionary.""" + + +class OrderedDict(dict): + """A dict that returns keys/values/items in the order they were added.""" + + def __init__(self, ____sequence=None, **kwargs): + self._list = [] + if ____sequence is None: + if kwargs: + self.update(**kwargs) + else: + self.update(____sequence, **kwargs) + + def clear(self): + self._list = [] + dict.clear(self) + + def copy(self): + return self.__copy__() + + def __copy__(self): + return OrderedDict(self) + + def sort(self, *arg, **kw): + self._list.sort(*arg, **kw) + + def update(self, ____sequence=None, **kwargs): + if ____sequence is not None: + if hasattr(____sequence, 'keys'): + for key in ____sequence.keys(): + self.__setitem__(key, ____sequence[key]) + else: + for key, value in ____sequence: + self[key] = value + if kwargs: + self.update(kwargs) + + def setdefault(self, key, value): + if key not in self: + self.__setitem__(key, value) + return value + else: + return self.__getitem__(key) + + def __iter__(self): + return iter(self._list) + + + if py2k: + def values(self): + return [self[key] for key in self._list] + + def keys(self): + return self._list + + def itervalues(self): + return iter([self[key] for key in self._list]) + + def iterkeys(self): + return iter(self) + + def iteritems(self): + return iter(self.items()) + + def items(self): + return [(key, self[key]) for key in self._list] + else: + def values(self): + #return (self[key] for key in self) + return (self[key] for key in self._list) + + def keys(self): + #return iter(self) + return iter(self._list) + + def items(self): + #return ((key, self[key]) for key in self) + return ((key, self[key]) for key in self._list) + + _debug_iter = False + if _debug_iter: + # normally disabled to reduce function call + # overhead + def __iter__(self): + len_ = len(self._list) + for item in self._list: + yield item + assert len_ == len(self._list), \ + "Dictionary changed size during iteration" + def values(self): + return (self[key] for key in self) + def keys(self): + return iter(self) + def items(self): + return ((key, self[key]) for key in self) + + + def __setitem__(self, key, object): + if key not in self: + try: + self._list.append(key) + except AttributeError: + # work around Python pickle loads() with + # dict subclass (seems to ignore __setstate__?) + self._list = [key] + dict.__setitem__(self, key, object) + + def __delitem__(self, key): + dict.__delitem__(self, key) + self._list.remove(key) + + def pop(self, key, *default): + present = key in self + value = dict.pop(self, key, *default) + if present: + self._list.remove(key) + return value + + def popitem(self): + item = dict.popitem(self) + self._list.remove(item[0]) + return item + + +class OrderedSet(set): + def __init__(self, d=None): + set.__init__(self) + self._list = [] + if d is not None: + self.update(d) + + def add(self, element): + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element): + set.remove(self, element) + self._list.remove(element) + + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element): + if element in self: + self._list.remove(element) + set.remove(self, element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + __ior__ = update + + def union(self, other): + result = self.__class__(self) + result.update(other) + return result + + __or__ = union + + def intersection(self, other): + other = set(other) + return self.__class__(a for a in self if a in other) + + __and__ = intersection + + def symmetric_difference(self, other): + other = set(other) + result = self.__class__(a for a in self if a not in other) + result.update(a for a in other if a not in self) + return result + + __xor__ = symmetric_difference + + def difference(self, other): + other = set(other) + return self.__class__(a for a in self if a not in other) + + __sub__ = difference + + def intersection_update(self, other): + other = set(other) + set.intersection_update(self, other) + self._list = [a for a in self._list if a in other] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other._list if a in self] + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [a for a in self._list if a in self] + return self + + __isub__ = difference_update + + +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + _working_set = set + + def __init__(self, iterable=None): + self._members = dict() + if iterable: + for o in iterable: + self.add(o) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError('pop from an empty set') + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + other = type(self)(iterable) + + if len(self) > len(other): + return False + for m in itertools_filterfalse(other._members.__contains__, + iter(self._members.keys())): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + other = type(self)(iterable) + + if len(self) < len(other): + return False + + for m in itertools_filterfalse(self._members.__contains__, + iter(other._members.keys())): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + members = self._member_id_tuples() + other = _iter_id(iterable) + result._members.update(self._working_set(members).union(other)) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members = self.union(iterable)._members + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + members = self._member_id_tuples() + other = _iter_id(iterable) + result._members.update(self._working_set(members).difference(other)) + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + members = self._member_id_tuples() + other = _iter_id(iterable) + result._members.update(self._working_set(members).intersection(other)) + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + members = self._member_id_tuples() + other = _iter_id(iterable) + result._members.update( + self._working_set(members).symmetric_difference(other)) + return result + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.values()) + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + return type(self)(iter(self._members.values())) + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError('set objects are unhashable') + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, list(self._members.values())) + + +class WeakSequence(object): + def __init__(self, __elements=()): + self._storage = [ + weakref.ref(element, self._remove) for element in __elements + ] + + def append(self, item): + self._storage.append(weakref.ref(item, self._remove)) + + def _remove(self, ref): + self._storage.remove(ref) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + return (obj for obj in + (ref() for ref in self._storage) if obj is not None) + + def __getitem__(self, index): + try: + obj = self._storage[index] + except KeyError: + raise IndexError("Index %s out of range" % index) + else: + return obj() + + +class OrderedIdentitySet(IdentitySet): + class _working_set(OrderedSet): + # a testing pragma: exempt the OIDS working set from the test suite's + # "never call the user's __hash__" assertions. this is a big hammer, + # but it's safe here: IDS operates on (id, instance) tuples in the + # working set. + __sa_hash_exempt__ = True + + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + + +class PopulateDict(dict): + """A dict which populates missing values via a creation function. + + Note the creation function takes a key, unlike + collections.defaultdict. + + """ + + def __init__(self, creator): + self.creator = creator + + def __missing__(self, key): + self[key] = val = self.creator(key) + return val + +# Define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +# At this point, these are mostly historical, things +# used to be more complicated. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + +def unique_list(seq, hashfunc=None): + seen = {} + if not hashfunc: + return [x for x in seq + if x not in seen + and not seen.__setitem__(x, True)] + else: + return [x for x in seq + if hashfunc(x) not in seen + and not seen.__setitem__(hashfunc(x), True)] + + +class UniqueAppender(object): + """Appends items to a collection ensuring uniqueness. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + + def __init__(self, data, via=None): + self.data = data + self._unique = {} + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, 'append'): + self._data_appender = data.append + elif hasattr(data, 'add'): + self._data_appender = data.add + + def append(self, item): + id_ = id(item) + if id_ not in self._unique: + self._data_appender(item) + self._unique[id_] = True + + def __iter__(self): + return iter(self.data) + +def coerce_generator_arg(arg): + if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): + return list(arg[0]) + else: + return arg + +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, (list, tuple)): + return [x] + else: + return x + + +def to_set(x): + if x is None: + return set() + if not isinstance(x, set): + return set(to_list(x)) + else: + return x + + +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + + +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + + +def flatten_iterator(x): + """Given an iterator of which further sub-elements may also be + iterators, flatten the sub-elements into a single iterator. + + """ + for elem in x: + if not isinstance(elem, str) and hasattr(elem, '__iter__'): + for y in flatten_iterator(elem): + yield y + else: + yield elem + + +class LRUCache(dict): + """Dictionary with 'squishy' removal of least + recently used items. + + """ + def __init__(self, capacity=100, threshold=.5): + self.capacity = capacity + self.threshold = threshold + self._counter = 0 + + def _inc_counter(self): + self._counter += 1 + return self._counter + + def __getitem__(self, key): + item = dict.__getitem__(self, key) + item[2] = self._inc_counter() + return item[1] + + def values(self): + return [i[1] for i in dict.values(self)] + + def setdefault(self, key, value): + if key in self: + return self[key] + else: + self[key] = value + return value + + def __setitem__(self, key, value): + item = dict.get(self, key) + if item is None: + item = [key, value, self._inc_counter()] + dict.__setitem__(self, key, item) + else: + item[1] = value + self._manage_size() + + def _manage_size(self): + while len(self) > self.capacity + self.capacity * self.threshold: + by_counter = sorted(dict.values(self), + key=operator.itemgetter(2), + reverse=True) + for item in by_counter[self.capacity:]: + try: + del self[item[0]] + except KeyError: + # if we couldnt find a key, most + # likely some other thread broke in + # on us. loop around and try again + break + + +class ScopedRegistry(object): + """A Registry that can store one or multiple instances of a single + class on the basis of a "scope" function. + + The object implements ``__call__`` as the "getter", so by + calling ``myregistry()`` the contained object is returned + for the current scope. + + :param createfunc: + a callable that returns a new object to be placed in the registry + + :param scopefunc: + a callable that will return a key to store/retrieve an object. + """ + + def __init__(self, createfunc, scopefunc): + """Construct a new :class:`.ScopedRegistry`. + + :param createfunc: A creation function that will generate + a new value for the current scope, if none is present. + + :param scopefunc: A function that returns a hashable + token representing the current scope (such as, current + thread identifier). + + """ + self.createfunc = createfunc + self.scopefunc = scopefunc + self.registry = {} + + def __call__(self): + key = self.scopefunc() + try: + return self.registry[key] + except KeyError: + return self.registry.setdefault(key, self.createfunc()) + + def has(self): + """Return True if an object is present in the current scope.""" + + return self.scopefunc() in self.registry + + def set(self, obj): + """Set the value forthe current scope.""" + + self.registry[self.scopefunc()] = obj + + def clear(self): + """Clear the current scope, if any.""" + + try: + del self.registry[self.scopefunc()] + except KeyError: + pass + + +class ThreadLocalRegistry(ScopedRegistry): + """A :class:`.ScopedRegistry` that uses a ``threading.local()`` + variable for storage. + + """ + def __init__(self, createfunc): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self): + try: + return self.registry.value + except AttributeError: + val = self.registry.value = self.createfunc() + return val + + def has(self): + return hasattr(self.registry, "value") + + def set(self, obj): + self.registry.value = obj + + def clear(self): + try: + del self.registry.value + except AttributeError: + pass + + +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py new file mode 100644 index 00000000..f1346406 --- /dev/null +++ b/lib/sqlalchemy/util/compat.py @@ -0,0 +1,215 @@ +# util/compat.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Handle Python version/platform incompatibilities.""" + +import sys + +try: + import threading +except ImportError: + import dummy_threading as threading + +py33 = sys.version_info >= (3, 3) +py32 = sys.version_info >= (3, 2) +py3k = sys.version_info >= (3, 0) +py2k = sys.version_info < (3, 0) +jython = sys.platform.startswith('java') +pypy = hasattr(sys, 'pypy_version_info') +win32 = sys.platform.startswith('win') +cpython = not pypy and not jython # TODO: something better for this ? + +import collections +next = next + +if py3k: + import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle + +ArgSpec = collections.namedtuple("ArgSpec", + ["args", "varargs", "keywords", "defaults"]) + +if py3k: + import builtins + + from inspect import getfullargspec as inspect_getfullargspec + from urllib.parse import quote_plus, unquote_plus, parse_qsl, quote, unquote + import configparser + from io import StringIO + + from io import BytesIO as byte_buffer + + def inspect_getargspec(func): + return ArgSpec( + *inspect_getfullargspec(func)[0:4] + ) + + string_types = str, + binary_type = bytes + text_type = str + int_types = int, + iterbytes = iter + + def u(s): + return s + + def ue(s): + return s + + def b(s): + return s.encode("latin-1") + + if py32: + callable = callable + else: + def callable(fn): + return hasattr(fn, '__call__') + + def cmp(a, b): + return (a > b) - (a < b) + + from functools import reduce + + print_ = getattr(builtins, "print") + + import_ = getattr(builtins, '__import__') + + import itertools + itertools_filterfalse = itertools.filterfalse + itertools_filter = filter + itertools_imap = map + + import base64 + def b64encode(x): + return base64.b64encode(x).decode('ascii') + def b64decode(x): + return base64.b64decode(x.encode('ascii')) + +else: + from inspect import getargspec as inspect_getfullargspec + inspect_getargspec = inspect_getfullargspec + from urllib import quote_plus, unquote_plus, quote, unquote + from urlparse import parse_qsl + import ConfigParser as configparser + from StringIO import StringIO + from cStringIO import StringIO as byte_buffer + + string_types = basestring, + binary_type = str + text_type = unicode + int_types = int, long + def iterbytes(buf): + return (ord(byte) for byte in buf) + + def u(s): + # this differs from what six does, which doesn't support non-ASCII + # strings - we only use u() with + # literal source strings, and all our source files with non-ascii + # in them (all are tests) are utf-8 encoded. + return unicode(s, "utf-8") + + def ue(s): + return unicode(s, "unicode_escape") + + def b(s): + return s + + def import_(*args): + if len(args) == 4: + args = args[0:3] + ([str(arg) for arg in args[3]],) + return __import__(*args) + + callable = callable + cmp = cmp + reduce = reduce + + import base64 + b64encode = base64.b64encode + b64decode = base64.b64decode + + def print_(*args, **kwargs): + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + for arg in enumerate(args): + if not isinstance(arg, basestring): + arg = str(arg) + fp.write(arg) + + import itertools + itertools_filterfalse = itertools.ifilterfalse + itertools_filter = itertools.ifilter + itertools_imap = itertools.imap + + +import time +if win32 or jython: + time_func = time.clock +else: + time_func = time.time + +from collections import namedtuple +from operator import attrgetter as dottedgetter + + +if py3k: + def reraise(tp, value, tb=None, cause=None): + if cause is not None: + value.__cause__ = cause + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + def raise_from_cause(exception, exc_info=None): + if exc_info is None: + exc_info = sys.exc_info() + exc_type, exc_value, exc_tb = exc_info + reraise(type(exception), exception, tb=exc_tb, cause=exc_value) +else: + exec("def reraise(tp, value, tb=None, cause=None):\n" + " raise tp, value, tb\n") + + def raise_from_cause(exception, exc_info=None): + # not as nice as that of Py3K, but at least preserves + # the code line where the issue occurred + if exc_info is None: + exc_info = sys.exc_info() + exc_type, exc_value, exc_tb = exc_info + reraise(type(exception), exception, tb=exc_tb) + +if py3k: + exec_ = getattr(builtins, 'exec') +else: + def exec_(func_text, globals_, lcl=None): + if lcl is None: + exec('exec func_text in globals_') + else: + exec('exec func_text in globals_, lcl') + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass. + + Drops the middle class upon creation. + + Source: http://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/ + + """ + + class metaclass(meta): + __call__ = type.__call__ + __init__ = type.__init__ + def __new__(cls, name, this_bases, d): + if this_bases is None: + return type.__new__(cls, name, (), d) + return meta(name, bases, d) + return metaclass('temporary_class', None, {}) + + diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py new file mode 100644 index 00000000..c8854dc3 --- /dev/null +++ b/lib/sqlalchemy/util/deprecations.py @@ -0,0 +1,143 @@ +# util/deprecations.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Helpers related to deprecation of functions, methods, classes, other +functionality.""" + +from .. import exc +import warnings +import re +from .langhelpers import decorator + + +def warn_deprecated(msg, stacklevel=3): + warnings.warn(msg, exc.SADeprecationWarning, stacklevel=stacklevel) + + +def warn_pending_deprecation(msg, stacklevel=3): + warnings.warn(msg, exc.SAPendingDeprecationWarning, stacklevel=stacklevel) + + +def deprecated(version, message=None, add_deprecation_to_docstring=True): + """Decorates a function and issues a deprecation warning on use. + + :param message: + If provided, issue message in the warning. A sensible default + is used if not provided. + + :param add_deprecation_to_docstring: + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + + """ + + if add_deprecation_to_docstring: + header = ".. deprecated:: %s %s" % \ + (version, (message or '')) + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exc.SADeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + + +def pending_deprecation(version, message=None, + add_deprecation_to_docstring=True): + """Decorates a function and issues a pending deprecation warning on use. + + :param version: + An approximate future version at which point the pending deprecation + will become deprecated. Not used in messaging. + + :param message: + If provided, issue message in the warning. A sensible default + is used if not provided. + + :param add_deprecation_to_docstring: + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + """ + + if add_deprecation_to_docstring: + header = ".. deprecated:: %s (pending) %s" % \ + (version, (message or '')) + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exc.SAPendingDeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + + +def _sanitize_restructured_text(text): + def repl(m): + type_, name = m.group(1, 2) + if type_ in ("func", "meth"): + name += "()" + return name + return re.sub(r'\:(\w+)\:`~?\.?(.+?)`', repl, text) + + +def _decorate_with_warning(func, wtype, message, docstring_header=None): + """Wrap a function with a warnings.warn and augmented docstring.""" + + message = _sanitize_restructured_text(message) + + @decorator + def warned(fn, *args, **kwargs): + warnings.warn(wtype(message), stacklevel=3) + return fn(*args, **kwargs) + + doc = func.__doc__ is not None and func.__doc__ or '' + if docstring_header is not None: + docstring_header %= dict(func=func.__name__) + + doc = inject_docstring_text(doc, docstring_header, 1) + + decorated = warned(func) + decorated.__doc__ = doc + return decorated + +import textwrap + +def _dedent_docstring(text): + split_text = text.split("\n", 1) + if len(split_text) == 1: + return text + else: + firstline, remaining = split_text + if not firstline.startswith(" "): + return firstline + "\n" + textwrap.dedent(remaining) + else: + return textwrap.dedent(text) + +def inject_docstring_text(doctext, injecttext, pos): + doctext = _dedent_docstring(doctext or "") + lines = doctext.split('\n') + injectlines = textwrap.dedent(injecttext).split("\n") + if injectlines[0]: + injectlines.insert(0, "") + + blanks = [num for num, line in enumerate(lines) if not line.strip()] + blanks.insert(0, 0) + + inject_pos = blanks[min(pos, len(blanks) - 1)] + + lines = lines[0:inject_pos] + injectlines + lines[inject_pos:] + return "\n".join(lines) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py new file mode 100644 index 00000000..8a1164e7 --- /dev/null +++ b/lib/sqlalchemy/util/langhelpers.py @@ -0,0 +1,1231 @@ +# util/langhelpers.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Routines to help with the creation, loading and introspection of +modules, classes, hierarchies, attributes, functions, and methods. + +""" +import itertools +import inspect +import operator +import re +import sys +import types +import warnings +from functools import update_wrapper +from .. import exc +import hashlib +from . import compat +from . import _collections + +def md5_hex(x): + if compat.py3k: + x = x.encode('utf-8') + m = hashlib.md5() + m.update(x) + return m.hexdigest() + +class safe_reraise(object): + """Reraise an exception after invoking some + handler code. + + Stores the existing exception info before + invoking so that it is maintained across a potential + coroutine context switch. + + e.g.:: + + try: + sess.commit() + except: + with safe_reraise(): + sess.rollback() + + """ + + def __enter__(self): + self._exc_info = sys.exc_info() + + def __exit__(self, type_, value, traceback): + # see #2703 for notes + if type_ is None: + exc_type, exc_value, exc_tb = self._exc_info + self._exc_info = None # remove potential circular references + compat.reraise(exc_type, exc_value, exc_tb) + else: + self._exc_info = None # remove potential circular references + compat.reraise(type_, value, traceback) + +def decode_slice(slc): + """decode a slice object as sent to __getitem__. + + takes into account the 2.5 __index__() method, basically. + + """ + ret = [] + for x in slc.start, slc.stop, slc.step: + if hasattr(x, '__index__'): + x = x.__index__() + ret.append(x) + return tuple(ret) + +def _unique_symbols(used, *bases): + used = set(used) + for base in bases: + pool = itertools.chain((base,), + compat.itertools_imap(lambda i: base + str(i), + range(1000))) + for sym in pool: + if sym not in used: + used.add(sym) + yield sym + break + else: + raise NameError("exhausted namespace for symbol base %s" % base) + + +def decorator(target): + """A signature-matching decorator factory.""" + + def decorate(fn): + if not inspect.isfunction(fn): + raise Exception("not a decoratable function") + spec = compat.inspect_getfullargspec(fn) + names = tuple(spec[0]) + spec[1:3] + (fn.__name__,) + targ_name, fn_name = _unique_symbols(names, 'target', 'fn') + + metadata = dict(target=targ_name, fn=fn_name) + metadata.update(format_argspec_plus(spec, grouped=False)) + metadata['name'] = fn.__name__ + code = """\ +def %(name)s(%(args)s): + return %(target)s(%(fn)s, %(apply_kw)s) +""" % metadata + decorated = _exec_code_in_env(code, + {targ_name: target, fn_name: fn}, + fn.__name__) + decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + return update_wrapper(decorate, target) + +def _exec_code_in_env(code, env, fn_name): + exec(code, env) + return env[fn_name] + +def public_factory(target, location): + """Produce a wrapping function for the given cls or classmethod. + + Rationale here is so that the __init__ method of the + class can serve as documentation for the function. + + """ + if isinstance(target, type): + fn = target.__init__ + callable_ = target + doc = "Construct a new :class:`.%s` object. \n\n"\ + "This constructor is mirrored as a public API function; see :func:`~%s` "\ + "for a full usage and argument description." % ( + target.__name__, location, ) + else: + fn = callable_ = target + doc = "This function is mirrored; see :func:`~%s` "\ + "for a description of arguments." % location + + location_name = location.split(".")[-1] + spec = compat.inspect_getfullargspec(fn) + del spec[0][0] + metadata = format_argspec_plus(spec, grouped=False) + metadata['name'] = location_name + code = """\ +def %(name)s(%(args)s): + return cls(%(apply_kw)s) +""" % metadata + env = {'cls': callable_, 'symbol': symbol} + exec(code, env) + decorated = env[location_name] + decorated.__doc__ = fn.__doc__ + if compat.py2k or hasattr(fn, '__func__'): + fn.__func__.__doc__ = doc + else: + fn.__doc__ = doc + return decorated + + +class PluginLoader(object): + + def __init__(self, group, auto_fn=None): + self.group = group + self.impls = {} + self.auto_fn = auto_fn + + def load(self, name): + if name in self.impls: + return self.impls[name]() + + if self.auto_fn: + loader = self.auto_fn(name) + if loader: + self.impls[name] = loader + return loader() + + try: + import pkg_resources + except ImportError: + pass + else: + for impl in pkg_resources.iter_entry_points( + self.group, name): + self.impls[name] = impl.load + return impl.load() + + raise exc.NoSuchModuleError( + "Can't load plugin: %s:%s" % + (self.group, name)) + + def register(self, name, modulepath, objname): + def load(): + mod = compat.import_(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + self.impls[name] = load + + +def get_cls_kwargs(cls, _set=None): + """Return the full set of inherited kwargs for the given `cls`. + + Probes a class's __init__ method, collecting all named arguments. If the + __init__ defines a \**kwargs catch-all, then the constructor is presumed to + pass along unrecognized keywords to it's base classes, and the collection + process is repeated recursively on each of the bases. + + Uses a subset of inspect.getargspec() to cut down on method overhead. + No anonymous tuple arguments please ! + + """ + toplevel = _set == None + if toplevel: + _set = set() + + ctr = cls.__dict__.get('__init__', False) + + has_init = ctr and isinstance(ctr, types.FunctionType) and \ + isinstance(ctr.__code__, types.CodeType) + + if has_init: + names, has_kw = inspect_func_args(ctr) + _set.update(names) + + if not has_kw and not toplevel: + return None + + if not has_init or has_kw: + for c in cls.__bases__: + if get_cls_kwargs(c, _set) is None: + break + + _set.discard('self') + return _set + + + +try: + # TODO: who doesn't have this constant? + from inspect import CO_VARKEYWORDS + + def inspect_func_args(fn): + co = fn.__code__ + nargs = co.co_argcount + names = co.co_varnames + args = list(names[:nargs]) + has_kw = bool(co.co_flags & CO_VARKEYWORDS) + return args, has_kw + +except ImportError: + def inspect_func_args(fn): + names, _, has_kw, _ = inspect.getargspec(fn) + return names, bool(has_kw) + + +def get_func_kwargs(func): + """Return the set of legal kwargs for the given `func`. + + Uses getargspec so is safe to call for methods, functions, + etc. + + """ + + return compat.inspect_getargspec(func)[0] + +def get_callable_argspec(fn, no_self=False, _is_init=False): + """Return the argument signature for any callable. + + All pure-Python callables are accepted, including + functions, methods, classes, objects with __call__; + builtins and other edge cases like functools.partial() objects + raise a TypeError. + + """ + if inspect.isbuiltin(fn): + raise TypeError("Can't inspect builtin: %s" % fn) + elif inspect.isfunction(fn): + if _is_init and no_self: + spec = compat.inspect_getargspec(fn) + return compat.ArgSpec(spec.args[1:], spec.varargs, + spec.keywords, spec.defaults) + else: + return compat.inspect_getargspec(fn) + elif inspect.ismethod(fn): + if no_self and (_is_init or fn.__self__): + spec = compat.inspect_getargspec(fn.__func__) + return compat.ArgSpec(spec.args[1:], spec.varargs, + spec.keywords, spec.defaults) + else: + return compat.inspect_getargspec(fn.__func__) + elif inspect.isclass(fn): + return get_callable_argspec(fn.__init__, no_self=no_self, _is_init=True) + elif hasattr(fn, '__func__'): + return compat.inspect_getargspec(fn.__func__) + elif hasattr(fn, '__call__'): + if inspect.ismethod(fn.__call__): + return get_callable_argspec(fn.__call__, no_self=no_self) + else: + raise TypeError("Can't inspect callable: %s" % fn) + else: + raise TypeError("Can't inspect callable: %s" % fn) + +def format_argspec_plus(fn, grouped=True): + """Returns a dictionary of formatted, introspected function arguments. + + A enhanced variant of inspect.formatargspec to support code generation. + + fn + An inspectable callable or tuple of inspect getargspec() results. + grouped + Defaults to True; include (parens, around, argument) lists + + Returns: + + args + Full inspect.formatargspec for fn + self_arg + The name of the first positional argument, varargs[0], or None + if the function defines no positional arguments. + apply_pos + args, re-written in calling rather than receiving syntax. Arguments are + passed positionally. + apply_kw + Like apply_pos, except keyword-ish args are passed as keywords. + + Example:: + + >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123) + {'args': '(self, a, b, c=3, **d)', + 'self_arg': 'self', + 'apply_kw': '(self, a, b, c=c, **d)', + 'apply_pos': '(self, a, b, c, **d)'} + + """ + if compat.callable(fn): + spec = compat.inspect_getfullargspec(fn) + else: + # we accept an existing argspec... + spec = fn + args = inspect.formatargspec(*spec) + if spec[0]: + self_arg = spec[0][0] + elif spec[1]: + self_arg = '%s[0]' % spec[1] + else: + self_arg = None + + if compat.py3k: + apply_pos = inspect.formatargspec(spec[0], spec[1], + spec[2], None, spec[4]) + num_defaults = 0 + if spec[3]: + num_defaults += len(spec[3]) + if spec[4]: + num_defaults += len(spec[4]) + name_args = spec[0] + spec[4] + else: + apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2]) + num_defaults = 0 + if spec[3]: + num_defaults += len(spec[3]) + name_args = spec[0] + + if num_defaults: + defaulted_vals = name_args[0 - num_defaults:] + else: + defaulted_vals = () + + apply_kw = inspect.formatargspec(name_args, spec[1], spec[2], + defaulted_vals, + formatvalue=lambda x: '=' + x) + if grouped: + return dict(args=args, self_arg=self_arg, + apply_pos=apply_pos, apply_kw=apply_kw) + else: + return dict(args=args[1:-1], self_arg=self_arg, + apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1]) + + +def format_argspec_init(method, grouped=True): + """format_argspec_plus with considerations for typical __init__ methods + + Wraps format_argspec_plus with error handling strategies for typical + __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + if method is object.__init__: + args = grouped and '(self)' or 'self' + else: + try: + return format_argspec_plus(method, grouped=grouped) + except TypeError: + args = (grouped and '(self, *args, **kwargs)' + or 'self, *args, **kwargs') + return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args) + + +def getargspec_init(method): + """inspect.getargspec with considerations for typical __init__ methods + + Wraps inspect.getargspec with error handling for typical __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return inspect.getargspec(method) + except TypeError: + if method is object.__init__: + return (['self'], None, None, None) + else: + return (['self'], 'args', 'kwargs', None) + + +def unbound_method_to_callable(func_or_cls): + """Adjust the incoming callable such that a 'self' argument is not + required. + + """ + + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__: + return func_or_cls.__func__ + else: + return func_or_cls + + +def generic_repr(obj, additional_kw=(), to_inspect=None): + """Produce a __repr__() based on direct association of the __init__() + specification vs. same-named attributes present. + + """ + if to_inspect is None: + to_inspect = [obj] + else: + to_inspect = _collections.to_list(to_inspect) + + missing = object() + + pos_args = [] + kw_args = _collections.OrderedDict() + vargs = None + for i, insp in enumerate(to_inspect): + try: + (_args, _vargs, vkw, defaults) = \ + inspect.getargspec(insp.__init__) + except TypeError: + continue + else: + default_len = defaults and len(defaults) or 0 + if i == 0: + if _vargs: + vargs = _vargs + if default_len: + pos_args.extend(_args[1:-default_len]) + else: + pos_args.extend(_args[1:]) + else: + kw_args.update([ + (arg, missing) for arg in _args[1:-default_len] + ]) + + if default_len: + kw_args.update([ + (arg, default) + for arg, default + in zip(_args[-default_len:], defaults) + ]) + output = [] + + output.extend(repr(getattr(obj, arg, None)) for arg in pos_args) + + if vargs is not None and hasattr(obj, vargs): + output.extend([repr(val) for val in getattr(obj, vargs)]) + + for arg, defval in kw_args.items(): + try: + val = getattr(obj, arg, missing) + if val is not missing and val != defval: + output.append('%s=%r' % (arg, val)) + except: + pass + + if additional_kw: + for arg, defval in additional_kw: + try: + val = getattr(obj, arg, missing) + if val is not missing and val != defval: + output.append('%s=%r' % (arg, val)) + except: + pass + + return "%s(%s)" % (obj.__class__.__name__, ", ".join(output)) + + +class portable_instancemethod(object): + """Turn an instancemethod into a (parent, name) pair + to produce a serializable callable. + + """ + def __init__(self, meth): + self.target = meth.__self__ + self.name = meth.__name__ + + def __call__(self, *arg, **kw): + return getattr(self.target, self.name)(*arg, **kw) + + +def class_hierarchy(cls): + """Return an unordered sequence of all classes related to cls. + + Traverses diamond hierarchies. + + Fibs slightly: subclasses of builtin types are not returned. Thus + class_hierarchy(class A(object)) returns (A, object), not A plus every + class systemwide that derives from object. + + Old-style classes are discarded and hierarchies rooted on them + will not be descended. + + """ + if compat.py2k: + if isinstance(cls, types.ClassType): + return list() + + hier = set([cls]) + process = list(cls.__mro__) + while process: + c = process.pop() + if compat.py2k: + if isinstance(c, types.ClassType): + continue + bases = (_ for _ in c.__bases__ + if _ not in hier and not isinstance(_, types.ClassType)) + else: + bases = (_ for _ in c.__bases__ if _ not in hier) + + for b in bases: + process.append(b) + hier.add(b) + + if compat.py3k: + if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'): + continue + else: + if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'): + continue + + for s in [_ for _ in c.__subclasses__() if _ not in hier]: + process.append(s) + hier.add(s) + return list(hier) + + +def iterate_attributes(cls): + """iterate all the keys and attributes associated + with a class, without using getattr(). + + Does not use getattr() so that class-sensitive + descriptors (i.e. property.__get__()) are not called. + + """ + keys = dir(cls) + for key in keys: + for c in cls.__mro__: + if key in c.__dict__: + yield (key, c.__dict__[key]) + break + + +def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, + name='self.proxy', from_instance=None): + """Automates delegation of __specials__ for a proxying type.""" + + if only: + dunders = only + else: + if skip is None: + skip = ('__slots__', '__del__', '__getattribute__', + '__metaclass__', '__getstate__', '__setstate__') + dunders = [m for m in dir(from_cls) + if (m.startswith('__') and m.endswith('__') and + not hasattr(into_cls, m) and m not in skip)] + + for method in dunders: + try: + fn = getattr(from_cls, method) + if not hasattr(fn, '__call__'): + continue + fn = getattr(fn, 'im_func', fn) + except AttributeError: + continue + try: + spec = inspect.getargspec(fn) + fn_args = inspect.formatargspec(spec[0]) + d_args = inspect.formatargspec(spec[0][1:]) + except TypeError: + fn_args = '(self, *args, **kw)' + d_args = '(*args, **kw)' + + py = ("def %(method)s%(fn_args)s: " + "return %(name)s.%(method)s%(d_args)s" % locals()) + + env = from_instance is not None and {name: from_instance} or {} + compat.exec_(py, env) + try: + env[method].__defaults__ = fn.__defaults__ + except AttributeError: + pass + setattr(into_cls, method, env[method]) + + +def methods_equivalent(meth1, meth2): + """Return True if the two methods are the same implementation.""" + + return getattr(meth1, '__func__', meth1) is getattr(meth2, '__func__', meth2) + + +def as_interface(obj, cls=None, methods=None, required=None): + """Ensure basic interface compliance for an instance or dict of callables. + + Checks that ``obj`` implements public methods of ``cls`` or has members + listed in ``methods``. If ``required`` is not supplied, implementing at + least one interface method is sufficient. Methods present on ``obj`` that + are not in the interface are ignored. + + If ``obj`` is a dict and ``dict`` does not meet the interface + requirements, the keys of the dictionary are inspected. Keys present in + ``obj`` that are not in the interface will raise TypeErrors. + + Raises TypeError if ``obj`` does not meet the interface criteria. + + In all passing cases, an object with callable members is returned. In the + simple case, ``obj`` is returned as-is; if dict processing kicks in then + an anonymous class is returned. + + obj + A type, instance, or dictionary of callables. + cls + Optional, a type. All public methods of cls are considered the + interface. An ``obj`` instance of cls will always pass, ignoring + ``required``.. + methods + Optional, a sequence of method names to consider as the interface. + required + Optional, a sequence of mandatory implementations. If omitted, an + ``obj`` that provides at least one interface method is considered + sufficient. As a convenience, required may be a type, in which case + all public methods of the type are required. + + """ + if not cls and not methods: + raise TypeError('a class or collection of method names are required') + + if isinstance(cls, type) and isinstance(obj, cls): + return obj + + interface = set(methods or [m for m in dir(cls) if not m.startswith('_')]) + implemented = set(dir(obj)) + + complies = operator.ge + if isinstance(required, type): + required = interface + elif not required: + required = set() + complies = operator.gt + else: + required = set(required) + + if complies(implemented.intersection(interface), required): + return obj + + # No dict duck typing here. + if not type(obj) is dict: + qualifier = complies is operator.gt and 'any of' or 'all of' + raise TypeError("%r does not implement %s: %s" % ( + obj, qualifier, ', '.join(interface))) + + class AnonymousInterface(object): + """A callable-holding shell.""" + + if cls: + AnonymousInterface.__name__ = 'Anonymous' + cls.__name__ + found = set() + + for method, impl in dictlike_iteritems(obj): + if method not in interface: + raise TypeError("%r: unknown in this interface" % method) + if not compat.callable(impl): + raise TypeError("%r=%r is not callable" % (method, impl)) + setattr(AnonymousInterface, method, staticmethod(impl)) + found.add(method) + + if complies(found, required): + return AnonymousInterface + + raise TypeError("dictionary does not contain required keys %s" % + ', '.join(required - found)) + + +class memoized_property(object): + """A read-only @property that is only evaluated once.""" + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + obj.__dict__[self.__name__] = result = self.fget(obj) + return result + + def _reset(self, obj): + memoized_property.reset(obj, self.__name__) + + @classmethod + def reset(cls, obj, name): + obj.__dict__.pop(name, None) + + +class memoized_instancemethod(object): + """Decorate a method memoize its return value. + + Best applied to no-arg methods: memoization is not sensitive to + argument values, and will always return the same value even when + called with different arguments. + + """ + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + + def oneshot(*args, **kw): + result = self.fget(obj, *args, **kw) + memo = lambda *a, **kw: result + memo.__name__ = self.__name__ + memo.__doc__ = self.__doc__ + obj.__dict__[self.__name__] = memo + return result + + oneshot.__name__ = self.__name__ + oneshot.__doc__ = self.__doc__ + return oneshot + + +class group_expirable_memoized_property(object): + """A family of @memoized_properties that can be expired in tandem.""" + + def __init__(self, attributes=()): + self.attributes = [] + if attributes: + self.attributes.extend(attributes) + + def expire_instance(self, instance): + """Expire all memoized properties for *instance*.""" + stash = instance.__dict__ + for attribute in self.attributes: + stash.pop(attribute, None) + + def __call__(self, fn): + self.attributes.append(fn.__name__) + return memoized_property(fn) + + def method(self, fn): + self.attributes.append(fn.__name__) + return memoized_instancemethod(fn) + + + +def dependency_for(modulename): + def decorate(obj): + # TODO: would be nice to improve on this import silliness, + # unfortunately importlib doesn't work that great either + tokens = modulename.split(".") + mod = compat.import_(".".join(tokens[0:-1]), globals(), locals(), tokens[-1]) + mod = getattr(mod, tokens[-1]) + setattr(mod, obj.__name__, obj) + return obj + return decorate + +class dependencies(object): + """Apply imported dependencies as arguments to a function. + + E.g.:: + + @util.dependencies( + "sqlalchemy.sql.widget", + "sqlalchemy.engine.default" + ); + def some_func(self, widget, default, arg1, arg2, **kw): + # ... + + Rationale is so that the impact of a dependency cycle can be + associated directly with the few functions that cause the cycle, + and not pollute the module-level namespace. + + """ + + def __init__(self, *deps): + self.import_deps = [] + for dep in deps: + tokens = dep.split(".") + self.import_deps.append( + dependencies._importlater( + ".".join(tokens[0:-1]), + tokens[-1] + ) + ) + + def __call__(self, fn): + import_deps = self.import_deps + spec = compat.inspect_getfullargspec(fn) + + spec_zero = list(spec[0]) + hasself = spec_zero[0] in ('self', 'cls') + + for i in range(len(import_deps)): + spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i + + inner_spec = format_argspec_plus(spec, grouped=False) + + for impname in import_deps: + del spec_zero[1 if hasself else 0] + spec[0][:] = spec_zero + + outer_spec = format_argspec_plus(spec, grouped=False) + + code = 'lambda %(args)s: fn(%(apply_kw)s)' % { + "args": outer_spec['args'], + "apply_kw": inner_spec['apply_kw'] + } + + decorated = eval(code, locals()) + decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__ + return update_wrapper(decorated, fn) + + @classmethod + def resolve_all(cls, path): + for m in list(dependencies._unresolved): + if m._full_path.startswith(path): + m._resolve() + + _unresolved = set() + _by_key = {} + + class _importlater(object): + _unresolved = set() + + _by_key = {} + + def __new__(cls, path, addtl): + key = path + "." + addtl + if key in dependencies._by_key: + return dependencies._by_key[key] + else: + dependencies._by_key[key] = imp = object.__new__(cls) + return imp + + def __init__(self, path, addtl): + self._il_path = path + self._il_addtl = addtl + dependencies._unresolved.add(self) + + + @property + def _full_path(self): + return self._il_path + "." + self._il_addtl + + @memoized_property + def module(self): + if self in dependencies._unresolved: + raise ImportError( + "importlater.resolve_all() hasn't " + "been called (this is %s %s)" + % (self._il_path, self._il_addtl)) + + return getattr(self._initial_import, self._il_addtl) + + def _resolve(self): + dependencies._unresolved.discard(self) + self._initial_import = compat.import_( + self._il_path, globals(), locals(), + [self._il_addtl]) + + def __getattr__(self, key): + if key == 'module': + raise ImportError("Could not resolve module %s" + % self._full_path) + try: + attr = getattr(self.module, key) + except AttributeError: + raise AttributeError( + "Module %s has no attribute '%s'" % + (self._full_path, key) + ) + self.__dict__[key] = attr + return attr + + +# from paste.deploy.converters +def asbool(obj): + if isinstance(obj, compat.string_types): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + return bool(obj) + + +def bool_or_str(*text): + """Return a callable that will evaulate a string as + boolean, or one of a set of "alternate" string values. + + """ + def bool_or_value(obj): + if obj in text: + return obj + else: + return asbool(obj) + return bool_or_value + + +def asint(value): + """Coerce to integer.""" + + if value is None: + return value + return int(value) + + +def coerce_kw_type(kw, key, type_, flexi_bool=True): + """If 'key' is present in dict 'kw', coerce its value to type 'type\_' if + necessary. If 'flexi_bool' is True, the string '0' is considered false + when coercing to boolean. + """ + + if key in kw and type(kw[key]) is not type_ and kw[key] is not None: + if type_ is bool and flexi_bool: + kw[key] = asbool(kw[key]) + else: + kw[key] = type_(kw[key]) + + +def constructor_copy(obj, cls, **kw): + """Instantiate cls using the __dict__ of obj as constructor arguments. + + Uses inspect to match the named arguments of ``cls``. + + """ + + names = get_cls_kwargs(cls) + kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) + return cls(**kw) + + +def counter(): + """Return a threadsafe counter function.""" + + lock = compat.threading.Lock() + counter = itertools.count(1) + + # avoid the 2to3 "next" transformation... + def _next(): + lock.acquire() + try: + return next(counter) + finally: + lock.release() + + return _next + + +def duck_type_collection(specimen, default=None): + """Given an instance or class, guess if it is or is acting as one of + the basic collection types: list, set and dict. If the __emulates__ + property is present, return that preferentially. + """ + + if hasattr(specimen, '__emulates__'): + # canonicalize set vs sets.Set to a standard: the builtin set + if (specimen.__emulates__ is not None and + issubclass(specimen.__emulates__, set)): + return set + else: + return specimen.__emulates__ + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set): + return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, 'append'): + return list + elif hasattr(specimen, 'add'): + return set + elif hasattr(specimen, 'set'): + return dict + else: + return default + + +def assert_arg_type(arg, argtype, name): + if isinstance(arg, argtype): + return arg + else: + if isinstance(argtype, tuple): + raise exc.ArgumentError( + "Argument '%s' is expected to be one of type %s, got '%s'" % + (name, ' or '.join("'%s'" % a for a in argtype), type(arg))) + else: + raise exc.ArgumentError( + "Argument '%s' is expected to be of type '%s', got '%s'" % + (name, argtype, type(arg))) + + +def dictlike_iteritems(dictlike): + """Return a (key, value) iterator for almost any dict-like object.""" + + if compat.py3k: + if hasattr(dictlike, 'items'): + return list(dictlike.items()) + else: + if hasattr(dictlike, 'iteritems'): + return dictlike.iteritems() + elif hasattr(dictlike, 'items'): + return iter(dictlike.items()) + + getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None)) + if getter is None: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, 'iterkeys'): + def iterator(): + for key in dictlike.iterkeys(): + yield key, getter(key) + return iterator() + elif hasattr(dictlike, 'keys'): + return iter((key, getter(key)) for key in dictlike.keys()) + else: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + + +class classproperty(property): + """A decorator that behaves like @property except that operates + on classes rather than instances. + + The decorator is currently special when using the declarative + module, but note that the + :class:`~.sqlalchemy.ext.declarative.declared_attr` + decorator should be used for this purpose with declarative. + + """ + + def __init__(self, fget, *arg, **kw): + super(classproperty, self).__init__(fget, *arg, **kw) + self.__doc__ = fget.__doc__ + + def __get__(desc, self, cls): + return desc.fget(cls) + + +class hybridmethod(object): + """Decorate a function as cls- or instance- level.""" + def __init__(self, func, expr=None): + self.func = func + + def __get__(self, instance, owner): + if instance is None: + return self.func.__get__(owner, owner.__class__) + else: + return self.func.__get__(instance, owner) + + +class _symbol(int): + def __new__(self, name, doc=None, canonical=None): + """Construct a new named symbol.""" + assert isinstance(name, compat.string_types) + if canonical is None: + canonical = hash(name) + v = int.__new__(_symbol, canonical) + v.name = name + if doc: + v.__doc__ = doc + return v + + def __reduce__(self): + return symbol, (self.name, "x", int(self)) + + def __str__(self): + return repr(self) + + def __repr__(self): + return "symbol(%r)" % self.name + +_symbol.__name__ = 'symbol' + + +class symbol(object): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + The optional ``doc`` argument assigns to ``__doc__``. This + is strictly so that Sphinx autoattr picks up the docstring we want + (it doesn't appear to pick up the in-module docstring if the datamember + is in a different module - autoattribute also blows up completely). + If Sphinx fixes/improves this then we would no longer need + ``doc`` here. + + """ + symbols = {} + _lock = compat.threading.Lock() + + def __new__(cls, name, doc=None, canonical=None): + cls._lock.acquire() + try: + sym = cls.symbols.get(name) + if sym is None: + cls.symbols[name] = sym = _symbol(name, doc, canonical) + return sym + finally: + symbol._lock.release() + + +_creation_order = 1 + + +def set_creation_order(instance): + """Assign a '_creation_order' sequence to the given instance. + + This allows multiple instances to be sorted in order of creation + (typically within a single thread; the counter is not particularly + threadsafe). + + """ + global _creation_order + instance._creation_order = _creation_order + _creation_order += 1 + + +def warn_exception(func, *args, **kwargs): + """executes the given function, catches all exceptions and converts to + a warning. + + """ + try: + return func(*args, **kwargs) + except: + warn("%s('%s') ignored" % sys.exc_info()[0:2]) + + +def warn(msg, stacklevel=3): + """Issue a warning. + + If msg is a string, :class:`.exc.SAWarning` is used as + the category. + + .. note:: + + This function is swapped out when the test suite + runs, with a compatible version that uses + warnings.warn_explicit, so that the warnings registry can + be controlled. + + """ + if isinstance(msg, compat.string_types): + warnings.warn(msg, exc.SAWarning, stacklevel=stacklevel) + else: + warnings.warn(msg, stacklevel=stacklevel) + + +def only_once(fn): + """Decorate the given function to be a no-op after it is called exactly + once.""" + + once = [fn] + def go(*arg, **kw): + if once: + once_fn = once.pop() + return once_fn(*arg, **kw) + + return go + + +_SQLA_RE = re.compile(r'sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py') +_UNITTEST_RE = re.compile(r'unit(?:2|test2?/)') + +def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): + """Chop extraneous lines off beginning and end of a traceback. + + :param tb: + a list of traceback lines as returned by ``traceback.format_stack()`` + + :param exclude_prefix: + a regular expression object matching lines to skip at beginning of ``tb`` + + :param exclude_suffix: + a regular expression object matching lines to skip at end of ``tb`` + """ + start = 0 + end = len(tb) - 1 + while start <= end and exclude_prefix.search(tb[start]): + start += 1 + while start <= end and exclude_suffix.search(tb[end]): + end -= 1 + return tb[start:end + 1] + +NoneType = type(None) diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py new file mode 100644 index 00000000..c98aa7fd --- /dev/null +++ b/lib/sqlalchemy/util/queue.py @@ -0,0 +1,199 @@ +# util/queue.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""An adaptation of Py2.3/2.4's Queue module which supports reentrant +behavior, using RLock instead of Lock for its mutex object. The +Queue object is used exclusively by the sqlalchemy.pool.QueuePool +class. + +This is to support the connection pool's usage of weakref callbacks to return +connections to the underlying Queue, which can in extremely +rare cases be invoked within the ``get()`` method of the Queue itself, +producing a ``put()`` inside the ``get()`` and therefore a reentrant +condition. + +""" + +from collections import deque +from time import time as _time +from .compat import threading + + +__all__ = ['Empty', 'Full', 'Queue'] + + +class Empty(Exception): + "Exception raised by Queue.get(block=0)/get_nowait()." + + pass + + +class Full(Exception): + "Exception raised by Queue.put(block=0)/put_nowait()." + + pass + + +class Queue: + def __init__(self, maxsize=0): + """Initialize a queue object with a given maximum size. + + If `maxsize` is <= 0, the queue size is infinite. + """ + + self._init(maxsize) + # mutex must be held whenever the queue is mutating. All methods + # that acquire mutex must release it before returning. mutex + # is shared between the two conditions, so acquiring and + # releasing the conditions also acquires and releases mutex. + self.mutex = threading.RLock() + # Notify not_empty whenever an item is added to the queue; a + # thread waiting to get is notified then. + self.not_empty = threading.Condition(self.mutex) + # Notify not_full whenever an item is removed from the queue; + # a thread waiting to put is notified then. + self.not_full = threading.Condition(self.mutex) + + + def qsize(self): + """Return the approximate size of the queue (not reliable!).""" + + self.mutex.acquire() + n = self._qsize() + self.mutex.release() + return n + + def empty(self): + """Return True if the queue is empty, False otherwise (not + reliable!).""" + + self.mutex.acquire() + n = self._empty() + self.mutex.release() + return n + + def full(self): + """Return True if the queue is full, False otherwise (not + reliable!).""" + + self.mutex.acquire() + n = self._full() + self.mutex.release() + return n + + def put(self, item, block=True, timeout=None): + """Put an item into the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until a free slot is + available. If `timeout` is a positive number, it blocks at + most `timeout` seconds and raises the ``Full`` exception if no + free slot was available within that time. Otherwise (`block` + is false), put an item on the queue if a free slot is + immediately available, else raise the ``Full`` exception + (`timeout` is ignored in that case). + """ + + self.not_full.acquire() + try: + if not block: + if self._full(): + raise Full + elif timeout is None: + while self._full(): + self.not_full.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._full(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Full + self.not_full.wait(remaining) + self._put(item) + self.not_empty.notify() + finally: + self.not_full.release() + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + Only enqueue the item if a free slot is immediately available. + Otherwise raise the ``Full`` exception. + """ + return self.put(item, False) + + def get(self, block=True, timeout=None): + """Remove and return an item from the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until an item is available. If + `timeout` is a positive number, it blocks at most `timeout` + seconds and raises the ``Empty`` exception if no item was + available within that time. Otherwise (`block` is false), + return an item if one is immediately available, else raise the + ``Empty`` exception (`timeout` is ignored in that case). + """ + self.not_empty.acquire() + try: + if not block: + if self._empty(): + raise Empty + elif timeout is None: + while self._empty(): + self.not_empty.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._empty(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + item = self._get() + self.not_full.notify() + return item + finally: + self.not_empty.release() + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Only get an item if one is immediately available. Otherwise + raise the ``Empty`` exception. + """ + + return self.get(False) + + # Override these methods to implement other queue organizations + # (e.g. stack or priority queue). + # These will only be called with appropriate locks held + + # Initialize the queue representation + def _init(self, maxsize): + self.maxsize = maxsize + self.queue = deque() + + def _qsize(self): + return len(self.queue) + + # Check whether the queue is empty + def _empty(self): + return not self.queue + + # Check whether the queue is full + def _full(self): + return self.maxsize > 0 and len(self.queue) == self.maxsize + + # Put a new item in the queue + def _put(self, item): + self.queue.append(item) + + # Get an item from the queue + def _get(self): + return self.queue.popleft() diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py new file mode 100644 index 00000000..fe7e7689 --- /dev/null +++ b/lib/sqlalchemy/util/topological.py @@ -0,0 +1,96 @@ +# util/topological.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Topological sorting algorithms.""" + +from ..exc import CircularDependencyError +from .. import util + +__all__ = ['sort', 'sort_as_subsets', 'find_cycles'] + + +def sort_as_subsets(tuples, allitems): + + edges = util.defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + todo = set(allitems) + + while todo: + output = set() + for node in list(todo): + if not todo.intersection(edges[node]): + output.add(node) + + if not output: + raise CircularDependencyError( + "Circular dependency detected.", + find_cycles(tuples, allitems), + _gen_edges(edges) + ) + + todo.difference_update(output) + yield output + + +def sort(tuples, allitems): + """sort the given list of items by dependency. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + for set_ in sort_as_subsets(tuples, allitems): + for s in set_: + yield s + + +def find_cycles(tuples, allitems): + # adapted from: + # http://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html + + edges = util.defaultdict(set) + for parent, child in tuples: + edges[parent].add(child) + nodes_to_test = set(edges) + + output = set() + + # we'd like to find all nodes that are + # involved in cycles, so we do the full + # pass through the whole thing for each + # node in the original list. + + # we can go just through parent edge nodes. + # if a node is only a child and never a parent, + # by definition it can't be part of a cycle. same + # if it's not in the edges at all. + for node in nodes_to_test: + stack = [node] + todo = nodes_to_test.difference(stack) + while stack: + top = stack[-1] + for node in edges[top]: + if node in stack: + cyc = stack[stack.index(node):] + todo.difference_update(cyc) + output.update(cyc) + + if node in todo: + stack.append(node) + todo.remove(node) + break + else: + node = stack.pop() + return output + + +def _gen_edges(edges): + return set([ + (right, left) + for left in edges + for right in edges[left] + ]) diff --git a/sickbeard/name_parser/parser.py b/sickbeard/name_parser/parser.py index dddf9b4a..9b5ac762 100644 --- a/sickbeard/name_parser/parser.py +++ b/sickbeard/name_parser/parser.py @@ -18,18 +18,19 @@ from __future__ import with_statement +import os import time import re import datetime import os.path import regexes -import shelve import sickbeard from sickbeard import logger, helpers, scene_numbering, common, exceptions, scene_exceptions, encodingKludge as ek +from sickbeard.exceptions import ex from contextlib import closing from dateutil import parser -from sickbeard.exceptions import ex +from shove import Shove class NameParser(object): @@ -231,7 +232,8 @@ class NameParser(object): bestResult = max(sorted(matches, reverse=True, key=lambda x: x.which_regex), key=lambda x: x.score) # get quality - bestResult.quality = common.Quality.nameQuality(name, bestResult.show.is_anime if bestResult.show else False) + bestResult.quality = common.Quality.nameQuality(name, + bestResult.show.is_anime if bestResult.show else False) # scene convert result bestResult = bestResult.convert() if self.convert and not self.naming_pattern else bestResult @@ -605,38 +607,50 @@ class ParseResult(object): class NameParserCache: def __init__(self): + self.db_name = ek.ek(os.path.join, sickbeard.CACHE_DIR, 'name_parser_cache.db') self.npc_cache_size = 200 - self.db_name = ek.ek(os.path.join, sickbeard.CACHE_DIR, 'name_parser_cache') def add(self, name, parse_result): name = name.encode('utf-8', 'ignore') - try : - with closing(shelve.open(self.db_name, writeback=True)) as npc: + try: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as npc: npc[str(name)] = parse_result while len(npc.items()) > self.npc_cache_size: del npc.keys()[0] - except Exception as e: - logger.log(u"NameParser cache error: " + ex(e), logger.ERROR) - logger.log(u"NameParser cache corrupted, please delete " + self.db_name, logger.ERROR) + except: + os.remove(self.db_name) + try: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as npc: + npc[str(name)] = parse_result + + while len(npc.items()) > self.npc_cache_size: + del npc.keys()[0] + except Exception as e: + logger.log(u"NameParser cache error: " + ex(e), logger.ERROR) def get(self, name): name = name.encode('utf-8', 'ignore') try: - with closing(shelve.open(self.db_name, writeback=True)) as npc: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as npc: parse_result = npc.get(str(name), None) - except Exception as e: - logger.log(u"NameParser cache error: " + ex(e), logger.ERROR) - logger.log(u"NameParser cache corrupted, please delete " + self.db_name, logger.ERROR) - parse_result = None + except: + os.remove(self.db_name) + try: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as npc: + parse_result = npc.get(str(name), None) + except Exception as e: + logger.log(u"NameParser cache error: " + ex(e), logger.ERROR) + parse_result = None if parse_result: logger.log("Using cached parse result for: " + name, logger.DEBUG) return parse_result + class InvalidNameException(Exception): "The given release name is not valid" diff --git a/sickbeard/rssfeeds.py b/sickbeard/rssfeeds.py index 7e51d2bd..7c7ddadd 100644 --- a/sickbeard/rssfeeds.py +++ b/sickbeard/rssfeeds.py @@ -4,28 +4,33 @@ import os import urllib import urlparse import re -import shelve import sickbeard from sickbeard import logger from sickbeard import encodingKludge as ek from contextlib import closing -from lib.feedcache import cache from sickbeard.exceptions import ex +from lib.feedcache import cache +from shove import Shove class RSSFeeds: def __init__(self, db_name): - self.db_name = ek.ek(os.path.join, sickbeard.CACHE_DIR, db_name) + self.db_name = ek.ek(os.path.join, sickbeard.CACHE_DIR, db_name + '.db') def clearCache(self, age=None): try: - with closing(shelve.open(self.db_name, writeback=True)) as fs: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as fs: fc = cache.Cache(fs) fc.purge(age) - except Exception as e: - logger.log(u"RSS Error: " + ex(e), logger.ERROR) - logger.log(u"RSS cache file corrupted, please delete " + self.db_name, logger.ERROR) + except: + os.remove(self.db_name) + try: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as fs: + fc = cache.Cache(fs) + fc.purge(age) + except Exception as e: + logger.log(u"RSS cache error: " + ex(e), logger.DEBUG) def getFeed(self, url, post_data=None, request_headers=None): parsed = list(urlparse.urlparse(url)) @@ -35,13 +40,18 @@ class RSSFeeds: url += urllib.urlencode(post_data) try: - with closing(shelve.open(self.db_name, writeback=True)) as fs: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as fs: fc = cache.Cache(fs) feed = fc.fetch(url, False, False, request_headers) - except Exception as e: - logger.log(u"RSS Error: " + ex(e), logger.ERROR) - logger.log(u"RSS cache file corrupted, please delete " + self.db_name, logger.ERROR) - feed = None + except: + os.remove(self.db_name) + try: + with closing(Shove('sqlite:///' + self.db_name, compress=True)) as fs: + fc = cache.Cache(fs) + feed = fc.fetch(url, False, False, request_headers) + except Exception as e: + logger.log(u"RSS cache error: " + ex(e), logger.DEBUG) + feed = None if not feed: logger.log(u"RSS Error loading URL: " + url, logger.ERROR)