Skip to content

Commit

Permalink
feat: [vertexai] add GenerateContentConfig to generateContent method (#…
Browse files Browse the repository at this point in the history
…10425)

PiperOrigin-RevId: 609363710

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li committed Feb 28, 2024
1 parent ec9dd00 commit 903abf3
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
Expand Up @@ -409,6 +409,22 @@ public GenerateContentResponse generateContent(String text) throws IOException {
return generateContent(text, null, null);
}

/**
* Generates content from generative model given a text and configs.
*
* @param text a text message to send to the generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains
* response contents and other metadata
* @throws IOException if an I/O error occurs while making the API call
*/
@BetaApi
public GenerateContentResponse generateContent(String text, GenerateContentConfig config)
throws IOException {
return generateContent(ContentMaker.fromString(text), config);
}

/**
* Generate content from generative model given a text and generation config.
*
Expand Down Expand Up @@ -511,6 +527,41 @@ public GenerateContentResponse generateContent(
return generateContent(contents, null, safetySettings);
}

/**
* Generates content from generative model given a list of contents and configs.
*
* @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the
* generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains
* response contents and other metadata
* @throws IOException if an I/O error occurs while making the API call
*/
@BetaApi
public GenerateContentResponse generateContent(
List<Content> contents, GenerateContentConfig config) throws IOException {
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (config.getGenerationConfig() != null) {
requestBuilder.setGenerationConfig(config.getGenerationConfig());
} else if (this.generationConfig != null) {
requestBuilder.setGenerationConfig(this.generationConfig);
}
if (config.getSafetySettings().isEmpty() == false) {
requestBuilder.addAllSafetySettings(config.getSafetySettings());
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (config.getTools().isEmpty() == false) {
requestBuilder.addAllTools(config.getTools());
} else if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}

return generateContent(requestBuilder);
}

/**
* Generate content from generative model given a list of contents, generation config, and safety
* settings.
Expand Down Expand Up @@ -581,6 +632,22 @@ public GenerateContentResponse generateContent(Content content) throws IOExcepti
return generateContent(content, null, null);
}

/**
* Generates content from generative model given a single content and configs.
*
* @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance that contains
* response contents and other metadata
* @throws IOException if an I/O error occurs while making the API call
*/
@BetaApi
public GenerateContentResponse generateContent(Content content, GenerateContentConfig config)
throws IOException {
return generateContent(Arrays.asList(content), config);
}

/**
* Generate content from this model given a single content and generation config.
*
Expand Down
Expand Up @@ -452,6 +452,36 @@ public void testGenerateContentwithDefaultTools() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentwithGenerateContentConfig() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
GenerateContentConfig config =
GenerateContentConfig.newBuilder()
.setGenerationConfig(GENERATION_CONFIG)
.setSafetySettings(safetySettings)
.setTools(tools)
.build();

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

GenerateContentResponse unused = model.generateContent(TEXT, config);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());

assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down

0 comments on commit 903abf3

Please sign in to comment.