Skip to content

Commit c8b48d9

Browse files
jaycee-licopybara-github
authored andcommitted
feat: [vertexai] add generateContentAsync methods to GenerativeModel
PiperOrigin-RevId: 617951189
1 parent b5e8e3d commit c8b48d9

4 files changed

Lines changed: 411 additions & 66 deletions

File tree

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java

Lines changed: 162 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,97 @@
2424
import com.google.cloud.vertexai.api.Candidate.FinishReason;
2525
import com.google.cloud.vertexai.api.Content;
2626
import com.google.cloud.vertexai.api.GenerateContentResponse;
27+
import com.google.cloud.vertexai.api.GenerationConfig;
28+
import com.google.cloud.vertexai.api.SafetySetting;
29+
import com.google.cloud.vertexai.api.Tool;
2730
import java.io.IOException;
2831
import java.util.ArrayList;
2932
import java.util.Collections;
3033
import java.util.List;
34+
import java.util.Optional;
3135

3236
/** Represents a conversation between the user and the model */
3337
public final class ChatSession {
3438
private final GenerativeModel model;
39+
private final Optional<ChatSession> rootChatSession;
3540
private List<Content> history = new ArrayList<>();
36-
private ResponseStream<GenerateContentResponse> currentResponseStream = null;
37-
private GenerateContentResponse currentResponse = null;
41+
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
42+
private Optional<GenerateContentResponse> currentResponse;
3843

44+
/**
45+
* Creates a new chat session given a GenerativeModel instance. Configurations of the chat (e.g.,
46+
* GenerationConfig) inherits from the model.
47+
*/
3948
@BetaApi
4049
public ChatSession(GenerativeModel model) {
50+
this(model, Optional.empty());
51+
}
52+
53+
/**
54+
* Creates a new chat session given a GenerativeModel instance and a root chat session.
55+
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
56+
*
57+
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
58+
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
59+
* chat session will be merged to the root chat session.
60+
* @return a {@link ChatSession} instance.
61+
*/
62+
@BetaApi
63+
private ChatSession(GenerativeModel model, Optional<ChatSession> rootChatSession) {
4164
if (model == null) {
4265
throw new IllegalArgumentException("model should not be null");
4366
}
4467
this.model = model;
68+
this.rootChatSession = rootChatSession;
69+
currentResponseStream = Optional.empty();
70+
currentResponse = Optional.empty();
71+
}
72+
73+
/**
74+
* Creates a copy of the current ChatSession with updated GenerationConfig.
75+
*
76+
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
77+
* used in the new ChatSession.
78+
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
79+
*/
80+
@BetaApi
81+
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
82+
ChatSession rootChat = rootChatSession.orElse(this);
83+
ChatSession newChatSession =
84+
new ChatSession(model.withGenerationConfig(generationConfig), Optional.of(rootChat));
85+
newChatSession.setHistory(history);
86+
return newChatSession;
87+
}
88+
89+
/**
90+
* Creates a copy of the current ChatSession with updated SafetySettings.
91+
*
92+
* @param safetySettings a {@link com.google.cloud.vertexai.api.SafetySetting} that will be used
93+
* in the new ChatSession.
94+
* @return a new {@link ChatSession} instance with the specified SafetySettings.
95+
*/
96+
@BetaApi
97+
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
98+
ChatSession rootChat = rootChatSession.orElse(this);
99+
ChatSession newChatSession =
100+
new ChatSession(model.withSafetySettings(safetySettings), Optional.of(rootChat));
101+
newChatSession.setHistory(history);
102+
return newChatSession;
103+
}
104+
105+
/**
106+
* Creates a copy of the current ChatSession with updated Tools.
107+
*
108+
* @param tools a {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
109+
* ChatSession.
110+
* @return a new {@link ChatSession} instance with the specified Tools.
111+
*/
112+
@BetaApi
113+
public ChatSession withTools(List<Tool> tools) {
114+
ChatSession rootChat = rootChatSession.orElse(this);
115+
ChatSession newChatSession = new ChatSession(model.withTools(tools), Optional.of(rootChat));
116+
newChatSession.setHistory(history);
117+
return newChatSession;
45118
}
46119

