Skip to content

Commit f46757c

Browse files
committed
Fix tests and flake8 warnings
1 parent 6f0cade commit f46757c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+57772
-11667
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22
exclude =
3-
investing_algorithm_framework/domain/utils/backtesting.py
3+
investing_algorithm_framework/app/reporting/
44
investing_algorithm_framework/infrastructure/database/sql_alchemy.py
55
examples
66
investing_algorithm_framework/metrics

README.md

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ The following example connects to Binance and buys BTC every 2 hours.
7979
import logging.config
8080
from dotenv import load_dotenv
8181

82-
from investing_algorithm_framework import create_app, TimeUnit, Context, \
83-
CCXTOHLCVMarketDataSource, CCXTTickerMarketDataSource, DEFAULT_LOGGING_CONFIG
82+
from investing_algorithm_framework import create_app, TimeUnit, Context, BacktestDateRange, \
83+
CCXTOHLCVMarketDataSource, CCXTTickerMarketDataSource, DEFAULT_LOGGING_CONFIG, \
84+
TradingStrategy, SnapshotInterval
8485

8586
load_dotenv()
8687
logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
@@ -104,24 +105,27 @@ app = create_app()
104105
# Registered bitvavo market, credentials are read from .env file by default
105106
app.add_market(market="BITVAVO", trading_symbol="EUR", initial_balance=100)
106107

