Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __iter__(self):

# Download all images concurrently
image_map = self._load_images_threaded(tasks)

pre_batch = []
for task in tasks:
image_tensor = image_map.get(task.image_id)
errors = []
Expand All @@ -315,7 +315,11 @@ def __iter__(self):
}
if errors:
row["error"] = "; ".join(errors) if errors else None
yield row
pre_batch.append(row)
batch = rest_collate_fn(
pre_batch
) # Collate before yielding to GPU process
yield batch
Comment thread
carlosgjs marked this conversation as resolved.

logger.debug(f"Worker {worker_id}: Iterator finished")
except Exception as e:
Expand Down Expand Up @@ -360,7 +364,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": [item["image"] for item in successful],
"images": torch.stack([item["image"] for item in successful]),
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand All @@ -377,6 +381,17 @@ def rest_collate_fn(batch: list[dict]) -> dict:
return result


def _no_op_collate_fn(batch: list[dict]) -> dict:
"""
A no-op collate function that returns the batch as-is.

This can be used when the dataset already returns batches in the desired format,
and no further collation is needed. It simply returns the input list of dicts
without modification.
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
"""
return batch[0]
Comment thread
carlosgjs marked this conversation as resolved.


def get_rest_dataloader(
job_id: int,
settings: "Settings",
Expand Down Expand Up @@ -410,7 +425,43 @@ def get_rest_dataloader(

return torch.utils.data.DataLoader(
dataset,
batch_size=settings.localization_batch_size,
# batch_size=settings.localization_batch_size,
batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here
num_workers=settings.num_workers,
collate_fn=rest_collate_fn,
collate_fn=_no_op_collate_fn,
pin_memory=True,
persistent_workers=True if settings.num_workers > 0 else False,
prefetch_factor=4,
)
Comment thread
carlosgjs marked this conversation as resolved.


class CUDAPrefetcher:
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.next_batch = None
self.preload()

def preload(self):
try:
batch = next(self.loader)
except StopIteration:
self.next_batch = None
return

with torch.cuda.stream(self.stream):
self.next_batch = {
k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}

def __iter__(self):
return self

def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
Comment thread
carlosgjs marked this conversation as resolved.
batch = self.next_batch
if batch is None:
raise StopIteration
self.preload()
return batch
Comment thread
carlosgjs marked this conversation as resolved.
37 changes: 25 additions & 12 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision

from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results
from trapdata.antenna.datasets import get_rest_dataloader
from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader
from trapdata.antenna.result_posting import ResultPoster
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections
Expand Down Expand Up @@ -150,7 +150,7 @@ def _apply_binary_classification(
# Process binary classification crops
binary_crops = []
binary_valid_indices = []
to_pil = torchvision.transforms.ToPILImage()
# to_pil = torchvision.transforms.ToPILImage()
binary_transforms = binary_filter.get_transforms()

for idx, dresp in enumerate(detector_results):
Expand All @@ -165,8 +165,9 @@ def _apply_binary_classification(
)
continue
crop = image_tensor[:, y1:y2, x1:x2]
crop_pil = to_pil(crop)
crop_transformed = binary_transforms(crop_pil)
# crop_pil = to_pil(crop)
# crop_transformed = binary_transforms(crop_pil)
crop_transformed = binary_transforms(crop)
binary_crops.append(crop_transformed)
binary_valid_indices.append(idx)

Expand Down Expand Up @@ -242,8 +243,13 @@ def _process_job(
all_detections = []
_, t = log_time()
result_poster: ResultPoster | None = None
prefetcher = CUDAPrefetcher(loader) # if torch.cuda.is_available() else None
try:
for i, batch in enumerate(loader):
Comment thread
carlosgjs marked this conversation as resolved.
prefetcher.preload()
i, batch = 0, next(prefetcher)
_, t_total = log_time() # reset total time for this batch
# for i, batch in enumerate(loader):
while batch is not None:
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
cls_time = 0.0
det_time = 0.0
load_time, t = t()
Expand Down Expand Up @@ -300,7 +306,10 @@ def _process_job(

# output is dict of "boxes", "labels", "scores"
batch_output = []
to_gpu_time = 0.0
if len(images) > 0:
images = images.to(detector.device)
to_gpu_time, t = t()
batch_output = detector.predict_batch(images)

items += len(batch_output)
Expand Down Expand Up @@ -345,7 +354,7 @@ def _process_job(

# Run terminal classifier on filtered detections
classifier.reset(detections_for_terminal_classifier)
to_pil = torchvision.transforms.ToPILImage()
# to_pil = torchvision.transforms.ToPILImage()
classify_transforms = classifier.get_transforms()

# Collect and transform all crops for batched classification
Expand All @@ -363,8 +372,9 @@ def _process_job(
)
continue
crop = image_tensor[:, y1:y2, x1:x2]
crop_pil = to_pil(crop)
crop_transformed = classify_transforms(crop_pil)
# crop_pil = to_pil(crop)
# crop_transformed = classify_transforms(crop_pil)
crop_transformed = classify_transforms(crop)
crops.append(crop_transformed)
valid_indices.append(idx)

Expand Down Expand Up @@ -417,9 +427,7 @@ def _process_job(
)
)
except Exception as e:
logger.error(
f"Batch {i + 1} failed during processing: {e}", exc_info=True
)
logger.error(f"Batch {i + 1} failed during processing: {e}")
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
# Report errors back to Antenna so tasks aren't stuck in the queue
batch_results = []
for reply_subject, image_id in zip(
Expand Down Expand Up @@ -457,9 +465,12 @@ def _process_job(
processing_service_name,
)
_, t = log_time() # reset time to measure batch load time
batch_total, t_total = t_total()
logger.info(
f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s"
f"Total: {batch_total/(len(images)):.2f}s/image, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s, to GPU time: {to_gpu_time:.2f}s, "
)
Comment thread
carlosgjs marked this conversation as resolved.
batch = next(prefetcher)
i += 1

if result_poster:
# Wait for all async posts to complete before finishing the job
Expand All @@ -479,6 +490,8 @@ def _process_job(
f"max queue size: {post_metrics.max_queue_size})"
)
return did_work
except StopIteration:
pass
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
finally:
if result_poster:
result_poster.shutdown()
8 changes: 4 additions & 4 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_transforms(self):
return torchvision.transforms.Compose(
[
torchvision.transforms.Resize((self.input_size, self.input_size)),
torchvision.transforms.ToTensor(),
# torchvision.transforms.ToTensor(),
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
self.normalization,
]
)
Comment thread
carlosgjs marked this conversation as resolved.
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_transforms(self):
[
# self._pad_to_square(),
torchvision.transforms.Resize((self.input_size, self.input_size)),
torchvision.transforms.ToTensor(),
# torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
Expand Down Expand Up @@ -189,7 +189,7 @@ def get_transforms(self):
return torchvision.transforms.Compose(
[
torchvision.transforms.Resize((self.input_size, self.input_size)),
torchvision.transforms.ToTensor(),
# torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_transforms(self):
return torchvision.transforms.Compose(
[
self._pad_to_square(),
torchvision.transforms.ToTensor(),
# torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(
(self.input_size, self.input_size), antialias=True # type: ignore
),
Expand Down
4 changes: 2 additions & 2 deletions trapdata/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class Settings(BaseSettings):
default=ml.models.DEFAULT_FEATURE_EXTRACTOR
)
classification_threshold: float = 0.6
localization_batch_size: int = 8
localization_batch_size: int = 32
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
classification_batch_size: int = 20
num_workers: int = 4

# Antenna API worker settings
antenna_api_base_url: str = "http://localhost:8000/api/v2"
antenna_api_auth_token: str = ""
antenna_service_name: str = "AMI Data Companion"
antenna_api_batch_size: int = 16
antenna_api_batch_size: int = 24
Comment thread
carlosgjs marked this conversation as resolved.

@pydantic.field_validator("image_base_path", "user_data_path")
def validate_path(cls, v):
Expand Down
Loading