diff --git a/pr_agent/git_providers/utils.py b/pr_agent/git_providers/utils.py index 1e64b9578d..49dad2cba4 100644 --- a/pr_agent/git_providers/utils.py +++ b/pr_agent/git_providers/utils.py @@ -35,7 +35,10 @@ def apply_repo_settings(pr_url): category = 'local' try: fd, repo_settings_file = tempfile.mkstemp(suffix='.toml') - os.write(fd, repo_settings) + try: + os.write(fd, repo_settings) + finally: + os.close(fd) try: dynconf_kwargs = {'core_loaders': [], # DISABLE default loaders, otherwise will load toml files more than once. diff --git a/tests/unittest/_settings_helpers.py b/tests/unittest/_settings_helpers.py new file mode 100644 index 0000000000..99bb3099b1 --- /dev/null +++ b/tests/unittest/_settings_helpers.py @@ -0,0 +1,62 @@ +"""Test-only helpers for snapshotting/restoring Dynaconf settings. + +Dynaconf's ``settings.unset(key, force=True)`` does not reliably remove a +dotted leaf (e.g. ``"openai.deployment_id"``) after that leaf was created +via ``settings.set(...)``. The leaf survives inside the parent section's +``DynaBox``, which causes state to leak between tests that share the global +settings singleton. + +These helpers provide a SENTINEL-based snapshot so that keys which were +originally absent are truly removed (not just rewritten to ``None``) during +restore, including for dotted keys. +""" + +from contextlib import suppress + +from pr_agent.config_loader import get_settings + +SENTINEL = object() + + +def snapshot_settings(keys): + """Capture current values for ``keys``; missing keys map to ``SENTINEL``.""" + settings = get_settings() + return {key: settings.get(key, SENTINEL) for key in keys} + + +def _remove_key(settings, key): + """Best-effort removal of ``key`` (supports dotted leaves).""" + if "." in key: + section_name, leaf = key.split(".", 1) + container = getattr(settings, section_name, None) + if container is None: + return + # DynaBox lookups are case-insensitive, but ``pop`` requires the + # stored casing. Find the matching key and pop it. + for stored in list(container.keys()): + if stored.lower() == leaf.lower(): + # ``pop`` with a default never raises KeyError; we narrow to + # ``(AttributeError, TypeError)`` to tolerate exotic container + # types that don't implement a dict-like ``pop``. Any other + # exception is unexpected and should surface so tests fail + # loudly rather than mask Dynaconf state leaks. + with suppress(AttributeError, TypeError): + container.pop(stored, None) + return + return + # ``settings.unset`` raises ``KeyError`` for keys that were never set; + # that case is benign for restore. Any other exception (e.g. a Dynaconf + # internal error) must propagate so that a broken cleanup is visible + # instead of silently leaking state across tests. + with suppress(KeyError): + settings.unset(key, force=True) + + +def restore_settings(snapshot): + """Restore ``snapshot``; truly remove entries whose snapshot is SENTINEL.""" + settings = get_settings() + for key, value in snapshot.items(): + if value is SENTINEL: + _remove_key(settings, key) + else: + settings.set(key, value) diff --git a/tests/unittest/test_apply_repo_settings_security.py b/tests/unittest/test_apply_repo_settings_security.py new file mode 100644 index 0000000000..eaf65ee3de --- /dev/null +++ b/tests/unittest/test_apply_repo_settings_security.py @@ -0,0 +1,338 @@ +""" +Security-focused tests for pr_agent.git_providers.utils.apply_repo_settings. + +These tests verify: +- The repo settings fetch path is skipped when use_repo_settings_file is disabled. +- Valid repo TOML overrides only the specified keys and preserves siblings. +- Invalid TOML produces exactly one local-category configuration error and + does not pollute global settings. +- Forbidden directives (e.g. dynaconf_include) are rejected and produce a + local-category configuration error without polluting settings. +- The temporary file created from the repo settings bytes is removed after + apply_repo_settings, both on success and on failure. +""" + +import copy +import os +import tempfile +from contextlib import suppress + +import pytest + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import utils as git_utils +from pr_agent.git_providers.utils import apply_repo_settings + + +class FakeGitProvider: + """Minimal fake provider exposing the methods apply_repo_settings touches.""" + + def __init__(self, repo_settings_bytes=b""): + self._repo_settings = repo_settings_bytes + self.persistent_comments = [] + self.comments = [] + self.get_repo_settings_calls = 0 + + def get_repo_settings(self): + self.get_repo_settings_calls += 1 + return self._repo_settings + + def is_supported(self, capability): + return capability == "gfm_markdown" + + def publish_persistent_comment(self, body, initial_header, update_header, final_update_message): + self.persistent_comments.append( + { + "body": body, + "initial_header": initial_header, + "update_header": update_header, + "final_update_message": final_update_message, + } + ) + + def publish_comment(self, body): + self.comments.append(body) + + +SNAPSHOT_SECTIONS = ("CONFIG", "PR_REVIEWER", "CUSTOM_SECTION_FOR_TEST") + + +def _snapshot_settings_sections(settings): + return {section: copy.deepcopy(settings.as_dict().get(section)) for section in SNAPSHOT_SECTIONS} + + +def _restore_settings_sections(settings, snapshot): + for section, data in snapshot.items(): + # ``unset`` raises ``KeyError`` if the section was never set during + # the test; that's expected and safe to ignore. Anything else (e.g. + # a Dynaconf internal error) should propagate so a broken teardown + # surfaces instead of silently leaking state into other tests. + with suppress(KeyError): + settings.unset(section, force=True) + if data is not None: + settings.set(section, copy.deepcopy(data), merge=False) + + +_ENV_ABSENT = object() + + +@pytest.fixture +def settings_snapshot(): + """Snapshot the keys mutated by these tests and restore them afterwards. + + Also snapshots the ``AUTO_CAST_FOR_DYNACONF`` environment variable, which + ``apply_repo_settings`` unconditionally sets to ``"false"``. Using a + sentinel for "originally absent" ensures the env restore is exact: + keys that were absent are deleted, never left as a stray ``None``-like + string that could leak into other tests. + """ + settings = get_settings() + snapshot = _snapshot_settings_sections(settings) + env_before = os.environ.get("AUTO_CAST_FOR_DYNACONF", _ENV_ABSENT) + try: + yield + finally: + _restore_settings_sections(settings, snapshot) + if env_before is _ENV_ABSENT: + os.environ.pop("AUTO_CAST_FOR_DYNACONF", None) + else: + os.environ["AUTO_CAST_FOR_DYNACONF"] = env_before + + +def _install_provider(monkeypatch, provider): + captured = {"errors": None, "git_provider": None} + + def fake_get_git_provider_with_context(pr_url): + return provider + + def fake_handle_configurations_errors(errors, git_provider): + captured["errors"] = errors + captured["git_provider"] = git_provider + + monkeypatch.setattr(git_utils, "get_git_provider_with_context", fake_get_git_provider_with_context) + monkeypatch.setattr(git_utils, "handle_configurations_errors", fake_handle_configurations_errors) + return captured + + +def test_disabled_repo_settings_skips_provider_fetch(monkeypatch, settings_snapshot): + provider = FakeGitProvider(repo_settings_bytes=b"[pr_reviewer]\nnum_max_findings = 99\n") + captured = _install_provider(monkeypatch, provider) + + get_settings().set("config.use_repo_settings_file", False) + original_num = get_settings().as_dict().get("PR_REVIEWER", {}).get("num_max_findings") + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert provider.get_repo_settings_calls == 0 + assert captured["errors"] is None + # Settings were not touched. + assert get_settings().as_dict().get("PR_REVIEWER", {}).get("num_max_findings") == original_num + + +def _section(settings, name): + """Return a section dict from settings using a case-insensitive lookup.""" + data = settings.as_dict() + upper = name.upper() + for key, value in data.items(): + if key.upper() == upper: + return value if isinstance(value, dict) else {} + return {} + + +def test_valid_repo_settings_merge_overrides_key_and_preserves_siblings(monkeypatch, settings_snapshot): + provider = FakeGitProvider(repo_settings_bytes=b"[pr_reviewer]\nnum_max_findings = 11\n") + captured = _install_provider(monkeypatch, provider) + + get_settings().set("config.use_repo_settings_file", True) + settings = get_settings() + sibling_before = _section(settings, "pr_reviewer").get("require_tests_review") + assert sibling_before is not None, "Test precondition: sibling key should already be present" + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert provider.get_repo_settings_calls == 1 + assert captured["errors"] is None, f"Unexpected configuration errors: {captured['errors']}" + + pr_reviewer = _section(settings, "pr_reviewer") + assert pr_reviewer.get("num_max_findings") == 11 + # Unrelated sibling key in the same section must be preserved by the merge logic. + assert pr_reviewer.get("require_tests_review") == sibling_before + + +def test_invalid_toml_does_not_pollute_settings(monkeypatch, settings_snapshot): + """ + Malformed TOML must never leak into the live settings. The custom loader + currently swallows the TOMLDecodeError and logs it, so no local error is + propagated to handle_configurations_errors; the surviving security + guarantee is that the existing settings are untouched. + """ + malformed = b"[pr_reviewer\nnum_max_findings = 7\n" + provider = FakeGitProvider(repo_settings_bytes=malformed) + _install_provider(monkeypatch, provider) + + get_settings().set("config.use_repo_settings_file", True) + settings = get_settings() + before = copy.deepcopy(_section(settings, "pr_reviewer")) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + after = _section(settings, "pr_reviewer") + assert after == before + # Whatever errors may or may not be published, the malformed payload must + # never be merged silently into pr_reviewer. + assert after.get("num_max_findings") != 7 + + +@pytest.mark.xfail( + reason=( + "Behavior gap: pr_agent.custom_merge_loader is invoked with silent=True, " + "so TOMLDecodeError is logged and swallowed instead of being surfaced to " + "handle_configurations_errors. apply_repo_settings therefore never publishes " + "a 'local' configuration-error comment for malformed TOML." + ), + strict=True, +) +def test_invalid_toml_publishes_one_local_error(monkeypatch, settings_snapshot): + malformed = b"[pr_reviewer\nnum_max_findings = 7\n" + provider = FakeGitProvider(repo_settings_bytes=malformed) + captured = _install_provider(monkeypatch, provider) + get_settings().set("config.use_repo_settings_file", True) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert captured["errors"] is not None + assert len(captured["errors"]) == 1 + assert captured["errors"][0]["category"] == "local" + assert captured["errors"][0]["settings"] == malformed + + +def test_forbidden_directive_does_not_pollute_settings(monkeypatch, settings_snapshot): + """ + A repo TOML containing forbidden directives (e.g. dynaconf_include) must + not leak into the live settings. As with malformed TOML, the loader's + SecurityError is currently swallowed silently; the security guarantee + checked here is that no part of the payload (including the legitimate + pr_reviewer override) reaches the settings. + """ + forbidden_toml = b"dynaconf_include = ['evil.toml']\n[pr_reviewer]\nnum_max_findings = 42\n" + provider = FakeGitProvider(repo_settings_bytes=forbidden_toml) + _install_provider(monkeypatch, provider) + + get_settings().set("config.use_repo_settings_file", True) + settings = get_settings() + before = copy.deepcopy(_section(settings, "pr_reviewer")) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + after = _section(settings, "pr_reviewer") + # The forbidden file must be rejected wholesale; neither the directive + # nor the piggy-backed pr_reviewer override should be applied. + assert after == before + assert after.get("num_max_findings") != 42 + assert "dynaconf_include" not in {k.lower() for k in settings.as_dict().keys()} + + +@pytest.mark.xfail( + reason=( + "Behavior gap: forbidden-directive SecurityError raised by " + "validate_file_security is swallowed by the silent-loader path, so " + "apply_repo_settings does not publish a 'local' configuration-error " + "comment for forbidden TOML directives." + ), + strict=True, +) +def test_forbidden_directive_publishes_one_local_error(monkeypatch, settings_snapshot): + forbidden_toml = b"dynaconf_include = ['evil.toml']\n[pr_reviewer]\nnum_max_findings = 42\n" + provider = FakeGitProvider(repo_settings_bytes=forbidden_toml) + captured = _install_provider(monkeypatch, provider) + get_settings().set("config.use_repo_settings_file", True) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert captured["errors"] is not None + assert len(captured["errors"]) == 1 + assert captured["errors"][0]["category"] == "local" + assert captured["errors"][0]["settings"] == forbidden_toml + + +def test_temp_file_is_removed_after_successful_apply(monkeypatch, tmp_path, settings_snapshot): + provider = FakeGitProvider(repo_settings_bytes=b"[pr_reviewer]\nnum_max_findings = 5\n") + _install_provider(monkeypatch, provider) + get_settings().set("config.use_repo_settings_file", True) + + known_path = tmp_path / "repo_settings_success.toml" + + def fake_mkstemp(suffix=None, prefix=None, dir=None, text=False): + fd = os.open(str(known_path), os.O_RDWR | os.O_CREAT | os.O_TRUNC) + return fd, str(known_path) + + monkeypatch.setattr(tempfile, "mkstemp", fake_mkstemp) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert not known_path.exists(), "Temp settings file must be removed after successful apply" + + +def test_temp_file_is_removed_after_failed_apply(monkeypatch, tmp_path, settings_snapshot): + """The temp file must be removed even when the Dynaconf load step raises. + + We use *valid* TOML bytes (so the failure cannot be confused with the + silent-swallow malformed-TOML path) and force the failure by replacing + the ``Dynaconf`` symbol bound inside ``pr_agent.git_providers.utils`` + with a stub that raises *after* ``mkstemp`` has been called. We do not + patch the external ``dynaconf`` module — only the imported reference + that ``apply_repo_settings`` actually uses. + """ + valid_toml = b"[pr_reviewer]\nnum_max_findings = 3\n" + provider = FakeGitProvider(repo_settings_bytes=valid_toml) + captured = _install_provider(monkeypatch, provider) + get_settings().set("config.use_repo_settings_file", True) + + known_path = tmp_path / "repo_settings_failure.toml" + mkstemp_calls = {"n": 0} + + def fake_mkstemp(suffix=None, prefix=None, dir=None, text=False): + mkstemp_calls["n"] += 1 + fd = os.open(str(known_path), os.O_RDWR | os.O_CREAT | os.O_TRUNC) + return fd, str(known_path) + + monkeypatch.setattr(tempfile, "mkstemp", fake_mkstemp) + + def exploding_dynaconf(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(git_utils, "Dynaconf", exploding_dynaconf) + + apply_repo_settings("https://example.com/owner/repo/pull/1") + + assert mkstemp_calls["n"] == 1, "mkstemp must have run before the failure" + assert not known_path.exists(), "Temp settings file must be removed even after a failed apply" + + # The local-category configuration error path must have been exercised. + assert captured["errors"] is not None, "handle_configurations_errors should have been called" + assert len(captured["errors"]) == 1 + err = captured["errors"][0] + assert err["category"] == "local" + assert err["settings"] == valid_toml + assert "boom" in err["error"] + + +def test_restore_settings_sections_removes_section_created_after_snapshot(): + settings = get_settings() + original_snapshot = _snapshot_settings_sections(settings) + + try: + settings.unset("CUSTOM_SECTION_FOR_TEST", force=True) + assert "CUSTOM_SECTION_FOR_TEST" not in settings.as_dict() + + snapshot = _snapshot_settings_sections(settings) + assert snapshot["CUSTOM_SECTION_FOR_TEST"] is None + + settings.set("CUSTOM_SECTION_FOR_TEST", {"foo": "bar"}, merge=False) + assert settings.as_dict()["CUSTOM_SECTION_FOR_TEST"] == {"foo": "bar"} + + _restore_settings_sections(settings, snapshot) + + assert "CUSTOM_SECTION_FOR_TEST" not in settings.as_dict() + finally: + _restore_settings_sections(settings, original_snapshot) diff --git a/tests/unittest/test_cli_args_security.py b/tests/unittest/test_cli_args_security.py new file mode 100644 index 0000000000..696d653f6b --- /dev/null +++ b/tests/unittest/test_cli_args_security.py @@ -0,0 +1,145 @@ +from unittest.mock import Mock + +import pytest + +import pr_agent.agent.pr_agent as pr_agent_module +from pr_agent.algo.cli_args import CliArgs + +FORBIDDEN_ARGS = [ + # section-qualified key forms + "--openai.key=secret", + "--OPENAI.KEY=secret", + "--config.openai.key=secret", + # double-underscore form is normalized to dot before matching + "--openai__key=secret", + "--OPENAI__KEY=secret", + # webhook / app secrets via section-qualified prefix + "--github.webhook_secret=secret", + "--github_app.private_key=---BEGIN---", + "--github_app.app_id=123", + "--github_app.webhook_secret=secret", + # base/api URLs (SSRF / redirection style abuses) + "--github.base_url=https://evil.example", + "--litellm.api_base=https://evil.example", + "--litellm.api_type=azure", + "--litellm.api_version=2024-01-01", + "--jira.jira_base_url=https://evil.example", + "--config.url=https://evil.example", + "--config.uri=https://evil.example", + # provider / auth selection and skip lists + "--config.secret_provider=aws", + "--config.git_provider=github", + "--config.skip_keys=foo", + "--auth.bearer_token=abc", + "--provider.personal_access_token=ghp_xxx", + "--provider.PERSONAL_ACCESS_TOKEN=ghp_xxx", + # approval / deployment toggles + "--config.enable_auto_approval=true", + "--config.enable_manual_approval=true", + "--config.enable_comment_approval=true", + "--config.approve_pr_on_self_review=true", + "--config.override_deployment_type=app", + # local cache + "--config.enable_local_cache=true", + "--config.local_cache_path=/etc", + # misc + "--config.shared_secret=xxx", + "--config.app_name=evil", + "--config.analytics_folder=/tmp", + # double-underscore variants of the above + "--github__webhook_secret=secret", + "--github_app__private_key=xxx", + "--litellm__api_base=https://evil.example", +] + + +ALLOWED_ARGS_SINGLE = [ + "--pr_reviewer.num_code_suggestions=3", + "--pr_reviewer.require_tests_review=true", + "--config.response_language=zh-tw", + "--pr_description.publish_labels=false", + # non-flag arguments are not validated against the forbidden list + "some-positional-arg", + "yes", + "because prod is broken", + "", +] + + +@pytest.mark.parametrize("forbidden", FORBIDDEN_ARGS) +def test_validate_user_args_rejects_forbidden(forbidden): + ok, offending = CliArgs.validate_user_args([forbidden]) + assert ok is False, f"Expected {forbidden!r} to be rejected" + assert isinstance(offending, str) and offending, ( + f"Expected an offending-token string for {forbidden!r}, got {offending!r}" + ) + + +@pytest.mark.parametrize("allowed", ALLOWED_ARGS_SINGLE) +def test_validate_user_args_accepts_allowed_single(allowed): + ok, offending = CliArgs.validate_user_args([allowed]) + assert ok is True, ( + f"Expected {allowed!r} to be accepted, but it was rejected as {offending!r}" + ) + assert offending == "" + + +def test_validate_user_args_empty_list_is_allowed(): + assert CliArgs.validate_user_args([]) == (True, "") + + +def test_validate_user_args_none_is_allowed(): + # falsy args short-circuit to allowed + assert CliArgs.validate_user_args(None) == (True, "") + + +def test_validate_user_args_mixed_allowed_then_forbidden(): + ok, offending = CliArgs.validate_user_args( + ["--pr_reviewer.num_code_suggestions=3", "--github.webhook_secret=secret"] + ) + assert ok is False + assert "webhook_secret" in offending + + +def test_validate_user_args_all_allowed_together(): + ok, offending = CliArgs.validate_user_args(ALLOWED_ARGS_SINGLE) + assert ok is True, f"Allowed batch unexpectedly rejected at {offending!r}" + assert offending == "" + + +@pytest.mark.asyncio +async def test_handle_request_uses_real_validator_to_block_forbidden(monkeypatch): + """Integration test: forbidden CLI arg should be rejected by the real + CliArgs.validate_user_args, before any settings update, tool + instantiation, tool run, or notify call happens.""" + + notify = Mock() + + monkeypatch.setattr(pr_agent_module, "apply_repo_settings", lambda pr_url: None) + + def _fail_update_settings(args): + raise AssertionError( + "update_settings_from_args must not be called when validation fails" + ) + + monkeypatch.setattr( + pr_agent_module, "update_settings_from_args", _fail_update_settings + ) + + class FakeTool: + def __init__(self, *args, **kwargs): + raise AssertionError("tool must not be instantiated for forbidden args") + + async def run(self): + raise AssertionError("tool must not run for forbidden args") + + monkeypatch.setitem(pr_agent_module.command2class, "custom", FakeTool) + + handled = await pr_agent_module.PRAgent(ai_handler="fake-ai")._handle_request( + "https://example/pr/1", + "/custom --github.webhook_secret=secret", + notify, + ) + + assert handled is False + notify.assert_not_called() diff --git a/tests/unittest/test_custom_merge_loader_security.py b/tests/unittest/test_custom_merge_loader_security.py new file mode 100644 index 0000000000..3cb75428cc --- /dev/null +++ b/tests/unittest/test_custom_merge_loader_security.py @@ -0,0 +1,251 @@ +""" +Security and behavior tests for pr_agent.custom_merge_loader. + +These tests exercise validate_file_security directly with forbidden directives at +various nesting positions, deep-nesting guard, and a representative safe config. +They also exercise the load() entry point against a minimal fake Dynaconf-like +object to verify file-skipping behavior, security enforcement, and single-key +loading semantics. +""" + +import importlib +from pathlib import Path + +import pytest +from jinja2.exceptions import SecurityError + +# Import pr_agent.config_loader first (for its module-level side effects) to +# complete the config_loader import chain and avoid the circular import between +# pr_agent.log and pr_agent.custom_merge_loader. +importlib.import_module("pr_agent.config_loader") +custom_merge_loader = importlib.import_module("pr_agent.custom_merge_loader") +load = custom_merge_loader.load +validate_file_security = custom_merge_loader.validate_file_security + +FORBIDDEN_DIRECTIVES = [ + "dynaconf_include", + "dynaconf_includes", + "includes", + "preload", + "preload_for_dynaconf", + "preloads", + "dynaconf_merge", + "dynaconf_merge_enabled", + "merge_enabled", + "loaders", + "loaders_for_dynaconf", + "core_loaders", + "core_loaders_for_dynaconf", + "settings_module", + "settings_file_for_dynaconf", + "settings_files_for_dynaconf", + "envvar_prefix", + "envvar_prefix_for_dynaconf", +] + + +class FakeDynaconf: + """Minimal Dynaconf-like object exposing settings_files and a .set() recorder.""" + + def __init__(self, settings_files, includes=None, preload=None): + self.settings_files = settings_files + if includes is not None: + self.includes = includes + if preload is not None: + self.preload = preload + self._store = {} + + def set(self, key, value): + self._store[key] = value + + +# --------------------------------------------------------------------------- +# validate_file_security: forbidden directives at varying positions +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("directive", FORBIDDEN_DIRECTIVES) +def test_forbidden_directive_at_top_level_raises(directive): + data = {directive: "anything"} + with pytest.raises(SecurityError): + validate_file_security(data, "test.toml") + + +@pytest.mark.parametrize("directive", FORBIDDEN_DIRECTIVES) +def test_forbidden_directive_inside_section_raises(directive): + data = {"config": {"some_key": "ok", directive: "bad"}} + with pytest.raises(SecurityError): + validate_file_security(data, "test.toml") + + +@pytest.mark.parametrize("directive", FORBIDDEN_DIRECTIVES) +def test_forbidden_directive_deeply_nested_raises(directive): + data = { + "config": { + "subsection": { + "deeper": { + "evendeeper": {directive: True}, + }, + }, + }, + } + with pytest.raises(SecurityError): + validate_file_security(data, "test.toml") + + +@pytest.mark.parametrize("directive", FORBIDDEN_DIRECTIVES) +def test_forbidden_directive_mixed_case_raises(directive): + # The implementation lowercases keys before comparison; ensure mixed case is caught. + mixed = directive.upper() if directive.islower() else directive.swapcase() + # Ensure case is actually mixed/different + if mixed == directive: + mixed = directive.upper() + data = {"config": {mixed: "bad"}} + with pytest.raises(SecurityError): + validate_file_security(data, "test.toml") + + +# --------------------------------------------------------------------------- +# validate_file_security: max depth guard +# --------------------------------------------------------------------------- + +def test_excessive_nesting_raises_security_error(): + # Build a dict deeper than MAX_DEPTH (50) so the guard trips. + data = current = {} + for _ in range(120): + nxt = {} + current["nested"] = nxt + current = nxt + current["leaf"] = "value" + with pytest.raises(SecurityError): + validate_file_security(data, "deep.toml") + + +# --------------------------------------------------------------------------- +# validate_file_security: representative safe PR-Agent config does not raise +# --------------------------------------------------------------------------- + +def test_safe_pr_agent_config_does_not_raise(): + data = { + "config": { + "model": "gpt-4", + "fallback_models": ["gpt-3.5-turbo"], + "git_provider": "github", + "publish_output": True, + "verbosity_level": 0, + }, + "pr_reviewer": { + "require_score_review": False, + "num_code_suggestions": 4, + "extra_instructions": "", + }, + "pr_description": { + "publish_labels": True, + "add_original_user_description": True, + }, + "github": { + "deployment_type": "user", + "ratelimit_retries": 5, + }, + } + # Should not raise. + validate_file_security(data, "safe.toml") + + +# --------------------------------------------------------------------------- +# load(): behavior tests using a minimal fake Dynaconf-like object +# --------------------------------------------------------------------------- + +def _write(tmp_path, name, content): + p = Path(tmp_path) / name + p.write_text(content, encoding="utf-8") + return str(p) + + +def test_load_skips_non_toml_files(tmp_path): + non_toml = _write(tmp_path, "settings.yaml", "config:\n model: foo\n") + obj = FakeDynaconf(settings_files=[non_toml]) + load(obj) + assert obj._store == {} + + +def test_load_skips_missing_files(tmp_path): + missing = str(Path(tmp_path) / "does_not_exist.toml") + obj = FakeDynaconf(settings_files=[missing]) + load(obj) + assert obj._store == {} + + +def test_load_silent_true_skips_on_forbidden_directive(tmp_path): + bad = _write( + tmp_path, + "bad.toml", + "[config]\nmodel = \"gpt-4\"\ndynaconf_include = [\"other.toml\"]\n", + ) + obj = FakeDynaconf(settings_files=[bad]) + # silent=True: exception is swallowed; no values should be set + load(obj, silent=True) + assert obj._store == {} + + +def test_load_silent_false_raises_on_forbidden_directive(tmp_path): + bad = _write( + tmp_path, + "bad.toml", + "[config]\nmodel = \"gpt-4\"\nincludes = [\"other.toml\"]\n", + ) + obj = FakeDynaconf(settings_files=[bad]) + with pytest.raises(SecurityError): + load(obj, silent=False) + + +def test_load_silent_false_raises_on_top_level_includes_attr(tmp_path): + # The loader also checks the object's own .includes attribute. + good = _write(tmp_path, "ok.toml", "[config]\nmodel = \"gpt-4\"\n") + obj = FakeDynaconf(settings_files=[good], includes=["something.toml"]) + with pytest.raises(SecurityError): + load(obj, silent=False) + + +def test_load_silent_false_raises_on_top_level_preload_attr(tmp_path): + good = _write(tmp_path, "ok.toml", "[config]\nmodel = \"gpt-4\"\n") + obj = FakeDynaconf(settings_files=[good], preload=["something.toml"]) + with pytest.raises(SecurityError): + load(obj, silent=False) + + +def test_load_valid_toml_sets_expected_sections(tmp_path): + a = _write( + tmp_path, + "a.toml", + "[config]\nmodel = \"gpt-4\"\nverbosity_level = 1\n\n[pr_reviewer]\nnum_code_suggestions = 4\n", + ) + obj = FakeDynaconf(settings_files=[a]) + load(obj) + assert "config" in obj._store + assert obj._store["config"]["model"] == "gpt-4" + assert obj._store["config"]["verbosity_level"] == 1 + assert "pr_reviewer" in obj._store + assert obj._store["pr_reviewer"]["num_code_suggestions"] == 4 + + +def test_load_respects_single_key_loading(tmp_path): + a = _write( + tmp_path, + "a.toml", + "[config]\nmodel = \"gpt-4\"\n\n[pr_reviewer]\nnum_code_suggestions = 4\n", + ) + obj = FakeDynaconf(settings_files=[a]) + # key matching is case-insensitive in the loader + load(obj, key="CONFIG") + assert "config" in obj._store + assert "pr_reviewer" not in obj._store + + +def test_load_later_file_replaces_earlier_field(tmp_path): + a = _write(tmp_path, "a.toml", "[config]\nmodel = \"gpt-4\"\nshared = \"from_a\"\n") + b = _write(tmp_path, "b.toml", "[config]\nshared = \"from_b\"\n") + obj = FakeDynaconf(settings_files=[a, b]) + load(obj) + assert obj._store["config"]["shared"] == "from_b" + # earlier-only field is preserved (accumulated, not replaced wholesale at section level) + assert obj._store["config"]["model"] == "gpt-4" diff --git a/tests/unittest/test_diff_pipeline_core.py b/tests/unittest/test_diff_pipeline_core.py new file mode 100644 index 0000000000..508418d791 --- /dev/null +++ b/tests/unittest/test_diff_pipeline_core.py @@ -0,0 +1,475 @@ +"""Tests for the diff / hunk pipeline. + +Covers: +- pr_agent.algo.git_patch_processing.decouple_and_convert_to_hunks_with_lines_numbers +- pr_agent.algo.git_patch_processing.extract_hunk_lines_from_patch +- pr_agent.algo.pr_processing.generate_full_patch +- pr_agent.algo.pr_processing.pr_generate_compressed_diff + +The tests document current behavior and assert on the key structural +markers (hunk headers, line numbers, selected lines, returned lists) +rather than on full golden strings, so they remain robust to minor +formatting tweaks. +""" + +import pr_agent.algo.pr_processing as pr_processing +from pr_agent.algo.git_patch_processing import ( + decouple_and_convert_to_hunks_with_lines_numbers, + extract_hunk_lines_from_patch, +) +from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class FakeTokenHandler: + """Deterministic token handler: 1 token per whitespace-split word.""" + + def __init__(self, prompt_tokens: int = 100): + self.prompt_tokens = prompt_tokens + + def count_tokens(self, patch: str) -> int: + return len(patch.split()) + + +MULTI_HUNK_PATCH = ( + "@@ -1,3 +1,4 @@\n" + " line1\n" + "-line2\n" + "+line2_new\n" + "+line2b\n" + " line3\n" + "@@ -10,3 +11,3 @@\n" + " ctx_a\n" + "-removed\n" + "+added\n" + " ctx_b\n" +) + + +def _make_file(filename="src/sample.py", patch=MULTI_HUNK_PATCH, + edit_type=EDIT_TYPE.MODIFIED, tokens=10, + base="orig", head="new"): + return FilePatchInfo( + base_file=base, + head_file=head, + patch=patch, + filename=filename, + tokens=tokens, + edit_type=edit_type, + ) + + +# --------------------------------------------------------------------------- +# decouple_and_convert_to_hunks_with_lines_numbers +# --------------------------------------------------------------------------- + + +class TestDecoupleAndConvertToHunks: + def test_multi_hunk_emits_both_hunks_with_new_and_old_sections(self): + file = _make_file() + out = decouple_and_convert_to_hunks_with_lines_numbers(MULTI_HUNK_PATCH, file) + + # File header is present. + assert "## File: 'src/sample.py'" in out + # Both hunk headers are preserved. + assert "@@ -1,3 +1,4 @@" in out + assert "@@ -10,3 +11,3 @@" in out + # Each hunk produces a __new hunk__ section; modified hunks produce __old hunk__ too. + assert out.count("__new hunk__") == 2 + assert out.count("__old hunk__") == 2 + + def test_new_hunk_lines_are_numbered_starting_at_start2(self): + file = _make_file() + out = decouple_and_convert_to_hunks_with_lines_numbers(MULTI_HUNK_PATCH, file) + + # First hunk starts at +1 in the new file; context line " line1" -> "1 line1". + assert "1 line1" in out + # The inserted lines get the next numbers (2, 3) in the new file. + assert "2 +line2_new" in out + assert "3 +line2b" in out + assert "4 line3" in out + + # Second hunk starts at +11 in the new file (we set start2=11). + # Layout per implementation: context, then '+added' (replacing '-removed'), + # then trailing context. The '-' line is not numbered in __new hunk__. + assert "11 ctx_a" in out + assert "12 +added" in out + + def test_old_hunk_contains_removed_and_context_lines_unnumbered(self): + file = _make_file() + out = decouple_and_convert_to_hunks_with_lines_numbers(MULTI_HUNK_PATCH, file) + + old_section = out.split("__old hunk__", 1)[1] + # The first __old hunk__ contains the removed line and context. + assert "-line2" in old_section + assert " line1" in old_section + # Old hunk lines do NOT have numeric prefixes — assert no "1 -line2" style line. + for line in old_section.splitlines(): + stripped = line.lstrip() + if stripped.startswith(("-", "+", " ")) and stripped == line: + # current implementation does not prefix old-hunk lines with numbers + assert not line[:1].isdigit() + + def test_deleted_file_short_circuits_with_message(self): + file = _make_file(edit_type=EDIT_TYPE.DELETED) + out = decouple_and_convert_to_hunks_with_lines_numbers(MULTI_HUNK_PATCH, file) + assert "was deleted" in out + assert "src/sample.py" in out + # No hunk content should be emitted for deleted files. + assert "__new hunk__" not in out + assert "__old hunk__" not in out + + def test_pure_addition_hunk_emits_only_new_hunk_section(self): + patch = ( + "@@ -0,0 +1,2 @@\n" + "+brand new line 1\n" + "+brand new line 2\n" + ) + file = _make_file(patch=patch, edit_type=EDIT_TYPE.ADDED) + out = decouple_and_convert_to_hunks_with_lines_numbers(patch, file) + + assert "__new hunk__" in out + assert "__old hunk__" not in out + # Numbering: implementation falls back to start2=0 for "@@ -0,0 +1 @@"-style + # headers, but here header has explicit "+1,2" so start2=1. + assert "1 +brand new line 1" in out + assert "2 +brand new line 2" in out + + def test_no_file_arg_omits_file_header(self): + out = decouple_and_convert_to_hunks_with_lines_numbers(MULTI_HUNK_PATCH, file=None) + assert "## File:" not in out + assert "@@ -1,3 +1,4 @@" in out + assert "__new hunk__" in out + + +# --------------------------------------------------------------------------- +# extract_hunk_lines_from_patch +# --------------------------------------------------------------------------- + + +class TestExtractHunkLinesFromPatch: + def test_right_side_single_line_selection_in_first_hunk(self): + # In MULTI_HUNK_PATCH new-file numbering: + # 1 " line1" + # 2 "+line2_new" + # 3 "+line2b" + # 4 " line3" + full, selected = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", line_start=2, line_end=2, side="right" + ) + assert "## File: 'src/sample.py'" in full + assert "@@ -1,3 +1,4 @@" in full + # The second hunk's header should NOT be in `full` since its range + # does not contain line_start=2. + assert "@@ -10,3 +11,3 @@" not in full + # Current production behavior includes the paired removed line before + # the targeted inserted line; assert exactly so this test cannot pass + # while silently selecting additional wrong-side/context lines. + assert selected == "-line2\n+line2_new" + + def test_right_side_range_selection_across_consecutive_lines(self): + full, selected = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", line_start=2, line_end=3, side="right" + ) + assert selected == "-line2\n+line2_new\n+line2b" + + def test_left_side_selects_from_old_line_numbers(self): + # Old file numbering in first hunk starts at 1; "-line2" is old-line 2. + full, selected = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", line_start=2, line_end=2, side="left" + ) + assert "@@ -1,3 +1,4 @@" in full + # Current production behavior includes adjacent context/paired new + # lines around the deleted line; assert exactly so this test documents + # the full selected payload rather than only a partial match. + assert selected == " line1\n-line2\n+line2_new" + + def test_targets_second_hunk_when_line_in_its_range(self): + # Second hunk new-file range: start2=11, size2=3 -> lines 11..14. + full, selected = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", line_start=12, line_end=12, side="right" + ) + assert "@@ -10,3 +11,3 @@" in full + assert "@@ -1,3 +1,4 @@" not in full + assert selected == "-removed\n+added" + + def test_out_of_range_returns_only_header_and_empty_selection(self): + full, selected = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", line_start=999, line_end=1000, side="right" + ) + # Neither hunk matched, so no hunk headers are emitted. + assert "@@" not in full + assert "## File: 'src/sample.py'" in full + assert selected == "" + + def test_malformed_patch_returns_empty_tuple(self): + # An '@@' line that does not match RE_HUNK_HEADER causes the + # implementation to raise inside extract_hunk_headers; the function + # catches it and returns ("", ""). + bad_patch = "@@ not a real header @@\n+something\n" + full, selected = extract_hunk_lines_from_patch( + bad_patch, "src/sample.py", line_start=1, line_end=1, side="right" + ) + assert full == "" + assert selected == "" + + def test_remove_trailing_chars_false_preserves_trailing_newlines(self): + full_stripped, sel_stripped = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", 2, 2, "right", remove_trailing_chars=True + ) + full_raw, sel_raw = extract_hunk_lines_from_patch( + MULTI_HUNK_PATCH, "src/sample.py", 2, 2, "right", remove_trailing_chars=False + ) + assert full_raw.endswith("\n") + assert sel_raw.endswith("\n") + # Trimmed variants are strict suffixes (no trailing whitespace). + assert full_stripped == full_raw.rstrip() + assert sel_stripped == sel_raw.rstrip() + + +# --------------------------------------------------------------------------- +# generate_full_patch +# --------------------------------------------------------------------------- + + +class TestGenerateFullPatch: + def test_files_within_budget_are_all_included(self, monkeypatch): + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: 10_000) + token_handler = FakeTokenHandler(prompt_tokens=10) + file_dict = { + "a.py": {"patch": "+ change a", "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + "b.py": {"patch": "+ change b", "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + } + total, patches, remaining, files_in = pr_processing.generate_full_patch( + convert_hunks_to_line_numbers=False, + file_dict=file_dict, + max_tokens_model=10_000, + remaining_files_list_prev=list(file_dict), + token_handler=token_handler, + ) + assert files_in == ["a.py", "b.py"] + assert remaining == [] + assert len(patches) == 2 + # File header format for non-line-numbered patches: + assert any("## File: 'a.py'" in p for p in patches) + assert any("## File: 'b.py'" in p for p in patches) + assert total > token_handler.prompt_tokens + + def test_oversized_patch_is_deferred_to_remaining_list(self): + token_handler = FakeTokenHandler(prompt_tokens=10) + big_tokens = 5000 # exceeds (max_tokens - SOFT=1500) when added on top of prompt + file_dict = { + "small.py": {"patch": "+ small", "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + "huge.py": {"patch": "+ huge", "tokens": big_tokens, "edit_type": EDIT_TYPE.MODIFIED}, + } + total, patches, remaining, files_in = pr_processing.generate_full_patch( + convert_hunks_to_line_numbers=False, + file_dict=file_dict, + max_tokens_model=4_000, # SOFT=1500, HARD=1000 + remaining_files_list_prev=list(file_dict), + token_handler=token_handler, + ) + assert "small.py" in files_in + assert "huge.py" not in files_in + assert remaining == ["huge.py"] + + def test_remaining_files_list_prev_filters_input(self): + token_handler = FakeTokenHandler(prompt_tokens=10) + file_dict = { + "a.py": {"patch": "+ change a", "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + "b.py": {"patch": "+ change b", "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + } + total, patches, remaining, files_in = pr_processing.generate_full_patch( + convert_hunks_to_line_numbers=False, + file_dict=file_dict, + max_tokens_model=10_000, + remaining_files_list_prev=["b.py"], # only b.py is eligible this round + token_handler=token_handler, + ) + assert files_in == ["b.py"] + assert remaining == [] + + def test_line_numbered_mode_omits_extra_file_header(self): + token_handler = FakeTokenHandler(prompt_tokens=10) + prebuilt = "## File: 'a.py'\n@@ -1 +1 @@\n+x" + file_dict = { + "a.py": {"patch": prebuilt, "tokens": 5, "edit_type": EDIT_TYPE.MODIFIED}, + } + _, patches, _, _ = pr_processing.generate_full_patch( + convert_hunks_to_line_numbers=True, + file_dict=file_dict, + max_tokens_model=10_000, + remaining_files_list_prev=["a.py"], + token_handler=token_handler, + ) + # In line-numbered mode, the function does not wrap with another header. + assert patches[0].count("## File: 'a.py'") == 1 + + +# --------------------------------------------------------------------------- +# pr_generate_compressed_diff +# --------------------------------------------------------------------------- + + +class TestPrGenerateCompressedDiff: + def _settings(self): + from pr_agent.config_loader import get_settings + return get_settings() + + def test_deleted_files_collected_and_excluded_from_patches(self, monkeypatch): + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: 10_000) + + deleted = FilePatchInfo( + base_file="old content", + head_file="", + patch="@@ -1,2 +0,0 @@\n-old1\n-old2\n", + filename="gone.py", + tokens=5, + edit_type=EDIT_TYPE.DELETED, + ) + kept = _make_file(filename="kept.py", tokens=5) + top_langs = [{"files": [deleted, kept]}] + + (patches_list, total_tokens_list, deleted_files_list, + remaining_files_list, file_dict, files_in_patches_list) = \ + pr_processing.pr_generate_compressed_diff( + top_langs=top_langs, + token_handler=FakeTokenHandler(prompt_tokens=10), + model="some-model", + convert_hunks_to_line_numbers=False, + large_pr_handling=False, + ) + + assert "gone.py" in deleted_files_list + assert "gone.py" not in file_dict + assert "kept.py" in file_dict + # First (and only) iteration carries kept.py and no remaining files. + assert files_in_patches_list[0] == ["kept.py"] + assert remaining_files_list == [] + assert len(patches_list) == 1 + assert len(total_tokens_list) == 1 + + def test_large_pr_handling_paginates_across_iterations(self, monkeypatch): + # Build patches large enough that exactly one fits per iteration. The + # per-iteration budget in generate_full_patch is + # max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - prompt_tokens + # so we derive max_tokens from the actual token count of a patch (since + # pr_generate_compressed_diff recomputes tokens from patch content via + # token_handler.count_tokens, ignoring FilePatchInfo.tokens). + prompt_tokens = 100 + token_handler = FakeTokenHandler(prompt_tokens=prompt_tokens) + patch_str = "@@ -1,1 +1,1 @@\n+" + " ".join(["tok"] * 100) + "\n" + patch_tokens = token_handler.count_tokens(patch_str) + soft_threshold = pr_processing.OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD + # Budget allows exactly one patch per iteration (prompt + patch fits, + # but prompt + 2*patch does not): + # prompt + patch_tokens <= max - SOFT + # prompt + 2 * patch_tokens > max - SOFT + max_tokens = soft_threshold + prompt_tokens + patch_tokens + 1 + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: max_tokens) + + settings = self._settings() + original_max_ai_calls = settings.pr_description.max_ai_calls + # NUMBER_OF_ALLOWED_ITERATIONS = max_ai_calls - 1; loop runs range(that - 1). + # We want 3 total iterations (1 mandatory + 2 in the loop) -> max_ai_calls = 4. + settings.pr_description.max_ai_calls = 4 + + try: + files = [ + _make_file(filename=f"f{i}.py", tokens=5, patch=patch_str) + for i in range(3) + ] + + top_langs = [{"files": files}] + (patches_list, total_tokens_list, deleted_files_list, + remaining_files_list, file_dict, files_in_patches_list) = \ + pr_processing.pr_generate_compressed_diff( + top_langs=top_langs, + token_handler=token_handler, + model="some-model", + convert_hunks_to_line_numbers=False, + large_pr_handling=True, + ) + + # Pagination actually fired: 3 batches, one file each, nothing left over. + assert len(patches_list) == 3 + assert len(patches_list) == len(total_tokens_list) == len(files_in_patches_list) + assert files_in_patches_list == [["f0.py"], ["f1.py"], ["f2.py"]] + assert remaining_files_list == [] + assert deleted_files_list == [] + finally: + settings.pr_description.max_ai_calls = original_max_ai_calls + + def test_files_with_empty_patch_are_skipped(self, monkeypatch): + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: 10_000) + + empty = _make_file(filename="empty.py", patch="", tokens=0) + kept = _make_file(filename="kept.py", tokens=5) + top_langs = [{"files": [empty, kept]}] + + (patches_list, _, deleted_files_list, remaining_files_list, + file_dict, files_in_patches_list) = \ + pr_processing.pr_generate_compressed_diff( + top_langs=top_langs, + token_handler=FakeTokenHandler(prompt_tokens=10), + model="some-model", + convert_hunks_to_line_numbers=False, + large_pr_handling=False, + ) + + assert "empty.py" not in file_dict + assert "empty.py" not in deleted_files_list + assert "kept.py" in file_dict + assert files_in_patches_list[0] == ["kept.py"] + + def test_convert_hunks_to_line_numbers_runs_decouple_per_file(self, monkeypatch): + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: 10_000) + kept = _make_file(filename="kept.py", tokens=5) + top_langs = [{"files": [kept]}] + + (patches_list, _, _, _, file_dict, files_in_patches_list) = \ + pr_processing.pr_generate_compressed_diff( + top_langs=top_langs, + token_handler=FakeTokenHandler(prompt_tokens=10), + model="some-model", + convert_hunks_to_line_numbers=True, + large_pr_handling=False, + ) + # Decoupled output marker should be present in the stored patch. + assert "__new hunk__" in file_dict["kept.py"]["patch"] + assert files_in_patches_list[0] == ["kept.py"] + + def test_max_ai_calls_boundary_caps_iterations(self, monkeypatch): + # Force every patch to be too big to fit so each iteration defers + # everything to the next round; this isolates the iteration cap. + monkeypatch.setattr(pr_processing, "get_max_tokens", lambda model: 1_000) + settings = self._settings() + original_max_ai_calls = settings.pr_description.max_ai_calls + settings.pr_description.max_ai_calls = 2 # allow 1 extra loop iteration (range(0)) + + try: + files = [_make_file(filename=f"f{i}.py", tokens=5, + patch=f"@@ -1 +1 @@\n+x_{i}\n") + for i in range(3)] + top_langs = [{"files": files}] + + (patches_list, _, _, remaining_files_list, _, files_in_patches_list) = \ + pr_processing.pr_generate_compressed_diff( + top_langs=top_langs, + token_handler=FakeTokenHandler(prompt_tokens=10_000), + model="some-model", + convert_hunks_to_line_numbers=False, + large_pr_handling=True, + ) + + # The first (mandatory) iteration always appends one batch, even if empty. + assert len(patches_list) >= 1 + # With max_ai_calls=2, NUMBER_OF_ALLOWED_ITERATIONS=1 and the loop body + # executes range(0) -> zero times. So only the first iteration runs. + assert len(patches_list) == 1 + assert len(files_in_patches_list) == 1 + finally: + settings.pr_description.max_ai_calls = original_max_ai_calls diff --git a/tests/unittest/test_github_app_timeout_core.py b/tests/unittest/test_github_app_timeout_core.py new file mode 100644 index 0000000000..aedc965404 --- /dev/null +++ b/tests/unittest/test_github_app_timeout_core.py @@ -0,0 +1,512 @@ +"""Unit tests for ``pr_agent.servers.utils.DefaultDictWithTimeout`` and a few +helper functions in ``pr_agent.servers.github_app``. + +These tests intentionally avoid network, external credentials, and real sleeps. +Time-dependent behavior is exercised by monkeypatching ``time.monotonic`` on the +``pr_agent.servers.utils`` module. +""" + +import asyncio +from types import SimpleNamespace + +import pytest + +from pr_agent.servers import github_app +from pr_agent.servers import utils as servers_utils +from pr_agent.servers.utils import DefaultDictWithTimeout + +# --------------------------------------------------------------------------- +# Shared test helpers +# --------------------------------------------------------------------------- + + +def _snapshot_ask_diff_hunk(settings): + """Snapshot the ``ask_diff_hunk`` setting. + + Returns ``(sentinel, original)`` where ``original is sentinel`` indicates + the key was absent prior to the snapshot. The sentinel is a fresh object + so callers can distinguish "absent" from "present-as-None". + """ + sentinel = object() + original = settings.get("ask_diff_hunk", sentinel) + return sentinel, original + + +def _restore_ask_diff_hunk(settings, original, sentinel): + """Restore ``ask_diff_hunk`` to the state captured by ``_snapshot_ask_diff_hunk``. + + When the baseline was absent (``original is sentinel``), the key is truly + removed via ``unset(force=True)`` rather than being set to ``None`` — + Dynaconf's ``LazySettings`` does not support ``del settings[key]``. + """ + if original is sentinel: + settings.unset("ask_diff_hunk", force=True) + else: + settings.set("ask_diff_hunk", original) + + +# --------------------------------------------------------------------------- +# DefaultDictWithTimeout +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fake_clock(monkeypatch): + """Replace ``time.monotonic`` in pr_agent.servers.utils with a controllable clock.""" + state = {"t": 1_000_000.0} + monkeypatch.setattr(servers_utils.time, "monotonic", lambda: state["t"]) + return state + + +def _key_times(d): + # Access the name-mangled private attribute used for testing internals. + return d._DefaultDictWithTimeout__key_times + + +class TestDefaultDictWithTimeout: + def test_update_key_time_on_get_true_refreshes_access_time(self, fake_clock): + # Use a large refresh_interval to keep __refresh from interfering. + d = DefaultDictWithTimeout( + lambda: 0, ttl=1000, refresh_interval=1000, update_key_time_on_get=True + ) + d["a"] = 1 + t_set = fake_clock["t"] + assert _key_times(d)["a"] == t_set + + fake_clock["t"] += 7.5 + _ = d["a"] + assert _key_times(d)["a"] == t_set + 7.5 + + def test_update_key_time_on_get_false_keeps_original_time(self, fake_clock): + d = DefaultDictWithTimeout( + lambda: 0, ttl=1000, refresh_interval=1000, update_key_time_on_get=False + ) + d["a"] = 1 + original = _key_times(d)["a"] + + fake_clock["t"] += 7.5 + _ = d["a"] + assert _key_times(d)["a"] == original + + def test_setitem_always_updates_key_time(self, fake_clock): + d = DefaultDictWithTimeout( + lambda: 0, ttl=1000, refresh_interval=1000, update_key_time_on_get=False + ) + d["a"] = 1 + first = _key_times(d)["a"] + fake_clock["t"] += 3 + d["a"] = 2 + assert _key_times(d)["a"] == first + 3 + + def test_expires_keys_older_than_ttl_when_refresh_runs(self, fake_clock): + # refresh_interval is wide enough that the cleanup branch runs on access. + d = DefaultDictWithTimeout( + lambda: 0, ttl=2, refresh_interval=1000, update_key_time_on_get=False + ) + d["a"] = 1 + d["b"] = 2 + + # Warm-up access: __last_refresh is seeded in __init__ to (now - + # refresh_interval), so the very first __getitem__ in the same tick + # is required to bring it forward. Without this, a later access can + # exceed refresh_interval and trip the early-return branch. + _ = d["warm"] + + # Advance past TTL but not past refresh_interval. + fake_clock["t"] += 5 + + # Touching a different (new) key triggers __refresh which should + # purge stale entries. defaultdict.__missing__ will route through our + # __setitem__ for the brand-new key, so it gets a fresh timestamp. + _ = d["fresh"] + + assert "a" not in d + assert "b" not in d + assert "fresh" in d + assert "a" not in _key_times(d) + assert "b" not in _key_times(d) + + def test_no_ttl_means_no_expiration(self, fake_clock): + d = DefaultDictWithTimeout(lambda: 0, ttl=None, refresh_interval=1) + d["a"] = 1 + fake_clock["t"] += 10_000 + _ = d["a"] + assert d["a"] == 1 + assert "a" in _key_times(d) + + def test_delitem_removes_key_time(self, fake_clock): + d = DefaultDictWithTimeout(lambda: 0, ttl=10, refresh_interval=1000) + d["a"] = 1 + del d["a"] + assert "a" not in d + assert "a" not in _key_times(d) + + @pytest.mark.xfail( + strict=True, + reason=( + "Documents a behavior gap in DefaultDictWithTimeout.__refresh: when " + "the elapsed time since the last refresh exceeds refresh_interval, " + "the method returns early *before* expiring stale keys, so long " + "idle periods skip cleanup entirely. Kept as strict xfail per " + "instructions to avoid changing production logic in test scope." + ), + ) + def test_refresh_runs_after_long_idle_period(self, fake_clock): + d = DefaultDictWithTimeout( + lambda: 0, ttl=2, refresh_interval=5, update_key_time_on_get=False + ) + d["a"] = 1 + # Idle long enough that delta > refresh_interval; current implementation + # skips cleanup in that case. We assert the *intended* behavior: + # accessing/inserting any key should still trigger expiration of + # stale entries. + fake_clock["t"] += 100 + _ = d["fresh"] + assert "a" not in d + + +# --------------------------------------------------------------------------- +# handle_line_comments +# --------------------------------------------------------------------------- + + +class TestHandleLineComments: + @pytest.fixture(autouse=True) + def restore_ask_diff_hunk_after_each_test(self): + from pr_agent.config_loader import get_settings + + settings = get_settings() + sentinel, original = _snapshot_ask_diff_hunk(settings) + try: + yield + finally: + _restore_ask_diff_hunk(settings, original, sentinel) + + def _payload(self, **overrides): + comment = { + "start_line": 10, + "line": 14, + "diff_hunk": "@@ -1,3 +1,4 @@\n+new line", + "path": "src/file.py", + "side": "RIGHT", + "id": 987654, + } + comment.update(overrides) + return {"comment": comment} + + def test_returns_empty_string_for_empty_body(self): + assert github_app.handle_line_comments({}, "") == "" + + def test_converts_ask_to_ask_line_with_metadata(self): + body = self._payload() + result = github_app.handle_line_comments(body, "/ask Why this change?") + assert result.startswith("/ask_line ") + assert "--line_start=10" in result + assert "--line_end=14" in result + assert "--side=RIGHT" in result + assert "--file_name=src/file.py" in result + assert "--comment_id=987654" in result + assert result.endswith("Why this change?") + + def test_missing_start_line_falls_back_to_line(self): + body = self._payload(start_line=None) + result = github_app.handle_line_comments(body, "/ask anything") + assert "--line_start=14" in result + assert "--line_end=14" in result + + def test_sets_ask_diff_hunk_in_settings(self): + from pr_agent.config_loader import get_settings + + settings = get_settings() + body = self._payload(diff_hunk="DIFF_HUNK_SENTINEL") + github_app.handle_line_comments(body, "/ask hi") + assert settings.get("ask_diff_hunk") == "DIFF_HUNK_SENTINEL" + + def test_restore_ask_diff_hunk_missing_baseline_truly_absent(self): + """The cleanup helper must leave ``ask_diff_hunk`` truly absent (not + present-as-None) when the key did not exist before the test.""" + from pr_agent.config_loader import get_settings + + settings = get_settings() + # Snapshot the pre-existing outer state so we don't leak our forced + # absence to other tests. + outer_sentinel, outer_original = _snapshot_ask_diff_hunk(settings) + try: + # Force a known-absent baseline regardless of leaks from elsewhere. + settings.unset("ask_diff_hunk", force=True) + + sentinel, original = _snapshot_ask_diff_hunk(settings) + assert original is sentinel + + # Simulate the test body mutating the setting. + settings.set("ask_diff_hunk", "HUNK") + assert settings.get("ask_diff_hunk") == "HUNK" + + _restore_ask_diff_hunk(settings, original, sentinel) + + # Key must be absent, not merely None. + probe = object() + assert settings.get("ask_diff_hunk", probe) is probe + assert "ask_diff_hunk" not in settings + finally: + _restore_ask_diff_hunk(settings, outer_original, outer_sentinel) + + def test_restore_ask_diff_hunk_existing_value_is_restored(self): + """If a non-None baseline value existed, the helper restores it.""" + from pr_agent.config_loader import get_settings + + settings = get_settings() + outer_sentinel, outer_original = _snapshot_ask_diff_hunk(settings) + try: + settings.set("ask_diff_hunk", "BASELINE") + + sentinel, original = _snapshot_ask_diff_hunk(settings) + assert original == "BASELINE" + + settings.set("ask_diff_hunk", "MUTATED") + _restore_ask_diff_hunk(settings, original, sentinel) + + assert settings.get("ask_diff_hunk") == "BASELINE" + finally: + _restore_ask_diff_hunk(settings, outer_original, outer_sentinel) + + def test_restore_ask_diff_hunk_existing_none_baseline_is_preserved(self): + """If the baseline value was explicitly ``None``, the helper restores + ``None`` rather than removing the key.""" + from pr_agent.config_loader import get_settings + + settings = get_settings() + outer_sentinel, outer_original = _snapshot_ask_diff_hunk(settings) + try: + settings.set("ask_diff_hunk", None) + # Sanity: key is present with value None. + assert "ask_diff_hunk" in settings + assert settings.get("ask_diff_hunk") is None + + sentinel, original = _snapshot_ask_diff_hunk(settings) + assert original is None + assert original is not sentinel + + settings.set("ask_diff_hunk", "MUTATED") + _restore_ask_diff_hunk(settings, original, sentinel) + + assert settings.get("ask_diff_hunk") is None + finally: + _restore_ask_diff_hunk(settings, outer_original, outer_sentinel) + + def test_non_ask_comment_returned_unchanged(self): + body = self._payload() + result = github_app.handle_line_comments(body, "just a comment") + assert result == "just a comment" + + +# --------------------------------------------------------------------------- +# _check_pull_request_event +# --------------------------------------------------------------------------- + + +class TestCheckPullRequestEvent: + def _pr(self, **overrides): + pr = { + "url": "https://api.github.com/repos/o/r/pulls/1", + "state": "open", + "draft": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + } + pr.update(overrides) + return pr + + def test_accepts_valid_open_non_draft_pr(self): + body = {"pull_request": self._pr()} + log_context = {} + pr, api_url = github_app._check_pull_request_event("opened", body, log_context) + assert pr is body["pull_request"] + assert api_url == "https://api.github.com/repos/o/r/pulls/1" + assert log_context["api_url"] == api_url + + def test_rejects_missing_pull_request(self): + assert github_app._check_pull_request_event("opened", {}, {}) == ({}, "") + + def test_rejects_missing_url(self): + body = {"pull_request": self._pr(url=None)} + assert github_app._check_pull_request_event("opened", body, {}) == ({}, "") + + def test_rejects_closed_pr(self): + body = {"pull_request": self._pr(state="closed")} + assert github_app._check_pull_request_event("opened", body, {}) == ({}, "") + + def test_rejects_draft_pr(self): + body = {"pull_request": self._pr(draft=True)} + assert github_app._check_pull_request_event("opened", body, {}) == ({}, "") + + def test_rejects_when_draft_field_missing(self): + # pull_request.get("draft", True) defaults to True, so a missing draft + # field is treated as draft and rejected. + pr = self._pr() + pr.pop("draft") + body = {"pull_request": pr} + assert github_app._check_pull_request_event("opened", body, {}) == ({}, "") + + def test_rejects_synchronize_when_created_equals_updated(self): + body = { + "pull_request": self._pr( + created_at="2024-01-01T00:00:00Z", updated_at="2024-01-01T00:00:00Z" + ) + } + assert ( + github_app._check_pull_request_event("synchronize", body, {}) == ({}, "") + ) + + def test_accepts_synchronize_when_timestamps_differ(self): + body = {"pull_request": self._pr()} + pr, api_url = github_app._check_pull_request_event("synchronize", body, {}) + assert api_url.endswith("/pulls/1") + assert pr["state"] == "open" + + +# --------------------------------------------------------------------------- +# handle_push_trigger_for_new_commits dedupe behavior +# --------------------------------------------------------------------------- + + +def _run(coro): + return asyncio.run(coro) + + +@pytest.fixture +def push_trigger_env(monkeypatch): + """Set up minimal mocks so handle_push_trigger_for_new_commits can run.""" + # Swap module-level dedupe state with fresh test-local instances so we + # don't leak entries into other tests and don't depend on prior state. + fresh_duplicate_push_triggers = DefaultDictWithTimeout(ttl=None) + fresh_pending_conditions = DefaultDictWithTimeout( + asyncio.locks.Condition, ttl=None + ) + monkeypatch.setattr( + github_app, "_duplicate_push_triggers", fresh_duplicate_push_triggers + ) + monkeypatch.setattr( + github_app, + "_pending_task_duplicate_push_conditions", + fresh_pending_conditions, + ) + + settings = SimpleNamespace( + github_app=SimpleNamespace( + handle_push_trigger=True, + push_trigger_ignore_merge_commits=False, + push_trigger_pending_tasks_backlog=False, + ) + ) + monkeypatch.setattr(github_app, "get_settings", lambda: settings) + monkeypatch.setattr(github_app, "apply_repo_settings", lambda api_url: None) + + eligible_provider = SimpleNamespace( + verify_eligibility=lambda *a, **kw: github_app.Eligibility.ELIGIBLE + ) + monkeypatch.setattr(github_app, "get_identity_provider", lambda: eligible_provider) + + calls = {"count": 0} + + async def fake_perform(*args, **kwargs): + calls["count"] += 1 + + monkeypatch.setattr(github_app, "_perform_auto_commands_github", fake_perform) + yield calls + + +def _push_body(): + return { + "pull_request": { + "url": "https://api.github.com/repos/o/r/pulls/42", + "state": "open", + "draft": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "merge_commit_sha": "merge-sha", + }, + "before": "sha-before", + "after": "sha-after", + } + + +class TestPushTriggerDedupe: + def test_first_event_runs_perform_and_decrements_counter(self, push_trigger_env): + body = _push_body() + api_url = body["pull_request"]["url"] + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + body, "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + assert push_trigger_env["count"] == 1 + # Counter incremented then decremented back to 0. + assert github_app._duplicate_push_triggers[api_url] == 0 + + def test_skips_when_before_equals_after(self, push_trigger_env): + body = _push_body() + body["before"] = body["after"] + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + body, "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + assert push_trigger_env["count"] == 0 + + def test_skips_merge_commit_when_configured(self, push_trigger_env, monkeypatch): + body = _push_body() + body["after"] = body["pull_request"]["merge_commit_sha"] + settings = github_app.get_settings() + settings.github_app.push_trigger_ignore_merge_commits = True + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + body, "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + assert push_trigger_env["count"] == 0 + + def test_skips_when_handle_push_trigger_disabled(self, push_trigger_env): + github_app.get_settings().github_app.handle_push_trigger = False + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + _push_body(), "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + assert push_trigger_env["count"] == 0 + + def test_discards_when_max_active_tasks_reached(self, push_trigger_env): + body = _push_body() + api_url = body["pull_request"]["url"] + # Simulate an already-running task with backlog disabled (max=1). + github_app._duplicate_push_triggers[api_url] = 1 + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + body, "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + # Third path: counter is left untouched, perform never runs. + assert push_trigger_env["count"] == 0 + assert github_app._duplicate_push_triggers[api_url] == 1 + + def test_invalid_pr_event_short_circuits(self, push_trigger_env): + body = _push_body() + body["pull_request"]["state"] = "closed" + + asyncio.run( + github_app.handle_push_trigger_for_new_commits( + body, "push", "alice", "1", "synchronize", {}, agent=None + ) + ) + + assert push_trigger_env["count"] == 0 diff --git a/tests/unittest/test_github_provider_comments.py b/tests/unittest/test_github_provider_comments.py new file mode 100644 index 0000000000..5a3ed571de --- /dev/null +++ b/tests/unittest/test_github_provider_comments.py @@ -0,0 +1,376 @@ +""" +Tests for GitHub provider inline comment creation, publishing fallback, +and multi-line code suggestion payload shape. + +These tests use ``GithubProvider.__new__(GithubProvider)`` to bypass network-bound +``__init__`` and inject minimal fake collaborators. No real GitHub API access. +""" + +from types import SimpleNamespace + +import pytest + +from pr_agent.git_providers import github_provider as gh_module +from pr_agent.git_providers.github_provider import GithubProvider + + +class _FakeGithubException(Exception): + """Mimics github.GithubException enough for the provider's ``e.status`` check.""" + + def __init__(self, status, data=None): + super().__init__(f"GithubException status={status}") + self.status = status + self.data = data or {} + + +class _FakePR: + """Captures create_review calls; can be configured to raise on the first call.""" + + def __init__(self, raise_on_first=None): + self.create_review_calls = [] + self._raise_on_first = raise_on_first + self._calls = 0 + + def create_review(self, commit=None, comments=None): + self._calls += 1 + self.create_review_calls.append({"commit": commit, "comments": comments}) + if self._raise_on_first is not None and self._calls == 1: + exc = self._raise_on_first + self._raise_on_first = None + raise exc + return SimpleNamespace(id=1) + + +def _make_provider(pr=None, max_chars=65000): + p = GithubProvider.__new__(GithubProvider) + p.pr = pr if pr is not None else _FakePR() + p.repo = "owner/repo" + p.pr_num = 1 + p.max_comment_chars = max_chars + p.last_commit_id = SimpleNamespace(sha="deadbeef") + p.diff_files = [] + p.base_url = "https://api.github.com" + return p + + +# --------------------------------------------------------------------------- +# create_inline_comment +# --------------------------------------------------------------------------- + +def test_create_inline_comment_returns_line_payload(monkeypatch): + """When a position is resolved, payload must include body/path/position.""" + provider = _make_provider() + + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda diff_files, rel_file, rel_line, abs_pos: (5, 42), + ) + + payload = provider.create_inline_comment("LGTM", "src/foo.py", "x = 1") + + assert payload == {"body": "LGTM", "path": "src/foo.py", "position": 5} + + +def test_create_inline_comment_returns_empty_when_position_unresolved(monkeypatch): + """If no position can be resolved (position == -1) current behavior returns {}.""" + provider = _make_provider() + + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda *a, **kw: (-1, -1), + ) + + payload = provider.create_inline_comment("body", "src/foo.py", "x = 1") + assert payload == {} + + +def test_create_inline_comment_lookup_strips_backticks_but_payload_preserves_them(monkeypatch): + """Backtick handling is asymmetric in current production code. + + ``find_line_number_of_relevant_line_in_file`` is called with + ``relevant_file.strip('`')`` (so the *lookup* sees the un-backticked + path), but the payload ``path`` only has ``.strip()`` applied — so any + surrounding backticks survive into the resulting comment payload. This + test documents that asymmetry; it does not endorse it. + """ + provider = _make_provider() + recorded = {} + + def recording_resolver(diff_files, rel_file, rel_line, abs_pos): + recorded["rel_file"] = rel_file + return (3, 9) + + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + recording_resolver, + ) + + payload = provider.create_inline_comment("b", "`src/foo.py`", "x = 1") + + # Lookup arg has backticks stripped. + assert recorded["rel_file"] == "src/foo.py" + # Payload path preserves backticks (only .strip() runs on it). + assert payload["path"] == "`src/foo.py`" + + +def test_create_inline_comment_payload_strips_surrounding_whitespace(monkeypatch): + """Whitespace-only test: payload path is .strip()'d before being returned.""" + provider = _make_provider() + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda *a, **kw: (3, 9), + ) + + payload = provider.create_inline_comment("b", " src/foo.py ", "x = 1") + assert payload["path"] == "src/foo.py" + + +def test_create_inline_comment_limits_body_length(monkeypatch): + """Body longer than max_comment_chars must be truncated with trailing '...'.""" + provider = _make_provider(max_chars=10) + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda *a, **kw: (1, 1), + ) + + long_body = "A" * 50 + payload = provider.create_inline_comment(long_body, "f.py", "line") + + assert payload["body"].endswith("...") + # limit_output_characters: output[:max_chars] + '...' + assert payload["body"] == "A" * 10 + "..." + + +def test_create_inline_comment_does_not_truncate_short_body(monkeypatch): + provider = _make_provider(max_chars=100) + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda *a, **kw: (1, 1), + ) + + payload = provider.create_inline_comment("short", "f.py", "line") + assert payload["body"] == "short" + + +# --------------------------------------------------------------------------- +# publish_inline_comment(s) +# --------------------------------------------------------------------------- + +def test_publish_inline_comment_delegates_to_create_review(monkeypatch): + """Single-comment publish path should result in a create_review call.""" + fake_pr = _FakePR() + provider = _make_provider(pr=fake_pr) + monkeypatch.setattr( + gh_module, + "find_line_number_of_relevant_line_in_file", + lambda *a, **kw: (2, 7), + ) + + provider.publish_inline_comment("hi", "src/foo.py", "x = 1") + + assert len(fake_pr.create_review_calls) == 1 + call = fake_pr.create_review_calls[0] + assert call["commit"].sha == "deadbeef" + assert call["comments"] == [{"body": "hi", "path": "src/foo.py", "position": 2}] + + +def test_publish_inline_comments_non_422_reraises(): + """Non-422 exceptions during create_review must propagate (no fallback).""" + fake_pr = _FakePR(raise_on_first=_FakeGithubException(status=500)) + provider = _make_provider(pr=fake_pr) + + with pytest.raises(_FakeGithubException) as excinfo: + provider.publish_inline_comments( + [{"body": "b", "path": "f.py", "position": 1}] + ) + assert excinfo.value.status == 500 + # Only the original failing call was attempted - no fallback create_review. + assert len(fake_pr.create_review_calls) == 1 + + +def test_publish_inline_comments_disable_fallback_reraises_422(): + """When disable_fallback=True even a 422 must not trigger the fallback path.""" + fake_pr = _FakePR(raise_on_first=_FakeGithubException(status=422)) + provider = _make_provider(pr=fake_pr) + + with pytest.raises(_FakeGithubException): + provider.publish_inline_comments( + [{"body": "b", "path": "f.py", "position": 1}], + disable_fallback=True, + ) + assert len(fake_pr.create_review_calls) == 1 + + +def test_publish_inline_comments_422_triggers_fallback(monkeypatch): + """On 422 the provider should invoke the verification-based fallback.""" + fake_pr = _FakePR(raise_on_first=_FakeGithubException(status=422)) + provider = _make_provider(pr=fake_pr) + + called = {"n": 0, "args": None} + + def fake_fallback(comments): + called["n"] += 1 + called["args"] = comments + + provider._publish_inline_comments_fallback_with_verification = fake_fallback + + comments = [{"body": "b", "path": "f.py", "position": 1}] + provider.publish_inline_comments(comments) + + assert called["n"] == 1 + assert called["args"] == comments + # The initial create_review attempt is the only one made directly here; + # the fallback is stubbed out and would normally do further work. + assert len(fake_pr.create_review_calls) == 1 + + +def test_publish_inline_comments_fallback_failure_propagates(monkeypatch): + fake_pr = _FakePR(raise_on_first=_FakeGithubException(status=422)) + provider = _make_provider(pr=fake_pr) + + def broken_fallback(comments): + raise RuntimeError("fallback boom") + + provider._publish_inline_comments_fallback_with_verification = broken_fallback + + with pytest.raises(RuntimeError, match="fallback boom"): + provider.publish_inline_comments( + [{"body": "b", "path": "f.py", "position": 1}] + ) + + +def test_publish_inline_comments_success_no_fallback(): + """On a clean create_review call no fallback should be invoked.""" + fake_pr = _FakePR() + provider = _make_provider(pr=fake_pr) + + sentinel = {"called": False} + + def should_not_run(_): + sentinel["called"] = True + + provider._publish_inline_comments_fallback_with_verification = should_not_run + + provider.publish_inline_comments([{"body": "b", "path": "f.py", "position": 1}]) + + assert sentinel["called"] is False + assert len(fake_pr.create_review_calls) == 1 + + +# --------------------------------------------------------------------------- +# publish_code_suggestions - multi-line vs single-line payload shape +# --------------------------------------------------------------------------- + +def _stub_validation_passthrough(provider): + """Bypass hunk-validation so we can directly assert the constructed payload.""" + provider.validate_comments_inside_hunks = lambda suggestions: suggestions + + +def test_publish_code_suggestions_multi_line_payload_shape(): + """Multi-line suggestions (end > start) must use start_line/start_side fields.""" + fake_pr = _FakePR() + provider = _make_provider(pr=fake_pr) + _stub_validation_passthrough(provider) + + captured = {} + + def capture(comments, disable_fallback=False): + captured["comments"] = comments + + provider.publish_inline_comments = capture + + suggestions = [{ + "body": "```suggestion\nnew\n```", + "relevant_file": "src/foo.py", + "relevant_lines_start": 10, + "relevant_lines_end": 14, + }] + + assert provider.publish_code_suggestions(suggestions) is True + + assert "comments" in captured + payload = captured["comments"][0] + assert payload == { + "body": "```suggestion\nnew\n```", + "path": "src/foo.py", + "line": 14, + "start_line": 10, + "start_side": "RIGHT", + } + # Multi-line payloads must NOT carry a top-level 'side'; GitHub infers it. + assert "side" not in payload + + +def test_publish_code_suggestions_single_line_payload_shape(): + """When start == end the API shape differs: no start_line/start_side, side only.""" + fake_pr = _FakePR() + provider = _make_provider(pr=fake_pr) + _stub_validation_passthrough(provider) + + captured = {} + provider.publish_inline_comments = lambda comments, disable_fallback=False: captured.setdefault("c", comments) + + suggestions = [{ + "body": "fix", + "relevant_file": "src/foo.py", + "relevant_lines_start": 7, + "relevant_lines_end": 7, + }] + + assert provider.publish_code_suggestions(suggestions) is True + payload = captured["c"][0] + assert payload == { + "body": "fix", + "path": "src/foo.py", + "line": 7, + "side": "RIGHT", + } + assert "start_line" not in payload and "start_side" not in payload + + +def test_publish_code_suggestions_skips_invalid_ranges(): + """Suggestions with missing/negative start, or end empty string, + # logged as an error but does not raise. + assert p._get_owner_and_repo_path("https://github.com/owner/repo") == "" + + def test_get_git_repo_url_uses_html_base(self): + p = _bare_provider() + p.base_url_html = "https://github.com" + assert ( + p.get_git_repo_url("https://github.com/owner/repo/pull/1") + == "https://github.com/owner/repo.git" + ) + + def test_get_git_repo_url_uses_ghes_html_base(self): + p = _bare_provider() + p.base_url_html = "https://ghes.example.com" + assert ( + p.get_git_repo_url("https://ghes.example.com/owner/repo/pull/1") + == "https://ghes.example.com/owner/repo.git" + ) + + def test_get_git_repo_url_mismatch_returns_empty(self): + """If derived owner/repo doesn't appear in the input URL, return ''.""" + p = _bare_provider() + p.base_url_html = "https://github.com" + # _get_owner_and_repo_path returns "" for this input, so the guard + # `repo_path not in issues_or_pr_url` triggers the empty-string return. + assert p.get_git_repo_url("https://github.com/owner/repo") == "" + + +# --------------------------------------------------------------------------- +# get_diff_files edit_type mapping +# --------------------------------------------------------------------------- +def _make_file( + filename: str, + status: str, + patch: str = "@@ -0,0 +1 @@\n+new", + additions: int = 1, + deletions: int = 0, +): + return SimpleNamespace( + filename=filename, + status=status, + patch=patch, + additions=additions, + deletions=deletions, + ) + + +def _make_provider_for_diff(files): + p = _bare_provider() + p.diff_files = None + p.git_files = None + p.incremental = SimpleNamespace(is_incremental=False) + p.unreviewed_files_set = {} + # pr.base/head shas drive repo.compare which we stub out below. + p.pr = SimpleNamespace( + base=SimpleNamespace(sha="base-sha"), + head=SimpleNamespace(sha="head-sha"), + get_files=lambda: files, + ) + # repo_obj.compare returns an object with a merge_base_commit. + p.repo_obj = SimpleNamespace( + compare=lambda b, h: SimpleNamespace( + merge_base_commit=SimpleNamespace(sha="base-sha") + ) + ) + return p + + +@pytest.fixture +def patched_helpers(): + """Patch module-level helpers used by get_diff_files.""" + mod = "pr_agent.git_providers.github_provider" + with patch(f"{mod}.filter_ignored", side_effect=lambda fs: fs), patch( + f"{mod}.is_valid_file", return_value=True + ), patch(f"{mod}.load_large_diff", return_value="LARGE_DIFF"): + yield + + +class TestGetDiffFilesEditTypes: + @pytest.mark.parametrize( + "status,expected", + [ + ("added", EDIT_TYPE.ADDED), + ("removed", EDIT_TYPE.DELETED), + ("renamed", EDIT_TYPE.RENAMED), + ("modified", EDIT_TYPE.MODIFIED), + ("copied", EDIT_TYPE.UNKNOWN), # any unrecognized status + ], + ) + def test_status_to_edit_type(self, patched_helpers, status, expected): + f = _make_file(f"{status}.py", status) + p = _make_provider_for_diff([f]) + # Avoid reaching real GitHub for file content. + p._get_pr_file_content = lambda file, sha: "content" + + diffs = p.get_diff_files() + + assert len(diffs) == 1 + assert isinstance(diffs[0], FilePatchInfo) + assert diffs[0].edit_type == expected + assert diffs[0].filename == f.filename + + def test_missing_patch_triggers_load_large_diff(self, patched_helpers): + """When file.patch is falsy, load_large_diff fills it in.""" + f = _make_file("big.py", "modified", patch="") + p = _make_provider_for_diff([f]) + p._get_pr_file_content = lambda file, sha: "content" + + diffs = p.get_diff_files() + + assert len(diffs) == 1 + assert diffs[0].patch == "LARGE_DIFF" + assert diffs[0].edit_type == EDIT_TYPE.MODIFIED + + def test_existing_patch_preserved(self, patched_helpers): + f = _make_file("ok.py", "modified", patch="@@ -1 +1 @@\n-a\n+b") + p = _make_provider_for_diff([f]) + p._get_pr_file_content = lambda file, sha: "content" + + diffs = p.get_diff_files() + + assert diffs[0].patch == "@@ -1 +1 @@\n-a\n+b" + + def test_cached_diff_files_short_circuits(self, patched_helpers): + p = _make_provider_for_diff([]) + sentinel = [FilePatchInfo("a", "b", "p", "f.py")] + p.diff_files = sentinel + # No fake _get_pr_file_content needed because it should not be called. + assert p.get_diff_files() is sentinel + + def test_additions_deletions_propagated(self, patched_helpers): + f = _make_file("x.py", "modified", additions=5, deletions=2) + p = _make_provider_for_diff([f]) + p._get_pr_file_content = lambda file, sha: "content" + + diffs = p.get_diff_files() + + assert diffs[0].num_plus_lines == 5 + assert diffs[0].num_minus_lines == 2 diff --git a/tests/unittest/test_markdown_ticket_output_core.py b/tests/unittest/test_markdown_ticket_output_core.py new file mode 100644 index 0000000000..c842ab8335 --- /dev/null +++ b/tests/unittest/test_markdown_ticket_output_core.py @@ -0,0 +1,682 @@ +""" +Focused unit tests for Markdown / parser / ticket-visible output helpers. + +These tests document current behavior of helper seams that render +user-visible Markdown / HTML for review output, ticket compliance, +TODO sections, and PR description ticket extraction. The goal is to +lock in branches that are not already covered by: + +- tests/unittest/test_convert_to_markdown.py +- tests/unittest/test_parse_code_suggestion.py +- tests/unittest/test_extract_issue_from_branch.py +- tests/unittest/test_pr_description.py + +Assertions intentionally key off structural markers (emoji, header +text, anchor/href substrings, list bullets) rather than full golden +strings to remain robust against trivial whitespace changes. +""" + +from unittest.mock import Mock + +import pytest + +from pr_agent.algo.utils import ( + convert_to_markdown_v2, + emphasize_header, + format_todo_item, + format_todo_items, + is_value_no, + parse_code_suggestion, + process_can_be_split, + ticket_markdown_logic, +) +from pr_agent.tools.pr_description import insert_br_after_x_chars +from pr_agent.tools.ticket_pr_compliance_check import ( + extract_ticket_links_from_pr_description, + find_jira_tickets, +) + +# --------------------------------------------------------------------------- +# is_value_no / emphasize_header +# --------------------------------------------------------------------------- + + +class TestIsValueNo: + @pytest.mark.parametrize( + "value", + ["No", "no", "NONE", " false ", "", None, 0, [], {}], + ) + def test_truthy_no_values(self, value): + assert is_value_no(value) is True + + @pytest.mark.parametrize("value", ["yes", "Yes", "true", "maybe", "123"]) + def test_other_values_are_not_no(self, value): + assert is_value_no(value) is False + + +class TestEmphasizeHeader: + def test_html_emphasis_with_colon(self): + out = emphasize_header("Header: details continue here") + # First segment is wrapped in and split with
. + assert out.startswith("Header:") + assert "
" in out + assert "details continue here" in out + + def test_markdown_only_emphasis_with_colon(self): + out = emphasize_header("Header: rest", only_markdown=True) + assert out.startswith("**Header:**") + # Newline-separated rest of text (no
). + assert "\n rest" in out + assert "
" not in out + + def test_reference_link_html(self): + out = emphasize_header( + "Header: rest", reference_link="https://example.com/x" + ) + assert "Header:" in out + assert out.startswith("") + + def test_reference_link_markdown(self): + out = emphasize_header( + "Header: rest", + only_markdown=True, + reference_link="https://example.com/x", + ) + assert "[**Header:**](https://example.com/x)" in out + + def test_no_colon_returns_unchanged(self): + text = "Plain text without a delimiter" + assert emphasize_header(text) == text + + +# --------------------------------------------------------------------------- +# convert_to_markdown_v2 — branches not covered elsewhere +# --------------------------------------------------------------------------- + + +class TestConvertToMarkdownV2Branches: + def test_empty_review_returns_empty(self): + # When the review dict exists but is missing, output is empty. + assert convert_to_markdown_v2({}).strip() == "" + assert convert_to_markdown_v2({"review": None}).strip() == "" + + def test_incremental_review_header_and_note(self): + out = convert_to_markdown_v2( + {"review": {"security_concerns": "No"}}, + incremental_review="2 commits", + ) + assert "Incremental PR Reviewer Guide" in out + assert "Review for commits since previous PR-Agent review 2 commits" in out + + def test_relevant_tests_yes_branch_gfm(self): + out = convert_to_markdown_v2( + {"review": {"relevant_tests": "Yes"}} + ) + assert "PR contains tests" in out + assert "" in out and "
" in out + + def test_relevant_tests_yes_branch_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"relevant_tests": "Yes"}}, gfm_supported=False + ) + assert "### 🧪 PR contains tests" in out + assert "" not in out + + def test_relevant_tests_no_branch_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"relevant_tests": "No"}}, gfm_supported=False + ) + assert "### 🧪 No relevant tests" in out + + def test_security_concerns_with_details_gfm(self): + out = convert_to_markdown_v2( + {"review": {"security_concerns": "SQL injection: details follow"}} + ) + assert "Security concerns" in out + # emphasize_header wraps the part before ':' in . + assert "SQL injection:" in out + + def test_security_concerns_with_details_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"security_concerns": "SQL injection: details"}}, + gfm_supported=False, + ) + assert "### 🔒 Security concerns" in out + assert "**SQL injection:**" in out + + def test_key_issues_no_major_issues_gfm(self): + out = convert_to_markdown_v2( + {"review": {"key_issues_to_review": "No"}} + ) + assert "No major issues detected" in out + + def test_key_issues_no_major_issues_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"key_issues_to_review": "No"}}, gfm_supported=False + ) + assert "### ⚡ No major issues detected" in out + + def test_key_issues_possible_bug_header_softened(self): + mock_provider = Mock() + mock_provider.get_line_link.return_value = "https://example.com/diff" + out = convert_to_markdown_v2( + { + "review": { + "key_issues_to_review": [ + { + "relevant_file": "src/x.py", + "issue_header": "possible bug", + "issue_content": "may explode", + "start_line": 1, + "end_line": 2, + } + ] + } + }, + git_provider=mock_provider, + ) + # 'possible bug' is rewritten to the less alarming 'Possible Issue'. + assert "Possible Issue" in out + assert "possible bug" not in out + + def test_key_issues_without_provider_renders_strong_header(self): + out = convert_to_markdown_v2( + { + "review": { + "key_issues_to_review": [ + { + "relevant_file": "src/x.py", + "issue_header": "Code Smell", + "issue_content": "long", + "start_line": 1, + "end_line": 2, + } + ] + } + } + ) + # No reference link → plain header, no anchor. + assert "Code Smell" in out + assert ": 3 🔵🔵🔵⚪⚪" in out + + def test_estimated_effort_invalid_value_is_skipped(self): + # Completely unparsable value falls through `continue` and is omitted. + out = convert_to_markdown_v2( + {"review": {"estimated_effort_to_review_[1-5]": "not-a-number"}} + ) + assert "Estimated effort to review" not in out + + def test_can_be_split_single_item_renders_no_themes(self): + out = convert_to_markdown_v2( + { + "review": { + "can_be_split": [ + {"relevant_files": ["a.py"], "title": "Only one"} + ] + } + } + ) + assert "No multiple PR themes" in out + + def test_can_be_split_empty_renders_no_themes(self): + out = convert_to_markdown_v2({"review": {"can_be_split": []}}) + assert "No multiple PR themes" in out + + def test_default_branch_unknown_key_gfm(self): + out = convert_to_markdown_v2( + {"review": {"some_other_field": "interesting value"}} + ) + # Fallback formatting capitalizes & joins with ': '. + assert "Some other field: interesting value" in out + + def test_default_branch_unknown_key_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"some_other_field": "interesting value"}}, + gfm_supported=False, + ) + assert "### Some other field: interesting value" in out + + def test_todo_sections_no_value_gfm(self): + out = convert_to_markdown_v2({"review": {"todo_sections": "No"}}) + assert "No TODO sections" in out + assert "✅" in out + + def test_todo_sections_no_value_non_gfm(self): + out = convert_to_markdown_v2( + {"review": {"todo_sections": "No"}}, gfm_supported=False + ) + assert "### ✅ No TODO sections" in out + + def test_todo_sections_list_with_provider_gfm(self): + provider = Mock() + provider.get_line_link.return_value = "https://example.com/L10" + out = convert_to_markdown_v2( + { + "review": { + "todo_sections": [ + { + "relevant_file": "src/x.py", + "line_number": 10, + "content": "finish refactor", + } + ] + } + }, + git_provider=provider, + ) + assert "TODO sections" in out + assert "
    " in out + assert "src/x.py [10]" in out + assert "finish refactor" in out + + +# --------------------------------------------------------------------------- +# ticket_markdown_logic — covers branches the existing single-ticket test +# in test_convert_to_markdown.py does not (mixed, not-compliant, partial, +# empty list, non-gfm rendering). +# --------------------------------------------------------------------------- + + +class TestTicketMarkdownLogic: + @pytest.fixture(autouse=True) + def _cleanup_extra_statistics(self): + """``ticket_markdown_logic`` writes ``config.extra_statistics`` as a + side effect; snapshot and restore it so these tests don't leak + ``compliance_level`` state into other tests sharing the settings + singleton. + """ + from tests.unittest._settings_helpers import ( + restore_settings, + snapshot_settings, + ) + + snapshot = snapshot_settings(["config.extra_statistics"]) + try: + yield + finally: + restore_settings(snapshot) + + def _ticket(self, **overrides): + base = { + "ticket_url": "https://example.com/ticket/42", + "ticket_requirements": "- R1\n", + "fully_compliant_requirements": "", + "not_compliant_requirements": "", + "requires_further_human_verification": "", + } + base.update(overrides) + return base + + def test_not_a_list_returns_unchanged(self): + # Defensive branch: non-list values are ignored. + out = ticket_markdown_logic("🎫", "PREFIX", "not-a-list", True) + assert out == "PREFIX" + + def test_empty_list_still_renders_header_without_compliance_emoji(self): + # Current behavior: even with an empty ticket list the gfm branch + # appends a header row, but with an empty compliance emoji and no + # body. This documents the quirk so a future refactor that skips + # the header in this case will surface as a deliberate change. + out = ticket_markdown_logic("🎫", "PREFIX", [], True) + assert out.startswith("PREFIX") + assert "Ticket compliance analysis" in out + # No compliance emoji is rendered after the heading text. + assert "Ticket compliance analysis **" in out + + def test_not_compliant_only_renders_red_x(self): + tickets = [ + self._ticket(not_compliant_requirements="- broken\n") + ] + out = ticket_markdown_logic("🎫", "", tickets, True) + assert "Ticket compliance analysis ❌" in out + assert "Not compliant" in out + assert "Non-compliant requirements:" in out + + def test_partially_compliant_renders_orange_diamond(self): + tickets = [ + self._ticket( + fully_compliant_requirements="- ok\n", + not_compliant_requirements="- broken\n", + ) + ] + out = ticket_markdown_logic("🎫", "", tickets, True) + assert "Ticket compliance analysis 🔶" in out + assert "Partially compliant" in out + # Both sections are rendered. + assert "Compliant requirements:" in out + assert "Non-compliant requirements:" in out + + def test_mixed_full_and_not_compliant_renders_partial(self): + tickets = [ + self._ticket(fully_compliant_requirements="- ok\n"), + self._ticket( + ticket_url="https://example.com/ticket/43", + not_compliant_requirements="- broken\n", + ), + ] + out = ticket_markdown_logic("🎫", "", tickets, True) + # Mix of Fully compliant + Not compliant ⇒ overall Partially compliant 🔶. + assert "Ticket compliance analysis 🔶" in out + # Both ticket id slugs are rendered. + assert "[42](https://example.com/ticket/42)" in out + assert "[43](https://example.com/ticket/43)" in out + + def test_requires_further_human_verification_marks_pr_code_verified(self): + tickets = [ + self._ticket( + fully_compliant_requirements="- ok\n", + requires_further_human_verification="- check infra\n", + ) + ] + out = ticket_markdown_logic("🎫", "", tickets, True) + assert "PR Code Verified" in out + assert "Requires further human verification:" in out + # All tickets verified ⇒ green check. + assert "Ticket compliance analysis ✅" in out + + def test_ticket_with_no_requirements_renders_header_only(self): + # Tickets that have neither compliant nor non-compliant requirements + # are skipped in the per-ticket loop, but the gfm branch still + # emits an (empty-body) header row. This documents that current + # behavior — no compliance level or per-ticket detail is shown. + tickets = [self._ticket()] + out = ticket_markdown_logic("🎫", "", tickets, True) + assert "Ticket compliance analysis" in out + # No per-ticket body rendered. + assert "https://example.com/ticket/42" not in out + + def test_non_gfm_renders_markdown_heading(self): + tickets = [self._ticket(fully_compliant_requirements="- ok\n")] + out = ticket_markdown_logic("🎫", "", tickets, gfm_supported=False) + assert out.startswith("### 🎫 Ticket compliance analysis ✅") + assert "
