diff --git a/sickbeard/db.py b/sickbeard/db.py index e0b1b49d..b9d4b303 100644 --- a/sickbeard/db.py +++ b/sickbeard/db.py @@ -30,6 +30,9 @@ from sickbeard import encodingKludge as ek from sickbeard import logger from sickbeard.exceptions import ex +db_cons = {} +db_locks = {} + def dbFilename(filename="sickbeard.db", suffix=None): """ @param filename: The sqlite database filename to use. If not specified, @@ -42,59 +45,34 @@ def dbFilename(filename="sickbeard.db", suffix=None): filename = "%s.%s" % (filename, suffix) return ek.ek(os.path.join, sickbeard.DATA_DIR, filename) - -class DBConnection(threading.Thread): +class DBConnection(object): def __init__(self, filename="sickbeard.db", suffix=None, row_type=None): self.filename = filename self.suffix = suffix self.row_type = row_type - self.connection = None - self.db_lock = threading.Lock() try: - self.reconnect() + if self.filename not in db_cons: + db_locks[self.filename] = threading.Lock() + + self.connection = sqlite3.connect(dbFilename(self.filename, self.suffix), 20, check_same_thread=False) + self.connection.text_factory = self._unicode_text_factory + self.connection.isolation_level = None + + db_cons[self.filename] = self.connection + else: + self.connection = db_cons[self.filename] + + if self.row_type == "dict": + self.connection.row_factory = self._dict_factory + else: + self.connection.row_factory = sqlite3.Row except Exception as e: logger.log(u"DB error: " + ex(e), logger.ERROR) raise - def reconnect(self): - """Closes the existing database connection and re-opens it.""" - self.close() - self.connection = sqlite3.connect(dbFilename(self.filename, self.suffix), 20, check_same_thread=False) - self.connection.execute("pragma synchronous = off") - self.connection.execute("pragma temp_store = memory") - self.connection.execute("pragma journal_mode = memory") - self.connection.execute("pragma secure_delete = false") - self.connection.execute("pragma foreign_keys = on") - self.connection.text_factory = self._unicode_text_factory - self.connection.isolation_level = None - - if self.row_type == "dict": - self.connection.row_factory = self._dict_factory - else: - self.connection.row_factory = sqlite3.Row - - def _cursor(self): - """Returns the cursor; reconnects if disconnected.""" - if self.connection is None: self.reconnect() - return self.connection.cursor() - - def execute(self, query, args=None, fetchall=False, fetchone=False): - """Executes the given query, returning the lastrowid from the query.""" - - cursor = self._cursor() - try: - if fetchall: - return self._execute(cursor, query, args).fetchall() - elif fetchone: - return self._execute(cursor, query, args).fetchone() - else: - return self._execute(cursor, query, args) - finally: - cursor.close() - - def _execute(self, cursor, query, args): + def _execute(self, query, args): def convert(x): if isinstance(x, basestring): try: @@ -104,15 +82,23 @@ class DBConnection(threading.Thread): return x try: - with self.db_lock: - if not args: - return cursor.execute(query) - #args = map(convert, args) - return cursor.execute(query, args) - except sqlite3.OperationalError as e: - logger.log(u"DB error: " + ex(e), logger.ERROR) - self.close() - raise + if not args: + return self.connection.cursor().execute(query) + # args = map(convert, args) + return self.connection.cursor().execute(query, args) + except Exception as e: + raise e + + def execute(self, query, args=None, fetchall=False, fetchone=False): + try: + if fetchall: + return self._execute(query, args).fetchall() + elif fetchone: + return self._execute(query, args).fetchone() + else: + return self._execute(query, args) + except Exception as e: + raise e def checkDBVersion(self): @@ -136,43 +122,44 @@ class DBConnection(threading.Thread): sqlResult = [] attempt = 0 - while attempt < 5: - try: - for qu in querylist: - if len(qu) == 1: - if logTransaction: - logger.log(qu[0], logger.DEBUG) - sqlResult.append(self.execute(qu[0], fetchall=fetchall)) - elif len(qu) > 1: - if logTransaction: - logger.log(qu[0] + " with args " + str(qu[1]), logger.DEBUG) - sqlResult.append(self.execute(qu[0], qu[1], fetchall=fetchall)) + with db_locks[self.filename]: + while attempt < 5: + try: + for qu in querylist: + if len(qu) == 1: + if logTransaction: + logger.log(qu[0], logger.DEBUG) + sqlResult.append(self.execute(qu[0], fetchall=fetchall)) + elif len(qu) > 1: + if logTransaction: + logger.log(qu[0] + " with args " + str(qu[1]), logger.DEBUG) + sqlResult.append(self.execute(qu[0], qu[1], fetchall=fetchall)) - logger.log(u"Transaction with " + str(len(querylist)) + u" queries executed", logger.DEBUG) + logger.log(u"Transaction with " + str(len(querylist)) + u" queries executed", logger.DEBUG) - # finished - break - except sqlite3.OperationalError, e: - sqlResult = [] - if self.connection: - self.connection.rollback() - if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]: - logger.log(u"DB error: " + ex(e), logger.WARNING) - attempt += 1 - time.sleep(1) - else: - logger.log(u"DB error: " + ex(e), logger.ERROR) + # finished + break + except sqlite3.OperationalError, e: + sqlResult = [] + if self.connection: + self.connection.rollback() + if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]: + logger.log(u"DB error: " + ex(e), logger.WARNING) + attempt += 1 + time.sleep(1) + else: + logger.log(u"DB error: " + ex(e), logger.ERROR) + raise + except sqlite3.DatabaseError, e: + sqlResult = [] + if self.connection: + self.connection.rollback() + logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR) raise - except sqlite3.DatabaseError, e: - sqlResult = [] - if self.connection: - self.connection.rollback() - logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR) - raise - #time.sleep(0.02) + #time.sleep(0.02) - return sqlResult + return sqlResult def action(self, query, args=None, fetchall=False, fetchone=False): if query == None: @@ -181,32 +168,33 @@ class DBConnection(threading.Thread): sqlResult = None attempt = 0 - while attempt < 5: - try: - if args == None: - logger.log(self.filename + ": " + query, logger.DB) - else: - logger.log(self.filename + ": " + query + " with args " + str(args), logger.DB) + with db_locks[self.filename]: + while attempt < 5: + try: + if args == None: + logger.log(self.filename + ": " + query, logger.DB) + else: + logger.log(self.filename + ": " + query + " with args " + str(args), logger.DB) - sqlResult = self.execute(query, args, fetchall=fetchall, fetchone=fetchone) + sqlResult = self.execute(query, args, fetchall=fetchall, fetchone=fetchone) - # get out of the connection attempt loop since we were successful - break - except sqlite3.OperationalError, e: - if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]: - logger.log(u"DB error: " + ex(e), logger.WARNING) - attempt += 1 - time.sleep(1) - else: - logger.log(u"DB error: " + ex(e), logger.ERROR) + # get out of the connection attempt loop since we were successful + break + except sqlite3.OperationalError, e: + if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]: + logger.log(u"DB error: " + ex(e), logger.WARNING) + attempt += 1 + time.sleep(1) + else: + logger.log(u"DB error: " + ex(e), logger.ERROR) + raise + except sqlite3.DatabaseError, e: + logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR) raise - except sqlite3.DatabaseError, e: - logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR) - raise - #time.sleep(0.02) + #time.sleep(0.02) - return sqlResult + return sqlResult def select(self, query, args=None): @@ -268,16 +256,9 @@ class DBConnection(threading.Thread): self.action("ALTER TABLE %s ADD %s %s" % (table, column, type)) self.action("UPDATE %s SET %s = ?" % (table, column), (default,)) - def close(self): - """Close database connection""" - if getattr(self, "connection", None) is not None: - self.connection.close() - self.connection = None - def sanityCheckDatabase(connection, sanity_check): sanity_check(connection).check() - class DBSanityCheck(object): def __init__(self, connection): self.connection = connection diff --git a/sickbeard/webserve.py b/sickbeard/webserve.py index 360d600b..ef51c71d 100644 --- a/sickbeard/webserve.py +++ b/sickbeard/webserve.py @@ -286,11 +286,12 @@ class WebHandler(BaseHandler): def taskFinished(self, result, route): try: - if result: - # encode result data - try:result = ek.ss(result).encode('utf-8', 'xmlcharrefreplace') - except:pass + # encode results + try:result = ek.ss(result).encode('utf-8', 'xmlcharrefreplace') if result else None + except:pass + # ignore empty results + if result: # Check JSONP callback jsonp_callback = self.get_argument('callback_func', default=None) @@ -2108,8 +2109,11 @@ class NewHomeAddShows(Home): for cur_file in file_list: - cur_path = ek.ek(os.path.normpath, ek.ek(os.path.join, root_dir, cur_file)) - if not ek.ek(os.path.isdir, cur_path): + try: + cur_path = ek.ek(os.path.normpath, ek.ek(os.path.join, root_dir, cur_file)) + if not ek.ek(os.path.isdir, cur_path): + continue + except: continue cur_dir = { diff --git a/tests/db_tests.py b/tests/db_tests.py index fe988f62..f1dedbd2 100644 --- a/tests/db_tests.py +++ b/tests/db_tests.py @@ -19,10 +19,9 @@ import unittest import test_lib as test - +import threading class DBBasicTests(test.SickbeardTestDBCase): - def setUp(self): super(DBBasicTests, self).setUp() self.db = test.db.DBConnection() @@ -30,6 +29,18 @@ class DBBasicTests(test.SickbeardTestDBCase): def test_select(self): self.db.select("SELECT * FROM tv_episodes WHERE showid = ? AND location != ''", [0000]) +class DBMultiTests(test.SickbeardTestDBCase): + def setUp(self): + super(DBMultiTests, self).setUp() + self.db = test.db.DBConnection() + + def select(self): + self.db.select("SELECT * FROM tv_episodes WHERE showid = ? AND location != ''", [0000]) + + def test_threaded(self): + for i in xrange(20): + t = threading.Thread(target=self.select) + t.start() if __name__ == '__main__': print "==================" @@ -38,3 +49,6 @@ if __name__ == '__main__': print "######################################################################" suite = unittest.TestLoader().loadTestsFromTestCase(DBBasicTests) unittest.TextTestRunner(verbosity=2).run(suite) + + suite = unittest.TestLoader().loadTestsFromTestCase(DBMultiTests) + unittest.TextTestRunner(verbosity=2).run(suite)