diff --git a/.github/workflows/development.yaml b/.github/workflows/development.yaml index 39308c026..54829c9b6 100644 --- a/.github/workflows/development.yaml +++ b/.github/workflows/development.yaml @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 + pip install flake8 black - name: Run syntax tests run: flake8 datajoint --count --select=E9,F63,F7,F82 --show-source --statistics - name: Run primary tests @@ -47,5 +47,7 @@ jobs: run: docker-compose -f LNX-docker-compose.yml up --build --exit-code-from app - name: Run style tests run: | - flake8 --ignore=E121,E123,E126,E226,E24,E704,W503,W504,E722,F401,W605 datajoint \ + flake8 --ignore=E203,E722,F401,W503 datajoint \ --count --max-complexity=62 --max-line-length=127 --statistics + black datajoint --check -v + black tests --check -v diff --git a/CHANGELOG.md b/CHANGELOG.md index 465d1ccf3..029853156 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,11 @@ ## Release notes -### 0.13.4 -- March, 25 2022 +### 0.13.4 -- March, 28 2022 * Add - Allow reading blobs produced by legacy 32-bit compiled mYm library for matlab. PR #995 * Bugfix - Add missing `jobs` argument for multiprocessing PR #997 * Add - Test for multiprocessing PR #1008 * Bugfix - Fix external store key name doesn't allow '-' (#1005) PR #1006 +* Add - Adopted black formatting into code base PR #998 ### 0.13.3 -- Feb 9, 2022 * Bugfix - Fix error in listing ancestors, descendants with part tables. diff --git a/datajoint/__init__.py b/datajoint/__init__.py index d0303d2dd..70f0b408a 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -16,15 +16,40 @@ __author__ = "DataJoint Contributors" __date__ = "November 7, 2020" -__all__ = ['__author__', '__version__', - 'config', 'conn', 'Connection', - 'Schema', 'schema', 'VirtualModule', 'create_virtual_module', - 'list_schemas', 'Table', 'FreeTable', - 'Manual', 'Lookup', 'Imported', 'Computed', 'Part', - 'Not', 'AndList', 'U', 'Diagram', 'Di', 'ERD', - 'set_password', 'kill', - 'MatCell', 'MatStruct', 'AttributeAdapter', - 'errors', 'DataJointError', 'key', 'key_hash'] +__all__ = [ + "__author__", + "__version__", + "config", + "conn", + "Connection", + "Schema", + "schema", + "VirtualModule", + "create_virtual_module", + "list_schemas", + "Table", + "FreeTable", + "Manual", + "Lookup", + "Imported", + "Computed", + "Part", + "Not", + "AndList", + "U", + "Diagram", + "Di", + "ERD", + "set_password", + "kill", + "MatCell", + "MatStruct", + "AttributeAdapter", + "errors", + "DataJointError", + "key", + "key_hash", +] from .version import __version__ from .settings import config @@ -44,6 +69,6 @@ from .errors import DataJointError from .migrate import migrate_dj011_external_blob_storage_to_dj012 -ERD = Di = Diagram # Aliases for Diagram -schema = Schema # Aliases for Schema -create_virtual_module = VirtualModule # Aliases for VirtualModule +ERD = Di = Diagram # Aliases for Diagram +schema = Schema # Aliases for Schema +create_virtual_module = VirtualModule # Aliases for VirtualModule diff --git a/datajoint/admin.py b/datajoint/admin.py index db2f61ccc..a8bd75eee 100644 --- a/datajoint/admin.py +++ b/datajoint/admin.py @@ -5,19 +5,23 @@ from .utils import user_choice -def set_password(new_password=None, connection=None, update_config=None): # pragma: no cover +def set_password( + new_password=None, connection=None, update_config=None +): # pragma: no cover connection = conn() if connection is None else connection if new_password is None: - new_password = getpass('New password: ') - confirm_password = getpass('Confirm password: ') + new_password = getpass("New password: ") + confirm_password = getpass("Confirm password: ") if new_password != confirm_password: - print('Failed to confirm the password! Aborting password change.') + print("Failed to confirm the password! Aborting password change.") return connection.query("SET PASSWORD = PASSWORD('%s')" % new_password) - print('Password updated.') + print("Password updated.") - if update_config or (update_config is None and user_choice('Update local setting?') == 'yes'): - config['database.password'] = new_password + if update_config or ( + update_config is None and user_choice("Update local setting?") == "yes" + ): + config["database.password"] = new_password config.save_local(verbose=True) @@ -40,24 +44,32 @@ def kill(restriction=None, connection=None, order_by=None): # pragma: no cover connection = conn() if order_by is not None and not isinstance(order_by, str): - order_by = ','.join(order_by) + order_by = ",".join(order_by) - query = 'SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()' + ( - "" if restriction is None else ' AND (%s)' % restriction) + ( - ' ORDER BY %s' % (order_by or 'id')) + query = ( + "SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()" + + ("" if restriction is None else " AND (%s)" % restriction) + + (" ORDER BY %s" % (order_by or "id")) + ) while True: - print(' ID USER HOST STATE TIME INFO') - print('+--+ +----------+ +-----------+ +-----------+ +-----+') - cur = ({k.lower(): v for k, v in elem.items()} - for elem in connection.query(query, as_dict=True)) + print(" ID USER HOST STATE TIME INFO") + print("+--+ +----------+ +-----------+ +-----------+ +-----+") + cur = ( + {k.lower(): v for k, v in elem.items()} + for elem in connection.query(query, as_dict=True) + ) for process in cur: try: - print('{id:>4d} {user:<12s} {host:<12s} {state:<12s} {time:>7d} {info}'.format(**process)) + print( + "{id:>4d} {user:<12s} {host:<12s} {state:<12s} {time:>7d} {info}".format( + **process + ) + ) except TypeError: print(process) response = input('process to kill or "q" to quit > ') - if response == 'q': + if response == "q": break if response: try: @@ -66,9 +78,9 @@ def kill(restriction=None, connection=None, order_by=None): # pragma: no cover pass # ignore non-numeric input else: try: - connection.query('kill %d' % pid) + connection.query("kill %d" % pid) except pymysql.err.InternalError: - print('Process not found') + print("Process not found") def kill_quick(restriction=None, connection=None): @@ -86,13 +98,17 @@ def kill_quick(restriction=None, connection=None): if connection is None: connection = conn() - query = 'SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()' + ( - "" if restriction is None else ' AND (%s)' % restriction) + query = ( + "SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()" + + ("" if restriction is None else " AND (%s)" % restriction) + ) - cur = ({k.lower(): v for k, v in elem.items()} - for elem in connection.query(query, as_dict=True)) + cur = ( + {k.lower(): v for k, v in elem.items()} + for elem in connection.query(query, as_dict=True) + ) nkill = 0 for process in cur: - connection.query('kill %d' % process['id']) + connection.query("kill %d" % process["id"]) nkill += 1 return nkill diff --git a/datajoint/attribute_adapter.py b/datajoint/attribute_adapter.py index dc1c45706..2917791f1 100644 --- a/datajoint/attribute_adapter.py +++ b/datajoint/attribute_adapter.py @@ -7,12 +7,13 @@ class AttributeAdapter: """ Base class for adapter objects for user-defined attribute types. """ + @property def attribute_type(self): """ :return: a supported DataJoint attribute type to use; e.g. "longblob", "blob@store" """ - raise NotImplementedError('Undefined attribute adapter') + raise NotImplementedError("Undefined attribute adapter") def get(self, value): """ @@ -20,7 +21,7 @@ def get(self, value): :param value: value from the database :return: object of the adapted type """ - raise NotImplementedError('Undefined attribute adapter') + raise NotImplementedError("Undefined attribute adapter") def put(self, obj): """ @@ -28,7 +29,7 @@ def put(self, obj): :param obj: an object of the adapted type :return: value to store in the database """ - raise NotImplementedError('Undefined attribute adapter') + raise NotImplementedError("Undefined attribute adapter") def get_adapter(context, adapter_name): @@ -36,19 +37,32 @@ def get_adapter(context, adapter_name): Extract the AttributeAdapter object by its name from the context and validate. """ if not _support_adapted_types(): - raise DataJointError('Support for Adapted Attribute types is disabled.') - adapter_name = adapter_name.lstrip('<').rstrip('>') + raise DataJointError("Support for Adapted Attribute types is disabled.") + adapter_name = adapter_name.lstrip("<").rstrip(">") try: - adapter = (context[adapter_name] if adapter_name in context - else type_plugins[adapter_name]['object'].load()) + adapter = ( + context[adapter_name] + if adapter_name in context + else type_plugins[adapter_name]["object"].load() + ) except KeyError: raise DataJointError( - "Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name)) + "Attribute adapter '{adapter_name}' is not defined.".format( + adapter_name=adapter_name + ) + ) if not isinstance(adapter, AttributeAdapter): raise DataJointError( "Attribute adapter '{adapter_name}' must be an instance of datajoint.AttributeAdapter".format( - adapter_name=adapter_name)) - if not isinstance(adapter.attribute_type, str) or not re.match(r'^\w', adapter.attribute_type): - raise DataJointError("Invalid attribute type {type} in attribute adapter '{adapter_name}'".format( - type=adapter.attribute_type, adapter_name=adapter_name)) + adapter_name=adapter_name + ) + ) + if not isinstance(adapter.attribute_type, str) or not re.match( + r"^\w", adapter.attribute_type + ): + raise DataJointError( + "Invalid attribute type {type} in attribute adapter '{adapter_name}'".format( + type=adapter.attribute_type, adapter_name=adapter_name + ) + ) return adapter diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index f062ef7bf..4a23b9a52 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -17,6 +17,7 @@ # --- helper functions for multiprocessing -- + def _initialize_populate(table, jobs, populate_kwargs): """ Initialize the process for mulitprocessing. @@ -45,6 +46,7 @@ class AutoPopulate: Auto-populated relations must inherit from both Relation and AutoPopulate, must define the property `key_source`, and must define the callback method `make`. """ + _key_source = None _allow_insert = False @@ -57,17 +59,29 @@ def key_source(self): Subclasses may override they key_source to change the scope or the granularity of the make calls. """ + def _rename_attributes(table, props): - return (table.proj( - **{attr: ref for attr, ref in props['attr_map'].items() if attr != ref}) - if props['aliased'] else table.proj()) + return ( + table.proj( + **{ + attr: ref + for attr, ref in props["attr_map"].items() + if attr != ref + } + ) + if props["aliased"] + else table.proj() + ) if self._key_source is None: parents = self.target.parents( - primary=True, as_objects=True, foreign_key_info=True) + primary=True, as_objects=True, foreign_key_info=True + ) if not parents: - raise DataJointError('A table must have dependencies ' - 'from its primary key for auto-populate to work') + raise DataJointError( + "A table must have dependencies " + "from its primary key for auto-populate to work" + ) self._key_source = _rename_attributes(*parents[0]) for q in parents[1:]: self._key_source *= _rename_attributes(*q) @@ -80,7 +94,8 @@ def make(self, key): computes secondary attributes, and inserts the new tuples into self. """ raise NotImplementedError( - 'Subclasses of AutoPopulate must implement the method `make`') + "Subclasses of AutoPopulate must implement the method `make`" + ) @property def target(self): @@ -104,8 +119,10 @@ def _jobs_to_do(self, restrictions): :return: the relation containing the keys to be computed (derived from self.key_source) """ if self.restriction: - raise DataJointError('Cannot call populate on a restricted table. ' - 'Instead, pass conditions to populate() as arguments.') + raise DataJointError( + "Cannot call populate on a restricted table. " + "Instead, pass conditions to populate() as arguments." + ) todo = self.key_source # key_source is a QueryExpression subclass -- trigger instantiation @@ -113,22 +130,36 @@ def _jobs_to_do(self, restrictions): todo = todo() if not isinstance(todo, QueryExpression): - raise DataJointError('Invalid key_source value') + raise DataJointError("Invalid key_source value") try: # check if target lacks any attributes from the primary key of key_source raise DataJointError( - 'The populate target lacks attribute %s ' - 'from the primary key of key_source' % next( - name for name in todo.heading.primary_key - if name not in self.target.heading)) + "The populate target lacks attribute %s " + "from the primary key of key_source" + % next( + name + for name in todo.heading.primary_key + if name not in self.target.heading + ) + ) except StopIteration: pass return (todo & AndList(restrictions)).proj() - def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False, - reserve_jobs=False, order="original", limit=None, max_calls=None, - display_progress=False, processes=1, make_kwargs=None): + def populate( + self, + *restrictions, + suppress_errors=False, + return_exception_objects=False, + reserve_jobs=False, + order="original", + limit=None, + max_calls=None, + display_progress=False, + processes=1, + make_kwargs=None + ): """ ``table.populate()`` calls ``table.make(key)`` for every primary key in ``self.key_source`` for which there is not already a tuple in table. @@ -150,18 +181,24 @@ def populate(self, *restrictions, suppress_errors=False, return_exception_object :type make_kwargs: dict, optional """ if self.connection.in_transaction: - raise DataJointError('Populate cannot be called during a transaction.') + raise DataJointError("Populate cannot be called during a transaction.") - valid_order = ['original', 'reverse', 'random'] + valid_order = ["original", "reverse", "random"] if order not in valid_order: - raise DataJointError('The order argument must be one of %s' % str(valid_order)) - jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None + raise DataJointError( + "The order argument must be one of %s" % str(valid_order) + ) + jobs = ( + self.connection.schemas[self.target.database].jobs if reserve_jobs else None + ) # define and set up signal handler for SIGTERM: if reserve_jobs: + def handler(signum, frame): - logger.info('Populate terminated by SIGTERM') - raise SystemExit('SIGTERM received') + logger.info("Populate terminated by SIGTERM") + raise SystemExit("SIGTERM received") + old_handler = signal.signal(signal.SIGTERM, handler) keys = (self._jobs_to_do(restrictions) - self.target).fetch("KEY", limit=limit) @@ -170,7 +207,7 @@ def handler(signum, frame): elif order == "random": random.shuffle(keys) - logger.info('Found %d keys to populate' % len(keys)) + logger.info("Found %d keys to populate" % len(keys)) keys = keys[:max_calls] nkeys = len(keys) @@ -182,10 +219,13 @@ def handler(signum, frame): populate_kwargs = dict( suppress_errors=suppress_errors, return_exception_objects=return_exception_objects, - make_kwargs=make_kwargs) + make_kwargs=make_kwargs, + ) if processes == 1: - for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys: + for key in ( + tqdm(keys, desc=self.__class__.__name__) if display_progress else keys + ): error = self._populate1(key, jobs, **populate_kwargs) if error is not None: error_list.append(error) @@ -193,7 +233,9 @@ def handler(signum, frame): # spawn multiple processes self.connection.close() # disconnect parent process from MySQL server del self.connection._conn.ctx # SSLContext is not pickleable - with mp.Pool(processes, _initialize_populate, (self, jobs, populate_kwargs)) as pool: + with mp.Pool( + processes, _initialize_populate, (self, jobs, populate_kwargs) + ) as pool: if display_progress: with tqdm(desc="Processes: ", total=nkeys) as pbar: for error in pool.imap(_call_populate1, keys, chunksize=1): @@ -213,7 +255,9 @@ def handler(signum, frame): if suppress_errors: return error_list - def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None): + def _populate1( + self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None + ): """ populates table for one source key, calling self.make inside a transaction. :param jobs: the jobs table or None if not reserve_jobs @@ -222,7 +266,7 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_ :param return_exception_objects: if True, errors must be returned as objects :return: (key, error) when suppress_errors=True, otherwise None """ - make = self._make_tuples if hasattr(self, '_make_tuples') else self.make + make = self._make_tuples if hasattr(self, "_make_tuples") else self.make if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)): self.connection.start_transaction() @@ -231,7 +275,7 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_ if jobs is not None: jobs.complete(self.target.table_name, self._job_key(key)) else: - logger.info('Populating: ' + str(key)) + logger.info("Populating: " + str(key)) self.__class__._allow_insert = True try: make(dict(key), **(make_kwargs or {})) @@ -240,14 +284,18 @@ def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_ self.connection.cancel_transaction() except LostConnectionError: pass - error_message = '{exception}{msg}'.format( + error_message = "{exception}{msg}".format( exception=error.__class__.__name__, - msg=': ' + str(error) if str(error) else '') + msg=": " + str(error) if str(error) else "", + ) if jobs is not None: # show error name and error message (if any) jobs.error( - self.target.table_name, self._job_key(key), - error_message=error_message, error_stack=traceback.format_exc()) + self.target.table_name, + self._job_key(key), + error_message=error_message, + error_stack=traceback.format_exc(), + ) if not suppress_errors or isinstance(error, SystemExit): raise else: @@ -269,9 +317,17 @@ def progress(self, *restrictions, display=True): total = len(todo) remaining = len(todo - self.target) if display: - print('%-20s' % self.__class__.__name__, - 'Completed %d of %d (%2.1f%%) %s' % ( - total - remaining, total, 100 - 100 * remaining / (total+1e-12), - datetime.datetime.strftime(datetime.datetime.now(), - '%Y-%m-%d %H:%M:%S')), flush=True) + print( + "%-20s" % self.__class__.__name__, + "Completed %d of %d (%2.1f%%) %s" + % ( + total - remaining, + total, + 100 - 100 * remaining / (total + 1e-12), + datetime.datetime.strftime( + datetime.datetime.now(), "%Y-%m-%d %H:%M:%S" + ), + ), + flush=True, + ) return remaining, total diff --git a/datajoint/blob.py b/datajoint/blob.py index def068463..df51e4136 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -14,33 +14,34 @@ from .settings import config -mxClassID = dict(( - # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html - ('mxUNKNOWN_CLASS', None), - ('mxCELL_CLASS', None), - ('mxSTRUCT_CLASS', None), - ('mxLOGICAL_CLASS', np.dtype('bool')), - ('mxCHAR_CLASS', np.dtype('c')), - ('mxVOID_CLASS', np.dtype('O')), - ('mxDOUBLE_CLASS', np.dtype('float64')), - ('mxSINGLE_CLASS', np.dtype('float32')), - ('mxINT8_CLASS', np.dtype('int8')), - ('mxUINT8_CLASS', np.dtype('uint8')), - ('mxINT16_CLASS', np.dtype('int16')), - ('mxUINT16_CLASS', np.dtype('uint16')), - ('mxINT32_CLASS', np.dtype('int32')), - ('mxUINT32_CLASS', np.dtype('uint32')), - ('mxINT64_CLASS', np.dtype('int64')), - ('mxUINT64_CLASS', np.dtype('uint64')), - ('mxFUNCTION_CLASS', None))) +mxClassID = dict( + ( + # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html + ("mxUNKNOWN_CLASS", None), + ("mxCELL_CLASS", None), + ("mxSTRUCT_CLASS", None), + ("mxLOGICAL_CLASS", np.dtype("bool")), + ("mxCHAR_CLASS", np.dtype("c")), + ("mxVOID_CLASS", np.dtype("O")), + ("mxDOUBLE_CLASS", np.dtype("float64")), + ("mxSINGLE_CLASS", np.dtype("float32")), + ("mxINT8_CLASS", np.dtype("int8")), + ("mxUINT8_CLASS", np.dtype("uint8")), + ("mxINT16_CLASS", np.dtype("int16")), + ("mxUINT16_CLASS", np.dtype("uint16")), + ("mxINT32_CLASS", np.dtype("int32")), + ("mxUINT32_CLASS", np.dtype("uint32")), + ("mxINT64_CLASS", np.dtype("int64")), + ("mxUINT64_CLASS", np.dtype("uint64")), + ("mxFUNCTION_CLASS", None), + ) +) rev_class_id = {dtype: i for i, dtype in enumerate(mxClassID.values())} dtype_list = list(mxClassID.values()) type_names = list(mxClassID) -compression = { - b'ZL123\0': zlib.decompress -} +compression = {b"ZL123\0": zlib.decompress} bypass_serialization = False # runtime setting to bypass blob (en|de)code @@ -58,12 +59,14 @@ def len_u32(obj): class MatCell(np.ndarray): - """ a numpy ndarray representing a Matlab cell array """ + """a numpy ndarray representing a Matlab cell array""" + pass class MatStruct(np.recarray): - """ numpy.recarray representing a Matlab struct array """ + """numpy.recarray representing a Matlab struct array""" + pass @@ -75,9 +78,11 @@ def __init__(self, squeeze=False): self.protocol = None def set_dj0(self): - if not config.get('enable_python_native_blobs'): - raise DataJointError("""v0.12+ python native blobs disabled. - See also: https://github.com/datajoint/datajoint-python#python-native-blobs""") + if not config.get("enable_python_native_blobs"): + raise DataJointError( + """v0.12+ python native blobs disabled. + See also: https://github.com/datajoint/datajoint-python#python-native-blobs""" + ) self.protocol = b"dj0\0" # when using new blob features @@ -95,52 +100,57 @@ def unpack(self, blob): self._blob = blob try: # decompress - prefix = next(p for p in compression if self._blob[self._pos:].startswith(p)) + prefix = next( + p for p in compression if self._blob[self._pos :].startswith(p) + ) except StopIteration: pass # assume uncompressed but could be unrecognized compression else: self._pos += len(prefix) blob_size = self.read_value() - blob = compression[prefix](self._blob[self._pos:]) + blob = compression[prefix](self._blob[self._pos :]) assert len(blob) == blob_size self._blob = blob self._pos = 0 blob_format = self.read_zero_terminated_string() - if blob_format in ('mYm', 'dj0'): + if blob_format in ("mYm", "dj0"): return self.read_blob(n_bytes=len(self._blob) - self._pos) def read_blob(self, n_bytes=None): start = self._pos - data_structure_code = chr(self.read_value('uint8')) + data_structure_code = chr(self.read_value("uint8")) try: call = { # MATLAB-compatible, inherited from original mYm - "A": self.read_array, # matlab-compatible numeric arrays and scalars with ndim==0 + "A": self.read_array, # matlab-compatible numeric arrays and scalars with ndim==0 "P": self.read_sparse_array, # matlab sparse array -- not supported yet - "S": self.read_struct, # matlab struct array - "C": self.read_cell_array, # matlab cell array + "S": self.read_struct, # matlab struct array + "C": self.read_cell_array, # matlab cell array # basic data types - "\xFF": self.read_none, # None - "\x01": self.read_tuple, # a Sequence (e.g. tuple) - "\x02": self.read_list, # a MutableSequence (e.g. list) - "\x03": self.read_set, # a Set - "\x04": self.read_dict, # a Mapping (e.g. dict) - "\x05": self.read_string, # a UTF8-encoded string - "\x06": self.read_bytes, # a ByteString - "\x0a": self.read_int, # unbounded scalar int - "\x0b": self.read_bool, # scalar boolean - "\x0c": self.read_complex, # scalar 128-bit complex number - "\x0d": self.read_float, # scalar 64-bit float - "F": self.read_recarray, # numpy array with fields, including recarrays - "d": self.read_decimal, # a decimal - "t": self.read_datetime, # date, time, or datetime - "u": self.read_uuid, # UUID + "\xFF": self.read_none, # None + "\x01": self.read_tuple, # a Sequence (e.g. tuple) + "\x02": self.read_list, # a MutableSequence (e.g. list) + "\x03": self.read_set, # a Set + "\x04": self.read_dict, # a Mapping (e.g. dict) + "\x05": self.read_string, # a UTF8-encoded string + "\x06": self.read_bytes, # a ByteString + "\x0a": self.read_int, # unbounded scalar int + "\x0b": self.read_bool, # scalar boolean + "\x0c": self.read_complex, # scalar 128-bit complex number + "\x0d": self.read_float, # scalar 64-bit float + "F": self.read_recarray, # numpy array with fields, including recarrays + "d": self.read_decimal, # a decimal + "t": self.read_datetime, # date, time, or datetime + "u": self.read_uuid, # UUID }[data_structure_code] except KeyError: - raise DataJointError('Unknown data structure code "%s". Upgrade datajoint.' % data_structure_code) + raise DataJointError( + 'Unknown data structure code "%s". Upgrade datajoint.' + % data_structure_code + ) v = call() if n_bytes is not None and self._pos - start != n_bytes: - raise DataJointError('Blob length check failed! Invalid blob') + raise DataJointError("Blob length check failed! Invalid blob") return v def pack_blob(self, obj): @@ -192,25 +202,33 @@ def pack_blob(self, obj): return self.pack_set(obj) if obj is None: return self.pack_none() - raise DataJointError("Packing object of type %s currently not supported!" % type(obj)) + raise DataJointError( + "Packing object of type %s currently not supported!" % type(obj) + ) def read_array(self): n_dims = int(self.read_value()) shape = self.read_value(count=n_dims) n_elem = np.prod(shape, dtype=int) - dtype_id, is_complex = self.read_value('uint32', 2) + dtype_id, is_complex = self.read_value("uint32", 2) dtype = dtype_list[dtype_id] - if type_names[dtype_id] == 'mxVOID_CLASS': + if type_names[dtype_id] == "mxVOID_CLASS": data = np.array( - list(self.read_blob(self.read_value()) for _ in range(n_elem)), dtype=np.dtype('O')) - elif type_names[dtype_id] == 'mxCHAR_CLASS': + list(self.read_blob(self.read_value()) for _ in range(n_elem)), + dtype=np.dtype("O"), + ) + elif type_names[dtype_id] == "mxCHAR_CLASS": # compensate for MATLAB packing of char arrays data = self.read_value(dtype, count=2 * n_elem) - data = data[::2].astype('U1') + data = data[::2].astype("U1") if n_dims == 2 and shape[0] == 1 or n_dims == 1: compact = data.squeeze() - data = compact if compact.shape == () else np.array(''.join(data.squeeze())) + data = ( + compact + if compact.shape == () + else np.array("".join(data.squeeze())) + ) shape = (1,) else: data = self.read_value(dtype, count=n_elem) @@ -222,21 +240,33 @@ def pack_array(self, array): """ Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0. """ - blob = b"A" + np.uint64(array.ndim).tobytes() + np.array(array.shape, dtype=np.uint64).tobytes() + blob = ( + b"A" + + np.uint64(array.ndim).tobytes() + + np.array(array.shape, dtype=np.uint64).tobytes() + ) is_complex = np.iscomplexobj(array) if is_complex: array, imaginary = np.real(array), np.imag(array) - type_id = (rev_class_id[array.dtype] if array.dtype.char != 'U' - else rev_class_id[np.dtype('O')]) + type_id = ( + rev_class_id[array.dtype] + if array.dtype.char != "U" + else rev_class_id[np.dtype("O")] + ) if dtype_list[type_id] is None: raise DataJointError("Type %s is ambiguous or unknown" % array.dtype) blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes() - if type_names[type_id] == 'mxVOID_CLASS': # array of dtype('O') - blob += b"".join(len_u64(it) + it for it in (self.pack_blob(e) for e in array.flatten(order="F"))) + if type_names[type_id] == "mxVOID_CLASS": # array of dtype('O') + blob += b"".join( + len_u64(it) + it + for it in (self.pack_blob(e) for e in array.flatten(order="F")) + ) self.set_dj0() # not supported by original mym - elif type_names[type_id] == 'mxCHAR_CLASS': # array of dtype('c') - blob += array.view(np.uint8).astype(np.uint16).tobytes() # convert to 16-bit chars for MATLAB + elif type_names[type_id] == "mxCHAR_CLASS": # array of dtype('c') + blob += ( + array.view(np.uint8).astype(np.uint16).tobytes() + ) # convert to 16-bit chars for MATLAB else: # numeric arrays if array.ndim == 0: # not supported by original mym self.set_dj0() @@ -249,55 +279,74 @@ def read_recarray(self): """ Serialize an np.ndarray with fields, including recarrays """ - n_fields = self.read_value('uint32') + n_fields = self.read_value("uint32") if not n_fields: return np.array(None) # empty array field_names = [self.read_zero_terminated_string() for _ in range(n_fields)] arrays = [self.read_blob() for _ in range(n_fields)] - rec = np.empty(arrays[0].shape, np.dtype([(f, t.dtype) for f, t in zip(field_names, arrays)])) + rec = np.empty( + arrays[0].shape, + np.dtype([(f, t.dtype) for f, t in zip(field_names, arrays)]), + ) for f, t in zip(field_names, arrays): rec[f] = t return rec.view(np.recarray) def pack_recarray(self, array): - """ Serialize a Matlab struct array """ - return (b"F" + len_u32(array.dtype) + # number of fields - '\0'.join(array.dtype.names).encode() + b"\0" + # field names - b"".join(self.pack_recarray(array[f]) if array[f].dtype.fields else self.pack_array(array[f]) - for f in array.dtype.names)) + """Serialize a Matlab struct array""" + return ( + b"F" + + len_u32(array.dtype) + + "\0".join(array.dtype.names).encode() # number of fields + + b"\0" + + b"".join( # field names + self.pack_recarray(array[f]) + if array[f].dtype.fields + else self.pack_array(array[f]) + for f in array.dtype.names + ) + ) def read_sparse_array(self): - raise DataJointError('datajoint-python does not yet support sparse arrays. Issue (#590)') + raise DataJointError( + "datajoint-python does not yet support sparse arrays. Issue (#590)" + ) def read_int(self): - return int.from_bytes(self.read_binary(self.read_value('uint16')), byteorder='little', signed=True) + return int.from_bytes( + self.read_binary(self.read_value("uint16")), byteorder="little", signed=True + ) @staticmethod def pack_int(v): n_bytes = v.bit_length() // 8 + 1 - assert 0 < n_bytes <= 0xFFFF, 'Integers are limited to 65535 bytes' - return b"\x0a" + np.uint16(n_bytes).tobytes() + v.to_bytes(n_bytes, byteorder='little', signed=True) + assert 0 < n_bytes <= 0xFFFF, "Integers are limited to 65535 bytes" + return ( + b"\x0a" + + np.uint16(n_bytes).tobytes() + + v.to_bytes(n_bytes, byteorder="little", signed=True) + ) def read_bool(self): - return bool(self.read_value('bool')) + return bool(self.read_value("bool")) @staticmethod def pack_bool(v): - return b"\x0b" + np.array(v, dtype='bool').tobytes() + return b"\x0b" + np.array(v, dtype="bool").tobytes() def read_complex(self): - return complex(self.read_value('complex128')) + return complex(self.read_value("complex128")) @staticmethod def pack_complex(v): - return b"\x0c" + np.array(v, dtype='complex128').tobytes() + return b"\x0c" + np.array(v, dtype="complex128").tobytes() def read_float(self): - return float(self.read_value('float64')) + return float(self.read_value("float64")) @staticmethod def pack_float(v): - return b"\x0d" + np.array(v, dtype='float64').tobytes() + return b"\x0d" + np.array(v, dtype="float64").tobytes() def read_decimal(self): return Decimal(self.read_string()) @@ -330,82 +379,129 @@ def pack_none(): return b"\xFF" def read_tuple(self): - return tuple(self.read_blob(self.read_value()) for _ in range(self.read_value())) + return tuple( + self.read_blob(self.read_value()) for _ in range(self.read_value()) + ) def pack_tuple(self, t): - return b"\1" + len_u64(t) + b"".join( - len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + return ( + b"\1" + + len_u64(t) + + b"".join(len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + ) def read_list(self): return list(self.read_blob(self.read_value()) for _ in range(self.read_value())) def pack_list(self, t): - return b"\2" + len_u64(t) + b"".join( - len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + return ( + b"\2" + + len_u64(t) + + b"".join(len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + ) def read_set(self): return set(self.read_blob(self.read_value()) for _ in range(self.read_value())) def pack_set(self, t): - return b"\3" + len_u64(t) + b"".join( - len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + return ( + b"\3" + + len_u64(t) + + b"".join(len_u64(it) + it for it in (self.pack_blob(i) for i in t)) + ) def read_dict(self): - return dict((self.read_blob(self.read_value()), self.read_blob(self.read_value())) - for _ in range(self.read_value())) + return dict( + (self.read_blob(self.read_value()), self.read_blob(self.read_value())) + for _ in range(self.read_value()) + ) def pack_dict(self, d): - return b"\4" + len_u64(d) + b"".join( - b"".join((len_u64(it) + it) for it in packed) - for packed in (map(self.pack_blob, pair) for pair in d.items())) + return ( + b"\4" + + len_u64(d) + + b"".join( + b"".join((len_u64(it) + it) for it in packed) + for packed in (map(self.pack_blob, pair) for pair in d.items()) + ) + ) def read_struct(self): """deserialize matlab stuct""" n_dims = self.read_value() shape = self.read_value(count=n_dims) n_elem = np.prod(shape, dtype=int) - n_fields = self.read_value('uint32') + n_fields = self.read_value("uint32") if not n_fields: return np.array(None) # empty array field_names = [self.read_zero_terminated_string() for _ in range(n_fields)] raw_data = [ - tuple(self.read_blob(n_bytes=int(self.read_value())) for _ in range(n_fields)) - for __ in range(n_elem)] + tuple( + self.read_blob(n_bytes=int(self.read_value())) for _ in range(n_fields) + ) + for __ in range(n_elem) + ] data = np.array(raw_data, dtype=list(zip(field_names, repeat(object)))) - return self.squeeze(data.reshape(shape, order="F"), convert_to_scalar=False).view(MatStruct) + return self.squeeze( + data.reshape(shape, order="F"), convert_to_scalar=False + ).view(MatStruct) def pack_struct(self, array): - """ Serialize a Matlab struct array """ - return (b"S" + np.array((array.ndim,) + array.shape, dtype=np.uint64).tobytes() + # dimensionality - len_u32(array.dtype.names) + # number of fields - "\0".join(array.dtype.names).encode() + b"\0" + # field names - b"".join(len_u64(it) + it for it in ( - self.pack_blob(e) for rec in array.flatten(order="F") for e in rec))) # values + """Serialize a Matlab struct array""" + return ( + b"S" + + np.array((array.ndim,) + array.shape, dtype=np.uint64).tobytes() + + len_u32(array.dtype.names) # dimensionality + + "\0".join(array.dtype.names).encode() # number of fields + + b"\0" + + b"".join( # field names + len_u64(it) + it + for it in ( + self.pack_blob(e) for rec in array.flatten(order="F") for e in rec + ) + ) + ) # values def read_cell_array(self): - """ deserialize MATLAB cell array """ + """deserialize MATLAB cell array""" n_dims = self.read_value() shape = self.read_value(count=n_dims) n_elem = int(np.prod(shape)) result = [self.read_blob(n_bytes=self.read_value()) for _ in range(n_elem)] - return (self.squeeze(np.array(result).reshape(shape, order="F"), convert_to_scalar=False)).view(MatCell) + return ( + self.squeeze( + np.array(result).reshape(shape, order="F"), convert_to_scalar=False + ) + ).view(MatCell) def pack_cell_array(self, array): - return (b"C" + np.array((array.ndim,) + array.shape, dtype=np.uint64).tobytes() + - b"".join(len_u64(it) + it for it in (self.pack_blob(e) for e in array.flatten(order="F")))) + return ( + b"C" + + np.array((array.ndim,) + array.shape, dtype=np.uint64).tobytes() + + b"".join( + len_u64(it) + it + for it in (self.pack_blob(e) for e in array.flatten(order="F")) + ) + ) def read_datetime(self): - """ deserialize datetime.date, .time, or .datetime """ - date, time = self.read_value('int32'), self.read_value('int64') - date = datetime.date( - year=date // 10000, - month=(date // 100) % 100, - day=date % 100) if date >= 0 else None - time = datetime.time( - hour=(time // 10000000000) % 100, - minute=(time // 100000000) % 100, - second=(time // 1000000) % 100, - microsecond=time % 1000000) if time >= 0 else None + """deserialize datetime.date, .time, or .datetime""" + date, time = self.read_value("int32"), self.read_value("int64") + date = ( + datetime.date(year=date // 10000, month=(date // 100) % 100, day=date % 100) + if date >= 0 + else None + ) + time = ( + datetime.time( + hour=(time // 10000000000) % 100, + minute=(time // 100000000) % 100, + second=(time // 1000000) % 100, + microsecond=time % 1000000, + ) + if time >= 0 + else None + ) return time and date and datetime.datetime.combine(date, time) or time or date @staticmethod @@ -417,9 +513,16 @@ def pack_datetime(d): else: date, time = None, d return b"t" + ( - np.int32(-1 if date is None else (date.year*100 + date.month)*100 + date.day).tobytes() + - np.int64(-1 if time is None else - ((time.hour*100 + time.minute)*100 + time.second)*1000000 + time.microsecond).tobytes()) + np.int32( + -1 if date is None else (date.year * 100 + date.month) * 100 + date.day + ).tobytes() + + np.int64( + -1 + if time is None + else ((time.hour * 100 + time.minute) * 100 + time.second) * 1000000 + + time.microsecond + ).tobytes() + ) def read_uuid(self): q = self.read_binary(16) @@ -430,28 +533,30 @@ def pack_uuid(obj): return b"u" + obj.bytes def read_zero_terminated_string(self): - target = self._blob.find(b'\0', self._pos) - data = self._blob[self._pos:target].decode() + target = self._blob.find(b"\0", self._pos) + data = self._blob[self._pos : target].decode() self._pos = target + 1 return data def read_value(self, dtype=None, count=1): if dtype is None: - dtype = 'uint32' if use_32bit_dims else 'uint64' + dtype = "uint32" if use_32bit_dims else "uint64" data = np.frombuffer(self._blob, dtype=dtype, count=count, offset=self._pos) self._pos += data.dtype.itemsize * data.size return data[0] if count == 1 else data def read_binary(self, size): self._pos += int(size) - return self._blob[self._pos-int(size):self._pos] + return self._blob[self._pos - int(size) : self._pos] def pack(self, obj, compress): self.protocol = b"mYm\0" # will be replaced with dj0 if new features are used - blob = self.pack_blob(obj) # this may reset the protocol and must precede protocol evaluation + blob = self.pack_blob( + obj + ) # this may reset the protocol and must precede protocol evaluation blob = self.protocol + blob if compress and len(blob) > 1000: - compressed = b'ZL123\0' + len_u64(blob) + zlib.compress(blob) + compressed = b"ZL123\0" + len_u64(blob) + zlib.compress(blob) if len(compressed) < len(blob): blob = compressed return blob @@ -460,7 +565,9 @@ def pack(self, obj, compress): def pack(obj, compress=True): if bypass_serialization: # provide a way to move blobs quickly without de/serialization - assert isinstance(obj, bytes) and obj.startswith((b'ZL123\0', b'mYm\0', b'dj0\0')) + assert isinstance(obj, bytes) and obj.startswith( + (b"ZL123\0", b"mYm\0", b"dj0\0") + ) return obj return Blob().pack(obj, compress=compress) @@ -468,7 +575,9 @@ def pack(obj, compress=True): def unpack(blob, squeeze=False): if bypass_serialization: # provide a way to move blobs quickly without de/serialization - assert isinstance(blob, bytes) and blob.startswith((b'ZL123\0', b'mYm\0', b'dj0\0')) + assert isinstance(blob, bytes) and blob.startswith( + (b"ZL123\0", b"mYm\0", b"dj0\0") + ) return blob if blob is not None: return Blob(squeeze=squeeze).unpack(blob) diff --git a/datajoint/condition.py b/datajoint/condition.py index fed138cf1..397f68b53 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -15,6 +15,7 @@ class PromiscuousOperand: """ A container for an operand to ignore join compatibility """ + def __init__(self, operand): self.operand = operand @@ -30,6 +31,7 @@ class AndList(list): is equivalent to expr2 = expr & cond1 & cond2 & cond3 """ + def append(self, restriction): if isinstance(restriction, AndList): # extend to reduce nesting @@ -39,7 +41,8 @@ def append(self, restriction): class Not: - """ invert restriction """ + """invert restriction""" + def __init__(self, restriction): self.restriction = restriction @@ -59,13 +62,21 @@ def assert_join_compatibility(expr1, expr2): for rel in (expr1, expr2): if not isinstance(rel, (U, QueryExpression)): raise DataJointError( - 'Object %r is not a QueryExpression and cannot be joined.' % rel) - if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible + "Object %r is not a QueryExpression and cannot be joined." % rel + ) + if not isinstance(expr1, U) and not isinstance( + expr2, U + ): # dj.U is always compatible try: raise DataJointError( - "Cannot join query expressions on dependent attribute `%s`" % next( - r for r in set(expr1.heading.secondary_attributes).intersection( - expr2.heading.secondary_attributes))) + "Cannot join query expressions on dependent attribute `%s`" + % next( + r + for r in set(expr1.heading.secondary_attributes).intersection( + expr2.heading.secondary_attributes + ) + ) + ) except StopIteration: pass # all ok @@ -90,13 +101,16 @@ def prep_value(k, v): v = uuid.UUID(v) except (AttributeError, ValueError): raise DataJointError( - 'Badly formed UUID {v} in restriction by `{k}`'.format(k=k, v=v)) + "Badly formed UUID {v} in restriction by `{k}`".format(k=k, v=v) + ) return "X'%s'" % v.bytes.hex() - if isinstance(v, (datetime.date, datetime.datetime, datetime.time, decimal.Decimal)): + if isinstance( + v, (datetime.date, datetime.datetime, datetime.time, decimal.Decimal) + ): return '"%s"' % v if isinstance(v, str): - return '"%s"' % v.replace('%', '%%') - return '%r' % v + return '"%s"' % v.replace("%", "%%") + return "%r" % v negate = False while isinstance(condition, Not): @@ -107,19 +121,25 @@ def prep_value(k, v): # restrict by string if isinstance(condition, str): columns.update(extract_column_names(condition)) - return template % condition.strip().replace("%", "%%") # escape %, see issue #376 + return template % condition.strip().replace( + "%", "%%" + ) # escape %, see issue #376 # restrict by AndList if isinstance(condition, AndList): # omit all conditions that evaluate to True - items = [item for item in (make_condition(query_expression, cond, columns) - for cond in condition) - if item is not True] + items = [ + item + for item in ( + make_condition(query_expression, cond, columns) for cond in condition + ) + if item is not True + ] if any(item is False for item in items): return negate # if any item is False, the whole thing is False if not items: return not negate # and empty AndList is True - return template % ('(' + ') AND ('.join(items) + ')') + return template % ("(" + ") AND (".join(items) + ")") # restriction by dj.U evaluates to True if isinstance(condition, U): @@ -135,20 +155,36 @@ def prep_value(k, v): if not common_attributes: return not negate # no matching attributes -> evaluates to True columns.update(common_attributes) - return template % ('(' + ') AND ('.join( - '`%s`%s' % (k, ' IS NULL' if condition[k] is None - else f'={prep_value(k, condition[k])}') - for k in common_attributes) + ')') + return template % ( + "(" + + ") AND (".join( + "`%s`%s" + % ( + k, + " IS NULL" + if condition[k] is None + else f"={prep_value(k, condition[k])}", + ) + for k in common_attributes + ) + + ")" + ) # restrict by a numpy record -- convert to an AndList of string equality conditions if isinstance(condition, numpy.void): common_attributes = set(condition.dtype.fields).intersection( - query_expression.heading.names) + query_expression.heading.names + ) if not common_attributes: - return not negate # no matching attributes -> evaluate to True + return not negate # no matching attributes -> evaluate to True columns.update(common_attributes) - return template % ('(' + ') AND ('.join( - '`%s`=%s' % (k, prep_value(k, condition[k])) for k in common_attributes) + ')') + return template % ( + "(" + + ") AND (".join( + "`%s`=%s" % (k, prep_value(k, condition[k])) for k in common_attributes + ) + + ")" + ) # restrict by a QueryExpression subclass -- trigger instantiation and move on if inspect.isclass(condition) and issubclass(condition, QueryExpression): @@ -163,18 +199,22 @@ def prep_value(k, v): if isinstance(condition, QueryExpression): if check_compatibility: assert_join_compatibility(query_expression, condition) - common_attributes = [q for q in condition.heading.names - if q in query_expression.heading.names] + common_attributes = [ + q for q in condition.heading.names if q in query_expression.heading.names + ] columns.update(common_attributes) if isinstance(condition, Aggregation): condition = condition.make_subquery() return ( # without common attributes, any non-empty set matches everything - (not negate if condition else negate) if not common_attributes - else '({fields}) {not_}in ({subquery})'.format( - fields='`' + '`,`'.join(common_attributes) + '`', + (not negate if condition else negate) + if not common_attributes + else "({fields}) {not_}in ({subquery})".format( + fields="`" + "`,`".join(common_attributes) + "`", not_="not " if negate else "", - subquery=condition.make_sql(common_attributes))) + subquery=condition.make_sql(common_attributes), + ) + ) # restrict by pandas.DataFrames if isinstance(condition, pandas.DataFrame): @@ -184,12 +224,14 @@ def prep_value(k, v): try: or_list = [make_condition(query_expression, q, columns) for q in condition] except TypeError: - raise DataJointError('Invalid restriction type %r' % condition) + raise DataJointError("Invalid restriction type %r" % condition) else: - or_list = [item for item in or_list if item is not False] # ignore False conditions + or_list = [ + item for item in or_list if item is not False + ] # ignore False conditions if any(item is True for item in or_list): # if any item is True, entirely True return not negate - return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate + return template % ("(%s)" % " OR ".join(or_list)) if or_list else negate def extract_column_names(sql_expression): @@ -205,21 +247,38 @@ def extract_column_names(sql_expression): result = set() s = sql_expression # for terseness # remove escaped quotes - s = re.sub(r'(\\\")|(\\\')', '', s) + s = re.sub(r"(\\\")|(\\\')", "", s) # remove quoted text s = re.sub(r"'[^']*'", "", s) - s = re.sub(r'"[^"]*"', '', s) + s = re.sub(r'"[^"]*"', "", s) # find all tokens in back quotes and remove them result.update(re.findall(r"`([a-z][a-z_0-9]*)`", s)) - s = re.sub(r"`[a-z][a-z_0-9]*`", '', s) + s = re.sub(r"`[a-z][a-z_0-9]*`", "", s) # remove space before parentheses s = re.sub(r"\s*\(", "(", s) # remove tokens followed by ( since they must be functions s = re.sub(r"(\b[a-z][a-z_0-9]*)\(", "(", s) remaining_tokens = set(re.findall(r"\b[a-z][a-z_0-9]*\b", s)) # update result removing reserved words - result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null", - "not", "interval", "second", "minute", "hour", "day", - "month", "week", "year" - }) + result.update( + remaining_tokens + - { + "is", + "in", + "between", + "like", + "and", + "or", + "null", + "not", + "interval", + "second", + "minute", + "hour", + "day", + "month", + "week", + "year", + } + ) return result diff --git a/datajoint/connection.py b/datajoint/connection.py index 3daac4bac..76fbee386 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -22,25 +22,29 @@ def get_host_hook(host_input): - if '://' in host_input: - plugin_name = host_input.split('://')[0] + if "://" in host_input: + plugin_name = host_input.split("://")[0] try: - return connection_plugins[plugin_name]['object'].load().get_host(host_input) + return connection_plugins[plugin_name]["object"].load().get_host(host_input) except KeyError: raise errors.DataJointError( - "Connection plugin '{}' not found.".format(plugin_name)) + "Connection plugin '{}' not found.".format(plugin_name) + ) else: return host_input def connect_host_hook(connection_obj): - if '://' in connection_obj.conn_info['host_input']: - plugin_name = connection_obj.conn_info['host_input'].split('://')[0] + if "://" in connection_obj.conn_info["host_input"]: + plugin_name = connection_obj.conn_info["host_input"].split("://")[0] try: - connection_plugins[plugin_name]['object'].load().connect_host(connection_obj) + connection_plugins[plugin_name]["object"].load().connect_host( + connection_obj + ) except KeyError: raise errors.DataJointError( - "Connection plugin '{}' not found.".format(plugin_name)) + "Connection plugin '{}' not found.".format(plugin_name) + ) else: connection_obj.connect() @@ -52,20 +56,22 @@ def translate_query_error(client_error, query): :param query: sql query with placeholders :return: an instance of the corresponding subclass of datajoint.errors.DataJointError """ - logger.debug('type: {}, args: {}'.format(type(client_error), client_error.args)) + logger.debug("type: {}, args: {}".format(type(client_error), client_error.args)) err, *args = client_error.args # Loss of connection errors if err in (0, "(0, '')"): - return errors.LostConnectionError('Server connection lost due to an interface error.', *args) + return errors.LostConnectionError( + "Server connection lost due to an interface error.", *args + ) if err == 2006: return errors.LostConnectionError("Connection timed out", *args) if err == 2013: return errors.LostConnectionError("Server connection lost", *args) # Access errors if err in (1044, 1142): - return errors.AccessError('Insufficient privileges.', args[0], query) + return errors.AccessError("Insufficient privileges.", args[0], query) # Integrity errors if err == 1062: return errors.DuplicateError(*args) @@ -87,7 +93,9 @@ def translate_query_error(client_error, query): return client_error -def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None): +def conn( + host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None +): """ Returns a persistent connection object to be shared by multiple modules. If the connection is not yet established or reset=True, a new connection is set up. @@ -106,22 +114,25 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use https://dev.mysql.com/doc/refman/5.7/en/connection-options.html #encrypted-connection-options). """ - if not hasattr(conn, 'connection') or reset: - host = host if host is not None else config['database.host'] - user = user if user is not None else config['database.user'] - password = password if password is not None else config['database.password'] + if not hasattr(conn, "connection") or reset: + host = host if host is not None else config["database.host"] + user = user if user is not None else config["database.user"] + password = password if password is not None else config["database.password"] if user is None: # pragma: no cover user = input("Please enter DataJoint username: ") if password is None: # pragma: no cover password = getpass(prompt="Please enter DataJoint password: ") - init_fun = init_fun if init_fun is not None else config['connection.init_function'] - use_tls = use_tls if use_tls is not None else config['database.use_tls'] + init_fun = ( + init_fun if init_fun is not None else config["connection.init_function"] + ) + use_tls = use_tls if use_tls is not None else config["database.use_tls"] conn.connection = Connection(host, user, password, None, init_fun, use_tls) return conn.connection class EmulatedCursor: """acts like a cursor""" + def __init__(self, data): self._data = data self._iter = iter(self._data) @@ -160,17 +171,19 @@ class Connection: def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None): host_input, host = (host, get_host_hook(host)) - if ':' in host: + if ":" in host: # the port in the hostname overrides the port argument - host, port = host.split(':') + host, port = host.split(":") port = int(port) elif port is None: - port = config['database.port'] + port = config["database.port"] self.conn_info = dict(host=host, port=port, user=user, passwd=password) if use_tls is not False: - self.conn_info['ssl'] = use_tls if isinstance(use_tls, dict) else {'ssl': {}} - self.conn_info['ssl_input'] = use_tls - self.conn_info['host_input'] = host_input + self.conn_info["ssl"] = ( + use_tls if isinstance(use_tls, dict) else {"ssl": {}} + ) + self.conn_info["ssl_input"] = use_tls + self.conn_info["host_input"] = host_input self.init_fun = init_fun print("Connecting {user}@{host}:{port}".format(**self.conn_info)) self._conn = None @@ -178,9 +191,9 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None) connect_host_hook(self) if self.is_connected: logger.info("Connected {user}@{host}:{port}".format(**self.conn_info)) - self.connection_id = self.query('SELECT connection_id()').fetchone()[0] + self.connection_id = self.query("SELECT connection_id()").fetchone()[0] else: - raise errors.LostConnectionError('Connection failed.') + raise errors.LostConnectionError("Connection failed.") self._in_transaction = False self.schemas = dict() self.dependencies = Dependencies(self) @@ -191,29 +204,41 @@ def __eq__(self, other): def __repr__(self): connected = "connected" if self.is_connected else "disconnected" return "DataJoint connection ({connected}) {user}@{host}:{port}".format( - connected=connected, **self.conn_info) + connected=connected, **self.conn_info + ) def connect(self): - """ Connect to the database server.""" + """Connect to the database server.""" with warnings.catch_warnings(): - warnings.filterwarnings('ignore', '.*deprecated.*') + warnings.filterwarnings("ignore", ".*deprecated.*") try: self._conn = client.connect( init_command=self.init_fun, sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", - charset=config['connection.charset'], - **{k: v for k, v in self.conn_info.items() - if k not in ['ssl_input', 'host_input']}) + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", + charset=config["connection.charset"], + **{ + k: v + for k, v in self.conn_info.items() + if k not in ["ssl_input", "host_input"] + } + ) except client.err.InternalError: self._conn = client.connect( init_command=self.init_fun, sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", - charset=config['connection.charset'], - **{k: v for k, v in self.conn_info.items() - if not(k in ['ssl_input', 'host_input'] or - k == 'ssl' and self.conn_info['ssl_input'] is None)}) + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", + charset=config["connection.charset"], + **{ + k: v + for k, v in self.conn_info.items() + if not ( + k in ["ssl_input", "host_input"] + or k == "ssl" + and self.conn_info["ssl_input"] is None + ) + } + ) self._conn.autocommit(True) def set_query_cache(self, query_cache=None): @@ -227,10 +252,13 @@ def set_query_cache(self, query_cache=None): self._query_cache = query_cache def purge_query_cache(self): - """ Purges all query cache. """ - if 'query_cache' in config and isinstance(config['query_cache'], str) and \ - pathlib.Path(config['query_cache']).is_dir(): - path_iter = pathlib.Path(config['query_cache']).glob('**/*') + """Purges all query cache.""" + if ( + "query_cache" in config + and isinstance(config["query_cache"], str) + and pathlib.Path(config["query_cache"]).is_dir() + ): + path_iter = pathlib.Path(config["query_cache"]).glob("**/*") for path in path_iter: path.unlink() @@ -242,12 +270,12 @@ def register(self, schema): self.dependencies.clear() def ping(self): - """ Ping the connection or raises an exception if the connection is closed. """ + """Ping the connection or raises an exception if the connection is closed.""" self._conn.ping(reconnect=False) @property def is_connected(self): - """ Return true if the object is connected to the database server. """ + """Return true if the object is connected to the database server.""" try: self.ping() except: @@ -265,7 +293,9 @@ def _execute_query(cursor, query, args, suppress_warnings): except client.err.Error as err: raise translate_query_error(err, query) - def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None): + def query( + self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None + ): """ Execute the specified query and return the tuple generator (cursor). :param query: SQL query @@ -278,21 +308,28 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn # check cache first: use_query_cache = bool(self._query_cache) if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query): - raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.") + raise errors.DataJointError( + "Only SELECT queries are allowed when query caching is on." + ) if use_query_cache: - if not config['query_cache']: - raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.") - hash_ = uuid_from_buffer((str(self._query_cache) + re.sub(r'`\$\w+`', '', query)).encode() + pack(args)) - cache_path = pathlib.Path(config['query_cache']) / str(hash_) + if not config["query_cache"]: + raise errors.DataJointError( + "Provide filepath dj.config['query_cache'] when using query caching." + ) + hash_ = uuid_from_buffer( + (str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode() + + pack(args) + ) + cache_path = pathlib.Path(config["query_cache"]) / str(hash_) try: buffer = cache_path.read_bytes() except FileNotFoundError: - pass # proceed to query the database + pass # proceed to query the database else: return EmulatedCursor(unpack(buffer)) if reconnect is None: - reconnect = config['database.reconnect'] + reconnect = config["database.reconnect"] logger.debug("Executing SQL:" + query[:query_log_max_length]) cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor cursor = self._conn.cursor(cursor=cursor_class) @@ -305,7 +342,9 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn connect_host_hook(self) if self._in_transaction: self.cancel_transaction() - raise errors.LostConnectionError("Connection was lost during a transaction.") + raise errors.LostConnectionError( + "Connection was lost during a transaction." + ) logger.debug("Re-executing") cursor = self._conn.cursor(cursor=cursor_class) self._execute_query(cursor, query, args, suppress_warnings) @@ -321,7 +360,7 @@ def get_user(self): """ :return: the user name and host name provided by the client to the server. """ - return self.query('SELECT user()').fetchone()[0] + return self.query("SELECT user()").fetchone()[0] # ---------- transaction processing @property @@ -338,7 +377,7 @@ def start_transaction(self): """ if self.in_transaction: raise errors.DataJointError("Nested connections are not supported.") - self.query('START TRANSACTION WITH CONSISTENT SNAPSHOT') + self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT") self._in_transaction = True logger.info("Transaction started") @@ -346,7 +385,7 @@ def cancel_transaction(self): """ Cancels the current transaction and rolls back all changes made during the transaction. """ - self.query('ROLLBACK') + self.query("ROLLBACK") self._in_transaction = False logger.info("Transaction cancelled. Rolling back ...") @@ -355,7 +394,7 @@ def commit_transaction(self): Commit all changes made during the transaction and close it. """ - self.query('COMMIT') + self.query("COMMIT") self._in_transaction = False logger.info("Transaction committed and closed.") diff --git a/datajoint/declare.py b/datajoint/declare.py index 91d37e59e..0e4849403 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -9,42 +9,70 @@ from .errors import DataJointError, _support_filepath_types, FILEPATH_FEATURE_SWITCH from .attribute_adapter import get_adapter -UUID_DATA_TYPE = 'binary(16)' +UUID_DATA_TYPE = "binary(16)" MAX_TABLE_NAME_LENGTH = 64 -CONSTANT_LITERALS = {'CURRENT_TIMESTAMP', 'NULL'} # SQL literals to be used without quotes (case insensitive) -EXTERNAL_TABLE_ROOT = '~external' - -TYPE_PATTERN = {k: re.compile(v, re.I) for k, v in dict( - INTEGER=r'((tiny|small|medium|big|)int|integer)(\s*\(.+\))?(\s+unsigned)?(\s+auto_increment)?|serial$', - DECIMAL=r'(decimal|numeric)(\s*\(.+\))?(\s+unsigned)?$', - FLOAT=r'(double|float|real)(\s*\(.+\))?(\s+unsigned)?$', - STRING=r'(var)?char\s*\(.+\)$', - ENUM=r'enum\s*\(.+\)$', - BOOL=r'bool(ean)?$', # aliased to tinyint(1) - TEMPORAL=r'(date|datetime|time|timestamp|year)(\s*\(.+\))?$', - INTERNAL_BLOB=r'(tiny|small|medium|long|)blob$', - EXTERNAL_BLOB=r'blob@(?P[a-z][\-\w]*)$', - INTERNAL_ATTACH=r'attach$', - EXTERNAL_ATTACH=r'attach@(?P[a-z][\-\w]*)$', - FILEPATH=r'filepath@(?P[a-z][\-\w]*)$', - UUID=r'uuid$', - ADAPTED=r'<.+>$' -).items()} +CONSTANT_LITERALS = { + "CURRENT_TIMESTAMP", + "NULL", +} # SQL literals to be used without quotes (case insensitive) +EXTERNAL_TABLE_ROOT = "~external" + +TYPE_PATTERN = { + k: re.compile(v, re.I) + for k, v in dict( + INTEGER=r"((tiny|small|medium|big|)int|integer)(\s*\(.+\))?(\s+unsigned)?(\s+auto_increment)?|serial$", + DECIMAL=r"(decimal|numeric)(\s*\(.+\))?(\s+unsigned)?$", + FLOAT=r"(double|float|real)(\s*\(.+\))?(\s+unsigned)?$", + STRING=r"(var)?char\s*\(.+\)$", + ENUM=r"enum\s*\(.+\)$", + BOOL=r"bool(ean)?$", # aliased to tinyint(1) + TEMPORAL=r"(date|datetime|time|timestamp|year)(\s*\(.+\))?$", + INTERNAL_BLOB=r"(tiny|small|medium|long|)blob$", + EXTERNAL_BLOB=r"blob@(?P[a-z][\-\w]*)$", + INTERNAL_ATTACH=r"attach$", + EXTERNAL_ATTACH=r"attach@(?P[a-z][\-\w]*)$", + FILEPATH=r"filepath@(?P[a-z][\-\w]*)$", + UUID=r"uuid$", + ADAPTED=r"<.+>$", + ).items() +} # custom types are stored in attribute comment -SPECIAL_TYPES = {'UUID', 'INTERNAL_ATTACH', 'EXTERNAL_ATTACH', 'EXTERNAL_BLOB', 'FILEPATH', 'ADAPTED'} +SPECIAL_TYPES = { + "UUID", + "INTERNAL_ATTACH", + "EXTERNAL_ATTACH", + "EXTERNAL_BLOB", + "FILEPATH", + "ADAPTED", +} NATIVE_TYPES = set(TYPE_PATTERN) - SPECIAL_TYPES -EXTERNAL_TYPES = {'EXTERNAL_ATTACH', 'EXTERNAL_BLOB', 'FILEPATH'} # data referenced by a UUID in external tables -SERIALIZED_TYPES = {'EXTERNAL_ATTACH', 'INTERNAL_ATTACH', 'EXTERNAL_BLOB', 'INTERNAL_BLOB'} # requires packing data +EXTERNAL_TYPES = { + "EXTERNAL_ATTACH", + "EXTERNAL_BLOB", + "FILEPATH", +} # data referenced by a UUID in external tables +SERIALIZED_TYPES = { + "EXTERNAL_ATTACH", + "INTERNAL_ATTACH", + "EXTERNAL_BLOB", + "INTERNAL_BLOB", +} # requires packing data assert set().union(SPECIAL_TYPES, EXTERNAL_TYPES, SERIALIZED_TYPES) <= set(TYPE_PATTERN) def match_type(attribute_type): try: - return next(category for category, pattern in TYPE_PATTERN.items() if pattern.match(attribute_type)) + return next( + category + for category, pattern in TYPE_PATTERN.items() + if pattern.match(attribute_type) + ) except StopIteration: - raise DataJointError("Unsupported attribute type {type}".format(type=attribute_type)) + raise DataJointError( + "Unsupported attribute type {type}".format(type=attribute_type) + ) logger = logging.getLogger(__name__) @@ -53,48 +81,68 @@ def match_type(attribute_type): def build_foreign_key_parser_old(): # old-style foreign key parser. Superseded by expression-based syntax. See issue #436 # This will be deprecated in a future release. - left = pp.Literal('(').suppress() - right = pp.Literal(')').suppress() - attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')) - new_attrs = pp.Optional(left + pp.delimitedList(attribute_name) + right).setResultsName('new_attrs') - arrow = pp.Literal('->').suppress() - lbracket = pp.Literal('[').suppress() - rbracket = pp.Literal(']').suppress() - option = pp.Word(pp.srange('[a-zA-Z]')) - options = pp.Optional(lbracket + pp.delimitedList(option) + rbracket).setResultsName('options') - ref_table = pp.Word(pp.alphas, pp.alphanums + '._').setResultsName('ref_table') - ref_attrs = pp.Optional(left + pp.delimitedList(attribute_name) + right).setResultsName('ref_attrs') + left = pp.Literal("(").suppress() + right = pp.Literal(")").suppress() + attribute_name = pp.Word(pp.srange("[a-z]"), pp.srange("[a-z0-9_]")) + new_attrs = pp.Optional( + left + pp.delimitedList(attribute_name) + right + ).setResultsName("new_attrs") + arrow = pp.Literal("->").suppress() + lbracket = pp.Literal("[").suppress() + rbracket = pp.Literal("]").suppress() + option = pp.Word(pp.srange("[a-zA-Z]")) + options = pp.Optional( + lbracket + pp.delimitedList(option) + rbracket + ).setResultsName("options") + ref_table = pp.Word(pp.alphas, pp.alphanums + "._").setResultsName("ref_table") + ref_attrs = pp.Optional( + left + pp.delimitedList(attribute_name) + right + ).setResultsName("ref_attrs") return new_attrs + arrow + options + ref_table + ref_attrs def build_foreign_key_parser(): - arrow = pp.Literal('->').suppress() - lbracket = pp.Literal('[').suppress() - rbracket = pp.Literal(']').suppress() - option = pp.Word(pp.srange('[a-zA-Z]')) - options = pp.Optional(lbracket + pp.delimitedList(option) + rbracket).setResultsName('options') - ref_table = pp.restOfLine.setResultsName('ref_table') + arrow = pp.Literal("->").suppress() + lbracket = pp.Literal("[").suppress() + rbracket = pp.Literal("]").suppress() + option = pp.Word(pp.srange("[a-zA-Z]")) + options = pp.Optional( + lbracket + pp.delimitedList(option) + rbracket + ).setResultsName("options") + ref_table = pp.restOfLine.setResultsName("ref_table") return arrow + options + ref_table def build_attribute_parser(): quoted = pp.QuotedString('"') ^ pp.QuotedString("'") - colon = pp.Literal(':').suppress() - attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name') - data_type = (pp.Combine(pp.Word(pp.alphas) + pp.SkipTo("#", ignore=quoted)) - ^ pp.QuotedString('<', endQuoteChar='>', unquoteResults=False)).setResultsName('type') - default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default') - comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment') + colon = pp.Literal(":").suppress() + attribute_name = pp.Word(pp.srange("[a-z]"), pp.srange("[a-z0-9_]")).setResultsName( + "name" + ) + data_type = ( + pp.Combine(pp.Word(pp.alphas) + pp.SkipTo("#", ignore=quoted)) + ^ pp.QuotedString("<", endQuoteChar=">", unquoteResults=False) + ).setResultsName("type") + default = pp.Literal("=").suppress() + pp.SkipTo( + colon, ignore=quoted + ).setResultsName("default") + comment = pp.Literal("#").suppress() + pp.restOfLine.setResultsName("comment") return attribute_name + pp.Optional(default) + colon + data_type + comment def build_index_parser(): - left = pp.Literal('(').suppress() - right = pp.Literal(')').suppress() - unique = pp.Optional(pp.CaselessKeyword('unique')).setResultsName('unique') - index = pp.CaselessKeyword('index').suppress() - attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')) - return unique + index + left + pp.delimitedList(attribute_name).setResultsName('attr_list') + right + left = pp.Literal("(").suppress() + right = pp.Literal(")").suppress() + unique = pp.Optional(pp.CaselessKeyword("unique")).setResultsName("unique") + index = pp.CaselessKeyword("index").suppress() + attribute_name = pp.Word(pp.srange("[a-z]"), pp.srange("[a-z0-9_]")) + return ( + unique + + index + + left + + pp.delimitedList(attribute_name).setResultsName("attr_list") + + right + ) foreign_key_parser_old = build_foreign_key_parser_old() @@ -108,11 +156,13 @@ def is_foreign_key(line): :param line: a line from the table definition :return: true if the line appears to be a foreign key definition """ - arrow_position = line.find('->') - return arrow_position >= 0 and not any(c in line[:arrow_position] for c in '"#\'') + arrow_position = line.find("->") + return arrow_position >= 0 and not any(c in line[:arrow_position] for c in "\"#'") -def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql): +def compile_foreign_key( + line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql +): """ :param line: a line from a table definition :param context: namespace containing referenced objects @@ -127,7 +177,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig from .table import Table from .expression import QueryExpression - obsolete = False # See issue #436. Old style to be deprecated in a future release + obsolete = False # See issue #436. Old style to be deprecated in a future release try: result = foreign_key_parser.parseString(line) except pp.ParseException: @@ -140,44 +190,66 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig try: ref = eval(result.ref_table, context) except NameError if obsolete else Exception: - raise DataJointError('Foreign key reference %s could not be resolved' % result.ref_table) + raise DataJointError( + "Foreign key reference %s could not be resolved" % result.ref_table + ) options = [opt.upper() for opt in result.options] for opt in options: # check for invalid options - if opt not in {'NULLABLE', 'UNIQUE'}: + if opt not in {"NULLABLE", "UNIQUE"}: raise DataJointError('Invalid foreign key option "{opt}"'.format(opt=opt)) - is_nullable = 'NULLABLE' in options - is_unique = 'UNIQUE' in options + is_nullable = "NULLABLE" in options + is_unique = "UNIQUE" in options if is_nullable and primary_key is not None: - raise DataJointError('Primary dependencies cannot be nullable in line "{line}"'.format(line=line)) + raise DataJointError( + 'Primary dependencies cannot be nullable in line "{line}"'.format(line=line) + ) if obsolete: warnings.warn( 'Line "{line}" uses obsolete syntax that will no longer be supported in datajoint 0.14. ' - 'For details, see issue #780 https://github.com/datajoint/datajoint-python/issues/780'.format(line=line)) + "For details, see issue #780 https://github.com/datajoint/datajoint-python/issues/780".format( + line=line + ) + ) if not isinstance(ref, type) or not issubclass(ref, Table): - raise DataJointError('Foreign key reference %r must be a valid query' % result.ref_table) + raise DataJointError( + "Foreign key reference %r must be a valid query" % result.ref_table + ) if isinstance(ref, type) and issubclass(ref, Table): ref = ref() # check that dependency is of a supported type - if (not isinstance(ref, QueryExpression) or len(ref.restriction) or - len(ref.support) != 1 or not isinstance(ref.support[0], str)): - raise DataJointError('Dependency "%s" is not supported (yet). Use a base table or its projection.' % - result.ref_table) + if ( + not isinstance(ref, QueryExpression) + or len(ref.restriction) + or len(ref.support) != 1 + or not isinstance(ref.support[0], str) + ): + raise DataJointError( + 'Dependency "%s" is not supported (yet). Use a base table or its projection.' + % result.ref_table + ) if obsolete: # for backward compatibility with old-style dependency declarations. See issue #436 if not isinstance(ref, Table): - DataJointError('Dependency "%s" is not supported. Check documentation.' % result.ref_table) + DataJointError( + 'Dependency "%s" is not supported. Check documentation.' + % result.ref_table + ) if not all(r in ref.primary_key for r in result.ref_attrs): raise DataJointError('Invalid foreign key attributes in "%s"' % line) try: - raise DataJointError('Duplicate attributes "{attr}" in "{line}"'.format( - attr=next(attr for attr in result.new_attrs if attr in attributes), line=line)) + raise DataJointError( + 'Duplicate attributes "{attr}" in "{line}"'.format( + attr=next(attr for attr in result.new_attrs if attr in attributes), + line=line, + ) + ) except StopIteration: - pass # the normal outcome + pass # the normal outcome # Match the primary attributes of the referenced table to local attributes new_attrs = list(result.new_attrs) @@ -186,7 +258,10 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig # special case, the renamed attribute is implicit if new_attrs and not ref_attrs: if len(new_attrs) != 1: - raise DataJointError('Renamed foreign key must be mapped to the primary key in "%s"' % line) + raise DataJointError( + 'Renamed foreign key must be mapped to the primary key in "%s"' + % line + ) if len(ref.primary_key) == 1: # if the primary key has one attribute, allow implicit renaming ref_attrs = ref.primary_key @@ -194,7 +269,10 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig # if only one primary key attribute remains, then allow implicit renaming ref_attrs = [attr for attr in ref.primary_key if attr not in attributes] if len(ref_attrs) != 1: - raise DataJointError('Could not resolve which primary key attribute should be referenced in "%s"' % line) + raise DataJointError( + 'Could not resolve which primary key attribute should be referenced in "%s"' + % line + ) if len(new_attrs) != len(ref_attrs): raise DataJointError('Mismatched attributes in foreign key "%s"' % line) @@ -210,26 +288,35 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig if primary_key is not None: primary_key.append(attr) attr_sql.append( - ref.heading[attr].sql.replace('NOT NULL ', '', int(is_nullable))) + ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable)) + ) # declare the foreign key foreign_key_sql.append( - 'FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT'.format( - fk='`,`'.join(ref.primary_key), - pk='`,`'.join(ref.heading[name].original_name for name in ref.primary_key), - ref=ref.support[0])) + "FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT".format( + fk="`,`".join(ref.primary_key), + pk="`,`".join(ref.heading[name].original_name for name in ref.primary_key), + ref=ref.support[0], + ) + ) # declare unique index if is_unique: - index_sql.append('UNIQUE INDEX ({attrs})'.format(attrs=','.join("`%s`" % attr for attr in ref.primary_key))) + index_sql.append( + "UNIQUE INDEX ({attrs})".format( + attrs=",".join("`%s`" % attr for attr in ref.primary_key) + ) + ) def prepare_declare(definition, context): # split definition into lines - definition = re.split(r'\s*\n\s*', definition.strip()) + definition = re.split(r"\s*\n\s*", definition.strip()) # check for optional table comment - table_comment = definition.pop(0)[1:].strip() if definition[0].startswith('#') else '' - if table_comment.startswith(':'): + table_comment = ( + definition.pop(0)[1:].strip() if definition[0].startswith("#") else "" + ) + if table_comment.startswith(":"): raise DataJointError('Table comment must not start with a colon ":"') in_key = True # parse primary keys primary_key = [] @@ -240,15 +327,21 @@ def prepare_declare(definition, context): external_stores = [] for line in definition: - if not line or line.startswith('#'): # ignore additional comments + if not line or line.startswith("#"): # ignore additional comments pass - elif line.startswith('---') or line.startswith('___'): + elif line.startswith("---") or line.startswith("___"): in_key = False # start parsing dependent attributes elif is_foreign_key(line): - compile_foreign_key(line, context, attributes, - primary_key if in_key else None, - attribute_sql, foreign_key_sql, index_sql) - elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index + compile_foreign_key( + line, + context, + attributes, + primary_key if in_key else None, + attribute_sql, + foreign_key_sql, + index_sql, + ) + elif re.match(r"^(unique\s+)?index[^:]*$", line, re.I): # index compile_index(line, index_sql) else: name, sql, store = compile_attribute(line, in_key, foreign_key_sql, context) @@ -260,7 +353,14 @@ def prepare_declare(definition, context): attributes.append(name) attribute_sql.append(sql) - return table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores + return ( + table_comment, + primary_key, + attribute_sql, + foreign_key_sql, + index_sql, + external_stores, + ) def declare(full_table_name, definition, context): @@ -271,23 +371,36 @@ def declare(full_table_name, definition, context): :param context: dictionary of objects that might be referred to in the table :return: SQL CREATE TABLE statement, list of external stores used """ - table_name = full_table_name.strip('`').split('.')[1] + table_name = full_table_name.strip("`").split(".")[1] if len(table_name) > MAX_TABLE_NAME_LENGTH: raise DataJointError( - 'Table name `{name}` exceeds the max length of {max_length}'.format( - name=table_name, - max_length=MAX_TABLE_NAME_LENGTH)) - - table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores = prepare_declare( - definition, context) + "Table name `{name}` exceeds the max length of {max_length}".format( + name=table_name, max_length=MAX_TABLE_NAME_LENGTH + ) + ) + + ( + table_comment, + primary_key, + attribute_sql, + foreign_key_sql, + index_sql, + external_stores, + ) = prepare_declare(definition, context) if not primary_key: - raise DataJointError('Table must have a primary key') + raise DataJointError("Table must have a primary key") return ( - 'CREATE TABLE IF NOT EXISTS %s (\n' % full_table_name + - ',\n'.join(attribute_sql + ['PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)'] + foreign_key_sql + index_sql) + - '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment), external_stores + "CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name + + ",\n".join( + attribute_sql + + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + + foreign_key_sql + + index_sql + ) + + '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment + ), external_stores def _make_attribute_alter(new, old, primary_key): @@ -301,28 +414,32 @@ def _make_attribute_alter(new, old, primary_key): name_regexp = re.compile(r"^`(?P\w+)`") original_regexp = re.compile(r'COMMENT "{\s*(?P\w+)\s*}') matched = ((name_regexp.match(d), original_regexp.search(d)) for d in new) - new_names = dict((d.group('name'), n and n.group('name')) for d, n in matched) - old_names = [name_regexp.search(d).group('name') for d in old] + new_names = dict((d.group("name"), n and n.group("name")) for d, n in matched) + old_names = [name_regexp.search(d).group("name") for d in old] # verify that original names are only used once renamed = set() for v in new_names.values(): if v: if v in renamed: - raise DataJointError('Alter attempted to rename attribute {%s} twice.' % v) + raise DataJointError( + "Alter attempted to rename attribute {%s} twice." % v + ) renamed.add(v) # verify that all renamed attributes existed in the old definition try: raise DataJointError( "Attribute {} does not exist in the original definition".format( - next(attr for attr in renamed if attr not in old_names))) + next(attr for attr in renamed if attr not in old_names) + ) + ) except StopIteration: pass # dropping attributes to_drop = [n for n in old_names if n not in renamed and n not in new_names] - sql = ['DROP `%s`' % n for n in to_drop] + sql = ["DROP `%s`" % n for n in to_drop] old_names = [name for name in old_names if name not in to_drop] # add or change attributes in order @@ -339,12 +456,19 @@ def _make_attribute_alter(new, old, primary_key): if idx >= 1 and old_names[idx - 1] != (prev[1] or prev[0]): after = prev[0] if new_def not in old or after: - sql.append('{command} {new_def} {after}'.format( - command=("ADD" if (old_name or new_name) not in old_names else - "MODIFY" if not old_name else - "CHANGE `%s`" % old_name), - new_def=new_def, - after="" if after is None else "AFTER `%s`" % after)) + sql.append( + "{command} {new_def} {after}".format( + command=( + "ADD" + if (old_name or new_name) not in old_names + else "MODIFY" + if not old_name + else "CHANGE `%s`" % old_name + ), + new_def=new_def, + after="" if after is None else "AFTER `%s`" % after, + ) + ) prev = new_name, old_name return sql @@ -357,19 +481,31 @@ def alter(definition, old_definition, context): :param context: the context in which to evaluate foreign key definitions :return: string SQL ALTER command, list of new stores used for external storage """ - table_comment, primary_key, attribute_sql, foreign_key_sql, index_sql, external_stores = prepare_declare( - definition, context) - table_comment_, primary_key_, attribute_sql_, foreign_key_sql_, index_sql_, external_stores_ = prepare_declare( - old_definition, context) + ( + table_comment, + primary_key, + attribute_sql, + foreign_key_sql, + index_sql, + external_stores, + ) = prepare_declare(definition, context) + ( + table_comment_, + primary_key_, + attribute_sql_, + foreign_key_sql_, + index_sql_, + external_stores_, + ) = prepare_declare(old_definition, context) # analyze differences between declarations sql = list() if primary_key != primary_key_: - raise NotImplementedError('table.alter cannot alter the primary key (yet).') + raise NotImplementedError("table.alter cannot alter the primary key (yet).") if foreign_key_sql != foreign_key_sql_: - raise NotImplementedError('table.alter cannot alter foreign keys (yet).') + raise NotImplementedError("table.alter cannot alter foreign keys (yet).") if index_sql != index_sql_: - raise NotImplementedError('table.alter cannot alter indexes (yet)') + raise NotImplementedError("table.alter cannot alter indexes (yet)") if attribute_sql != attribute_sql_: sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key)) if table_comment != table_comment_: @@ -379,9 +515,11 @@ def alter(definition, old_definition, context): def compile_index(line, index_sql): match = index_parser.parseString(line) - index_sql.append('{unique} index ({attrs})'.format( - unique=match.unique, - attrs=','.join('`%s`' % a for a in match.attr_list))) + index_sql.append( + "{unique} index ({attrs})".format( + unique=match.unique, attrs=",".join("`%s`" % a for a in match.attr_list) + ) + ) def substitute_special_type(match, category, foreign_key_sql, context): @@ -391,31 +529,38 @@ def substitute_special_type(match, category, foreign_key_sql, context): :param foreign_key_sql: list of foreign key declarations to add to :param context: context for looking up user-defined attribute_type adapters """ - if category == 'UUID': - match['type'] = UUID_DATA_TYPE - elif category == 'INTERNAL_ATTACH': - match['type'] = 'LONGBLOB' + if category == "UUID": + match["type"] = UUID_DATA_TYPE + elif category == "INTERNAL_ATTACH": + match["type"] = "LONGBLOB" elif category in EXTERNAL_TYPES: - if category == 'FILEPATH' and not _support_filepath_types(): - raise DataJointError(""" + if category == "FILEPATH" and not _support_filepath_types(): + raise DataJointError( + """ The filepath data type is disabled until complete validation. To turn it on as experimental feature, set the environment variable {env} = TRUE or upgrade datajoint. - """.format(env=FILEPATH_FEATURE_SWITCH)) - match['store'] = match['type'].split('@', 1)[1] - match['type'] = UUID_DATA_TYPE + """.format( + env=FILEPATH_FEATURE_SWITCH + ) + ) + match["store"] = match["type"].split("@", 1)[1] + match["type"] = UUID_DATA_TYPE foreign_key_sql.append( "FOREIGN KEY (`{name}`) REFERENCES `{{database}}`.`{external_table_root}_{store}` (`hash`) " - "ON UPDATE RESTRICT ON DELETE RESTRICT".format(external_table_root=EXTERNAL_TABLE_ROOT, **match)) - elif category == 'ADAPTED': - adapter = get_adapter(context, match['type']) - match['type'] = adapter.attribute_type - category = match_type(match['type']) + "ON UPDATE RESTRICT ON DELETE RESTRICT".format( + external_table_root=EXTERNAL_TABLE_ROOT, **match + ) + ) + elif category == "ADAPTED": + adapter = get_adapter(context, match["type"]) + match["type"] = adapter.attribute_type + category = match_type(match["type"]) if category in SPECIAL_TYPES: # recursive redefinition from user-defined datatypes. substitute_special_type(match, category, foreign_key_sql, context) else: - assert False, 'Unknown special type' + assert False, "Unknown special type" def compile_attribute(line, in_key, foreign_key_sql, context): @@ -428,41 +573,67 @@ def compile_attribute(line, in_key, foreign_key_sql, context): :returns: (name, sql, is_external) -- attribute name and sql code for its declaration """ try: - match = attribute_parser.parseString(line + '#', parseAll=True) + match = attribute_parser.parseString(line + "#", parseAll=True) except pp.ParseException as err: - raise DataJointError('Declaration error in position {pos} in line:\n {line}\n{msg}'.format( - line=err.args[0], pos=err.args[1], msg=err.args[2])) - match['comment'] = match['comment'].rstrip('#') - if 'default' not in match: - match['default'] = '' + raise DataJointError( + "Declaration error in position {pos} in line:\n {line}\n{msg}".format( + line=err.args[0], pos=err.args[1], msg=err.args[2] + ) + ) + match["comment"] = match["comment"].rstrip("#") + if "default" not in match: + match["default"] = "" match = {k: v.strip() for k, v in match.items()} - match['nullable'] = match['default'].lower() == 'null' + match["nullable"] = match["default"].lower() == "null" - if match['nullable']: + if match["nullable"]: if in_key: - raise DataJointError('Primary key attributes cannot be nullable in line "%s"' % line) - match['default'] = 'DEFAULT NULL' # nullable attributes default to null + raise DataJointError( + 'Primary key attributes cannot be nullable in line "%s"' % line + ) + match["default"] = "DEFAULT NULL" # nullable attributes default to null else: - if match['default']: - quote = (match['default'].split('(')[0].upper() not in CONSTANT_LITERALS - and match['default'][0] not in '"\'') - match['default'] = 'NOT NULL DEFAULT ' + ('"%s"' if quote else "%s") % match['default'] + if match["default"]: + quote = ( + match["default"].split("(")[0].upper() not in CONSTANT_LITERALS + and match["default"][0] not in "\"'" + ) + match["default"] = ( + "NOT NULL DEFAULT " + ('"%s"' if quote else "%s") % match["default"] + ) else: - match['default'] = 'NOT NULL' + match["default"] = "NOT NULL" - match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment + match["comment"] = match["comment"].replace( + '"', '\\"' + ) # escape double quotes in comment - if match['comment'].startswith(':'): - raise DataJointError('An attribute comment must not start with a colon in comment "{comment}"'.format(**match)) + if match["comment"].startswith(":"): + raise DataJointError( + 'An attribute comment must not start with a colon in comment "{comment}"'.format( + **match + ) + ) - category = match_type(match['type']) + category = match_type(match["type"]) if category in SPECIAL_TYPES: - match['comment'] = ':{type}:{comment}'.format(**match) # insert custom type into comment + match["comment"] = ":{type}:{comment}".format( + **match + ) # insert custom type into comment substitute_special_type(match, category, foreign_key_sql, context) - if category in SERIALIZED_TYPES and match['default'] not in {'DEFAULT NULL', 'NOT NULL'}: + if category in SERIALIZED_TYPES and match["default"] not in { + "DEFAULT NULL", + "NOT NULL", + }: raise DataJointError( - 'The default value for a blob or attachment attributes can only be NULL in:\n{line}'.format(line=line)) - - sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '')).format(**match) - return match['name'], sql, match.get('store') + "The default value for a blob or attachment attributes can only be NULL in:\n{line}".format( + line=line + ) + ) + + sql = ( + "`{name}` {type} {default}" + + (' COMMENT "{comment}"' if match["comment"] else "") + ).format(**match) + return match["name"], sql, match.get("store") diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index e5e94225d..96dc8f7f4 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -18,13 +18,13 @@ def unite_master_parts(lst): """ for i in range(2, len(lst)): name = lst[i] - match = re.match(r'(?P`\w+`.`#?\w+)__\w+`', name) + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", name) if match: # name is a part table - master = match.group('master') - for j in range(i-1, -1, -1): - if lst[j] == master + '`' or lst[j].startswith(master + '__'): + master = match.group("master") + for j in range(i - 1, -1, -1): + if lst[j] == master + "`" or lst[j].startswith(master + "__"): # move from the ith position to the (j+1)th position - lst[j+1:i+1] = [name] + lst[j+1:i] + lst[j + 1 : i + 1] = [name] + lst[j + 1 : i] break return lst @@ -38,6 +38,7 @@ class Dependencies(nx.DiGraph): internally create objects with the expectation of empty constructors. See also: https://github.com/datajoint/datajoint-python/pull/443 """ + def __init__(self, connection=None): self._conn = connection self._node_alias_count = itertools.count() @@ -60,12 +61,16 @@ def load(self, force=True): self.clear() # load primary key info - keys = self._conn.query(""" + keys = self._conn.query( + """ SELECT concat('`', table_schema, '`.`', table_name, '`') as tab, column_name FROM information_schema.key_column_usage WHERE table_name not LIKE "~%%" AND table_schema in ('{schemas}') AND constraint_name="PRIMARY" - """.format(schemas="','".join(self._conn.schemas))) + """.format( + schemas="','".join(self._conn.schemas) + ) + ) pks = defaultdict(set) for key in keys: pks[key[0]].add(key[1]) @@ -75,7 +80,10 @@ def load(self, force=True): self.add_node(n, primary_key=pk) # load foreign keys - keys = ({k.lower(): v for k, v in elem.items()} for elem in self._conn.query(""" + keys = ( + {k.lower(): v for k, v in elem.items()} + for elem in self._conn.query( + """ SELECT constraint_name, concat('`', table_schema, '`.`', table_name, '`') as referencing_table, concat('`', referenced_table_schema, '`.`', referenced_table_name, '`') as referenced_table, @@ -83,32 +91,44 @@ def load(self, force=True): FROM information_schema.key_column_usage WHERE referenced_table_name NOT LIKE "~%%" AND (referenced_table_schema in ('{schemas}') OR referenced_table_schema is not NULL AND table_schema in ('{schemas}')) - """.format(schemas="','".join(self._conn.schemas)), as_dict=True)) + """.format( + schemas="','".join(self._conn.schemas) + ), + as_dict=True, + ) + ) fks = defaultdict(lambda: dict(attr_map=dict())) for key in keys: - d = fks[(key['constraint_name'], key['referencing_table'], key['referenced_table'])] - d['referencing_table'] = key['referencing_table'] - d['referenced_table'] = key['referenced_table'] - d['attr_map'][key['column_name']] = key['referenced_column_name'] + d = fks[ + ( + key["constraint_name"], + key["referencing_table"], + key["referenced_table"], + ) + ] + d["referencing_table"] = key["referencing_table"] + d["referenced_table"] = key["referenced_table"] + d["attr_map"][key["column_name"]] = key["referenced_column_name"] # add edges to the graph for fk in fks.values(): props = dict( - primary=set(fk['attr_map']) <= set(pks[fk['referencing_table']]), - attr_map=fk['attr_map'], - aliased=any(k != v for k, v in fk['attr_map'].items()), - multi=set(fk['attr_map']) != set(pks[fk['referencing_table']])) - if not props['aliased']: - self.add_edge(fk['referenced_table'], fk['referencing_table'], **props) + primary=set(fk["attr_map"]) <= set(pks[fk["referencing_table"]]), + attr_map=fk["attr_map"], + aliased=any(k != v for k, v in fk["attr_map"].items()), + multi=set(fk["attr_map"]) != set(pks[fk["referencing_table"]]), + ) + if not props["aliased"]: + self.add_edge(fk["referenced_table"], fk["referencing_table"], **props) else: # for aliased dependencies, add an extra node in the format '1', '2', etc - alias_node = '%d' % next(self._node_alias_count) + alias_node = "%d" % next(self._node_alias_count) self.add_node(alias_node) - self.add_edge(fk['referenced_table'], alias_node, **props) - self.add_edge(alias_node, fk['referencing_table'], **props) + self.add_edge(fk["referenced_table"], alias_node, **props) + self.add_edge(alias_node, fk["referencing_table"], **props) if not nx.is_directed_acyclic_graph(self): # pragma: no cover - raise DataJointError('DataJoint can only work with acyclic dependencies') + raise DataJointError("DataJoint can only work with acyclic dependencies") self._loaded = True def parents(self, table_name, primary=None): @@ -120,8 +140,11 @@ def parents(self, table_name, primary=None): :return: dict of tables referenced by the foreign keys of table """ self.load(force=False) - return {p[0]: p[2] for p in self.in_edges(table_name, data=True) - if primary is None or p[2]['primary'] == primary} + return { + p[0]: p[2] + for p in self.in_edges(table_name, data=True) + if primary is None or p[2]["primary"] == primary + } def children(self, table_name, primary=None): """ @@ -132,8 +155,11 @@ def children(self, table_name, primary=None): :return: dict of tables referencing the table through foreign keys """ self.load(force=False) - return {p[1]: p[2] for p in self.out_edges(table_name, data=True) - if primary is None or p[2]['primary'] == primary} + return { + p[1]: p[2] + for p in self.out_edges(table_name, data=True) + if primary is None or p[2]["primary"] == primary + } def descendants(self, full_table_name): """ @@ -141,10 +167,10 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph( - nx.algorithms.dag.descendants(self, full_table_name)) - return unite_master_parts([full_table_name] + list( - nx.algorithms.dag.topological_sort(nodes))) + nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) + return unite_master_parts( + [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) + ) def ancestors(self, full_table_name): """ @@ -152,7 +178,11 @@ def ancestors(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph( - nx.algorithms.dag.ancestors(self, full_table_name)) - return list(reversed(unite_master_parts(list( - nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]))) + nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) + return list( + reversed( + unite_master_parts( + list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] + ) + ) + ) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index fa03c123a..f5bb4cc8d 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -9,12 +9,14 @@ try: from matplotlib import pyplot as plt + plot_active = True except: plot_active = False try: from networkx.drawing.nx_pydot import pydot_layout + diagram_active = True except: diagram_active = False @@ -31,21 +33,26 @@ class _AliasNode: """ special class to indicate aliased foreign keys """ + pass def _get_tier(table_name): - if not table_name.startswith('`'): + if not table_name.startswith("`"): return _AliasNode else: try: - return next(tier for tier in user_table_classes - if re.fullmatch(tier.tier_regexp, table_name.split('`')[-2])) + return next( + tier + for tier in user_table_classes + if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) + ) except StopIteration: return None if not diagram_active: + class Diagram: """ Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz. @@ -56,9 +63,12 @@ class Diagram: """ def __init__(self, *args, **kwargs): - warnings.warn('Please install matplotlib and pygraphviz libraries to enable the Diagram feature.') + warnings.warn( + "Please install matplotlib and pygraphviz libraries to enable the Diagram feature." + ) else: + class Diagram(nx.DiGraph): """ Entity relationship diagram. @@ -81,6 +91,7 @@ class Diagram(nx.DiGraph): Note that diagram + 1 - 1 may differ from diagram - 1 + 1 and so forth. Only those tables that are loaded in the connection object are displayed """ + def __init__(self, source, context=None): if isinstance(source, Diagram): @@ -105,7 +116,9 @@ def __init__(self, source, context=None): try: connection = source.schema.connection except AttributeError: - raise DataJointError('Could not find database connection in %s' % repr(source[0])) + raise DataJointError( + "Could not find database connection in %s" % repr(source[0]) + ) # initialize graph from dependencies connection.dependencies.load() @@ -122,9 +135,11 @@ def __init__(self, source, context=None): try: database = source.schema.database except AttributeError: - raise DataJointError('Cannot plot Diagram for %s' % repr(source)) + raise DataJointError( + "Cannot plot Diagram for %s" % repr(source) + ) for node in self: - if node.startswith('`%s`' % database): + if node.startswith("`%s`" % database): self.nodes_to_show.add(node) @classmethod @@ -134,31 +149,44 @@ def from_sequence(cls, sequence): :param sequence: a sequence (e.g. list, tuple) :return: Diagram(arg1) + ... + Diagram(argn) """ - return functools.reduce(lambda x, y: x+y, map(Diagram, sequence)) + return functools.reduce(lambda x, y: x + y, map(Diagram, sequence)) def add_parts(self): """ Adds to the diagram the part tables of tables already included in the diagram :return: """ + def is_part(part, master): """ :param part: `database`.`table_name` :param master: `database`.`table_name` :return: True if part is part of master. """ - part = [s.strip('`') for s in part.split('.')] - master = [s.strip('`') for s in master.split('.')] - return master[0] == part[0] and master[1] + '__' == part[1][:len(master[1])+2] + part = [s.strip("`") for s in part.split(".")] + master = [s.strip("`") for s in master.split(".")] + return ( + master[0] == part[0] + and master[1] + "__" == part[1][: len(master[1]) + 2] + ) self = Diagram(self) # copy - self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) + self.nodes_to_show.update( + n + for n in self.nodes() + if any(is_part(n, m) for m in self.nodes_to_show) + ) return self def topological_sort(self): - """ :return: list of nodes in topological order """ - return unite_master_parts(list(nx.algorithms.dag.topological_sort( - nx.DiGraph(self).subgraph(self.nodes_to_show)))) + """:return: list of nodes in topological order""" + return unite_master_parts( + list( + nx.algorithms.dag.topological_sort( + nx.DiGraph(self).subgraph(self.nodes_to_show) + ) + ) + ) def __add__(self, arg): """ @@ -166,7 +194,7 @@ def __add__(self, arg): :return: Union of the diagrams when arg is another Diagram or an expansion downstream when arg is a positive integer. """ - self = Diagram(self) # copy + self = Diagram(self) # copy try: self.nodes_to_show.update(arg.nodes_to_show) except AttributeError: @@ -174,11 +202,17 @@ def __add__(self, arg): self.nodes_to_show.add(arg.full_table_name) except AttributeError: for i in range(arg): - new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + new = nx.algorithms.boundary.node_boundary( + self, self.nodes_to_show + ) if not new: break # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit()))) + new.update( + nx.algorithms.boundary.node_boundary( + self, (a for a in new if a.isdigit()) + ) + ) self.nodes_to_show.update(new) return self @@ -188,7 +222,7 @@ def __sub__(self, arg): :return: Difference of the diagrams when arg is another Diagram or an expansion upstream when arg is a positive integer. """ - self = Diagram(self) # copy + self = Diagram(self) # copy try: self.nodes_to_show.difference_update(arg.nodes_to_show) except AttributeError: @@ -197,11 +231,17 @@ def __sub__(self, arg): except AttributeError: for i in range(arg): graph = nx.DiGraph(self).reverse() - new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show) + new = nx.algorithms.boundary.node_boundary( + graph, self.nodes_to_show + ) if not new: break # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit()))) + new.update( + nx.algorithms.boundary.node_boundary( + graph, (a for a in new if a.isdigit()) + ) + ) self.nodes_to_show.update(new) return self @@ -211,7 +251,7 @@ def __mul__(self, arg): :param arg: another Diagram :return: a new Diagram comprising nodes that are present in both operands. """ - self = Diagram(self) # copy + self = Diagram(self) # copy self.nodes_to_show.intersection_update(arg.nodes_to_show) return self @@ -223,28 +263,39 @@ def _make_graph(self): # attributes for name in self.nodes_to_show: foreign_attributes = set( - attr for p in self.in_edges(name, data=True) - for attr in p[2]['attr_map'] if p[2]['primary']) - self.nodes[name]['distinguished'] = ( - 'primary_key' in self.nodes[name] and - foreign_attributes < self.nodes[name]['primary_key']) + attr + for p in self.in_edges(name, data=True) + for attr in p[2]["attr_map"] + if p[2]["primary"] + ) + self.nodes[name]["distinguished"] = ( + "primary_key" in self.nodes[name] + and foreign_attributes < self.nodes[name]["primary_key"] + ) # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary( - self, self.nodes_to_show)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), - self.nodes_to_show)) + gaps = set( + nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + ).intersection( + nx.algorithms.boundary.node_boundary( + nx.DiGraph(self).reverse(), self.nodes_to_show + ) + ) nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) - nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) - for n in graph}) + nx.set_node_attributes( + graph, name="node_type", values={n: _get_tier(n) for n in graph} + ) # relabel nodes to class names - mapping = {node: lookup_class_name(node, self.context) or node - for node in graph.nodes()} + mapping = { + node: lookup_class_name(node, self.context) or node + for node in graph.nodes() + } new_names = [mapping.values()] if len(new_names) > len(set(new_names)): raise DataJointError( - 'Some classes have identical names. The Diagram cannot be plotted.') + "Some classes have identical names. The Diagram cannot be plotted." + ) nx.relabel_nodes(graph, mapping, copy=False) return graph @@ -253,64 +304,125 @@ def make_dot(self): graph = self._make_graph() graph.nodes() - scale = 1.2 # scaling factor for fonts and boxes + scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html - None: dict(shape='circle', color="#FFFF0040", fontcolor='yellow', fontsize=round(scale*8), - size=0.4*scale, fixed=False), - _AliasNode: dict(shape='circle', color="#FF880080", fontcolor='#FF880080', fontsize=round(scale*0), - size=0.05*scale, fixed=True), - Manual: dict(shape='box', color="#00FF0030", fontcolor='darkgreen', fontsize=round(scale*10), - size=0.4*scale, fixed=False), - Lookup: dict(shape='plaintext', color='#00000020', fontcolor='black', fontsize=round(scale*8), - size=0.4*scale, fixed=False), - Computed: dict(shape='ellipse', color='#FF000020', fontcolor='#7F0000A0', fontsize=round(scale*10), - size=0.3*scale, fixed=True), - Imported: dict(shape='ellipse', color='#00007F40', fontcolor='#00007FA0', fontsize=round(scale*10), - size=0.4*scale, fixed=False), - Part: dict(shape='plaintext', color='#0000000', fontcolor='black', fontsize=round(scale*8), - size=0.1*scale, fixed=False)} - node_props = {node: label_props[d['node_type']] for node, d in dict(graph.nodes(data=True)).items()} + None: dict( + shape="circle", + color="#FFFF0040", + fontcolor="yellow", + fontsize=round(scale * 8), + size=0.4 * scale, + fixed=False, + ), + _AliasNode: dict( + shape="circle", + color="#FF880080", + fontcolor="#FF880080", + fontsize=round(scale * 0), + size=0.05 * scale, + fixed=True, + ), + Manual: dict( + shape="box", + color="#00FF0030", + fontcolor="darkgreen", + fontsize=round(scale * 10), + size=0.4 * scale, + fixed=False, + ), + Lookup: dict( + shape="plaintext", + color="#00000020", + fontcolor="black", + fontsize=round(scale * 8), + size=0.4 * scale, + fixed=False, + ), + Computed: dict( + shape="ellipse", + color="#FF000020", + fontcolor="#7F0000A0", + fontsize=round(scale * 10), + size=0.3 * scale, + fixed=True, + ), + Imported: dict( + shape="ellipse", + color="#00007F40", + fontcolor="#00007FA0", + fontsize=round(scale * 10), + size=0.4 * scale, + fixed=False, + ), + Part: dict( + shape="plaintext", + color="#0000000", + fontcolor="black", + fontsize=round(scale * 8), + size=0.1 * scale, + fixed=False, + ), + } + node_props = { + node: label_props[d["node_type"]] + for node, d in dict(graph.nodes(data=True)).items() + } dot = nx.drawing.nx_pydot.to_pydot(graph) for node in dot.get_nodes(): - node.set_shape('circle') + node.set_shape("circle") name = node.get_name().strip('"') props = node_props[name] - node.set_fontsize(props['fontsize']) - node.set_fontcolor(props['fontcolor']) - node.set_shape(props['shape']) - node.set_fontname('arial') - node.set_fixedsize('shape' if props['fixed'] else False) - node.set_width(props['size']) - node.set_height(props['size']) - if name.split('.')[0] in self.context: + node.set_fontsize(props["fontsize"]) + node.set_fontcolor(props["fontcolor"]) + node.set_shape(props["shape"]) + node.set_fontname("arial") + node.set_fixedsize("shape" if props["fixed"] else False) + node.set_width(props["size"]) + node.set_height(props["size"]) + if name.split(".")[0] in self.context: cls = eval(name, self.context) - assert(issubclass(cls, Table)) - description = cls().describe(context=self.context, printout=False).split('\n') + assert issubclass(cls, Table) + description = ( + cls().describe(context=self.context, printout=False).split("\n") + ) description = ( - '-'*30 if q.startswith('---') else q.replace('->', '→') if '->' in q else q.split(':')[0] - for q in description if not q.startswith('#')) - node.set_tooltip(' '.join(description)) - node.set_label("<"+name+">" if node.get('distinguished') == 'True' else name) - node.set_color(props['color']) - node.set_style('filled') + "-" * 30 + if q.startswith("---") + else q.replace("->", "→") + if "->" in q + else q.split(":")[0] + for q in description + if not q.startswith("#") + ) + node.set_tooltip(" ".join(description)) + node.set_label( + "<" + name + ">" + if node.get("distinguished") == "True" + else name + ) + node.set_color(props["color"]) + node.set_style("filled") for edge in dot.get_edges(): # see https://graphviz.org/doc/info/attrs.html src = edge.get_source().strip('"') dest = edge.get_destination().strip('"') props = graph.get_edge_data(src, dest) - edge.set_color('#00000040') - edge.set_style('solid' if props['primary'] else 'dashed') - master_part = graph.nodes[dest]['node_type'] is Part and dest.startswith(src+'.') + edge.set_color("#00000040") + edge.set_style("solid" if props["primary"] else "dashed") + master_part = graph.nodes[dest][ + "node_type" + ] is Part and dest.startswith(src + ".") edge.set_weight(3 if master_part else 1) - edge.set_arrowhead('none') - edge.set_penwidth(.75 if props['multi'] else 2) + edge.set_arrowhead("none") + edge.set_penwidth(0.75 if props["multi"] else 2) return dot def make_svg(self): from IPython.display import SVG + return SVG(self.make_dot().create_svg()) def make_png(self): @@ -328,26 +440,26 @@ def _repr_svg_(self): def draw(self): if plot_active: plt.imshow(self.make_image()) - plt.gca().axis('off') + plt.gca().axis("off") plt.show() else: raise DataJointError("pyplot was not imported") def save(self, filename, format=None): if format is None: - if filename.lower().endswith('.png'): - format = 'png' - elif filename.lower().endswith('.svg'): - format = 'svg' - if format.lower() == 'png': - with open(filename, 'wb') as f: + if filename.lower().endswith(".png"): + format = "png" + elif filename.lower().endswith(".svg"): + format = "svg" + if format.lower() == "png": + with open(filename, "wb") as f: f.write(self.make_png().getbuffer().tobytes()) - elif format.lower() == 'svg': - with open(filename, 'w') as f: + elif format.lower() == "svg": + with open(filename, "w") as f: f.write(self.make_svg().data) else: - raise DataJointError('Unsupported file format') + raise DataJointError("Unsupported file format") @staticmethod def _layout(graph, **kwargs): - return pydot_layout(graph, prog='dot', **kwargs) + return pydot_layout(graph, prog="dot", **kwargs) diff --git a/datajoint/errors.py b/datajoint/errors.py index fe0ffc539..621ad1a3d 100644 --- a/datajoint/errors.py +++ b/datajoint/errors.py @@ -15,13 +15,21 @@ class DataJointError(Exception): """ Base class for errors specific to DataJoint internal operation. """ + def __init__(self, *args): from .plugin import connection_plugins, type_plugins - self.__cause__ = PluginWarning( - 'Unverified DataJoint plugin detected.') if any([any( - [not plugins[k]['verified'] for k in plugins]) - for plugins in [connection_plugins, type_plugins] - if plugins]) else None + + self.__cause__ = ( + PluginWarning("Unverified DataJoint plugin detected.") + if any( + [ + any([not plugins[k]["verified"] for k in plugins]) + for plugins in [connection_plugins, type_plugins] + if plugins + ] + ) + else None + ) def suggest(self, *args): """ diff --git a/datajoint/expression.py b/datajoint/expression.py index 624f1122b..6e4194cbc 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -7,8 +7,14 @@ from .errors import DataJointError from .fetch import Fetch, Fetch1 from .preview import preview, repr_html -from .condition import AndList, Not, \ - make_condition, assert_join_compatibility, extract_column_names, PromiscuousOperand +from .condition import ( + AndList, + Not, + make_condition, + assert_join_compatibility, + extract_column_names, + PromiscuousOperand, +) from .declare import CONSTANT_LITERALS logger = logging.getLogger(__name__) @@ -35,6 +41,7 @@ class QueryExpression: 2. A projection is applied remapping remapped attributes 3. Subclasses: Join, Aggregation, and Union have additional specific rules. """ + _restriction = None _restriction_attributes = None _left = [] # list of booleans True for left joins, False for inner joins @@ -50,36 +57,36 @@ class QueryExpression: @property def connection(self): - """ a dj.Connection object """ + """a dj.Connection object""" assert self._connection is not None return self._connection @property def support(self): - """ A list of table names or subqueries to from the FROM clause """ + """A list of table names or subqueries to from the FROM clause""" assert self._support is not None return self._support @property def heading(self): - """ a dj.Heading object, reflects the effects of the projection operator .proj """ + """a dj.Heading object, reflects the effects of the projection operator .proj""" return self._heading @property def original_heading(self): - """ a dj.Heading object reflecting the attributes before projection """ + """a dj.Heading object reflecting the attributes before projection""" return self._original_heading or self.heading @property def restriction(self): - """ a AndList object of restrictions applied to input to produce the result """ + """a AndList object of restrictions applied to input to produce the result""" if self._restriction is None: self._restriction = AndList() return self._restriction @property def restriction_attributes(self): - """ the set of attribute names invoked in the WHERE clause """ + """the set of attribute names invoked in the WHERE clause""" if self._restriction_attributes is None: self._restriction_attributes = set() return self._restriction_attributes @@ -88,36 +95,44 @@ def restriction_attributes(self): def primary_key(self): return self.heading.primary_key - _subquery_alias_count = count() # count for alias names used in the FROM clause + _subquery_alias_count = count() # count for alias names used in the FROM clause def from_clause(self): - support = ('(' + src.make_sql() + ') as `$%x`' % next( - self._subquery_alias_count) if isinstance(src, QueryExpression) - else src for src in self.support) + support = ( + "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) + if isinstance(src, QueryExpression) + else src + for src in self.support + ) clause = next(support) for s, left in zip(support, self._left): - clause += ' NATURAL{left} JOIN {clause}'.format( - left=" LEFT" if left else "", - clause=s) + clause += " NATURAL{left} JOIN {clause}".format( + left=" LEFT" if left else "", clause=s + ) return clause def where_clause(self): - return '' if not self.restriction else ' WHERE (%s)' % ')AND('.join( - str(s) for s in self.restriction) + return ( + "" + if not self.restriction + else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) + ) def make_sql(self, fields=None): """ Make the SQL SELECT statement. :param fields: used to explicitly set the select attributes """ - return 'SELECT {distinct}{fields} FROM {from_}{where}'.format( + return "SELECT {distinct}{fields} FROM {from_}{where}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), - from_=self.from_clause(), where=self.where_clause()) + from_=self.from_clause(), + where=self.where_clause(), + ) # --------- query operators ----------- def make_subquery(self): - """ create a new SELECT statement where self is the FROM clause """ + """create a new SELECT statement where self is the FROM clause""" result = QueryExpression() result._connection = self.connection result._support = [self] @@ -175,19 +190,24 @@ def restrict(self, restriction): return self # restriction has no effect, return the same object # check that all attributes in condition are present in the query try: - raise DataJointError("Attribute `%s` is not found in query." % next( - attr for attr in attributes if attr not in self.heading.names)) + raise DataJointError( + "Attribute `%s` is not found in query." + % next(attr for attr in attributes if attr not in self.heading.names) + ) except StopIteration: pass # all ok # If the new condition uses any new attributes, a subquery is required. # However, Aggregation's HAVING statement works fine with aliased attributes. need_subquery = isinstance(self, Union) or ( - not isinstance(self, Aggregation) and self.heading.new_attributes) + not isinstance(self, Aggregation) and self.heading.new_attributes + ) if need_subquery: result = self.make_subquery() else: result = copy.copy(self) - result._restriction = AndList(self.restriction) # copy to preserve the original + result._restriction = AndList( + self.restriction + ) # copy to preserve the original result.restriction.append(new_condition) result.restriction_attributes.update(attributes) return result @@ -266,14 +286,21 @@ def join(self, other, semantic_check=True, left=False): # needs subquery if self's FROM clause has common attributes with other's FROM clause need_subquery1 = need_subquery2 = bool( (set(self.original_heading.names) & set(other.original_heading.names)) - - join_attributes) + - join_attributes + ) # need subquery if any of the join attributes are derived - need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or - any(n in self.heading.new_attributes for n in join_attributes) - or isinstance(self, Union)) - need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or - any(n in other.heading.new_attributes for n in join_attributes) - or isinstance(self, Union)) + need_subquery1 = ( + need_subquery1 + or isinstance(self, Aggregation) + or any(n in self.heading.new_attributes for n in join_attributes) + or isinstance(self, Union) + ) + need_subquery2 = ( + need_subquery2 + or isinstance(other, Aggregation) + or any(n in other.heading.new_attributes for n in join_attributes) + or isinstance(self, Union) + ) if need_subquery1: self = self.make_subquery() if need_subquery2: @@ -314,66 +341,115 @@ def proj(self, *attributes, **named_attributes): Each attribute name can only be used once. """ # new attributes in parentheses are included again with the new name without removing original - duplication_pattern = re.compile(fr'^\s*\(\s*(?!{"|".join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*\)\s*$') + duplication_pattern = re.compile( + rf'^\s*\(\s*(?!{"|".join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*\)\s*$' + ) # attributes without parentheses renamed - rename_pattern = re.compile(fr'^\s*(?!{"|".join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*$') - replicate_map = {k: m.group('name') - for k, m in ((k, duplication_pattern.match(v)) for k, v in named_attributes.items()) if m} - rename_map = {k: m.group('name') - for k, m in ((k, rename_pattern.match(v)) for k, v in named_attributes.items()) if m} - compute_map = {k: v for k, v in named_attributes.items() - if not duplication_pattern.match(v) and not rename_pattern.match(v)} + rename_pattern = re.compile( + rf'^\s*(?!{"|".join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*$' + ) + replicate_map = { + k: m.group("name") + for k, m in ( + (k, duplication_pattern.match(v)) for k, v in named_attributes.items() + ) + if m + } + rename_map = { + k: m.group("name") + for k, m in ( + (k, rename_pattern.match(v)) for k, v in named_attributes.items() + ) + if m + } + compute_map = { + k: v + for k, v in named_attributes.items() + if not duplication_pattern.match(v) and not rename_pattern.match(v) + } attributes = set(attributes) # include primary key attributes.update((k for k in self.primary_key if k not in rename_map.values())) # include all secondary attributes with Ellipsis if Ellipsis in attributes: attributes.discard(Ellipsis) - attributes.update((a for a in self.heading.secondary_attributes - if a not in attributes and a not in rename_map.values())) + attributes.update( + ( + a + for a in self.heading.secondary_attributes + if a not in attributes and a not in rename_map.values() + ) + ) try: - raise DataJointError("%s is not a valid data type for an attribute in .proj" % next( - a for a in attributes if not isinstance(a, str))) + raise DataJointError( + "%s is not a valid data type for an attribute in .proj" + % next(a for a in attributes if not isinstance(a, str)) + ) except StopIteration: pass # normal case # remove excluded attributes, specified as `-attr' - excluded = set(a for a in attributes if a.strip().startswith('-')) + excluded = set(a for a in attributes if a.strip().startswith("-")) attributes.difference_update(excluded) - excluded = set(a.lstrip('-').strip() for a in excluded) + excluded = set(a.lstrip("-").strip() for a in excluded) attributes.difference_update(excluded) try: - raise DataJointError("Cannot exclude primary key attribute %s", next( - a for a in excluded if a in self.primary_key)) + raise DataJointError( + "Cannot exclude primary key attribute %s", + next(a for a in excluded if a in self.primary_key), + ) except StopIteration: pass # all ok # check that all attributes exist in heading try: raise DataJointError( - 'Attribute `%s` not found.' % next(a for a in attributes if a not in self.heading.names)) + "Attribute `%s` not found." + % next(a for a in attributes if a not in self.heading.names) + ) except StopIteration: pass # all ok # check that all mentioned names are present in heading mentions = attributes.union(replicate_map.values()).union(rename_map.values()) try: - raise DataJointError("Attribute '%s' not found." % next(a for a in mentions if not self.heading.names)) + raise DataJointError( + "Attribute '%s' not found." + % next(a for a in mentions if not self.heading.names) + ) except StopIteration: pass # all ok # check that newly created attributes do not clash with any other selected attributes try: - raise DataJointError("Attribute `%s` already exists" % next( - a for a in rename_map if a in attributes.union(compute_map).union(replicate_map))) + raise DataJointError( + "Attribute `%s` already exists" + % next( + a + for a in rename_map + if a in attributes.union(compute_map).union(replicate_map) + ) + ) except StopIteration: pass # all ok try: - raise DataJointError("Attribute `%s` already exists" % next( - a for a in compute_map if a in attributes.union(rename_map).union(replicate_map))) + raise DataJointError( + "Attribute `%s` already exists" + % next( + a + for a in compute_map + if a in attributes.union(rename_map).union(replicate_map) + ) + ) except StopIteration: pass # all ok try: - raise DataJointError("Attribute `%s` already exists" % next( - a for a in replicate_map if a in attributes.union(rename_map).union(compute_map))) + raise DataJointError( + "Attribute `%s` already exists" + % next( + a + for a in replicate_map + if a in attributes.union(rename_map).union(compute_map) + ) + ) except StopIteration: pass # all ok @@ -383,15 +459,22 @@ def proj(self, *attributes, **named_attributes): used.update(replicate_map.values()) used.intersection_update(self.heading.names) need_subquery = isinstance(self, Union) or any( - self.heading[name].attribute_expression is not None for name in used) + self.heading[name].attribute_expression is not None for name in used + ) if not need_subquery and self.restriction: # need a subquery if the restriction applies to attributes that have been renamed - need_subquery = any(name in self.restriction_attributes for name in self.heading.new_attributes) + need_subquery = any( + name in self.restriction_attributes + for name in self.heading.new_attributes + ) result = self.make_subquery() if need_subquery else copy.copy(self) result._original_heading = result.original_heading result._heading = result.heading.select( - attributes, rename_map=dict(**rename_map, **replicate_map), compute_map=compute_map) + attributes, + rename_map=dict(**rename_map, **replicate_map), + compute_map=compute_map, + ) return result def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes): @@ -408,8 +491,9 @@ def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes): attributes = set(attributes) attributes.discard(Ellipsis) attributes.update(self.heading.secondary_attributes) - return Aggregation.create( - self, group=group, keep_all_rows=keep_all_rows).proj(*attributes, **named_attributes) + return Aggregation.create(self, group=group, keep_all_rows=keep_all_rows).proj( + *attributes, **named_attributes + ) aggregate = aggr # alias for aggr @@ -445,22 +529,33 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" return self.connection.query( - 'SELECT {select_} FROM {from_}{where}'.format( - select_=('count(*)' if any(self._left) - else 'count(DISTINCT {fields})'.format(fields=self.heading.as_sql( - self.primary_key, include_aliases=False))), + "SELECT {select_} FROM {from_}{where}".format( + select_=( + "count(*)" + if any(self._left) + else "count(DISTINCT {fields})".format( + fields=self.heading.as_sql( + self.primary_key, include_aliases=False + ) + ) + ), from_=self.from_clause(), - where=self.where_clause())).fetchone()[0] + where=self.where_clause(), + ) + ).fetchone()[0] def __bool__(self): """ :return: True if the result is not empty. Equivalent to len(self) > 0 but often faster e.g. ``bool(q1)``. """ - return bool(self.connection.query( - 'SELECT EXISTS(SELECT 1 FROM {from_}{where})'.format( - from_=self.from_clause(), - where=self.where_clause())).fetchone()[0]) + return bool( + self.connection.query( + "SELECT EXISTS(SELECT 1 FROM {from_}{where})".format( + from_=self.from_clause(), where=self.where_clause() + ) + ).fetchone()[0] + ) def __contains__(self, item): """ @@ -479,7 +574,7 @@ def __iter__(self): :param self: iterator-compatible QueryExpression object """ self._iter_only_key = all(v.in_key for v in self.heading.attributes.values()) - self._iter_keys = self.fetch('KEY') + self._iter_keys = self.fetch("KEY") return self def __next__(self): @@ -495,8 +590,10 @@ def __next__(self): key = self._iter_keys.pop(0) except AttributeError: # self._iter_keys is missing because __iter__ has not been called. - raise TypeError("A QueryExpression object is not an iterator. " - "Use iter(obj) to create an iterator.") + raise TypeError( + "A QueryExpression object is not an iterator. " + "Use iter(obj) to create an iterator." + ) except IndexError: raise StopIteration else: @@ -516,12 +613,12 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): :return: query cursor """ if offset and limit is None: - raise DataJointError('limit is required when offset is set') + raise DataJointError("limit is required when offset is set") sql = self.make_sql() if order_by is not None: - sql += ' ORDER BY ' + ', '.join(order_by) + sql += " ORDER BY " + ", ".join(order_by) if limit is not None: - sql += ' LIMIT %d' % limit + (' OFFSET %d' % offset if offset else "") + sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) @@ -533,14 +630,18 @@ def __repr__(self): :type self: :class:`QueryExpression` :rtype: str """ - return super().__repr__() if config['loglevel'].lower() == 'debug' else self.preview() + return ( + super().__repr__() + if config["loglevel"].lower() == "debug" + else self.preview() + ) def preview(self, limit=None, width=None): - """ :return: a string of preview of the contents of the query. """ + """:return: a string of preview of the contents of the query.""" return preview(self, limit, width) def _repr_html_(self): - """ :return: HTML to display table in Jupyter notebook. """ + """:return: HTML to display table in Jupyter notebook.""" return repr_html(self) @@ -553,20 +654,23 @@ class Aggregation(QueryExpression): Aggregation is used QueryExpression.aggr and U.aggr. Aggregation is a private class in DataJoint, not exposed to users. """ - _left_restrict = None # the pre-GROUP BY conditions for the WHERE clause + + _left_restrict = None # the pre-GROUP BY conditions for the WHERE clause _subquery_alias_count = count() @classmethod def create(cls, arg, group, keep_all_rows=False): if inspect.isclass(group) and issubclass(group, QueryExpression): - group = group() # instantiate if a class + group = group() # instantiate if a class assert isinstance(group, QueryExpression) if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes: group = group.make_subquery() # subquery if left joining a join join = arg.join(group, left=keep_all_rows) # reuse the join logic result = cls() result._connection = join.connection - result._heading = join.heading.set_primary_key(arg.primary_key) # use left operand's primary key + result._heading = join.heading.set_primary_key( + arg.primary_key + ) # use left operand's primary key result._support = join.support result._left = join._left result._left_restrict = join.restriction # WHERE clause applied before GROUP BY @@ -575,37 +679,51 @@ def create(cls, arg, group, keep_all_rows=False): return result def where_clause(self): - return '' if not self._left_restrict else ' WHERE (%s)' % ')AND('.join( - str(s) for s in self._left_restrict) + return ( + "" + if not self._left_restrict + else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict) + ) def make_sql(self, fields=None): fields = self.heading.as_sql(fields or self.heading.names) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) - return 'SELECT {distinct}{fields} FROM {from_}{where}{group_by}'.format( + return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format( distinct="DISTINCT " if distinct else "", fields=fields, from_=self.from_clause(), where=self.where_clause(), - group_by="" if not self.primary_key else ( - " GROUP BY `%s`" % '`,`'.join(self._grouping_attributes) + - ("" if not self.restriction else ' HAVING (%s)' % ')AND('.join(self.restriction)))) + group_by="" + if not self.primary_key + else ( + " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) + + ( + "" + if not self.restriction + else " HAVING (%s)" % ")AND(".join(self.restriction) + ) + ), + ) def __len__(self): return self.connection.query( - 'SELECT count(1) FROM ({subquery}) `${alias:x}`'.format( - subquery=self.make_sql(), - alias=next(self._subquery_alias_count))).fetchone()[0] + "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( + subquery=self.make_sql(), alias=next(self._subquery_alias_count) + ) + ).fetchone()[0] def __bool__(self): - return bool(self.connection.query( - 'SELECT EXISTS({sql})'.format(sql=self.make_sql()))) + return bool( + self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())) + ) class Union(QueryExpression): """ Union is the private DataJoint class that implements the union operator. """ + __count = count() @classmethod @@ -614,15 +732,22 @@ def create(cls, arg1, arg2): arg2 = arg2() # instantiate if a class if not isinstance(arg2, QueryExpression): raise DataJointError( - "A QueryExpression can only be unioned with another QueryExpression") + "A QueryExpression can only be unioned with another QueryExpression" + ) if arg1.connection != arg2.connection: raise DataJointError( - "Cannot operate on QueryExpressions originating from different connections.") + "Cannot operate on QueryExpressions originating from different connections." + ) if set(arg1.primary_key) != set(arg2.primary_key): - raise DataJointError("The operands of a union must share the same primary key.") - if set(arg1.heading.secondary_attributes) & set(arg2.heading.secondary_attributes): raise DataJointError( - "The operands of a union must not share any secondary attributes.") + "The operands of a union must share the same primary key." + ) + if set(arg1.heading.secondary_attributes) & set( + arg2.heading.secondary_attributes + ): + raise DataJointError( + "The operands of a union must not share any secondary attributes." + ) result = cls() result._connection = arg1.connection result._heading = arg1.heading.join(arg2.heading) @@ -631,38 +756,51 @@ def create(cls, arg1, arg2): def make_sql(self): arg1, arg2 = self._support - if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes: + if ( + not arg1.heading.secondary_attributes + and not arg2.heading.secondary_attributes + ): # no secondary attributes: use UNION DISTINCT fields = arg1.primary_key - return ("SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format( - sql1=arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields), - sql2=arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields), - alias=next(self.__count) - )) + return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format( + sql1=arg1.make_sql() + if isinstance(arg1, Union) + else arg1.make_sql(fields), + sql2=arg2.make_sql() + if isinstance(arg2, Union) + else arg2.make_sql(fields), + alias=next(self.__count), + ) # with secondary attributes, use union of left join with antijoin fields = self.heading.names sql1 = arg1.join(arg2, left=True).make_sql(fields) - sql2 = (arg2 - arg1).proj( - ..., **{k: 'NULL' for k in arg1.heading.secondary_attributes}).make_sql(fields) + sql2 = ( + (arg2 - arg1) + .proj(..., **{k: "NULL" for k in arg1.heading.secondary_attributes}) + .make_sql(fields) + ) return "({sql1}) UNION ({sql2})".format(sql1=sql1, sql2=sql2) def from_clause(self): - """ The union does not use a FROM clause """ + """The union does not use a FROM clause""" assert False def where_clause(self): - """ The union does not use a WHERE clause """ + """The union does not use a WHERE clause""" assert False def __len__(self): return self.connection.query( - 'SELECT count(1) FROM ({subquery}) `${alias:x}`'.format( + "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( subquery=self.make_sql(), - alias=next(QueryExpression._subquery_alias_count))).fetchone()[0] + alias=next(QueryExpression._subquery_alias_count), + ) + ).fetchone()[0] def __bool__(self): - return bool(self.connection.query( - 'SELECT EXISTS({sql})'.format(sql=self.make_sql()))) + return bool( + self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())) + ) class U: @@ -725,9 +863,9 @@ def primary_key(self): def __and__(self, other): if inspect.isclass(other) and issubclass(other, QueryExpression): - other = other() # instantiate if a class + other = other() # instantiate if a class if not isinstance(other, QueryExpression): - raise DataJointError('Set U can only be restricted with a QueryExpression.') + raise DataJointError("Set U can only be restricted with a QueryExpression.") result = copy.copy(other) result._distinct = True result._heading = result.heading.set_primary_key(self.primary_key) @@ -744,23 +882,25 @@ def join(self, other, left=False): :return: a copy of the other query expression with the primary key extended. """ if inspect.isclass(other) and issubclass(other, QueryExpression): - other = other() # instantiate if a class + other = other() # instantiate if a class if not isinstance(other, QueryExpression): - raise DataJointError('Set U can only be joined with a QueryExpression.') + raise DataJointError("Set U can only be joined with a QueryExpression.") try: raise DataJointError( - 'Attribute `%s` not found' % next(k for k in self.primary_key - if k not in other.heading.names)) + "Attribute `%s` not found" + % next(k for k in self.primary_key if k not in other.heading.names) + ) except StopIteration: pass # all ok result = copy.copy(other) result._heading = result.heading.set_primary_key( - other.primary_key + [k for k in self.primary_key - if k not in other.primary_key]) + other.primary_key + + [k for k in self.primary_key if k not in other.primary_key] + ) return result def __mul__(self, other): - """ shorthand for join """ + """shorthand for join""" return self.join(other) def aggr(self, group, **named_attributes): @@ -771,9 +911,12 @@ def aggr(self, group, **named_attributes): :param named_attributes: computations of the form new_attribute="sql expression on attributes of group" :return: The derived query expression """ - if named_attributes.get('keep_all_rows', False): + if named_attributes.get("keep_all_rows", False): raise DataJointError( - 'Cannot set keep_all_rows=True when aggregating on a universal set.') - return Aggregation.create(self, group=group, keep_all_rows=False).proj(**named_attributes) + "Cannot set keep_all_rows=True when aggregating on a universal set." + ) + return Aggregation.create(self, group=group, keep_all_rows=False).proj( + **named_attributes + ) aggregate = aggr # alias for aggr diff --git a/datajoint/external.py b/datajoint/external.py index 62120e306..197176e51 100644 --- a/datajoint/external.py +++ b/datajoint/external.py @@ -10,15 +10,22 @@ from . import s3 from .utils import safe_write, safe_copy -CACHE_SUBFOLDING = (2, 2) # (2, 2) means "0123456789abcd" will be saved as "01/23/0123456789abcd" -SUPPORT_MIGRATED_BLOBS = True # support blobs migrated from datajoint 0.11.* +CACHE_SUBFOLDING = ( + 2, + 2, +) # (2, 2) means "0123456789abcd" will be saved as "01/23/0123456789abcd" +SUPPORT_MIGRATED_BLOBS = True # support blobs migrated from datajoint 0.11.* def subfold(name, folds): """ subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde'] """ - return (name[:folds[0]].lower(),) + subfold(name[folds[0]:], folds[1:]) if folds else () + return ( + (name[: folds[0]].lower(),) + subfold(name[folds[0] :], folds[1:]) + if folds + else () + ) class ExternalTable(Table): @@ -26,24 +33,29 @@ class ExternalTable(Table): The table tracking externally stored objects. Declare as ExternalTable(connection, database) """ + def __init__(self, connection, store, database): self.store = store self.spec = config.get_store_spec(store) self._s3 = None self.database = database self._connection = connection - self._heading = Heading(table_info=dict( - conn=connection, - database=database, - table_name=self.table_name, - context=None)) + self._heading = Heading( + table_info=dict( + conn=connection, + database=database, + table_name=self.table_name, + context=None, + ) + ) self._support = [self.full_table_name] if not self.is_declared: self.declare() self._s3 = None - if self.spec['protocol'] == 'file' and not Path(self.spec['location']).is_dir(): - raise FileNotFoundError('Inaccessible local directory %s' % - self.spec['location']) from None + if self.spec["protocol"] == "file" and not Path(self.spec["location"]).is_dir(): + raise FileNotFoundError( + "Inaccessible local directory %s" % self.spec["location"] + ) from None @property def definition(self): @@ -60,7 +72,9 @@ def definition(self): @property def table_name(self): - return '{external_table_root}_{store}'.format(external_table_root=EXTERNAL_TABLE_ROOT, store=self.store) + return "{external_table_root}_{store}".format( + external_table_root=EXTERNAL_TABLE_ROOT, store=self.store + ) @property def s3(self): @@ -73,60 +87,66 @@ def s3(self): def _make_external_filepath(self, relative_filepath): """resolve the complete external path based on the relative path""" # Strip root - if self.spec['protocol'] == 's3': - posix_path = PurePosixPath(PureWindowsPath(self.spec['location'])) - location_path = Path( - *posix_path.parts[1:]) if len( - self.spec['location']) > 0 and any( - case in posix_path.parts[0] for case in ( - '\\', ':')) else Path(posix_path) + if self.spec["protocol"] == "s3": + posix_path = PurePosixPath(PureWindowsPath(self.spec["location"])) + location_path = ( + Path(*posix_path.parts[1:]) + if len(self.spec["location"]) > 0 + and any(case in posix_path.parts[0] for case in ("\\", ":")) + else Path(posix_path) + ) return PurePosixPath(location_path, relative_filepath) # Preserve root - elif self.spec['protocol'] == 'file': - return PurePosixPath(Path(self.spec['location']), relative_filepath) + elif self.spec["protocol"] == "file": + return PurePosixPath(Path(self.spec["location"]), relative_filepath) else: assert False - def _make_uuid_path(self, uuid, suffix=''): + def _make_uuid_path(self, uuid, suffix=""): """create external path based on the uuid hash""" - return self._make_external_filepath(PurePosixPath( - self.database, '/'.join(subfold(uuid.hex, self.spec['subfolding'])), uuid.hex).with_suffix(suffix)) + return self._make_external_filepath( + PurePosixPath( + self.database, + "/".join(subfold(uuid.hex, self.spec["subfolding"])), + uuid.hex, + ).with_suffix(suffix) + ) def _upload_file(self, local_path, external_path, metadata=None): - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": self.s3.fput(local_path, external_path, metadata) - elif self.spec['protocol'] == 'file': + elif self.spec["protocol"] == "file": safe_copy(local_path, external_path, overwrite=True) else: assert False def _download_file(self, external_path, download_path): - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": self.s3.fget(external_path, download_path) - elif self.spec['protocol'] == 'file': + elif self.spec["protocol"] == "file": safe_copy(external_path, download_path) else: assert False def _upload_buffer(self, buffer, external_path): - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": self.s3.put(external_path, buffer) - elif self.spec['protocol'] == 'file': + elif self.spec["protocol"] == "file": safe_write(external_path, buffer) else: assert False def _download_buffer(self, external_path): - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": return self.s3.get(external_path) - if self.spec['protocol'] == 'file': + if self.spec["protocol"] == "file": return Path(external_path).read_bytes() assert False def _remove_external_file(self, external_path): - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": self.s3.remove_object(external_path) - elif self.spec['protocol'] == 'file': + elif self.spec["protocol"] == "file": try: Path(external_path).unlink() except FileNotFoundError: @@ -136,9 +156,9 @@ def exists(self, external_filepath): """ :return: True if the external file is accessible """ - if self.spec['protocol'] == 's3': + if self.spec["protocol"] == "s3": return self.s3.exists(external_filepath) - if self.spec['protocol'] == 'file': + if self.spec["protocol"] == "file": return Path(external_filepath).is_file() assert False @@ -154,7 +174,10 @@ def put(self, blob): self.connection.query( "INSERT INTO {tab} (hash, size) VALUES (%s, {size}) ON DUPLICATE KEY " "UPDATE timestamp=CURRENT_TIMESTAMP".format( - tab=self.full_table_name, size=len(blob)), args=(uuid.bytes,)) + tab=self.full_table_name, size=len(blob) + ), + args=(uuid.bytes,), + ) return uuid def get(self, uuid): @@ -165,7 +188,7 @@ def get(self, uuid): return None # attempt to get object from cache blob = None - cache_folder = config.get('cache', None) + cache_folder = config.get("cache", None) if cache_folder: try: cache_path = Path(cache_folder, *subfold(uuid.hex, CACHE_SUBFOLDING)) @@ -181,10 +204,14 @@ def get(self, uuid): if not SUPPORT_MIGRATED_BLOBS: raise # blobs migrated from datajoint 0.11 are stored at explicitly defined filepaths - relative_filepath, contents_hash = (self & {'hash': uuid}).fetch1('filepath', 'contents_hash') + relative_filepath, contents_hash = (self & {"hash": uuid}).fetch1( + "filepath", "contents_hash" + ) if relative_filepath is None: raise - blob = self._download_buffer(self._make_external_filepath(relative_filepath)) + blob = self._download_buffer( + self._make_external_filepath(relative_filepath) + ) if cache_folder: cache_path.mkdir(parents=True, exist_ok=True) safe_write(cache_path / uuid.hex, blob) @@ -194,25 +221,29 @@ def get(self, uuid): def upload_attachment(self, local_path): attachment_name = Path(local_path).name - uuid = uuid_from_file(local_path, init_string=attachment_name + '\0') - external_path = self._make_uuid_path(uuid, '.' + attachment_name) + uuid = uuid_from_file(local_path, init_string=attachment_name + "\0") + external_path = self._make_uuid_path(uuid, "." + attachment_name) self._upload_file(local_path, external_path) # insert tracking info - self.connection.query(""" + self.connection.query( + """ INSERT INTO {tab} (hash, size, attachment_name) VALUES (%s, {size}, "{attachment_name}") ON DUPLICATE KEY UPDATE timestamp=CURRENT_TIMESTAMP""".format( tab=self.full_table_name, size=Path(local_path).stat().st_size, - attachment_name=attachment_name), args=[uuid.bytes]) + attachment_name=attachment_name, + ), + args=[uuid.bytes], + ) return uuid def get_attachment_name(self, uuid): - return (self & {'hash': uuid}).fetch1('attachment_name') + return (self & {"hash": uuid}).fetch1("attachment_name") def download_attachment(self, uuid, attachment_name, download_path): - """ save attachment from memory buffer into the save_path """ - external_path = self._make_uuid_path(uuid, '.' + attachment_name) + """save attachment from memory buffer into the save_path""" + external_path = self._make_uuid_path(uuid, "." + attachment_name) self._download_file(external_path, download_path) # --- FILEPATH --- @@ -225,28 +256,45 @@ def upload_filepath(self, local_filepath): """ local_filepath = Path(local_filepath) try: - relative_filepath = str(local_filepath.relative_to(self.spec['stage']).as_posix()) + relative_filepath = str( + local_filepath.relative_to(self.spec["stage"]).as_posix() + ) except ValueError: - raise DataJointError('The path {path} is not in stage {stage}'.format( - path=local_filepath.parent, **self.spec)) - uuid = uuid_from_buffer(init_string=relative_filepath) # hash relative path, not contents + raise DataJointError( + "The path {path} is not in stage {stage}".format( + path=local_filepath.parent, **self.spec + ) + ) + uuid = uuid_from_buffer( + init_string=relative_filepath + ) # hash relative path, not contents contents_hash = uuid_from_file(local_filepath) # check if the remote file already exists and verify that it matches - check_hash = (self & {'hash': uuid}).fetch('contents_hash') + check_hash = (self & {"hash": uuid}).fetch("contents_hash") if check_hash: # the tracking entry exists, check that it's the same file as before if contents_hash != check_hash[0]: raise DataJointError( - "A different version of '{file}' has already been placed.".format(file=relative_filepath)) + "A different version of '{file}' has already been placed.".format( + file=relative_filepath + ) + ) else: # upload the file and create its tracking entry - self._upload_file(local_filepath, self._make_external_filepath(relative_filepath), - metadata={'contents_hash': str(contents_hash)}) + self._upload_file( + local_filepath, + self._make_external_filepath(relative_filepath), + metadata={"contents_hash": str(contents_hash)}, + ) self.connection.query( "INSERT INTO {tab} (hash, size, filepath, contents_hash) VALUES (%s, {size}, '{filepath}', %s)".format( - tab=self.full_table_name, size=Path(local_filepath).stat().st_size, - filepath=relative_filepath), args=(uuid.bytes, contents_hash.bytes)) + tab=self.full_table_name, + size=Path(local_filepath).stat().st_size, + filepath=relative_filepath, + ), + args=(uuid.bytes, contents_hash.bytes), + ) return uuid def download_filepath(self, filepath_hash): @@ -256,15 +304,26 @@ def download_filepath(self, filepath_hash): :return: hash (UUID) of the contents of the downloaded file or Nones """ if filepath_hash is not None: - relative_filepath, contents_hash = (self & {'hash': filepath_hash}).fetch1('filepath', 'contents_hash') + relative_filepath, contents_hash = (self & {"hash": filepath_hash}).fetch1( + "filepath", "contents_hash" + ) external_path = self._make_external_filepath(relative_filepath) - local_filepath = Path(self.spec['stage']).absolute() / relative_filepath - file_exists = Path(local_filepath).is_file() and uuid_from_file(local_filepath) == contents_hash + local_filepath = Path(self.spec["stage"]).absolute() / relative_filepath + file_exists = ( + Path(local_filepath).is_file() + and uuid_from_file(local_filepath) == contents_hash + ) if not file_exists: self._download_file(external_path, local_filepath) checksum = uuid_from_file(local_filepath) - if checksum != contents_hash: # this should never happen without outside interference - raise DataJointError("'{file}' downloaded but did not pass checksum'".format(file=local_filepath)) + if ( + checksum != contents_hash + ): # this should never happen without outside interference + raise DataJointError( + "'{file}' downloaded but did not pass checksum'".format( + file=local_filepath + ) + ) return str(local_filepath), contents_hash # --- UTILITIES --- @@ -274,11 +333,19 @@ def references(self): """ :return: generator of referencing table names and their referencing columns """ - return ({k.lower(): v for k, v in elem.items()} for elem in self.connection.query(""" + return ( + {k.lower(): v for k, v in elem.items()} + for elem in self.connection.query( + """ SELECT concat('`', table_schema, '`.`', table_name, '`') as referencing_table, column_name FROM information_schema.key_column_usage WHERE referenced_table_name="{tab}" and referenced_table_schema="{db}" - """.format(tab=self.table_name, db=self.database), as_dict=True)) + """.format( + tab=self.table_name, db=self.database + ), + as_dict=True, + ) + ) def fetch_external_paths(self, **fetch_kwargs): """ @@ -288,17 +355,17 @@ def fetch_external_paths(self, **fetch_kwargs): """ fetch_kwargs.update(as_dict=True) paths = [] - for item in self.fetch('hash', 'attachment_name', 'filepath', **fetch_kwargs): - if item['attachment_name']: + for item in self.fetch("hash", "attachment_name", "filepath", **fetch_kwargs): + if item["attachment_name"]: # attachments - path = self._make_uuid_path(item['hash'], '.' + item['attachment_name']) - elif item['filepath']: + path = self._make_uuid_path(item["hash"], "." + item["attachment_name"]) + elif item["filepath"]: # external filepaths - path = self._make_external_filepath(item['filepath']) + path = self._make_external_filepath(item["filepath"]) else: # blobs - path = self._make_uuid_path(item['hash']) - paths.append((item['hash'], path)) + path = self._make_uuid_path(item["hash"]) + paths.append((item["hash"], path)) return paths def unused(self): @@ -306,18 +373,33 @@ def unused(self): query expression for unused hashes :return: self restricted to elements that are not in use by any tables in the schema """ - return self - [FreeTable(self.connection, ref['referencing_table']).proj(hash=ref['column_name']) - for ref in self.references] + return self - [ + FreeTable(self.connection, ref["referencing_table"]).proj( + hash=ref["column_name"] + ) + for ref in self.references + ] def used(self): """ query expression for used hashes :return: self restricted to elements that in use by tables in the schema """ - return self & [FreeTable(self.connection, ref['referencing_table']).proj(hash=ref['column_name']) - for ref in self.references] - - def delete(self, *, delete_external_files=None, limit=None, display_progress=True, errors_as_string=True): + return self & [ + FreeTable(self.connection, ref["referencing_table"]).proj( + hash=ref["column_name"] + ) + for ref in self.references + ] + + def delete( + self, + *, + delete_external_files=None, + limit=None, + display_progress=True, + errors_as_string=True + ): """ :param delete_external_files: True or False. If False, only the tracking info is removed from the external store table but the external files remain intact. If True, then the external files @@ -330,7 +412,8 @@ def delete(self, *, delete_external_files=None, limit=None, display_progress=Tru if delete_external_files not in (True, False): raise DataJointError( "The delete_external_files argument must be set to either " - "True or False in delete()") + "True or False in delete()" + ) if not delete_external_files: self.unused().delete_quick() @@ -341,20 +424,25 @@ def delete(self, *, delete_external_files=None, limit=None, display_progress=Tru # delete items one by one, close to transaction-safe error_list = [] for uuid, external_path in items: - row = (self & {'hash': uuid}).fetch() + row = (self & {"hash": uuid}).fetch() if row.size: try: - (self & {'hash': uuid}).delete_quick() + (self & {"hash": uuid}).delete_quick() except Exception: - pass # if delete failed, do not remove the external file + pass # if delete failed, do not remove the external file else: try: self._remove_external_file(external_path) except Exception as error: # adding row back into table after failed delete self.insert1(row[0], skip_duplicates=True) - error_list.append((uuid, external_path, - str(error) if errors_as_string else error)) + error_list.append( + ( + uuid, + external_path, + str(error) if errors_as_string else error, + ) + ) return error_list @@ -365,14 +453,18 @@ class ExternalMapping(Mapping): e = ExternalMapping(schema) external_table = e[store] """ + def __init__(self, schema): self.schema = schema self._tables = {} def __repr__(self): - return ("External file tables for schema `{schema}`:\n ".format(schema=self.schema.database) - + "\n ".join('"{store}" {protocol}:{location}'.format( - store=k, **v.spec) for k, v in self.items())) + return "External file tables for schema `{schema}`:\n ".format( + schema=self.schema.database + ) + "\n ".join( + '"{store}" {protocol}:{location}'.format(store=k, **v.spec) + for k, v in self.items() + ) def __getitem__(self, store): """ @@ -383,7 +475,10 @@ def __getitem__(self, store): """ if store not in self._tables: self._tables[store] = ExternalTable( - connection=self.schema.connection, store=store, database=self.schema.database) + connection=self.schema.connection, + store=store, + database=self.schema.database, + ) return self._tables[store] def __len__(self): diff --git a/datajoint/fetch.py b/datajoint/fetch.py index caf17d5ac..8c22ddee0 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -18,11 +18,12 @@ class key: object that allows requesting the primary key as an argument in expression.fetch() The string "KEY" can be used instead of the class key """ + pass def is_key(attr): - return attr is key or attr == 'KEY' + return attr is key or attr == "KEY" def to_dicts(recarray): @@ -45,7 +46,11 @@ def _get(connection, attr, data, squeeze, download_path): if data is None: return - extern = connection.schemas[attr.database].external[attr.store] if attr.is_external else None + extern = ( + connection.schemas[attr.database].external[attr.store] + if attr.is_external + else None + ) # apply attribute adapter if present adapt = attr.adapter.get if attr.adapter else lambda x: x @@ -60,20 +65,33 @@ def _get(connection, attr, data, squeeze, download_path): # 3. if exists and checksum passes then return the local filepath # 4. Otherwise, download the remote file and return the new filepath _uuid = uuid.UUID(bytes=data) if attr.is_external else None - attachment_name = (extern.get_attachment_name(_uuid) if attr.is_external - else data.split(b"\0", 1)[0].decode()) + attachment_name = ( + extern.get_attachment_name(_uuid) + if attr.is_external + else data.split(b"\0", 1)[0].decode() + ) local_filepath = Path(download_path) / attachment_name if local_filepath.is_file(): - attachment_checksum = _uuid if attr.is_external else hash.uuid_from_buffer(data) - if attachment_checksum == hash.uuid_from_file(local_filepath, init_string=attachment_name + '\0'): - return adapt(str(local_filepath)) # checksum passed, no need to download again + attachment_checksum = ( + _uuid if attr.is_external else hash.uuid_from_buffer(data) + ) + if attachment_checksum == hash.uuid_from_file( + local_filepath, init_string=attachment_name + "\0" + ): + return adapt( + str(local_filepath) + ) # checksum passed, no need to download again # generate the next available alias filename for n in itertools.count(): - f = local_filepath.parent / (local_filepath.stem + '_%04x' % n + local_filepath.suffix) + f = local_filepath.parent / ( + local_filepath.stem + "_%04x" % n + local_filepath.suffix + ) if not f.is_file(): local_filepath = f break - if attachment_checksum == hash.uuid_from_file(f, init_string=attachment_name + '\0'): + if attachment_checksum == hash.uuid_from_file( + f, init_string=attachment_name + "\0" + ): return adapt(str(f)) # checksum passed, no need to download again # Save attachment if attr.is_external: @@ -83,9 +101,18 @@ def _get(connection, attr, data, squeeze, download_path): safe_write(local_filepath, data.split(b"\0", 1)[1]) return adapt(str(local_filepath)) # download file from remote store - return adapt(uuid.UUID(bytes=data) if attr.uuid else ( - blob.unpack(extern.get(uuid.UUID(bytes=data)) if attr.is_external else data, squeeze=squeeze) - if attr.is_blob else data)) + return adapt( + uuid.UUID(bytes=data) + if attr.uuid + else ( + blob.unpack( + extern.get(uuid.UUID(bytes=data)) if attr.is_external else data, + squeeze=squeeze, + ) + if attr.is_blob + else data + ) + ) def _flatten_attribute_list(primary_key, attrs): @@ -95,10 +122,10 @@ def _flatten_attribute_list(primary_key, attrs): :return: generator of attributes where "KEY" is replaces with its component attributes """ for a in attrs: - if re.match(r'^\s*KEY(\s+[aA][Ss][Cc])?\s*$', a): + if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): yield from primary_key - elif re.match(r'^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$', a): - yield from (q + ' DESC' for q in primary_key) + elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): + yield from (q + " DESC" for q in primary_key) else: yield a @@ -112,8 +139,17 @@ class Fetch: def __init__(self, expression): self._expression = expression - def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None, as_dict=None, - squeeze=False, download_path='.'): + def __call__( + self, + *attrs, + offset=None, + limit=None, + order_by=None, + format=None, + as_dict=None, + squeeze=False, + download_path="." + ): """ Fetches the expression results from the database into an np.array or list of dictionaries and unpacks blob attributes. @@ -140,68 +176,109 @@ def __call__(self, *attrs, offset=None, limit=None, order_by=None, format=None, if isinstance(order_by, str): order_by = [order_by] # expand "KEY" or "KEY DESC" - order_by = list(_flatten_attribute_list(self._expression.primary_key, order_by)) + order_by = list( + _flatten_attribute_list(self._expression.primary_key, order_by) + ) attrs_as_dict = as_dict and attrs if attrs_as_dict: # absorb KEY into attrs and prepare to return attributes as dict (issue #595) if any(is_key(k) for k in attrs): attrs = list(self._expression.primary_key) + [ - a for a in attrs if a not in self._expression.primary_key] + a for a in attrs if a not in self._expression.primary_key + ] if as_dict is None: as_dict = bool(attrs) # default to True for "KEY" and False otherwise # format should not be specified with attrs or is_dict=True if format is not None and (as_dict or attrs): - raise DataJointError('Cannot specify output format when as_dict=True or ' - 'when attributes are selected to be fetched separately.') + raise DataJointError( + "Cannot specify output format when as_dict=True or " + "when attributes are selected to be fetched separately." + ) if format not in {None, "array", "frame"}: raise DataJointError( - 'Fetch output format must be in ' - '{{"array", "frame"}} but "{}" was given'.format(format)) + "Fetch output format must be in " + '{{"array", "frame"}} but "{}" was given'.format(format) + ) if not (attrs or as_dict) and format is None: - format = config['fetch_format'] # default to array + format = config["fetch_format"] # default to array if format not in {"array", "frame"}: raise DataJointError( 'Invalid entry "{}" in datajoint.config["fetch_format"]: ' - 'use "array" or "frame"'.format(format)) + 'use "array" or "frame"'.format(format) + ) if limit is None and offset is not None: - warnings.warn('Offset set, but no limit. Setting limit to a large number. ' - 'Consider setting a limit explicitly.') + warnings.warn( + "Offset set, but no limit. Setting limit to a large number. " + "Consider setting a limit explicitly." + ) limit = 8000000000 # just a very large number to effect no limit - get = partial(_get, self._expression.connection, - squeeze=squeeze, download_path=download_path) + get = partial( + _get, + self._expression.connection, + squeeze=squeeze, + download_path=download_path, + ) if attrs: # a list of attributes provided attributes = [a for a in attrs if not is_key(a)] ret = self._expression.proj(*attributes) ret = ret.fetch( - offset=offset, limit=limit, order_by=order_by, - as_dict=False, squeeze=squeeze, download_path=download_path, - format='array') + offset=offset, + limit=limit, + order_by=order_by, + as_dict=False, + squeeze=squeeze, + download_path=download_path, + format="array", + ) if attrs_as_dict: - ret = [{k: v for k, v in zip(ret.dtype.names, x) if k in attrs} for x in ret] + ret = [ + {k: v for k, v in zip(ret.dtype.names, x) if k in attrs} + for x in ret + ] else: - return_values = [list( - (to_dicts if as_dict else lambda x: x)(ret[self._expression.primary_key])) - if is_key(attribute) else ret[attribute] - for attribute in attrs] + return_values = [ + list( + (to_dicts if as_dict else lambda x: x)( + ret[self._expression.primary_key] + ) + ) + if is_key(attribute) + else ret[attribute] + for attribute in attrs + ] ret = return_values[0] if len(attrs) == 1 else return_values else: # fetch all attributes as a numpy.record_array or pandas.DataFrame cur = self._expression.cursor( - as_dict=as_dict, limit=limit, offset=offset, order_by=order_by) + as_dict=as_dict, limit=limit, offset=offset, order_by=order_by + ) heading = self._expression.heading if as_dict: - ret = [dict((name, get(heading[name], d[name])) - for name in heading.names) for d in cur] + ret = [ + dict((name, get(heading[name], d[name])) for name in heading.names) + for d in cur + ] else: ret = list(cur.fetchall()) - record_type = (heading.as_dtype if not ret else np.dtype( - [(name, type(value)) # use the first element to determine blob type - if heading[name].is_blob and isinstance(value, numbers.Number) - else (name, heading.as_dtype[name]) - for value, name in zip(ret[0], heading.as_dtype.names)])) + record_type = ( + heading.as_dtype + if not ret + else np.dtype( + [ + ( + name, + type(value), + ) # use the first element to determine blob type + if heading[name].is_blob + and isinstance(value, numbers.Number) + else (name, heading.as_dtype[name]) + for value, name in zip(ret[0], heading.as_dtype.names) + ] + ) + ) try: ret = np.array(ret, dtype=record_type) except Exception as e: @@ -219,10 +296,11 @@ class Fetch1: Fetch object for fetching the result of a query yielding one row. :param expression: a query expression to fetch from. """ + def __init__(self, expression): self._expression = expression - def __call__(self, *attrs, squeeze=False, download_path='.'): + def __call__(self, *attrs, squeeze=False, download_path="."): """ Fetches the result of a query expression that yields one entry. @@ -245,20 +323,36 @@ def __call__(self, *attrs, squeeze=False, download_path='.'): cur = self._expression.cursor(as_dict=True) ret = cur.fetchone() if not ret or cur.fetchone(): - raise DataJointError('fetch1 requires exactly one tuple in the input set.') - ret = dict((name, _get(self._expression.connection, heading[name], ret[name], - squeeze=squeeze, download_path=download_path)) - for name in heading.names) + raise DataJointError( + "fetch1 requires exactly one tuple in the input set." + ) + ret = dict( + ( + name, + _get( + self._expression.connection, + heading[name], + ret[name], + squeeze=squeeze, + download_path=download_path, + ), + ) + for name in heading.names + ) else: # fetch some attributes, return as tuple attributes = [a for a in attrs if not is_key(a)] result = self._expression.proj(*attributes).fetch( - squeeze=squeeze, download_path=download_path, format="array") + squeeze=squeeze, download_path=download_path, format="array" + ) if len(result) != 1: raise DataJointError( - 'fetch1 should only return one tuple. %d tuples found' % len(result)) + "fetch1 should only return one tuple. %d tuples found" % len(result) + ) return_values = tuple( next(to_dicts(result[self._expression.primary_key])) - if is_key(attribute) else result[attribute][0] - for attribute in attrs) + if is_key(attribute) + else result[attribute][0] + for attribute in attrs + ) ret = return_values[0] if len(attrs) == 1 else return_values return ret diff --git a/datajoint/heading.py b/datajoint/heading.py index 076a2204e..14f54e0cf 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -4,36 +4,62 @@ import re import logging from .errors import DataJointError, _support_filepath_types, FILEPATH_FEATURE_SWITCH -from .declare import UUID_DATA_TYPE, SPECIAL_TYPES, TYPE_PATTERN, EXTERNAL_TYPES, NATIVE_TYPES +from .declare import ( + UUID_DATA_TYPE, + SPECIAL_TYPES, + TYPE_PATTERN, + EXTERNAL_TYPES, + NATIVE_TYPES, +) from .attribute_adapter import get_adapter, AttributeAdapter logger = logging.getLogger(__name__) -default_attribute_properties = dict( # these default values are set in computed attributes - name=None, type='expression', in_key=False, nullable=False, default=None, comment='calculated attribute', - autoincrement=False, numeric=None, string=None, uuid=False, is_blob=False, is_attachment=False, is_filepath=False, - is_external=False, adapter=None, - store=None, unsupported=False, attribute_expression=None, database=None, dtype=object) - - -class Attribute(namedtuple('_Attribute', default_attribute_properties)): +default_attribute_properties = ( + dict( # these default values are set in computed attributes + name=None, + type="expression", + in_key=False, + nullable=False, + default=None, + comment="calculated attribute", + autoincrement=False, + numeric=None, + string=None, + uuid=False, + is_blob=False, + is_attachment=False, + is_filepath=False, + is_external=False, + adapter=None, + store=None, + unsupported=False, + attribute_expression=None, + database=None, + dtype=object, + ) +) + + +class Attribute(namedtuple("_Attribute", default_attribute_properties)): """ Properties of a table column (attribute) """ + def todict(self): """Convert namedtuple to dict.""" return dict((name, self[i]) for i, name in enumerate(self._fields)) @property def sql_type(self): - """ :return: datatype (as string) in database. In most cases, it is the same as self.type """ + """:return: datatype (as string) in database. In most cases, it is the same as self.type""" return UUID_DATA_TYPE if self.uuid else self.type @property def sql_comment(self): - """ :return: full comment for the SQL declaration. Includes custom type specification """ - return (':uuid:' if self.uuid else '') + self.comment + """:return: full comment for the SQL declaration. Includes custom type specification""" + return (":uuid:" if self.uuid else "") + self.comment @property def sql(self): @@ -44,14 +70,15 @@ def sql(self): :return: SQL code for attribute declaration """ return '`{name}` {type} NOT NULL COMMENT "{comment}"'.format( - name=self.name, type=self.sql_type, comment=self.sql_comment) + name=self.name, type=self.sql_type, comment=self.sql_comment + ) @property def original_name(self): if self.attribute_expression is None: return self.name - assert self.attribute_expression.startswith('`') - return self.attribute_expression.strip('`') + assert self.attribute_expression.startswith("`") + return self.attribute_expression.strip("`") class Heading: @@ -69,8 +96,11 @@ def __init__(self, attribute_specs=None, table_info=None): self.indexes = None self.table_info = table_info self._table_status = None - self._attributes = None if attribute_specs is None else dict( - (q['name'], Attribute(**q)) for q in attribute_specs) + self._attributes = ( + None + if attribute_specs is None + else dict((q["name"], Attribute(**q)) for q in attribute_specs) + ) def __len__(self): return 0 if self.attributes is None else len(self.attributes) @@ -86,7 +116,7 @@ def table_status(self): @property def attributes(self): if self._attributes is None: - self._init_from_database() # lazy loading from database + self._init_from_database() # lazy loading from database return self._attributes @property @@ -107,11 +137,17 @@ def blobs(self): @property def non_blobs(self): - return [k for k, v in self.attributes.items() if not v.is_blob and not v.is_attachment and not v.is_filepath] + return [ + k + for k, v in self.attributes.items() + if not v.is_blob and not v.is_attachment and not v.is_filepath + ] @property def new_attributes(self): - return [k for k, v in self.attributes.items() if v.attribute_expression is not None] + return [ + k for k, v in self.attributes.items() if v.attribute_expression is not None + ] def __getitem__(self, name): """shortcut to the attribute""" @@ -122,16 +158,18 @@ def __repr__(self): :return: heading representation in DataJoint declaration format but without foreign key expansion """ in_key = True - ret = '' + ret = "" if self._table_status is not None: - ret += '# ' + self.table_status['comment'] + '\n' + ret += "# " + self.table_status["comment"] + "\n" for v in self.attributes.values(): if in_key and not v.in_key: - ret += '---\n' + ret += "---\n" in_key = False - ret += '%-20s : %-28s # %s\n' % ( - v.name if v.default is None else '%s=%s' % (v.name, v.default), - '%s%s' % (v.type, 'auto_increment' if v.autoincrement else ''), v.comment) + ret += "%-20s : %-28s # %s\n" % ( + v.name if v.default is None else "%s=%s" % (v.name, v.default), + "%s%s" % (v.type, "auto_increment" if v.autoincrement else ""), + v.comment, + ) return ret @property @@ -143,180 +181,254 @@ def as_dtype(self): """ represent the heading as a numpy dtype """ - return np.dtype(dict( - names=self.names, - formats=[v.dtype for v in self.attributes.values()])) + return np.dtype( + dict(names=self.names, formats=[v.dtype for v in self.attributes.values()]) + ) def as_sql(self, fields, include_aliases=True): """ represent heading as the SQL SELECT clause. """ - return ','.join( - '`%s`' % name if self.attributes[name].attribute_expression is None - else self.attributes[name].attribute_expression + (' as `%s`' % name if include_aliases else '') - for name in fields) + return ",".join( + "`%s`" % name + if self.attributes[name].attribute_expression is None + else self.attributes[name].attribute_expression + + (" as `%s`" % name if include_aliases else "") + for name in fields + ) def __iter__(self): return iter(self.attributes) def _init_from_database(self): - """ initialize heading from an existing database table. """ + """initialize heading from an existing database table.""" conn, database, table_name, context = ( - self.table_info[k] for k in ('conn', 'database', 'table_name', 'context')) - info = conn.query('SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( - table_name=table_name, database=database), as_dict=True).fetchone() + self.table_info[k] for k in ("conn", "database", "table_name", "context") + ) + info = conn.query( + 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format( + table_name=table_name, database=database + ), + as_dict=True, + ).fetchone() if info is None: - if table_name == '~log': - logger.warning('Could not create the ~log table') + if table_name == "~log": + logger.warning("Could not create the ~log table") return - raise DataJointError('The table `{database}`.`{table_name}` is not defined.'.format( - table_name=table_name, database=database)) + raise DataJointError( + "The table `{database}`.`{table_name}` is not defined.".format( + table_name=table_name, database=database + ) + ) self._table_status = {k.lower(): v for k, v in info.items()} cur = conn.query( - 'SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`'.format( - table_name=table_name, database=database), as_dict=True) + "SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`".format( + table_name=table_name, database=database + ), + as_dict=True, + ) attributes = cur.fetchall() rename_map = { - 'Field': 'name', - 'Type': 'type', - 'Null': 'nullable', - 'Default': 'default', - 'Key': 'in_key', - 'Comment': 'comment'} + "Field": "name", + "Type": "type", + "Null": "nullable", + "Default": "default", + "Key": "in_key", + "Comment": "comment", + } - fields_to_drop = ('Privileges', 'Collation') + fields_to_drop = ("Privileges", "Collation") # rename and drop attributes - attributes = [{rename_map[k] if k in rename_map else k: v - for k, v in x.items() if k not in fields_to_drop} - for x in attributes] + attributes = [ + { + rename_map[k] if k in rename_map else k: v + for k, v in x.items() + if k not in fields_to_drop + } + for x in attributes + ] numeric_types = { - ('float', False): np.float64, - ('float', True): np.float64, - ('double', False): np.float64, - ('double', True): np.float64, - ('tinyint', False): np.int64, - ('tinyint', True): np.int64, - ('smallint', False): np.int64, - ('smallint', True): np.int64, - ('mediumint', False): np.int64, - ('mediumint', True): np.int64, - ('int', False): np.int64, - ('int', True): np.int64, - ('bigint', False): np.int64, - ('bigint', True): np.uint64} - - sql_literals = ['CURRENT_TIMESTAMP'] + ("float", False): np.float64, + ("float", True): np.float64, + ("double", False): np.float64, + ("double", True): np.float64, + ("tinyint", False): np.int64, + ("tinyint", True): np.int64, + ("smallint", False): np.int64, + ("smallint", True): np.int64, + ("mediumint", False): np.int64, + ("mediumint", True): np.int64, + ("int", False): np.int64, + ("int", True): np.int64, + ("bigint", False): np.int64, + ("bigint", True): np.uint64, + } + + sql_literals = ["CURRENT_TIMESTAMP"] # additional attribute properties for attr in attributes: attr.update( - in_key=(attr['in_key'] == 'PRI'), + in_key=(attr["in_key"] == "PRI"), database=database, - nullable=attr['nullable'] == 'YES', - autoincrement=bool(re.search(r'auto_increment', attr['Extra'], flags=re.I)), - numeric=any(TYPE_PATTERN[t].match(attr['type']) for t in ('DECIMAL', 'INTEGER', 'FLOAT')), - string=any(TYPE_PATTERN[t].match(attr['type']) for t in ('ENUM', 'TEMPORAL', 'STRING')), - is_blob=bool(TYPE_PATTERN['INTERNAL_BLOB'].match(attr['type'])), - uuid=False, is_attachment=False, is_filepath=False, adapter=None, - store=None, is_external=False, attribute_expression=None) - - if any(TYPE_PATTERN[t].match(attr['type']) for t in ('INTEGER', 'FLOAT')): - attr['type'] = re.sub(r'\(\d+\)', '', attr['type'], count=1) # strip size off integers and floats - attr['unsupported'] = not any((attr['is_blob'], attr['numeric'], attr['numeric'])) - attr.pop('Extra') + nullable=attr["nullable"] == "YES", + autoincrement=bool( + re.search(r"auto_increment", attr["Extra"], flags=re.I) + ), + numeric=any( + TYPE_PATTERN[t].match(attr["type"]) + for t in ("DECIMAL", "INTEGER", "FLOAT") + ), + string=any( + TYPE_PATTERN[t].match(attr["type"]) + for t in ("ENUM", "TEMPORAL", "STRING") + ), + is_blob=bool(TYPE_PATTERN["INTERNAL_BLOB"].match(attr["type"])), + uuid=False, + is_attachment=False, + is_filepath=False, + adapter=None, + store=None, + is_external=False, + attribute_expression=None, + ) + + if any(TYPE_PATTERN[t].match(attr["type"]) for t in ("INTEGER", "FLOAT")): + attr["type"] = re.sub( + r"\(\d+\)", "", attr["type"], count=1 + ) # strip size off integers and floats + attr["unsupported"] = not any( + (attr["is_blob"], attr["numeric"], attr["numeric"]) + ) + attr.pop("Extra") # process custom DataJoint types - special = re.match(r':(?P[^:]+):(?P.*)', attr['comment']) + special = re.match(r":(?P[^:]+):(?P.*)", attr["comment"]) if special: special = special.groupdict() attr.update(special) # process adapted attribute types - if special and TYPE_PATTERN['ADAPTED'].match(attr['type']): - assert context is not None, 'Declaration context is not set' - adapter_name = special['type'] + if special and TYPE_PATTERN["ADAPTED"].match(attr["type"]): + assert context is not None, "Declaration context is not set" + adapter_name = special["type"] try: attr.update(adapter=get_adapter(context, adapter_name)) except DataJointError: # if no adapter, then delay the error until the first invocation attr.update(adapter=AttributeAdapter()) else: - attr.update(type=attr['adapter'].attribute_type) - if not any(r.match(attr['type']) for r in TYPE_PATTERN.values()): + attr.update(type=attr["adapter"].attribute_type) + if not any(r.match(attr["type"]) for r in TYPE_PATTERN.values()): raise DataJointError( "Invalid attribute type '{type}' in adapter object <{adapter_name}>.".format( - adapter_name=adapter_name, **attr)) - special = not any(TYPE_PATTERN[c].match(attr['type']) for c in NATIVE_TYPES) + adapter_name=adapter_name, **attr + ) + ) + special = not any( + TYPE_PATTERN[c].match(attr["type"]) for c in NATIVE_TYPES + ) if special: try: - category = next(c for c in SPECIAL_TYPES if TYPE_PATTERN[c].match(attr['type'])) + category = next( + c for c in SPECIAL_TYPES if TYPE_PATTERN[c].match(attr["type"]) + ) except StopIteration: - if attr['type'].startswith('external'): - url = "https://docs.datajoint.io/python/admin/5-blob-config.html" \ - "#migration-between-datajoint-v0-11-and-v0-12" - raise DataJointError('Legacy datatype `{type}`. Migrate your external stores to ' - 'datajoint 0.12: {url}'.format(url=url, **attr)) - raise DataJointError('Unknown attribute type `{type}`'.format(**attr)) - if category == 'FILEPATH' and not _support_filepath_types(): - raise DataJointError(""" + if attr["type"].startswith("external"): + url = ( + "https://docs.datajoint.io/python/admin/5-blob-config.html" + "#migration-between-datajoint-v0-11-and-v0-12" + ) + raise DataJointError( + "Legacy datatype `{type}`. Migrate your external stores to " + "datajoint 0.12: {url}".format(url=url, **attr) + ) + raise DataJointError( + "Unknown attribute type `{type}`".format(**attr) + ) + if category == "FILEPATH" and not _support_filepath_types(): + raise DataJointError( + """ The filepath data type is disabled until complete validation. To turn it on as experimental feature, set the environment variable {env} = TRUE or upgrade datajoint. - """.format(env=FILEPATH_FEATURE_SWITCH)) + """.format( + env=FILEPATH_FEATURE_SWITCH + ) + ) attr.update( unsupported=False, - is_attachment=category in ('INTERNAL_ATTACH', 'EXTERNAL_ATTACH'), - is_filepath=category == 'FILEPATH', + is_attachment=category in ("INTERNAL_ATTACH", "EXTERNAL_ATTACH"), + is_filepath=category == "FILEPATH", # INTERNAL_BLOB is not a custom type but is included for completeness - is_blob=category in ('INTERNAL_BLOB', 'EXTERNAL_BLOB'), - uuid=category == 'UUID', + is_blob=category in ("INTERNAL_BLOB", "EXTERNAL_BLOB"), + uuid=category == "UUID", is_external=category in EXTERNAL_TYPES, - store=attr['type'].split('@')[1] if category in EXTERNAL_TYPES else None) - - if attr['in_key'] and any((attr['is_blob'], attr['is_attachment'], attr['is_filepath'])): - raise DataJointError('Blob, attachment, or filepath attributes are not allowed in the primary key') - - if attr['string'] and attr['default'] is not None and attr['default'] not in sql_literals: - attr['default'] = '"%s"' % attr['default'] - - if attr['nullable']: # nullable fields always default to null - attr['default'] = 'null' + store=attr["type"].split("@")[1] + if category in EXTERNAL_TYPES + else None, + ) + + if attr["in_key"] and any( + (attr["is_blob"], attr["is_attachment"], attr["is_filepath"]) + ): + raise DataJointError( + "Blob, attachment, or filepath attributes are not allowed in the primary key" + ) + + if ( + attr["string"] + and attr["default"] is not None + and attr["default"] not in sql_literals + ): + attr["default"] = '"%s"' % attr["default"] + + if attr["nullable"]: # nullable fields always default to null + attr["default"] = "null" # fill out dtype. All floats and non-nullable integers are turned into specific dtypes - attr['dtype'] = object - if attr['numeric'] and not attr['adapter']: - is_integer = TYPE_PATTERN['INTEGER'].match(attr['type']) - is_float = TYPE_PATTERN['FLOAT'].match(attr['type']) - if is_integer and not attr['nullable'] or is_float: - is_unsigned = bool(re.match('sunsigned', attr['type'], flags=re.I)) - t = re.sub(r'\(.*\)', '', attr['type']) # remove parentheses - t = re.sub(r' unsigned$', '', t) # remove unsigned - assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t - attr['dtype'] = numeric_types[(t, is_unsigned)] - - if attr['adapter']: + attr["dtype"] = object + if attr["numeric"] and not attr["adapter"]: + is_integer = TYPE_PATTERN["INTEGER"].match(attr["type"]) + is_float = TYPE_PATTERN["FLOAT"].match(attr["type"]) + if is_integer and not attr["nullable"] or is_float: + is_unsigned = bool(re.match("sunsigned", attr["type"], flags=re.I)) + t = re.sub(r"\(.*\)", "", attr["type"]) # remove parentheses + t = re.sub(r" unsigned$", "", t) # remove unsigned + assert (t, is_unsigned) in numeric_types, ( + "dtype not found for type %s" % t + ) + attr["dtype"] = numeric_types[(t, is_unsigned)] + + if attr["adapter"]: # restore adapted type name - attr['type'] = adapter_name + attr["type"] = adapter_name - self._attributes = dict(((q['name'], Attribute(**q)) for q in attributes)) + self._attributes = dict(((q["name"], Attribute(**q)) for q in attributes)) # Read and tabulate secondary indexes keys = defaultdict(dict) - for item in conn.query('SHOW KEYS FROM `{db}`.`{tab}`'.format(db=database, tab=table_name), as_dict=True): - if item['Key_name'] != 'PRIMARY': - keys[item['Key_name']][item['Seq_in_index']] = dict( - column=item['Column_name'], - unique=(item['Non_unique'] == 0), - nullable=item['Null'].lower() == 'yes') + for item in conn.query( + "SHOW KEYS FROM `{db}`.`{tab}`".format(db=database, tab=table_name), + as_dict=True, + ): + if item["Key_name"] != "PRIMARY": + keys[item["Key_name"]][item["Seq_in_index"]] = dict( + column=item["Column_name"], + unique=(item["Non_unique"] == 0), + nullable=item["Null"].lower() == "yes", + ) self.indexes = { - tuple(item[k]['column'] for k in sorted(item.keys())): - dict(unique=item[1]['unique'], - nullable=any(v['nullable'] for v in item.values())) - for item in keys.values()} + tuple(item[k]["column"] for k in sorted(item.keys())): dict( + unique=item[1]["unique"], + nullable=any(v["nullable"] for v in item.values()), + ) + for item in keys.values() + } def select(self, select_list, rename_map=None, compute_map=None): """ @@ -333,11 +445,21 @@ def select(self, select_list, rename_map=None, compute_map=None): for name in self.attributes: if name in select_list: copy_attrs.append(self.attributes[name].todict()) - copy_attrs.extend(( - dict(self.attributes[old_name].todict(), name=new_name, attribute_expression='`%s`' % old_name) - for new_name, old_name in rename_map.items() if old_name == name)) - compute_attrs = (dict(default_attribute_properties, name=new_name, attribute_expression=expr) - for new_name, expr in compute_map.items()) + copy_attrs.extend( + ( + dict( + self.attributes[old_name].todict(), + name=new_name, + attribute_expression="`%s`" % old_name, + ) + for new_name, old_name in rename_map.items() + if old_name == name + ) + ) + compute_attrs = ( + dict(default_attribute_properties, name=new_name, attribute_expression=expr) + for new_name, expr in compute_map.items() + ) return Heading(chain(copy_attrs, compute_attrs)) def join(self, other): @@ -346,23 +468,49 @@ def join(self, other): It assumes that self and other are headings that share no common dependent attributes. """ return Heading( - [self.attributes[name].todict() for name in self.primary_key] + - [other.attributes[name].todict() for name in other.primary_key if name not in self.primary_key] + - [self.attributes[name].todict() for name in self.secondary_attributes if name not in other.primary_key] + - [other.attributes[name].todict() for name in other.secondary_attributes if name not in self.primary_key]) + [self.attributes[name].todict() for name in self.primary_key] + + [ + other.attributes[name].todict() + for name in other.primary_key + if name not in self.primary_key + ] + + [ + self.attributes[name].todict() + for name in self.secondary_attributes + if name not in other.primary_key + ] + + [ + other.attributes[name].todict() + for name in other.secondary_attributes + if name not in self.primary_key + ] + ) def set_primary_key(self, primary_key): """ Create a new heading with the specified primary key. This low-level method performs no error checking. """ - return Heading(chain( - (dict(self.attributes[name].todict(), in_key=True) for name in primary_key), - (dict(self.attributes[name].todict(), in_key=False) for name in self.names if name not in primary_key))) + return Heading( + chain( + ( + dict(self.attributes[name].todict(), in_key=True) + for name in primary_key + ), + ( + dict(self.attributes[name].todict(), in_key=False) + for name in self.names + if name not in primary_key + ), + ) + ) def make_subquery_heading(self): """ Create a new heading with removed attribute sql_expressions. Used by subqueries, which resolve the sql_expressions. """ - return Heading(dict(v.todict(), attribute_expression=None) for v in self.attributes.values()) + return Heading( + dict(v.todict(), attribute_expression=None) + for v in self.attributes.values() + ) diff --git a/datajoint/jobs.py b/datajoint/jobs.py index 571270931..3e06add4e 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -7,22 +7,22 @@ from .heading import Heading ERROR_MESSAGE_LENGTH = 2047 -TRUNCATION_APPENDIX = '...truncated' +TRUNCATION_APPENDIX = "...truncated" class JobTable(Table): """ A base relation with no definition. Allows reserving jobs """ + def __init__(self, conn, database): self.database = database self._connection = conn - self._heading = Heading(table_info=dict( - conn=conn, - database=database, - table_name=self.table_name, - context=None - )) + self._heading = Heading( + table_info=dict( + conn=conn, database=database, table_name=self.table_name, context=None + ) + ) self._support = [self.full_table_name] self._definition = """ # job reservation table for `{database}` @@ -38,7 +38,9 @@ def __init__(self, conn, database): pid=0 :int unsigned # system process id connection_id = 0 : bigint unsigned # connection_id() timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp - """.format(database=database, error_message_length=ERROR_MESSAGE_LENGTH) + """.format( + database=database, error_message_length=ERROR_MESSAGE_LENGTH + ) if not self.is_declared: self.declare() self._user = self.connection.get_user() @@ -49,7 +51,7 @@ def definition(self): @property def table_name(self): - return '~jobs' + return "~jobs" def delete(self): """bypass interactive prompts and dependencies""" @@ -70,12 +72,13 @@ def reserve(self, table_name, key): job = dict( table_name=table_name, key_hash=key_hash(key), - status='reserved', + status="reserved", host=platform.node(), pid=os.getpid(), connection_id=self.connection.connection_id, key=key, - user=self._user) + user=self._user, + ) try: with config(enable_python_native_blobs=True): self.insert1(job, ignore_extra_fields=True) @@ -102,7 +105,10 @@ def error(self, table_name, key, error_message, error_stack=None): :param error_stack: stack trace """ if len(error_message) > ERROR_MESSAGE_LENGTH: - error_message = error_message[:ERROR_MESSAGE_LENGTH-len(TRUNCATION_APPENDIX)] + TRUNCATION_APPENDIX + error_message = ( + error_message[: ERROR_MESSAGE_LENGTH - len(TRUNCATION_APPENDIX)] + + TRUNCATION_APPENDIX + ) with config(enable_python_native_blobs=True): self.insert1( dict( @@ -115,5 +121,8 @@ def error(self, table_name, key, error_message, error_stack=None): user=self._user, key=key, error_message=error_message, - error_stack=error_stack), - replace=True, ignore_extra_fields=True) + error_stack=error_stack, + ), + replace=True, + ignore_extra_fields=True, + ) diff --git a/datajoint/migrate.py b/datajoint/migrate.py index 445bc317c..9d38dff5c 100644 --- a/datajoint/migrate.py +++ b/datajoint/migrate.py @@ -12,24 +12,34 @@ def migrate_dj011_external_blob_storage_to_dj012(migration_schema, store): """ if not isinstance(migration_schema, str): raise ValueError( - 'Expected type {} for migration_schema, not {}.'.format( - str, type(migration_schema))) + "Expected type {} for migration_schema, not {}.".format( + str, type(migration_schema) + ) + ) do_migration = False - do_migration = user_choice( + do_migration = ( + user_choice( """ Warning: Ensure the following are completed before proceeding. - Appropriate backups have been taken, - Any existing DJ 0.11.X connections are suspended, and - External config has been updated to new dj.config['stores'] structure. Proceed? - """, default='no') == 'yes' + """, + default="no", + ) + == "yes" + ) if do_migration: _migrate_dj011_blob(dj.Schema(migration_schema), store) - print('Migration completed for schema: {}, store: {}.'.format( - migration_schema, store)) + print( + "Migration completed for schema: {}, store: {}.".format( + migration_schema, store + ) + ) return - print('No migration performed.') + print("No migration performed.") def _migrate_dj011_blob(schema, default_store): @@ -38,34 +48,44 @@ def _migrate_dj011_blob(schema, default_store): LEGACY_HASH_SIZE = 43 legacy_external = dj.FreeTable( - schema.connection, - '`{db}`.`~external`'.format(db=schema.database)) + schema.connection, "`{db}`.`~external`".format(db=schema.database) + ) # get referencing tables - refs = [{k.lower(): v for k, v in elem.items()} for elem in query(""" + refs = [ + {k.lower(): v for k, v in elem.items()} + for elem in query( + """ SELECT concat('`', table_schema, '`.`', table_name, '`') as referencing_table, column_name, constraint_name FROM information_schema.key_column_usage WHERE referenced_table_name="{tab}" and referenced_table_schema="{db}" """.format( - tab=legacy_external.table_name, - db=legacy_external.database), as_dict=True).fetchall()] + tab=legacy_external.table_name, db=legacy_external.database + ), + as_dict=True, + ).fetchall() + ] for ref in refs: # get comment column = query( - 'SHOW FULL COLUMNS FROM {referencing_table}' - 'WHERE Field="{column_name}"'.format( - **ref), as_dict=True).fetchone() + "SHOW FULL COLUMNS FROM {referencing_table}" + 'WHERE Field="{column_name}"'.format(**ref), + as_dict=True, + ).fetchone() store, comment = re.match( - r':external(-(?P.+))?:(?P.*)', - column['Comment']).group('store', 'comment') + r":external(-(?P.+))?:(?P.*)", column["Comment"] + ).group("store", "comment") # get all the hashes from the reference - hashes = {x[0] for x in query( - 'SELECT `{column_name}` FROM {referencing_table}'.format( - **ref))} + hashes = { + x[0] + for x in query( + "SELECT `{column_name}` FROM {referencing_table}".format(**ref) + ) + } # sanity check make sure that store suffixes match if store is None: @@ -77,55 +97,69 @@ def _migrate_dj011_blob(schema, default_store): ext = schema.external[store or default_store] # add the new-style reference field - temp_suffix = 'tempsub' + temp_suffix = "tempsub" try: - query("""ALTER TABLE {referencing_table} + query( + """ALTER TABLE {referencing_table} ADD COLUMN `{column_name}_{temp_suffix}` {type} DEFAULT NULL COMMENT ":blob@{store}:{comment}" """.format( - type=dj.declare.UUID_DATA_TYPE, - temp_suffix=temp_suffix, - store=(store or default_store), comment=comment, **ref)) + type=dj.declare.UUID_DATA_TYPE, + temp_suffix=temp_suffix, + store=(store or default_store), + comment=comment, + **ref + ) + ) except: - print('Column already added') + print("Column already added") pass - for _hash, size in zip(*legacy_external.fetch('hash', 'size')): + for _hash, size in zip(*legacy_external.fetch("hash", "size")): if _hash in hashes: relative_path = str(Path(schema.database, _hash).as_posix()) uuid = dj.hash.uuid_from_buffer(init_string=relative_path) external_path = ext._make_external_filepath(relative_path) - if ext.spec['protocol'] == 's3': - contents_hash = dj.hash.uuid_from_buffer(ext._download_buffer(external_path)) + if ext.spec["protocol"] == "s3": + contents_hash = dj.hash.uuid_from_buffer( + ext._download_buffer(external_path) + ) else: contents_hash = dj.hash.uuid_from_file(external_path) - ext.insert1(dict( - filepath=relative_path, - size=size, - contents_hash=contents_hash, - hash=uuid - ), skip_duplicates=True) + ext.insert1( + dict( + filepath=relative_path, + size=size, + contents_hash=contents_hash, + hash=uuid, + ), + skip_duplicates=True, + ) query( - 'UPDATE {referencing_table} ' - 'SET `{column_name}_{temp_suffix}`=%s ' - 'WHERE `{column_name}` = "{_hash}"' - .format( - _hash=_hash, - temp_suffix=temp_suffix, **ref), uuid.bytes) + "UPDATE {referencing_table} " + "SET `{column_name}_{temp_suffix}`=%s " + 'WHERE `{column_name}` = "{_hash}"'.format( + _hash=_hash, temp_suffix=temp_suffix, **ref + ), + uuid.bytes, + ) # check that all have been copied check = query( - 'SELECT * FROM {referencing_table} ' - 'WHERE `{column_name}` IS NOT NULL' - ' AND `{column_name}_{temp_suffix}` IS NULL' - .format(temp_suffix=temp_suffix, **ref)).fetchall() + "SELECT * FROM {referencing_table} " + "WHERE `{column_name}` IS NOT NULL" + " AND `{column_name}_{temp_suffix}` IS NULL".format( + temp_suffix=temp_suffix, **ref + ) + ).fetchall() - assert len(check) == 0, 'Some hashes havent been migrated' + assert len(check) == 0, "Some hashes havent been migrated" # drop old foreign key, rename, and create new foreign key - query(""" + query( + """ ALTER TABLE {referencing_table} DROP FOREIGN KEY `{constraint_name}`, DROP COLUMN `{column_name}`, @@ -138,20 +172,30 @@ def _migrate_dj011_blob(schema, default_store): temp_suffix=temp_suffix, ext_table_name=ext.full_table_name, type=dj.declare.UUID_DATA_TYPE, - store=(store or default_store), comment=comment, **ref)) + store=(store or default_store), + comment=comment, + **ref + ) + ) # Drop the old external table but make sure it's no longer referenced # get referencing tables - refs = [{k.lower(): v for k, v in elem.items()} for elem in query(""" + refs = [ + {k.lower(): v for k, v in elem.items()} + for elem in query( + """ SELECT concat('`', table_schema, '`.`', table_name, '`') as referencing_table, column_name, constraint_name FROM information_schema.key_column_usage WHERE referenced_table_name="{tab}" and referenced_table_schema="{db}" """.format( - tab=legacy_external.table_name, - db=legacy_external.database), as_dict=True).fetchall()] + tab=legacy_external.table_name, db=legacy_external.database + ), + as_dict=True, + ).fetchall() + ] - assert not refs, 'Some references still exist' + assert not refs, "Some references still exist" # drop old external table legacy_external.drop_quick() diff --git a/datajoint/plugin.py b/datajoint/plugin.py index d82e457d1..96f388089 100644 --- a/datajoint/plugin.py +++ b/datajoint/plugin.py @@ -7,32 +7,35 @@ def _update_error_stack(plugin_name): try: - base_name = 'datajoint' + base_name = "datajoint" base_meta = pkg_resources.get_distribution(base_name) plugin_meta = pkg_resources.get_distribution(plugin_name) data = hash_pkg(pkgpath=str(Path(plugin_meta.module_path, plugin_name))) - signature = plugin_meta.get_metadata('{}.sig'.format(plugin_name)) - pubkey_path = str(Path(base_meta.egg_info, '{}.pub'.format(base_name))) + signature = plugin_meta.get_metadata("{}.sig".format(plugin_name)) + pubkey_path = str(Path(base_meta.egg_info, "{}.pub".format(base_name))) verify(pubkey_path=pubkey_path, data=data, signature=signature) - print('DataJoint verified plugin `{}` detected.'.format(plugin_name)) + print("DataJoint verified plugin `{}` detected.".format(plugin_name)) return True except (FileNotFoundError, InvalidSignature): - print('Unverified plugin `{}` detected.'.format(plugin_name)) + print("Unverified plugin `{}` detected.".format(plugin_name)) return False def _import_plugins(category): return { - entry_point.name: dict(object=entry_point, - verified=_update_error_stack( - entry_point.module_name.split('.')[0])) - for entry_point - in pkg_resources.iter_entry_points('datajoint_plugins.{}'.format(category)) - if 'plugin' not in config or category not in config['plugin'] or - entry_point.module_name.split('.')[0] in config['plugin'][category] - } + entry_point.name: dict( + object=entry_point, + verified=_update_error_stack(entry_point.module_name.split(".")[0]), + ) + for entry_point in pkg_resources.iter_entry_points( + "datajoint_plugins.{}".format(category) + ) + if "plugin" not in config + or category not in config["plugin"] + or entry_point.module_name.split(".")[0] in config["plugin"][category] + } -connection_plugins = _import_plugins('connection') -type_plugins = _import_plugins('datatype') +connection_plugins = _import_plugins("connection") +type_plugins = _import_plugins("datatype") diff --git a/datajoint/preview.py b/datajoint/preview.py index f3daeebf5..f761cf533 100644 --- a/datajoint/preview.py +++ b/datajoint/preview.py @@ -7,33 +7,52 @@ def preview(query_expression, limit, width): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) if limit is None: - limit = config['display.limit'] + limit = config["display.limit"] if width is None: - width = config['display.width'] + width = config["display.width"] tuples = rel.fetch(limit=limit + 1, format="array") has_more = len(tuples) > limit tuples = tuples[:limit] columns = heading.names - widths = {f: min(max([len(f)] + - [len(str(e)) for e in tuples[f]] if f in tuples.dtype.names else [len('=BLOB=')]) + 4, width) for f - in columns} - templates = {f: '%%-%d.%ds' % (widths[f], widths[f]) for f in columns} + widths = { + f: min( + max( + [len(f)] + [len(str(e)) for e in tuples[f]] + if f in tuples.dtype.names + else [len("=BLOB=")] + ) + + 4, + width, + ) + for f in columns + } + templates = {f: "%%-%d.%ds" % (widths[f], widths[f]) for f in columns} return ( - ' '.join([templates[f] % ('*' + f if f in rel.primary_key else f) for f in columns]) + '\n' + - ' '.join(['+' + '-' * (widths[column] - 2) + '+' for column in columns]) + '\n' + - '\n'.join(' '.join(templates[f] % (tup[f] if f in tup.dtype.names else '=BLOB=') - for f in columns) for tup in tuples) + - ('\n ...\n' if has_more else '\n') + - (' (Total: %d)\n' % len(rel) if config['display.show_tuple_count'] else '')) + " ".join( + [templates[f] % ("*" + f if f in rel.primary_key else f) for f in columns] + ) + + "\n" + + " ".join(["+" + "-" * (widths[column] - 2) + "+" for column in columns]) + + "\n" + + "\n".join( + " ".join( + templates[f] % (tup[f] if f in tup.dtype.names else "=BLOB=") + for f in columns + ) + for tup in tuples + ) + + ("\n ...\n" if has_more else "\n") + + (" (Total: %d)\n" % len(rel) if config["display.show_tuple_count"] else "") + ) def repr_html(query_expression): heading = query_expression.heading rel = query_expression.proj(*heading.non_blobs) info = heading.table_status - tuples = rel.fetch(limit=config['display.limit'] + 1, format='array') - has_more = len(tuples) > config['display.limit'] - tuples = tuples[0:config['display.limit']] + tuples = rel.fetch(limit=config["display.limit"] + 1, format="array") + has_more = len(tuples) > config["display.limit"] + tuples = tuples[0 : config["display.limit"]] css = """