Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions labelbox/schema/data_row_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from labelbox.schema.ontology import SchemaId
from labelbox.utils import camel_case

_MAX_METADATA_FIELDS = 5


class DataRowMetadataKind(Enum):
number = "CustomMetadataNumber"
Expand Down Expand Up @@ -40,16 +38,8 @@ def id(self):

DataRowMetadataSchema.update_forward_refs()

# Constraints for metadata values
Embedding: Type[List[float]] = conlist(float, min_items=128, max_items=128)
DateTime: Type[datetime] = datetime # must be in UTC
String: Type[str] = constr(max_length=500)
OptionId: Type[SchemaId] = SchemaId # enum option
Number: Type[float] = float

DataRowMetadataValue = Union[Embedding, Number, DateTime, String, OptionId]
# primitives used in uploads
_DataRowMetadataValuePrimitives = Union[str, List, dict, float]


class _CamelCaseMixin(BaseModel):
Expand All @@ -62,7 +52,9 @@ class Config:
# Metadata base class
class DataRowMetadataField(_CamelCaseMixin):
schema_id: SchemaId
value: Union[DataRowMetadataValue, _DataRowMetadataValuePrimitives]
# value is of type `Any` so that we do not improperly coerce the value to the wrong tpye
# Additional validation is performed before upload using the schema information
value: Any


class DataRowMetadata(_CamelCaseMixin):
Expand Down Expand Up @@ -241,10 +233,14 @@ def parse_metadata(
elif schema.kind == DataRowMetadataKind.option:
field = DataRowMetadataField(schema_id=schema.parent,
value=schema.uid)
elif schema.kind == DataRowMetadataKind.datetime:
field = DataRowMetadataField(
schema_id=schema.uid,
value=datetime.fromisoformat(f["value"][:-1] +
"+00:00"))
else:
field = DataRowMetadataField(schema_id=schema.uid,
value=f["value"])

fields.append(field)
parsed.append(
DataRowMetadata(data_row_id=dr["dataRowId"], fields=fields))
Expand Down Expand Up @@ -300,10 +296,6 @@ def _batch_upsert(

items = []
for m in metadata:
if len(m.fields) > _MAX_METADATA_FIELDS:
raise ValueError(
f"Cannot upload {len(m.fields)}, the max number is {_MAX_METADATA_FIELDS}"
)
items.append(
_UpsertBatchDataRowMetadata(
data_row_id=m.data_row_id,
Expand Down Expand Up @@ -478,17 +470,39 @@ def _batch_operations(
def _validate_parse_embedding(
field: DataRowMetadataField
) -> List[Dict[str, Union[SchemaId, Embedding]]]:

if isinstance(field.value, list):
if not (Embedding.min_items <= len(field.value) <= Embedding.max_items):
raise ValueError(
"Embedding length invalid. "
"Must have length within the interval "
f"[{Embedding.min_items},{Embedding.max_items}]. Found {len(field.value)}"
)
field.value = [float(x) for x in field.value]
else:
raise ValueError(
f"Expected a list for embedding. Found {type(field.value)}")
return [field.dict(by_alias=True)]


def _validate_parse_number(
field: DataRowMetadataField
) -> List[Dict[str, Union[SchemaId, Number]]]:
field: DataRowMetadataField
) -> List[Dict[str, Union[SchemaId, str, float, int]]]:
field.value = float(field.value)
return [field.dict(by_alias=True)]


def _validate_parse_datetime(
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
if isinstance(field.value, str):
if field.value.endswith("Z"):
field.value = field.value[:-1]
field.value = datetime.fromisoformat(field.value)
elif not isinstance(field.value, datetime):
raise TypeError(
f"value for datetime fields must be either a string or datetime object. Found {type(field.value)}"
)

return [{
"schemaId": field.schema_id,
"value": field.value.isoformat() + "Z", # needs to be UTC
Expand All @@ -497,6 +511,15 @@ def _validate_parse_datetime(

def _validate_parse_text(
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
if not isinstance(field.value, str):
raise ValueError(
f"Expected a string type for the text field. Found {type(field.value)}"
)

if len(field.value) > String.max_length:
raise ValueError(
f"string fields cannot exceed {String.max_length} characters.")

return [field.dict(by_alias=True)]


Expand Down