Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ Changelog
==========


13.3.1 (2024-12-20)
---------------------

* Initial release for DSS 13.3.1

13.2.4 (2024-12-03)
---------------------

Expand Down
5 changes: 3 additions & 2 deletions dataikuapi/apinode_admin/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ def list_keys(self):
"""Lists the Admin API keys"""
return self.client._perform_json("GET", "keys")

def add_key(self, label=None, description=None, created_by=None):
def add_key(self, label=None, description=None, created_by=None, expiry=None):
"""Add an Admin API key. Returns the key details"""
key = {
"label" : label,
"description" : description,
"createdBy" : created_by
"createdBy" : created_by,
"expiry" : expiry
}
return self.client._perform_json("POST", "keys", body=key)

Expand Down
7 changes: 7 additions & 0 deletions dataikuapi/dss/jupyternotebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def delete(self):
return self.client._perform_json("DELETE",
"/projects/%s/jupyter-notebooks/%s" % (self.project_key, self.notebook_name))

def clear_outputs(self):
"""
Clear this Jupyter notebook's outputs.
"""
return self.client._perform_json("DELETE",
"/projects/%s/jupyter-notebooks/%s/outputs" % (self.project_key, self.notebook_name))

########################################################
# Discussions
########################################################
Expand Down
69 changes: 48 additions & 21 deletions dataikuapi/dss/langchain/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@
import asyncio
import concurrent
import logging
from typing import List, Any
import threading

from pydantic import BaseModel, Extra
from typing import List, Any

import pydantic
from langchain.embeddings.base import Embeddings
from dataikuapi.dss.llm_tracing import new_trace

from dataikuapi.dss.langchain.utils import must_use_deprecated_pydantic_config

logger = logging.getLogger(__name__)
CHUNK_SIZE = 1000


class DKUEmbeddings(BaseModel, Embeddings):
if must_use_deprecated_pydantic_config():
class LockedDownBaseModel(pydantic.BaseModel):
class Config:
extra = pydantic.Extra.forbid
underscore_attrs_are_private = True
else:
class LockedDownBaseModel(pydantic.BaseModel):
model_config = {
'extra': 'forbid',
}


class DKUEmbeddings(LockedDownBaseModel, Embeddings):
"""
Langchain-compatible wrapper around Dataiku-mediated embedding LLMs

Expand All @@ -27,9 +42,12 @@ class DKUEmbeddings(BaseModel, Embeddings):
_llm_handle = None
""":class:`dataikuapi.dss.llm.DSSLLM` object to wrap."""

class Config:
extra = Extra.forbid
underscore_attrs_are_private = True
# The embeddings class of LangChain can only return raw embedding, without any additional information
# (unlike ChatModel, which supports additional information), so we cannot use this to return the last
# trace to the caller.
# So, instead, we keep a thread local with the last trace, and the caller can get it from here
# (at the moment, it's mostly done by rag_query_server.py)
_last_trace = None

def __init__(self, llm_handle=None, **data: Any):
if llm_handle is None:
Expand All @@ -45,6 +63,7 @@ def __init__(self, llm_handle=None, **data: Any):

super().__init__(**data)
self._llm_handle = llm_handle
self._last_trace = threading.local()

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Dataiku-mediated LLM
Expand All @@ -55,29 +74,37 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
Returns:
List of embeddings, one for each text.
"""
logging.info("Performing embedding of {num_texts} texts".format(num_texts=len(texts)))

embeddings = []
for i in range(0, len(texts), CHUNK_SIZE):
query = self._llm_handle.new_embeddings(text_overflow_mode="FAIL")
with new_trace("DKUEmbeddings") as trace:
self._last_trace.trace = trace

logger.info("Performing embedding of {num_texts} texts".format(num_texts=len(texts)))

embeddings = []
for i in range(0, len(texts), CHUNK_SIZE):
query = self._llm_handle.new_embeddings(text_overflow_mode="FAIL")

for text in texts[i:i+CHUNK_SIZE]:
query.add_text(text)

for text in texts[i:i+CHUNK_SIZE]:
query.add_text(text)
resp = query.execute()

resp = query.execute()
# TODO
#if not resp.success:
# raise Exception("LLM call failed: %s" % resp._raw.get("errorMessage", "Unknown error"))

# TODO
#if not resp.success:
# raise Exception("LLM call failed: %s" % resp._raw.get("errorMessage", "Unknown error"))
if "responses" in resp._raw and len(resp._raw["responses"]) == 1:
if "trace" in resp._raw["responses"][0]:
trace.append_trace(resp._raw["responses"][0]["trace"])

embeddings.extend(resp.get_embeddings())
embeddings.extend(resp.get_embeddings())

logging.info("Finished a chunk. Embedded {num_embedded} of {num_texts} texts".format(
num_embedded=min(i + CHUNK_SIZE, len(texts)), num_texts=len(texts)))
logger.info("Finished a chunk. Embedded {num_embedded} of {num_texts} texts".format(
num_embedded=min(i + CHUNK_SIZE, len(texts)), num_texts=len(texts)))

logging.info("Done performing embedding of {num_texts} texts".format(num_texts=len(texts)))
logger.info("Done performing embedding of {num_texts} texts".format(num_texts=len(texts)))

return embeddings
return embeddings

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
loop = asyncio.get_event_loop()
Expand Down
Loading