Skip to content

Commit c99d116

Browse files
committed
fix:improve save
1 parent 16fef73 commit c99d116

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

scheduler/redis_models/base.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def deserialize(cls, data: Dict[str, Any]) -> Self:
108108
logger.warning(f"Unknown field {k} in {cls.__name__}")
109109
continue
110110
data[k] = _deserialize(data[k], types[k])
111-
return cls(**data)
111+
res= cls(**data)
112+
return res
112113

113114

114115
@dataclasses.dataclass(slots=True, kw_only=True)
@@ -125,10 +126,17 @@ def __post_init__(self):
125126
self._save_all = True
126127

127128
def __setattr__(self, key, value):
128-
if key != "_dirty_fields" and hasattr(self, "_dirty_fields"):
129+
if not key.startswith("_") and hasattr(self, "_dirty_fields"):
129130
self._dirty_fields.add(key)
130131
super(HashModel, self).__setattr__(key, value)
131132

133+
@classmethod
134+
def deserialize(cls, data: Dict[str, Any]) -> Self:
135+
instance = super(HashModel, cls).deserialize(data)
136+
instance._dirty_fields = set()
137+
instance._save_all = False
138+
return instance
139+
132140
@property
133141
def _parent_key(self) -> Optional[str]:
134142
if self.parent is None:
@@ -171,27 +179,30 @@ def get(cls, name: str, connection: ConnectionType) -> Optional[Self]:
171179

172180
@classmethod
173181
def get_many(cls, names: Sequence[str], connection: ConnectionType) -> List[Optional[Self]]:
174-
pipeline = connection.pipeline()
175-
for name in names:
176-
pipeline.hgetall(cls._element_key_template.format(name))
177-
values = pipeline.execute()
182+
with connection.pipeline() as pipeline:
183+
for name in names:
184+
pipeline.hgetall(cls._element_key_template.format(name))
185+
values = pipeline.execute()
178186
return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values]
179187

180-
def save(self, connection: ConnectionType) -> None:
181-
connection.sadd(self._list_key, self.name)
182-
if self._parent_key is not None:
183-
connection.sadd(self._parent_key, self.name)
184-
mapping = self.serialize(with_nones=True)
185-
if not self._save_all and len(self._dirty_fields) > 0:
186-
mapping = {k: v for k, v in mapping.items() if k in self._dirty_fields}
187-
none_values = {k for k, v in mapping.items() if v is None}
188-
if none_values:
189-
connection.hdel(self._key, *none_values)
190-
mapping = {k: v for k, v in mapping.items() if v is not None}
191-
if mapping:
192-
connection.hset(self._key, mapping=mapping)
193-
self._dirty_fields = set()
194-
self._save_all = False
188+
def save(self, connection: ConnectionType, save_all:bool=False) -> None:
189+
save_all = save_all or self._save_all
190+
with connection.pipeline() as pipeline:
191+
pipeline.sadd(self._list_key, self.name)
192+
if self._parent_key is not None:
193+
pipeline.sadd(self._parent_key, self.name)
194+
mapping = self.serialize(with_nones=True)
195+
if not save_all:
196+
mapping = {k: v for k, v in mapping.items() if k in self._dirty_fields}
197+
none_values = {k for k, v in mapping.items() if v is None}
198+
if none_values:
199+
pipeline.hdel(self._key, *none_values)
200+
mapping = {k: v for k, v in mapping.items() if v is not None}
201+
if mapping:
202+
pipeline.hset(self._key, mapping=mapping)
203+
pipeline.execute()
204+
self._dirty_fields = set()
205+
self._save_all = False
195206

196207
def delete(self, connection: ConnectionType) -> None:
197208
connection.srem(self._list_key, self._key)

0 commit comments

Comments
 (0)