-
Notifications
You must be signed in to change notification settings - Fork 14
feat(ml): propagate pipeline config through NATS pull-mode tasks #1279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -567,6 +567,57 @@ def test_tasks_endpoint_without_pipeline(self): | |
| self.assertEqual(resp.status_code, 400) | ||
| self.assertIn("pipeline", resp.json()[0].lower()) | ||
|
|
||
| def test_queue_images_to_nats_embeds_pipeline_config(self): | ||
| """Tasks queued to NATS carry the pipeline config (including project overrides).""" | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| from ami.ml.models import ProjectPipelineConfig | ||
| from ami.ml.schemas import PipelineRequestConfigParameters | ||
|
|
||
| pipeline = self._create_pipeline() | ||
| pipeline.default_config = PipelineRequestConfigParameters({"example_param": "default"}) | ||
| pipeline.save() | ||
| # _create_pipeline already called pipeline.projects.add(self.project) which | ||
| # created a ProjectPipelineConfig row; update it rather than creating a duplicate. | ||
| ProjectPipelineConfig.objects.filter(project=self.project, pipeline=pipeline).update( | ||
| config={"example_param": "project_override"} | ||
| ) | ||
|
Comment on lines
+578
to
+584
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test does not verify default-key preservation in merge. Line 578 and Line 583 only exercise an overridden key, so Line 619 cannot catch regressions where defaults are dropped instead of merged. Add one default-only key and assert it survives in Suggested test tightening- pipeline.default_config = PipelineRequestConfigParameters({"example_param": "default"})
+ pipeline.default_config = PipelineRequestConfigParameters(
+ {"example_param": "default", "default_only_param": 7}
+ )
@@
self.assertIsNotNone(task.config)
self.assertEqual(task.config.get("example_param"), "project_override")
+ self.assertEqual(task.config.get("default_only_param"), 7)Also applies to: 619-619 🤖 Prompt for AI Agents |
||
|
|
||
| job = self._create_ml_job("Config propagation test", pipeline) | ||
| job.dispatch_mode = JobDispatchMode.ASYNC_API | ||
| job.status = JobState.STARTED | ||
| job.save(update_fields=["dispatch_mode", "status"]) | ||
|
|
||
| image = SourceImage.objects.create( | ||
| path="config_test.jpg", | ||
| public_base_url="http://example.com", | ||
| project=self.project, | ||
| ) | ||
|
|
||
| published_tasks = [] | ||
|
|
||
| mock_manager = AsyncMock() | ||
| mock_manager.log_async = AsyncMock() | ||
|
|
||
| async def capture_publish(job_id, data): | ||
| published_tasks.append(data) | ||
| return True | ||
|
|
||
| mock_manager.publish_task = capture_publish | ||
|
|
||
| mock_ctx = MagicMock() | ||
| mock_ctx.__aenter__ = AsyncMock(return_value=mock_manager) | ||
| mock_ctx.__aexit__ = AsyncMock(return_value=False) | ||
|
|
||
| with patch("ami.ml.orchestration.jobs.TaskQueueManager", return_value=mock_ctx): | ||
| with patch("ami.ml.orchestration.jobs.AsyncJobStateManager"): | ||
| queue_images_to_nats(job, [image]) | ||
|
|
||
| self.assertEqual(len(published_tasks), 1) | ||
| task = published_tasks[0] | ||
| self.assertIsNotNone(task.config) | ||
| self.assertEqual(task.config.get("example_param"), "project_override") | ||
|
|
||
| def test_result_endpoint_stub(self): | ||
| """Test the result endpoint accepts results (stubbed implementation).""" | ||
| pipeline = self._create_pipeline() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -513,6 +513,74 @@ def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_cla | |
| # Assert: Task was acknowledged despite missing job | ||
| mock_manager.acknowledge_task.assert_called_once_with(reply_subject) | ||
|
|
||
| @patch("ami.jobs.tasks.TaskQueueManager") | ||
| def test_pipeline_config_drift_logs_warning(self, mock_manager_class): | ||
| """ | ||
| When a worker echoes a pipeline config that doesn't match what Antenna | ||
| would resolve today (e.g. ProjectPipelineConfig was edited mid-job, or | ||
| the worker is stale), process_nats_pipeline_result must log a warning. | ||
| Drift is logged but not enforced. | ||
| """ | ||
| from ami.ml.schemas import PipelineRequestConfigParameters | ||
|
|
||
| self._setup_mock_nats(mock_manager_class) | ||
|
|
||
| self.pipeline.default_config = PipelineRequestConfigParameters({"example_config_param": 5}) | ||
| self.pipeline.save() | ||
|
|
||
| # Worker echoes a config that doesn't match the current default | ||
| success_data = PipelineResultsResponse( | ||
| pipeline="test-pipeline", | ||
| algorithms={}, | ||
| total_time=1.0, | ||
| source_images=[SourceImageResponse(id=str(self.images[0].pk), url="http://example.com/test_image_0.jpg")], | ||
| detections=[], | ||
| errors=None, | ||
| config={"example_config_param": 99}, | ||
| ).dict() | ||
|
|
||
| with self.assertLogs(level="WARNING") as cm: | ||
| process_nats_pipeline_result.apply( | ||
| kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.drift"} | ||
| ) | ||
|
|
||
| self.assertTrue( | ||
| any("Pipeline config drift" in msg for msg in cm.output), | ||
| f"Expected drift warning in logs, got: {cm.output}", | ||
| ) | ||
|
|
||
| @patch("ami.jobs.tasks.TaskQueueManager") | ||
| def test_pipeline_config_match_does_not_warn(self, mock_manager_class): | ||
| """When echoed config matches current pipeline config, no drift warning is logged.""" | ||
| from ami.ml.schemas import PipelineRequestConfigParameters | ||
|
|
||
| self._setup_mock_nats(mock_manager_class) | ||
|
|
||
| self.pipeline.default_config = PipelineRequestConfigParameters({"example_config_param": 5}) | ||
| self.pipeline.save() | ||
|
|
||
| success_data = PipelineResultsResponse( | ||
| pipeline="test-pipeline", | ||
| algorithms={}, | ||
| total_time=1.0, | ||
| source_images=[SourceImageResponse(id=str(self.images[0].pk), url="http://example.com/test_image_0.jpg")], | ||
| detections=[], | ||
| errors=None, | ||
| config={"example_config_param": 5}, | ||
| ).dict() | ||
|
|
||
| # assertLogs requires at least one log; capture INFO so the test doesn't | ||
| # spuriously fail when no WARNING is emitted (the assertion below). | ||
| with self.assertLogs(level="INFO") as cm: | ||
| process_nats_pipeline_result.apply( | ||
| kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.match"} | ||
| ) | ||
|
Comment on lines
+532
to
+577
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drift tests currently run through a failing save path (missing detector). These two tests construct a “successful” result but the fixture pipeline has no detection algorithm, so Proposed fix def test_pipeline_config_drift_logs_warning(self, mock_manager_class):
@@
- self._setup_mock_nats(mock_manager_class)
+ mock_manager = self._setup_mock_nats(mock_manager_class)
@@
self.pipeline.default_config = PipelineRequestConfigParameters({"example_config_param": 5})
self.pipeline.save()
+ self.pipeline.algorithms.add(
+ Algorithm.objects.create(
+ name="drift-detector",
+ key="drift-detector",
+ task_type=AlgorithmTaskType.LOCALIZATION,
+ )
+ )
@@
with self.assertLogs(level="WARNING") as cm:
process_nats_pipeline_result.apply(
kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.drift"}
)
+ mock_manager.acknowledge_task.assert_called_once_with("reply.drift")
@@
def test_pipeline_config_match_does_not_warn(self, mock_manager_class):
@@
- self._setup_mock_nats(mock_manager_class)
+ mock_manager = self._setup_mock_nats(mock_manager_class)
@@
self.pipeline.default_config = PipelineRequestConfigParameters({"example_config_param": 5})
self.pipeline.save()
+ self.pipeline.algorithms.add(
+ Algorithm.objects.create(
+ name="match-detector",
+ key="match-detector",
+ task_type=AlgorithmTaskType.LOCALIZATION,
+ )
+ )
@@
with self.assertLogs(level="INFO") as cm:
process_nats_pipeline_result.apply(
kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.match"}
)
+ mock_manager.acknowledge_task.assert_called_once_with("reply.match")🤖 Prompt for AI Agents |
||
|
|
||
| self.assertFalse( | ||
| any("Pipeline config drift" in msg for msg in cm.output), | ||
| f"Did not expect drift warning for matching config, got: {cm.output}", | ||
| ) | ||
|
|
||
|
|
||
| class TestTaskFailureGuard(TransactionTestCase): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,8 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): | |
| job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job.pk}'") | ||
|
|
||
| # Prepare all messages outside of async context to avoid Django ORM issues | ||
| pipeline_config = job.pipeline.get_config(project_id=job.project.pk) if job.pipeline else None | ||
|
|
||
| tasks: list[tuple[int, PipelineProcessingTask]] = [] | ||
| image_ids = [] | ||
| skipped_count = 0 | ||
|
|
@@ -101,6 +103,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): | |
| id=image_id, | ||
| image_id=image_id, | ||
| image_url=image_url, | ||
| config=pipeline_config, | ||
| ) | ||
|
Comment on lines
89
to
107
|
||
| tasks.append((image.pk, task)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Proposal: move pipeline config from per-task to job-level fetch | ||
|
|
||
| **Status**: draft, not implemented. Tracks a follow-up to PRs #1279 (Antenna) + ADC #146. | ||
|
|
||
| ## Context | ||
|
|
||
| After #1279, every `PipelineProcessingTask` published to NATS carries a `config` field. All tasks within a single job share the same config — ADC's `rest_collate_fn` already encodes this assumption with `successful[0].get("config")` — so embedding the config redundantly in every task is informationally wasteful and structurally incorrect. | ||
|
|
||
| This is fine for the current shape of pipeline configs (a handful of small primitives like `example_config_param: int`). It will not stay fine once configs grow to include things like: | ||
|
|
||
| - A taxa allow-list for a CLIP-style classifier (potentially thousands of names) | ||
| - Per-stage hyperparameter overrides | ||
| - Feature-flag toggles for "roll up taxa on the ADC side", `include_features`, `include_softmax` | ||
| - Per-job model variant selection or threshold curves | ||
|
|
||
| A job with N images and a config of size M ships N×M bytes through NATS today, when it should ship M bytes once. | ||
|
|
||
| ## Proposed shape | ||
|
|
||
| ### Pull mode (NATS) | ||
|
|
||
| Add a job metadata fetch that ADC calls **once per job**, before iterating tasks: | ||
|
|
||
| ``` | ||
| GET /api/v2/jobs/{job_id}/ | ||
| ``` | ||
|
|
||
|
Comment on lines
+24
to
+27
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a language to the fenced code block. This block is missing a fence language and will trip markdownlint MD040 in repos where docs linting is enforced. Proposed fix-```
+```http
GET /api/v2/jobs/{job_id}/🧰 Tools🪛 markdownlint-cli2 (0.22.1)[warning] 24-24: Fenced code blocks should have a language specified (MD040, fenced-code-language) 🤖 Prompt for AI Agents |
||
| Response includes (among existing fields) the resolved pipeline config: | ||
|
|
||
| ```json | ||
| { | ||
| "id": 7, | ||
| "pipeline_slug": "global_moths_2024", | ||
| "config": { | ||
| "include_features": true, | ||
| "taxa_allowlist": ["Lepidoptera", ...] | ||
| }, | ||
| ... | ||
| } | ||
| ``` | ||
|
|
||
| ADC fetches and caches this once per job (already does `_fetch_tasks` per-job; add a sibling `_fetch_job_meta`). `PipelineProcessingTask.config` becomes vestigial and can be deprecated in a follow-up after both sides have shipped. | ||
|
|
||
| The existing `AntennaJobsListResponse` only returns `id` and `pipeline_slug`; this would be a separate detail endpoint, not a change to the list endpoint. | ||
|
|
||
| ### Push mode (sync HTTP `POST /process`) | ||
|
|
||
| The push path is request/response, not task-fetched, so there's no equivalent "fetch once" moment for the worker. Two reasonable options: | ||
|
|
||
| 1. **Keep config on each `PipelineRequest`** (status quo). Simple. Wastes bandwidth on the wire, but most push-mode requests are small (single image or a handful), so the overhead is bounded. No worker-side change. | ||
|
|
||
| 2. **Send config in a job-init handshake**. The push API would need a notion of a "job" that workers can register against, which they don't have today — push-mode services are stateless w.r.t. jobs. Adding job state to push-mode workers is a substantially bigger change (cache invalidation, eviction, multi-tenant memory growth) and not worth it for the current config sizes. | ||
|
|
||
| Recommendation: **(1) for push, (2) for pull.** Push-mode requests are already independent — there's no "session" to attach config to without inventing one. Pull-mode has a natural job boundary already; reuse it. | ||
|
|
||
| ## Migration | ||
|
|
||
| Pull mode: | ||
| 1. Antenna ships a `GET /api/v2/jobs/{id}/` endpoint that includes `config` in the response. | ||
| 2. ADC adds `_fetch_job_meta()` and reads `config` from there; falls back to `task.config` if the meta endpoint returns 404 (older Antenna). | ||
| 3. After ADC ≥ this version is the floor, Antenna removes the per-task `config` field. | ||
|
|
||
| Push mode: no change. | ||
|
|
||
| ## Costs of doing it now vs. later | ||
|
|
||
| Doing it now: bigger PR than #1279, but the schema isn't fossilized yet — only one consumer (ADC) and zero data persisted with the per-task shape. | ||
|
|
||
| Doing it later: every external worker that adopts the per-task `config` field becomes a backwards-compat constraint. The longer the per-task shape is "the contract," the more painful the migration. | ||
|
|
||
| The audit log added in this PR (`ami/jobs/tasks.py:process_nats_pipeline_result`) becomes simpler under job-level config: compare once at job start, not on every result. | ||
|
|
||
| ## Out of scope | ||
|
|
||
| - Authentication / permissions on the new job meta endpoint (use whatever ADC already uses for `/tasks/`) | ||
| - Schema versioning of `config` itself (separate problem; matters more once configs start carrying user-editable structures like taxa lists) | ||
| - UI for editing `ProjectPipelineConfig` (already exists in admin; no change needed) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -76,6 +76,28 @@ PipelineChoice = typing.Literal[ | |||||||||||||
| "new-pipeline-slug", | ||||||||||||||
| ] | ||||||||||||||
| ``` | ||||||||||||||
| ## NATS Pull-Mode (Async API) Contract | ||||||||||||||
|
|
||||||||||||||
| Processing services that operate in pull-mode (fetching tasks from Antenna via `POST /api/v2/jobs/{id}/tasks/`) receive `PipelineProcessingTask` objects. Each task now includes a `config` field carrying the pipeline configuration for that job: | ||||||||||||||
|
|
||||||||||||||
| ```json | ||||||||||||||
| { | ||||||||||||||
| "id": "42", | ||||||||||||||
| "image_id": "42", | ||||||||||||||
| "image_url": "https://...", | ||||||||||||||
| "reply_subject": "antenna.results.job.7.img.42", | ||||||||||||||
| "config": { | ||||||||||||||
| "example_config_param": 3 | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| ``` | ||||||||||||||
|
|
||||||||||||||
| `config` mirrors `PipelineRequest.config` from the synchronous HTTP path. It is derived from the pipeline's `default_config` merged with any per-project `ProjectPipelineConfig` override. It may be `null` if no config is set. | ||||||||||||||
|
|
||||||||||||||
| Workers should read `config` from each task and apply it to their processing. If `config` is absent or null, fall back to worker-level defaults (e.g. environment variables). | ||||||||||||||
|
Comment on lines
+95
to
+97
|
||||||||||||||
| `config` mirrors `PipelineRequest.config` from the synchronous HTTP path. It is derived from the pipeline's `default_config` merged with any per-project `ProjectPipelineConfig` override. It may be `null` if no config is set. | |
| Workers should read `config` from each task and apply it to their processing. If `config` is absent or null, fall back to worker-level defaults (e.g. environment variables). | |
| `config` mirrors `PipelineRequest.config` from the synchronous HTTP path. It is derived from the pipeline's `default_config` merged with any per-project `ProjectPipelineConfig` override. In the normal Antenna code path, it is an object; if no defaults or overrides are set, it may be an empty object. | |
| Workers should read `config` from each task and apply it to their processing. If `config` is unexpectedly absent or `null` (for example, due to a malformed or legacy payload), fall back to worker-level defaults (e.g. environment variables). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid logging full pipeline config payloads in drift warnings.
This logs complete worker/current config values on every drift event. In practice that can expose sensitive fields and create very large per-job logs. Log a compact diff summary (e.g., differing keys/count) instead of full payloads.
Proposed fix
🤖 Prompt for AI Agents