forked from databricks/databricks-sdk-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
auth_m2m.go
84 lines (75 loc) · 2.48 KB
/
auth_m2m.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"github.com/xuxiaoshuo/databricks-sdk-go/logger"
)
var errOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
type M2mCredentials struct {
}
func (c M2mCredentials) Name() string {
return "oauth-m2m"
}
func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
if cfg.ClientID == "" || cfg.ClientSecret == "" {
return nil, nil
}
endpoints, err := oidcEndpoints(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("oidc: %w", err)
}
logger.Debugf(ctx, "Generating Databricks OAuth token for Service Principal (%s)", cfg.ClientID)
ts := (&clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
AuthStyle: oauth2.AuthStyleInHeader,
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
}).TokenSource(ctx)
return refreshableVisitor(ts), nil
}
func oidcEndpoints(ctx context.Context, cfg *Config) (*oauthAuthorizationServer, error) {
prefix := cfg.Host
if cfg.IsAccountClient() && cfg.AccountID != "" {
// TODO: technically, we could use the same config profile for both workspace
// and account, but we have to add logic for determining accounts host from
// workspace host.
prefix := fmt.Sprintf("%s/oidc/accounts/%s", cfg.Host, cfg.AccountID)
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
}, nil
}
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
oidcResponse, err := http.Get(oidc)
if err != nil {
return nil, fmt.Errorf("fetch .well-known: %w", err)
}
if oidcResponse.StatusCode != 200 {
return nil, errOAuthNotSupported
}
if oidcResponse.Body == nil {
return nil, fmt.Errorf("fetch .well-known: empty body")
}
defer oidcResponse.Body.Close()
raw, err := io.ReadAll(oidcResponse.Body)
if err != nil {
return nil, fmt.Errorf("read .well-known: %w", err)
}
var oauthEndpoints oauthAuthorizationServer
err = json.Unmarshal(raw, &oauthEndpoints)
if err != nil {
return nil, fmt.Errorf("parse .well-known: %w", err)
}
return &oauthEndpoints, nil
}
type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
}