Skip to content

Commit 5865524

Browse files
taskingaijcSimsonW
authored andcommitted
test: add retrieval test
1 parent 9778177 commit 5865524

15 files changed

+1233
-869
lines changed

‎test/common/utils.py

+50
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,53 @@ def get_random():
9393

9494
def get_project_id(i: int):
9595
return (get_random() for j in range(i))
96+
97+
98+
def assume_text_embedding_result(result: list):
99+
pytest.assume(len(result) > 0)
100+
pytest.assume(all(isinstance(value, float) for value in result))
101+
102+
103+
def assume_collection_result(create_dict: dict, res_dict: dict):
104+
for key in create_dict:
105+
pytest.assume(res_dict[key] == create_dict[key])
106+
pytest.assume(res_dict["status"] == "ready")
107+
108+
109+
def assume_record_result(create_record_data: dict, res_dict: dict):
110+
for key in create_record_data:
111+
if key == "text_splitter":
112+
continue
113+
pytest.assume(res_dict[key] == create_record_data[key])
114+
pytest.assume(res_dict["status"] == "ready")
115+
116+
117+
def assume_chunk_result(chunk_dict: dict, res: dict):
118+
for key, value in chunk_dict.items():
119+
pytest.assume(res[key] == chunk_dict[key])
120+
121+
122+
def assume_query_chunk_result(query_text, chunk_dict: dict):
123+
pytest.assume(query_text in chunk_dict["content"])
124+
pytest.assume(isinstance(chunk_dict["score"], float))
125+
126+
127+
def assume_assistant_result(assistant_dict: dict, res: dict):
128+
for key, value in assistant_dict.items():
129+
if key == 'system_prompt_template' and isinstance(value, str):
130+
pytest.assume(res[key] == [assistant_dict[key]])
131+
elif key in ["memory", "tool", "retrievals"]:
132+
continue
133+
else:
134+
pytest.assume(res[key] == assistant_dict[key])
135+
136+
137+
def assume_chat_result(chat_dict: dict, res: dict):
138+
for key, value in chat_dict.items():
139+
pytest.assume(res[key] == chat_dict[key])
140+
141+
142+
def assume_message_result(message_dict: dict, res: dict):
143+
for key, value in message_dict.items():
144+
pytest.assume(res[key] == message_dict[key])
145+

‎test/config.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
22

3-
chat_completion_model_id = os.environ.get("CHAT_COMPLETION_MODEL_ID")
4-
if not chat_completion_model_id:
5-
raise ValueError("chat_completion_model_id is not defined")
63

7-
embedding_model_id = os.environ.get("EMBEDDING_MODEL_ID")
8-
if not chat_completion_model_id:
9-
raise ValueError("chat_completion_model_id is not defined")
4+
class Config:
105

11-
sleep_time = 1
6+
chat_completion_model_id = os.environ.get("CHAT_COMPLETION_MODEL_ID")
7+
if not chat_completion_model_id:
8+
raise ValueError("chat_completion_model_id is not defined")
9+
10+
text_embedding_model_id = os.environ.get("TEXT_EMBEDDING_MODEL_ID")
11+
if not chat_completion_model_id:
12+
raise ValueError("chat_completion_model_id is not defined")
13+
14+
sleep_time = 1
1215

1316

‎test/data/test_assistant_data.yml

-19
This file was deleted.

‎test/run_test.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
set -e
2+
parent_dir="$(dirname "$(pwd)")"
3+
export PYTHONPATH="${PYTHONPATH}:${parent_dir}"
4+
5+
echo "Starting tests..."
6+
7+
pytest -q --tb=no
8+
9+
echo "Tests completed."
10+

‎test/testcase/test_async/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class Base:
2+
3+
collection_id, record_id, chunk_id, action_id, assistant_id, chat_id, message_id = None, None, None, None, None, None, None

‎test/testcase/test_async/base.py

-3
This file was deleted.

0 commit comments

Comments
 (0)