Skip to content

Commit

Permalink
Add flatten argument to python history api
Browse files Browse the repository at this point in the history
This allows users to decide whether they want fully expanded dataframes for universe and other collection data types. Else, master behavior is kept
  • Loading branch information
jhonabreul committed Oct 31, 2024
1 parent d6f6445 commit 2a67590
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def initialize(self):
self.universe_settings.resolution = Resolution.HOUR
universe = self.add_universe(self.universe.etf(spy, self.universe_settings, self.filter_etf_constituents))

historical_data = self.history(universe, 1)
historical_data = self.history(universe, 1, flatten=True)
if len(historical_data) < 200:
raise ValueError(f"Unexpected universe DataCollection count {len(universe_data_collection)}! Expected > 200")
raise ValueError(f"Unexpected universe DataCollection count {len(historical_data)}! Expected > 200")

### <summary>
### Filters ETF constituents
Expand Down
6 changes: 3 additions & 3 deletions Algorithm.Python/FundamentalRegressionAlgorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def initialize(self):
raise ValueError(f"Unexpected Fundamental count {len(fundamentals)}! Expected 2")

# Request historical fundamental data for symbols
history = self.history(Fundamental, TimeSpan(2, 0, 0, 0))
history = self.history(Fundamental, timedelta(days=2))
if len(history) != 4:
raise ValueError(f"Unexpected Fundamental history count {len(history)}! Expected 4")

Expand All @@ -69,11 +69,11 @@ def initialize(self):

def assert_fundamental_universe_data(self):
# Case A
universe_data = self.history(self._universe.data_type, [self._universe.symbol], TimeSpan(2, 0, 0, 0))
universe_data = self.history(self._universe.data_type, [self._universe.symbol], timedelta(days=2), flatten=True)
self.assert_fundamental_history(universe_data, "A")

# Case B (sugar on A)
universe_data_per_time = self.history(self._universe, TimeSpan(2, 0, 0, 0))
universe_data_per_time = self.history(self._universe, timedelta(days=2), flatten=True)
self.assert_fundamental_history(universe_data_per_time, "B")

# Case C: Passing through the unvierse type and symbol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def initialize(self):

option = self.add_option("GOOG").symbol

historical_options_data_df = self.history(option, 3, Resolution.DAILY)
historical_options_data_df = self.history(option, 3, flatten=True)