" not in out + + +# --------------------------------------------------------------------------- +# process_can_be_split — direct helper tests for edge inputs. +# --------------------------------------------------------------------------- + + +class TestProcessCanBeSplit: + def test_empty_value_returns_no_themes(self): + out = process_can_be_split("🔀", []) + assert "No multiple PR themes" in out + + def test_single_element_list_returns_no_themes(self): + out = process_can_be_split( + "🔀", [{"title": "only one", "relevant_files": ["a.py"]}] + ) + assert "No multiple PR themes" in out + + def test_multiple_themes_render_details(self): + out = process_can_be_split( + "🔀", + [ + {"title": "Refactor", "relevant_files": ["a.py", "b.py"]}, + {"title": "Fix", "relevant_files": ["c.py"]}, + ], + ) + assert "
" in out + # Each theme title is rendered. + assert "Refactor" in out + assert "Fix" in out + # Relevant files are bullet-listed. + assert "- a.py" in out + assert "- b.py" in out + assert "- c.py" in out + + +# --------------------------------------------------------------------------- +# format_todo_item / format_todo_items +# --------------------------------------------------------------------------- + + +class TestFormatTodoItem: + def _provider(self, link="https://example.com/L5"): + p = Mock() + p.get_line_link.return_value = link + return p + + def test_gfm_with_content_uses_anchor(self): + out = format_todo_item( + {"relevant_file": "src/a.py", "line_number": 5, "content": "do it"}, + self._provider(), + gfm_supported=True, + ) + assert "src/a.py [5]: do it" == out + + def test_non_gfm_with_content_uses_markdown_link(self): + out = format_todo_item( + {"relevant_file": "src/a.py", "line_number": 5, "content": "do it"}, + self._provider(), + gfm_supported=False, + ) + assert out == "[src/a.py [5]](https://example.com/L5): do it" + + def test_empty_content_returns_only_file_ref(self): + out = format_todo_item( + {"relevant_file": "src/a.py", "line_number": 5, "content": ""}, + self._provider(), + gfm_supported=True, + ) + assert out.endswith("src/a.py [5]") + assert ":" not in out.split("")[-1] # no trailing ": " + + def test_no_reference_link_plain_file_ref(self): + out = format_todo_item( + {"relevant_file": "src/a.py", "line_number": 5, "content": "x"}, + self._provider(link=""), + gfm_supported=True, + ) + # Falsy reference_link → no anchor tag. + assert "") and out.rstrip().endswith("

