Skip to content

Commit 8ecf5d0

Browse files
committed
Fix vector backtest capital allocation
1 parent 554279f commit 8ecf5d0

File tree

9 files changed

+136
-54
lines changed

9 files changed

+136
-54
lines changed

investing_algorithm_framework/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
get_rolling_sharpe_ratio_chart, rank_results, \
88
get_monthly_returns_heatmap_chart, create_weights, \
99
get_yearly_returns_bar_chart, get_entry_and_exit_signals, \
10-
get_ohlcv_data_completeness_chart
10+
get_ohlcv_data_completeness_chart, get_equity_curve_chart
1111
from .domain import ApiException, combine_backtests, PositionSize, \
1212
OrderType, OperationalException, OrderStatus, OrderSide, \
1313
TimeUnit, TimeInterval, Order, Portfolio, Backtest, \
@@ -174,5 +174,6 @@
174174
"get_cumulative_return_series",
175175
"get_total_loss",
176176
"get_total_growth",
177-
"generate_backtest_summary_metrics"
177+
"generate_backtest_summary_metrics",
178+
"get_equity_curve_chart"
178179
]

investing_algorithm_framework/app/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_equity_curve_with_drawdown_chart, \
1212
get_rolling_sharpe_ratio_chart, \
1313
get_monthly_returns_heatmap_chart, \
14-
get_yearly_returns_bar_chart, \
14+
get_yearly_returns_bar_chart, get_equity_curve_chart, \
1515
get_ohlcv_data_completeness_chart, get_entry_and_exit_signals
1616
from .analysis import select_backtest_date_ranges, rank_results, \
1717
create_weights
@@ -40,5 +40,6 @@
4040
"get_ohlcv_data_completeness_chart",
4141
"rank_results",
4242
"create_weights",
43-
"get_entry_and_exit_signals"
43+
"get_entry_and_exit_signals",
44+
"get_equity_curve_chart"
4445
]

