Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

# Version 3.4.0 (2021-09-10)
## Added
* New `IAMIntegration` entity
* `Client.create_dataset()` compatibility with delegated access
* `Organization.get_iam_integrations()` to list all integrations available to an org
* `Organization.get_default_iam_integration()` to only get the defaault iam integration

# Version 3.3.0 (2021-09-02)
## Added
* `Dataset.create_data_rows_sync()` for synchronous bulk uploads of data rows
Expand Down
2 changes: 1 addition & 1 deletion labelbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "labelbox"
__version__ = "3.3.0"
__version__ = "3.4.0"

from labelbox.schema.project import Project
from labelbox.client import Client
Expand Down
48 changes: 44 additions & 4 deletions labelbox/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore
from datetime import datetime, timezone
import json

import logging
import mimetypes
import os
Expand All @@ -9,8 +10,9 @@
import requests
import requests.exceptions

from labelbox import utils
import labelbox.exceptions
from labelbox import utils
from labelbox import __version__ as SDK_VERSION
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.pagination import PaginatedCollection
Expand All @@ -22,8 +24,8 @@
from labelbox.schema.organization import Organization
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
from labelbox.schema.labeling_frontend import LabelingFrontend
from labelbox.schema.iam_integration import IAMIntegration
from labelbox.schema import role
from labelbox import __version__ as SDK_VERSION

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -503,7 +505,7 @@ def _create(self, db_object_type, data):
res = res["create%s" % db_object_type.type_name()]
return db_object_type(self, res)

def create_dataset(self, **kwargs):
def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
""" Creates a Dataset object on the server.

Attribute values are passed as keyword arguments.
Expand All @@ -512,14 +514,52 @@ def create_dataset(self, **kwargs):
>>> dataset = client.create_dataset(name="<dataset_name>", projects=project)

Args:
iam_integration (IAMIntegration) : Uses the default integration.
Optionally specify another integration or set as None to not use delegated access
**kwargs: Keyword arguments with Dataset attribute values.
Returns:
A new Dataset object.
Raises:
InvalidAttributeError: If the Dataset type does not contain
any of the attribute names given in kwargs.
"""
return self._create(Dataset, kwargs)
dataset = self._create(Dataset, kwargs)

if iam_integration == IAMIntegration._DEFAULT:
iam_integration = self.get_organization(
).get_default_iam_integration()

if iam_integration is None:
return dataset

if not isinstance(iam_integration, IAMIntegration):
raise TypeError(
f"iam integration must be a reference an `IAMIntegration` object. Found {type(iam_integration)}"
)

if not iam_integration.valid:
raise ValueError("Integration is not valid. Please select another.")
try:
self.execute(
"""mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) {
setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}}
""", {
'signerId': iam_integration.uid,
'datasetId': dataset.uid
})
validation_result = self.execute(
"""mutation validateDatasetPyApi($id: ID!){validateDataset(where: {id : $id}){
valid checks{name, success}}}
""", {'id': dataset.uid})

if not validation_result['validateDataset']['valid']:
raise labelbox.exceptions.LabelboxError(
f"IAMIntegration {validation_result['validateDataset']['checks']['name']} was not successfully added added to the project."
)
except Exception as e:
dataset.delete()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: can we add optional SignerId to createDatasetInput? This way we wouldn't have to do this rollback, it would all be handled by the backend

raise e
return dataset

def create_project(self, **kwargs):
""" Creates a Project object on the server.
Expand Down
1 change: 1 addition & 0 deletions labelbox/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
import labelbox.schema.user
import labelbox.schema.webhook
import labelbox.schema.data_row_metadata
import labelbox.schema.iam_integration
3 changes: 3 additions & 0 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from labelbox.schema import iam_integration
from labelbox import utils
import os
import json
Expand Down Expand Up @@ -43,6 +44,8 @@ class Dataset(DbObject, Updateable, Deletable):
data_rows = Relationship.ToMany("DataRow", False)
created_by = Relationship.ToOne("User", False, "created_by")
organization = Relationship.ToOne("Organization", False)
iam_integration = Relationship.ToOne("IAMIntegration", False,
"iam_integration", "signer")

def create_data_row(self, **kwargs):
""" Creates a single DataRow belonging to this dataset.
Expand Down
56 changes: 56 additions & 0 deletions labelbox/schema/iam_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass

from labelbox.utils import snake_case
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field


@dataclass
class AwsIamIntegrationSettings:
role_arn: str


@dataclass
class GcpIamIntegrationSettings:
service_account_email_id: str
read_bucket: str


class IAMIntegration(DbObject):
""" Represents an IAM integration for delegated access

Attributes:
name (str)
updated_at (datetime)
created_at (datetime)
provider (str)
valid (bool)
last_valid_at (datetime)
is_org_default (boolean)

"""

