Skip to content

Commit 4874e1b

Browse files
committed
Put BotState in dedicated folder
1 parent 6dc4e76 commit 4874e1b

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from investing_bot_framework.core.states.bot_state import BotState
2+

investing_bot_framework/core/context/states/bot_state.py renamed to investing_bot_framework/core/states/bot_state.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ class BotState(ABC):
1414
transition_state_class = None
1515

1616
# Validator for the current state
17-
state_validators = None
17+
pre_state_validators: List[StateValidator] = None
18+
post_state_validators: List[StateValidator] = None
1819

19-
def __init__(self, context, state_validator: StateValidator = None) -> None:
20+
def __init__(self, context) -> None:
2021
self._bot_context = context
21-
self._state_validator = state_validator
2222

2323
def start(self):
2424

25+
# Will stop the state if pre-conditions are not met
26+
if not self.validate_state():
27+
return
28+
2529
while True:
2630
self.run()
2731

@@ -37,12 +41,15 @@ def run(self) -> None:
3741
def context(self):
3842
return self._bot_context
3943

40-
def validate_state(self) -> bool:
44+
def validate_state(self, pre_state: bool = False) -> bool:
4145
"""
4246
Function that will validate the state
4347
"""
4448

45-
state_validators = self.get_state_validators()
49+
if pre_state:
50+
state_validators = self.get_pre_state_validators()
51+
else:
52+
state_validators = self.get_post_state_validators()
4653

4754
if state_validators is None:
4855
return True
@@ -63,12 +70,18 @@ def get_transition_state_class(self):
6370

6471
return self.transition_state_class
6572

66-
def get_state_validators(self) -> List[StateValidator]:
73+
def get_pre_state_validators(self) -> List[StateValidator]:
6774

68-
if self.state_validators is not None:
75+
if self.pre_state_validators is not None:
6976
return [
70-
state_validator() for state_validator in getattr(self, 'state_validators')
77+
state_validator() for state_validator in getattr(self, 'pre_state_validators')
7178
if issubclass(state_validator, StateValidator)
7279
]
7380

81+
def get_post_state_validators(self) -> List[StateValidator]:
7482

83+
if self.post_state_validators is not None:
84+
return [
85+
state_validator() for state_validator in getattr(self, 'post_state_validators')
86+
if issubclass(state_validator, StateValidator)
87+
]

0 commit comments

Comments
 (0)