Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions trapdata/antenna/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run_benchmark(
batch_size: int,
gpu_batch_size: int,
service_name: str,
send_acks: bool = True,
) -> None:
"""Run the benchmark with the specified parameters.

Expand Down Expand Up @@ -132,8 +133,8 @@ def run_benchmark(
ack_result = create_empty_result(reply_subject, image_id)
ack_results.append(ack_result)

logger.info(f"Sending {len(ack_results)} acknowledgment(s)")
if ack_results:
if ack_results and send_acks:
logger.info(f"Sending {len(ack_results)} acknowledgment(s)")
# Send acknowledgments asynchronously
result_poster.post_async(
base_url=base_url,
Expand All @@ -157,7 +158,7 @@ def run_benchmark(
)
error_results.append(error_result)

if error_results:
if error_results and send_acks:
result_poster.post_async(
base_url=base_url,
auth_token=auth_token,
Expand Down Expand Up @@ -280,6 +281,11 @@ def main() -> int:
default="Performance Test",
help="Processing service name",
)
parser.add_argument(
"--skip-acks",
action="store_false",
help="Skip sending acknowledgments for processed images",
)

args = parser.parse_args()

Expand All @@ -298,6 +304,7 @@ def main() -> int:
batch_size=args.batch_size,
gpu_batch_size=args.gpu_batch_size,
service_name=args.service_name,
send_acks=args.skip_acks,
)
return 0

Expand Down
71 changes: 63 additions & 8 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
localization_batch_size (default 8)
How many images the GPU processes at once (detection). Larger =
more GPU memory. These are full-resolution images (~4K).
Async worker use antennna_api_batch_size for this.

num_workers (default 4)
DataLoader subprocesses. Each independently fetches tasks and
Expand Down Expand Up @@ -254,7 +255,7 @@ def __iter__(self):

Each API fetch returns a batch of tasks. Images for the entire batch
are downloaded concurrently using threads (see _load_images_threaded),
then yielded one at a time for the DataLoader to collate.
then as a pre-collated batch.

Yields:
Dictionary containing:
Expand Down Expand Up @@ -293,7 +294,7 @@ def __iter__(self):

# Download all images concurrently
image_map = self._load_images_threaded(tasks)

pre_batch = []
for task in tasks:
image_tensor = image_map.get(task.image_id)
errors = []
Expand All @@ -315,7 +316,11 @@ def __iter__(self):
}
if errors:
row["error"] = "; ".join(errors) if errors else None
yield row
pre_batch.append(row)
batch = rest_collate_fn(
pre_batch
) # Collate before yielding to GPU process
yield batch

logger.debug(f"Worker {worker_id}: Iterator finished")
except Exception as e:
Expand Down Expand Up @@ -360,7 +365,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": [item["image"] for item in successful],
"images": torch.stack([item["image"] for item in successful]),
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand All @@ -377,6 +382,17 @@ def rest_collate_fn(batch: list[dict]) -> dict:
return result


def _no_op_collate_fn(batch: list[dict]) -> dict:
"""
A no-op collate function that unwraps a single-element batch.

This can be used when the dataset already returns batches in the desired format,
and no further collation is needed. It simply returns the input list of dicts
without modification.
"""
return batch[0]
Comment thread
carlosgjs marked this conversation as resolved.


def get_rest_dataloader(
job_id: int,
settings: "Settings",
Expand All @@ -395,8 +411,7 @@ def get_rest_dataloader(
job_id: Job ID to fetch tasks for
settings: Settings object. Relevant fields:
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call)
- localization_batch_size (images per GPU batch)
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
Expand All @@ -410,7 +425,47 @@ def get_rest_dataloader(

return torch.utils.data.DataLoader(
dataset,
batch_size=settings.localization_batch_size,
batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here
num_workers=settings.num_workers,
collate_fn=rest_collate_fn,
collate_fn=_no_op_collate_fn,
pin_memory=True,
persistent_workers=settings.num_workers > 0,
prefetch_factor=4 if settings.num_workers > 0 else None,
)
Comment thread
carlosgjs marked this conversation as resolved.


class CUDAPrefetcher:
def __init__(self, loader: torch.utils.data.DataLoader, device: torch.device):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.device = device
self.next_batch = None
self._preload()

def _preload(self):
try:
batch = next(self.loader)
except StopIteration:
self.next_batch = None
return

with torch.cuda.stream(self.stream):
self.next_batch = {
k: (
v.to(self.device, non_blocking=True)
if isinstance(v, torch.Tensor)
else v
)
for k, v in batch.items()
}

def __iter__(self):
return self

def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
Comment thread
carlosgjs marked this conversation as resolved.
batch = self.next_batch
if batch is None:
raise StopIteration
self._preload()
return batch
Comment thread
carlosgjs marked this conversation as resolved.
19 changes: 15 additions & 4 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ def test_multiple_batches(self):
dataset = self._make_dataset(job_id=4, batch_size=2)
rows = list(dataset)

# Should get all 3 images (batch1: 2 images, batch2: 1 image)
assert len(rows) == 3
assert all(r["image"] is not None for r in rows)
# Dataset now yields pre-collated batches: batch1 (2 images), batch2 (1 image)
assert len(rows) == 2
total_images = sum(len(r["image_ids"]) for r in rows)
assert total_images == 3
assert all(r["images"] is not None for r in rows)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -272,6 +274,7 @@ def test_empty_queue(self):
100,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is False
Expand Down Expand Up @@ -300,6 +303,7 @@ def test_processes_batch_with_real_inference(self):
101,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

# Validate processing succeeded
Expand Down Expand Up @@ -339,6 +343,7 @@ def test_handles_failed_items(self):
102,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

posted_results = antenna_api_server.get_posted_results(102)
Expand Down Expand Up @@ -375,6 +380,7 @@ def test_mixed_batch_success_and_failures(self):
103,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is True
Expand Down Expand Up @@ -475,7 +481,11 @@ def test_full_workflow_with_real_inference(self):

# Step 3: Process job
result = _process_job(
pipeline_slug, 200, self._make_settings(), "Test Worker"
pipeline_slug,
200,
self._make_settings(),
"Test Worker",
device=torch.device("cpu"),
)
assert result is True

Expand Down Expand Up @@ -527,6 +537,7 @@ def test_multiple_batches_processed(self):
201,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is True
Expand Down
38 changes: 26 additions & 12 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision

from trapdata.antenna.client import get_full_service_name, get_jobs
from trapdata.antenna.datasets import get_rest_dataloader
from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader
from trapdata.antenna.result_posting import ResultPoster
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections
Expand Down Expand Up @@ -80,7 +79,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
pipelines: List of pipeline slugs to poll for jobs.
"""
settings = read_settings()