# Level 0 of the multi-index is the date, we expect 3 dates, 3 option chains
if historical_options_data_df.index.levshape[0] != 3:
Expand Down
154 changes: 118 additions & 36 deletions Algorithm/QCAlgorithm.Python.cs

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Common/Data/UniverseSelection/BaseDataCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace QuantConnect.Data.UniverseSelection
/// <summary>
/// This type exists for transport of data as a single packet
/// </summary>
[PandasIgnoreMembers]
public class BaseDataCollection : BaseData, IEnumerable<BaseData>
{
/// <summary>
Expand All @@ -38,11 +37,13 @@ public class BaseDataCollection : BaseData, IEnumerable<BaseData>
/// <summary>
/// The associated underlying price data if any
/// </summary>
[PandasNonExpandable]
public BaseData Underlying { get; set; }

/// <summary>
/// Gets or sets the contracts selected by the universe
/// </summary>
[PandasIgnore]
public HashSet<Symbol> FilteredContracts { get; set; }

/// <summary>
Expand All @@ -53,6 +54,7 @@ public class BaseDataCollection : BaseData, IEnumerable<BaseData>
/// <summary>
/// Gets or sets the end time of this data
/// </summary>
[PandasIgnore]
public override DateTime EndTime
{
get
Expand Down
13 changes: 9 additions & 4 deletions Research/QuantBook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,13 @@ public IEnumerable<IEnumerable<BaseData>> UniverseHistory(Universe universe, Dat
/// <param name="end">Optionally the end date, will default to today</param>
/// <param name="func">Optionally the universe selection function</param>
/// <param name="dateRule">Date rule to apply for the history data</param>
/// <param name="flatten">Whether to flatten the resulting data frame.
/// For universe data, the each row represents a day of data, and the data is stored in a list in a cell of the data frame.
/// If flatten is true, the resulting data frame will contain one row per universe constituent,
/// and each property of the constituent will be a column in the data frame.</param>
/// <returns>Enumerable of universe selection data for each date, filtered if the func was provided</returns>
public PyObject UniverseHistory(PyObject universe, DateTime start, DateTime? end = null, PyObject func = null, IDateRule dateRule = null)
public PyObject UniverseHistory(PyObject universe, DateTime start, DateTime? end = null, PyObject func = null, IDateRule dateRule = null,
bool flatten = false)
{
if (universe.TryConvert<Universe>(out var convertedUniverse))
{
Expand All @@ -768,7 +773,7 @@ public PyObject UniverseHistory(PyObject universe, DateTime start, DateTime? end
}
var filteredUniverseSelectionData = RunUniverseSelection(convertedUniverse, start, end, dateRule);

return GetDataFrame(filteredUniverseSelectionData);
return GetDataFrame(filteredUniverseSelectionData, flatten);
}
// for backwards compatibility
if (universe.TryConvert<Type>(out var convertedType) && convertedType.IsAssignableTo(typeof(BaseDataCollection)))
Expand All @@ -777,13 +782,13 @@ public PyObject UniverseHistory(PyObject universe, DateTime start, DateTime? end
var universeSymbol = ((BaseDataCollection)convertedType.GetBaseDataInstance()).UniverseSymbol();
if (func == null)
{
return History(universe, universeSymbol, start, endDate);
return History(universe, universeSymbol, start, endDate, flatten: flatten);
}

var requests = CreateDateRangeHistoryRequests(new[] { universeSymbol }, convertedType, start, endDate);
var history = History(requests);

return GetDataFrame(GetFilteredSlice(history, func, start, endDate, dateRule), convertedType);
return GetDataFrame(GetFilteredSlice(history, func, start, endDate, dateRule), flatten, convertedType);
}

throw new ArgumentException($"Failed to convert given universe {universe}. Please provider a valid {nameof(Universe)}");
Expand Down
89 changes: 86 additions & 3 deletions Tests/Algorithm/AlgorithmHistoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3227,6 +3227,89 @@ def getHistory(algorithm, symbol, period):
}
}

[Test]
public void PythonUniverseHistoryDataFramesAreFlattened()
{
var algorithm = GetAlgorithm(new DateTime(2014, 03, 28));
var universe = algorithm.AddUniverse(x => x.Select(x => x.Symbol));

using (Py.GIL())
{
PythonInitializer.Initialize();
algorithm.SetPandasConverter();

var testModule = PyModule.FromString("PythonHistoryDataFramesAreFlattened",
@"
from AlgorithmImports import *
def getFlattenedUniverseHistory(algorithm, universe, period):
return algorithm.history(universe, period, flatten=True)
def getUnflattenedUniverseHistory(algorithm, universe, period):
return algorithm.history(universe, period)
def assertFlattenedHistoryDates(df, expected_dates):
assert df.index.levels[0].to_list() == expected_dates, f'Expected dates: {expected_dates}, actual dates: {df.index.levels[0].to_list()}'
def assertUnflattenedHistoryDates(df, expected_dates):
assert df.index.to_list() == expected_dates, f'Expected dates: {expected_dates}, actual dates: {df.index.levels[0].to_list()}'
def assertConstituents(flattened_df, unflattened_df, dates, expected_constituents_per_date):
for i, date in enumerate(dates):
unflattened_universe = unflattened_df.loc[date]
assert isinstance(unflattened_universe, list), f'Unflattened DF: expected a list, found {type(unflattened_universe)}'
assert len(unflattened_universe) == expected_constituents_per_date[i], f'Unflattened DF: expected {expected_constituents_per_date[i]} constituents for date {date}, got {len(unflattened_universe)}'
for constituent in unflattened_universe:
assert isinstance(constituent, Fundamental), f'Unflattened DF: expected a list of Fundamental, found {type(constituent)}'
flattened_sub_df = flattened_df.loc[date]
assert flattened_sub_df.shape[0] == len(unflattened_universe), f'Flattened DF: expected {len(unflattened_universe)} rows for date {date}, got {flattened_sub_df.shape[0]}'
flattened_universe_symbols = flattened_sub_df.index.to_list()
unflattened_universe_symbols = [constituent.symbol for constituent in unflattened_universe]
flattened_universe_symbols.sort()
unflattened_universe_symbols.sort()
assert flattened_universe_symbols == unflattened_universe_symbols, f'Flattened DF: flattened universe symbols are not equal to unflattened universe symbols for date {date}'
");
dynamic getFlattenedUniverseHistory = testModule.GetAttr("getFlattenedUniverseHistory");
var flattenedDf = getFlattenedUniverseHistory(algorithm, universe, 3);

dynamic getUnflattenedUniverseHistory = testModule.GetAttr("getUnflattenedUniverseHistory");
var unflattenedDf = getUnflattenedUniverseHistory(algorithm, universe, 3);
// Drop the symbol
unflattenedDf = unflattenedDf.droplevel(0);

var expectedDates = new List<DateTime>
{
new DateTime(2014, 03, 26),
new DateTime(2014, 03, 27),
new DateTime(2014, 03, 28)
};
dynamic assertFlattenedHistoryDates = testModule.GetAttr("assertFlattenedHistoryDates");
AssertDesNotThrowPythonException(() => assertFlattenedHistoryDates(flattenedDf, expectedDates));

dynamic assertUnflattenedHistoryDates = testModule.GetAttr("assertUnflattenedHistoryDates");
AssertDesNotThrowPythonException(() => assertUnflattenedHistoryDates(unflattenedDf, expectedDates));

var expectedConstituentsCounts = new[] { 7068, 7055, 7049 };
dynamic assertConstituents = testModule.GetAttr("assertConstituents");
AssertDesNotThrowPythonException(() => assertConstituents(flattenedDf, unflattenedDf, expectedDates, expectedConstituentsCounts));
}
}

private static void AssertDesNotThrowPythonException(Action action)
{
try
{
action();
}
catch (PythonException ex)
{
Assert.Fail(ex.Message);
}
}

private class ThrowingHistoryProvider : HistoryProviderBase
{
public override int DataPointCount => 0;
Expand Down Expand Up @@ -3343,10 +3426,10 @@ public override BaseData Reader(SubscriptionDataConfig config, string line, Date
}

/// <summary>
/// Represents custom data with an optional sorting functionality. The <see cref="ExampleCustomDataWithSort"/> class
/// Represents custom data with an optional sorting functionality. The <see cref="ExampleCustomDataWithSort"/> class
/// allows you to specify a static property <seealso cref="CustomDataKey"/>, which defines the name of the custom data source.
/// Sorting can be enabled or disabled by setting the <seealso cref="Sort"/> property.
/// This class overrides <see cref="GetSource(SubscriptionDataConfig, DateTime, bool)"/> to initialize the
/// This class overrides <see cref="GetSource(SubscriptionDataConfig, DateTime, bool)"/> to initialize the
/// <seealso cref="SubscriptionDataSource.Sort"/> property based on the value of <see cref="Sort"/>.
/// </summary>
public class ExampleCustomDataWithSort : BaseData
Expand All @@ -3367,7 +3450,7 @@ public class ExampleCustomDataWithSort : BaseData
public decimal Close { get; set; }

/// <summary>
/// Returns the data source for the subscription. It uses the custom data key and sets sorting based on the
/// Returns the data source for the subscription. It uses the custom data key and sets sorting based on the
/// <see cref="Sort"/> property.
/// </summary>
/// <param name="config">Subscription configuration.</param>
Expand Down
33 changes: 22 additions & 11 deletions Tests/Python/PandasConverterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public void TearDown()
}

[Test]
public void HandlesBaseDataCollection()
public void HandlesBaseDataCollection([Values] bool flatten)
{
var converter = new PandasConverter();
var data = new[]
Expand All @@ -87,20 +87,31 @@ public void HandlesBaseDataCollection()
}
};

dynamic dataFrame = converter.GetDataFrame(data);
dynamic dataFrame = converter.GetDataFrame(data, flatten: flatten);

using (Py.GIL())
{
Assert.IsFalse(dataFrame.empty.AsManagedObject(typeof(bool)));

var indexNames = dataFrame.index.names.AsManagedObject(typeof(string[]));
CollectionAssert.AreEqual(new[] { "time", "symbol" }, indexNames);
if (flatten)
{
var indexNames = dataFrame.index.names.AsManagedObject(typeof(string[]));
CollectionAssert.AreEqual(new[] { "time", "symbol" }, indexNames);

Assert.IsFalse(dataFrame.empty.AsManagedObject(typeof(bool)));
Assert.IsFalse(dataFrame.empty.AsManagedObject(typeof(bool)));

var count = dataFrame.__len__().AsManagedObject(typeof(int));
Assert.AreEqual(2, count);
AssertBaseDataCollectionDataFrameTimes(data, dataFrame);
var count = dataFrame.__len__().AsManagedObject(typeof(int));
Assert.AreEqual(2, count);
AssertFlattenBaseDataCollectionDataFrameTimes(data, dataFrame);
}
else
{
var subDataFrame = dataFrame.loc[Symbols.IBM];
Assert.IsFalse(subDataFrame.empty.AsManagedObject(typeof(bool)));

var count = subDataFrame.__len__().AsManagedObject(typeof(int));
Assert.AreEqual(1, count);
}
}
}

Expand Down Expand Up @@ -133,7 +144,7 @@ public void HandlesBaseDataCollectionWithMultipleSymbols()
}
};

dynamic dataFrame = converter.GetDataFrame(data);
dynamic dataFrame = converter.GetDataFrame(data, flatten: true);

using (Py.GIL())
{
Expand Down Expand Up @@ -168,11 +179,11 @@ public void HandlesBaseDataCollectionWithMultipleSymbols()
}
});

AssertBaseDataCollectionDataFrameTimes(data, dataFrame);
AssertFlattenBaseDataCollectionDataFrameTimes(data, dataFrame);
}
}

private static void AssertBaseDataCollectionDataFrameTimes(EnumerableData[] data, dynamic dataFrame)
private static void AssertFlattenBaseDataCollectionDataFrameTimes(EnumerableData[] data, dynamic dataFrame)
{
// For base data collections, the end time of each data point is added as a column
// And the time in the index is the collection's time
Expand Down
Loading

0 comments on commit 2a67590

Please sign in to comment.