Skip to content

Commit

Permalink
BREAKING_CHANGE: [vertexai] remove Transport from GenerativeModel (#1…
Browse files Browse the repository at this point in the history
…0530)

PiperOrigin-RevId: 615144883

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li committed Mar 16, 2024
1 parent e153330 commit f024111
Show file tree
Hide file tree
Showing 22 changed files with 8,102 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public class GenerativeModel {
private GenerationConfig generationConfig = null;
private List<SafetySetting> safetySettings = null;
private List<Tool> tools = null;
private Transport transport;

public static Builder newBuilder() {
return new Builder();
Expand All @@ -67,12 +66,6 @@ private GenerativeModel(Builder builder) {
if (builder.tools != null) {
this.tools = builder.tools;
}

if (builder.transport != null) {
this.transport = builder.transport;
} else {
this.transport = this.vertexAi.getTransport();
}
}

/** Builder class for {@link GenerativeModel}. */
Expand All @@ -82,7 +75,6 @@ public static class Builder {
private GenerationConfig generationConfig;
private List<SafetySetting> safetySettings;
private List<Tool> tools;
private Transport transport;

private Builder() {}

Expand Down Expand Up @@ -158,15 +150,6 @@ public Builder setTools(List<Tool> tools) {
}
return this;
}

/**
* Sets the {@link Transport} layer for API calls in the generative model. It overrides the
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public Builder setTransport(Transport transport) {
this.transport = transport;
return this;
}
}

/**
Expand All @@ -180,21 +163,7 @@ public Builder setTransport(Transport transport) {
* for the generative model
*/
public GenerativeModel(String modelName, VertexAI vertexAi) {
this(modelName, null, null, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport) {
this(modelName, null, null, vertexAi, transport);
this(modelName, null, null, vertexAi);
}

/**
Expand All @@ -209,25 +178,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi, Transport transport)
*/
@BetaApi
public GenerativeModel(String modelName, GenerationConfig generationConfig, VertexAI vertexAi) {
this(modelName, generationConfig, null, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default generation config.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
* will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi
public GenerativeModel(
String modelName, GenerationConfig generationConfig, VertexAI vertexAi, Transport transport) {
this(modelName, generationConfig, null, vertexAi, transport);
this(modelName, generationConfig, null, vertexAi);
}

/**
Expand All @@ -242,28 +193,7 @@ public GenerativeModel(
*/
@BetaApi("safetySettings is a preview feature.")
public GenerativeModel(String modelName, List<SafetySetting> safetySettings, VertexAI vertexAi) {
this(modelName, null, safetySettings, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default safety settings.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
* that will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi("safetySettings is a preview feature.")
public GenerativeModel(
String modelName,
List<SafetySetting> safetySettings,
VertexAI vertexAi,
Transport transport) {
this(modelName, null, safetySettings, vertexAi, transport);
this(modelName, null, safetySettings, vertexAi);
}

/**
Expand All @@ -284,30 +214,6 @@ public GenerativeModel(
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
VertexAI vertexAi) {
this(modelName, generationConfig, safetySettings, vertexAi, null);
}

/**
* Constructs a GenerativeModel instance with default generation config and safety settings.
*
* @param modelName the name of the generative model. Supported format: "gemini-pro",
* "models/gemini-pro", "publishers/google/models/gemini-pro"
* @param generationConfig a {@link com.google.cloud.vertexai.api.GenerationConfig} instance that
* will be used by default for generating response
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.SafetySetting} instances
* that will be used by default for generating response
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
* @param transport the {@link Transport} layer for API calls in the generative model. It
* overrides the transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
@BetaApi
public GenerativeModel(
String modelName,
GenerationConfig generationConfig,
List<SafetySetting> safetySettings,
VertexAI vertexAi,
Transport transport) {
modelName = reconcileModelName(modelName);
this.modelName = modelName;
this.resourceName =
Expand All @@ -324,11 +230,6 @@ public GenerativeModel(
}
}
this.vertexAi = vertexAi;
if (transport != null) {
this.transport = transport;
} else {
this.transport = vertexAi.getTransport();
}
}

/**
Expand Down Expand Up @@ -388,7 +289,7 @@ public CountTokensResponse countTokens(List<Content> contents) throws IOExceptio
@BetaApi
private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getLlmUtilityRestClient().countTokens(request);
} else {
return vertexAi.getLlmUtilityClient().countTokens(request);
Expand Down Expand Up @@ -619,7 +520,7 @@ public GenerateContentResponse generateContent(
*/
private GenerateContentResponse generateContent(GenerateContentRequest request)
throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request);
} else {
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
Expand Down Expand Up @@ -1031,7 +932,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
*/
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest request) throws IOException {
if (this.transport == Transport.REST) {
if (vertexAi.getTransport() == Transport.REST) {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
Expand Down Expand Up @@ -1082,24 +983,11 @@ public void setTools(List<Tool> tools) {
}
}

/**
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
* generative model.
*/
public void setTransport(Transport transport) {
this.transport = transport;
}

/** Returns the model name of this generative model. */
public String getModelName() {
return this.modelName;
}

/** Returns the {@link Transport} layer for API calls in this generative model. */
public Transport getTransport() {
return this.transport;
}

/**
* Returns the {@link com.google.cloud.vertexai.api.GenerationConfig} of this generative model.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
Expand All @@ -35,15 +34,18 @@
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.Retrieval;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
Expand Down Expand Up @@ -96,14 +98,30 @@ public final class GenerativeModelTest {
.build())
.addRequired("location")))
.build();
private static final Tool GOOGLE_SEARCH_TOOL =
Tool.newBuilder()
.setGoogleSearchRetrieval(GoogleSearchRetrieval.newBuilder().setDisableAttribution(false))
.build();
private static final Tool VERTEX_AI_SEARCH_TOOL =
Tool.newBuilder()
.setRetrieval(
Retrieval.newBuilder()
.setVertexAiSearch(
VertexAISearch.newBuilder()
.setDatastore(
String.format(
"projects/%s/locations/%s/collections/%s/dataStores/%s",
PROJECT, "global", "default_collection", "test_123")))
.setDisableAttribution(false))
.build();

private static final String TEXT = "What is your name?";

private VertexAI vertexAi;
private GenerativeModel model;
private List<SafetySetting> safetySettings = Arrays.asList(SAFETY_SETTING);
private List<SafetySetting> defaultSafetySettings = Arrays.asList(DEFAULT_SAFETY_SETTING);
private List<Tool> tools = Arrays.asList(TOOL);
private List<Tool> tools = Arrays.asList(TOOL, GOOGLE_SEARCH_TOOL, VERTEX_AI_SEARCH_TOOL);

@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();

Expand Down Expand Up @@ -169,7 +187,6 @@ public void testInstantiateGenerativeModelwithBuilder() {
assertThat(model.getGenerationConfig()).isNull();
assertThat(model.getSafetySettings()).isNull();
assertThat(model.getTools()).isNull();
assertThat(model.getTransport()).isEqualTo(Transport.GRPC);
}

@Test
Expand All @@ -181,13 +198,11 @@ public void testInstantiateGenerativeModelwithBuilderAllConfigs() {
.setGenerationConfig(GENERATION_CONFIG)
.setSafetySettings(safetySettings)
.setTools(tools)
.setTransport(Transport.REST)
.build();
assertThat(model.getModelName()).isEqualTo(MODEL_NAME);
assertThat(model.getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(model.getSafetySettings()).isEqualTo(safetySettings);
assertThat(model.getTools()).isEqualTo(tools);
assertThat(model.getTransport()).isEqualTo(Transport.REST);
}

@Test
Expand Down

0 comments on commit f024111

Please sign in to comment.