Skip to content

Commit b343580

Browse files
committed
Standardize signal generation
1 parent 20161d8 commit b343580

File tree

31 files changed

+830
-703
lines changed

31 files changed

+830
-703
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from typing import Dict, Any
2+
from datetime import datetime, timezone
3+
4+
import pandas as pd
5+
from pyindicators import ema, rsi, crossover, crossunder
6+
7+
from investing_algorithm_framework import TradingStrategy, DataSource, \
8+
TimeUnit, DataType, PositionSize, create_app, RESOURCE_DIRECTORY, \
9+
BacktestDateRange, BacktestReport
10+
11+
12+
class RSIEMACrossoverStrategy(TradingStrategy):
13+
time_unit = TimeUnit.HOUR
14+
interval = 2
15+
symbols = ["BTC"]
16+
position_sizes = [
17+
PositionSize(
18+
symbol="BTC", percentage_of_portfolio=20.0
19+
),
20+
PositionSize(
21+
symbol="ETH", percentage_of_portfolio=20.0
22+
)
23+
]
24+
25+
def __init__(
26+
self,
27+
time_unit: TimeUnit,
28+
interval: int,
29+
market: str,
30+
rsi_time_frame: str,
31+
rsi_period: int,
32+
rsi_overbought_threshold,
33+
rsi_oversold_threshold,
34+
ema_time_frame,
35+
ema_short_period,
36+
ema_long_period,
37+
ema_cross_lookback_window: int = 10
38+
):
39+
self.rsi_time_frame = rsi_time_frame
40+
self.rsi_period = rsi_period
41+
self.rsi_result_column = f"rsi_{self.rsi_period}"
42+
self.rsi_overbought_threshold = rsi_overbought_threshold
43+
self.rsi_oversold_threshold = rsi_oversold_threshold
44+
self.ema_time_frame = ema_time_frame
45+
self.ema_short_result_column = f"ema_{ema_short_period}"
46+
self.ema_long_result_column = f"ema_{ema_long_period}"
47+
self.ema_crossunder_result_column = "ema_crossunder"
48+
self.ema_crossover_result_column = "ema_crossover"
49+
self.ema_short_period = ema_short_period
50+
self.ema_long_period = ema_long_period
51+
self.ema_cross_lookback_window = ema_cross_lookback_window
52+
data_sources = []
53+
54+
for symbol in self.symbols:
55+
full_symbol = f"{symbol}/EUR"
56+
data_sources.append(
57+
DataSource(
58+
identifier=f"{symbol}_rsi_data",
59+
data_type=DataType.OHLCV,
60+
time_frame=self.rsi_time_frame,
61+
market=market,
62+
symbol=full_symbol,
63+
pandas=True,
64+
window_size=800
65+
)
66+
)
67+
data_sources.append(
68+
DataSource(
69+
identifier=f"{symbol}_ema_data",
70+
data_type=DataType.OHLCV,
71+
time_frame=self.ema_time_frame,
72+
market=market,
73+
symbol=full_symbol,
74+
pandas=True,
75+
window_size=800
76+
)
77+
)
78+
79+
super().__init__(
80+
data_sources=data_sources, time_unit=time_unit, interval=interval
81+
)
82+
83+
self.buy_signal_dates = {}
84+
self.sell_signal_dates = {}
85+
86+
for symbol in self.symbols:
87+
self.buy_signal_dates[symbol] = []
88+
self.sell_signal_dates[symbol] = []
89+
90+
def _prepare_indicators(
91+
self,
92+
rsi_data,
93+
ema_data
94+
):
95+
ema_data = ema(
96+
ema_data,
97+
period=self.ema_short_period,
98+
source_column="Close",
99+
result_column=self.ema_short_result_column
100+
)
101+
ema_data = ema(
102+
ema_data,
103+
period=self.ema_long_period,
104+
source_column="Close",
105+
result_column=self.ema_long_result_column
106+
)
107+
# Detect crossover (short EMA crosses above long EMA)
108+
ema_data = crossover(
109+
ema_data,
110+
first_column=self.ema_short_result_column,
111+
second_column=self.ema_long_result_column,
112+
result_column=self.ema_crossover_result_column
113+
)
114+
# Detect crossunder (short EMA crosses below long EMA)
115+
ema_data = crossunder(
116+
ema_data,
117+
first_column=self.ema_short_result_column,
118+
second_column=self.ema_long_result_column,
119+
result_column=self.ema_crossunder_result_column
120+
)
121+
rsi_data = rsi(
122+
rsi_data,
123+
period=self.rsi_period,
124+
source_column="Close",
125+
result_column=self.rsi_result_column
126+
)
127+
128+
return ema_data, rsi_data
129+
130+
def generate_buy_signals(self, data: Dict[str, Any]) -> Dict[str, pd.Series]:
131+
"""
132+
Generate buy signals based on the moving average crossover.
133+
134+
data (Dict[str, Any]): Dictionary containing all the data for
135+
the strategy data sources.
136+
137+
Returns:
138+
Dict[str, pd.Series]: A dictionary where keys are symbols and values
139+
are pandas Series indicating buy signals (True/False).
140+
"""
141+
142+
signals = {}
143+
144+
for symbol in self.symbols:
145+
ema_data_identifier = f"{symbol}_ema_data"
146+
rsi_data_identifier = f"{symbol}_rsi_data"
147+
ema_data, rsi_data = self._prepare_indicators(
148+
data[ema_data_identifier].copy(),
149+
data[rsi_data_identifier].copy()
150+
)
151+
152+
# crossover confirmed
153+
ema_crossover_lookback = ema_data[
154+
self.ema_crossover_result_column].rolling(
155+
window=self.ema_cross_lookback_window
156+
).max().astype(bool)
157+
158+
# use only RSI column
159+
rsi_oversold = rsi_data[self.rsi_result_column] \
160+
< self.rsi_oversold_threshold
161+
162+
buy_signal = rsi_oversold & ema_crossover_lookback
163+
buy_signals = buy_signal.fillna(False).astype(bool)
164+
signals[symbol] = buy_signals
165+
166+
# Get all dates where there is a sell signal
167+
buy_signal_dates = buy_signals[buy_signals].index.tolist()
168+
169+
if buy_signal_dates:
170+
self.buy_signal_dates[symbol] += buy_signal_dates
171+
172+
return signals
173+
174+
def generate_sell_signals(self, data: Dict[str, Any]) -> Dict[str, pd.Series]:
175+
"""
176+
Generate sell signals based on the moving average crossover.
177+
178+
Args:
179+
data (Dict[str, Any]): Dictionary containing all the data for
180+
the strategy data sources.
181+
182+
Returns:
183+
Dict[str, pd.Series]: A dictionary where keys are symbols and values
184+
are pandas Series indicating sell signals (True/False).
185+
"""
186+
187+
signals = {}
188+
for symbol in self.symbols:
189+
ema_data_identifier = f"{symbol}_ema_data"
190+
rsi_data_identifier = f"{symbol}_rsi_data"
191+
ema_data, rsi_data = self._prepare_indicators(
192+
data[ema_data_identifier].copy(),
193+
data[rsi_data_identifier].copy()
194+
)
195+
196+
# Confirmed by crossover between short-term EMA and long-term EMA
197+
# within a given lookback window
198+
ema_crossunder_lookback = ema_data[
199+
self.ema_crossunder_result_column].rolling(
200+
window=self.ema_cross_lookback_window
201+
).max().astype(bool)
202+
203+
# use only RSI column
204+
rsi_overbought = rsi_data[self.rsi_result_column] \
205+
>= self.rsi_overbought_threshold
206+
207+
# Combine both conditions
208+
sell_signal = rsi_overbought & ema_crossunder_lookback
209+
sell_signal = sell_signal.fillna(False).astype(bool)
210+
signals[symbol] = sell_signal
211+
212+
# Get all dates where there is a sell signal
213+
sell_signal_dates = sell_signal[sell_signal].index.tolist()
214+
215+
if sell_signal_dates:
216+
self.sell_signal_dates[symbol] += sell_signal_dates
217+
218+
return signals
219+
220+
221+
if __name__ == "__main__":
222+
app = create_app()
223+
app.add_strategy(
224+
RSIEMACrossoverStrategy(
225+
time_unit=TimeUnit.HOUR,
226+
interval=2,
227+
market="bitvavo",
228+
rsi_time_frame="2h",
229+
rsi_period=14,
230+
rsi_overbought_threshold=70,
231+
rsi_oversold_threshold=30,
232+
ema_time_frame="2h",
233+
ema_short_period=12,
234+
ema_long_period=26,
235+
ema_cross_lookback_window=10
236+
)
237+
)
238+
239+
# Market credentials for coinbase for both the portfolio connection and data sources will
240+
# be read from .env file, when not registering a market credential object in the app.
241+
app.add_market(
242+
market="bitvavo",
243+
trading_symbol="EUR",
244+
)
245+
backtest_range = BacktestDateRange(
246+
start_date=datetime(2023, 1, 1, tzinfo=timezone.utc),
247+
end_date=datetime(2024, 6, 1, tzinfo=timezone.utc)
248+
)
249+
backtest = app.run_backtest(
250+
backtest_date_range=backtest_range, initial_amount=1000
251+
)
252+
report = BacktestReport(backtest)
253+
report.show(backtest_date_range=backtest_range, browser=True)

