diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index b55b58a77a..95a51636d2 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -306,7 +306,16 @@ CLAUDE_EXTENDED_THINKING_MODELS = [ "anthropic/claude-3-7-sonnet-20250219", - "claude-3-7-sonnet-20250219" + "claude-3-7-sonnet-20250219", + "anthropic/claude-sonnet-4-6", + "claude-sonnet-4-6", + "vertex_ai/claude-sonnet-4-6", + "bedrock/anthropic.claude-sonnet-4-6", + "bedrock/us.anthropic.claude-sonnet-4-6", + "bedrock/au.anthropic.claude-sonnet-4-6", + "bedrock/eu.anthropic.claude-sonnet-4-6", + "bedrock/jp.anthropic.claude-sonnet-4-6", + "bedrock/global.anthropic.claude-sonnet-4-6", ] # Models that require streaming mode diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index a6e79d7a07..bad50baf7f 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -234,8 +234,22 @@ def __init__(self): # Models that support reasoning effort self.support_reasoning_models = SUPPORT_REASONING_EFFORT_MODELS - # Models that support extended thinking - self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS + # Models that support extended thinking (config override replaces the built-in list when non-empty) + override = get_settings().config.get("claude_extended_thinking_models_override", []) or [] + if override and not isinstance(override, list): + get_logger().warning( + "Invalid claude_extended_thinking_models_override in config; expected a list of model names. " + "Falling back to the built-in Claude extended-thinking model list." + ) + override = [] + elif override and not all(isinstance(model, str) and model.strip() for model in override): + get_logger().warning( + "Invalid claude_extended_thinking_models_override in config; " + "expected a list of model name strings. " + "Falling back to the built-in Claude extended-thinking model list." + ) + override = [] + self.claude_extended_thinking_models = list(override) if override else CLAUDE_EXTENDED_THINKING_MODELS # Models that require streaming self.streaming_required_models = STREAMING_REQUIRED_MODELS diff --git a/pr_agent/algo/repo_context.py b/pr_agent/algo/repo_context.py new file mode 100644 index 0000000000..45b5b9c2a4 --- /dev/null +++ b/pr_agent/algo/repo_context.py @@ -0,0 +1,265 @@ +import time +from collections import OrderedDict +from html import escape + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers.git_provider import GitProvider +from pr_agent.log import get_logger + +TRUNCATION_MARKER = "...(truncated)..." +INSTRUCTION_FILES_INTRO = ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code." +) +MARKDOWN_FENCE = "`````" +REPO_CONTEXT_CACHE_ATTRIBUTE = "_repo_context_cache" +REPO_CONTEXT_CACHE_MAX_SIZE = 256 +REPO_CONTEXT_CACHE_TTL_SECONDS = 15 * 60 +_REPO_CONTEXT_CACHE_MISS = object() +_unsupported_repo_context_provider_classes = set() + + +class _RepoContextCache: + def __init__(self, max_size: int = REPO_CONTEXT_CACHE_MAX_SIZE, ttl_seconds: int = REPO_CONTEXT_CACHE_TTL_SECONDS): + self._max_size = max(1, int(max_size)) + self._ttl_seconds = max(0, int(ttl_seconds)) + self._entries = OrderedDict() + + def copy(self): + cache = type(self)(max_size=self._max_size, ttl_seconds=self._ttl_seconds) + cache._entries = self._entries.copy() + return cache + + def get(self, key, default=None): + entry = self._entries.get(key) + if entry is None: + return default + + value, expires_at = entry + if expires_at <= time.monotonic(): + del self._entries[key] + return default + + self._entries.move_to_end(key) + return value + + def __setitem__(self, key, value): + self._entries[key] = (value, time.monotonic() + self._ttl_seconds) + self._entries.move_to_end(key) + while len(self._entries) > self._max_size: + self._entries.popitem(last=False) + + +_repo_context_process_cache = _RepoContextCache() + + +def _get_markdown_fence(content: str) -> str: + fence = MARKDOWN_FENCE + while fence in content: + fence += "`" + return fence + + +def _get_repo_context_cache_key(context_files: list, max_lines: int) -> tuple[tuple[tuple[str, str], ...], int]: + return tuple((type(file_path).__name__, str(file_path)) for file_path in context_files), max_lines + + +def _get_repo_context_process_cache_key(git_provider, context_files: list, max_lines: int) -> tuple | None: + try: + pr_url = git_provider.get_pr_url() + except Exception: + pr_url = getattr(git_provider, "pr_url", None) + + if not pr_url: + return None + + return type(git_provider).__name__, pr_url, _get_repo_context_cache_key(context_files, max_lines) + + +def _get_repo_context_config() -> tuple[list, int] | None: + context_files = get_settings().config.get("repo_context_files", []) + if not context_files: + return None + + if isinstance(context_files, str): + get_logger().warning( + "repo_context_files should be a list of file paths; treating string value as one file path", + artifact={"repo_context_files": context_files}, + ) + context_files = [context_files] + elif not isinstance(context_files, list): + get_logger().warning( + "repo_context_files should be a list of file paths; skipping repo context", + artifact={"repo_context_files": context_files}, + ) + return None + + max_lines = get_settings().config.get("repo_context_max_lines", 500) + try: + max_lines = max(0, int(max_lines)) + except (TypeError, ValueError): + max_lines = 500 + + return context_files, max_lines + + +def _provider_supports_repo_context(git_provider) -> bool: + provider_class = type(git_provider) + if provider_class.get_repo_file_content is not GitProvider.get_repo_file_content: + return True + + if provider_class not in _unsupported_repo_context_provider_classes: + _unsupported_repo_context_provider_classes.add(provider_class) + get_logger().warning( + f"repo_context_files is configured, but {provider_class.__name__} does not support repository " + "file fetching; skipping repo context" + ) + return False + + +def _get_provider_repo_context_cache(git_provider) -> _RepoContextCache: + repo_context_cache = getattr(git_provider, REPO_CONTEXT_CACHE_ATTRIBUTE, None) + if repo_context_cache is None or not isinstance(repo_context_cache, _RepoContextCache): + repo_context_cache = _RepoContextCache() + setattr(git_provider, REPO_CONTEXT_CACHE_ATTRIBUTE, repo_context_cache) + return repo_context_cache + + +def _get_cached_repo_context(git_provider, context_files: list, max_lines: int): + process_cache_key = _get_repo_context_process_cache_key(git_provider, context_files, max_lines) + if process_cache_key is not None: + cached_repo_context = _repo_context_process_cache.get(process_cache_key, _REPO_CONTEXT_CACHE_MISS) + if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS: + return cached_repo_context + + cache_key = _get_repo_context_cache_key(context_files, max_lines) + cached_repo_context = _get_provider_repo_context_cache(git_provider).get(cache_key, _REPO_CONTEXT_CACHE_MISS) + if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS: + return cached_repo_context + + return _REPO_CONTEXT_CACHE_MISS + + +def _store_repo_context(git_provider, context_files: list, max_lines: int, repo_context: str) -> None: + cache_key = _get_repo_context_cache_key(context_files, max_lines) + _get_provider_repo_context_cache(git_provider)[cache_key] = repo_context + + process_cache_key = _get_repo_context_process_cache_key(git_provider, context_files, max_lines) + if process_cache_key: + _repo_context_process_cache[process_cache_key] = repo_context + + +def _load_repo_context_files(git_provider, context_files: list) -> tuple[dict[str, str], bool]: + files = {} + had_fetch_error = False + for file_path in context_files: + if not isinstance(file_path, str) or not file_path.strip(): + get_logger().warning("Skipping invalid repo context file path", artifact={"file_path": file_path}) + continue + + file_path = file_path.strip() + try: + content = git_provider.get_repo_file_content(file_path) + except Exception as e: + had_fetch_error = True + get_logger().warning(f"Failed to load repo context file: {file_path}", artifact={"error": str(e)}) + continue + + if not content: + get_logger().debug(f"Repo context file is empty or missing: {file_path}") + continue + + if isinstance(content, bytes): + content = content.decode("utf-8", errors="replace") + + files[file_path] = str(content).rstrip() + + return files, had_fetch_error + + +def render_instruction_files(files: dict[str, str]) -> str: + parts = [ + INSTRUCTION_FILES_INTRO, + "", + ] + + for path, content in files.items(): + scope = path.rsplit("/", 1)[0] if "/" in path else "repo-root" + fence = _get_markdown_fence(content) + parts.append(f'') + parts.append(f"{fence}markdown") + parts.append(content.rstrip()) + parts.append(fence) + parts.append("") + parts.append("") + + parts.append("") + return "\n".join(parts) + + +def render_instruction_files_with_line_budget(files: dict[str, str], max_lines: int) -> str: + parts = [ + INSTRUCTION_FILES_INTRO, + "", + ] + closing_tag = "" + if max_lines < len(parts) + 1: + return "" + + for path, content in files.items(): + scope = path.rsplit("/", 1)[0] if "/" in path else "repo-root" + fence = _get_markdown_fence(content) + file_header = [ + f'', + f"{fence}markdown", + ] + file_footer = [ + fence, + "", + "", + ] + content_lines = content.rstrip().splitlines() + reserved_file_and_closing_lines = len(file_header) + len(file_footer) + 1 + available_content_lines = max_lines - len(parts) - reserved_file_and_closing_lines + if available_content_lines < 0 or (content_lines and available_content_lines < 1): + break + + parts.extend(file_header) + if available_content_lines >= len(content_lines): + parts.extend(content_lines) + else: + if available_content_lines > 1: + parts.extend(content_lines[: available_content_lines - 1]) + parts.append(TRUNCATION_MARKER) + parts.extend(file_footer) + break + + parts.extend(file_footer) + + parts.append(closing_tag) + return "\n".join(parts).strip() + + +def build_repo_context(git_provider) -> str: + repo_context_config = _get_repo_context_config() + if repo_context_config is None: + return "" + + context_files, max_lines = repo_context_config + if not _provider_supports_repo_context(git_provider): + return "" + + cached_repo_context = _get_cached_repo_context(git_provider, context_files, max_lines) + if cached_repo_context is not _REPO_CONTEXT_CACHE_MISS: + return cached_repo_context + + files, had_fetch_error = _load_repo_context_files(git_provider, context_files) + if not files and had_fetch_error: + return "" + + if not files: + _store_repo_context(git_provider, context_files, max_lines, "") + return "" + + repo_context = render_instruction_files_with_line_budget(files, max_lines) + _store_repo_context(git_provider, context_files, max_lines, repo_context) + return repo_context diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 631e189c04..e831c8f9dd 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -274,6 +274,9 @@ def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool: def get_repo_settings(self): pass + def get_repo_file_content(self, file_path: str): + return "" + def get_workspace_name(self): return "" diff --git a/pr_agent/git_providers/gitea_provider.py b/pr_agent/git_providers/gitea_provider.py index 89a6248e9b..459610eea9 100644 --- a/pr_agent/git_providers/gitea_provider.py +++ b/pr_agent/git_providers/gitea_provider.py @@ -61,6 +61,8 @@ def __init__(self, url: Optional[str] = None): self.file_contents = {} self.file_diffs = {} self.sha = None + self.base_sha = "" + self.base_ref = "" self.diff_files = [] self.incremental = IncrementalPR(False) self.comments_list = [] @@ -738,6 +740,34 @@ def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: clone_url += f"{gitea_token}@{base_url}{repo_full_name}" return clone_url + def get_repo_file_content(self, file_path: str) -> str: + """Get content of a file from the repository target branch. + + This method implements the interface required by PR #2387 repo_context feature. + It retrieves file content from the PR target when available, falling back to + the PR head commit for non-PR contexts. + """ + try: + if not self.owner or not self.repo: + self.logger.warning(f"Cannot get repo file content: owner or repo not set") + return "" + + ref = self.base_sha or self.base_ref or self.sha + if not ref: + self.logger.warning(f"Cannot get repo file content: ref not set") + return "" + + content = self.repo_api.get_file_content( + owner=self.owner, + repo=self.repo, + commit_sha=ref, + filepath=file_path + ) + return content + except Exception as e: + self.logger.debug(f"Failed to load repo file: {file_path}, error: {e}") + return "" + class RepoApi(giteapy.RepositoryApi): def __init__(self, client: giteapy.ApiClient): self.repository = giteapy.RepositoryApi(client) diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 6416486d84..a3ad693d89 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -737,7 +737,18 @@ def get_repo_settings(self): # more logical to take 'pr_agent.toml' from the default branch contents = self.repo_obj.get_contents(".pr_agent.toml").decoded_content return contents - except Exception: + except Exception as e: + get_logger().warning(f"Failed to load .pr_agent.toml file, error: {e}") + return "" + + def get_repo_file_content(self, file_path: str): + try: + contents = self.repo_obj.get_contents(file_path).decoded_content + if isinstance(contents, bytes): + return contents.decode("utf-8", errors="replace") + return contents + except Exception as e: + get_logger().warning(f"Failed to load repo file: {file_path}, error: {e}") return "" def get_workspace_name(self): diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index b3f54920d0..a13a997c4a 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -670,9 +670,8 @@ def publish_code_suggestions(self, code_suggestions: list) -> bool: target_file = None for file in diff_files: if file.filename == relevant_file: - if file.filename == relevant_file: - target_file = file - break + target_file = file + break range = relevant_lines_end - relevant_lines_start # no need to add 1 body = body.replace('```suggestion', f'```suggestion:-0+{range}') lines = target_file.head_file.splitlines() @@ -797,6 +796,17 @@ def get_repo_settings(self): except Exception: return "" + def get_repo_file_content(self, file_path: str): + try: + project = self.gl.projects.get(self.id_project) + contents = project.files.get(file_path=file_path, ref=project.default_branch).decode() + return decode_if_bytes(contents) + except GitlabGetError: + return "" + except Exception as e: + get_logger().warning(f"Failed to load repo file: {file_path}, error: {e}") + return "" + def get_workspace_name(self): return self.id_project.split('/')[0] diff --git a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml index 36b4d0dcf6..1d2e402ae2 100644 --- a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml +++ b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml @@ -82,6 +82,15 @@ Extra user-provided instructions (should be addressed with high priority): ====== {%- endif %} +{%- if repo_context %} + + +Repository context: +====== +{{ repo_context }} +====== +{%- endif %} + The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions: ===== diff --git a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml index 6178ee23c0..de4cb7de05 100644 --- a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml +++ b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml @@ -71,6 +71,15 @@ Extra user-provided instructions (should be addressed with high priority): ====== {%- endif %} +{%- if repo_context %} + + +Repository context: +====== +{{ repo_context }} +====== +{%- endif %} + The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions: ===== diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index f4d63a73f2..71510dcb89 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -27,6 +27,8 @@ ai_timeout=120 # 2 minutes skip_keys = [] custom_reasoning_model = false # when true, disables system messages and temperature controls for models that don't support chat-style inputs response_language="en-US" # Language locales code for PR responses in ISO 3166 and ISO 639 format (e.g., "en-US", "it-IT", "zh-CN", ...) +repo_context_files = [] # Explicit repository-relative files to include as AI prompt context; use a list, e.g. ["AGENTS.md"] +repo_context_max_lines = 500 # Maximum total rendered lines for repo context, including wrapper tags # token limits max_description_tokens = 500 max_commits_tokens = 500 @@ -63,6 +65,10 @@ reasoning_effort = "medium" # "none", "minimal", "low", "medium", "high", "xhigh enable_claude_extended_thinking = false # Set to true to enable extended thinking feature extended_thinking_budget_tokens = 2048 extended_thinking_max_output_tokens = 4096 +# Optional: override the built-in list of Claude models that receive the extended-thinking payload. +# When non-empty, this list fully replaces the built-in defaults (see CLAUDE_EXTENDED_THINKING_MODELS +# in pr_agent/algo/__init__.py). Leave empty to use the defaults. +claude_extended_thinking_models_override = [] # Extract issue number from PR source branch name (e.g. feature/1-auth-google -> issue #1). When true, branch-derived # issue URLs are merged with tickets from the PR description for compliance. Set to false to restore description-only behaviour. # Note: Branch-name extraction is GitHub-only for now; other providers planned for later. diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 2627401ea9..9206fcf403 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -16,6 +16,13 @@ Extra instructions from the user: ===== {% endif %} +{%- if repo_context %} + +Repository context: +===== +{{ repo_context }} +===== +{% endif %} The output must be a YAML object equivalent to type $PRDescription, according to the following Pydantic definitions: ===== diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index bbe6c6d04c..8cef301618 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -69,6 +69,15 @@ Extra instructions from the user: ====== {% endif %} +{%- if repo_context %} + + +Repository context: +====== +{{ repo_context }} +====== +{% endif %} + The output must be a YAML object equivalent to type $PRReview, according to the following Pydantic definitions: ===== diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 6372396c8b..abdf4cdcce 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -17,6 +17,7 @@ from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models) +from pr_agent.algo.repo_context import build_repo_context from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags, show_relevant_configurations, get_max_tokens, clip_tokens, get_model) @@ -67,6 +68,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None, "diff_no_line_numbers": "", # empty diff for initial calculation "num_code_suggestions": num_code_suggestions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions, + "repo_context": build_repo_context(self.git_provider), "commit_messages_str": self.git_provider.get_commit_messages(), "relevant_best_practices": "", "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 26ea5d190a..1b2135fba8 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -14,6 +14,7 @@ get_pr_diff, get_pr_diff_multiple_patchs, retry_with_fallback_models) +from pr_agent.algo.repo_context import build_repo_context from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens, get_max_tokens, get_user_labels, load_yaml, @@ -67,6 +68,7 @@ def __init__(self, pr_url: str, args: list = None, "language": self.main_pr_language, "diff": "", # empty diff for initial calculation "extra_instructions": get_settings().pr_description.extra_instructions, + "repo_context": build_repo_context(self.git_provider), "commit_messages_str": self.git_provider.get_commit_messages(), "enable_custom_labels": get_settings().config.enable_custom_labels, "custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c4917f3597..437b8f9474 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -12,6 +12,7 @@ from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, get_pr_diff, retry_with_fallback_models) +from pr_agent.algo.repo_context import build_repo_context from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import (ModelType, PRReviewHeader, convert_to_markdown_v2, github_action_output, @@ -92,6 +93,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, 'question_str': question_str, 'answer_str': answer_str, "extra_instructions": get_settings().pr_reviewer.extra_instructions, + "repo_context": build_repo_context(self.git_provider), "commit_messages_str": self.git_provider.get_commit_messages(), "custom_labels": "", "enable_custom_labels": get_settings().config.enable_custom_labels, diff --git a/tests/unittest/test_gitea_provider.py b/tests/unittest/test_gitea_provider.py index 4174b398d0..6544d2ec8c 100644 --- a/tests/unittest/test_gitea_provider.py +++ b/tests/unittest/test_gitea_provider.py @@ -1,6 +1,8 @@ from io import BytesIO from unittest.mock import MagicMock, patch +from pr_agent.git_providers.gitea_provider import GiteaProvider + class TestGiteaProvider: @patch('pr_agent.git_providers.gitea_provider.get_settings') @@ -103,3 +105,66 @@ def call_api_side_effect(path, method, **kwargs): args, kwargs = mock_api_client.call_api.call_args assert args[0] == '/repos/owner/repo/pulls/123/commits' assert kwargs.get('auth_settings') == ['AuthorizationHeaderToken'] + + def test_get_repo_file_content_loads_from_base_sha(self): + provider = GiteaProvider.__new__(GiteaProvider) + provider.owner = "owner" + provider.repo = "repo" + provider.sha = "head-sha" + provider.base_sha = "base-sha" + provider.base_ref = "main" + provider.logger = MagicMock() + provider.repo_api = MagicMock() + provider.repo_api.get_file_content.return_value = "repo context" + + content = provider.get_repo_file_content("AGENTS.md") + + assert content == "repo context" + provider.repo_api.get_file_content.assert_called_once_with( + owner="owner", + repo="repo", + commit_sha="base-sha", + filepath="AGENTS.md" + ) + + def test_get_repo_file_content_loads_from_base_ref_when_base_sha_missing(self): + provider = GiteaProvider.__new__(GiteaProvider) + provider.owner = "owner" + provider.repo = "repo" + provider.sha = "head-sha" + provider.base_sha = "" + provider.base_ref = "main" + provider.logger = MagicMock() + provider.repo_api = MagicMock() + provider.repo_api.get_file_content.return_value = "repo context" + + content = provider.get_repo_file_content("AGENTS.md") + + assert content == "repo context" + provider.repo_api.get_file_content.assert_called_once_with( + owner="owner", + repo="repo", + commit_sha="main", + filepath="AGENTS.md" + ) + + def test_get_repo_file_content_falls_back_to_head_sha_when_base_missing(self): + provider = GiteaProvider.__new__(GiteaProvider) + provider.owner = "owner" + provider.repo = "repo" + provider.sha = "head-sha" + provider.base_sha = "" + provider.base_ref = "" + provider.logger = MagicMock() + provider.repo_api = MagicMock() + provider.repo_api.get_file_content.return_value = "repo context" + + content = provider.get_repo_file_content("AGENTS.md") + + assert content == "repo context" + provider.repo_api.get_file_content.assert_called_once_with( + owner="owner", + repo="repo", + commit_sha="head-sha", + filepath="AGENTS.md" + ) diff --git a/tests/unittest/test_gitlab_provider.py b/tests/unittest/test_gitlab_provider.py index c3864264d8..951b83c540 100644 --- a/tests/unittest/test_gitlab_provider.py +++ b/tests/unittest/test_gitlab_provider.py @@ -73,6 +73,27 @@ def test_get_pr_file_content_other_exception(self, gitlab_provider, mock_project assert content == "" + def test_get_repo_file_content_loads_from_default_branch(self, gitlab_provider, mock_gitlab_client, mock_project): + mock_project.default_branch = "main" + mock_file = MagicMock(ProjectFile) + mock_file.decode.return_value = b"repo context" + mock_project.files.get.return_value = mock_file + + content = gitlab_provider.get_repo_file_content("AGENTS.md") + + assert content == "repo context" + mock_gitlab_client.projects.get.assert_called_with("test/repo") + mock_project.files.get.assert_called_once_with(file_path="AGENTS.md", ref="main") + mock_file.decode.assert_called_once() + + def test_get_repo_file_content_treats_missing_file_as_empty(self, gitlab_provider, mock_project): + mock_project.default_branch = "main" + mock_project.files.get.side_effect = GitlabGetError("404 Not Found") + + content = gitlab_provider.get_repo_file_content("AGENTS.md") + + assert content == "" + def test_create_or_update_pr_file_create_new(self, gitlab_provider, mock_project): mock_project.files.get.side_effect = GitlabGetError("404 Not Found") mock_file = MagicMock() diff --git a/tests/unittest/test_litellm_claude_extended_thinking.py b/tests/unittest/test_litellm_claude_extended_thinking.py new file mode 100644 index 0000000000..45f7f34bba --- /dev/null +++ b/tests/unittest/test_litellm_claude_extended_thinking.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler +from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler + + +def settings_with_claude_override(override): + return SimpleNamespace( + config=SimpleNamespace( + verbosity_level=0, + get=lambda key, default=None: ( + override if key == "claude_extended_thinking_models_override" else default + ), + ), + litellm=SimpleNamespace(get=lambda key, default=None: default), + get=lambda key, default=None: default, + ) + + +@pytest.fixture +def logger(): + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.get_logger") as get_logger: + logger = MagicMock() + get_logger.return_value = logger + yield logger + + +@pytest.mark.parametrize( + "override", + [ + "claude-3-7-sonnet-latest", + ["claude-3-7-sonnet-latest", 123], + [""], + ], +) +def test_invalid_claude_extended_thinking_override_falls_back_to_built_in_models( + monkeypatch, + logger, + override, +): + monkeypatch.setattr(litellm_handler, "get_settings", lambda: settings_with_claude_override(override)) + + handler = LiteLLMAIHandler() + + assert handler.claude_extended_thinking_models == CLAUDE_EXTENDED_THINKING_MODELS + logger.warning.assert_called_once() + + +def test_valid_claude_extended_thinking_override_replaces_built_in_models(monkeypatch, logger): + override = ["custom-claude-model"] + monkeypatch.setattr(litellm_handler, "get_settings", lambda: settings_with_claude_override(override)) + + handler = LiteLLMAIHandler() + + assert handler.claude_extended_thinking_models == ["custom-claude-model"] + assert handler.claude_extended_thinking_models is not override + logger.warning.assert_not_called() diff --git a/tests/unittest/test_repo_context.py b/tests/unittest/test_repo_context.py new file mode 100644 index 0000000000..72448f8688 --- /dev/null +++ b/tests/unittest/test_repo_context.py @@ -0,0 +1,501 @@ +from unittest.mock import Mock, patch + +import pytest +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo import repo_context +from pr_agent.algo.repo_context import ( + TRUNCATION_MARKER, + build_repo_context, + render_instruction_files, + render_instruction_files_with_line_budget, +) +from pr_agent.config_loader import get_settings +from pr_agent.git_providers.git_provider import GitProvider +from pr_agent.git_providers.github_provider import GithubProvider + + +class FakeProvider: + def __init__(self, files, pr_url=None): + self.files = files + self.pr_url = pr_url + self.requested_paths = [] + + def get_repo_file_content(self, file_path: str): + self.requested_paths.append(file_path) + return self.files.get(file_path) + + +class UnsupportedProvider: + get_repo_file_content = GitProvider.get_repo_file_content + + +@pytest.fixture +def repo_context_settings(): + settings = get_settings() + original_files = settings.config.get("repo_context_files", []) + original_max_lines = settings.config.get("repo_context_max_lines", 500) + original_warned_provider_classes = repo_context._unsupported_repo_context_provider_classes.copy() + original_process_cache = repo_context._repo_context_process_cache.copy() + + yield settings + + settings.set("CONFIG.REPO_CONTEXT_FILES", original_files) + settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", original_max_lines) + repo_context._unsupported_repo_context_provider_classes = original_warned_provider_classes + repo_context._repo_context_process_cache = original_process_cache + + +def test_build_repo_context_fetches_and_formats_configured_files(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md", "CONTRIBUTING.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + provider = FakeProvider({ + "AGENTS.md": "# Agent Guide\nUse focused tests.", + "CONTRIBUTING.md": "Keep PRs small.", + }) + + context = build_repo_context(provider) + + assert context == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "`````markdown\n" + "# Agent Guide\n" + "Use focused tests.\n" + "`````\n" + "\n\n" + '\n' + "`````markdown\n" + "Keep PRs small.\n" + "`````\n" + "\n\n" + "" + ) + assert provider.requested_paths == ["AGENTS.md", "CONTRIBUTING.md"] + + +def test_build_repo_context_reuses_provider_cache_for_same_config(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md", "CONTRIBUTING.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + provider = FakeProvider({ + "AGENTS.md": "Repo purpose", + "CONTRIBUTING.md": "Keep PRs small.", + }) + + first_context = build_repo_context(provider) + second_context = build_repo_context(provider) + + assert second_context == first_context + assert provider.requested_paths == ["AGENTS.md", "CONTRIBUTING.md"] + + +def test_build_repo_context_reuses_process_cache_for_same_pr_url(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + first_provider = FakeProvider({"AGENTS.md": "Repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + second_provider = FakeProvider({"AGENTS.md": "Changed repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + + first_context = build_repo_context(first_provider) + second_context = build_repo_context(second_provider) + + assert second_context == first_context + assert "Repo purpose" in second_context + assert "Changed repo purpose" not in second_context + assert first_provider.requested_paths == ["AGENTS.md"] + assert second_provider.requested_paths == [] + + +def test_build_repo_context_refreshes_process_cache_after_ttl(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + first_provider = FakeProvider({"AGENTS.md": "Repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + second_provider = FakeProvider({"AGENTS.md": "Changed repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + + with patch("pr_agent.algo.repo_context.time.monotonic", side_effect=[100, 100, 2000, 2000, 2000]): + first_context = build_repo_context(first_provider) + second_context = build_repo_context(second_provider) + + assert "Repo purpose" in first_context + assert "Changed repo purpose" in second_context + assert first_provider.requested_paths == ["AGENTS.md"] + assert second_provider.requested_paths == ["AGENTS.md"] + + +def test_build_repo_context_refreshes_empty_process_cache_after_ttl(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + first_provider = FakeProvider({}, pr_url="https://example.com/org/repo/pull/1") + second_provider = FakeProvider({"AGENTS.md": "Repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + + with patch("pr_agent.algo.repo_context.time.monotonic", side_effect=[100, 100, 2000, 2000, 2000]): + first_context = build_repo_context(first_provider) + second_context = build_repo_context(second_provider) + + assert first_context == "" + assert "Repo purpose" in second_context + assert first_provider.requested_paths == ["AGENTS.md"] + assert second_provider.requested_paths == ["AGENTS.md"] + + +def test_repo_context_cache_evicts_oldest_entry_when_full(): + cache = repo_context._RepoContextCache(max_size=2, ttl_seconds=900) + missing = object() + + with patch("pr_agent.algo.repo_context.time.monotonic", return_value=100): + cache["first"] = "one" + cache["second"] = "two" + cache["third"] = "three" + + assert cache.get("first", missing) is missing + assert cache.get("second", missing) == "two" + assert cache.get("third", missing) == "three" + + +def test_get_repo_context_config_normalizes_inputs(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", "AGENTS.md") + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", "12") + + assert repo_context._get_repo_context_config() == (["AGENTS.md"], 12) + + +def test_get_repo_context_config_rejects_non_list_container(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", {"AGENTS.md": True}) + + assert repo_context._get_repo_context_config() is None + + +def test_provider_supports_repo_context_warns_once_for_unsupported_provider(repo_context_settings): + provider = UnsupportedProvider() + + with patch("pr_agent.algo.repo_context.get_logger") as mock_get_logger: + assert repo_context._provider_supports_repo_context(provider) is False + assert repo_context._provider_supports_repo_context(provider) is False + + mock_get_logger.return_value.warning.assert_called_once_with( + "repo_context_files is configured, but UnsupportedProvider does not support repository file fetching; " + "skipping repo context" + ) + + +def test_load_repo_context_files_normalizes_fetch_results(): + provider = FakeProvider({ + "AGENTS.md": b"Repo purpose", + "EMPTY.md": "", + "MISSING.md": None, + }) + + files, had_fetch_error = repo_context._load_repo_context_files( + provider, ["AGENTS.md", "EMPTY.md", "MISSING.md", " "] + ) + + assert files == {"AGENTS.md": "Repo purpose"} + assert had_fetch_error is False + assert provider.requested_paths == ["AGENTS.md", "EMPTY.md", "MISSING.md"] + + +def test_load_repo_context_files_reports_fetch_errors(): + provider = FakeProvider({}) + provider.get_repo_file_content = Mock(side_effect=Exception("temporary outage")) + + files, had_fetch_error = repo_context._load_repo_context_files(provider, ["AGENTS.md"]) + + assert files == {} + assert had_fetch_error is True + + +def test_build_repo_context_process_cache_invalidates_when_config_changes(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + first_provider = FakeProvider({ + "AGENTS.md": "Repo purpose", + "CONTRIBUTING.md": "Keep PRs small.", + }, pr_url="https://example.com/org/repo/pull/1") + second_provider = FakeProvider({ + "AGENTS.md": "Repo purpose", + "CONTRIBUTING.md": "Keep PRs small.", + }, pr_url="https://example.com/org/repo/pull/1") + + first_context = build_repo_context(first_provider) + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["CONTRIBUTING.md"]) + second_context = build_repo_context(second_provider) + + assert "Repo purpose" in first_context + assert "Keep PRs small." in second_context + assert first_provider.requested_paths == ["AGENTS.md"] + assert second_provider.requested_paths == ["CONTRIBUTING.md"] + + +def test_build_repo_context_does_not_cache_empty_context_after_fetch_error(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + provider = FakeProvider({"AGENTS.md": "Repo purpose"}, pr_url="https://example.com/org/repo/pull/1") + provider.get_repo_file_content = Mock(side_effect=[Exception("temporary outage"), "Repo purpose"]) + + first_context = build_repo_context(provider) + second_context = build_repo_context(provider) + + assert first_context == "" + assert "Repo purpose" in second_context + assert provider.get_repo_file_content.call_count == 2 + + +def test_build_repo_context_cache_invalidates_when_repo_context_files_change(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + provider = FakeProvider({ + "AGENTS.md": "Repo purpose", + "CONTRIBUTING.md": "Keep PRs small.", + }) + + first_context = build_repo_context(provider) + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["CONTRIBUTING.md"]) + second_context = build_repo_context(provider) + + assert "Repo purpose" in first_context + assert "Keep PRs small." in second_context + assert provider.requested_paths == ["AGENTS.md", "CONTRIBUTING.md"] + + +def test_build_repo_context_cache_invalidates_when_line_budget_changes(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 9) + provider = FakeProvider({"AGENTS.md": "one\ntwo\nthree"}) + + truncated_context = build_repo_context(provider) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 500) + full_context = build_repo_context(provider) + + assert TRUNCATION_MARKER in truncated_context + assert "one\ntwo\nthree" in full_context + assert provider.requested_paths == ["AGENTS.md", "AGENTS.md"] + + +def test_render_instruction_files_escapes_path_and_derives_scope(): + context = render_instruction_files({ + 'docs/Agent "Notes".md': "Use markers.\n", + }) + + assert context == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "`````markdown\n" + "Use markers.\n" + "`````\n" + "\n\n" + "" + ) + + +def test_render_instruction_files_uses_longer_fence_when_content_contains_default_fence(): + context = render_instruction_files({ + "AGENTS.md": "Avoid closing this fence:\n`````", + }) + + assert context == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "``````markdown\n" + "Avoid closing this fence:\n" + "`````\n" + "``````\n" + "\n\n" + "" + ) + + +def test_render_instruction_files_with_line_budget_uses_longer_fence_for_conflicting_content(): + context = render_instruction_files_with_line_budget({ + "AGENTS.md": "Avoid closing this fence:\n`````", + }, max_lines=500) + + assert context == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "``````markdown\n" + "Avoid closing this fence:\n" + "`````\n" + "``````\n" + "\n\n" + "" + ) + + +def test_build_repo_context_skips_invalid_missing_and_empty_files(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["", 7, "MISSING.md", "EMPTY.md", "AGENTS.md"]) + provider = FakeProvider({"EMPTY.md": "", "AGENTS.md": "Loaded context"}) + + assert build_repo_context(provider) == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "`````markdown\n" + "Loaded context\n" + "`````\n" + "\n\n" + "" + ) + assert provider.requested_paths == ["MISSING.md", "EMPTY.md", "AGENTS.md"] + + +def test_build_repo_context_enforces_total_line_cap(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md", "CONTRIBUTING.md"]) + repo_context_settings.set("CONFIG.REPO_CONTEXT_MAX_LINES", 4) + provider = FakeProvider({ + "AGENTS.md": "one\ntwo\nthree", + "CONTRIBUTING.md": "four\nfive", + }) + + context = build_repo_context(provider) + + assert context == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + "" + ) + assert len(context.splitlines()) <= 4 + + +def test_render_instruction_files_with_line_budget_returns_empty_when_wrapper_exceeds_budget(): + context = render_instruction_files_with_line_budget({ + "AGENTS.md": "one", + }, max_lines=2) + + assert context == "" + + +@pytest.mark.parametrize("max_lines", range(0, 12)) +def test_render_instruction_files_with_line_budget_never_exceeds_configured_budget(max_lines): + context = render_instruction_files_with_line_budget({ + "AGENTS.md": "one\ntwo\nthree", + "CONTRIBUTING.md": "four\nfive", + }, max_lines=max_lines) + + assert len(context.splitlines()) <= max_lines + + +def test_build_repo_context_returns_empty_when_no_files_configured(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", []) + + assert build_repo_context(FakeProvider({"AGENTS.md": "repo purpose"})) == "" + + +def test_build_repo_context_treats_string_config_as_single_file(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", "AGENTS.md") + provider = FakeProvider({"AGENTS.md": "repo purpose"}) + + assert build_repo_context(provider) == ( + "You are being given instruction files. Follow them as project-specific guidance when reviewing code.\n" + "\n" + '\n' + "`````markdown\n" + "repo purpose\n" + "`````\n" + "\n\n" + "" + ) + assert provider.requested_paths == ["AGENTS.md"] + + +def test_build_repo_context_skips_non_list_container(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", {"AGENTS.md": True}) + provider = FakeProvider({"AGENTS.md": "repo purpose"}) + + assert build_repo_context(provider) == "" + assert provider.requested_paths == [] + + +def test_build_repo_context_warns_once_for_provider_without_repo_file_fetching(repo_context_settings): + repo_context_settings.set("CONFIG.REPO_CONTEXT_FILES", ["AGENTS.md"]) + provider = UnsupportedProvider() + + with patch("pr_agent.algo.repo_context.get_logger") as mock_get_logger: + context = build_repo_context(provider) + second_context = build_repo_context(provider) + + assert context == "" + assert second_context == "" + mock_get_logger.return_value.warning.assert_called_once_with( + "repo_context_files is configured, but UnsupportedProvider does not support repository file fetching; " + "skipping repo context" + ) + + +def test_github_provider_decodes_repo_context_files_and_treats_failures_as_missing(): + provider = GithubProvider.__new__(GithubProvider) + provider.repo_obj = Mock() + provider.repo_obj.get_contents.return_value.decoded_content = b"repo context" + + assert provider.get_repo_file_content("AGENTS.md") == "repo context" + + provider.repo_obj.get_contents.side_effect = Exception("not found") + + assert provider.get_repo_file_content("MISSING.md") == "" + + +@pytest.mark.parametrize( + "prompt_name,variables", + [ + ( + "pr_review_prompt", + { + "extra_instructions": "", + "repo_context": render_instruction_files({"AGENTS.md": "Repo purpose"}), + "require_can_be_split_review": False, + "related_tickets": "", + "require_estimate_contribution_time_cost": False, + "require_score": False, + "require_tests": True, + "question_str": "", + "require_security_review": True, + "require_todo_scan": False, + "require_estimate_effort_to_review": True, + "num_max_findings": 3, + "num_pr_files": 1, + "is_ai_metadata": False, + }, + ), + ( + "pr_description_prompt", + { + "extra_instructions": "", + "repo_context": render_instruction_files({"AGENTS.md": "Repo purpose"}), + "enable_custom_labels": False, + "custom_labels_class": "", + "enable_semantic_files_types": True, + "include_file_summary_changes": True, + "enable_pr_diagram": False, + }, + ), + ( + "pr_code_suggestions_prompt", + { + "extra_instructions": "", + "repo_context": render_instruction_files({"AGENTS.md": "Repo purpose"}), + "focus_only_on_problems": True, + "num_code_suggestions": 3, + "is_ai_metadata": False, + }, + ), + ( + "pr_code_suggestions_prompt_not_decoupled", + { + "extra_instructions": "", + "repo_context": render_instruction_files({"AGENTS.md": "Repo purpose"}), + "focus_only_on_problems": True, + "num_code_suggestions": 3, + "is_ai_metadata": False, + }, + ), + ], +) +def test_prompt_templates_render_configured_repo_context(prompt_name, variables): + template = getattr(get_settings(), prompt_name).system + + rendered = Environment(undefined=StrictUndefined).from_string(template).render(variables) + + assert "Repository context:" in rendered + assert '' in rendered