Skip to content
11 changes: 10 additions & 1 deletion pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,16 @@

CLAUDE_EXTENDED_THINKING_MODELS = [
"anthropic/claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219"
"claude-3-7-sonnet-20250219",
"anthropic/claude-sonnet-4-6",
"claude-sonnet-4-6",
"vertex_ai/claude-sonnet-4-6",
"bedrock/anthropic.claude-sonnet-4-6",
"bedrock/us.anthropic.claude-sonnet-4-6",
"bedrock/au.anthropic.claude-sonnet-4-6",
"bedrock/eu.anthropic.claude-sonnet-4-6",
"bedrock/jp.anthropic.claude-sonnet-4-6",
"bedrock/global.anthropic.claude-sonnet-4-6",
]

# Models that require streaming mode
Expand Down
18 changes: 16 additions & 2 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,22 @@ def __init__(self):
# Models that support reasoning effort
self.support_reasoning_models = SUPPORT_REASONING_EFFORT_MODELS

# Models that support extended thinking
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
# Models that support extended thinking (config override replaces the built-in list when non-empty)
override = get_settings().config.get("claude_extended_thinking_models_override", []) or []
if override and not isinstance(override, list):
get_logger().warning(
"Invalid claude_extended_thinking_models_override in config; expected a list of model names. "
"Falling back to the built-in Claude extended-thinking model list."
)
override = []
elif override and not all(isinstance(model, str) and model.strip() for model in override):
get_logger().warning(
"Invalid claude_extended_thinking_models_override in config; "
"expected a list of model name strings. "
"Falling back to the built-in Claude extended-thinking model list."
)
override = []
self.claude_extended_thinking_models = list(override) if override else CLAUDE_EXTENDED_THINKING_MODELS

# Models that require streaming
self.streaming_required_models = STREAMING_REQUIRED_MODELS
Expand Down
265 changes: 265 additions & 0 deletions pr_agent/algo/repo_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import time
from collections import OrderedDict
from html import escape

from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.log import get_logger

TRUNCATION_MARKER = "...(truncated)..."
INSTRUCTION_FILES_INTRO = (
"You are being given instruction files. Follow them as project-specific guidance when reviewing code."
)
MARKDOWN_FENCE = "`````"
REPO_CONTEXT_CACHE_ATTRIBUTE = "_repo_context_cache"
REPO_CONTEXT_CACHE_MAX_SIZE = 256
REPO_CONTEXT_CACHE_TTL_SECONDS = 15 * 60
_REPO_CONTEXT_CACHE_MISS = object()
_unsupported_repo_context_provider_classes = set()


class _RepoContextCache:
def __init__(self, max_size: int = REPO_CONTEXT_CACHE_MAX_SIZE, ttl_seconds: int = REPO_CONTEXT_CACHE_TTL_SECONDS):
self._max_size = max(1, int(max_size))
self._ttl_seconds = max(0, int(ttl_seconds))
self._entries = OrderedDict()

def copy(self):
cache = type(self)(max_size=self._max_size, ttl_seconds=self._ttl_seconds)
cache._entries = self._entries.copy()
return cache

def get(self, key, default=None):
entry = self._entries.get(key)
if entry is None:
return default

value, expires_at = entry
if expires_at <= time.monotonic():
del self._entries[key]
return default

self._entries.move_to_end(key)
return value

def __setitem__(self, key, value):
self._entries[key] = (value, time.monotonic() + self._ttl_seconds)
self._entries.move_to_end(key)
while len(self._entries) > self._max_size:
self._entries.popitem(last=False)


_repo_context_process_cache = _RepoContextCache()


def _get_markdown_fence(content: str) -> str:
fence = MARKDOWN_FENCE
while fence in content:
fence += "`"
return fence


def _get_repo_context_cache_key(context_files: list, max_lines: int) -> tuple[tuple[tuple[str, str], ...], int]:
return tuple((type(file_path).__name__, str(file_path)) for file_path in context_files), max_lines


def _get_repo_context_process_cache_key(git_provider, context_files: list, max_lines: int) -> tuple | None:
Comment on lines +21 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

1. Long type-hinted signatures 📘 Rule violation ⚙ Maintainability

pr_agent/algo/repo_context.py introduces function signatures that likely exceed the repo’s Ruff
120-character line-length limit, which can fail CI linting. Wrap these signatures across multiple
lines (or refactor the type hints) to conform to the configured style.
Agent Prompt
## Issue description
New code adds long function signatures/type annotations that likely exceed Ruff's configured `line-length = 120`, which can cause lint/CI failures.