investing_algorithm_framework/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
get_monthly_returns_heatmap_chart, create_weights, \
99
get_yearly_returns_bar_chart, get_entry_and_exit_signals, \
1010
get_ohlcv_data_completeness_chart
11-
from .domain import ApiException, combine_backtests, \
11+
from .domain import ApiException, combine_backtests, PositionSize, \
1212
OrderType, OperationalException, OrderStatus, OrderSide, \
1313
TimeUnit, TimeInterval, Order, Portfolio, Backtest, \
1414
Position, TimeFrame, INDEX_DATETIME, MarketCredential, \
@@ -165,4 +165,5 @@
165165
"get_growth_percentage",
166166
"BacktestEvaluationFocus",
167167
"combine_backtests",
168+
"PositionSize"
168169
]

investing_algorithm_framework/app/app.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
PortfolioConfiguration, SnapshotInterval, DataType, \
2121
PortfolioProvider, OrderExecutor, ImproperlyConfigured, \
2222
DataProvider, INDEX_DATETIME, tqdm, BacktestPermutationTest, \
23-
LAST_SNAPSHOT_DATETIME
23+
LAST_SNAPSHOT_DATETIME, BACKTESTING_FLAG
2424
from investing_algorithm_framework.infrastructure import setup_sqlalchemy, \
2525
create_all_tables, CCXTOrderExecutor, CCXTPortfolioProvider, \
2626
BacktestOrderExecutor, CCXTOHLCVDataProvider, clear_db, \
@@ -312,7 +312,8 @@ def initialize_backtest_config(
312312
),
313313
BACKTESTING_INITIAL_AMOUNT: initial_amount,
314314
INDEX_DATETIME: backtest_date_range.start_date,
315-
LAST_SNAPSHOT_DATETIME: None
315+
LAST_SNAPSHOT_DATETIME: None,
316+
BACKTESTING_FLAG: True
316317
}
317318
configuration_service = self.container.configuration_service()
318319
configuration_service.initialize_from_dict(data)
@@ -451,12 +452,15 @@ def initialize_data_sources_backtest(
451452
if not show_progress:
452453
for _, data_provider in data_providers:
453454
data_provider.prepare_backtest_data(
454-
backtest_date_range.start_date,
455-
backtest_date_range.end_date
455+
backtest_start_date=backtest_date_range.start_date,
456+
backtest_end_date=backtest_date_range.end_date
456457
)
457458
else:
458459
for _, data_provider in \
459-
tqdm(data_providers, desc=description, colour="green"):
460+
tqdm(
461+
data_providers, desc=description, colour="green"
462+
):
463+
460464
data_provider.prepare_backtest_data(
461465
backtest_start_date=backtest_date_range.start_date,
462466
backtest_end_date=backtest_date_range.end_date

investing_algorithm_framework/app/context.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from datetime import datetime, timezone
23
from typing import List
34

45
from investing_algorithm_framework.services import ConfigurationService, \
@@ -420,9 +421,52 @@ def get_portfolio(self, market=None) -> Portfolio:
420421
"""
421422

422423
if market is None:
423-
return self.portfolio_service.get_all()[0]
424+
portfolio = self.portfolio_service.get_all()[0]
425+
else:
426+
portfolio = self.portfolio_service.find({"market": market})
427+
428+
# Retrieve positions
429+
positions = self.position_service.get_all(
430+
{"portfolio": portfolio.id}
431+
)
432+
433+
if BACKTESTING_FLAG in self.configuration_service.config \
434+
and self.configuration_service.config[BACKTESTING_FLAG]:
435+
date = self.configuration_service.config[INDEX_DATETIME]
436+
else:
437+
date = datetime.now(tz=timezone.utc)
438+
439+
allocated = 0.0
440+
441+
for position in positions:
442+
443+
if position.symbol != portfolio.trading_symbol:
444+
ticker = self.data_provider_service.get_ticker_data(
445+
symbol=f"{position.symbol}/{portfolio.trading_symbol}",
446+
market=portfolio.market,
447+
date=date
448+
)
449+
if ticker is not None and "bid" in ticker:
450+
allocated += position.get_amount() * ticker["bid"]
451+
452+
portfolio.allocated = allocated
453+
return portfolio
454+
455+
def get_latest_price(self, symbol, market=None):
456+
457+
if BACKTESTING_FLAG in self.configuration_service.config \
458+
and self.configuration_service.config[BACKTESTING_FLAG]:
459+
date = self.configuration_service.config[INDEX_DATETIME]
460+
else:
461+
date = datetime.now(tz=timezone.utc)
462+
463+
ticker = self.data_provider_service.get_ticker_data(
464+
symbol=symbol,
465+
market=market,
466+
date=date
467+
)
424468

425-
return self.portfolio_service.find({"market": market})
469+
return ticker['bid'] if ticker and 'bid' in ticker else None
426470

427471
def get_portfolios(self):
428472
"""
@@ -1597,3 +1641,27 @@ def get_market_credentials(self) -> List[MarketCredential]:
15971641
List[MarketCredential]: A list of all market credentials
15981642
"""
15991643
return self.market_credential_service.get_all()
1644+
1645+
def get_trading_symbol(self, portfolio_id=None):
1646+
"""
1647+
Function to get the trading symbol of a portfolio. If the
1648+
portfolio_id parameter is specified, the function will return
1649+
the trading symbol of the portfolio with the specified id.
1650+
1651+
Args:
1652+
portfolio_id: The id of the portfolio to get the trading symbol for
1653+
1654+
Returns:
1655+
str: The trading symbol of the portfolio
1656+
"""
1657+
if portfolio_id is None:
1658+
if self.portfolio_service.count() > 1:
1659+
raise OperationalException(
1660+
"Multiple portfolios found. Please specify a "
1661+
"portfolio identifier."
1662+
)
1663+
portfolio = self.portfolio_service.get_all()[0]
1664+
else:
1665+
portfolio = self.portfolio_service.get(portfolio_id)
1666+
1667+
return portfolio.trading_symbol

0 commit comments

Comments
 (0)