diff options
-rw-r--r-- | redis/cluster.py | 47 | ||||
-rw-r--r-- | tests/conftest.py | 2 | ||||
-rw-r--r-- | tests/test_search.py | 56 |
3 files changed, 68 insertions, 37 deletions
diff --git a/redis/cluster.py b/redis/cluster.py index 87643a7..09c9ab7 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -320,6 +320,42 @@ class RedisCluster(RedisClusterCommands): ), ) + SEARCH_COMMANDS = ( + [ + "FT.CREATE", + "FT.SEARCH", + "FT.AGGREGATE", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT,PROFILE", + "FT.ALTER", + "FT.DROPINDEX", + "FT.ALIASADD", + "FT.ALIASUPDATE", + "FT.ALIASDEL", + "FT.TAGVALS", + "FT.SUGADD", + "FT.SUGGET", + "FT.SUGDEL", + "FT.SUGLEN", + "FT.SYNUPDATE", + "FT.SYNDUMP", + "FT.SPELLCHECK", + "FT.DICTADD", + "FT.DICTDEL", + "FT.DICTDUMP", + "FT.INFO", + "FT._LIST", + "FT.CONFIG", + "FT.ADD", + "FT.DEL", + "FT.DROP", + "FT.GET", + "FT.MGET", + "FT.SYNADD", + ], + ) + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { "CLUSTER ADDSLOTS": bool, "CLUSTER ADDSLOTSRANGE": bool, @@ -854,6 +890,8 @@ class RedisCluster(RedisClusterCommands): elif command_flag == self.__class__.DEFAULT_NODE: # return the cluster's default node return [self.nodes_manager.default_node] + elif command in self.__class__.SEARCH_COMMANDS[0]: + return [self.nodes_manager.default_node] else: # get the node that holds the key's slot slot = self.determine_slot(*args) @@ -1956,17 +1994,14 @@ class ClusterPipeline(RedisCluster): # refer to our internal node -> slot table that # tells us where a given # command should route to. - slot = self.determine_slot(*c.args) - node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and c.args[0] in READ_COMMANDS - ) + node = self._determine_nodes(*c.args) # now that we know the name of the node # ( it's just a string in the form of host:port ) # we can build a list of commands for each node. - node_name = node.name + node_name = node[0].name if node_name not in nodes: - redis_node = self.get_redis_connection(node) + redis_node = self.get_redis_connection(node[0]) connection = get_connection(redis_node, c.args) nodes[node_name] = NodeCommands( redis_node.parse_response, redis_node.connection_pool, connection diff --git a/tests/conftest.py b/tests/conftest.py index b615915..3b66adf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -183,7 +183,7 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): while now < end_time: try: client = redis.RedisCluster.from_url(redis_url) - if len(client.get_nodes()) == cluster_nodes: + if len(client.get_nodes()) == int(cluster_nodes): print("All nodes are available!") break except RedisClusterException: diff --git a/tests/test_search.py b/tests/test_search.py index 67f4357..5dea739 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -10,7 +10,6 @@ import redis import redis.commands.search import redis.commands.search.aggregation as aggregations import redis.commands.search.reducers as reducers -from redis import Redis from redis.commands.json.path import Path from redis.commands.search import Search from redis.commands.search.field import GeoField, NumericField, TagField, TextField @@ -19,10 +18,7 @@ from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from .conftest import default_redismod_url, skip_ifmodversion_lt - -pytestmark = pytest.mark.onlynoncluster - +from .conftest import skip_ifmodversion_lt WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -36,7 +32,7 @@ TITLES_CSV = os.path.abspath( def waitForIndex(env, idx, timeout=None): delay = 0.1 while True: - res = env.execute_command("ft.info", idx) + res = env.execute_command("FT.INFO", idx) try: res.index("indexing") except ValueError: @@ -52,13 +48,12 @@ def waitForIndex(env, idx, timeout=None): break -def getClient(): +def getClient(client): """ Gets a client client attached to an index name which is ready to be created """ - rc = Redis.from_url(default_redismod_url, decode_responses=True) - return rc + return client def createIndex(client, num_docs=100, definition=None): @@ -96,12 +91,6 @@ def createIndex(client, num_docs=100, definition=None): indexer.commit() -# override the default module client, search requires both db=0, and text -@pytest.fixture -def modclient(): - return Redis.from_url(default_redismod_url, db=0, decode_responses=True) - - @pytest.fixture def client(modclient): modclient.flushdb() @@ -234,6 +223,7 @@ def test_payloads(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def test_scores(client): client.ft().create_index((TextField("txt"),)) @@ -356,14 +346,14 @@ def test_sort_by(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -def test_drop_index(): +def test_drop_index(client): """ Ensure the index gets dropped by data remains by default """ for x in range(20): for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: idx = "HaveIt" - index = getClient() + index = getClient(client) index.hset("index:haveit", mapping={"name": "haveit"}) idef = IndexDefinition(prefix=["index:"]) index.ft(idx).create_index((TextField("name"),), definition=idef) @@ -574,9 +564,9 @@ def test_summarize(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -def test_alias(): - index1 = getClient() - index2 = getClient() +def test_alias(client): + index1 = getClient(client) + index2 = getClient(client) def1 = IndexDefinition(prefix=["index1:"]) def2 = IndexDefinition(prefix=["index2:"]) @@ -594,7 +584,7 @@ def test_alias(): # create alias and check for results ftindex1.aliasadd("spaceballs") - alias_client = getClient().ft("spaceballs") + alias_client = getClient(client).ft("spaceballs") res = alias_client.search("*").docs[0] assert "index1:lonestar" == res.id @@ -604,7 +594,7 @@ def test_alias(): # update alias and ensure new results ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient().ft("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") res = alias_client2.search("*").docs[0] assert "index2:yogurt" == res.id @@ -615,21 +605,21 @@ def test_alias(): @pytest.mark.redismod -def test_alias_basic(): +def test_alias_basic(client): # Creating a client with one index - getClient().flushdb() - index1 = getClient().ft("testAlias") + getClient(client).flushdb() + index1 = getClient(client).ft("testAlias") index1.create_index((TextField("txt"),)) index1.add_document("doc1", txt="text goes here") - index2 = getClient().ft("testAlias2") + index2 = getClient(client).ft("testAlias2") index2.create_index((TextField("txt"),)) index2.add_document("doc2", txt="text goes here") # add the actual alias and check index1.aliasadd("myalias") - alias_client = getClient().ft("myalias") + alias_client = getClient(client).ft("myalias") res = sorted(alias_client.search("*").docs, key=lambda x: x.id) assert "doc1" == res[0].id @@ -639,7 +629,7 @@ def test_alias_basic(): # update the alias and ensure we get doc2 index2.aliasupdate("myalias") - alias_client2 = getClient().ft("myalias") + alias_client2 = getClient(client).ft("myalias") res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) assert "doc1" == res[0].id @@ -790,6 +780,7 @@ def test_phonetic_matcher(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def test_scorer(client): client.ft().create_index((TextField("description"),)) @@ -842,6 +833,7 @@ def test_get(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.2.0", "search") def test_config(client): assert client.ft().config_set("TIMEOUT", "100") @@ -854,6 +846,7 @@ def test_config(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def test_aggregations_groupby(client): # Creating the index definition and schema client.ft().create_index( @@ -1085,8 +1078,8 @@ def test_aggregations_apply(client): CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" ) res = client.ft().aggregate(req) - assert res.rows[0] == ["CreatedDateTimeUTC", "6373878785249699840"] - assert res.rows[1] == ["CreatedDateTimeUTC", "6373878758592700416"] + res_set = set([res.rows[0][1], res.rows[1][1]]) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) @pytest.mark.redismod @@ -1158,6 +1151,7 @@ def test_index_definition(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def testExpire(client): client.ft().create_index((TextField("txt", sortable=True),), temporary=4) ttl = client.execute_command("ft.debug", "TTL", "idx") @@ -1477,6 +1471,7 @@ def test_json_with_jsonpath(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def test_profile(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") @@ -1505,6 +1500,7 @@ def test_profile(client): @pytest.mark.redismod +@pytest.mark.onlynoncluster def test_profile_limited(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") |