107-
# Define a strategy for the algorithm that will run every 10 seconds
108-
@app.strategy(
109-
time_unit=TimeUnit.SECOND,
110-
interval=10,
111-
market_data_sources=[bitvavo_btc_eur_ticker, bitvavo_btc_eur_ohlcv_2h]
108+
class MyStrategy(TradingStrategy):
109+
interval = 2
110+
time_unit = TimeUnit.HOUR
111+
data_sources = [bitvavo_btc_eur_ohlcv_2h, bitvavo_btc_eur_ticker]
112+
113+
def run_strategy(self, context: Context, market_data):
114+
# Access the data sources with the indentifier
115+
polars_df = market_data["BTC-ohlcv"]
116+
ticker_data = market_data["BTC-ticker"]
117+
unallocated_balance = context.get_unallocated()
118+
positions = context.get_positions()
119+
trades = context.get_trades()
120+
open_trades = context.get_open_trades()
121+
closed_trades = context.get_closed_trades()
122+
123+
date_range = BacktestDateRange(
124+
start_date="2023-08-24 00:00:00",
125+
end_date="2023-12-02 00:00:00"
112126
)
113-
def perform_strategy(context: Context, market_data: dict):
114-
# Access the data sources with the indentifier
115-
polars_df = market_data["BTC-ohlcv"]
116-
ticker_data = market_data["BTC-ticker"]
117-
unallocated_balance = context.get_unallocated()
118-
positions = context.get_positions()
119-
trades = context.get_trades()
120-
open_trades = context.get_open_trades()
121-
closed_trades = context.get_closed_trades()
122-
123-
if __name__ == "__main__":
124-
app.run()
127+
backtest_report = app.run_backtest(backtest_date_range=date_range, initial_amount=100, snapshot_interval=SnapshotInterval.STRATEGY_ITERATION)
128+
backtest_report.show()
125129
```
126130

127131
> You can find more examples [here](./examples) folder.

investing_algorithm_framework/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from investing_algorithm_framework.app import App, Algorithm, \
22
TradingStrategy, StatelessAction, Task, AppHook, Context, \
33
add_html_report, add_metrics, generate_report, BacktestReport, \
4-
BacktestReportsEvaluation
4+
BacktestReportsEvaluation, pretty_print_trades, pretty_print_positions, \
5+
pretty_print_orders, pretty_print_backtest
56
from investing_algorithm_framework.domain import ApiException, \
67
TradingDataType, TradingTimeFrame, OrderType, OperationalException, \
78
OrderStatus, OrderSide, TimeUnit, TimeInterval, Order, Portfolio, \
@@ -59,9 +60,7 @@
5960
"MarketCredential",
6061
"MarketService",
6162
"OperationalException",
62-
"pretty_print_backtest_reports_evaluation",
6363
"BacktestReportsEvaluation",
64-
"load_backtest_reports",
6564
"SYMBOLS",
6665
"RESERVED_BALANCES",
6766
"APP_MODE",
@@ -71,7 +70,6 @@
7170
"BacktestDateRange",
7271
"convert_polars_to_pandas",
7372
"DateRange",
74-
"get_backtest_report",
7573
"AzureBlobStorageStateHandler",
7674
"DEFAULT_LOGGING_CONFIG",
7775
"BacktestReport",

investing_algorithm_framework/app/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .algorithm import Algorithm
77
from .context import Context
88
from .reporting import generate_report, add_html_report, add_metrics, \
9-
BacktestReport, pretty_print_backtest, BacktestReportsEvaluation
9+
BacktestReport, pretty_print_backtest, BacktestReportsEvaluation, \
10+
pretty_print_trades, pretty_print_positions, pretty_print_orders
1011

1112

1213
__all__ = [
@@ -23,5 +24,8 @@
2324
"add_metrics",
2425
"BacktestReport",
2526
"pretty_print_backtest",
26-
"BacktestReportsEvaluation"
27+
"BacktestReportsEvaluation",
28+
"pretty_print_trades",
29+
"pretty_print_positions",
30+
"pretty_print_orders"
2731
]

investing_algorithm_framework/app/app.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import threading
55
from time import sleep
66
from typing import List, Optional, Any
7-
from datetime import timedelta
87

98
from flask import Flask
109

@@ -28,7 +27,6 @@
2827
BacktestMarketDataSourceService, BacktestPortfolioService
2928
from .app_hook import AppHook
3029
from .reporting import BacktestReport
31-
from investing_algorithm_framework.domain.models.backtesting.backtest_results import BacktestResult
3230

3331
logger = logging.getLogger("investing_algorithm_framework")
3432
COLOR_RESET = '\033[0m'
@@ -807,13 +805,14 @@ def run_backtest(
807805

808806
# Run the backtest with the backtest_service and collect and
809807
# save the report
810-
report = backtest_service.run_backtest(
808+
results = backtest_service.run_backtest(
811809
algorithm=algorithm,
812810
context=context,
813811
strategy_orchestrator_service=strategy_orchestrator_service,
814812
initial_amount=initial_amount,
815813
backtest_date_range=backtest_date_range
816814
)
815+
report = BacktestReport(results=results)
817816
backtest_service.save_report(
818817
report=report,
819818
algorithm=algorithm,
@@ -1570,41 +1569,3 @@ def get_algorithm(self):
15701569
data_sources=self._market_data_sources,
15711570
on_strategy_run_hooks=self._on_strategy_run_hooks,
15721571
)
1573-
1574-
def run_backtest(self, backtest_date_range, algorithm, snapshot_interval):
1575-
"""
1576-
Run a backtest for the given date range and algorithm.
1577-
1578-
Args:
1579-
backtest_date_range (BacktestDateRange): The date range for the backtest.
1580-
algorithm (Algorithm): The algorithm to use for the backtest.
1581-
snapshot_interval (str): The interval for portfolio snapshots.
1582-
1583-
Returns:
1584-
BacktestResult: The result of the backtest.
1585-
"""
1586-
current_date = backtest_date_range.start_date
1587-
end_date = backtest_date_range.end_date
1588-
last_snapshot_date = None
1589-
1590-
while current_date <= end_date:
1591-
# Run the algorithm for the current date
1592-
algorithm.run(current_date)
1593-
1594-
# Check if a snapshot should be created
1595-
if snapshot_interval == "daily":
1596-
if last_snapshot_date is None or current_date.date() > last_snapshot_date.date():
1597-
self.container.portfolio_snapshot_service().create_snapshot(
1598-
portfolio=self.container.portfolio_service().get_default_portfolio(),
1599-
created_at=current_date
1600-
)
1601-
last_snapshot_date = current_date
1602-
1603-
# Increment the current date
1604-
current_date += timedelta(days=1)
1605-
1606-
return BacktestResult(
1607-
backtest_date_range=backtest_date_range,
1608-
algorithm=algorithm,
1609-
snapshot_interval=snapshot_interval
1610-
)
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .generate import add_html_report, add_metrics, generate_report
22
from .backtest_report import BacktestReport
3-
from .ascii import pretty_print_backtest
3+
from .ascii import pretty_print_backtest, pretty_print_positions, \
4+
pretty_print_trades, pretty_print_orders
45
from .evaluation import BacktestReportsEvaluation
56

67
__all__ = [
@@ -9,5 +10,8 @@
910
"generate_report",
1011
"BacktestReport",
1112
"pretty_print_backtest",
12-
"BacktestReportsEvaluation"
13+
"BacktestReportsEvaluation",
14+
"pretty_print_positions",
15+
"pretty_print_trades",
16+
"pretty_print_orders"
1317
]

0 commit comments

Comments
 (0)