Skip to content

Commit

Permalink
Updates to collection sampling methods (#717)
Browse files Browse the repository at this point in the history
* fix: handle queries with 0 results when populating collections

* fix: better handle 0 result queries in all sampling methods

* fix: handle midnight boudary for queries with both start & end hour

* feat: support date fields for collection sampling methods

* feat: randomize collections with max_num by default

* fix: explain changes to max_num field in UI
  • Loading branch information
mihow authored Feb 6, 2025
1 parent 178231a commit 3181753
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 34 deletions.
28 changes: 28 additions & 0 deletions ami/base/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import datetime

from rest_framework import serializers


class DateStringField(serializers.CharField):
"""
Field that validates and stores dates as YYYY-MM-DD strings.
Needed for storing dates as strings in JSON fields but keep validation.
"""

def to_internal_value(self, value: str | None) -> str | None:
if value is None:
return None

try:
# Validate the date format by parsing it
datetime.datetime.strptime(value, "%Y-%m-%d")
return value
except ValueError as e:
raise serializers.ValidationError("Invalid date format. Use YYYY-MM-DD format.") from e

@classmethod
def to_date(cls, value: str | None) -> datetime.date | None:
"""Convert a YYYY-MM-DD string to a Python date object for ORM queries."""
if value is None:
return None
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
8 changes: 7 additions & 1 deletion ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models import QuerySet
from rest_framework import serializers

from ami.base.fields import DateStringField
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, get_current_user, reverse_with_params
from ami.jobs.models import Job
from ami.main.models import create_source_image_from_upload
Expand Down Expand Up @@ -1025,9 +1026,14 @@ class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
# use for the "common_combined" method
minute_interval = serializers.IntegerField(required=False, allow_null=True)
max_num = serializers.IntegerField(required=False, allow_null=True)
shuffle = serializers.BooleanField(required=False, allow_null=True)

month_start = serializers.IntegerField(required=False, allow_null=True)
month_end = serializers.IntegerField(required=False, allow_null=True)

date_start = DateStringField(required=False, allow_null=True)
date_end = DateStringField(required=False, allow_null=True)

hour_start = serializers.IntegerField(required=False, allow_null=True)
hour_end = serializers.IntegerField(required=False, allow_null=True)

Expand All @@ -1039,9 +1045,9 @@ class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
deployment_id = serializers.IntegerField(required=False, allow_null=True)
position = serializers.IntegerField(required=False, allow_null=True)

# Don't return the kwargs if they are empty
def to_representation(self, instance):
data = super().to_representation(instance)
# Don't return the kwargs if they are empty
return {key: value for key, value in data.items() if value is not None}


Expand Down
74 changes: 41 additions & 33 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import ami.tasks
import ami.utils
from ami.base.fields import DateStringField
from ami.base.models import BaseModel
from ami.main import charts
from ami.users.models import User
Expand Down Expand Up @@ -1521,8 +1522,8 @@ def set_dimensions_for_collection(


def sample_captures_by_interval(
minute_interval: int = 10,
qs: models.QuerySet[SourceImage] | None = None,
minute_interval: int,
qs: models.QuerySet[SourceImage],
max_num: int | None = None,
) -> typing.Generator[SourceImage, None, None]:
"""
Expand All @@ -1532,9 +1533,6 @@ def sample_captures_by_interval(
last_capture = None
total = 0

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

for capture in qs.all():
Expand All @@ -1555,7 +1553,7 @@ def sample_captures_by_interval(

def sample_captures_by_position(
position: int,
qs: models.QuerySet[SourceImage] | None = None,
qs: models.QuerySet[SourceImage],
) -> typing.Generator[SourceImage | None, None, None]:
"""
Return the n-th position capture from each event.
Expand All @@ -1564,9 +1562,6 @@ def sample_captures_by_position(
If position = -1, the last capture from each event will be returned.
"""

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

events = Event.objects.filter(captures__in=qs).distinct()
Expand All @@ -1593,7 +1588,7 @@ def sample_captures_by_position(

def sample_captures_by_nth(
nth: int,
qs: models.QuerySet[SourceImage] | None = None,
qs: models.QuerySet[SourceImage],
) -> typing.Generator[SourceImage, None, None]:
"""
Return every nth capture from each event.
Expand All @@ -1602,9 +1597,6 @@ def sample_captures_by_nth(
If nth = 5, every 5th capture from each event will be returned.
"""

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

events = Event.objects.filter(captures__in=qs).distinct()
Expand Down Expand Up @@ -2973,35 +2965,51 @@ def sample_manual(self, image_ids: list[int]):
def sample_common_combined(
self,
minute_interval: int | None = None,
max_num: int | None = 100,
max_num: int | None = None,
shuffle: bool = True, # This is applicable if max_num is set and minute_interval is not set
hour_start: int | None = None,
hour_end: int | None = None,
month_start: datetime.date | None = None,
month_end: datetime.date | None = None,
day_start: datetime.date | None = None,
day_end: datetime.date | None = None,
month_start: int | None = None,
month_end: int | None = None,
date_start: str | None = None,
date_end: str | None = None,
) -> models.QuerySet | typing.Generator[SourceImage, None, None]:
qs = self.get_queryset()
if month_start:

if date_start is not None:
qs = qs.filter(timestamp__date__gte=DateStringField.to_date(date_start))
if date_end is not None:
qs = qs.filter(timestamp__date__lte=DateStringField.to_date(date_end))

if month_start is not None:
qs = qs.filter(timestamp__month__gte=month_start)
if month_end:
if month_end is not None:
qs = qs.filter(timestamp__month__lte=month_end)
if day_start:
qs = qs.filter(timestamp__day__gte=day_start)
if day_end:
qs = qs.filter(timestamp__day__lte=day_end)
if hour_start:

if hour_start is not None and hour_end is not None:
if hour_start < hour_end:
# Hour range within the same day (e.g., 08:00 to 15:00)
qs = qs.filter(timestamp__hour__gte=hour_start, timestamp__hour__lte=hour_end)
else:
# Hour range has Midnight crossover: (e.g., 17:00 to 06:00)
qs = qs.filter(models.Q(timestamp__hour__gte=hour_start) | models.Q(timestamp__hour__lte=hour_end))
elif hour_start is not None:
qs = qs.filter(timestamp__hour__gte=hour_start)
if hour_end:
elif hour_end is not None:
qs = qs.filter(timestamp__hour__lte=hour_end)
if not minute_interval and max_num:
qs = qs[:max_num]
if minute_interval:

if minute_interval is not None:
# @TODO can this be done in the database and return a queryset?
# this currently returns a list of source images
# Ensure the queryset is limited to the project
qs = qs.filter(project=self.project)
qs = sample_captures_by_interval(minute_interval, qs=qs, max_num=max_num)
qs = sample_captures_by_interval(minute_interval=minute_interval, qs=qs, max_num=max_num)
else:
if max_num is not None:
if shuffle:
qs = qs.order_by("?")
qs = qs[:max_num]

return qs

def sample_interval(
Expand All @@ -3016,19 +3024,19 @@ def sample_interval(
qs = qs.exclude(event__in=exclude_events)
qs.exclude(event__in=exclude_events)
qs = qs.filter(project=self.project)
return sample_captures_by_interval(minute_interval, qs=qs)
return sample_captures_by_interval(minute_interval=minute_interval, qs=qs)

def sample_positional(self, position: int = -1):
"""Sample the single nth source image from all events in the project"""

qs = self.get_queryset()
return sample_captures_by_position(position, qs=qs)
return sample_captures_by_position(position=position, qs=qs)

def sample_nth(self, nth: int):
"""Sample every nth source image from all events in the project"""

qs = self.get_queryset()
return sample_captures_by_nth(nth, qs=qs)
return sample_captures_by_nth(nth=nth, qs=qs)

def sample_random_from_each_event(self, num_each: int = 10):
"""Sample n random source images from each event in the project."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { FormField } from 'components/form/form-field'
import { isValid } from 'date-fns'

import {
FormActions,
FormError,
Expand All @@ -24,9 +26,29 @@ type CollectionFormValues = FormValues & {
month_end: string | undefined
hour_start: number | undefined
hour_end: number | undefined
date_start: string | undefined
date_end: string | undefined
}
}

// simple date string config

const kwargs_date_config = {
label: 'Date',
description: 'Format: YYYY-MM-DD',
rules: {
validate: (value: any): string | undefined => {
if (!value) return undefined

if (!isValid(new Date(value))) {
return 'Date must be in YYYY-MM-DD format'
}

return undefined
},
},
}

const config: FormConfig = {
name: {
label: translate(STRING.FIELD_LABEL_NAME),
Expand All @@ -51,6 +73,7 @@ const config: FormConfig = {
},
'kwargs.max_num': {
label: 'Max number of images',
description: 'When set, the collection will be a random sample',
},
'kwargs.minute_interval': {
label: 'Minutes between captures',
Expand All @@ -67,6 +90,14 @@ const config: FormConfig = {
'kwargs.hour_end': {
label: 'Latest hour',
},
'kwargs.date_start': {
...kwargs_date_config,
label: 'Earliest date',
},
'kwargs.date_end': {
...kwargs_date_config,
label: 'Latest date',
},
}

export const CollectionDetailsForm = ({
Expand Down Expand Up @@ -179,6 +210,20 @@ export const CollectionDetailsForm = ({
control={control}
/>
</FormRow>
<FormRow>
<FormField
name="kwargs.date_start"
type="text"
config={config}
control={control}
/>
<FormField
name="kwargs.date_end"
type="text"
config={config}
control={control}
/>
</FormRow>
<FormRow>
<FormField
name="method"
Expand Down

0 comments on commit 3181753

Please sign in to comment.