device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
torch.cuda.set_device(gpu_id)
logger.info(
Expand Down Expand Up @@ -115,6 +114,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
job_id=job_id,
settings=settings,
processing_service_name=full_service_name,
device=device,
)
any_jobs = any_jobs or any_work_done
except Exception as e:
Expand Down Expand Up @@ -153,7 +153,6 @@ def _apply_binary_classification(
# Process binary classification crops
binary_crops = []
binary_valid_indices = []
to_pil = torchvision.transforms.ToPILImage()
binary_transforms = binary_filter.get_transforms()

for idx, dresp in enumerate(detector_results):
Expand All @@ -168,8 +167,7 @@ def _apply_binary_classification(
)
continue
crop = image_tensor[:, y1:y2, x1:x2]
crop_pil = to_pil(crop)
crop_transformed = binary_transforms(crop_pil)
crop_transformed = binary_transforms(crop)
binary_crops.append(crop_transformed)
binary_valid_indices.append(idx)

Expand Down Expand Up @@ -298,7 +296,6 @@ def _process_batch(

# Run terminal classifier on filtered detections
classifier.reset(detections_for_terminal_classifier)
to_pil = torchvision.transforms.ToPILImage()
classify_transforms = classifier.get_transforms()

# Collect and transform all crops for batched classification
Expand All @@ -317,8 +314,7 @@ def _process_batch(
)
continue
crop = image_tensor[:, y1:y2, x1:x2]
crop_pil = to_pil(crop)
crop_transformed = classify_transforms(crop_pil)
crop_transformed = classify_transforms(crop)
crops.append(crop_transformed)
valid_indices.append(idx)

Expand Down Expand Up @@ -408,6 +404,7 @@ def _process_job(
job_id: int,
settings: Settings,
processing_service_name: str,
device: torch.device | None = None,
on_batch_complete: Callable | None = None,
) -> bool:
"""Run the worker to process images from the REST API queue.
Expand All @@ -417,6 +414,7 @@ def _process_job(
job_id: Job ID to process
settings: Settings object with antenna_api_* configuration
processing_service_name: Name of the processing service
device: The device to use for processing. Auto-detected if None.
on_batch_complete: Optional callback invoked after each batch, with kwargs
batch_num (int) and items (int, cumulative items processed so far).
Returns:
Expand All @@ -436,6 +434,9 @@ def _process_job(
use_binary_filter = should_filter_detections(classifier_class)
binary_filter = None

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
torch.cuda.empty_cache()
items = 0
Expand All @@ -446,8 +447,17 @@ def _process_job(
total_detections = 0
_, t = log_time()
result_poster: ResultPoster | None = None
# Conditionally use CUDA prefetcher; fall back to plain iterator on CPU
if torch.cuda.is_available():
batch_source = CUDAPrefetcher(
loader, device
) # __init__ already calls preload()
else:
batch_source = iter(loader)

_, t_total = log_time()
try:
for i, batch in enumerate(loader):
Comment thread
carlosgjs marked this conversation as resolved.
for i, batch in enumerate(batch_source):
cls_time = 0.0
det_time = 0.0
load_time, t = t()
Expand Down Expand Up @@ -500,12 +510,16 @@ def _process_job(
batch_results,
processing_service_name,
)
_, t = log_time() # reset time to measure batch load time
batch_total, t_total = t_total()
logger.info(
f"Finished batch {i + 1}. Total items: {items}, "
f"Batch {i + 1}: {batch_total/max(n_items, 1):.2f}s/image, "
f"Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, "
f"Load time: {load_time:.2f}s"
)
Comment thread
carlosgjs marked this conversation as resolved.
(
_,
t,
) = log_time() # reset before next() call to measure next batch's load time

if on_batch_complete:
on_batch_complete(batch_num=i, items=items)
Expand Down
Loading
Loading