Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
self,
dependencies: HandlerDependencies,
chat_llms: dict[str, Any],
playground_chat_llms: dict[str, Any] | None = None,
search_handler=None,
add_handler=None,
online_bot=None,
Expand All @@ -70,6 +71,7 @@ def __init__(
Args:
dependencies: HandlerDependencies instance
chat_llms: Dictionary mapping model names to LLM instances
playground_chat_llms: Optional model map for /chat/stream/playground
search_handler: Optional SearchHandler instance (created if not provided)
add_handler: Optional AddHandler instance (created if not provided)
online_bot: Optional DingDing bot function for notifications
Expand All @@ -89,6 +91,7 @@ def __init__(
add_handler = AddHandler(dependencies)

self.chat_llms = chat_llms
self.playground_chat_llms = playground_chat_llms or chat_llms
self.search_handler = search_handler
self.add_handler = add_handler
self.online_bot = online_bot
Expand Down Expand Up @@ -630,10 +633,11 @@ def generate_chat_response() -> Generator[str, None, None]:

# Step 3: Generate streaming response from LLM
try:
model = next(iter(self.chat_llms.keys()))
chat_llms = self.playground_chat_llms
model = next(iter(chat_llms.keys()))
self.logger.info(f"[PLAYGROUND CHAT] Chat Playground Stream Model: {model}")
start = time.time()
response_stream = self.chat_llms[model].generate_stream(
response_stream = chat_llms[model].generate_stream(
current_messages, model_name_or_path=model
)

Expand Down
7 changes: 7 additions & 0 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def init_server() -> dict[str, Any]:
graph_db_config = build_graph_db_config()
llm_config = build_llm_config()
chat_llm_config = build_chat_llm_config()
playground_chat_llm_config = build_chat_llm_config("PLAYGROUND_CHAT_MODEL_LIST")
embedder_config = build_embedder_config()
nli_client_config = build_nli_client_config()
mem_reader_config = build_mem_reader_config()
Expand All @@ -174,6 +175,11 @@ def init_server() -> dict[str, Any]:
if os.getenv("ENABLE_CHAT_API", "false") == "true"
else None
)
playground_chat_llms = (
_init_chat_llms(playground_chat_llm_config)
if os.getenv("ENABLE_CHAT_API", "false") == "true" and playground_chat_llm_config
else chat_llms
)
embedder = EmbedderFactory.from_config(embedder_config)

plugin_context = build_plugin_context(
Expand Down Expand Up @@ -317,6 +323,7 @@ def init_server() -> dict[str, Any]:
"mem_reader": mem_reader,
"llm": llm,
"chat_llms": chat_llms,
"playground_chat_llms": playground_chat_llms,
"embedder": embedder,
"reranker": reranker,
"internet_retriever": internet_retriever,
Expand Down
7 changes: 5 additions & 2 deletions src/memos/api/handlers/config_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,17 @@ def build_llm_config() -> dict[str, Any]:
)


def build_chat_llm_config() -> list[dict[str, Any]]:
def build_chat_llm_config(env_name: str = "CHAT_MODEL_LIST") -> list[dict[str, Any]]:
"""
Build chat LLM configuration.

Returns:
Validated chat LLM configuration dictionary
Args:
env_name: Environment variable that contains the JSON chat model list.

"""
configs = json.loads(os.getenv("CHAT_MODEL_LIST", "[]"))
configs = json.loads(os.getenv(env_name, "[]"))
return [
{
"config_class": LLMConfigFactory.model_validate(
Expand Down
9 changes: 5 additions & 4 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@
add_handler = AddHandler(dependencies)
chat_handler = (
ChatHandler(
dependencies,
components["chat_llms"],
search_handler,
add_handler,
dependencies=dependencies,
chat_llms=components["chat_llms"],
playground_chat_llms=components.get("playground_chat_llms"),
search_handler=search_handler,
add_handler=add_handler,
online_bot=components.get("online_bot"),
)
if os.getenv("ENABLE_CHAT_API", "false") == "true"
Expand Down
5 changes: 5 additions & 0 deletions src/memos/mem_scheduler/base_mixins/queue_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from memos.context.context import (
ContextThread,
RequestContext,
get_current_api_path,
get_current_context,
get_current_trace_id,
set_request_context,
Expand Down Expand Up @@ -38,13 +39,16 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
return

current_trace_id = get_current_trace_id()
current_api_path = get_current_api_path()

immediate_msgs: list[ScheduleMessageItem] = []
queued_msgs: list[ScheduleMessageItem] = []

for msg in messages:
if current_trace_id:
msg.trace_id = current_trace_id
if current_api_path and not getattr(msg, "api_path", None):
msg.api_path = current_api_path

with suppress(Exception):
self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label)
Expand Down Expand Up @@ -173,6 +177,7 @@ def _message_consumer(self) -> None:
try:
msg_context = RequestContext(
trace_id=msg.trace_id,
api_path=msg.api_path,
user_name=msg.user_name,
)
set_request_context(msg_context)
Expand Down
4 changes: 4 additions & 0 deletions src/memos/mem_scheduler/base_mixins/web_log_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from memos.log import get_logger
from memos.context.context import get_current_api_path
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
from memos.mem_scheduler.schemas.task_schemas import (
ADD_TASK_LABEL,
Expand Down Expand Up @@ -28,6 +29,9 @@ def _submit_web_logs(
if self.rabbitmq_config is None:
return
try:
current_api_path = get_current_api_path()
if current_api_path and not getattr(message, "api_path", None):
message.api_path = current_api_path
logger.info(
"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish %s",
message.model_dump_json(indent=2),
Expand Down
2 changes: 2 additions & 0 deletions src/memos/mem_scheduler/general_modules/scheduler_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable

from memos.log import get_logger
from memos.context.context import get_current_api_path
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
Expand Down Expand Up @@ -125,6 +126,7 @@ def create_autofilled_log_item(
log_content=log_content,
current_memory_sizes=current_memory_sizes,
memory_capacities=memory_capacities,
api_path=get_current_api_path(),
)
return log_message

Expand Down
4 changes: 4 additions & 0 deletions src/memos/mem_scheduler/schemas/message_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
description="user name / display name (optional)",
)
info: dict | None = Field(default=None, description="user custom info")
api_path: str | None = Field(default=None, description="source HTTP API path")
task_id: str | None = Field(
default=None,
description="Optional business-level task ID. Multiple items can share the same task_id.",
Expand Down Expand Up @@ -94,6 +95,7 @@ def to_dict(self) -> dict:
"timestamp": self.timestamp.isoformat(),
"user_name": self.user_name,
"task_id": self.task_id if self.task_id is not None else "",
"api_path": self.api_path if self.api_path is not None else "",
"chat_history": self.chat_history if self.chat_history is not None else [],
"user_context": self.user_context.model_dump(exclude_none=True)
if self.user_context
Expand Down Expand Up @@ -152,6 +154,7 @@ def _decode(val: Any) -> Any:
timestamp=timestamp,
user_name=_decode(data.get("user_name")),
task_id=_decode(data.get("task_id")),
api_path=_decode(data.get("api_path")),
chat_history=chat_history,
user_context=UserContext.model_validate(raw_user_context) if raw_user_context else None,
)
Expand Down Expand Up @@ -209,6 +212,7 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin):
)
source_doc_id: str | None = Field(default=None, description="Source document ID")
chat_history: list | None = Field(default=None, description="user chat history")
api_path: str | None = Field(default=None, description="source HTTP API path")

def debug_info(self) -> dict[str, Any]:
"""Return structured debug information for logging purposes."""
Expand Down
8 changes: 5 additions & 3 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_playground_api
from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker

Expand Down Expand Up @@ -140,6 +140,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
# Propagate trace_id and user info to logging context for this handler execution
ctx = RequestContext(
trace_id=trace_id,
api_path=getattr(first_msg, "api_path", None),
user_name=getattr(first_msg, "user_name", None),
user_type=None,
)
Expand Down Expand Up @@ -317,8 +318,7 @@ def _maybe_emit_task_completion(
mem_cube_id = first.mem_cube_id

try:
cloud_env = is_cloud_env()
if not cloud_env:
if is_playground_api():
return

for task_id in task_ids:
Expand All @@ -345,6 +345,7 @@ def _maybe_emit_task_completion(
log_content=f"Task {task_id} completed",
status="completed",
source_doc_id=source_doc_id,
api_path=getattr(messages[0], "api_path", None) if messages else None,
)
self.submit_web_logs(event)

Expand All @@ -369,6 +370,7 @@ def _maybe_emit_task_completion(
log_content=f"Task {task_id} failed: {error_msg}",
status="failed",
source_doc_id=source_doc_id,
api_path=getattr(messages[0], "api_path", None) if messages else None,
)
self.submit_web_logs(event)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
from memos.mem_scheduler.utils.misc_utils import is_cloud_env
from memos.mem_scheduler.utils.misc_utils import is_playground_api


if TYPE_CHECKING:
Expand All @@ -38,14 +38,14 @@ def batch_handler(
prepared_add_items,
prepared_update_items_with_original,
)
cloud_env = is_cloud_env()
playground_api = is_playground_api()

if cloud_env:
self.send_add_log_messages_to_cloud_env(
if playground_api:
self.send_add_log_messages_to_local_env(
msg, prepared_add_items, prepared_update_items_with_original
)
else:
self.send_add_log_messages_to_local_env(
self.send_add_log_messages_to_memory_change(
msg, prepared_add_items, prepared_update_items_with_original
)

Expand Down Expand Up @@ -231,10 +231,10 @@ def send_add_log_messages_to_local_env(
logger.info("send_add_log_messages_to_local_env: %s", len(events))
if events:
self.scheduler_context.services.submit_web_logs(
events, additional_log_info="send_add_log_messages_to_cloud_env"
events, additional_log_info="send_add_log_messages_to_local_env"
)

def send_add_log_messages_to_cloud_env(
def send_add_log_messages_to_memory_change(
self,
msg: ScheduleMessageItem,
prepared_add_items,
Expand Down Expand Up @@ -278,7 +278,7 @@ def send_add_log_messages_to_cloud_env(

if kb_log_content:
logger.info(
"[DIAGNOSTIC] add_handler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s",
"[DIAGNOSTIC] add_handler.send_add_log_messages_to_memory_change: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s",
msg.user_id,
msg.mem_cube_id,
msg.task_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
USER_INPUT_TYPE,
)
from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
from memos.mem_scheduler.utils.misc_utils import is_cloud_env
from memos.mem_scheduler.utils.misc_utils import is_playground_api


logger = get_logger(__name__)
Expand Down Expand Up @@ -75,8 +75,8 @@ def process_single_feedback(self, message: ScheduleMessageItem) -> None:
mem_cube_id,
)

cloud_env = is_cloud_env()
if cloud_env:
playground_api = is_playground_api()
if not playground_api:
record = feedback_result.get("record") if isinstance(feedback_result, dict) else {}
add_records = record.get("add") if isinstance(record, dict) else []
update_records = record.get("update") if isinstance(record, dict) else []
Expand Down Expand Up @@ -191,6 +191,7 @@ def _extract_fields(mem_item):
)
else:
logger.info(
"Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
cloud_env,
"Skipping memory-change web log for feedback on playground API "
"(is_playground_api=%s)",
playground_api,
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
from memos.mem_scheduler.utils.misc_utils import is_cloud_env
from memos.mem_scheduler.utils.misc_utils import is_playground_api
from memos.memories.textual.tree import TreeTextMemory


Expand Down Expand Up @@ -268,8 +268,8 @@ def _process_memories_with_reader(
"[Scheduler] merged_from provided but graph_db is unavailable; skip archiving."
)

cloud_env = is_cloud_env()
if cloud_env:
playground_api = is_playground_api()
if not playground_api:
kb_log_content = []
for item in flattened_memories:
metadata = getattr(item, "metadata", None)
Expand Down Expand Up @@ -448,8 +448,8 @@ def _process_memories_with_reader(
exc_info=True,
)
with contextlib.suppress(Exception):
cloud_env = is_cloud_env()
if cloud_env:
playground_api = is_playground_api()
if not playground_api:
if not kb_log_content:
trigger_source = (
info.get("trigger_source", "Messages") if info else "Messages"
Expand Down
5 changes: 4 additions & 1 deletion src/memos/mem_scheduler/task_schedule_modules/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
the local memos_message_queue functionality in BaseScheduler.
"""

from memos.context.context import get_current_trace_id
from memos.context.context import get_current_api_path, get_current_trace_id
from memos.log import get_logger
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
Expand Down Expand Up @@ -104,11 +104,14 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
return

current_trace_id = get_current_trace_id()
current_api_path = get_current_api_path()

for msg in messages:
if current_trace_id:
# Prefer current request trace_id so logs can be correlated
msg.trace_id = current_trace_id
if current_api_path and not getattr(msg, "api_path", None):
msg.api_path = current_api_path
msg.stream_key = self.memos_message_queue.get_stream_key(
user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, task_label=msg.label
)
Expand Down
Loading
Loading