investing_algorithm_framework/app/app.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,9 @@ def run_vector_backtests(
798798
snapshot_interval: SnapshotInterval = SnapshotInterval.DAILY,
799799
risk_free_rate: Optional[float] = None,
800800
skip_data_sources_initialization: bool = False,
801-
show_progress: bool = True
801+
show_progress: bool = True,
802+
market: Optional[str] = None,
803+
trading_symbol: Optional[str] = None,
802804
) -> List[Backtest]:
803805
"""
804806
Run vectorized backtests for a set of strategies. The provided
@@ -840,6 +842,12 @@ def run_vector_backtests(
840842
show_progress (bool): Whether to show progress bars during
841843
data source initialization. This is useful for long-running
842844
initialization processes.
845+
market (str): The market to use for the backtest. This is used
846+
to create a portfolio configuration if no portfolio
847+
configuration is provided in the strategy.
848+
trading_symbol (str): The trading symbol to use for the backtest.
849+
This is used to create a portfolio configuration if no
850+
portfolio configuration is provided in the strategy.
843851
844852
Returns:
845853
List[Backtest]: List of Backtest instances for each strategy
@@ -887,7 +895,9 @@ def run_vector_backtests(
887895
strategy=strategy,
888896
snapshot_interval=snapshot_interval,
889897
risk_free_rate=risk_free_rate,
890-
skip_data_sources_initialization=True
898+
skip_data_sources_initialization=True,
899+
market=market,
900+
trading_symbol=trading_symbol
891901
)
892902
backtests.append(backtest)
893903
else:
@@ -938,13 +948,15 @@ def run_vector_backtests(
938948
def run_vector_backtest(
939949
self,
940950
backtest_date_range: BacktestDateRange,
941-
initial_amount,
942951
strategy: TradingStrategy,
943952
snapshot_interval: SnapshotInterval = SnapshotInterval.DAILY,
944953
metadata: Optional[Dict[str, str]] = None,
945954
risk_free_rate: Optional[float] = None,
946955
skip_data_sources_initialization: bool = False,
947-
show_data_initialization_progress: bool = True
956+
show_data_initialization_progress: bool = True,
957+
initial_amount: float = None,
958+
market: str = None,
959+
trading_symbol: str = None
948960
) -> Backtest:
949961
"""
950962
Run vectorized backtests for a strategy. The provided
@@ -982,6 +994,20 @@ def run_vector_backtest(
982994
initialized before calling this method.
983995
show_data_initialization_progress (bool): Whether to show the
984996
progress bar when initializing data sources.
997+
market (str): The market to use for the backtest. This is used
998+
to create a portfolio configuration if no portfolio
999+
configuration is found for the strategy. If not provided,
1000+
the first portfolio configuration found will be used.
1001+
trading_symbol (str): The trading symbol to use for the backtest.
1002+
This is used to create a portfolio configuration if no
1003+
portfolio configuration is found for the strategy. If not
1004+
provided, the first trading symbol found in the portfolio
1005+
configuration will be used.
1006+
initial_amount (float): The initial amount to start the
1007+
backtest with. This will be the amount of trading currency
1008+
that the portfolio will start with. If not provided,
1009+
the initial amount from the portfolio configuration will
1010+
be used.
9851011
9861012
Returns:
9871013
Backtest: Instance of Backtest
@@ -1017,8 +1043,10 @@ def run_vector_backtest(
10171043
run = backtest_service.create_vector_backtest(
10181044
strategy=strategy,
10191045
backtest_date_range=backtest_date_range,
1020-
initial_amount=initial_amount,
1021-
risk_free_rate=risk_free_rate
1046+
risk_free_rate=risk_free_rate,
1047+
market=market,
1048+
trading_symbol=trading_symbol,
1049+
initial_amount=initial_amount
10221050
)
10231051
backtest = Backtest(
10241052
backtest_runs=[run],
@@ -1256,6 +1284,8 @@ def run_permutation_test(
12561284
backtest_date_range: BacktestDateRange,
12571285
number_of_permutations: int = 100,
12581286
initial_amount: float = 1000.0,
1287+
market: str = None,
1288+
trading_symbol: str = None,
12591289
risk_free_rate: Optional[float] = None
12601290
) -> BacktestPermutationTest:
12611291
"""
@@ -1281,6 +1311,15 @@ def run_permutation_test(
12811311
risk_free_rate (Optional[float]): The risk-free rate to use for
12821312
the backtest metrics. If not provided, it will try to fetch
12831313
the risk-free rate from the US Treasury website.
1314+
market (str): The market to use for the backtest. This is used
1315+
to create a portfolio configuration if no portfolio
1316+
configuration is found for the strategy. If not provided,
1317+
the first portfolio configuration found will be used.
1318+
trading_symbol (str): The trading symbol to use for the backtest.
1319+
This is used to create a portfolio configuration if no
1320+
portfolio configuration is found for the strategy. If not
1321+
provided, the first trading symbol found in the portfolio
1322+
configuration will be used.
12841323
12851324
Raises:
12861325
OperationalException: If the risk-free rate cannot be retrieved.
@@ -1309,6 +1348,8 @@ def run_permutation_test(
13091348
strategy=strategy,
13101349
snapshot_interval=SnapshotInterval.DAILY,
13111350
risk_free_rate=risk_free_rate,
1351+
market=market,
1352+
trading_symbol=trading_symbol
13121353
)
13131354
backtest_metrics = backtest.get_backtest_metrics(backtest_date_range)
13141355

@@ -1388,7 +1429,9 @@ def run_permutation_test(
13881429
strategy=strategy,
13891430
snapshot_interval=SnapshotInterval.DAILY,
13901431
risk_free_rate=risk_free_rate,
1391-
skip_data_sources_initialization=True
1432+
skip_data_sources_initialization=True,
1433+
market=market,
1434+
trading_symbol=trading_symbol
13921435
)
13931436

13941437
# Add the results of the permuted backtest to the main backtest

investing_algorithm_framework/app/reporting/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
get_monthly_returns_heatmap_chart, \
88
get_yearly_returns_bar_chart, \
99
get_ohlcv_data_completeness_chart, \
10-
get_entry_and_exit_signals
10+
get_entry_and_exit_signals, \
11+
get_equity_curve_chart
1112

1213
__all__ = [
1314
"add_html_report",
@@ -21,5 +22,6 @@
2122
"get_monthly_returns_heatmap_chart",
2223
"get_yearly_returns_bar_chart",
2324
"get_ohlcv_data_completeness_chart",
24-
"get_entry_and_exit_signals"
25+
"get_entry_and_exit_signals",
26+
"get_equity_curve_chart"
2527
]

investing_algorithm_framework/app/reporting/charts/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .equity_curve_drawdown import get_equity_curve_with_drawdown_chart
2+
from .equity_curve import get_equity_curve_chart
23
from .rolling_sharp_ratio import get_rolling_sharpe_ratio_chart
34
from .monthly_returns_heatmap import get_monthly_returns_heatmap_chart
45
from .yearly_returns_barchart import get_yearly_returns_bar_chart
@@ -14,4 +15,5 @@
1415
"get_ohlcv_data_completeness_chart",
1516
"get_entry_and_exit_signals",
1617
"create_line_scatter",
18+
"get_equity_curve_chart"
1719
]

investing_algorithm_framework/services/backtesting/backtest_service.py

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from investing_algorithm_framework.domain import BacktestRun, OrderType, \
1414
TimeUnit, Trade, OperationalException, BacktestDateRange, TimeFrame, \
1515
Backtest, TradeStatus, PortfolioSnapshot, Order, OrderStatus, OrderSide, \
16-
Portfolio, DataType, generate_backtest_summary_metrics
16+
Portfolio, DataType, generate_backtest_summary_metrics, \
17+
PortfolioConfiguration
1718
from investing_algorithm_framework.services.data_providers import \
1819
DataProviderService
1920
from investing_algorithm_framework.services.portfolios import \
@@ -114,8 +115,10 @@ def create_vector_backtest(
114115
self,
115116
strategy,
116117
backtest_date_range: BacktestDateRange,
117-
initial_amount: float,
118118
risk_free_rate: float = 0.027,
119+
initial_amount: float = None,
120+
trading_symbol: str = None,
121+
market: str = None,
119122
) -> BacktestRun:
120123
"""
121124
Vectorized backtest for multiple assets using strategy
@@ -124,9 +127,16 @@ def create_vector_backtest(
124127
Args:
125128
strategy: The strategy to backtest.
126129
backtest_date_range: The date range for the backtest.
127-
initial_amount: The initial amount to use for the backtest.
128130
risk_free_rate: The risk-free rate to use for the backtest
129131
metrics. Default is 0.027 (2.7%).
132+
initial_amount: The initial amount to use for the backtest.
133+
If None, the initial amount will be taken from the first
134+
portfolio configuration.
135+
trading_symbol: The trading symbol to use for the backtest.
136+
If None, the trading symbol will be taken from the first
137+
portfolio configuration.
138+
market: The market to use for the backtest. If None, the market
139+
will be taken from the first portfolio configuration.
130140
131141
Returns:
132142
BacktestRun: The backtest run containing the results and metrics.
@@ -135,12 +145,26 @@ def create_vector_backtest(
135145
.get_all()
136146

137147
if (portfolio_configurations is None
138-
or len(portfolio_configurations) == 0):
148+
or len(portfolio_configurations) == 0
149+
and initial_amount is None
150+
or trading_symbol is None
151+
or market is None):
139152
raise OperationalException(
140153
"No portfolio configurations found, please register a "
141154
"portfolio configuration before running a backtest."
142155
)
143156

157+
if portfolio_configurations is None \
158+
or len(portfolio_configurations) == 0:
159+
portfolio_configurations = []
160+
portfolio_configurations.append(
161+
PortfolioConfiguration(
162+
market=market,
163+
trading_symbol=trading_symbol,
164+
initial_balance=initial_amount
165+
)
166+
)
167+
144168
trading_symbol = portfolio_configurations[0].trading_symbol
145169

146170
# Load vectorized backtest data
@@ -188,8 +212,6 @@ def create_vector_backtest(
188212
total_net_gain=0.0
189213
)
190214
]
191-
unallocated = initial_amount
192-
total_values = pd.Series(0.0, index=index)
193215

194216
for symbol in buy_signals.keys():
195217
full_symbol = f"{symbol}/{trading_symbol}"
@@ -221,7 +243,6 @@ def create_vector_backtest(
221243
returns = close.pct_change().fillna(0)
222244
returns = returns.astype(float)
223245
signal = signal.astype(float)
224-
strategy_returns = signal * returns
225246

226247
if pos_size_obj is None:
227248
raise OperationalException(
@@ -242,9 +263,6 @@ def create_vector_backtest(
242263
asset_price=close.iloc[0]
243264
)
244265

245-
holdings = (strategy_returns + 1).cumprod() * capital_for_trade
246-
total_values += holdings
247-
248266
# Trade generation
249267
last_trade = None
250268

@@ -296,7 +314,6 @@ def create_vector_backtest(
296314
)
297315
last_trade = trade
298316
trades.append(trade)
299-
unallocated -= capital_for_trade
300317

301318
# If we are in a position, and we get a sell signal
302319
if current_signal == -1 and last_trade is not None:
@@ -327,42 +344,50 @@ def create_vector_backtest(
327344
"net_gain": net_gain_val
328345
}
329346
)
330-
unallocated += last_trade.available_amount * current_price
331347
last_trade = None
332348

349+
unallocated = initial_amount
350+
total_net_gain = 0.0
351+
open_trades = []
352+
333353
# Create portfolio snapshots
334354
for ts in index:
335-
invested_value = 0.0
355+
allocated = 0
356+
interval_datetime = pd.Timestamp(ts).to_pydatetime()
357+
interval_datetime = interval_datetime.replace(tzinfo=timezone.utc)
336358

337359
for trade in trades:
338-
if trade.opened_at <= ts and (
339-
trade.closed_at is None or trade.closed_at >= ts):
340-
341-
# Trade is still open at this time
342-
ohlcv = granular_ohlcv_data_order_by_symbol[trade.symbol]
343-
344-
# Datetime is the index for pandas DataFrame, find the
345-
# closest timestamp that is less than or equal to ts
346-
# prices = ohlcv.loc[ohlcv.index <= ts, "Close"].values
347-
#
348-
# if len(prices) == 0:
349-
# # No price data for this timestamp
350-
# price = trade.open_price
351-
# else:
352-
# price = prices[-1]
353-
try:
354-
price = ohlcv.loc[:ts, "Close"].iloc[-1]
355-
except IndexError:
356-
continue # skip if no price yet
357-
358-
invested_value += trade.filled_amount * price
359-
total_value = invested_value + unallocated
360-
total_net_gain = total_value - initial_amount
360+
361+
if trade.opened_at == interval_datetime:
362+
# Snapshot taken at the moment a trade is opened
363+
unallocated -= trade.cost
364+
open_trades.append(trade)
365+
366+
if trade.closed_at == interval_datetime:
367+
# Snapshot taken at the moment a trade is closed
368+
unallocated += trade.cost + trade.net_gain
369+
total_net_gain += trade.net_gain
370+
open_trades.remove(trade)
371+
372+
for open_trade in open_trades:
373+
ohlcv = granular_ohlcv_data_order_by_symbol[
374+
f"{open_trade.target_symbol}/{trading_symbol}"
375+
]
376+
377+
try:
378+
price = ohlcv.loc[:ts, "Close"].iloc[-1]
379+
except IndexError:
380+
continue # skip if no price yet
381+
382+
allocated += open_trade.filled_amount * price
383+
384+
# total_value = invested_value + unallocated
385+
# total_net_gain = total_value - initial_amount
361386
snapshots.append(
362387
PortfolioSnapshot(
363-
created_at=pd.Timestamp(ts),
388+
created_at=interval_datetime,
364389
unallocated=unallocated,
365-
total_value=total_value,
390+
total_value=unallocated + allocated,
366391
total_net_gain=total_net_gain
367392
)
368393
)

tests/scenarios/permutation_tests/test_permutation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ def test_run(self):
260260
)
261261
backtests = app.run_permutation_test(
262262
initial_amount=1000,
263+
market="bitvavo",
264+
trading_symbol="EUR",
263265
backtest_date_range=date_range,
264266
strategy=strategy,
265267
number_of_permutations=50

tests/scenarios/test_vector_vs_event_backtest_results.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def test_run(self):
271271
backtest_date_range=date_range,
272272
strategy=strategy,
273273
snapshot_interval=SnapshotInterval.DAILY,
274-
risk_free_rate=0.027
274+
risk_free_rate=0.027,
275+
trading_symbol="EUR",
276+
market="BITVAVO"
275277
)
276278
run = vector_backtests.backtest_runs[0]
277279
end_time = time.time()

0 commit comments

Comments
 (0)