summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--redis/cluster.py47
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/test_search.py56
3 files changed, 68 insertions, 37 deletions
diff --git a/redis/cluster.py b/redis/cluster.py
index 87643a7..09c9ab7 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -320,6 +320,42 @@ class RedisCluster(RedisClusterCommands):
),
)
+ SEARCH_COMMANDS = (
+ [
+ "FT.CREATE",
+ "FT.SEARCH",
+ "FT.AGGREGATE",
+ "FT.EXPLAIN",
+ "FT.EXPLAINCLI",
+ "FT,PROFILE",
+ "FT.ALTER",
+ "FT.DROPINDEX",
+ "FT.ALIASADD",
+ "FT.ALIASUPDATE",
+ "FT.ALIASDEL",
+ "FT.TAGVALS",
+ "FT.SUGADD",
+ "FT.SUGGET",
+ "FT.SUGDEL",
+ "FT.SUGLEN",
+ "FT.SYNUPDATE",
+ "FT.SYNDUMP",
+ "FT.SPELLCHECK",
+ "FT.DICTADD",
+ "FT.DICTDEL",
+ "FT.DICTDUMP",
+ "FT.INFO",
+ "FT._LIST",
+ "FT.CONFIG",
+ "FT.ADD",
+ "FT.DEL",
+ "FT.DROP",
+ "FT.GET",
+ "FT.MGET",
+ "FT.SYNADD",
+ ],
+ )
+
CLUSTER_COMMANDS_RESPONSE_CALLBACKS = {
"CLUSTER ADDSLOTS": bool,
"CLUSTER ADDSLOTSRANGE": bool,
@@ -854,6 +890,8 @@ class RedisCluster(RedisClusterCommands):
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
+ elif command in self.__class__.SEARCH_COMMANDS[0]:
+ return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
@@ -1956,17 +1994,14 @@ class ClusterPipeline(RedisCluster):
# refer to our internal node -> slot table that
# tells us where a given
# command should route to.
- slot = self.determine_slot(*c.args)
- node = self.nodes_manager.get_node_from_slot(
- slot, self.read_from_replicas and c.args[0] in READ_COMMANDS
- )
+ node = self._determine_nodes(*c.args)
# now that we know the name of the node
# ( it's just a string in the form of host:port )
# we can build a list of commands for each node.
- node_name = node.name
+ node_name = node[0].name
if node_name not in nodes:
- redis_node = self.get_redis_connection(node)
+ redis_node = self.get_redis_connection(node[0])
connection = get_connection(redis_node, c.args)
nodes[node_name] = NodeCommands(
redis_node.parse_response, redis_node.connection_pool, connection
diff --git a/tests/conftest.py b/tests/conftest.py
index b615915..3b66adf 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -183,7 +183,7 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60):
while now < end_time:
try:
client = redis.RedisCluster.from_url(redis_url)
- if len(client.get_nodes()) == cluster_nodes:
+ if len(client.get_nodes()) == int(cluster_nodes):
print("All nodes are available!")
break
except RedisClusterException:
diff --git a/tests/test_search.py b/tests/test_search.py
index 67f4357..5dea739 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -10,7 +10,6 @@ import redis
import redis.commands.search
import redis.commands.search.aggregation as aggregations
import redis.commands.search.reducers as reducers
-from redis import Redis
from redis.commands.json.path import Path
from redis.commands.search import Search
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
@@ -19,10 +18,7 @@ from redis.commands.search.query import GeoFilter, NumericFilter, Query
from redis.commands.search.result import Result
from redis.commands.search.suggestion import Suggestion
-from .conftest import default_redismod_url, skip_ifmodversion_lt
-
-pytestmark = pytest.mark.onlynoncluster
-
+from .conftest import skip_ifmodversion_lt
WILL_PLAY_TEXT = os.path.abspath(
os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2")
@@ -36,7 +32,7 @@ TITLES_CSV = os.path.abspath(
def waitForIndex(env, idx, timeout=None):
delay = 0.1
while True:
- res = env.execute_command("ft.info", idx)
+ res = env.execute_command("FT.INFO", idx)
try:
res.index("indexing")
except ValueError:
@@ -52,13 +48,12 @@ def waitForIndex(env, idx, timeout=None):
break
-def getClient():
+def getClient(client):
"""
Gets a client client attached to an index name which is ready to be
created
"""
- rc = Redis.from_url(default_redismod_url, decode_responses=True)
- return rc
+ return client
def createIndex(client, num_docs=100, definition=None):
@@ -96,12 +91,6 @@ def createIndex(client, num_docs=100, definition=None):
indexer.commit()
-# override the default module client, search requires both db=0, and text
-@pytest.fixture
-def modclient():
- return Redis.from_url(default_redismod_url, db=0, decode_responses=True)
-
-
@pytest.fixture
def client(modclient):
modclient.flushdb()
@@ -234,6 +223,7 @@ def test_payloads(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def test_scores(client):
client.ft().create_index((TextField("txt"),))
@@ -356,14 +346,14 @@ def test_sort_by(client):
@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
-def test_drop_index():
+def test_drop_index(client):
"""
Ensure the index gets dropped by data remains by default
"""
for x in range(20):
for keep_docs in [[True, {}], [False, {"name": "haveit"}]]:
idx = "HaveIt"
- index = getClient()
+ index = getClient(client)
index.hset("index:haveit", mapping={"name": "haveit"})
idef = IndexDefinition(prefix=["index:"])
index.ft(idx).create_index((TextField("name"),), definition=idef)
@@ -574,9 +564,9 @@ def test_summarize(client):
@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
-def test_alias():
- index1 = getClient()
- index2 = getClient()
+def test_alias(client):
+ index1 = getClient(client)
+ index2 = getClient(client)
def1 = IndexDefinition(prefix=["index1:"])
def2 = IndexDefinition(prefix=["index2:"])
@@ -594,7 +584,7 @@ def test_alias():
# create alias and check for results
ftindex1.aliasadd("spaceballs")
- alias_client = getClient().ft("spaceballs")
+ alias_client = getClient(client).ft("spaceballs")
res = alias_client.search("*").docs[0]
assert "index1:lonestar" == res.id
@@ -604,7 +594,7 @@ def test_alias():
# update alias and ensure new results
ftindex2.aliasupdate("spaceballs")
- alias_client2 = getClient().ft("spaceballs")
+ alias_client2 = getClient(client).ft("spaceballs")
res = alias_client2.search("*").docs[0]
assert "index2:yogurt" == res.id
@@ -615,21 +605,21 @@ def test_alias():
@pytest.mark.redismod
-def test_alias_basic():
+def test_alias_basic(client):
# Creating a client with one index
- getClient().flushdb()
- index1 = getClient().ft("testAlias")
+ getClient(client).flushdb()
+ index1 = getClient(client).ft("testAlias")
index1.create_index((TextField("txt"),))
index1.add_document("doc1", txt="text goes here")
- index2 = getClient().ft("testAlias2")
+ index2 = getClient(client).ft("testAlias2")
index2.create_index((TextField("txt"),))
index2.add_document("doc2", txt="text goes here")
# add the actual alias and check
index1.aliasadd("myalias")
- alias_client = getClient().ft("myalias")
+ alias_client = getClient(client).ft("myalias")
res = sorted(alias_client.search("*").docs, key=lambda x: x.id)
assert "doc1" == res[0].id
@@ -639,7 +629,7 @@ def test_alias_basic():
# update the alias and ensure we get doc2
index2.aliasupdate("myalias")
- alias_client2 = getClient().ft("myalias")
+ alias_client2 = getClient(client).ft("myalias")
res = sorted(alias_client2.search("*").docs, key=lambda x: x.id)
assert "doc1" == res[0].id
@@ -790,6 +780,7 @@ def test_phonetic_matcher(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def test_scorer(client):
client.ft().create_index((TextField("description"),))
@@ -842,6 +833,7 @@ def test_get(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
@skip_ifmodversion_lt("2.2.0", "search")
def test_config(client):
assert client.ft().config_set("TIMEOUT", "100")
@@ -854,6 +846,7 @@ def test_config(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def test_aggregations_groupby(client):
# Creating the index definition and schema
client.ft().create_index(
@@ -1085,8 +1078,8 @@ def test_aggregations_apply(client):
CreatedDateTimeUTC="@CreatedDateTimeUTC * 10"
)
res = client.ft().aggregate(req)
- assert res.rows[0] == ["CreatedDateTimeUTC", "6373878785249699840"]
- assert res.rows[1] == ["CreatedDateTimeUTC", "6373878758592700416"]
+ res_set = set([res.rows[0][1], res.rows[1][1]])
+ assert res_set == set(["6373878785249699840", "6373878758592700416"])
@pytest.mark.redismod
@@ -1158,6 +1151,7 @@ def test_index_definition(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def testExpire(client):
client.ft().create_index((TextField("txt", sortable=True),), temporary=4)
ttl = client.execute_command("ft.debug", "TTL", "idx")
@@ -1477,6 +1471,7 @@ def test_json_with_jsonpath(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def test_profile(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
@@ -1505,6 +1500,7 @@ def test_profile(client):
@pytest.mark.redismod
+@pytest.mark.onlynoncluster
def test_profile_limited(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")