47120
/**
@@ -69,8 +142,8 @@ public ResponseStream<GenerateContentResponse> sendMessageStream(Content content
69142
checkLastResponseAndEditHistory();
70143
history.add(content);
71144
ResponseStream<GenerateContentResponse> respStream = model.generateContentStream(history);
72-
currentResponseStream = respStream;
73-
currentResponse = null;
145+
setCurrentResponseStream(Optional.of(respStream));
146+
74147
return respStream;
75148
}
76149

@@ -96,8 +169,7 @@ public GenerateContentResponse sendMessage(Content content) throws IOException {
96169
checkLastResponseAndEditHistory();
97170
history.add(content);
98171
GenerateContentResponse response = model.generateContent(history);
99-
currentResponse = response;
100-
currentResponseStream = null;
172+
setCurrentResponse(Optional.of(response));
101173
return response;
102174
}
103175

@@ -112,38 +184,37 @@ private void removeLastContent() {
112184
* @throws IllegalStateException if the response stream is not finished.
113185
*/
114186
private void checkLastResponseAndEditHistory() {
115-
if (currentResponseStream == null && currentResponse == null) {
116-
return;
117-
} else if (currentResponseStream != null && !currentResponseStream.isConsumed()) {
118-
throw new IllegalStateException("Response stream is not consumed");
119-
} else if (currentResponseStream != null && currentResponseStream.isConsumed()) {
120-
GenerateContentResponse response = aggregateStreamIntoResponse(currentResponseStream);
121-
FinishReason finishReason = getFinishReason(response);
122-
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
123-
// We also remove the request from the history.
124-
removeLastContent();
125-
currentResponseStream = null;
126-
throw new IllegalStateException(
127-
String.format(
128-
"The last round of conversation will not be added to history because response"
129-
+ " stream did not finish normally. Finish reason is %s.",
130-
finishReason));
131-
}
132-
history.add(getContent(response));
133-
} else if (currentResponseStream == null && currentResponse != null) {
134-
FinishReason finishReason = getFinishReason(currentResponse);
135-
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
136-
// We also remove the request from the history.
137-
removeLastContent();
138-
currentResponse = null;
139-
throw new IllegalStateException(
140-
String.format(
141-
"The last round of conversation will not be added to history because response did"
142-
+ " not finish normally. Finish reason is %s.",
143-
finishReason));
144-
}
145-
history.add(getContent(currentResponse));
146-
currentResponse = null;
187+
getCurrentResponse()
188+
.ifPresent(
189+
currentResponse -> {
190+
setCurrentResponse(Optional.empty());
191+
checkFinishReasonAndRemoveLastContent(currentResponse);
192+
history.add(getContent(currentResponse));
193+
});
194+
getCurrentResponseStream()
195+
.ifPresent(
196+
responseStream -> {
197+
if (!responseStream.isConsumed()) {
198+
throw new IllegalStateException("Response stream is not consumed");
199+
} else {
200+
setCurrentResponseStream(Optional.empty());
201+
GenerateContentResponse response = aggregateStreamIntoResponse(responseStream);
202+
checkFinishReasonAndRemoveLastContent(response);
203+
history.add(getContent(response));
204+
}
205+
});
206+
}
207+
208+
/** Removes the last content in the history if the response finished with problems. */
209+
private void checkFinishReasonAndRemoveLastContent(GenerateContentResponse response) {
210+
FinishReason finishReason = getFinishReason(response);
211+
if (finishReason != FinishReason.STOP && finishReason != FinishReason.MAX_TOKENS) {
212+
removeLastContent();
213+
throw new IllegalStateException(
214+
String.format(
215+
"The last round of conversation will not be added to history because response"
216+
+ " stream did not finish normally. Finish reason is %s.",
217+
finishReason));
147218
}
148219
}
149220

@@ -169,9 +240,62 @@ public List<Content> getHistory() {
169240
return Collections.unmodifiableList(history);
170241
}
171242

243+
/**
244+
* Returns the current response of the root chat session (if exists) or the current chat session.
245+
*/
246+
private Optional<GenerateContentResponse> getCurrentResponse() {
247+
if (rootChatSession.isPresent()) {
248+
return rootChatSession.get().getCurrentResponse();
249+
} else {
250+
return currentResponse;
251+
}
252+
}
253+
254+
/**
255+
* Returns the current responseStream of the root chat session (if exists) or the current chat
256+
* session.
257+
*/
258+
private Optional<ResponseStream<GenerateContentResponse>> getCurrentResponseStream() {
259+
if (rootChatSession.isPresent()) {
260+
return rootChatSession.get().getCurrentResponseStream();
261+
} else {
262+
return currentResponseStream;
263+
}
264+
}
265+
172266
/** Set the history to a list of Content */
173267
@BetaApi
174268
public void setHistory(List<Content> history) {
175269
this.history = history;
176270
}
271+
272+
/** Sets the current response of the root chat session (if exists) or the current chat session. */
273+
private void setCurrentResponse(Optional<GenerateContentResponse> response) {
274+
if (currentResponseStream.isPresent()) {
275+
throw new IllegalStateException(
276+
"currentResponse and currentResponseStream cannot be set together");
277+
}
278+
if (rootChatSession.isPresent()) {
279+
rootChatSession.get().setCurrentResponse(response);
280+
} else {
281+
currentResponse = response;
282+
}
283+
}
284+
285+
/**
286+
* Sets the current responseStream of the root chat session (if exists) or the current chat
287+
* session.
288+
*/
289+
private void setCurrentResponseStream(
290+
Optional<ResponseStream<GenerateContentResponse>> responseStream) {
291+
if (currentResponse.isPresent()) {
292+
throw new IllegalStateException(
293+
"currentResponseStream and currentResponse cannot be set together");
294+
}
295+
if (rootChatSession.isPresent()) {
296+
rootChatSession.get().setCurrentResponseStream(responseStream);
297+
} else {
298+
currentResponseStream = responseStream;
299+
}
300+
}
177301
}

0 commit comments

Comments
 (0)