Skip to content

Commit 09c0174

Browse files
committed
Keep gds.alpha.graph.sample.rwr endpoint
1 parent d4507c0 commit 09c0174

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

graphdatascience/graph/graph_alpha_proc_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from .graph_alpha_project_runner import GraphAlphaProjectRunner
1212
from .graph_entity_ops_runner import GraphLabelRunner, GraphPropertyRunner
1313
from .graph_object import Graph
14-
from .graph_sample_runner import GraphSampleRunner
14+
from .graph_sample_runner import GraphAlphaSampleRunner
1515

1616

1717
class GraphAlphaProcRunner(UncallableNamespace, IllegalAttrChecker):
1818
@property
19-
def sample(self) -> GraphSampleRunner:
19+
def sample(self) -> GraphAlphaSampleRunner:
2020
self._namespace += ".sample"
21-
return GraphSampleRunner(self._query_runner, self._namespace, self._server_version)
21+
return GraphAlphaSampleRunner(self._query_runner, self._namespace, self._server_version)
2222

2323
@property
2424
def graphProperty(self) -> GraphPropertyRunner:

graphdatascience/graph/graph_sample_runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@
22

33
from pandas import Series
44

5+
from ..error.deprecation_warning import deprecation_warning
56
from ..error.illegal_attr_checker import IllegalAttrChecker
67
from ..server_version.compatible_with import compatible_with
78
from ..server_version.server_version import ServerVersion
89
from .graph_object import Graph
910
from .graph_type_check import from_graph_type_check
1011

1112

13+
class GraphAlphaSampleRunner(IllegalAttrChecker):
14+
@compatible_with("construct", min_inclusive=ServerVersion(2, 2, 0))
15+
@deprecation_warning("gds.graph.sample.rwr", ServerVersion(2, 4, 0))
16+
@from_graph_type_check
17+
def rwr(self, graph_name: str, from_G: Graph, **config: Any) -> Tuple[Graph, "Series[Any]"]:
18+
runner = RWRRunner(self._query_runner, self._namespace + ".rwr", self._server_version)
19+
return runner(graph_name, from_G, **config)
20+
21+
1222
class GraphSampleRunner(IllegalAttrChecker):
1323
@property
1424
def rwr(self) -> "RWRRunner":

graphdatascience/tests/unit/resources/example_server_endpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"gds.alpha.graph.graphProperty.stream",
2828
"gds.alpha.graph.nodeLabel.mutate",
2929
"gds.alpha.graph.nodeLabel.write",
30+
"gds.alpha.graph.sample.rwr",
3031
"gds.alpha.hits.mutate",
3132
"gds.alpha.hits.mutate.estimate",
3233
"gds.alpha.hits.stats",

graphdatascience/tests/unit/test_graph_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,18 @@ def test_graph_generate(runner: CollectingQueryRunner, gds: GraphDataScience) ->
602602
}
603603

604604

605+
def test_alpha_graph_sample_rwr(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
606+
from_G, _ = gds.graph.project("g", "*", "*")
607+
gds.alpha.graph.sample.rwr("s", from_G, samplingRatio=0.9, concurrency=7)
608+
609+
assert runner.last_query() == "CALL gds.alpha.graph.sample.rwr($graph_name, $from_graph_name, $config)"
610+
assert runner.last_params() == {
611+
"graph_name": "s",
612+
"from_graph_name": "g",
613+
"config": {"samplingRatio": 0.9, "concurrency": 7},
614+
}
615+
616+
605617
def test_graph_sample_rwr(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
606618
from_G, _ = gds.graph.project("g", "*", "*")
607619
gds.graph.sample.rwr("s", from_G, samplingRatio=0.9, concurrency=7)

0 commit comments

Comments
 (0)