diff --git a/.circleci/config.yml b/.circleci/config.yml index 92ddc2b9..c1c8395d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,7 +20,7 @@ jobs: pip install --user python-dateutil sudo apt update sudo apt install python3-pip - pip3 install --user "sqlalchemy<2.0" + pip3 install --user sqlalchemy pip3 install --user python-dateutil - persist_to_workspace: root: ~/.local diff --git a/dbprocessing/DButils.py b/dbprocessing/DButils.py index e1a1fcfb..6103d99a 100644 --- a/dbprocessing/DButils.py +++ b/dbprocessing/DButils.py @@ -39,7 +39,7 @@ import sqlalchemy.schema import sqlalchemy.sql.expression from sqlalchemy import Table -from sqlalchemy.orm import mapper +from sqlalchemy.orm import registry from sqlalchemy.orm import sessionmaker import sqlalchemy.orm.exc from sqlalchemy.exc import IntegrityError @@ -235,7 +235,7 @@ def openDB(self, engine, db_var=None, verbose=False, echo=False): (t, v, tb) = sys.exc_info() raise DBError('Error creating engine: ' + str(v)) try: - metadata = sqlalchemy.MetaData(bind=engineIns) + metadata = sqlalchemy.MetaData() # a session is what you use to actually talk to the DB, set one up with the current engine Session = sessionmaker(bind=engineIns) session = Session() @@ -277,12 +277,13 @@ def _createTableObjects(self, verbose=False): ## pass ## missions = Table('missions', metadata, autoload=True) ## mapper(Missions, missions) + mapper_registry = registry() for val in table_dict: if verbose: print(val) if not hasattr(self, val): # then make it myclass = type(str(val), (object,), dict()) - tableobj = Table(table_dict[val], self.metadata, autoload=True) - mapper(myclass, tableobj) + tableobj = Table(table_dict[val], self.metadata, autoload_with=self.engine) + mapper_registry.map_imperatively(myclass, tableobj) setattr(self, str(val), myclass) if verbose: print("Class %s created" % (val)) if verbose: DBlogging.dblogger.debug("Class %s created" % (val)) @@ -1996,7 +1997,7 @@ def getFileFullPath(self, filename): if isinstance(filename, str_classes): filename = self.getFileID(filename) sq = self.session.query(self.File.filename, self.Product.relative_path).filter( - self.File.file_id == filename).join((self.Product, self.File.product_id == self.Product.product_id)).one() + self.File.file_id == filename).join(self.Product, self.Product.product_id == self.File.product_id).one() path = os.path.join(self.MissionDirectory, *(sq[1].split(posixpath.sep) + [sq[0]])) if '{' in path: @@ -2082,7 +2083,7 @@ def getProcessID(self, proc_name): """ try: proc_id = int(proc_name) - proc_name = self.session.query(self.Process).get(proc_id) + proc_name = self.session.get(self.Process, proc_id) if proc_name is None: raise NoResultFound('No row was found for id={0}'.format(proc_id)) except ValueError: # it is not a number @@ -2127,7 +2128,7 @@ def getInstrumentID(self, name, satellite_id=None): """ try: i_id = int(name) - sq = self.session.query(self.Instrument).get(i_id) + sq = self.session.get(self.Instrument, i_id) if sq is None: raise DBNoData("No instrument_id {0} found in the DB".format(i_id)) return sq.instrument_id @@ -2202,7 +2203,7 @@ def getFileID(self, filename): return filename.file_id try: f_id = int(filename) - sq = self.session.query(self.File).get(f_id) + sq = self.session.get(self.File, f_id) if sq is None: raise DBNoData("No file_id {0} found in the DB".format(filename)) return sq.file_id @@ -2232,7 +2233,7 @@ def getCodeID(self, codename): """ try: c_id = int(codename) - code = self.session.query(self.Code).get(c_id) + code = self.session.get(self.Code, c_id) if code is None: raise DBNoData("No code id {0} found in the DB".format(c_id)) except TypeError: # came in as list or tuple @@ -2739,7 +2740,7 @@ def getProductID(self, product_name): # no file_id found raise DBNoData("No product_name %s found in the DB" % (product_name)) # Numerical product ID, make sure it exists - sq = self.session.query(self.Product).get(product_name) + sq = self.session.get(self.Product, product_name) if sq is not None: return sq.product_id else: @@ -2765,7 +2766,7 @@ def getSatelliteID(self, """ try: sat_id = int(sat_name) - sq = self.session.query(self.Satellite).get(sat_id) + sq = self.session.get(self.Satellite, sat_id) if sq is None: raise NoResultFound("No satellite id={0} found".format(sat_id)) return sq.satellite_id @@ -3074,7 +3075,7 @@ def getMissionID(self, mission_name): """ try: m_id = int(mission_name) - ms = self.session.query(self.Mission).get(m_id) + ms = self.session.get(self.Mission, m_id) if ms is None: raise DBNoData('Invalid mission id {0}'.format(m_id)) except (ValueError, TypeError): @@ -3244,18 +3245,23 @@ def getTraceback(self, table, in_id, in_id2=None): 'instrumentproductlink', 'satellite', 'mission'] in_id = self.getFileID(in_id) - - sq = (self.session.query(self.File, self.Product, - self.Inspector, self.Instrument, - self.Instrumentproductlink, self.Satellite, - self.Mission) + sq = (self.session.query(self.File, + self.Product, + self.Inspector, + self.Instrument, + self.Instrumentproductlink, + self.Satellite, + self.Mission, + ) .filter_by(file_id=in_id) - .join((self.Product, self.File.product_id == self.Product.product_id)) - .join((self.Inspector, self.Product.product_id == self.Inspector.product)) - .join((self.Instrumentproductlink, self.Product.product_id == self.Instrumentproductlink.product_id)) - .join((self.Instrument, self.Instrumentproductlink.instrument_id == self.Instrument.instrument_id)) - .join((self.Satellite, self.Instrument.satellite_id == self.Satellite.satellite_id)) - .join((self.Mission, self.Satellite.mission_id == self.Mission.mission_id)).all()) + .join(self.Product, self.Product.product_id == self.File.product_id) + .join(self.Inspector, self.Inspector.product == self.Product.product_id) + .join(self.Instrumentproductlink, self.Instrumentproductlink.product_id == self.Product.product_id) + .join(self.Instrument, self.Instrument.instrument_id == self.Instrumentproductlink.instrument_id) + .join(self.Satellite, self.Satellite.satellite_id == self.Instrument.satellite_id) + .join(self.Mission, self.Mission.mission_id == self.Satellite.mission_id) + .all() + ) if not sq: # did not find a matchm this is a dberror raise DBError("file {0} did not have a traceback, this is a problem, fix it".format(in_id)) @@ -3273,7 +3279,7 @@ def getTraceback(self, table, in_id, in_id2=None): vars = ['code', 'process'] sq = (self.session.query(self.Code, self.Process) .filter_by(code_id=in_id) - .join((self.Process, self.Code.process_id == self.Process.process_id)).all()) + .join(self.Process, self.Process.process_id == self.Code.process_id).all()) if not sq: # did not find a match this is a dberror raise DBError("code {0} did not have a traceback, this is a problem, fix it".format(in_id)) @@ -3281,18 +3287,23 @@ def getTraceback(self, table, in_id, in_id2=None): if sq[0][1].output_timebase != 'RUN': vars = ['code', 'process', 'product', 'instrument', 'instrumentproductlink', 'satellite', 'mission'] - sq = (self.session.query(self.Code, self.Process, - self.Product, self.Instrument, - self.Instrumentproductlink, self.Satellite, + sq = (self.session.query(self.Code, + self.Process, + self.Product, + self.Instrument, + self.Instrumentproductlink, + self.Satellite, self.Mission) .filter_by(code_id=in_id) - .join((self.Process, self.Code.process_id == self.Process.process_id)) - .join((self.Product, self.Product.product_id == self.Process.output_product)) - .join((self.Inspector, self.Product.product_id == self.Inspector.product)) - .join((self.Instrumentproductlink, self.Product.product_id == self.Instrumentproductlink.product_id)) - .join((self.Instrument, self.Instrumentproductlink.instrument_id == self.Instrument.instrument_id)) - .join((self.Satellite, self.Instrument.satellite_id == self.Satellite.satellite_id)) - .join((self.Mission, self.Satellite.mission_id == self.Mission.mission_id)).all()) + .join(self.Process, self.Process.process_id == self.Code.process_id) + .join(self.Product, self.Product.product_id == self.Process.output_product) + .join(self.Inspector, self.Inspector.product == self.Product.product_id) + .join(self.Instrumentproductlink, self.Instrumentproductlink.product_id == self.Product.product_id) + .join(self.Instrument, self.Instrument.instrument_id == self.Instrumentproductlink.instrument_id) + .join(self.Satellite, self.Satellite.satellite_id == self.Instrument.satellite_id) + .join(self.Mission, self.Mission.mission_id == self.Satellite.mission_id) + .all() + ) if not sq: # did not find a match this is a dberror raise DBError("code {0} did not have a traceback, this is a problem, fix it".format(in_id)) @@ -3312,17 +3323,21 @@ def getTraceback(self, table, in_id, in_id2=None): 'instrumentproductlink', 'satellite', 'mission'] in_id = self.getProductID(in_id) - sq = (self.session.query(self.Product, - self.Inspector, self.Instrument, - self.Instrumentproductlink, self.Satellite, - self.Mission) - .filter_by(product_id=in_id) - .join((self.Inspector, self.Product.product_id == self.Inspector.product)) - .join((self.Instrumentproductlink, self.Product.product_id == self.Instrumentproductlink.product_id)) - .join((self.Instrument, self.Instrumentproductlink.instrument_id == self.Instrument.instrument_id)) - .join((self.Satellite, self.Instrument.satellite_id == self.Satellite.satellite_id)) - .join((self.Mission, self.Satellite.mission_id == self.Mission.mission_id)).all()) + self.Inspector, + self.Instrument, + self.Instrumentproductlink, + self.Satellite, + self.Mission, + ) + .filter_by(product_id=in_id) + .join(self.Inspector, self.Inspector.product == self.Product.product_id) + .join(self.Instrumentproductlink, self.Instrumentproductlink.product_id == self.Product.product_id) + .join(self.Instrument, self.Instrument.instrument_id == self.Instrumentproductlink.instrument_id) + .join(self.Satellite, self.Satellite.satellite_id == self.Instrument.satellite_id ) + .join(self.Mission, self.Mission.mission_id == self.Satellite.mission_id ) + .all() + ) if not sq: # did not find a match this is a dberror raise DBError("product {0} did not have a traceback, this is a problem, fix it".format(in_id)) @@ -3338,18 +3353,22 @@ def getTraceback(self, table, in_id, in_id2=None): 'instrumentproductlink', 'satellite', 'mission'] in_id = self.getProcessID(in_id) - sq = (self.session.query(self.Process, - self.Product, self.Instrument, - self.Instrumentproductlink, self.Satellite, - self.Mission) + self.Product, + self.Instrument, + self.Instrumentproductlink, + self.Satellite, + self.Mission, + ) .filter_by(process_id=in_id) - .join((self.Product, self.Product.product_id == self.Process.output_product)) - .join((self.Inspector, self.Product.product_id == self.Inspector.product)) - .join((self.Instrumentproductlink, self.Product.product_id == self.Instrumentproductlink.product_id)) - .join((self.Instrument, self.Instrumentproductlink.instrument_id == self.Instrument.instrument_id)) - .join((self.Satellite, self.Instrument.satellite_id == self.Satellite.satellite_id)) - .join((self.Mission, self.Satellite.mission_id == self.Mission.mission_id)).all()) + .join(self.Product, self.Product.product_id == self.Process.output_product) + .join(self.Inspector, self.Inspector.product == self.Product.product_id) + .join(self.Instrumentproductlink, self.Instrumentproductlink.product_id == self.Product.product_id) + .join(self.Instrument, self.Instrument.instrument_id == self.Instrumentproductlink.instrument_id) + .join(self.Satellite, self.Satellite.satellite_id == self.Instrument.satellite_id) + .join(self.Mission, self.Mission.mission_id == self.Satellite.mission_id) + .all() + ) if not sq: # did not find a match this is a dberror raise DBError("process {0} did not have a traceback, this is a problem, fix it".format(in_id)) @@ -3541,13 +3560,13 @@ def getEntry(self, table, args): retval = None if isinstance(args, (int, collections.abc.Iterable)) \ and not isinstance(args, str_classes): # PK: int, non-str sequence - retval = self.session.query(getattr(self, table)).get(args) + retval = self.session.get(getattr(self, table), args) if retval is None: # Not valid PK type, or PK not found # see if it was a name if ('get' + table + 'ID') in dir(self): cmd = 'get' + table + 'ID' pk = getattr(self, cmd)(args) - retval = self.session.query(getattr(self, table)).get(pk) + retval = self.session.get(getattr(self, table),pk) # This code will make it consistently raise DBNoData if nothing is found, # but codebase needs to be scrubbed for callers that expect None instead. # else: @@ -3862,7 +3881,7 @@ def addUnixTimeTable(self): raise RuntimeError('Unixtime table already seems to exist.') unixtime = sqlalchemy.Table( 'unixtime', self.metadata, *tables.definition('unixtime')) - self.metadata.create_all(tables=[unixtime]) + self.metadata.create_all(self.engine, tables=[unixtime]) # Make object for the new table definition (skips existing tables) self._createTableObjects() unx0 = datetime.datetime(1970, 1, 1) @@ -3895,6 +3914,5 @@ def create_tables(filename='dbprocessing_default.db', dialect='sqlite'): data_table = sqlalchemy.schema.Table( name, metadata, *tables.definition(name)) engine = sqlalchemy.engine.create_engine(url, echo=False) - metadata.bind = engine - metadata.create_all(checkfirst=True) + metadata.create_all(checkfirst=True, bind=engine) engine.dispose() diff --git a/dbprocessing/Utils.py b/dbprocessing/Utils.py index 499fa842..8a37a79c 100644 --- a/dbprocessing/Utils.py +++ b/dbprocessing/Utils.py @@ -560,3 +560,18 @@ def readconfig(config_filepath): else: ans[section][item] = (ans[section][item], 0, 0) return ans +def load_source(modname, filepath, module): + """ + The imp module was removed in Python 3.4, thus, adaptations were made so the imp.load_source feature can be used for later Python versions + """ + try: + import importlib.util + import importlib.machinery + loader = importlib.machinery.SourceFileLoader(modname, filepath) + spec = importlib.util.spec_from_file_location(modname, filepath, loader=loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + except ImportError: + import imp # Depracated in Python 3.4 + module = imp.load_source(modname, filepath) + return module diff --git a/dbprocessing/dbprocessing.py b/dbprocessing/dbprocessing.py index e35c38cf..37d759d3 100644 --- a/dbprocessing/dbprocessing.py +++ b/dbprocessing/dbprocessing.py @@ -5,7 +5,6 @@ from __future__ import print_function import datetime -import imp import os import shutil import sys @@ -259,7 +258,9 @@ def figureProduct(self, filename=None): claimed = [] for code, desc, arg, product in act_insp: try: - inspect = imp.load_source('inspect', code) + fname = code + inspect = None + inspect = Utils.load_source('inspect',fname, inspect) except IOError as msg: DBlogging.dblogger.error('Inspector: "{0}" not found: {1}'.format(code, msg)) if os.path.isfile(code + ' '): diff --git a/dbprocessing/runMe.py b/dbprocessing/runMe.py index 16149870..0678dab0 100644 --- a/dbprocessing/runMe.py +++ b/dbprocessing/runMe.py @@ -723,6 +723,7 @@ def moveToError(self, fname): DBlogging.dblogger.debug("Entered moveToError: {0}".format(fname)) path = self.dbu.getErrorPath() + os.makedirs(path, exist_ok=True) # Creates error directory if doesn't exist if os.path.isfile(os.path.join(path, os.path.basename(fname) ) ): #TODO do I really want to remove old version:? os.remove( os.path.join(path, os.path.basename(fname) ) ) diff --git a/developer/scripts/clean_test_db.py b/developer/scripts/clean_test_db.py index 43e731a6..13a6bf22 100644 --- a/developer/scripts/clean_test_db.py +++ b/developer/scripts/clean_test_db.py @@ -15,7 +15,7 @@ import datetime import dbprocessing.DButils - +import sqlalchemy def find_related_products(dbu, prod_ids, outputs=False): """Find all input/output products for a list of products @@ -67,7 +67,6 @@ def find_related_products(dbu, prod_ids, outputs=False): for prod_id in range(1, 190): if prod_id in keep_products: continue - files = [rec.file_id for rec in dbu.getFiles(product=prod_id)] for file_id in files: dbu._purgeFileFromDB(file_id, trust_id=True, commit=False) @@ -76,13 +75,11 @@ def find_related_products(dbu, prod_ids, outputs=False): .filter_by(product_id=prod_id) for ll in list(sq): dbu.session.delete(ll) - sq = dbu.session.query(dbu.Productprocesslink)\ .filter_by(input_product_id=prod_id) results = list(sq) for ll in results: dbu.session.delete(ll) - sq = dbu.session.query(dbu.Process).filter_by(output_product=prod_id) results = list(sq) for ll in results: @@ -99,7 +96,6 @@ def find_related_products(dbu, prod_ids, outputs=False): results = list(sq) for ll in results: dbu.session.delete(ll) - dbu.delProduct(prod_id) # performs commit # Only keep a few dates @@ -116,5 +112,5 @@ def find_related_products(dbu, prod_ids, outputs=False): for file_id in delme: dbu._purgeFileFromDB(file_id, trust_id=True, commit=False) dbu.commitDB() -dbu.session.execute('VACUUM') +dbu.session.execute(sqlalchemy.sql.text("VACUUM")) dbu.commitDB() diff --git a/docs/CONTRIBUTORS.rst b/docs/CONTRIBUTORS.rst index 03d4fbbd..0af939de 100644 --- a/docs/CONTRIBUTORS.rst +++ b/docs/CONTRIBUTORS.rst @@ -26,6 +26,7 @@ Current developers (*italics denote project administrators*) are: | Andrew Walker | Meilin Yan | Xiaoguang Yang + | Elisabeth Drakatos Acknowledgements ================ diff --git a/examples/scripts/CreateDBsabrs.py b/examples/scripts/CreateDBsabrs.py index a3bed96f..62b20eb2 100644 --- a/examples/scripts/CreateDBsabrs.py +++ b/examples/scripts/CreateDBsabrs.py @@ -37,8 +37,8 @@ def init_db(self, user, password, db, host='localhost', port=5432): url = "postgresql://{0}:{1}@{2}:{3}/{4}" url = url.format(user, password, host, port, db) self.engine = create_engine(url, echo=False, encoding='utf-8') - self.metadata = sqlalchemy.MetaData(bind=self.engine) - self.metadata.reflect() + self.metadata = sqlalchemy.MetaData() + self.metadata.reflect(bind=self.engine) def createDB(self): """ @@ -318,10 +318,7 @@ def createDB(self): # engine = create_engine('postgres:///' + self.filename, echo=False) # metadata.bind = engine - metadata.create_all(checkfirst=True) - # self.engine = engine - # self.metadata = metadata - + metadata.create_all(checkfirst=True, bind=self.engine) def addMission(self, filename): """utility to add a mission""" self.dbu = DButils.DButils(filename) diff --git a/functional_test/scripts/run_rot13_L0toL1.py b/functional_test/scripts/run_rot13_L0toL1.py index 1954ddb9..903ad6c3 100755 --- a/functional_test/scripts/run_rot13_L0toL1.py +++ b/functional_test/scripts/run_rot13_L0toL1.py @@ -19,6 +19,6 @@ def doProcess(infiles, outfile): infiles = sorted(args[:-1]) outfile = args[-1] - print "infiles", infiles - print "outfile", outfile + print("infiles ", infiles) + print("outfile ", outfile) doProcess(infiles, outfile) diff --git a/functional_test/scripts/run_rot13_L1toL2.py b/functional_test/scripts/run_rot13_L1toL2.py index d1b62909..f81f2bbd 100755 --- a/functional_test/scripts/run_rot13_L1toL2.py +++ b/functional_test/scripts/run_rot13_L1toL2.py @@ -1,11 +1,11 @@ #!/usr/bin/env python +import codecs from optparse import OptionParser def doProcess(infile, outfile): with open(outfile, 'w') as output: with open(infile) as infile: - output.write(infile.read().encode('rot13')) - + output.write(codecs.encode(infile.read(), 'rot_13')) if __name__ == '__main__': usage = "usage: %prog infile outfile" parser = OptionParser(usage=usage) @@ -18,6 +18,6 @@ def doProcess(infile, outfile): infile = args[0] outfile = args[-1] - print "infile", infile - print "outfile", outfile + print("infile ", infile) + print("outfile ", outfile) doProcess(infile, outfile) diff --git a/functional_test/scripts/run_rot13_RUN_timebase.py b/functional_test/scripts/run_rot13_RUN_timebase.py index 8dcbfede..662db9e8 100755 --- a/functional_test/scripts/run_rot13_RUN_timebase.py +++ b/functional_test/scripts/run_rot13_RUN_timebase.py @@ -16,5 +16,5 @@ def doProcess(infile): infile = args[0] - print "infile", infile + print("infile ", infile) doProcess(infile) diff --git a/scripts/scrubber.py b/scripts/scrubber.py index c415f1ac..bec0c6d4 100644 --- a/scripts/scrubber.py +++ b/scripts/scrubber.py @@ -2,6 +2,8 @@ import argparse +import sqlalchemy.sql + from dbprocessing import DButils class scrubber(object): @@ -24,7 +26,7 @@ def parents_are_newest(self): print(np.difference(n)) def version_number_check(self): - x = self.dbu.session.execute("SELECT max(interface_version), max(quality_version), max(revision_version) FROM file").fetchone() + x = self.dbu.session.execute(sqlalchemy.sql.text("SELECT max(interface_version), max(quality_version), max(revision_version) FROM file")).fetchone() if x[0] >= 1000: print("A interface version is too large") if x[1] >= 1000: diff --git a/unit_tests/dbp_testing.py b/unit_tests/dbp_testing.py index 775aa7bf..5e916a07 100644 --- a/unit_tests/dbp_testing.py +++ b/unit_tests/dbp_testing.py @@ -229,7 +229,7 @@ def removeTestDB(self): """ if self.pg: self.dbu.session.close() - self.dbu.metadata.drop_all() + self.dbu.metadata.drop_all(bind=self.dbu.engine) self.dbu.closeDB() # Before the database is removed... del self.dbu shutil.rmtree(self.td) @@ -282,7 +282,7 @@ def loadData(self, filename): # persist_selectable added 1.3 (mapped_table deprecated) tbl = insp.persist_selectable\ if hasattr(insp, 'persist_selectable') else insp.mapped_table - tbl.drop() + tbl.drop(bind=self.dbu.engine) self.dbu.metadata.remove(tbl) del self.dbu.Unixtime if data['productprocesslink']\ @@ -306,7 +306,7 @@ def loadData(self, filename): sel = "SELECT pg_catalog.setval(pg_get_serial_sequence("\ "'{table}', '{column}'), {maxid})".format( table=t, column=idcolumn, maxid=maxid) - self.dbu.session.execute(sel) + self.dbu.session.execute(sqlalchemy.sql.text(sel)) self.dbu.commitDB() # Re-reference directories since new data loaded self.dbu.MissionDirectory = self.dbu.getMissionDirectory() diff --git a/unit_tests/test_CreateDB.py b/unit_tests/test_CreateDB.py index 17cfe852..6476a5e9 100644 --- a/unit_tests/test_CreateDB.py +++ b/unit_tests/test_CreateDB.py @@ -26,7 +26,7 @@ def test1(self): dbu = DButils.DButils(testdb) if pg: dbu.session.close() - dbu.metadata.drop_all() + dbu.metadata.drop_all(bind=dbu.engine) del dbu finally: shutil.rmtree(td) diff --git a/unit_tests/test_DBRunner.py b/unit_tests/test_DBRunner.py index 5255c05b..a96850c4 100644 --- a/unit_tests/test_DBRunner.py +++ b/unit_tests/test_DBRunner.py @@ -86,10 +86,8 @@ def test_parse_dbrunner_args_bad(self): finally: sys.stderr.close() sys.stderr = oldstderr - self.assertEqual( - '{}: error: {}'.format(os.path.basename(sys.argv[0]), msg), err) - - + running_script = os.path.basename(sys.argv[0]) + self.assertTrue(err.startswith("{}: error: argument".format(running_script))) class DBRunnerCalcRunmeTests(unittest.TestCase, dbp_testing.AddtoDBMixin): """DBRunner tests of calc_runme""" diff --git a/unit_tests/test_DButils.py b/unit_tests/test_DButils.py index da58fbb9..66565736 100755 --- a/unit_tests/test_DButils.py +++ b/unit_tests/test_DButils.py @@ -60,7 +60,7 @@ def tearDown(self): super(DBUtilsEmptyTests, self).tearDown() if self.pg: self.dbu.session.close() - self.dbu.metadata.drop_all() + self.dbu.metadata.drop_all(bind=self.dbu.engine) self.dbu.closeDB() del self.dbu shutil.rmtree(self.td) diff --git a/unit_tests/test_Inspector.py b/unit_tests/test_Inspector.py index 44ac7d7f..9ea4d5ed 100755 --- a/unit_tests/test_Inspector.py +++ b/unit_tests/test_Inspector.py @@ -4,7 +4,7 @@ import datetime import unittest import tempfile -import imp +import sys import warnings import os @@ -14,6 +14,7 @@ from dbprocessing import Version from dbprocessing import DButils from dbprocessing import Diskfile +from dbprocessing import Utils class InspectorFunctions(unittest.TestCase): """Tests of the inspector functions""" @@ -59,9 +60,9 @@ def setUp(self): self.makeTestDB() self.loadData(os.path.join(dbp_testing.testsdir, 'data', 'db_dumps', 'testDB_dump.json')) - self.inspect = imp.load_source('inspect', os.path.join( - dbp_testing.testsdir, 'inspector', 'rot13_L1.py')) - + filename = os.path.join(dbp_testing.testsdir, 'inspector', 'rot13_L1.py') + self.inspect = None + self.inspect = Utils.load_source('inspect', filename, self.inspect) def tearDown(self): super(InspectorClass, self).tearDown() self.removeTestDB() @@ -80,10 +81,11 @@ def test_inspector(self): self.assertEqual(repr(Diskfile.Diskfile(goodfile, self.dbu)), repr(self.inspect.Inspector(goodfile, self.dbu, 1,)())) #self.assertEqual(None, self.inspect.Inspector(goodfile, self.dbu, 1,).extract_YYYYMMDD()) # This inspector sets the data_level - not allowed - inspect = imp.load_source('inspect', os.path.join( - dbp_testing.testsdir, 'inspector', 'rot13_L1_dlevel.py')) - with warnings.catch_warnings(record=True) as w: - self.assertEqual(repr(Diskfile.Diskfile(goodfile, self.dbu)), repr(self.inspect.Inspector(goodfile, self.dbu, 1,)())) + filename = os.path.join(dbp_testing.testsdir, 'inspector', 'rot13_L1_dlevel.py') + inspect = None + inspect = Utils.load_source('inspect', filename, inspect) + with warnings.catch_warnings(record=True) as w: + self.assertEqual(repr(Diskfile.Diskfile(goodfile, self.dbu)), repr(inspect.Inspector(goodfile, self.dbu, 1,)())) self.assertEqual(len(w), 1) self.assertTrue(isinstance(w[0].message, UserWarning)) self.assertEqual('Inspector rot13_L1_dlevel.py: set level to 2.0, ' @@ -93,8 +95,9 @@ def test_inspector(self): # The file doesn't match the inspector pattern... badfile = os.path.join( dbp_testing.testsdir, 'inspector', 'testDB_01_first.raw') - inspect = imp.load_source('inspect', os.path.join( - dbp_testing.testsdir, 'inspector', 'rot13_L1.py')) + filename = os.path.join(dbp_testing.testsdir, 'inspector', 'rot13_L1.py') + inspect = None + inspect = Utils.load_source('inspect', filename, inspect) self.assertEqual(None, inspect.Inspector(badfile, self.dbu, 1,)()) def test_inspector_regex(self): diff --git a/unit_tests/test_Utils.py b/unit_tests/test_Utils.py index 652cd4c0..b8034a6b 100755 --- a/unit_tests/test_Utils.py +++ b/unit_tests/test_Utils.py @@ -304,7 +304,14 @@ def test_toDatetime(self): self.assertEqual( datetime.datetime(2010, 1, 1, 23, 59, 59, 999999), Utils.toDatetime(datetime.date(2010, 1, 1), end=True)) - - + def test_load_source(self): + """Testing load_source in Utils.py""" + filename = "temp_testfile.py" + with open(filename, "w") as f: + f.write("def hello():\n return 'Hello, world!'\n") + inspect = None + module = Utils.load_source("temp_testfile", filename, inspect) + self.assertEqual(module.hello(), 'Hello, world!') + if __name__ == "__main__": unittest.main() diff --git a/unit_tests/test_tables.py b/unit_tests/test_tables.py index 17448273..dc60cbeb 100755 --- a/unit_tests/test_tables.py +++ b/unit_tests/test_tables.py @@ -30,7 +30,7 @@ def setUp(self): self.engine = sqlalchemy.create_engine( 'sqlite:///{}'.format(os.path.join(self.td, 'test.sqlite')), echo=False) - self.metadata = sqlalchemy.schema.MetaData(bind=self.engine) + self.metadata = sqlalchemy.schema.MetaData() def tearDown(self): """Delete test database""" @@ -54,7 +54,7 @@ def makeTables(self, *tables): name: sqlalchemy.schema.Table( name, self.metadata, *dbprocessing.tables.definition(name)) for name in tables} - self.metadata.create_all() + self.metadata.create_all(bind=self.engine) actual = sqlalchemy.inspection.inspect(self.engine)\ .get_table_names() self.assertEqual(sorted(tables), sorted(actual))