## Issue Context
Ruff is configured in `pyproject.toml` with `line-length = 120`, and the PR adds at least two long `def ... -> ...:` lines in the new `repo_context.py` module.

## Fix Focus Areas
- pr_agent/algo/repo_context.py[21-23]
- pr_agent/algo/repo_context.py[62-66]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools

try:
pr_url = git_provider.get_pr_url()
except Exception:
pr_url = getattr(git_provider, "pr_url", None)

if not pr_url:
return None

return type(git_provider).__name__, pr_url, _get_repo_context_cache_key(context_files, max_lines)


def _get_repo_context_config() -> tuple[list, int] | None:
context_files = get_settings().config.get("repo_context_files", [])
if not context_files:
return None

if isinstance(context_files, str):
get_logger().warning(
"repo_context_files should be a list of file paths; treating string value as one file path",
artifact={"repo_context_files": context_files},
)
context_files = [context_files]
elif not isinstance(context_files, list):
get_logger().warning(
"repo_context_files should be a list of file paths; skipping repo context",
artifact={"repo_context_files": context_files},
)
return None

max_lines = get_settings().config.get("repo_context_max_lines", 500)
try:
max_lines = max(0, int(max_lines))
except (TypeError, ValueError):
max_lines = 500

return context_files, max_lines


def _provider_supports_repo_context(git_provider) -> bool:
provider_class = type(git_provider)
if provider_class.get_repo_file_content is not GitProvider.get_repo_file_content:
return True

if provider_class not in _unsupported_repo_context_provider_classes:
_unsupported_repo_context_provider_classes.add(provider_class)
get_logger().warning(
f"repo_context_files is configured, but {provider_class.__name__} does not support repository "
"file fetching; skipping repo context"
)
return False


def _get_provider_repo_context_cache(git_provider) -> _RepoContextCache:
repo_context_cache = getattr(git_provider, REPO_CONTEXT_CACHE_ATTRIBUTE, None)
if repo_context_cache is None or not isinstance(repo_context_cache, _RepoContextCache):
repo_context_cache = _RepoContextCache()
setattr(git_provider, REPO_CONTEXT_CACHE_ATTRIBUTE, repo_context_cache)
return repo_context_cache


def _get_cached_repo_context(git_provider, context_files: list, max_lines: int):
process_cache_key = _get_repo_context_process_cache_key(git_provider, context_files, max_lines)
if process_cache_key is not None:
cached_repo_context = _repo_context_process_cache.get(process_cache_key, _REPO_CONTEXT_CACHE_MISS)
if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS:
return cached_repo_context

cache_key = _get_repo_context_cache_key(context_files, max_lines)
cached_repo_context = _get_provider_repo_context_cache(git_provider).get(cache_key, _REPO_CONTEXT_CACHE_MISS)
if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS:
return cached_repo_context

return _REPO_CONTEXT_CACHE_MISS


def _store_repo_context(git_provider, context_files: list, max_lines: int, repo_context: str) -> None:
cache_key = _get_repo_context_cache_key(context_files, max_lines)
_get_provider_repo_context_cache(git_provider)[cache_key] = repo_context

process_cache_key = _get_repo_context_process_cache_key(git_provider, context_files, max_lines)
if process_cache_key:
_repo_context_process_cache[process_cache_key] = repo_context


def _load_repo_context_files(git_provider, context_files: list) -> tuple[dict[str, str], bool]:
files = {}
had_fetch_error = False
for file_path in context_files:
if not isinstance(file_path, str) or not file_path.strip():
get_logger().warning("Skipping invalid repo context file path", artifact={"file_path": file_path})
continue

file_path = file_path.strip()
try:
content = git_provider.get_repo_file_content(file_path)
except Exception as e:
had_fetch_error = True
get_logger().warning(f"Failed to load repo context file: {file_path}", artifact={"error": str(e)})
continue

if not content:
get_logger().debug(f"Repo context file is empty or missing: {file_path}")
continue

if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")

files[file_path] = str(content).rstrip()

return files, had_fetch_error


def render_instruction_files(files: dict[str, str]) -> str:
parts = [
INSTRUCTION_FILES_INTRO,
"<instruction_files>",
]

for path, content in files.items():
scope = path.rsplit("/", 1)[0] if "/" in path else "repo-root"
fence = _get_markdown_fence(content)
parts.append(f'<file path="{escape(path, quote=True)}" scope="{escape(scope, quote=True)}">')
parts.append(f"{fence}markdown")
parts.append(content.rstrip())
parts.append(fence)
parts.append("</file>")
parts.append("")

