Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow label based indexing in Rows (incl. test updates) #268

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ profile

# vi noise
*.swp
*~
docs/_build/*
coverage.xml
nosetests.xml
Expand Down
111 changes: 79 additions & 32 deletions tablib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from collections import OrderedDict
from copy import copy
from copy import deepcopy
from operator import itemgetter

from tablib import formats
Expand All @@ -28,13 +28,16 @@


class Row(object):
"""Internal Row object. Mainly used for filtering."""
"""Internal Row object. Mainly used for filtering. Note: To allow label
based indexing Row needs to be aware of the Dataset it belongs to. This is
passed to the constructor's `dset` argument."""

__slots__ = ['_row', 'tags']
__slots__ = ['_row', 'tags', '_dset']

def __init__(self, row=list(), tags=list()):
def __init__(self, row=list(), tags=list(), dset=None):
self._row = list(row)
self.tags = list(tags)
self._dset = dset

def __iter__(self):
return (col for col in self._row)
Expand All @@ -48,14 +51,47 @@ def __repr__(self):
def __getslice__(self, i, j):
return self._row[i:j]

def __getitem__(self, i):
return self._row[i]
def _index(self, key):
"""Returns index for ``key`` (string or int). Raises TypeError if
``key`` is string bt Dataset has no unique headers set and IndexError
if ``key`` is not in headers."""

def __setitem__(self, i, value):
self._row[i] = value
if isinstance(key, (str, unicode)):
if not self._dset._lblidx:
raise TypeError("Cannot access element by key '{0}' - Dataset"
" headers not suitable for indexing".format(key))
try:
i = self._dset.headers.index(key)
except ValueError:
raise IndexError("'{0}' not in Dataset headers".format(key))
else:
i = key

return i

def __getitem__(self, key):
return self._row[self._index(key)]

def __setitem__(self, key, value):
self._row[self._index(key)] = value

def __delitem__(self, key):
del self._row[self._index(key)]

def __add__(self, other):
"""Returns concatenation as plain list. ``other`` can be Row or a
sequence type"""
return self._row + list(other)

def __delitem__(self, i):
del self._row[i]
def __eq__(self, other):
"""Requires ``_row`` and ``tags`` attributes to be equal but not
headers of respective owning Datasets"""
if not isinstance(other, Row):
raise TypeError("Can't compare Row to %s" % type(other))
return self._row == other._row and self.tags == other.tags

def __ne__(self, other):
return not self == other

def __getstate__(self):

Expand Down Expand Up @@ -100,12 +136,8 @@ def has_tag(self, tag):

if tag == None:
return False
elif isinstance(tag, str):
return (tag in self.tags)
else:
return bool(len(set(tag) & set(self.tags)))


return (tag in self.tags)


class Dataset(object):
Expand Down Expand Up @@ -158,8 +190,9 @@ class Dataset(object):
_formats = {}

def __init__(self, *args, **kwargs):
self._data = list(Row(arg) for arg in args)
self._data = list(Row(arg, dset=self) for arg in args)
self.__headers = None
self._lblidx = False

# ('title', index) tuples
self._separators = []
Expand All @@ -173,11 +206,9 @@ def __init__(self, *args, **kwargs):

self._register_formats()


def __len__(self):
return self.height


def __getitem__(self, key):
if isinstance(key, (str, unicode)):
if key in self.headers:
Expand All @@ -188,13 +219,13 @@ def __getitem__(self, key):
else:
_results = self._data[key]
if isinstance(_results, Row):
return _results.tuple
return _results
else:
return [result.tuple for result in _results]
return [result for result in _results]

def __setitem__(self, key, value):
self._validate(value)
self._data[key] = Row(value)
self._data[key] = Row(value, dset=self)


def __delitem__(self, key):
Expand Down Expand Up @@ -340,10 +371,13 @@ def _set_headers(self, collection):
if collection:
try:
self.__headers = list(collection)
self._lblidx = (len(set(collection)) == len(collection))
except TypeError:
self._lblidx = False
raise TypeError
else:
self.__headers = None
self._lblidx = False

headers = property(_get_headers, _set_headers)

Expand Down Expand Up @@ -381,14 +415,14 @@ def _set_dict(self, pickle):
if isinstance(pickle[0], list):
self.wipe()
for row in pickle:
self.append(Row(row))
self.append(Row(row, dset=self))

# if list of objects
elif isinstance(pickle[0], dict):
self.wipe()
self.headers = list(pickle[0].keys())
for row in pickle:
self.append(Row(list(row.values())))
self.append(Row(list(row.values()), dset=self))
else:
raise UnsupportedFormat

Expand Down Expand Up @@ -675,7 +709,7 @@ def insert(self, index, row, tags=list()):
"""

self._validate(row)
self._data.insert(index, Row(row, tags=tags))
self._data.insert(index, Row(row, tags=tags, dset=self))


def rpush(self, row, tags=list()):
Expand Down Expand Up @@ -796,8 +830,7 @@ def insert_col(self, index, col=None, header=None):
row.insert(index, col[i])
self._data[i] = row
else:
self._data = [Row([row]) for row in col]

self._data = [Row([row], dset=self) for row in col]


def rpush_col(self, col, header=None):
Expand Down Expand Up @@ -880,7 +913,7 @@ def filter(self, tag):
"""Returns a new instance of the :class:`Dataset`, excluding any rows
that do not contain the given :ref:`tags <tags>`.
"""
_dset = copy(self)
_dset = self.copy()
_dset._data = [row for row in _dset._data if row.has_tag(tag)]

return _dset
Expand Down Expand Up @@ -949,11 +982,22 @@ def transpose(self):
# Adding the column name as now they're a regular column
# Use `get_col(index)` in case there are repeated values
row_data = [column] + self.get_col(index)
row_data = Row(row_data)
row_data = Row(row_data, dset=self)
_dset.append(row=row_data)
return _dset


def copy(self):
"""Return copy with each Row's Dataset reference set to the new
object"""

_dset = deepcopy(self)
for row in _dset._data:
row._dset = _dset

return _dset


def stack(self, other):
"""Stack two :class:`Dataset` instances together by
joining at the row level, and return new combined
Expand All @@ -965,14 +1009,17 @@ def stack(self, other):
if self.width != other.width:
raise InvalidDimensions

# Copy the source data
_dset = copy(self)
# Copy the source data (updates Dataset reference in Rows)
_dset = self.copy()
_dset.extend(other._data)

"""
rows_to_stack = [row for row in _dset._data]
other_rows = [row for row in other._data]

rows_to_stack.extend(other_rows)
_dset._data = rows_to_stack
"""

return _dset

Expand Down Expand Up @@ -1022,6 +1069,7 @@ def wipe(self):
"""Removes all content and headers from the :class:`Dataset` object."""
self._data = list()
self.__headers = None
self._lblidx = None


def subset(self, rows=None, cols=None):
Expand Down Expand Up @@ -1059,12 +1107,11 @@ def subset(self, rows=None, cols=None):
raise KeyError

if row_no in rows:
_dset.append(row=Row(data_row))
_dset.append(row=Row(data_row, dset=_dset))

return _dset



class Databook(object):
"""A book of :class:`Dataset` objects.
"""
Expand Down
Loading