2424import com .google .cloud .vertexai .api .Candidate .FinishReason ;
2525import com .google .cloud .vertexai .api .Content ;
2626import com .google .cloud .vertexai .api .GenerateContentResponse ;
27+ import com .google .cloud .vertexai .api .GenerationConfig ;
28+ import com .google .cloud .vertexai .api .SafetySetting ;
29+ import com .google .cloud .vertexai .api .Tool ;
2730import java .io .IOException ;
2831import java .util .ArrayList ;
2932import java .util .Collections ;
3033import java .util .List ;
34+ import java .util .Optional ;
3135
3236/** Represents a conversation between the user and the model */
3337public final class ChatSession {
3438 private final GenerativeModel model ;
39+ private final Optional <ChatSession > rootChatSession ;
3540 private List <Content > history = new ArrayList <>();
36- private ResponseStream <GenerateContentResponse > currentResponseStream = null ;
37- private GenerateContentResponse currentResponse = null ;
41+ private Optional < ResponseStream <GenerateContentResponse >> currentResponseStream ;
42+ private Optional < GenerateContentResponse > currentResponse ;
3843
44+ /**
45+ * Creates a new chat session given a GenerativeModel instance. Configurations of the chat (e.g.,
46+ * GenerationConfig) inherits from the model.
47+ */
3948 @ BetaApi
4049 public ChatSession (GenerativeModel model ) {
50+ this (model , Optional .empty ());
51+ }
52+
53+ /**
54+ * Creates a new chat session given a GenerativeModel instance and a root chat session.
55+ * Configurations of the chat (e.g., GenerationConfig) inherits from the model.
56+ *
57+ * @param model a {@link GenerativeModel} instance that generates contents in the chat.
58+ * @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
59+ * chat session will be merged to the root chat session.
60+ * @return a {@link ChatSession} instance.
61+ */
62+ @ BetaApi
63+ private ChatSession (GenerativeModel model , Optional <ChatSession > rootChatSession ) {
4164 if (model == null ) {
4265 throw new IllegalArgumentException ("model should not be null" );
4366 }
4467 this .model = model ;
68+ this .rootChatSession = rootChatSession ;
69+ currentResponseStream = Optional .empty ();
70+ currentResponse = Optional .empty ();
71+ }
72+
73+ /**
74+ * Creates a copy of the current ChatSession with updated GenerationConfig.
75+ *
76+ * @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} that will be
77+ * used in the new ChatSession.
78+ * @return a new {@link ChatSession} instance with the specified GenerationConfig.
79+ */
80+ @ BetaApi
81+ public ChatSession withGenerationConfig (GenerationConfig generationConfig ) {
82+ ChatSession rootChat = rootChatSession .orElse (this );
83+ ChatSession newChatSession =
84+ new ChatSession (model .withGenerationConfig (generationConfig ), Optional .of (rootChat ));
85+ newChatSession .setHistory (history );
86+ return newChatSession ;
87+ }
88+
89+ /**
90+ * Creates a copy of the current ChatSession with updated SafetySettings.
91+ *
92+ * @param safetySettings a {@link com.google.cloud.vertexai.api.SafetySetting} that will be used
93+ * in the new ChatSession.
94+ * @return a new {@link ChatSession} instance with the specified SafetySettings.
95+ */
96+ @ BetaApi
97+ public ChatSession withSafetySettings (List <SafetySetting > safetySettings ) {
98+ ChatSession rootChat = rootChatSession .orElse (this );
99+ ChatSession newChatSession =
100+ new ChatSession (model .withSafetySettings (safetySettings ), Optional .of (rootChat ));
101+ newChatSession .setHistory (history );
102+ return newChatSession ;
103+ }
104+
105+ /**
106+ * Creates a copy of the current ChatSession with updated Tools.
107+ *
108+ * @param tools a {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
109+ * ChatSession.
110+ * @return a new {@link ChatSession} instance with the specified Tools.
111+ */
112+ @ BetaApi
113+ public ChatSession withTools (List <Tool > tools ) {
114+ ChatSession rootChat = rootChatSession .orElse (this );
115+ ChatSession newChatSession = new ChatSession (model .withTools (tools ), Optional .of (rootChat ));
116+ newChatSession .setHistory (history );
117+ return newChatSession ;
45118 }
46119
47120 /**
@@ -69,8 +142,8 @@ public ResponseStream<GenerateContentResponse> sendMessageStream(Content content
69142 checkLastResponseAndEditHistory ();
70143 history .add (content );
71144 ResponseStream <GenerateContentResponse > respStream = model .generateContentStream (history );
72- currentResponseStream = respStream ;
73- currentResponse = null ;
145+ setCurrentResponseStream ( Optional . of ( respStream )) ;
146+
74147 return respStream ;
75148 }
76149
@@ -96,8 +169,7 @@ public GenerateContentResponse sendMessage(Content content) throws IOException {
96169 checkLastResponseAndEditHistory ();
97170 history .add (content );
98171 GenerateContentResponse response = model .generateContent (history );
99- currentResponse = response ;
100- currentResponseStream = null ;
172+ setCurrentResponse (Optional .of (response ));
101173 return response ;
102174 }
103175
@@ -112,38 +184,37 @@ private void removeLastContent() {
112184 * @throws IllegalStateException if the response stream is not finished.
113185 */
114186 private void checkLastResponseAndEditHistory () {
115- if (currentResponseStream == null && currentResponse == null ) {
116- return ;
117- } else if (currentResponseStream != null && !currentResponseStream .isConsumed ()) {
118- throw new IllegalStateException ("Response stream is not consumed" );
119- } else if (currentResponseStream != null && currentResponseStream .isConsumed ()) {
120- GenerateContentResponse response = aggregateStreamIntoResponse (currentResponseStream );
121- FinishReason finishReason = getFinishReason (response );
122- if (finishReason != FinishReason .STOP && finishReason != FinishReason .MAX_TOKENS ) {
123- // We also remove the request from the history.
124- removeLastContent ();
125- currentResponseStream = null ;
126- throw new IllegalStateException (
127- String .format (
128- "The last round of conversation will not be added to history because response"
129- + " stream did not finish normally. Finish reason is %s." ,
130- finishReason ));
131- }
132- history .add (getContent (response ));
133- } else if (currentResponseStream == null && currentResponse != null ) {
134- FinishReason finishReason = getFinishReason (currentResponse );
135- if (finishReason != FinishReason .STOP && finishReason != FinishReason .MAX_TOKENS ) {
136- // We also remove the request from the history.
137- removeLastContent ();
138- currentResponse = null ;
139- throw new IllegalStateException (
140- String .format (
141- "The last round of conversation will not be added to history because response did"
142- + " not finish normally. Finish reason is %s." ,
143- finishReason ));
144- }
145- history .add (getContent (currentResponse ));
146- currentResponse = null ;
187+ getCurrentResponse ()
188+ .ifPresent (
189+ currentResponse -> {
190+ setCurrentResponse (Optional .empty ());
191+ checkFinishReasonAndRemoveLastContent (currentResponse );
192+ history .add (getContent (currentResponse ));
193+ });
194+ getCurrentResponseStream ()
195+ .ifPresent (
196+ responseStream -> {
197+ if (!responseStream .isConsumed ()) {
198+ throw new IllegalStateException ("Response stream is not consumed" );
199+ } else {
200+ setCurrentResponseStream (Optional .empty ());
201+ GenerateContentResponse response = aggregateStreamIntoResponse (responseStream );
202+ checkFinishReasonAndRemoveLastContent (response );
203+ history .add (getContent (response ));
204+ }
205+ });
206+ }
207+
208+ /** Removes the last content in the history if the response finished with problems. */
209+ private void checkFinishReasonAndRemoveLastContent (GenerateContentResponse response ) {
210+ FinishReason finishReason = getFinishReason (response );
211+ if (finishReason != FinishReason .STOP && finishReason != FinishReason .MAX_TOKENS ) {
212+ removeLastContent ();
213+ throw new IllegalStateException (
214+ String .format (
215+ "The last round of conversation will not be added to history because response"
216+ + " stream did not finish normally. Finish reason is %s." ,
217+ finishReason ));
147218 }
148219 }
149220
@@ -169,9 +240,62 @@ public List<Content> getHistory() {
169240 return Collections .unmodifiableList (history );
170241 }
171242
243+ /**
244+ * Returns the current response of the root chat session (if exists) or the current chat session.
245+ */
246+ private Optional <GenerateContentResponse > getCurrentResponse () {
247+ if (rootChatSession .isPresent ()) {
248+ return rootChatSession .get ().getCurrentResponse ();
249+ } else {
250+ return currentResponse ;
251+ }
252+ }
253+
254+ /**
255+ * Returns the current responseStream of the root chat session (if exists) or the current chat
256+ * session.
257+ */
258+ private Optional <ResponseStream <GenerateContentResponse >> getCurrentResponseStream () {
259+ if (rootChatSession .isPresent ()) {
260+ return rootChatSession .get ().getCurrentResponseStream ();
261+ } else {
262+ return currentResponseStream ;
263+ }
264+ }
265+
172266 /** Set the history to a list of Content */
173267 @ BetaApi
174268 public void setHistory (List <Content > history ) {
175269 this .history = history ;
176270 }
271+
272+ /** Sets the current response of the root chat session (if exists) or the current chat session. */
273+ private void setCurrentResponse (Optional <GenerateContentResponse > response ) {
274+ if (currentResponseStream .isPresent ()) {
275+ throw new IllegalStateException (
276+ "currentResponse and currentResponseStream cannot be set together" );
277+ }
278+ if (rootChatSession .isPresent ()) {
279+ rootChatSession .get ().setCurrentResponse (response );
280+ } else {
281+ currentResponse = response ;
282+ }
283+ }
284+
285+ /**
286+ * Sets the current responseStream of the root chat session (if exists) or the current chat
287+ * session.
288+ */
289+ private void setCurrentResponseStream (
290+ Optional <ResponseStream <GenerateContentResponse >> responseStream ) {
291+ if (currentResponse .isPresent ()) {
292+ throw new IllegalStateException (
293+ "currentResponseStream and currentResponse cannot be set together" );
294+ }
295+ if (rootChatSession .isPresent ()) {
296+ rootChatSession .get ().setCurrentResponseStream (responseStream );
297+ } else {
298+ currentResponseStream = responseStream ;
299+ }
300+ }
177301}
0 commit comments