parts.append("</instruction_files>")
return "\n".join(parts)


def render_instruction_files_with_line_budget(files: dict[str, str], max_lines: int) -> str:
parts = [
INSTRUCTION_FILES_INTRO,
"<instruction_files>",
]
closing_tag = "</instruction_files>"
if max_lines < len(parts) + 1:
return ""

for path, content in files.items():
scope = path.rsplit("/", 1)[0] if "/" in path else "repo-root"
fence = _get_markdown_fence(content)
file_header = [
f'<file path="{escape(path, quote=True)}" scope="{escape(scope, quote=True)}">',
f"{fence}markdown",
]
file_footer = [
fence,
"</file>",
"",
]
content_lines = content.rstrip().splitlines()
reserved_file_and_closing_lines = len(file_header) + len(file_footer) + 1
available_content_lines = max_lines - len(parts) - reserved_file_and_closing_lines
if available_content_lines < 0 or (content_lines and available_content_lines < 1):
break

parts.extend(file_header)
if available_content_lines >= len(content_lines):
parts.extend(content_lines)
else:
if available_content_lines > 1:
parts.extend(content_lines[: available_content_lines - 1])
parts.append(TRUNCATION_MARKER)
parts.extend(file_footer)
break

parts.extend(file_footer)

parts.append(closing_tag)
return "\n".join(parts).strip()


def build_repo_context(git_provider) -> str:
repo_context_config = _get_repo_context_config()
if repo_context_config is None:
return ""

context_files, max_lines = repo_context_config
if not _provider_supports_repo_context(git_provider):
return ""

cached_repo_context = _get_cached_repo_context(git_provider, context_files, max_lines)
if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS:
return cached_repo_context

files, had_fetch_error = _load_repo_context_files(git_provider, context_files)
if not files and had_fetch_error:
return ""

if not files:
_store_repo_context(git_provider, context_files, max_lines, "")
return ""

repo_context = render_instruction_files_with_line_budget(files, max_lines)
_store_repo_context(git_provider, context_files, max_lines, repo_context)
return repo_context
3 changes: 3 additions & 0 deletions pr_agent/git_providers/git_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
def get_repo_settings(self):
pass

def get_repo_file_content(self, file_path: str):
return ""

def get_workspace_name(self):
return ""

Expand Down
30 changes: 30 additions & 0 deletions pr_agent/git_providers/gitea_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, url: Optional[str] = None):
self.file_contents = {}
self.file_diffs = {}
self.sha = None
self.base_sha = ""
self.base_ref = ""
self.diff_files = []
self.incremental = IncrementalPR(False)
self.comments_list = []
Expand Down Expand Up @@ -738,6 +740,34 @@ def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None:
clone_url += f"{gitea_token}@{base_url}{repo_full_name}"
return clone_url

def get_repo_file_content(self, file_path: str) -> str:
"""Get content of a file from the repository target branch.

This method implements the interface required by PR #2387 repo_context feature.
It retrieves file content from the PR target when available, falling back to
the PR head commit for non-PR contexts.
"""
try:
if not self.owner or not self.repo:
self.logger.warning(f"Cannot get repo file content: owner or repo not set")
return ""

ref = self.base_sha or self.base_ref or self.sha
if not ref:
self.logger.warning(f"Cannot get repo file content: ref not set")
return ""

content = self.repo_api.get_file_content(
owner=self.owner,
repo=self.repo,
commit_sha=ref,
filepath=file_path
)
return content
except Exception as e:
self.logger.debug(f"Failed to load repo file: {file_path}, error: {e}")
return ""
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

class RepoApi(giteapy.RepositoryApi):
def __init__(self, client: giteapy.ApiClient):
self.repository = giteapy.RepositoryApi(client)
Expand Down
13 changes: 12 additions & 1 deletion pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,18 @@ def get_repo_settings(self):
# more logical to take 'pr_agent.toml' from the default branch
contents = self.repo_obj.get_contents(".pr_agent.toml").decoded_content
return contents
except Exception:
except Exception as e:
get_logger().warning(f"Failed to load .pr_agent.toml file, error: {e}")
return ""

def get_repo_file_content(self, file_path: str):
try:
contents = self.repo_obj.get_contents(file_path).decoded_content
if isinstance(contents, bytes):
return contents.decode("utf-8", errors="replace")
return contents
except Exception as e:
get_logger().warning(f"Failed to load repo file: {file_path}, error: {e}")
return ""

def get_workspace_name(self):
Expand Down
Loading
Loading