Skip to content

Commit

Permalink
Hoisted userManagementClient and providerConfigClient into baseClient (
Browse files Browse the repository at this point in the history
…firebase#317)

* Hoisted userManagementClient and providerConfigClient into baseClient

* Removed providerConfigClient
  • Loading branch information
hiranya911 committed Dec 21, 2019
1 parent 945b1b1 commit d30698d
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 168 deletions.
44 changes: 26 additions & 18 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,25 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
return nil, err
}

hc, _, err := transport.NewHTTPClient(ctx, conf.Opts...)
transport, _, err := transport.NewHTTPClient(ctx, conf.Opts...)
if err != nil {
return nil, err
}

hc := internal.WithDefaultRetryConfig(transport)
hc.CreateErrFn = handleHTTPError
hc.SuccessFn = internal.HasSuccessStatus
hc.Opts = []internal.HTTPOption{
internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)),
}

base := &baseClient{
userManagementClient: newUserManagementClient(hc, conf),
providerConfigClient: newProviderConfigClient(hc, conf),
idTokenVerifier: idTokenVerifier,
cookieVerifier: cookieVerifier,
userManagementEndpoint: idToolkitV1Endpoint,
providerConfigEndpoint: providerConfigEndpoint,
projectID: conf.ProjectID,
httpClient: hc,
idTokenVerifier: idTokenVerifier,
cookieVerifier: cookieVerifier,
}
return &Client{
baseClient: base,
Expand Down Expand Up @@ -183,7 +192,7 @@ func (c *Client) SessionCookie(
idToken string,
expiresIn time.Duration,
) (string, error) {
return c.baseClient.userManagementClient.createSessionCookie(ctx, idToken, expiresIn)
return c.baseClient.createSessionCookie(ctx, idToken, expiresIn)
}

// Token represents a decoded Firebase ID token.
Expand Down Expand Up @@ -213,22 +222,21 @@ type FirebaseInfo struct {
Identities map[string]interface{} `json:"identities"`
}

// baseClient exposes the APIs common to both auth.Client and auth.TenantClient.
type baseClient struct {
*userManagementClient
*providerConfigClient
idTokenVerifier *tokenVerifier
cookieVerifier *tokenVerifier
tenantID string
userManagementEndpoint string
providerConfigEndpoint string
projectID string
tenantID string
httpClient *internal.HTTPClient
idTokenVerifier *tokenVerifier
cookieVerifier *tokenVerifier
}

func (c *baseClient) withTenantID(tenantID string) *baseClient {
return &baseClient{
userManagementClient: c.userManagementClient.withTenantID(tenantID),
providerConfigClient: c.providerConfigClient.withTenantID(tenantID),
idTokenVerifier: c.idTokenVerifier,
cookieVerifier: c.cookieVerifier,
tenantID: tenantID,
}
copy := *c
copy.tenantID = tenantID
return &copy
}

// VerifyIDToken verifies the signature and payload of the provided ID token.
Expand Down
27 changes: 15 additions & 12 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ func TestNewClientWithServiceAccountCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, creds.ProjectID); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, creds.ProjectID); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, creds.ProjectID); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand All @@ -127,8 +127,8 @@ func TestNewClientWithoutCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand All @@ -155,8 +155,8 @@ func TestNewClientWithServiceAccountID(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand Down Expand Up @@ -194,8 +194,8 @@ func TestNewClientWithUserCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand Down Expand Up @@ -1057,10 +1057,13 @@ func checkCookieVerifier(tv *tokenVerifier, projectID string) error {
return nil
}

func checkUserManagementClient(client *Client, wantProjectID string) error {
umc := client.userManagementClient
if umc.baseURL != idToolkitV1Endpoint {
return fmt.Errorf("baseURL = %q; want = %q", umc.baseURL, idToolkitV1Endpoint)
func checkBaseClient(client *Client, wantProjectID string) error {
umc := client.baseClient
if umc.userManagementEndpoint != idToolkitV1Endpoint {
return fmt.Errorf("userManagementEndpoint = %q; want = %q", umc.userManagementEndpoint, idToolkitV1Endpoint)
}
if umc.providerConfigEndpoint != providerConfigEndpoint {
return fmt.Errorf("providerConfigEndpoint = %q; want = %q", umc.providerConfigEndpoint, providerConfigEndpoint)
}
if umc.projectID != wantProjectID {
return fmt.Errorf("projectID = %q; want = %q", umc.projectID, wantProjectID)
Expand Down
12 changes: 6 additions & 6 deletions auth/email_action_links.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,38 @@ const (

// EmailVerificationLink generates the out-of-band email action link for email verification flows for the specified
// email address.
func (c *userManagementClient) EmailVerificationLink(ctx context.Context, email string) (string, error) {
func (c *baseClient) EmailVerificationLink(ctx context.Context, email string) (string, error) {
return c.EmailVerificationLinkWithSettings(ctx, email, nil)
}

// EmailVerificationLinkWithSettings generates the out-of-band email action link for email verification flows for the
// specified email address, using the action code settings provided.
func (c *userManagementClient) EmailVerificationLinkWithSettings(
func (c *baseClient) EmailVerificationLinkWithSettings(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, emailVerification, email, settings)
}

// PasswordResetLink generates the out-of-band email action link for password reset flows for the specified email
// address.
func (c *userManagementClient) PasswordResetLink(ctx context.Context, email string) (string, error) {
func (c *baseClient) PasswordResetLink(ctx context.Context, email string) (string, error) {
return c.PasswordResetLinkWithSettings(ctx, email, nil)
}

// PasswordResetLinkWithSettings generates the out-of-band email action link for password reset flows for the
// specified email address, using the action code settings provided.
func (c *userManagementClient) PasswordResetLinkWithSettings(
func (c *baseClient) PasswordResetLinkWithSettings(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, passwordReset, email, settings)
}

// EmailSignInLink generates the out-of-band email action link for email link sign-in flows, using the action
// code settings provided.
func (c *userManagementClient) EmailSignInLink(
func (c *baseClient) EmailSignInLink(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, emailLinkSignIn, email, settings)
}

func (c *userManagementClient) generateEmailActionLink(
func (c *baseClient) generateEmailActionLink(
ctx context.Context, linkType linkType, email string, settings *ActionCodeSettings) (string, error) {

if email == "" {
Expand Down
2 changes: 1 addition & 1 deletion auth/email_action_links_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func TestEmailVerificationLinkError(t *testing.T) {
}
s := echoServer(testActionLinkResponse, t)
defer s.Close()
s.Client.userManagementClient.httpClient.RetryConfig = nil
s.Client.baseClient.httpClient.RetryConfig = nil
s.Status = http.StatusInternalServerError

for code, check := range cases {
Expand Down
4 changes: 2 additions & 2 deletions auth/export_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const maxReturnedResults = 1000
//
// If nextPageToken is empty, the iterator will start at the beginning.
// If the nextPageToken is not empty, the iterator starts after the token.
func (c *userManagementClient) Users(ctx context.Context, nextPageToken string) *UserIterator {
func (c *baseClient) Users(ctx context.Context, nextPageToken string) *UserIterator {
it := &UserIterator{
ctx: ctx,
client: c,
Expand All @@ -49,7 +49,7 @@ func (c *userManagementClient) Users(ctx context.Context, nextPageToken string)
//
// Also see: https://github.com/GoogleCloudPlatform/google-cloud-go/wiki/Iterator-Guidelines
type UserIterator struct {
client *userManagementClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down
2 changes: 1 addition & 1 deletion auth/import_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type ErrorInfo struct {
//
// No more than 1000 users can be imported in a single call. If at least one user specifies a
// password, a UserImportHash must be specified as an option.
func (c *userManagementClient) ImportUsers(
func (c *baseClient) ImportUsers(
ctx context.Context, users []*UserToImport, opts ...UserImportOption) (*UserImportResult, error) {

if len(users) == 0 {
Expand Down
60 changes: 17 additions & 43 deletions auth/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (config *OIDCProviderConfigToUpdate) buildRequest() (nestedMap, error) {

// OIDCProviderConfigIterator is an iterator over OIDC provider configurations.
type OIDCProviderConfigIterator struct {
client *providerConfigClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down Expand Up @@ -538,7 +538,7 @@ func (config *SAMLProviderConfigToUpdate) buildRequest() (nestedMap, error) {

// SAMLProviderConfigIterator is an iterator over SAML provider configurations.
type SAMLProviderConfigIterator struct {
client *providerConfigClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down Expand Up @@ -595,36 +595,8 @@ func (it *SAMLProviderConfigIterator) fetch(pageSize int, pageToken string) (str
return result.NextPageToken, nil
}

type providerConfigClient struct {
endpoint string
projectID string
tenantID string
httpClient *internal.HTTPClient
}

func newProviderConfigClient(client *http.Client, conf *internal.AuthConfig) *providerConfigClient {
hc := internal.WithDefaultRetryConfig(client)
hc.CreateErrFn = handleHTTPError
hc.SuccessFn = internal.HasSuccessStatus
hc.Opts = []internal.HTTPOption{
internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)),
}

return &providerConfigClient{
endpoint: providerConfigEndpoint,
projectID: conf.ProjectID,
httpClient: hc,
}
}

func (c *providerConfigClient) withTenantID(tenantID string) *providerConfigClient {
copy := *c
copy.tenantID = tenantID
return &copy
}

// OIDCProviderConfig returns the OIDCProviderConfig with the given ID.
func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) {
func (c *baseClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) {
if err := validateOIDCConfigID(id); err != nil {
return nil, err
}
Expand All @@ -642,7 +614,7 @@ func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string
}

// CreateOIDCProviderConfig creates a new OIDC provider config from the given parameters.
func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) {
func (c *baseClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) {
if config == nil {
return nil, errors.New("config must not be nil")
}
Expand All @@ -669,7 +641,7 @@ func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, con
}

// UpdateOIDCProviderConfig updates an existing OIDC provider config with the given parameters.
func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) {
func (c *baseClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) {
if err := validateOIDCConfigID(id); err != nil {
return nil, err
}
Expand Down Expand Up @@ -704,7 +676,7 @@ func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id
}

// DeleteOIDCProviderConfig deletes the OIDCProviderConfig with the given ID.
func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error {
func (c *baseClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error {
if err := validateOIDCConfigID(id); err != nil {
return err
}
Expand All @@ -721,7 +693,7 @@ func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id
//
// If nextPageToken is empty, the iterator will start at the beginning. Otherwise,
// iterator starts after the token.
func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator {
func (c *baseClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator {
it := &OIDCProviderConfigIterator{
ctx: ctx,
client: c,
Expand All @@ -736,7 +708,7 @@ func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPage
}

// SAMLProviderConfig returns the SAMLProviderConfig with the given ID.
func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) {
func (c *baseClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) {
if err := validateSAMLConfigID(id); err != nil {
return nil, err
}
Expand All @@ -754,7 +726,7 @@ func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string
}

// CreateSAMLProviderConfig creates a new SAML provider config from the given parameters.
func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) {
func (c *baseClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) {
if config == nil {
return nil, errors.New("config must not be nil")
}
Expand All @@ -781,7 +753,7 @@ func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, con
}

// UpdateSAMLProviderConfig updates an existing SAML provider config with the given parameters.
func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) {
func (c *baseClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) {
if err := validateSAMLConfigID(id); err != nil {
return nil, err
}
Expand Down Expand Up @@ -816,7 +788,7 @@ func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id
}

// DeleteSAMLProviderConfig deletes the SAMLProviderConfig with the given ID.
func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error {
func (c *baseClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error {
if err := validateSAMLConfigID(id); err != nil {
return err
}
Expand All @@ -833,7 +805,7 @@ func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id
//
// If nextPageToken is empty, the iterator will start at the beginning. Otherwise,
// iterator starts after the token.
func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator {
func (c *baseClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator {
it := &SAMLProviderConfigIterator{
ctx: ctx,
client: c,
Expand All @@ -847,15 +819,17 @@ func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPage
return it
}

func (c *providerConfigClient) makeRequest(ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) {
func (c *baseClient) makeRequest(
ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) {

if c.projectID == "" {
return nil, errors.New("project id not available")
}

if c.tenantID != "" {
req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.endpoint, c.projectID, c.tenantID, req.URL)
req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.providerConfigEndpoint, c.projectID, c.tenantID, req.URL)
} else {
req.URL = fmt.Sprintf("%s/projects/%s%s", c.endpoint, c.projectID, req.URL)
req.URL = fmt.Sprintf("%s/projects/%s%s", c.providerConfigEndpoint, c.projectID, req.URL)
}

return c.httpClient.DoAndUnmarshal(ctx, req, v)
Expand Down
Loading

0 comments on commit d30698d

Please sign in to comment.