def __init__(self, client, data):
settings = data.pop('settings', None)
if settings is not None:
type_name = settings.pop('__typename')
settings = {snake_case(k): v for k, v in settings.items()}
if type_name == "GcpIamIntegrationSettings":
self.settings = GcpIamIntegrationSettings(**settings)
elif type_name == "AwsIamIntegrationSettings":
self.settings = AwsIamIntegrationSettings(**settings)
else:
self.settings = None
else:
self.settings = None
super().__init__(client, data)

_DEFAULT = "DEFAULT"

name = Field.String("name")
created_at = Field.DateTime("created_at")
updated_at = Field.DateTime("updated_at")
provider = Field.String("provider")
valid = Field.Boolean("valid")
last_valid_at = Field.DateTime("last_valid_at")
is_org_default = Field.Boolean("is_org_default")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we're missing settings field?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh - I know. The union thing is a nightmare. But users might want to know which bucket they have access to. I'll try to figure something out.

38 changes: 37 additions & 1 deletion labelbox/schema/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from labelbox.exceptions import LabelboxError
from labelbox import utils
from labelbox.orm.db_object import DbObject, experimental, query
from labelbox.orm.db_object import DbObject, experimental, query, Entity
from labelbox.orm.model import Field, Relationship
from labelbox.schema.invite import Invite, InviteLimit, ProjectRole
from labelbox.schema.user import User
Expand Down Expand Up @@ -129,3 +129,39 @@ def remove_user(self, user: User):
"""mutation DeleteMemberPyApi($%s: ID!) {
updateUser(where: {id: $%s}, data: {deleted: true}) { id deleted }
}""" % (user_id_param, user_id_param), {user_id_param: user.uid})

def get_iam_integrations(self):
"""
Returns all IAM Integrations for an organization
"""
res = self.client.execute(
"""query getAllIntegrationsPyApi { iamIntegrations {
%s
settings {
__typename
... on AwsIamIntegrationSettings {roleArn}
... on GcpIamIntegrationSettings {serviceAccountEmailId readBucket}
}

} } """ % query.results_query_part(Entity.IAMIntegration))
return [
Entity.IAMIntegration(self.client, integration_data)
for integration_data in res['iamIntegrations']
]

def get_default_iam_integration(self):
"""
Returns the default IAM integration for the organization.
Will return None if there are no default integrations for the org.
"""
integrations = self.get_iam_integrations()
default_integration = [
integration for integration in integrations
if integration.is_org_default
]
if len(default_integration) > 1:
raise ValueError(
"Found more than one default signer. Please contact Labelbox to resolve"
)
return None if not len(
default_integration) else default_integration.pop()
1 change: 1 addition & 0 deletions tests/integration/test_data_row_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def test_delete_non_existent_schema_id(datarow, mdo):


@pytest.mark.slow
@pytest.mark.skip("Test is inconsistent.")
def test_large_bulk_delete_non_existent_schema_id(big_dataset, mdo):
deletes = []
n_fields_start = 0
Expand Down
43 changes: 43 additions & 0 deletions tests/integration/test_delegated_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import requests
import pytest


@pytest.mark.skip("Can only be tested in specific organizations.")
def test_default_integration(client):
# This tests assumes the following:
# 1. gcp delegated access is configured to work with utkarsh-da-test-bucket
# 2. the integration name is gcp test
# 3. This integration is the default
ds = client.create_dataset(name="new_ds")
dr = ds.create_data_row(
row_data=
"gs://utkarsh-da-test-bucket/mathew-schwartz-8rj4sz9YLCI-unsplash.jpg")
assert requests.get(dr.row_data).status_code == 200
assert ds.iam_integration().name == "GCP Test"
ds.delete()


@pytest.mark.skip("Can only be tested in specific organizations.")
def test_non_default_integration(client):
# This tests assumes the following:
# 1. aws delegated access is configured to work with lbox-test-bucket
# 2. an integration called aws is available to the org
integrations = client.get_organization().get_iam_integrations()
integration = [inte for inte in integrations if 'aws' in inte.name][0]
assert integration.valid
ds = client.create_dataset(iam_integration=integration, name="new_ds")
assert ds.iam_integration().name == "aws"
dr = ds.create_data_row(
row_data=
"https://lbox-test-bucket.s3.us-east-1.amazonaws.com/2021_09_08_0hz_Kleki.png"
)
assert requests.get(dr.row_data).status_code == 200
ds.delete()


def test_no_integration(client, image_url):
ds = client.create_dataset(iam_integration=None, name="new_ds")
assert ds.iam_integration() is None
dr = ds.create_data_row(row_data=image_url)
assert requests.get(dr.row_data).status_code == 200
ds.delete()