diff --git a/auth/auth.go b/auth/auth.go index 322cf297..c12a82ed 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -93,16 +93,25 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error) return nil, err } - hc, _, err := transport.NewHTTPClient(ctx, conf.Opts...) + transport, _, err := transport.NewHTTPClient(ctx, conf.Opts...) if err != nil { return nil, err } + hc := internal.WithDefaultRetryConfig(transport) + hc.CreateErrFn = handleHTTPError + hc.SuccessFn = internal.HasSuccessStatus + hc.Opts = []internal.HTTPOption{ + internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)), + } + base := &baseClient{ - userManagementClient: newUserManagementClient(hc, conf), - providerConfigClient: newProviderConfigClient(hc, conf), - idTokenVerifier: idTokenVerifier, - cookieVerifier: cookieVerifier, + userManagementEndpoint: idToolkitV1Endpoint, + providerConfigEndpoint: providerConfigEndpoint, + projectID: conf.ProjectID, + httpClient: hc, + idTokenVerifier: idTokenVerifier, + cookieVerifier: cookieVerifier, } return &Client{ baseClient: base, @@ -183,7 +192,7 @@ func (c *Client) SessionCookie( idToken string, expiresIn time.Duration, ) (string, error) { - return c.baseClient.userManagementClient.createSessionCookie(ctx, idToken, expiresIn) + return c.baseClient.createSessionCookie(ctx, idToken, expiresIn) } // Token represents a decoded Firebase ID token. @@ -213,22 +222,21 @@ type FirebaseInfo struct { Identities map[string]interface{} `json:"identities"` } +// baseClient exposes the APIs common to both auth.Client and auth.TenantClient. type baseClient struct { - *userManagementClient - *providerConfigClient - idTokenVerifier *tokenVerifier - cookieVerifier *tokenVerifier - tenantID string + userManagementEndpoint string + providerConfigEndpoint string + projectID string + tenantID string + httpClient *internal.HTTPClient + idTokenVerifier *tokenVerifier + cookieVerifier *tokenVerifier } func (c *baseClient) withTenantID(tenantID string) *baseClient { - return &baseClient{ - userManagementClient: c.userManagementClient.withTenantID(tenantID), - providerConfigClient: c.providerConfigClient.withTenantID(tenantID), - idTokenVerifier: c.idTokenVerifier, - cookieVerifier: c.cookieVerifier, - tenantID: tenantID, - } + copy := *c + copy.tenantID = tenantID + return © } // VerifyIDToken verifies the signature and payload of the provided ID token. diff --git a/auth/auth_test.go b/auth/auth_test.go index ca94a307..72ad19d9 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -100,8 +100,8 @@ func TestNewClientWithServiceAccountCredentials(t *testing.T) { if err := checkCookieVerifier(client.cookieVerifier, creds.ProjectID); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } - if err := checkUserManagementClient(client, creds.ProjectID); err != nil { - t.Errorf("NewClient().userManagementClient: %v", err) + if err := checkBaseClient(client, creds.ProjectID); err != nil { + t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) @@ -127,8 +127,8 @@ func TestNewClientWithoutCredentials(t *testing.T) { if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } - if err := checkUserManagementClient(client, ""); err != nil { - t.Errorf("NewClient().userManagementClient: %v", err) + if err := checkBaseClient(client, ""); err != nil { + t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) @@ -155,8 +155,8 @@ func TestNewClientWithServiceAccountID(t *testing.T) { if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } - if err := checkUserManagementClient(client, ""); err != nil { - t.Errorf("NewClient().userManagementClient: %v", err) + if err := checkBaseClient(client, ""); err != nil { + t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) @@ -194,8 +194,8 @@ func TestNewClientWithUserCredentials(t *testing.T) { if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } - if err := checkUserManagementClient(client, ""); err != nil { - t.Errorf("NewClient().userManagementClient: %v", err) + if err := checkBaseClient(client, ""); err != nil { + t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) @@ -1057,10 +1057,13 @@ func checkCookieVerifier(tv *tokenVerifier, projectID string) error { return nil } -func checkUserManagementClient(client *Client, wantProjectID string) error { - umc := client.userManagementClient - if umc.baseURL != idToolkitV1Endpoint { - return fmt.Errorf("baseURL = %q; want = %q", umc.baseURL, idToolkitV1Endpoint) +func checkBaseClient(client *Client, wantProjectID string) error { + umc := client.baseClient + if umc.userManagementEndpoint != idToolkitV1Endpoint { + return fmt.Errorf("userManagementEndpoint = %q; want = %q", umc.userManagementEndpoint, idToolkitV1Endpoint) + } + if umc.providerConfigEndpoint != providerConfigEndpoint { + return fmt.Errorf("providerConfigEndpoint = %q; want = %q", umc.providerConfigEndpoint, providerConfigEndpoint) } if umc.projectID != wantProjectID { return fmt.Errorf("projectID = %q; want = %q", umc.projectID, wantProjectID) diff --git a/auth/email_action_links.go b/auth/email_action_links.go index 25e1bf15..6b649254 100644 --- a/auth/email_action_links.go +++ b/auth/email_action_links.go @@ -71,38 +71,38 @@ const ( // EmailVerificationLink generates the out-of-band email action link for email verification flows for the specified // email address. -func (c *userManagementClient) EmailVerificationLink(ctx context.Context, email string) (string, error) { +func (c *baseClient) EmailVerificationLink(ctx context.Context, email string) (string, error) { return c.EmailVerificationLinkWithSettings(ctx, email, nil) } // EmailVerificationLinkWithSettings generates the out-of-band email action link for email verification flows for the // specified email address, using the action code settings provided. -func (c *userManagementClient) EmailVerificationLinkWithSettings( +func (c *baseClient) EmailVerificationLinkWithSettings( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, emailVerification, email, settings) } // PasswordResetLink generates the out-of-band email action link for password reset flows for the specified email // address. -func (c *userManagementClient) PasswordResetLink(ctx context.Context, email string) (string, error) { +func (c *baseClient) PasswordResetLink(ctx context.Context, email string) (string, error) { return c.PasswordResetLinkWithSettings(ctx, email, nil) } // PasswordResetLinkWithSettings generates the out-of-band email action link for password reset flows for the // specified email address, using the action code settings provided. -func (c *userManagementClient) PasswordResetLinkWithSettings( +func (c *baseClient) PasswordResetLinkWithSettings( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, passwordReset, email, settings) } // EmailSignInLink generates the out-of-band email action link for email link sign-in flows, using the action // code settings provided. -func (c *userManagementClient) EmailSignInLink( +func (c *baseClient) EmailSignInLink( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, emailLinkSignIn, email, settings) } -func (c *userManagementClient) generateEmailActionLink( +func (c *baseClient) generateEmailActionLink( ctx context.Context, linkType linkType, email string, settings *ActionCodeSettings) (string, error) { if email == "" { diff --git a/auth/email_action_links_test.go b/auth/email_action_links_test.go index 21780c0e..5db51d60 100644 --- a/auth/email_action_links_test.go +++ b/auth/email_action_links_test.go @@ -273,7 +273,7 @@ func TestEmailVerificationLinkError(t *testing.T) { } s := echoServer(testActionLinkResponse, t) defer s.Close() - s.Client.userManagementClient.httpClient.RetryConfig = nil + s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError for code, check := range cases { diff --git a/auth/export_users.go b/auth/export_users.go index 955990f5..baeb5bbc 100644 --- a/auth/export_users.go +++ b/auth/export_users.go @@ -31,7 +31,7 @@ const maxReturnedResults = 1000 // // If nextPageToken is empty, the iterator will start at the beginning. // If the nextPageToken is not empty, the iterator starts after the token. -func (c *userManagementClient) Users(ctx context.Context, nextPageToken string) *UserIterator { +func (c *baseClient) Users(ctx context.Context, nextPageToken string) *UserIterator { it := &UserIterator{ ctx: ctx, client: c, @@ -49,7 +49,7 @@ func (c *userManagementClient) Users(ctx context.Context, nextPageToken string) // // Also see: https://github.com/GoogleCloudPlatform/google-cloud-go/wiki/Iterator-Guidelines type UserIterator struct { - client *userManagementClient + client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo diff --git a/auth/import_users.go b/auth/import_users.go index 22cf1255..cfb92493 100644 --- a/auth/import_users.go +++ b/auth/import_users.go @@ -50,7 +50,7 @@ type ErrorInfo struct { // // No more than 1000 users can be imported in a single call. If at least one user specifies a // password, a UserImportHash must be specified as an option. -func (c *userManagementClient) ImportUsers( +func (c *baseClient) ImportUsers( ctx context.Context, users []*UserToImport, opts ...UserImportOption) (*UserImportResult, error) { if len(users) == 0 { diff --git a/auth/provider_config.go b/auth/provider_config.go index 8f208944..ac87f4a4 100644 --- a/auth/provider_config.go +++ b/auth/provider_config.go @@ -247,7 +247,7 @@ func (config *OIDCProviderConfigToUpdate) buildRequest() (nestedMap, error) { // OIDCProviderConfigIterator is an iterator over OIDC provider configurations. type OIDCProviderConfigIterator struct { - client *providerConfigClient + client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo @@ -538,7 +538,7 @@ func (config *SAMLProviderConfigToUpdate) buildRequest() (nestedMap, error) { // SAMLProviderConfigIterator is an iterator over SAML provider configurations. type SAMLProviderConfigIterator struct { - client *providerConfigClient + client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo @@ -595,36 +595,8 @@ func (it *SAMLProviderConfigIterator) fetch(pageSize int, pageToken string) (str return result.NextPageToken, nil } -type providerConfigClient struct { - endpoint string - projectID string - tenantID string - httpClient *internal.HTTPClient -} - -func newProviderConfigClient(client *http.Client, conf *internal.AuthConfig) *providerConfigClient { - hc := internal.WithDefaultRetryConfig(client) - hc.CreateErrFn = handleHTTPError - hc.SuccessFn = internal.HasSuccessStatus - hc.Opts = []internal.HTTPOption{ - internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)), - } - - return &providerConfigClient{ - endpoint: providerConfigEndpoint, - projectID: conf.ProjectID, - httpClient: hc, - } -} - -func (c *providerConfigClient) withTenantID(tenantID string) *providerConfigClient { - copy := *c - copy.tenantID = tenantID - return © -} - // OIDCProviderConfig returns the OIDCProviderConfig with the given ID. -func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) { +func (c *baseClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) { if err := validateOIDCConfigID(id); err != nil { return nil, err } @@ -642,7 +614,7 @@ func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string } // CreateOIDCProviderConfig creates a new OIDC provider config from the given parameters. -func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) { +func (c *baseClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) { if config == nil { return nil, errors.New("config must not be nil") } @@ -669,7 +641,7 @@ func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, con } // UpdateOIDCProviderConfig updates an existing OIDC provider config with the given parameters. -func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) { +func (c *baseClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) { if err := validateOIDCConfigID(id); err != nil { return nil, err } @@ -704,7 +676,7 @@ func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id } // DeleteOIDCProviderConfig deletes the OIDCProviderConfig with the given ID. -func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error { +func (c *baseClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error { if err := validateOIDCConfigID(id); err != nil { return err } @@ -721,7 +693,7 @@ func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id // // If nextPageToken is empty, the iterator will start at the beginning. Otherwise, // iterator starts after the token. -func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator { +func (c *baseClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator { it := &OIDCProviderConfigIterator{ ctx: ctx, client: c, @@ -736,7 +708,7 @@ func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPage } // SAMLProviderConfig returns the SAMLProviderConfig with the given ID. -func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) { +func (c *baseClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) { if err := validateSAMLConfigID(id); err != nil { return nil, err } @@ -754,7 +726,7 @@ func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string } // CreateSAMLProviderConfig creates a new SAML provider config from the given parameters. -func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) { +func (c *baseClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) { if config == nil { return nil, errors.New("config must not be nil") } @@ -781,7 +753,7 @@ func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, con } // UpdateSAMLProviderConfig updates an existing SAML provider config with the given parameters. -func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) { +func (c *baseClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) { if err := validateSAMLConfigID(id); err != nil { return nil, err } @@ -816,7 +788,7 @@ func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id } // DeleteSAMLProviderConfig deletes the SAMLProviderConfig with the given ID. -func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error { +func (c *baseClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error { if err := validateSAMLConfigID(id); err != nil { return err } @@ -833,7 +805,7 @@ func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id // // If nextPageToken is empty, the iterator will start at the beginning. Otherwise, // iterator starts after the token. -func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator { +func (c *baseClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator { it := &SAMLProviderConfigIterator{ ctx: ctx, client: c, @@ -847,15 +819,17 @@ func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPage return it } -func (c *providerConfigClient) makeRequest(ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { +func (c *baseClient) makeRequest( + ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { + if c.projectID == "" { return nil, errors.New("project id not available") } if c.tenantID != "" { - req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.endpoint, c.projectID, c.tenantID, req.URL) + req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.providerConfigEndpoint, c.projectID, c.tenantID, req.URL) } else { - req.URL = fmt.Sprintf("%s/projects/%s%s", c.endpoint, c.projectID, req.URL) + req.URL = fmt.Sprintf("%s/projects/%s%s", c.providerConfigEndpoint, c.projectID, req.URL) } return c.httpClient.DoAndUnmarshal(ctx, req, v) diff --git a/auth/provider_config_test.go b/auth/provider_config_test.go index bb707527..0bde823f 100644 --- a/auth/provider_config_test.go +++ b/auth/provider_config_test.go @@ -123,7 +123,7 @@ func TestOIDCProviderConfig(t *testing.T) { } func TestOIDCProviderConfigInvalidID(t *testing.T) { - client := &providerConfigClient{} + client := &baseClient{} wantErr := "invalid OIDC provider id: " for _, id := range invalidOIDCConfigIDs { @@ -241,7 +241,7 @@ func TestCreateOIDCProviderConfigError(t *testing.T) { defer s.Close() client := s.Client - client.providerConfigClient.httpClient.RetryConfig = nil + client.baseClient.httpClient.RetryConfig = nil options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). ClientID(oidcProviderConfig.ClientID). @@ -304,7 +304,7 @@ func TestCreateOIDCProviderConfigInvalidInput(t *testing.T) { }, } - client := &providerConfigClient{} + client := &baseClient{} for _, tc := range cases { _, err := client.CreateOIDCProviderConfig(context.Background(), tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { @@ -408,7 +408,7 @@ func TestUpdateOIDCProviderConfigZeroValues(t *testing.T) { func TestUpdateOIDCProviderConfigInvalidID(t *testing.T) { cases := []string{"", "saml.config"} - client := &providerConfigClient{} + client := &baseClient{} options := (&OIDCProviderConfigToUpdate{}). DisplayName("") want := "invalid OIDC provider id: " @@ -456,7 +456,7 @@ func TestUpdateOIDCProviderConfigInvalidInput(t *testing.T) { }, } - client := &providerConfigClient{} + client := &baseClient{} for _, tc := range cases { _, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { @@ -486,7 +486,7 @@ func TestDeleteOIDCProviderConfig(t *testing.T) { } func TestDeleteOIDCProviderConfigInvalidID(t *testing.T) { - client := &providerConfigClient{} + client := &baseClient{} wantErr := "invalid OIDC provider id: " for _, id := range invalidOIDCConfigIDs { @@ -580,7 +580,7 @@ func TestOIDCProviderConfigsError(t *testing.T) { s.Status = http.StatusInternalServerError client := s.Client - client.providerConfigClient.httpClient.RetryConfig = nil + client.baseClient.httpClient.RetryConfig = nil it := client.OIDCProviderConfigs(context.Background(), "") config, err := it.Next() if config != nil || err == nil || !IsUnknown(err) { @@ -614,7 +614,7 @@ func TestSAMLProviderConfig(t *testing.T) { } func TestSAMLProviderConfigInvalidID(t *testing.T) { - client := &providerConfigClient{} + client := &baseClient{} wantErr := "invalid SAML provider id: " for _, id := range invalidSAMLConfigIDs { @@ -766,7 +766,7 @@ func TestCreateSAMLProviderConfigError(t *testing.T) { defer s.Close() client := s.Client - client.providerConfigClient.httpClient.RetryConfig = nil + client.baseClient.httpClient.RetryConfig = nil options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). IDPEntityID(samlProviderConfig.IDPEntityID). @@ -879,7 +879,7 @@ func TestCreateSAMLProviderConfigInvalidInput(t *testing.T) { }, } - client := &providerConfigClient{} + client := &baseClient{} for _, tc := range cases { _, err := client.CreateSAMLProviderConfig(context.Background(), tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { @@ -1004,7 +1004,7 @@ func TestUpdateSAMLProviderConfigZeroValues(t *testing.T) { func TestUpdateSAMLProviderConfigInvalidID(t *testing.T) { cases := []string{"", "oidc.config"} - client := &providerConfigClient{} + client := &baseClient{} options := (&SAMLProviderConfigToUpdate{}). DisplayName(""). Enabled(false). @@ -1084,7 +1084,7 @@ func TestUpdateSAMLProviderConfigInvalidInput(t *testing.T) { }, } - client := &providerConfigClient{} + client := &baseClient{} for _, tc := range cases { _, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { @@ -1114,7 +1114,7 @@ func TestDeleteSAMLProviderConfig(t *testing.T) { } func TestDeleteSAMLProviderConfigInvalidID(t *testing.T) { - client := &providerConfigClient{} + client := &baseClient{} wantErr := "invalid SAML provider id: " for _, id := range invalidSAMLConfigIDs { @@ -1208,7 +1208,7 @@ func TestSAMLProviderConfigsError(t *testing.T) { s.Status = http.StatusInternalServerError client := s.Client - client.providerConfigClient.httpClient.RetryConfig = nil + client.baseClient.httpClient.RetryConfig = nil it := client.SAMLProviderConfigs(context.Background(), "") config, err := it.Next() if config != nil || err == nil || !IsUnknown(err) { @@ -1217,7 +1217,7 @@ func TestSAMLProviderConfigsError(t *testing.T) { } func TestSAMLProviderConfigNoProjectID(t *testing.T) { - client := &providerConfigClient{} + client := &baseClient{} want := "project id not available" if _, err := client.SAMLProviderConfig(context.Background(), "saml.provider"); err == nil || err.Error() != want { t.Errorf("SAMLProviderConfig() = %v; want = %q", err, want) diff --git a/auth/tenant_mgt.go b/auth/tenant_mgt.go index 50f1f50d..d35ab139 100644 --- a/auth/tenant_mgt.go +++ b/auth/tenant_mgt.go @@ -85,19 +85,12 @@ type TenantManager struct { httpClient *internal.HTTPClient } -func newTenantManager(client *http.Client, conf *internal.AuthConfig, base *baseClient) *TenantManager { - hc := internal.WithDefaultRetryConfig(client) - hc.CreateErrFn = handleHTTPError - hc.SuccessFn = internal.HasSuccessStatus - hc.Opts = []internal.HTTPOption{ - internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)), - } - +func newTenantManager(client *internal.HTTPClient, conf *internal.AuthConfig, base *baseClient) *TenantManager { return &TenantManager{ base: base, endpoint: tenantMgtEndpoint, projectID: conf.ProjectID, - httpClient: hc, + httpClient: client, } } diff --git a/auth/tenant_mgt_test.go b/auth/tenant_mgt_test.go index 166f646f..ec5811d0 100644 --- a/auth/tenant_mgt_test.go +++ b/auth/tenant_mgt_test.go @@ -55,12 +55,8 @@ func TestTenantID(t *testing.T) { t.Errorf("TenantID() = %q; want = %q", tenantID, want) } - if client.userManagementClient.tenantID != want { - t.Errorf("userManagementClient.tenantID = %q; want = %q", client.userManagementClient.tenantID, want) - } - - if client.providerConfigClient.tenantID != want { - t.Errorf("providerConfigClient.tenantID = %q; want = %q", client.providerConfigClient.tenantID, want) + if client.baseClient.tenantID != want { + t.Errorf("baseClient.tenantID = %q; want = %q", client.baseClient.tenantID, want) } } diff --git a/auth/user_mgt.go b/auth/user_mgt.go index b95dc542..e1a18a0f 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -493,37 +493,8 @@ func validatePhone(phone string) error { // End of validators -// userManagementClient is a helper for interacting with the Identity Toolkit REST API. -type userManagementClient struct { - baseURL string - projectID string - tenantID string - httpClient *internal.HTTPClient -} - -func newUserManagementClient(client *http.Client, conf *internal.AuthConfig) *userManagementClient { - hc := internal.WithDefaultRetryConfig(client) - hc.CreateErrFn = handleHTTPError - hc.SuccessFn = internal.HasSuccessStatus - hc.Opts = []internal.HTTPOption{ - internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)), - } - - return &userManagementClient{ - baseURL: idToolkitV1Endpoint, - projectID: conf.ProjectID, - httpClient: hc, - } -} - -func (c *userManagementClient) withTenantID(tenantID string) *userManagementClient { - copy := *c - copy.tenantID = tenantID - return © -} - // GetUser gets the user data corresponding to the specified user ID. -func (c *userManagementClient) GetUser(ctx context.Context, uid string) (*UserRecord, error) { +func (c *baseClient) GetUser(ctx context.Context, uid string) (*UserRecord, error) { return c.getUser(ctx, &userQuery{ field: "localId", value: uid, @@ -532,7 +503,7 @@ func (c *userManagementClient) GetUser(ctx context.Context, uid string) (*UserRe } // GetUserByEmail gets the user data corresponding to the specified email. -func (c *userManagementClient) GetUserByEmail(ctx context.Context, email string) (*UserRecord, error) { +func (c *baseClient) GetUserByEmail(ctx context.Context, email string) (*UserRecord, error) { if err := validateEmail(email); err != nil { return nil, err } @@ -543,7 +514,7 @@ func (c *userManagementClient) GetUserByEmail(ctx context.Context, email string) } // GetUserByPhoneNumber gets the user data corresponding to the specified user phone number. -func (c *userManagementClient) GetUserByPhoneNumber(ctx context.Context, phone string) (*UserRecord, error) { +func (c *baseClient) GetUserByPhoneNumber(ctx context.Context, phone string) (*UserRecord, error) { if err := validatePhone(phone); err != nil { return nil, err } @@ -574,7 +545,7 @@ func (q *userQuery) build() map[string]interface{} { } } -func (c *userManagementClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { +func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { var parsed struct { Users []*userQueryResponse `json:"users"` } @@ -665,7 +636,7 @@ func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error } // CreateUser creates a new user with the specified properties. -func (c *userManagementClient) CreateUser(ctx context.Context, user *UserToCreate) (*UserRecord, error) { +func (c *baseClient) CreateUser(ctx context.Context, user *UserToCreate) (*UserRecord, error) { uid, err := c.createUser(ctx, user) if err != nil { return nil, err @@ -673,7 +644,7 @@ func (c *userManagementClient) CreateUser(ctx context.Context, user *UserToCreat return c.GetUser(ctx, uid) } -func (c *userManagementClient) createUser(ctx context.Context, user *UserToCreate) (string, error) { +func (c *baseClient) createUser(ctx context.Context, user *UserToCreate) (string, error) { if user == nil { user = &UserToCreate{} } @@ -691,7 +662,7 @@ func (c *userManagementClient) createUser(ctx context.Context, user *UserToCreat } // UpdateUser updates an existing user account with the specified properties. -func (c *userManagementClient) UpdateUser( +func (c *baseClient) UpdateUser( ctx context.Context, uid string, user *UserToUpdate) (ur *UserRecord, err error) { if err := c.updateUser(ctx, uid, user); err != nil { return nil, err @@ -707,7 +678,7 @@ func (c *userManagementClient) UpdateUser( // While this revokes all sessions for a specified user and disables any new ID tokens for existing sessions // from getting minted, existing ID tokens may remain active until their natural expiration (one hour). // To verify that ID tokens are revoked, use `verifyIdTokenAndCheckRevoked(ctx, idToken)`. -func (c *userManagementClient) RevokeRefreshTokens(ctx context.Context, uid string) error { +func (c *baseClient) RevokeRefreshTokens(ctx context.Context, uid string) error { return c.updateUser(ctx, uid, (&UserToUpdate{}).revokeRefreshTokens()) } @@ -719,14 +690,14 @@ func (c *userManagementClient) RevokeRefreshTokens(ctx context.Context, uid stri // can be accessed via the user's ID token JWT. If a reserved OIDC claim is specified (sub, iat, // iss, etc), an error is thrown. Claims payload must also not be larger then 1000 characters // when serialized into a JSON string. -func (c *userManagementClient) SetCustomUserClaims(ctx context.Context, uid string, customClaims map[string]interface{}) error { +func (c *baseClient) SetCustomUserClaims(ctx context.Context, uid string, customClaims map[string]interface{}) error { if customClaims == nil || len(customClaims) == 0 { customClaims = map[string]interface{}{} } return c.updateUser(ctx, uid, (&UserToUpdate{}).CustomClaims(customClaims)) } -func (c *userManagementClient) updateUser(ctx context.Context, uid string, user *UserToUpdate) error { +func (c *baseClient) updateUser(ctx context.Context, uid string, user *UserToUpdate) error { if err := validateUID(uid); err != nil { return err } @@ -745,7 +716,7 @@ func (c *userManagementClient) updateUser(ctx context.Context, uid string, user } // DeleteUser deletes the user by the given UID. -func (c *userManagementClient) DeleteUser(ctx context.Context, uid string) error { +func (c *baseClient) DeleteUser(ctx context.Context, uid string) error { if err := validateUID(uid); err != nil { return err } @@ -763,7 +734,7 @@ func (c *userManagementClient) DeleteUser(ctx context.Context, uid string) error // // This function is only exposed via [auth.Client] for now, since the tenant-scoped variant // of it is currently not supported. -func (c *userManagementClient) createSessionCookie( +func (c *baseClient) createSessionCookie( ctx context.Context, idToken string, expiresIn time.Duration, @@ -788,7 +759,7 @@ func (c *userManagementClient) createSessionCookie( return result.SessionCookie, err } -func (c *userManagementClient) post( +func (c *baseClient) post( ctx context.Context, path string, payload, resp interface{}, @@ -807,16 +778,16 @@ func (c *userManagementClient) post( return c.httpClient.DoAndUnmarshal(ctx, req, resp) } -func (c *userManagementClient) makeUserMgtURL(path string) (string, error) { +func (c *baseClient) makeUserMgtURL(path string) (string, error) { if c.projectID == "" { return "", errors.New("project id not available") } var url string if c.tenantID != "" { - url = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.baseURL, c.projectID, c.tenantID, path) + url = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.userManagementEndpoint, c.projectID, c.tenantID, path) } else { - url = fmt.Sprintf("%s/projects/%s%s", c.baseURL, c.projectID, path) + url = fmt.Sprintf("%s/projects/%s%s", c.userManagementEndpoint, c.projectID, path) } return url, nil diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 49a317bc..b2591f1e 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -141,9 +141,7 @@ func TestGetUserByPhoneNumber(t *testing.T) { func TestInvalidGetUser(t *testing.T) { client := &Client{ - baseClient: &baseClient{ - userManagementClient: &userManagementClient{}, - }, + baseClient: &baseClient{}, } user, err := client.GetUser(context.Background(), "") if user != nil || err == nil { @@ -1217,9 +1215,7 @@ func TestSessionCookieError(t *testing.T) { func TestSessionCookieWithoutProjectID(t *testing.T) { client := &Client{ - baseClient: &baseClient{ - userManagementClient: &userManagementClient{}, - }, + baseClient: &baseClient{}, } _, err := client.SessionCookie(context.Background(), "idToken", 10*time.Minute) want := "project id not available" @@ -1260,7 +1256,7 @@ func TestSessionCookieLongExpiresIn(t *testing.T) { func TestHTTPError(t *testing.T) { s := echoServer([]byte(`{"error":"test"}`), t) defer s.Close() - s.Client.userManagementClient.httpClient.RetryConfig = nil + s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError u, err := s.Client.GetUser(context.Background(), "some uid") @@ -1287,7 +1283,7 @@ func TestHTTPErrorWithCode(t *testing.T) { } s := echoServer(nil, t) defer s.Close() - s.Client.userManagementClient.httpClient.RetryConfig = nil + s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError for code, check := range errorCodes { @@ -1382,8 +1378,8 @@ func echoServer(resp interface{}, t *testing.T) *mockAuthServer { t.Fatal(err) } - authClient.userManagementClient.baseURL = s.Srv.URL - authClient.providerConfigClient.endpoint = s.Srv.URL + authClient.baseClient.userManagementEndpoint = s.Srv.URL + authClient.baseClient.providerConfigEndpoint = s.Srv.URL authClient.TenantManager.endpoint = s.Srv.URL s.Client = authClient return &s