Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ llama-index-llms-ollama = "^0.9.1"
llama-index-llms-openai = "^0.6.12"
llama-index-llms-openai-like = "^0.5.3"
llama-index-llms-perplexity = "^0.4.1"
litellm = ">=1.55,<2.0"
llama-index-multi-modal-llms-openai = "^0.6.1"
llama-index-vector-stores-chroma = "^0.5.3"
llama-index-vector-stores-elasticsearch = "0.5.1"
Expand Down
2 changes: 2 additions & 0 deletions src/pygpt_net/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def run(**kwargs):
from pygpt_net.provider.llms.perplexity import PerplexityLLM
from pygpt_net.provider.llms.x_ai import xAILLM
from pygpt_net.provider.llms.open_router import OpenRouterLLM
from pygpt_net.provider.llms.litellm import LiteLLMProvider

# vector store providers (llama-index)
from pygpt_net.provider.vector_stores.chroma import ChromaProvider
Expand Down Expand Up @@ -487,6 +488,7 @@ def run(**kwargs):
launcher.add_llm(PerplexityLLM())
launcher.add_llm(xAILLM())
launcher.add_llm(OpenRouterLLM())
launcher.add_llm(LiteLLMProvider())

# register LLMs
llms = kwargs.get('llms', None)
Expand Down
135 changes: 135 additions & 0 deletions src/pygpt_net/provider/llms/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ================================================== #
# This file is a part of PYGPT package #
# Website: https://pygpt.net #
# GitHub: https://github.com/szczyglis-dev/py-gpt #
# MIT License #
# Created By : RheagalFire #
# Updated Date: 2026.04.24 00:00:00 #
# ================================================== #

from typing import Any, List, Dict, Optional, Sequence

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms import (
CompletionResponse,
CompletionResponseGen,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core.llms.llm import BaseLLM as LlamaBaseLLM

from pygpt_net.core.types import MODE_LLAMA_INDEX
from pygpt_net.item.model import ModelItem
from pygpt_net.provider.llms.base import BaseLLM


class LiteLLMIndex(CustomLLM):
"""LlamaIndex CustomLLM that routes to 100+ providers via litellm.completion()."""

model_name: str = "openai/gpt-4o-mini"
temperature: float = 0.7
max_tokens: int = 1024
api_key: Optional[str] = None
api_base: Optional[str] = None

@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
model_name=self.model_name,
num_output=self.max_tokens,
)

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
import litellm

completion_kwargs: Dict[str, Any] = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": self.max_tokens,
# drop_params silently drops provider-unsupported kwargs
# to prevent cross-provider errors
"drop_params": True,
}
if self.api_key:
completion_kwargs["api_key"] = self.api_key
if self.api_base:
completion_kwargs["api_base"] = self.api_base

response = litellm.completion(**completion_kwargs)
text = response.choices[0].message.content or ""
return CompletionResponse(text=text, raw=response.model_dump())

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
import litellm

completion_kwargs: Dict[str, Any] = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": True,
"drop_params": True,
}
if self.api_key:
completion_kwargs["api_key"] = self.api_key
if self.api_base:
completion_kwargs["api_base"] = self.api_base

def gen() -> CompletionResponseGen:
text = ""
stream = litellm.completion(**completion_kwargs)
for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
content = getattr(delta, "content", "") or ""
text += content
yield CompletionResponse(
delta=content, text=text, raw=chunk.model_dump()
)

return gen()


class LiteLLMProvider(BaseLLM):
"""PyGPT LLM provider that routes to 100+ providers via LiteLLM."""

def __init__(self, *args, **kwargs):
super(LiteLLMProvider, self).__init__(*args, **kwargs)
self.id = "litellm"
self.name = "LiteLLM"
self.type = [MODE_LLAMA_INDEX]

def llama(
self,
window,
model: ModelItem,
stream: bool = False
) -> LlamaBaseLLM:
"""
Return LLM provider instance for llama

:param window: window instance
:param model: model instance
:param stream: stream mode
:return: LLM provider instance
"""
args = self.parse_args(model.llama_index, window)
model_name = args.pop("model", model.id)
temperature = float(args.pop("temperature", 0.7))
max_tokens = int(args.pop("max_tokens", 1024))
api_key = args.pop("api_key", "")
api_base = args.pop("api_base", "")
return LiteLLMIndex(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key or None,
api_base=api_base or None,
)