") + + def test_single_item_non_gfm_uses_bullet(self): + out = format_todo_items( + {"relevant_file": "f.py", "line_number": 1, "content": "x"}, + self._provider(), + gfm_supported=False, + ) + assert out.startswith("- ") + + def test_list_truncates_to_max_items_gfm(self): + items = [ + {"relevant_file": f"f{i}.py", "line_number": i, "content": "x"} + for i in range(10) + ] + out = format_todo_items(items, self._provider(), gfm_supported=True) + # MAX_ITEMS is 5 — only the first five files appear, the rest are dropped. + for i in range(5): + assert f"f{i}.py" in out + for i in range(5, 10): + assert f"f{i}.py" not in out + assert out.count("
  • ") == 5 + + def test_list_truncates_to_max_items_non_gfm(self): + items = [ + {"relevant_file": f"f{i}.py", "line_number": i, "content": "x"} + for i in range(7) + ] + out = format_todo_items(items, self._provider(), gfm_supported=False) + # Counts bullet rows. + assert out.count("\n- ") + (1 if out.startswith("- ") else 0) == 5 + + +# --------------------------------------------------------------------------- +# parse_code_suggestion — gfm branch with relevant_line is not exercised +# by tests/unittest/test_parse_code_suggestion.py. +# --------------------------------------------------------------------------- + + +class TestParseCodeSuggestionGfm: + def test_relevant_line_with_markdown_link(self): + suggestion = { + "relevant_file": "src/app.py", + "suggestion": "Use a constant", + "relevant_line": "[`foo = 1`](https://example.com/diff#L10)", + } + out = parse_code_suggestion(suggestion, gfm_supported=True) + assert out.startswith("
  • ") + assert "" in out + assert "" in out and "Use a constant" in out + assert "" in out + assert out.rstrip().endswith("
    ") + + def test_relevant_line_without_link(self): + suggestion = { + "relevant_file": "src/app.py", + "suggestion": "Use a constant", + "relevant_line": "`foo = 1`", + } + out = parse_code_suggestion(suggestion, gfm_supported=True) + # No "](" link delimiter → no anchor, just the (leading-backtick + # stripped) literal line. + assert "
    " in out + assert "foo = 1" in out + + @pytest.mark.xfail( + strict=True, + reason=( + "parse_code_suggestion only left-strips a leading backtick from " + "relevant_line; the trailing backtick is not stripped. This xfail " + "encodes the desired symmetric stripping behavior." + ), + ) + def test_relevant_line_strips_both_backticks(self): + suggestion = { + "relevant_file": "src/app.py", + "suggestion": "Use a constant", + "relevant_line": "`foo = 1`", + } + out = parse_code_suggestion(suggestion, gfm_supported=True) + assert "" in out + + def test_falls_back_to_non_gfm_when_no_relevant_line(self): + # Without 'relevant_line', the function takes the non-gfm code path + # even when gfm_supported=True. + suggestion = {"suggestion": "S", "description": "D"} + out = parse_code_suggestion(suggestion, gfm_supported=True) + assert "
    relevant filesrc/app.py
    relevant linefoo = 1
    " not in out + assert "**suggestion:**" in out + assert "**description:**" in out + + +# --------------------------------------------------------------------------- +# insert_br_after_x_chars — edges around its very short-circuit branches. +# --------------------------------------------------------------------------- + + +class TestInsertBrAfterXChars: + def test_empty_returns_empty_string(self): + assert insert_br_after_x_chars("") == "" + assert insert_br_after_x_chars(None) == "" + + def test_short_text_returned_unchanged(self): + text = "short text" + assert insert_br_after_x_chars(text) == text + + def test_long_text_inserts_br(self): + text = "word " * 30 # well over default x=70 + out = insert_br_after_x_chars(text) + assert "
    " in out + + def test_bullet_list_starts_with_li(self): + text = ( + "- first bullet with a fair amount of text " + "that should clearly exceed the seventy character limit\n" + "- second bullet" + ) + out = insert_br_after_x_chars(text) + assert "
  • " in out + + +# --------------------------------------------------------------------------- +# Ticket extraction from PR description / Jira ticket detection. +# --------------------------------------------------------------------------- + + +class TestFindJiraTickets: + def test_finds_standard_jira_id(self): + assert "PROJ-123" in find_jira_tickets("Fixes PROJ-123 today") + + def test_finds_jira_via_url(self): + text = "See https://company.atlassian.net/browse/ABC-9 for details" + tickets = find_jira_tickets(text) + assert "ABC-9" in tickets + + def test_no_match_returns_empty(self): + assert find_jira_tickets("nothing here") == [] + + def test_short_uppercase_prefix_not_matched(self): + # Requires at least 2 uppercase letters; single-letter prefixes ignored. + assert find_jira_tickets("A-1 should not match") == [] + + def test_deduplicates_repeated_tickets(self): + tickets = find_jira_tickets("PROJ-1 PROJ-1 PROJ-1") + assert tickets == ["PROJ-1"] + + +class TestExtractTicketLinksFromPRDescription: + def test_full_url_extracted(self): + desc = "Closes https://github.com/foo/bar/issues/7 and more" + out = extract_ticket_links_from_pr_description(desc, "foo/bar") + assert "https://github.com/foo/bar/issues/7" in out + + def test_shorthand_owner_repo_issue(self): + desc = "See foo/bar#42 for context" + out = extract_ticket_links_from_pr_description( + desc, "foo/bar", base_url_html="https://github.com" + ) + assert "https://github.com/foo/bar/issues/42" in out + + def test_hash_only_uses_repo_path(self): + desc = "Fixes #5" + out = extract_ticket_links_from_pr_description(desc, "foo/bar") + assert "https://github.com/foo/bar/issues/5" in out + + def test_hash_only_requires_repo_path(self): + desc = "Fixes #5" + # Without repo_path, '#5'-only references cannot be resolved. + out = extract_ticket_links_from_pr_description(desc, "") + assert out == [] + + def test_hash_only_rejects_long_numbers(self): + desc = "Fixes #12345 (5 digits, looks like a code, not an issue)" + out = extract_ticket_links_from_pr_description(desc, "foo/bar") + assert out == [] + + def test_results_capped_at_three(self): + desc = " ".join(f"foo/bar#{i}" for i in range(1, 8)) + out = extract_ticket_links_from_pr_description(desc, "foo/bar") + assert len(out) == 3 + + def test_base_url_trailing_slash_is_stripped(self): + desc = "See foo/bar#1" + out = extract_ticket_links_from_pr_description( + desc, "foo/bar", base_url_html="https://ghe.example.com/" + ) + assert out == ["https://ghe.example.com/foo/bar/issues/1"] diff --git a/tests/unittest/test_pr_code_suggestions_filtering.py b/tests/unittest/test_pr_code_suggestions_filtering.py new file mode 100644 index 0000000000..1d74c7c40d --- /dev/null +++ b/tests/unittest/test_pr_code_suggestions_filtering.py @@ -0,0 +1,450 @@ +"""Focused tests for /improve filtering and quality-guard helpers. + +These tests exercise pure-Python helpers on PRCodeSuggestions without +invoking any LLM or git provider network calls. The tool is constructed +via ``__new__`` and only the attributes touched by each helper are set. +""" +from unittest.mock import MagicMock + +import pytest + +from pr_agent.algo.types import FilePatchInfo +from pr_agent.config_loader import get_settings +from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions +from tests.unittest._settings_helpers import restore_settings, snapshot_settings + +TRUNCATION_SETTINGS = ( + "pr_code_suggestions.max_code_suggestion_length", + "pr_code_suggestions.suggestion_truncation_message", +) + + +def _make_tool(git_provider=None): + tool = PRCodeSuggestions.__new__(PRCodeSuggestions) + tool.git_provider = git_provider or MagicMock() + tool.progress_response = None + return tool + + +def _valid_suggestion(**overrides): + suggestion = { + "one_sentence_summary": "Avoid duplicated work", + "label": "maintainability", + "relevant_file": "app.py", + "relevant_lines_start": 1, + "relevant_lines_end": 1, + "suggestion_content": "Use the shared helper.", + "existing_code": "old()", + "improved_code": "new()", + } + suggestion.update(overrides) + return suggestion + + +# --------------------------------------------------------------------------- +# _truncate_if_needed +# --------------------------------------------------------------------------- + +def test_truncate_if_needed_noop_when_threshold_disabled(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 0) + try: + suggestion = _valid_suggestion(improved_code="x" * 5000) + result = PRCodeSuggestions._truncate_if_needed(suggestion) + assert result["improved_code"] == "x" * 5000 + finally: + restore_settings(snapshot) + + +def test_truncate_if_needed_truncates_and_appends_message(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 10) + settings.set("pr_code_suggestions.suggestion_truncation_message", "[...truncated]") + try: + suggestion = _valid_suggestion(improved_code="abcdefghijKLMNOP") + result = PRCodeSuggestions._truncate_if_needed(suggestion) + assert result["improved_code"].startswith("abcdefghij") + assert "[...truncated]" in result["improved_code"] + # Truncated body is exactly the first max_code_suggestion_length chars + assert result["improved_code"].split("\n")[0] == "abcdefghij" + finally: + restore_settings(snapshot) + + +def test_truncate_if_needed_keeps_short_code_unchanged(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 1000) + try: + suggestion = _valid_suggestion(improved_code="short()") + result = PRCodeSuggestions._truncate_if_needed(suggestion) + assert result["improved_code"] == "short()" + finally: + restore_settings(snapshot) + + +# --------------------------------------------------------------------------- +# validate_one_liner_suggestion_not_repeating_code (stale-suggestion guard) +# --------------------------------------------------------------------------- + +def _patch_files(base_file, head_file, filename="app.py"): + return [FilePatchInfo(base_file=base_file, head_file=head_file, patch="", filename=filename)] + + +def test_validate_one_liner_marks_stale_suggestion_as_score_zero(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = _patch_files( + base_file="return old()\n", head_file="return new()\n" + ) + tool = _make_tool(git_provider) + suggestion = _valid_suggestion( + existing_code="return old()", + improved_code="return new()", + score=8, + ) + + result = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert result["score"] == 0 + + +def test_validate_one_liner_skips_when_existing_code_uses_ellipsis(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = _patch_files( + base_file="return old()\n", head_file="return new()\n" + ) + tool = _make_tool(git_provider) + suggestion = _valid_suggestion( + existing_code="... old() ...", + improved_code="return new()", + score=8, + ) + + result = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + # '...' short-circuits the check; original score is preserved. + assert result["score"] == 8 + + +def test_validate_one_liner_preserves_score_when_code_not_yet_applied(): + git_provider = MagicMock() + # head still contains the old code: the patch hasn't applied the new code yet. + git_provider.get_diff_files.return_value = _patch_files( + base_file="return old()\n", head_file="return old()\n" + ) + tool = _make_tool(git_provider) + suggestion = _valid_suggestion( + existing_code="return old()", + improved_code="return new()", + score=8, + ) + + result = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert result["score"] == 8 + + +def test_validate_one_liner_preserves_score_when_filename_not_in_diff(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = _patch_files( + base_file="return old()\n", + head_file="return new()\n", + filename="other.py", + ) + tool = _make_tool(git_provider) + suggestion = _valid_suggestion( + existing_code="return old()", + improved_code="return new()", + score=8, + ) + + result = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert result["score"] == 8 + + +def test_validate_one_liner_handles_empty_head_file_gracefully(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = _patch_files( + base_file="return old()\n", head_file="" + ) + tool = _make_tool(git_provider) + suggestion = _valid_suggestion( + existing_code="return old()", + improved_code="return new()", + score=8, + ) + + result = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert result["score"] == 8 + + +# --------------------------------------------------------------------------- +# remove_line_numbers +# --------------------------------------------------------------------------- + +def test_remove_line_numbers_strips_leading_digits_and_separator(): + tool = _make_tool() + tool.patches_diff_list = [ + "## File: app.py\n" + "1 def f():\n" + "2 return old()\n" + "10 return other()\n" + ] + + result = tool.remove_line_numbers(tool.patches_diff_list) + + assert len(result) == 1 + lines = result[0].splitlines() + # Header (no leading digit) is preserved unchanged. + assert lines[0] == "## File: app.py" + # Each numbered code line has the "" prefix removed. + assert lines[1] == "def f():" + assert lines[2] == " return old()" + assert lines[3] == " return other()" + + +def test_remove_line_numbers_clears_pure_numeric_lines(): + tool = _make_tool() + tool.patches_diff_list = ["42\nkeep me\n7\ntail"] + + result = tool.remove_line_numbers(tool.patches_diff_list) + + lines = result[0].splitlines() + assert lines == ["", "keep me", "", "tail"] + + +def test_remove_line_numbers_preserves_blank_lines(): + tool = _make_tool() + tool.patches_diff_list = ["1 alpha\n\n2 beta"] + + result = tool.remove_line_numbers(tool.patches_diff_list) + + lines = result[0].splitlines() + assert lines == ["alpha", "", "beta"] + + +def test_remove_line_numbers_returns_original_on_exception(): + tool = _make_tool() + # Exercise the broad ``except`` fallback by putting invalid data in the + # instance list itself (``None`` has no ``splitlines``), and assert the + # parameter object is returned untouched. + tool.patches_diff_list = [None] + original_input = ["1 alpha"] + + result = tool.remove_line_numbers(original_input) + + assert result is original_input + + +# --------------------------------------------------------------------------- +# analyze_self_reflection_response (reflection-mismatch / invalid output) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_analyze_self_reflection_length_mismatch_leaves_data_untouched(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [] + tool = _make_tool(git_provider) + settings = get_settings() + original_publish_output = settings.config.publish_output + settings.config.publish_output = False + try: + data = {"code_suggestions": [_valid_suggestion(), _valid_suggestion(one_sentence_summary="Second")]} + # Only one feedback item for two suggestions -> mismatch, all skipped. + response_reflect = """ +code_suggestions: + - suggestion_score: 9 + why: only one feedback entry +""" + + await tool.analyze_self_reflection_response(data, response_reflect) + + for suggestion in data["code_suggestions"]: + assert "score" not in suggestion + assert "score_why" not in suggestion + finally: + settings.config.publish_output = original_publish_output + + +@pytest.mark.asyncio +async def test_analyze_self_reflection_invalid_feedback_assigns_default_score_seven(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [] + tool = _make_tool(git_provider) + settings = get_settings() + original_publish_output = settings.config.publish_output + settings.config.publish_output = False + try: + data = {"code_suggestions": [_valid_suggestion()]} + # Missing required keys ('suggestion_score', 'why') triggers the + # fallback branch which assigns score=7 and clears score_why. + response_reflect = """ +code_suggestions: + - irrelevant_key: 1 +""" + + await tool.analyze_self_reflection_response(data, response_reflect) + + assert data["code_suggestions"][0]["score"] == 7 + assert data["code_suggestions"][0]["score_why"] == "" + finally: + settings.config.publish_output = original_publish_output + + +@pytest.mark.asyncio +async def test_analyze_self_reflection_clears_existing_code_when_equal_to_improved(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [] + tool = _make_tool(git_provider) + settings = get_settings() + original_publish_output = settings.config.publish_output + snapshot = snapshot_settings(["pr_code_suggestions.commitable_code_suggestions"]) + settings.config.publish_output = False + settings.set("pr_code_suggestions.commitable_code_suggestions", False) + try: + data = {"code_suggestions": [_valid_suggestion(existing_code="same()", improved_code="same()")]} + response_reflect = """ +code_suggestions: + - suggestion_score: 6 + why: equal codes +""" + + await tool.analyze_self_reflection_response(data, response_reflect) + + suggestion = data["code_suggestions"][0] + assert suggestion["score"] == 6 + # Non-commitable mode clears existing_code so the rendered suggestion + # doesn't show an identical before/after block. + assert suggestion["existing_code"] == "" + assert suggestion["improved_code"] == "same()" + finally: + settings.config.publish_output = original_publish_output + restore_settings(snapshot) + + +@pytest.mark.asyncio +async def test_analyze_self_reflection_clears_improved_code_in_commitable_mode(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [] + tool = _make_tool(git_provider) + settings = get_settings() + original_publish_output = settings.config.publish_output + snapshot = snapshot_settings(["pr_code_suggestions.commitable_code_suggestions"]) + settings.config.publish_output = False + settings.set("pr_code_suggestions.commitable_code_suggestions", True) + try: + data = {"code_suggestions": [_valid_suggestion(existing_code="same()", improved_code="same()")]} + response_reflect = """ +code_suggestions: + - suggestion_score: 6 + why: equal codes +""" + + await tool.analyze_self_reflection_response(data, response_reflect) + + suggestion = data["code_suggestions"][0] + # Commitable mode keeps existing_code (used to locate the line in PR) + # and clears improved_code instead. + assert suggestion["existing_code"] == "same()" + assert suggestion["improved_code"] == "" + finally: + settings.config.publish_output = original_publish_output + restore_settings(snapshot) + + +# --------------------------------------------------------------------------- +# _prepare_pr_code_suggestions filtering +# --------------------------------------------------------------------------- + +def test_prepare_pr_code_suggestions_drops_const_instead_let_suggestion(): + tool = _make_tool() + prediction = """ +code_suggestions: + - one_sentence_summary: Prefer const + label: best practice + relevant_file: app.js + suggestion_content: Use const instead of let when not reassigning. + existing_code: let x = 1; + improved_code: const x = 1; + - one_sentence_summary: Keep this one + label: maintainability + relevant_file: app.js + suggestion_content: Extract helper. + existing_code: a() + improved_code: helper() +""" + + data = tool._prepare_pr_code_suggestions(prediction) + + summaries = [s["one_sentence_summary"] for s in data["code_suggestions"]] + assert summaries == ["Keep this one"] + + +def test_prepare_pr_code_suggestions_skips_suggestion_missing_improved_code(): + tool = _make_tool() + prediction = """ +code_suggestions: + - one_sentence_summary: Missing improved_code + label: maintainability + relevant_file: app.py + suggestion_content: Refactor. + existing_code: a() + - one_sentence_summary: Complete + label: maintainability + relevant_file: app.py + suggestion_content: Refactor. + existing_code: a() + improved_code: b() +""" + + data = tool._prepare_pr_code_suggestions(prediction) + + assert len(data["code_suggestions"]) == 1 + assert data["code_suggestions"][0]["one_sentence_summary"] == "Complete" + + +def test_prepare_pr_code_suggestions_accepts_list_payload(): + tool = _make_tool() + # Some prompt variants return a bare list rather than a mapping. + prediction = """ +- one_sentence_summary: Only suggestion + label: maintainability + relevant_file: app.py + suggestion_content: Refactor. + existing_code: a() + improved_code: b() +""" + + data = tool._prepare_pr_code_suggestions(prediction) + + assert isinstance(data, dict) + assert len(data["code_suggestions"]) == 1 + assert data["code_suggestions"][0]["improved_code"] == "b()" + + +def test_prepare_pr_code_suggestions_truncates_long_improved_code(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 8) + settings.set("pr_code_suggestions.suggestion_truncation_message", "[cut]") + tool = _make_tool() + prediction = """ +code_suggestions: + - one_sentence_summary: Long + label: maintainability + relevant_file: app.py + suggestion_content: Refactor. + existing_code: a() + improved_code: ABCDEFGHIJKLMNOP +""" + try: + data = tool._prepare_pr_code_suggestions(prediction) + improved = data["code_suggestions"][0]["improved_code"] + assert improved.startswith("ABCDEFGH") + assert "[cut]" in improved + finally: + restore_settings(snapshot) diff --git a/tests/unittest/test_pr_code_suggestions_rendering.py b/tests/unittest/test_pr_code_suggestions_rendering.py new file mode 100644 index 0000000000..c3eeebc0f3 --- /dev/null +++ b/tests/unittest/test_pr_code_suggestions_rendering.py @@ -0,0 +1,348 @@ +from unittest.mock import MagicMock + +import pytest + +from pr_agent.algo.types import FilePatchInfo +from pr_agent.config_loader import get_settings +from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions +from tests.unittest._settings_helpers import restore_settings, snapshot_settings + +TRUNCATION_SETTINGS = ( + "pr_code_suggestions.max_code_suggestion_length", + "pr_code_suggestions.suggestion_truncation_message", +) + + +def _make_tool(git_provider=None): + tool = PRCodeSuggestions.__new__(PRCodeSuggestions) + tool.git_provider = git_provider or MagicMock() + tool.progress_response = None + return tool + + +def _suggestion(**overrides): + base = { + "one_sentence_summary": "Use the shared helper", + "label": "maintainability", + "relevant_file": "app.py", + "relevant_lines_start": 2, + "relevant_lines_end": 2, + "suggestion_content": "Use the shared helper.", + "existing_code": "return old()", + "improved_code": "return new()", + "score": 7, + } + base.update(overrides) + return base + + +# --------------------------------------------------------------------------- +# _truncate_if_needed +# --------------------------------------------------------------------------- + +def test_truncate_if_needed_appends_message_when_over_limit(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 10) + settings.set("pr_code_suggestions.suggestion_truncation_message", "[truncated]") + try: + suggestion = _suggestion(improved_code="a" * 50) + out = PRCodeSuggestions._truncate_if_needed(suggestion) + # Truncated content + truncation message on a new line + assert out["improved_code"].startswith("a" * 10) + assert out["improved_code"].endswith("\n[truncated]") + assert "a" * 11 not in out["improved_code"] + finally: + restore_settings(snapshot) + + +def test_truncate_if_needed_noop_when_under_limit_or_disabled(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 100) + settings.set("pr_code_suggestions.suggestion_truncation_message", "[truncated]") + try: + short = _suggestion(improved_code="short()") + out = PRCodeSuggestions._truncate_if_needed(short) + assert out["improved_code"] == "short()" + + # Disabled (0) leaves long content untouched + settings.set("pr_code_suggestions.max_code_suggestion_length", 0) + long_suggestion = _suggestion(improved_code="x" * 500) + out = PRCodeSuggestions._truncate_if_needed(long_suggestion) + assert out["improved_code"] == "x" * 500 + assert "[truncated]" not in out["improved_code"] + finally: + restore_settings(snapshot) + + +def test_prepare_pr_code_suggestions_applies_truncation_inline(): + settings = get_settings() + snapshot = snapshot_settings(TRUNCATION_SETTINGS) + settings.set("pr_code_suggestions.max_code_suggestion_length", 5) + settings.set("pr_code_suggestions.suggestion_truncation_message", "[cut]") + try: + tool = _make_tool() + prediction = """ +code_suggestions: + - one_sentence_summary: Inline truncation + label: maintainability + relevant_file: app.py + suggestion_content: Trim me. + existing_code: old() + improved_code: aaaaaaaaaaaaaaaaaaaa +""" + data = tool._prepare_pr_code_suggestions(prediction) + assert len(data["code_suggestions"]) == 1 + improved = data["code_suggestions"][0]["improved_code"] + assert improved.startswith("aaaaa") + assert improved.endswith("\n[cut]") + finally: + restore_settings(snapshot) + + +# --------------------------------------------------------------------------- +# push_inline_code_suggestions: rendered body shape +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_push_inline_renders_body_with_score_and_label(): + git_provider = MagicMock() + git_provider.diff_files = [ + FilePatchInfo( + base_file="", + head_file="def f():\n return old()\n", + patch="", + filename="app.py", + ) + ] + git_provider.publish_code_suggestions.return_value = True + tool = _make_tool(git_provider) + data = {"code_suggestions": [_suggestion(score=8)]} + + await tool.push_inline_code_suggestions(data) + + args = git_provider.publish_code_suggestions.call_args.args[0] + assert len(args) == 1 + body = args[0]["body"] + assert body.startswith("**Suggestion:** Use the shared helper.") + assert "[maintainability, importance: 8]" in body + assert "```suggestion\n return new()\n```" in body + # original_suggestion is the unmodified dict + assert args[0]["original_suggestion"]["one_sentence_summary"] == "Use the shared helper" + + +@pytest.mark.asyncio +async def test_push_inline_renders_body_without_score_when_missing_or_zero(): + git_provider = MagicMock() + git_provider.diff_files = [ + FilePatchInfo( + base_file="", + head_file="def f():\n return old()\n", + patch="", + filename="app.py", + ) + ] + git_provider.publish_code_suggestions.return_value = True + tool = _make_tool(git_provider) + suggestion = _suggestion() + suggestion.pop("score") + data = {"code_suggestions": [suggestion]} + + await tool.push_inline_code_suggestions(data) + + body = git_provider.publish_code_suggestions.call_args.args[0][0]["body"] + assert "[maintainability]" in body + assert "importance" not in body + + +@pytest.mark.asyncio +async def test_push_inline_publishes_no_suggestions_comment_when_empty(): + git_provider = MagicMock() + tool = _make_tool(git_provider) + + await tool.push_inline_code_suggestions({"code_suggestions": []}) + + git_provider.publish_comment.assert_called_once_with( + "No suggestions found to improve this PR." + ) + git_provider.publish_code_suggestions.assert_not_called() + + +# --------------------------------------------------------------------------- +# generate_summarized_suggestions +# --------------------------------------------------------------------------- + +def test_generate_summarized_suggestions_empty_returns_placeholder(): + tool = _make_tool() + out = tool.generate_summarized_suggestions({"code_suggestions": []}) + assert "PR Code Suggestions" in out + assert "No suggestions found to improve this PR." in out + # No table is rendered when empty + assert "
  • " not in out + + +def test_generate_summarized_suggestions_renders_table_and_sorts_by_score(): + git_provider = MagicMock() + git_provider.get_line_link.return_value = "https://example.test/app.py#L2" + tool = _make_tool(git_provider) + settings = get_settings() + snapshot = snapshot_settings(["pr_code_suggestions.new_score_mechanism"]) + settings.set("pr_code_suggestions.new_score_mechanism", False) + try: + low = _suggestion(one_sentence_summary="Lower scored tweak", score=3, label="maintainability") + high = _suggestion( + one_sentence_summary="Higher scored tweak", + score=9, + label="security", + relevant_file="auth.py", + ) + out = tool.generate_summarized_suggestions({"code_suggestions": [low, high]}) + + assert "
    " in out and "
    " in out + assert "" in out + # Labels are capitalized in the rendered category column + assert "Security" in out + assert "Maintainability" in out + # Higher score group appears before lower score group + assert out.index("Security") < out.index("Maintainability") + # Both suggestion summaries appear + assert "Higher scored tweak" in out + assert "Lower scored tweak" in out + # Numeric score shown (new_score_mechanism disabled) + assert ">9\n\n" in out + assert ">3\n\n" in out + # Diff block is rendered + assert "```diff" in out + finally: + restore_settings(snapshot) + + +def test_generate_summarized_suggestions_uses_score_string_when_new_mechanism_enabled(): + git_provider = MagicMock() + git_provider.get_line_link.return_value = "" + tool = _make_tool(git_provider) + settings = get_settings() + snapshot = snapshot_settings(["pr_code_suggestions.new_score_mechanism"]) + settings.set("pr_code_suggestions.new_score_mechanism", True) + try: + out = tool.generate_summarized_suggestions({ + "code_suggestions": [_suggestion(score=9, one_sentence_summary="High one")] + }) + # The new mechanism replaces numeric score with bucket label + assert "High" in out + # Plain numeric "9" should not be shown in the impact column + assert ">9\n\n" not in out + finally: + restore_settings(snapshot) + + +def test_generate_summarized_suggestions_escapes_angle_bracket_strings_in_summary(): + git_provider = MagicMock() + git_provider.get_line_link.return_value = "" + tool = _make_tool(git_provider) + suggestion = _suggestion(one_sentence_summary="Replace '' with new_name") + out = tool.generate_summarized_suggestions({"code_suggestions": [suggestion]}) + # The "'<...>'" pattern is rewritten with backticks, which replace_code_tags + # then turns into an HTML span with escaped angle brackets so it isn't + # parsed as an HTML tag. + assert "''" not in out + assert "<old_name>" in out + + +def test_generate_summarized_suggestions_includes_score_why_block_when_present(): + git_provider = MagicMock() + git_provider.get_line_link.return_value = "" + tool = _make_tool(git_provider) + suggestion = _suggestion(score_why="Catches a real bug.") + out = tool.generate_summarized_suggestions({"code_suggestions": [suggestion]}) + assert "Suggestion importance[1-10]: 7" in out + assert "Why: Catches a real bug." in out + + +# --------------------------------------------------------------------------- +# Stale one-liner validation +# --------------------------------------------------------------------------- + +def test_validate_one_liner_zeroes_score_when_change_already_applied(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [ + FilePatchInfo( + base_file="def f():\n return old()\n", + head_file="def f():\n return new()\n", + patch="", + filename="app.py", + ) + ] + tool = _make_tool(git_provider) + suggestion = _suggestion(score=8, existing_code="return old()", improved_code="return new()") + + out = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert out["score"] == 0 + + +def test_validate_one_liner_keeps_score_when_existing_code_still_present(): + git_provider = MagicMock() + git_provider.get_diff_files.return_value = [ + FilePatchInfo( + base_file="def f():\n return old()\n", + head_file="def f():\n return old()\n", + patch="", + filename="app.py", + ) + ] + tool = _make_tool(git_provider) + suggestion = _suggestion(score=8, existing_code="return old()", improved_code="return new()") + + out = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + assert out["score"] == 8 + + +def test_validate_one_liner_skips_when_existing_code_contains_ellipsis(): + git_provider = MagicMock() + # Provide a diff_files target that would otherwise trigger the stale guard, + # to confirm the early-return for "..." takes precedence. + git_provider.get_diff_files.return_value = [ + FilePatchInfo( + base_file="def f():\n return old()\n", + head_file="def f():\n return new()\n", + patch="", + filename="app.py", + ) + ] + tool = _make_tool(git_provider) + suggestion = _suggestion( + score=8, + existing_code="...\nreturn old()\n...", + improved_code="return new()", + ) + + out = tool.validate_one_liner_suggestion_not_repeating_code(suggestion) + + # Score must remain untouched because the ellipsis early-return runs first. + assert out["score"] == 8 + + +# --------------------------------------------------------------------------- +# get_score_str thresholds +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "score,expected", + [(10, "High"), (9, "High"), (8, "Medium"), (7, "Medium"), (6, "Low"), (0, "Low")], +) +def test_get_score_str_returns_bucket_for_default_thresholds(score, expected): + settings = get_settings() + snapshot = snapshot_settings([ + "pr_code_suggestions.new_score_mechanism_th_high", + "pr_code_suggestions.new_score_mechanism_th_medium", + ]) + settings.set("pr_code_suggestions.new_score_mechanism_th_high", 9) + settings.set("pr_code_suggestions.new_score_mechanism_th_medium", 7) + try: + tool = _make_tool() + assert tool.get_score_str(score) == expected + finally: + restore_settings(snapshot) diff --git a/tests/unittest/test_pr_description_output_core.py b/tests/unittest/test_pr_description_output_core.py new file mode 100644 index 0000000000..ca46e02ee7 --- /dev/null +++ b/tests/unittest/test_pr_description_output_core.py @@ -0,0 +1,479 @@ +"""Focused unit tests for /describe output behavior. + +These tests target stable helper seams on ``PRDescription`` and the +``process_description`` helper. They avoid LLM/network calls by bypassing +``__init__`` and providing minimal in-memory state. + +Coverage: +* ``_prepare_data`` key reordering, diagram sanitization removal, and + ``add_original_user_description`` injection. +* ``_prepare_labels`` list/string parsing, fallback-to-type behavior, and + ``labels_minimal_to_labels_dict`` re-casing. +* ``_prepare_pr_answer_with_markers`` HTML-comment guards, generated-by + header injection, list-type joining, and the diagram marker dual-format. +* ``_prepare_pr_answer`` non-gfm vs gfm branching, ``enable_pr_type`` + toggling, ``get_labels`` removal, and description bullet formatting. +* ``process_pr_files_prediction`` gfm-only table rendering. +* Round-trip: ``process_description`` recovers files from a rendered + walkthrough produced by ``process_pr_files_prediction``. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import yaml + +from pr_agent.algo.types import FilePatchInfo +from pr_agent.algo.utils import PRDescriptionHeader, process_description +from pr_agent.tools.pr_description import PRDescription + +KEYS_FIX = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"] + + +def _make_instance(prediction_yaml: str = "") -> PRDescription: + """Construct a ``PRDescription`` instance without running ``__init__``.""" + with patch.object(PRDescription, "__init__", lambda self, *a, **kw: None): + obj = PRDescription.__new__(PRDescription) + obj.prediction = prediction_yaml + obj.keys_fix = KEYS_FIX + obj.user_description = "" + obj.vars = {} + obj.data = {} + obj.pr_id = "1" + obj.file_label_dict = {} + obj.COLLAPSIBLE_FILE_LIST_THRESHOLD = 8 + return obj + + +def _settings( + *, + add_original_user_description: bool = False, + publish_labels: bool = False, + enable_pr_type: bool = True, + generate_ai_title: bool = True, + include_generated_by_header: bool = False, + enable_semantic_files_types: bool = True, + collapsible_file_list: str = "adaptive", + file_table_collapsible_open_by_default: bool = False, +) -> MagicMock: + """Build a settings mock with all PR-description knobs the SUT reads.""" + settings = MagicMock() + pd = settings.pr_description + pd.add_original_user_description = add_original_user_description + pd.publish_labels = publish_labels + pd.enable_pr_type = enable_pr_type + pd.generate_ai_title = generate_ai_title + pd.include_generated_by_header = include_generated_by_header + pd.enable_semantic_files_types = enable_semantic_files_types + pd.collapsible_file_list = collapsible_file_list + pd.get.side_effect = lambda key, default=None: { + "file_table_collapsible_open_by_default": file_table_collapsible_open_by_default, + }.get(key, default) + return settings + + +# --------------------------------------------------------------------------- +# _prepare_data +# --------------------------------------------------------------------------- +class TestPrepareData: + @patch("pr_agent.tools.pr_description.get_settings") + def test_keys_are_reordered_in_canonical_sequence(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = _make_instance(yaml.dump({ + "pr_files": [], + "description": "desc", + "labels": ["bug"], + "type": "Bug fix", + "title": "AI title", + })) + + obj._prepare_data() + + # Order matters: title, type, labels, description, pr_files + assert list(obj.data.keys()) == ["title", "type", "labels", "description", "pr_files"] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_empty_diagram_key_is_dropped(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = _make_instance(yaml.dump({ + "title": "t", + "description": "d", + "changes_diagram": "graph LR\nA --> B", # no mermaid fence -> sanitized to '' + })) + + obj._prepare_data() + + assert "changes_diagram" not in obj.data + + @patch("pr_agent.tools.pr_description.get_settings") + def test_user_description_is_injected_when_enabled(self, mock_get_settings): + mock_get_settings.return_value = _settings(add_original_user_description=True) + obj = _make_instance(yaml.dump({"title": "t", "description": "d"})) + obj.user_description = "Original body from user" + + obj._prepare_data() + + assert obj.data["User Description"] == "Original body from user" + + +# --------------------------------------------------------------------------- +# _prepare_labels +# --------------------------------------------------------------------------- +class TestPrepareLabels: + @patch("pr_agent.tools.pr_description.get_settings") + def test_labels_list_is_returned_stripped(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = _make_instance() + obj.data = {"labels": [" bug ", "perf"]} + obj.variables = {} + + assert obj._prepare_labels() == ["bug", "perf"] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_labels_comma_string_is_split(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = _make_instance() + obj.data = {"labels": "bug, perf , docs"} + obj.variables = {} + + assert obj._prepare_labels() == ["bug", "perf", "docs"] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_falls_back_to_type_only_when_publish_labels_enabled(self, mock_get_settings): + mock_get_settings.return_value = _settings(publish_labels=True) + obj = _make_instance() + obj.data = {"type": "Bug fix, Refactor"} + obj.variables = {} + + assert obj._prepare_labels() == ["Bug fix", "Refactor"] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_does_not_fall_back_to_type_when_publish_labels_disabled(self, mock_get_settings): + mock_get_settings.return_value = _settings(publish_labels=False) + obj = _make_instance() + obj.data = {"type": "Bug fix"} + obj.variables = {} + + assert obj._prepare_labels() == [] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_labels_minimal_dict_remaps_case(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = _make_instance() + obj.data = {"labels": ["bug fix", "perf"]} + obj.variables = {"labels_minimal_to_labels_dict": {"bug fix": "Bug Fix"}} + + assert obj._prepare_labels() == ["Bug Fix", "perf"] + + +# --------------------------------------------------------------------------- +# _prepare_pr_answer_with_markers +# --------------------------------------------------------------------------- +class TestPrepareAnswerWithMarkers: + def _obj_with_user_description(self, user_description: str, data: dict) -> PRDescription: + obj = _make_instance() + obj.vars = {"title": "Original title"} + obj.user_description = user_description + obj.data = data + obj.git_provider = MagicMock() + obj.git_provider.last_commit_id.sha = "deadbeef" + return obj + + @patch("pr_agent.tools.pr_description.get_settings") + def test_html_comment_guard_prevents_type_replacement(self, mock_get_settings): + mock_get_settings.return_value = _settings() + body_in = "\npr_agent:type stays raw" + obj = self._obj_with_user_description(body_in, {"title": "AI", "type": "Bug fix"}) + + _, body, _, _ = obj._prepare_pr_answer_with_markers() + + # Guard present -> the plain marker is NOT replaced. + assert "pr_agent:type stays raw" in body + assert "Bug fix" not in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_plain_summary_marker_is_replaced(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = self._obj_with_user_description( + "Intro\npr_agent:summary\nOutro", + {"title": "AI", "description": "Adds caching layer."}, + ) + + _, body, _, _ = obj._prepare_pr_answer_with_markers() + + assert "Adds caching layer." in body + assert "pr_agent:summary" not in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_generated_by_header_prefixes_replacements(self, mock_get_settings): + mock_get_settings.return_value = _settings(include_generated_by_header=True) + obj = self._obj_with_user_description( + "pr_agent:type\npr_agent:summary", + {"title": "AI", "type": "Bug fix", "description": "Fix bug."}, + ) + + _, body, _, _ = obj._prepare_pr_answer_with_markers() + + assert "### 🤖 Generated by PR Agent at deadbeef" in body + # Header appears for both replaced markers. + assert body.count("### 🤖 Generated by PR Agent at deadbeef") == 2 + + @patch("pr_agent.tools.pr_description.get_settings") + def test_list_type_is_joined_with_comma(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = self._obj_with_user_description( + "pr_agent:type", + {"title": "AI", "type": ["Bug fix", "Refactor"]}, + ) + + _, body, _, _ = obj._prepare_pr_answer_with_markers() + + assert "Bug fix, Refactor" in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_diagram_marker_replaces_both_plain_and_html_comment(self, mock_get_settings): + mock_get_settings.return_value = _settings() + diagram = "\n```mermaid\ngraph LR\nA --> B\n```" + obj = self._obj_with_user_description( + "First: pr_agent:diagram\nSecond: ", + {"title": "AI", "changes_diagram": diagram}, + ) + + _, body, _, _ = obj._prepare_pr_answer_with_markers() + + # Both forms are substituted with the diagram. + assert body.count("```mermaid") == 2 + assert "" not in body + assert "pr_agent:diagram" not in body.replace("```mermaid", "") + + @patch("pr_agent.tools.pr_description.get_settings") + def test_title_falls_back_when_generate_ai_title_disabled(self, mock_get_settings): + mock_get_settings.return_value = _settings(generate_ai_title=False) + obj = self._obj_with_user_description( + "pr_agent:summary", + {"title": "AI Title", "description": "x"}, + ) + + title, _, _, _ = obj._prepare_pr_answer_with_markers() + + assert title == "Original title" + + +# --------------------------------------------------------------------------- +# _prepare_pr_answer (non-marker rendering path) +# --------------------------------------------------------------------------- +class TestPrepareAnswer: + def _obj(self, data: dict, *, gfm: bool = True) -> PRDescription: + obj = _make_instance() + obj.vars = {"title": "Original title"} + obj.data = data + obj.file_label_dict = {} + obj.git_provider = MagicMock() + obj.git_provider.is_supported.side_effect = lambda cap: { + "gfm_markdown": gfm, + "get_labels": False, + }.get(cap, False) + obj.git_provider.get_diff_files.return_value = [] + obj.git_provider.get_line_link.return_value = "" + return obj + + @patch("pr_agent.tools.pr_description.get_settings") + def test_labels_removed_when_provider_supports_get_labels(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = self._obj({"title": "t", "labels": ["bug"], "description": "d"}) + obj.git_provider.is_supported.side_effect = lambda cap: cap in {"gfm_markdown", "get_labels"} + + _, body, _, _ = obj._prepare_pr_answer() + + # The Labels section is suppressed for providers with native label support. + assert "Labels" not in body + assert "bug" not in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_type_section_removed_when_disabled(self, mock_get_settings): + mock_get_settings.return_value = _settings(enable_pr_type=False) + obj = self._obj({"title": "t", "type": "Bug fix", "description": "d"}) + + _, body, _, _ = obj._prepare_pr_answer() + + assert "PR Type" not in body + assert "Bug fix" not in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_description_list_value_is_joined_and_bullets_spaced(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = self._obj({ + "title": "t", + "description": "Intro\n- one\n- two", + }) + + _, body, _, _ = obj._prepare_pr_answer() + + # Bullet readability: single newline before "-" becomes double newline. + assert "Intro\n\n- one\n\n- two" in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_diagram_section_uses_header_enum(self, mock_get_settings): + mock_get_settings.return_value = _settings() + diagram = "\n```mermaid\ngraph LR\nA --> B\n```" + obj = self._obj({"title": "t", "description": "d", "changes_diagram": diagram}) + + _, body, _, _ = obj._prepare_pr_answer() + + assert f"### {PRDescriptionHeader.DIAGRAM_WALKTHROUGH.value}" in body + assert "```mermaid" in body + + @patch("pr_agent.tools.pr_description.get_settings") + def test_title_uses_vars_title_when_data_has_no_title(self, mock_get_settings): + mock_get_settings.return_value = _settings(generate_ai_title=False) + obj = self._obj({"description": "d"}) + + title, _, _, _ = obj._prepare_pr_answer() + + assert title == "Original title" + + +# --------------------------------------------------------------------------- +# process_pr_files_prediction (gfm vs non-gfm) +# --------------------------------------------------------------------------- +class TestProcessPRFilesPrediction: + def _obj(self, *, gfm: bool, diff_files=None) -> PRDescription: + obj = _make_instance() + obj.git_provider = MagicMock() + obj.git_provider.is_supported.side_effect = lambda cap: cap == "gfm_markdown" and gfm + obj.git_provider.get_diff_files.return_value = diff_files or [] + obj.git_provider.get_line_link.return_value = "https://example/blob/main/src/app.py#L1" + return obj + + @patch("pr_agent.tools.pr_description.get_settings") + def test_non_gfm_provider_skips_table_rendering(self, mock_get_settings): + mock_get_settings.return_value = _settings() + obj = self._obj(gfm=False) + value = {"backend": [("src/app.py", "Add cache", "Adds a bounded cache.")]} + + body, comments = obj.process_pr_files_prediction("PRE", value) + + assert body == "PRE" + assert comments == [] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_gfm_provider_emits_table_with_file_row(self, mock_get_settings): + mock_get_settings.return_value = _settings() + diff = FilePatchInfo("", "", "", "src/app.py") + diff.num_plus_lines = 5 + diff.num_minus_lines = 2 + obj = self._obj(gfm=True, diff_files=[diff]) + value = {"backend": [("src/app.py", "Add cache", "Adds a bounded cache.")]} + + body, comments = obj.process_pr_files_prediction("", value) + + assert body.startswith("") + assert body.rstrip().endswith("
    ") + assert "Backend" in body + assert "app.py" in body + assert "+5/-2" in body + assert comments == [] + + @patch("pr_agent.tools.pr_description.get_settings") + def test_adaptive_collapsible_triggers_above_threshold(self, mock_get_settings): + mock_get_settings.return_value = _settings(collapsible_file_list="adaptive") + obj = self._obj(gfm=True) + obj.COLLAPSIBLE_FILE_LIST_THRESHOLD = 1 # force collapsible behavior with 2 files + value = { + "backend": [ + ("a.py", "t1", "s1"), + ("b.py", "t2", "s2"), + ] + } + + body, _ = obj.process_pr_files_prediction("", value) + + assert "
    2 files" in body + + +# --------------------------------------------------------------------------- +# Round-trip: process_description recovers structured files from rendering +# --------------------------------------------------------------------------- +class TestRoundTripWithProcessDescription: + @patch("pr_agent.tools.pr_description.get_settings") + def test_walkthrough_table_round_trips_through_process_description(self, mock_get_settings): + mock_get_settings.return_value = _settings(collapsible_file_list=False) + obj = _make_instance() + diff = FilePatchInfo("", "", "", "src/app.py") + diff.num_plus_lines = 3 + diff.num_minus_lines = 1 + obj.git_provider = MagicMock() + obj.git_provider.is_supported.side_effect = lambda cap: cap == "gfm_markdown" + obj.git_provider.get_diff_files.return_value = [diff] + obj.git_provider.get_line_link.return_value = "https://example/blob/main/src/app.py#L1" + + value = {"backend": [("src/app.py", "Add cache", "Adds a bounded cache.")]} + table, _ = obj.process_pr_files_prediction("", value) + + full_description = ( + "Some intro text.\n\n___\n\n" + f"

    {PRDescriptionHeader.FILE_WALKTHROUGH.value}

    \n\n" + f"{table}\n\n
    \n\n___\n\nFooter" + ) + + base, files = process_description(full_description) + + assert base.startswith("Some intro text.") + # At least one structured file entry was recovered. + assert files, "expected process_description to recover at least one file entry" + recovered = files[0] + assert recovered["short_file_name"] == "app.py" + assert recovered["full_file_name"] == "src/app.py" + assert "Add cache" in recovered["short_summary"] + + def test_process_description_returns_empty_on_empty_input(self): + assert process_description("") == ("", []) + + def test_process_description_without_walkthrough_returns_full_text(self): + text = "Just a description without any walkthrough section." + base, files = process_description(text) + assert base == text + assert files == [] + + +# --------------------------------------------------------------------------- +# _prepare_file_labels edge cases not covered elsewhere +# --------------------------------------------------------------------------- +class TestPrepareFileLabelsEdgeCases: + def test_returns_empty_when_data_missing_pr_files(self): + obj = _make_instance() + obj.data = {"title": "t"} + assert obj._prepare_file_labels() == {} + + def test_returns_empty_when_data_is_not_a_dict(self): + obj = _make_instance() + obj.data = None + assert obj._prepare_file_labels() == {} + + def test_filename_quotes_are_normalized(self): + obj = _make_instance() + obj.vars = {"include_file_summary_changes": True} + obj.data = { + "pr_files": [ + { + "filename": "src/it's a \"file\".py", + "changes_title": "T", + "changes_summary": "S", + "label": "Backend", + }, + ] + } + + labels = obj._prepare_file_labels() + + # Single and double quotes in filenames are replaced with backticks; + # labels are lower-cased for grouping. + assert list(labels.keys()) == ["backend"] + recovered_name = labels["backend"][0][0] + assert "'" not in recovered_name + assert '"' not in recovered_name + + +# Ensure SimpleNamespace import is used (kept for potential future fixtures); +# referenced here to avoid an unused-import warning without changing semantics. +_ = SimpleNamespace diff --git a/tests/unittest/test_pr_questions_helpers.py b/tests/unittest/test_pr_questions_helpers.py new file mode 100644 index 0000000000..f045d3fe70 --- /dev/null +++ b/tests/unittest/test_pr_questions_helpers.py @@ -0,0 +1,342 @@ +"""Focused unit tests for PRQuestions / PR_LineQuestions pure helpers. + +These tests avoid constructing the tool objects through their public +``__init__`` (which would create real git providers and a TokenHandler). +Instead, instances are built with ``__new__`` and only the attributes needed +by the method under test are populated. No live providers and no AI calls. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers.gitlab_provider import GitLabProvider +from pr_agent.tools.pr_line_questions import PR_LineQuestions +from pr_agent.tools.pr_questions import PRQuestions +from tests.unittest._settings_helpers import SENTINEL, restore_settings, snapshot_settings + + +def _make_pr_questions(question_str: str = "", prediction: str = "", git_provider=None) -> PRQuestions: + obj = PRQuestions.__new__(PRQuestions) + obj.question_str = question_str + obj.prediction = prediction + obj.vars = {} + obj.git_provider = git_provider if git_provider is not None else MagicMock() + return obj + + +def _make_line_questions() -> PR_LineQuestions: + obj = PR_LineQuestions.__new__(PR_LineQuestions) + obj.vars = {} + obj.git_provider = MagicMock() + return obj + + +# --------------------------------------------------------------------------- +# PRQuestions.parse_args +# --------------------------------------------------------------------------- + +class TestPRQuestionsParseArgs: + def test_joins_multiple_args(self): + pr = _make_pr_questions() + assert pr.parse_args(["why", "is", "the", "sky", "blue?"]) == "why is the sky blue?" + + def test_empty_args_returns_empty_string(self): + pr = _make_pr_questions() + assert pr.parse_args([]) == "" + assert pr.parse_args(None) == "" + + def test_single_arg(self): + pr = _make_pr_questions() + assert pr.parse_args(["hello"]) == "hello" + + +# --------------------------------------------------------------------------- +# PRQuestions.identify_image_in_comment +# --------------------------------------------------------------------------- + +class TestIdentifyImageInComment: + def test_markdown_image_extracts_url_and_sets_vars(self): + pr = _make_pr_questions( + question_str="explain this ![image](https://example.com/foo.png)" + ) + result = pr.identify_image_in_comment() + # Current contract: parses out content between the parentheses after + # the literal "![image]" marker (strips surrounding parens). + assert result == "https://example.com/foo.png" + assert pr.vars["img_path"] == "https://example.com/foo.png" + + def test_direct_image_url_png(self): + pr = _make_pr_questions( + question_str="please look at https://example.com/diagram.png and answer" + ) + result = pr.identify_image_in_comment() + # Current behavior captures everything from "https://" to end of string + # (including any trailing text). We assert the prefix / contains the URL, + # rather than the exact full match, to remain robust to that quirk. + assert result.startswith("https://example.com/diagram.png") + assert pr.vars["img_path"] == result + + def test_direct_image_url_jpg(self): + pr = _make_pr_questions( + question_str="see https://example.com/screen.jpg" + ) + result = pr.identify_image_in_comment() + assert result.startswith("https://example.com/screen.jpg") + assert "img_path" in pr.vars + + def test_no_image_returns_empty_and_does_not_set_vars(self): + pr = _make_pr_questions(question_str="just a plain text question") + result = pr.identify_image_in_comment() + assert result == "" + assert "img_path" not in pr.vars + + def test_https_without_image_extension_returns_empty(self): + pr = _make_pr_questions(question_str="see https://example.com/docs.html") + result = pr.identify_image_in_comment() + assert result == "" + assert "img_path" not in pr.vars + + +# --------------------------------------------------------------------------- +# PRQuestions._prepare_pr_answer +# --------------------------------------------------------------------------- + +class TestPreparePrAnswer: + def test_wraps_answer_with_ask_answer_headers(self): + pr = _make_pr_questions( + question_str="why?", + prediction="because reasons", + git_provider=MagicMock(), # not GitLab + ) + out = pr._prepare_pr_answer() + assert "### **Ask**❓" in out + assert "why?" in out + assert "### **Answer:**" in out + assert "because reasons" in out + + def test_sanitizes_leading_slash(self): + pr = _make_pr_questions( + question_str="q", prediction="/merge looks fine", git_provider=MagicMock() + ) + out = pr._prepare_pr_answer() + # Leading "/" should have been prefixed with a space so the answer + # does not look like a slash command to the host platform. + assert "\n /merge looks fine" in out + assert "\n/merge" not in out + + def test_sanitizes_newline_slash(self): + pr = _make_pr_questions( + question_str="q", prediction="hello\n/close now", git_provider=MagicMock() + ) + out = pr._prepare_pr_answer() + assert "\n /close now" in out + assert "\n/close" not in out + + def test_sanitizes_carriage_return_slash(self): + pr = _make_pr_questions( + question_str="q", prediction="hello\r/close", git_provider=MagicMock() + ) + out = pr._prepare_pr_answer() + assert "\r /close" in out + assert "\r/close" not in out + + def test_non_gitlab_provider_does_not_apply_gitlab_protections(self): + # Use a non-GitLab provider; a model answer that *does* contain a + # quick-action substring like "/merge" must still come through as a + # (sanitized) answer, NOT be replaced with the GitLab error string. + pr = _make_pr_questions( + question_str="q", prediction="/merge would be premature", git_provider=MagicMock() + ) + out = pr._prepare_pr_answer() + assert "Model answer contains GitHub quick actions" not in out + assert "would be premature" in out + + def test_gitlab_provider_blocks_quick_actions(self): + gitlab_provider = GitLabProvider.__new__(GitLabProvider) + pr = _make_pr_questions( + question_str="q", + prediction="/merge this please", + git_provider=gitlab_provider, + ) + out = pr._prepare_pr_answer() + assert "Model answer contains GitHub quick actions" in out + + def test_gitlab_provider_passes_through_safe_text(self): + gitlab_provider = GitLabProvider.__new__(GitLabProvider) + pr = _make_pr_questions( + question_str="q", + prediction="this change looks correct", + git_provider=gitlab_provider, + ) + out = pr._prepare_pr_answer() + assert "this change looks correct" in out + assert "Model answer contains GitHub quick actions" not in out + + +# --------------------------------------------------------------------------- +# PRQuestions.gitlab_protections +# --------------------------------------------------------------------------- + +class TestGitlabProtections: + @pytest.mark.parametrize( + "quick_action", + ["/approve", "/close", "/merge", "/reopen", "/unapprove", + "/title", "/assign", "/copy_metadata", "/target_branch"], + ) + def test_detects_each_quick_action(self, quick_action): + pr = _make_pr_questions() + result = pr.gitlab_protections(f"prefix {quick_action} suffix") + assert "GitHub quick actions" in result + + def test_passthrough_for_safe_text(self): + pr = _make_pr_questions() + safe = "everything is fine here" + assert pr.gitlab_protections(safe) == safe + + +# --------------------------------------------------------------------------- +# PR_LineQuestions.parse_args +# --------------------------------------------------------------------------- + +class TestLineQuestionsParseArgs: + def test_joins_multiple_args(self): + lq = _make_line_questions() + assert lq.parse_args(["what", "does", "this", "do"]) == "what does this do" + + def test_empty_args(self): + lq = _make_line_questions() + assert lq.parse_args([]) == "" + assert lq.parse_args(None) == "" + + +# --------------------------------------------------------------------------- +# PR_LineQuestions._load_conversation_history +# --------------------------------------------------------------------------- + +@pytest.fixture +def line_question_settings(): + """Snapshot and restore the dynaconf keys touched by these tests. + + Uses a SENTINEL-based snapshot so keys that were originally absent are + truly removed during teardown, rather than being restored as ``None``. + """ + settings = get_settings() + keys = ("comment_id", "file_name", "line_end") + saved = snapshot_settings(keys) + try: + yield settings + finally: + restore_settings(saved) + + +class TestLoadConversationHistory: + def _set_required(self, settings, *, comment_id=42, file_name="src/foo.py", line_end=10): + settings.set("comment_id", comment_id) + settings.set("file_name", file_name) + settings.set("line_end", line_end) + + def test_returns_empty_when_settings_missing(self, line_question_settings): + # explicitly clear all required settings + line_question_settings.set("comment_id", "") + line_question_settings.set("file_name", "") + line_question_settings.set("line_end", "") + + lq = _make_line_questions() + # provider should not be consulted at all + lq.git_provider.get_review_thread_comments = MagicMock( + side_effect=AssertionError("provider must not be called") + ) + assert lq._load_conversation_history() == "" + + def test_returns_empty_when_only_one_required_setting_missing(self, line_question_settings): + line_question_settings.set("comment_id", 7) + line_question_settings.set("file_name", "") # missing + line_question_settings.set("line_end", 5) + + lq = _make_line_questions() + lq.git_provider.get_review_thread_comments = MagicMock( + side_effect=AssertionError("provider must not be called") + ) + assert lq._load_conversation_history() == "" + + def test_filters_empty_and_current_comment_and_formats(self, line_question_settings): + self._set_required(line_question_settings, comment_id=100) + + current = SimpleNamespace(id=100, body="this is the current comment", + user=SimpleNamespace(login="alice")) + empty = SimpleNamespace(id=101, body="", user=SimpleNamespace(login="bob")) + whitespace = SimpleNamespace(id=102, body=" \n ", + user=SimpleNamespace(login="carol")) + good1 = SimpleNamespace(id=103, body="first reply", + user=SimpleNamespace(login="dave")) + good2 = SimpleNamespace(id=104, body="second reply", + user=SimpleNamespace(login="erin")) + + lq = _make_line_questions() + lq.git_provider.get_review_thread_comments = MagicMock( + return_value=[current, empty, whitespace, good1, good2] + ) + + out = lq._load_conversation_history() + assert out == "1. dave: first reply\n2. erin: second reply" + + def test_user_without_login_attribute_is_unknown(self, line_question_settings): + self._set_required(line_question_settings, comment_id=1) + + # user object that has no 'login' attribute at all + class _NoLoginUser: + pass + + comment = SimpleNamespace(id=2, body="anonymous reply", user=_NoLoginUser()) + + lq = _make_line_questions() + lq.git_provider.get_review_thread_comments = MagicMock(return_value=[comment]) + + out = lq._load_conversation_history() + assert out == "1. Unknown: anonymous reply" + + def test_provider_exception_returns_empty_without_raising(self, line_question_settings): + self._set_required(line_question_settings, comment_id=1) + + lq = _make_line_questions() + lq.git_provider.get_review_thread_comments = MagicMock( + side_effect=RuntimeError("boom") + ) + + # must not propagate the exception + assert lq._load_conversation_history() == "" + + def test_only_filtered_comments_returns_empty(self, line_question_settings): + self._set_required(line_question_settings, comment_id=10) + + # everything in the thread is either the current comment or empty + current = SimpleNamespace(id=10, body="current", user=SimpleNamespace(login="u")) + blank = SimpleNamespace(id=11, body="", user=SimpleNamespace(login="u")) + + lq = _make_line_questions() + lq.git_provider.get_review_thread_comments = MagicMock( + return_value=[current, blank] + ) + assert lq._load_conversation_history() == "" + + +def test_line_question_settings_teardown_restores_sentinel_for_missing_keys(): + """Run the fixture manually and verify keys absent before are absent after.""" + settings = get_settings() + key = "comment_id" + # Make sure key is genuinely absent on entry. + if settings.get(key, SENTINEL) is not SENTINEL: + restore_settings({key: SENTINEL}) + assert settings.get(key, SENTINEL) is SENTINEL + + saved = snapshot_settings((key,)) + try: + settings.set(key, 999) + assert settings.get(key) == 999 + finally: + restore_settings(saved) + + assert settings.get(key, SENTINEL) is SENTINEL diff --git a/tests/unittest/test_retry_with_fallback_models.py b/tests/unittest/test_retry_with_fallback_models.py new file mode 100644 index 0000000000..e0dd3dfc8a --- /dev/null +++ b/tests/unittest/test_retry_with_fallback_models.py @@ -0,0 +1,200 @@ +import asyncio + +import pytest + +from pr_agent.algo.pr_processing import retry_with_fallback_models +from pr_agent.algo.utils import ModelType +from pr_agent.config_loader import get_settings +from tests.unittest._settings_helpers import SENTINEL, restore_settings, snapshot_settings + +_TRACKED_KEYS = ( + "config.model", + "config.model_weak", + "config.model_reasoning", + "config.fallback_models", + "openai.deployment_id", + "openai.fallback_deployments", +) + + +def _snapshot_settings(): + return snapshot_settings(_TRACKED_KEYS) + + +def _restore_settings(snapshot): + restore_settings(snapshot) + + +def test_primary_model_success_invoked_once_and_returns_value(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "primary-model") + get_settings().set("config.fallback_models", ["fallback-1", "fallback-2"]) + get_settings().set("openai.deployment_id", None) + get_settings().set("openai.fallback_deployments", []) + + calls = [] + + async def fake_f(model): + calls.append(model) + return "primary-result" + + result = asyncio.run(retry_with_fallback_models(fake_f)) + + assert result == "primary-result" + assert calls == ["primary-model"] + finally: + _restore_settings(snapshot) + + +def test_primary_fails_fallback_succeeds(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "primary-model") + get_settings().set("config.fallback_models", ["fallback-1", "fallback-2"]) + get_settings().set("openai.deployment_id", None) + get_settings().set("openai.fallback_deployments", []) + + calls = [] + + async def fake_f(model): + calls.append(model) + if model == "primary-model": + raise RuntimeError("primary failed") + return f"ok:{model}" + + result = asyncio.run(retry_with_fallback_models(fake_f)) + + assert result == "ok:fallback-1" + assert calls == ["primary-model", "fallback-1"] + finally: + _restore_settings(snapshot) + + +def test_all_models_fail_raises_with_aggregate_message_and_cause(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "primary-model") + get_settings().set("config.fallback_models", ["fallback-1"]) + get_settings().set("openai.deployment_id", None) + get_settings().set("openai.fallback_deployments", []) + + last_error = ValueError("last failure") + attempted = [] + + async def fake_f(model): + attempted.append(model) + if model == "fallback-1": + raise last_error + raise RuntimeError("primary failure") + + with pytest.raises(Exception) as exc_info: + asyncio.run(retry_with_fallback_models(fake_f)) + + assert attempted == ["primary-model", "fallback-1"] + assert "Failed to generate prediction with any model" in str(exc_info.value) + # Production code uses `raise ... from e`, so the last failure should be chained. + assert exc_info.value.__cause__ is last_error + finally: + _restore_settings(snapshot) + + +def test_deployment_id_updated_per_attempt(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "primary-model") + get_settings().set("config.fallback_models", ["fallback-1", "fallback-2"]) + get_settings().set("openai.deployment_id", "deployment-primary") + get_settings().set( + "openai.fallback_deployments", + ["deployment-fb1", "deployment-fb2"], + ) + + observed = [] + + async def fake_f(model): + observed.append( + (model, get_settings().get("openai.deployment_id", None)) + ) + if model != "fallback-1": + raise RuntimeError(f"fail for {model}") + return "fallback-ok" + + result = asyncio.run(retry_with_fallback_models(fake_f)) + + assert result == "fallback-ok" + assert observed == [ + ("primary-model", "deployment-primary"), + ("fallback-1", "deployment-fb1"), + ] + finally: + _restore_settings(snapshot) + + +def test_weak_model_type_uses_weak_setting_and_forwards_identifier(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "regular-model") + get_settings().set("config.model_weak", "weak-model-id") + get_settings().set("config.fallback_models", []) + get_settings().set("openai.deployment_id", None) + get_settings().set("openai.fallback_deployments", []) + + calls = [] + + async def fake_f(model): + calls.append(model) + return model + + result = asyncio.run( + retry_with_fallback_models(fake_f, model_type=ModelType.WEAK) + ) + + assert result == "weak-model-id" + assert calls == ["weak-model-id"] + finally: + _restore_settings(snapshot) + + +def test_reasoning_model_type_uses_reasoning_setting(): + snapshot = _snapshot_settings() + try: + get_settings().set("config.model", "regular-model") + get_settings().set("config.model_reasoning", "reasoning-model-id") + get_settings().set("config.fallback_models", []) + get_settings().set("openai.deployment_id", None) + get_settings().set("openai.fallback_deployments", []) + + calls = [] + + async def fake_f(model): + calls.append(model) + return model + + result = asyncio.run( + retry_with_fallback_models(fake_f, model_type=ModelType.REASONING) + ) + + assert result == "reasoning-model-id" + assert calls == ["reasoning-model-id"] + finally: + _restore_settings(snapshot) + + +def test_restore_settings_truly_removes_originally_missing_dotted_keys(): + """Regression: SENTINEL-snapshotted dotted leaves must be removed, not left behind.""" + settings = get_settings() + key = "openai.fallback_deployments" + # Ensure key is absent on entry; if a previous test leaked it, clean it. + if settings.get(key, SENTINEL) is not SENTINEL: + _restore_settings({key: SENTINEL}) + assert settings.get(key, SENTINEL) is SENTINEL + + snapshot = _snapshot_settings() + try: + settings.set(key, ["leaked-deployment"]) + assert settings.get(key) == ["leaked-deployment"] + finally: + _restore_settings(snapshot) + + assert settings.get(key, SENTINEL) is SENTINEL diff --git a/tests/unittest/test_ticket_extraction_async.py b/tests/unittest/test_ticket_extraction_async.py new file mode 100644 index 0000000000..b6d819b859 --- /dev/null +++ b/tests/unittest/test_ticket_extraction_async.py @@ -0,0 +1,485 @@ +""" +Unit tests for async ticket extraction & caching in +``pr_agent.tools.ticket_pr_compliance_check``. + +These tests are deterministic and fake-provider based — no live API or +network access is performed. +""" +import asyncio + +import pytest + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import AzureDevopsProvider, GithubProvider +from pr_agent.tools import ticket_pr_compliance_check as tpc +from pr_agent.tools.ticket_pr_compliance_check import ( + extract_and_cache_pr_tickets, + extract_tickets, +) +from tests.unittest._settings_helpers import restore_settings, snapshot_settings + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + +class _FakeLabel: + def __init__(self, name): + self.name = name + + +class _FakeIssue: + def __init__(self, number, title="t", body="b", labels=None): + self.number = number + self.title = title + self.body = body + self.labels = labels if labels is not None else [] + + +class _FakeRepoObj: + """Mimics PyGithub Repository.get_issue lookup behaviour.""" + + def __init__(self, issues_by_number=None, raise_for=None): + self._issues = issues_by_number or {} + self._raise_for = raise_for or set() + + def get_issue(self, number): + if number in self._raise_for: + raise RuntimeError(f"boom for issue {number}") + if number not in self._issues: + raise KeyError(f"unknown issue {number}") + return self._issues[number] + + +def _make_github_provider( + *, + user_description="", + branch="main", + repo="org/repo", + base_url_html="https://github.com", + repo_obj=None, + sub_issues_map=None, + sub_issues_raises=False, +): + """Build a GithubProvider that passes ``isinstance`` checks without __init__.""" + provider = GithubProvider.__new__(GithubProvider) + provider.repo = repo + provider.base_url_html = base_url_html + provider.repo_obj = repo_obj + provider.get_user_description = lambda: user_description + provider.get_pr_branch = lambda: branch + + sub_issues_map = sub_issues_map or {} + + def _fetch_sub_issues(ticket_url): + if sub_issues_raises: + raise RuntimeError("sub-issue fetch failed") + return sub_issues_map.get(ticket_url, []) + + provider.fetch_sub_issues = _fetch_sub_issues + return provider + + +def _make_azure_provider(work_items): + provider = AzureDevopsProvider.__new__(AzureDevopsProvider) + provider.get_linked_work_items = lambda: work_items + return provider + + +# --------------------------------------------------------------------------- +# Settings snapshot helper +# --------------------------------------------------------------------------- + +@pytest.fixture +def settings_snapshot(): + """Snapshot and restore settings keys mutated by these tests. + + Uses the shared sentinel-based helpers so that keys originally absent + (including the dotted ``pr_reviewer.require_ticket_analysis_review`` + leaf) are truly removed on restore — never left as a ``None`` value + that would leak into subsequent tests. + """ + s = get_settings() + snapshot = snapshot_settings( + ["related_tickets", "pr_reviewer.require_ticket_analysis_review"] + ) + # Reset to known defaults for each test + s.set("related_tickets", []) + s.set("pr_reviewer.require_ticket_analysis_review", False) + try: + yield s + finally: + restore_settings(snapshot) + + +# --------------------------------------------------------------------------- +# Scenario 1: GitHub extraction merges description + branch, dedupes, caps +# --------------------------------------------------------------------------- + +class TestGithubExtractionMerging: + def test_branch_extraction_contributes_ticket_not_in_description(self, settings_snapshot): + # Description mentions only #1; branch contributes #2. Without branch + # extraction the result would be [1]; with it, [1, 2] (description first). + desc = "Fixes #1" + repo_obj = _FakeRepoObj({ + 1: _FakeIssue(1, title="One", body="body1"), + 2: _FakeIssue(2, title="Two", body="body2"), + }) + provider = _make_github_provider( + user_description=desc, + branch="feature/2-dup", + repo_obj=repo_obj, + ) + result = asyncio.run(extract_tickets(provider)) + assert result is not None + ids = [t["ticket_id"] for t in result] + # Order is meaningful: description-derived ticket first, then branch. + assert ids == [1, 2] + + def test_branch_duplicate_is_deduped_against_description(self, settings_snapshot): + # Description references both #1 and #2; branch also points at #2. + # The branch duplicate must not produce a second entry for #2. + desc = "Fixes #1 and addresses #2" + repo_obj = _FakeRepoObj({ + 1: _FakeIssue(1, title="One", body="body1"), + 2: _FakeIssue(2, title="Two", body="body2"), + }) + provider = _make_github_provider( + user_description=desc, + branch="feature/2-dup", + repo_obj=repo_obj, + ) + result = asyncio.run(extract_tickets(provider)) + assert result is not None + urls = [t["ticket_url"] for t in result] + assert len(urls) == len(set(urls)) + ids = sorted(t["ticket_id"] for t in result) + assert ids == [1, 2] + + def test_branch_only_extraction_produces_single_ticket(self, settings_snapshot): + # Description carries no ticket references — the branch must still + # surface its issue number on its own. + repo_obj = _FakeRepoObj({ + 77: _FakeIssue(77, title="From branch", body="bb"), + }) + provider = _make_github_provider( + user_description="No ticket reference here.", + branch="feature/77-add-thing", + repo_obj=repo_obj, + ) + result = asyncio.run(extract_tickets(provider)) + assert result is not None + assert len(result) == 1 + assert result[0]["ticket_id"] == 77 + assert result[0]["ticket_url"].endswith("/issues/77") + + def test_caps_total_tickets_to_three(self, settings_snapshot): + # Description has 3 explicit URLs; branch adds a 4th — total must be + # capped at 3 and the dropped one must be the branch-derived #13. + desc = ( + "See https://github.com/org/repo/issues/10 " + "and https://github.com/org/repo/issues/11 " + "and https://github.com/org/repo/issues/12" + ) + repo_obj = _FakeRepoObj({ + 10: _FakeIssue(10), + 11: _FakeIssue(11), + 12: _FakeIssue(12), + 13: _FakeIssue(13), + }) + provider = _make_github_provider( + user_description=desc, + branch="feature/13-extra", + repo_obj=repo_obj, + ) + result = asyncio.run(extract_tickets(provider)) + assert result is not None + assert len(result) == 3 + ids = sorted(t["ticket_id"] for t in result) + # The branch-derived #13 must be the one dropped: description tickets + # come first in the merge order, so the cap drops the trailing entry. + assert ids == [10, 11, 12] + + +# --------------------------------------------------------------------------- +# Scenario 2: Long body truncation +# --------------------------------------------------------------------------- + +class TestBodyTruncation: + def test_main_issue_body_truncated_to_10000_chars_plus_ellipsis(self, settings_snapshot): + long_body = "x" * 10500 + repo_obj = _FakeRepoObj({1: _FakeIssue(1, body=long_body)}) + provider = _make_github_provider( + user_description="Fixes #1", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result and len(result) == 1 + body = result[0]["body"] + assert body.endswith("...") + assert len(body) == 10000 + len("...") + + def test_short_body_not_truncated(self, settings_snapshot): + repo_obj = _FakeRepoObj({1: _FakeIssue(1, body="short")}) + provider = _make_github_provider( + user_description="Fixes #1", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result[0]["body"] == "short" + + +# --------------------------------------------------------------------------- +# Scenario 3: get_issue failure on one ticket does not block others +# --------------------------------------------------------------------------- + +class TestGetIssueFailureIsolated: + def test_failure_on_one_issue_does_not_break_others(self, settings_snapshot): + repo_obj = _FakeRepoObj( + issues_by_number={2: _FakeIssue(2, title="Two")}, + raise_for={1}, + ) + provider = _make_github_provider( + user_description="Fixes #1 and #2", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result is not None + ids = [t["ticket_id"] for t in result] + assert ids == [2] + + +# --------------------------------------------------------------------------- +# Scenario 4 + 5: sub-issue fetch success and exception handling +# --------------------------------------------------------------------------- + +class TestSubIssues: + def test_sub_issue_success_populates_and_truncates(self, settings_snapshot): + long_sub_body = "y" * 10500 + repo_obj = _FakeRepoObj({ + 1: _FakeIssue(1, title="Main", body="m"), + 99: _FakeIssue(99, title="Sub", body=long_sub_body), + }) + sub_url = "https://github.com/org/repo/issues/99" + provider = _make_github_provider( + user_description="Fixes #1", + repo_obj=repo_obj, + sub_issues_map={"https://github.com/org/repo/issues/1": [sub_url]}, + ) + result = asyncio.run(extract_tickets(provider)) + assert result and len(result) == 1 + subs = result[0]["sub_issues"] + assert len(subs) == 1 + assert subs[0]["ticket_url"] == sub_url + assert subs[0]["title"] == "Sub" + assert subs[0]["body"].endswith("...") + assert len(subs[0]["body"]) == 10000 + len("...") + + def test_sub_issue_fetch_exception_yields_empty_sub_issues(self, settings_snapshot): + repo_obj = _FakeRepoObj({1: _FakeIssue(1, title="Main", body="m")}) + provider = _make_github_provider( + user_description="Fixes #1", + repo_obj=repo_obj, + sub_issues_raises=True, + ) + result = asyncio.run(extract_tickets(provider)) + assert result and len(result) == 1 + assert result[0]["sub_issues"] == [] + + def test_single_sub_issue_failure_does_not_break_others(self, settings_snapshot): + repo_obj = _FakeRepoObj( + issues_by_number={ + 1: _FakeIssue(1, title="Main"), + 99: _FakeIssue(99, title="OK", body="ok"), + }, + raise_for={50}, + ) + sub_bad = "https://github.com/org/repo/issues/50" + sub_good = "https://github.com/org/repo/issues/99" + provider = _make_github_provider( + user_description="Fixes #1", + repo_obj=repo_obj, + sub_issues_map={ + "https://github.com/org/repo/issues/1": [sub_bad, sub_good] + }, + ) + result = asyncio.run(extract_tickets(provider)) + subs = result[0]["sub_issues"] + assert [s["ticket_url"] for s in subs] == [sub_good] + + +# --------------------------------------------------------------------------- +# Scenario 6: labels — supports both object-style and string-style +# --------------------------------------------------------------------------- + +class TestLabelExtraction: + def test_object_labels_extracted_by_name(self, settings_snapshot): + repo_obj = _FakeRepoObj({ + 1: _FakeIssue(1, labels=[_FakeLabel("bug"), _FakeLabel("urgent")]), + }) + provider = _make_github_provider( + user_description="Fixes #1", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result[0]["labels"] == "bug, urgent" + + def test_string_labels_also_supported(self, settings_snapshot): + repo_obj = _FakeRepoObj({ + 1: _FakeIssue(1, labels=["bug", "urgent"]), + }) + provider = _make_github_provider( + user_description="Fixes #1", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result[0]["labels"] == "bug, urgent" + + def test_label_iteration_failure_yields_empty_labels(self, settings_snapshot): + class _Boom: + def __iter__(self): + raise RuntimeError("nope") + + issue = _FakeIssue(1) + issue.labels = _Boom() + repo_obj = _FakeRepoObj({1: issue}) + provider = _make_github_provider( + user_description="Fixes #1", repo_obj=repo_obj + ) + result = asyncio.run(extract_tickets(provider)) + assert result[0]["labels"] == "" + + +# --------------------------------------------------------------------------- +# Scenario 7: Azure DevOps linked work items mapping +# --------------------------------------------------------------------------- + +class TestAzureDevopsExtraction: + def test_linked_work_items_mapped_with_truncation(self, settings_snapshot): + long_body = "z" * 10500 + work_items = [ + { + "id": 1, + "url": "https://dev.azure.com/o/p/_workitems/edit/1", + "title": "WI 1", + "body": long_body, + "acceptance_criteria": "AC1", + "labels": ["a", "b"], + }, + { + "id": 2, + "url": "https://dev.azure.com/o/p/_workitems/edit/2", + "title": "WI 2", + "body": "short", + "labels": [], + }, + ] + provider = _make_azure_provider(work_items) + result = asyncio.run(extract_tickets(provider)) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["ticket_id"] == 1 + assert result[0]["title"] == "WI 1" + assert result[0]["body"].endswith("...") + assert len(result[0]["body"]) == 10000 + len("...") + assert result[0]["requirements"] == "AC1" + assert result[0]["labels"] == "a, b" + assert result[1]["body"] == "short" + assert result[1]["labels"] == "" + assert result[1].get("requirements", "") == "" + + +# --------------------------------------------------------------------------- +# Scenario 11: Unsupported provider returns None per current contract +# --------------------------------------------------------------------------- + +class TestUnsupportedProvider: + def test_non_github_non_azure_provider_returns_none(self, settings_snapshot): + class _OtherProvider: + pass + + result = asyncio.run(extract_tickets(_OtherProvider())) + # Current contract: function returns implicit None for unsupported providers + assert result is None + + +# --------------------------------------------------------------------------- +# Scenarios 8-10: extract_and_cache_pr_tickets behavior +# --------------------------------------------------------------------------- + +class TestExtractAndCachePrTickets: + def test_review_setting_disabled_returns_without_provider_calls( + self, settings_snapshot + ): + settings_snapshot.set("pr_reviewer.require_ticket_analysis_review", False) + calls = {"n": 0} + + class _Tripwire: + def __getattr__(self, name): + calls["n"] += 1 + raise AttributeError( + f"Provider should not be touched (attr={name})" + ) + + vars_ = {} + result = asyncio.run(extract_and_cache_pr_tickets(_Tripwire(), vars_)) + assert result is None + assert calls["n"] == 0 + assert "related_tickets" not in vars_ + + def test_uses_existing_related_tickets_cache_without_extract( + self, settings_snapshot, monkeypatch + ): + settings_snapshot.set("pr_reviewer.require_ticket_analysis_review", True) + cached = [{"ticket_id": 42, "title": "cached"}] + settings_snapshot.set("related_tickets", cached) + + async def _boom(_): + raise AssertionError("extract_tickets should not be called when cache is set") + + monkeypatch.setattr(tpc, "extract_tickets", _boom) + + vars_ = {} + # Provider value irrelevant — should never be used + asyncio.run(extract_and_cache_pr_tickets(object(), vars_)) + assert vars_["related_tickets"] == cached + + def test_stores_sub_issues_before_main_issue_in_related_tickets( + self, settings_snapshot, monkeypatch + ): + settings_snapshot.set("pr_reviewer.require_ticket_analysis_review", True) + settings_snapshot.set("related_tickets", []) + + sub_a = {"ticket_url": "u/sub_a", "title": "sub_a", "body": "s1"} + sub_b = {"ticket_url": "u/sub_b", "title": "sub_b", "body": "s2"} + main_ticket = { + "ticket_id": 1, + "ticket_url": "u/main", + "title": "main", + "body": "m", + "labels": "", + "sub_issues": [sub_a, sub_b], + } + + async def _fake_extract(_): + return [main_ticket] + + monkeypatch.setattr(tpc, "extract_tickets", _fake_extract) + + vars_ = {} + asyncio.run(extract_and_cache_pr_tickets(object(), vars_)) + + # Per current production order: sub-issues are appended first, then main. + stored = vars_["related_tickets"] + assert stored == [sub_a, sub_b, main_ticket] + # Settings cache is also populated + assert get_settings().get("related_tickets") == stored + + def test_no_tickets_extracted_leaves_vars_untouched( + self, settings_snapshot, monkeypatch + ): + settings_snapshot.set("pr_reviewer.require_ticket_analysis_review", True) + settings_snapshot.set("related_tickets", []) + + async def _empty(_): + return [] + + monkeypatch.setattr(tpc, "extract_tickets", _empty) + + vars_ = {} + asyncio.run(extract_and_cache_pr_tickets(object(), vars_)) + assert "related_tickets" not in vars_ diff --git a/tests/unittest/test_verify_signature.py b/tests/unittest/test_verify_signature.py new file mode 100644 index 0000000000..4de419fa93 --- /dev/null +++ b/tests/unittest/test_verify_signature.py @@ -0,0 +1,48 @@ +import hashlib +import hmac + +import pytest +from fastapi import HTTPException + +from pr_agent.servers.utils import verify_signature + + +def _sign(payload: bytes, secret: str) -> str: + return "sha256=" + hmac.new(secret.encode("utf-8"), msg=payload, digestmod=hashlib.sha256).hexdigest() + + +class TestVerifySignature: + secret = "unit-test-signing-value" + payload = b'{"action":"opened","number":1}' + + def test_valid_signature_does_not_raise(self): + signature = _sign(self.payload, self.secret) + # Should return None without raising + assert verify_signature(self.payload, self.secret, signature) is None + + @pytest.mark.parametrize("missing", [None, ""]) + def test_missing_signature_raises_403(self, missing): + with pytest.raises(HTTPException) as exc_info: + verify_signature(self.payload, self.secret, missing) + assert exc_info.value.status_code == 403 + assert "x-hub-signature-256" in exc_info.value.detail + + def test_invalid_signature_raises_403(self): + bad_signature = "sha256=" + "0" * 64 + with pytest.raises(HTTPException) as exc_info: + verify_signature(self.payload, self.secret, bad_signature) + assert exc_info.value.status_code == 403 + assert "didn't match" in exc_info.value.detail + + def test_signature_for_different_payload_is_rejected(self): + other_payload = b'{"action":"closed","number":2}' + signature_for_other = _sign(other_payload, self.secret) + with pytest.raises(HTTPException) as exc_info: + verify_signature(self.payload, self.secret, signature_for_other) + assert exc_info.value.status_code == 403 + + def test_signature_with_wrong_secret_is_rejected(self): + signature_wrong_secret = _sign(self.payload, "other-signing-value") + with pytest.raises(HTTPException) as exc_info: + verify_signature(self.payload, self.secret, signature_wrong_secret) + assert exc_info.value.status_code == 403