diff --git a/internal/server/interop.go b/internal/server/interop.go index 6a555fa..9bb8f84 100644 --- a/internal/server/interop.go +++ b/internal/server/interop.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/http" - "time" "github.com/aws/aws-sdk-go/aws" "github.com/localstack/lambda-runtime-init/internal/localstack" @@ -32,35 +31,6 @@ func NewInteropServer(ls *localstack.LocalStackClient) *LocalStackInteropsServer } } -func (c *LocalStackInteropsServer) Init(initRequest *interop.Init, timeoutMs int64) error { - // This allows us to properly timeout when an INIT request -- which is unimplemented in the upstream. - - initStart := metering.Monotime() - - initDone := make(chan error, 1) - go func() { - initDone <- c.Server.Init(initRequest, timeoutMs) - }() - - var err error - select { - case err = <-initDone: - case <-time.After(time.Duration(timeoutMs) * time.Millisecond): - if _, resetErr := c.Server.Reset("timeout", 2000); resetErr != nil { - log.WithError(resetErr).Error("Failed to reset after init timeout") - } - err = errors.New("timeout") - } - - initDuration := float64(metering.Monotime()-initStart) / float64(time.Millisecond) - - if err != nil { - log.WithError(err).WithField("duration", initDuration).Error("Init failed") - } - - return err -} - func (c *LocalStackInteropsServer) Execute(ctx context.Context, responseWriter http.ResponseWriter, invoke *interop.Invoke) error { ctx, cancel := context.WithTimeout(context.Background(), c.Server.GetInvokeTimeout()) defer cancel() diff --git a/internal/server/service.go b/internal/server/service.go index 5b84ec2..a91f0f8 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "math" "os" @@ -24,6 +25,8 @@ import ( "go.amzn.com/lambda/rapidcore/env" ) +var errTimeout = errors.New("timeout") + type LocalStackService struct { sandbox *LocalStackInteropsServer adapter *localstack.LocalStackClient @@ -129,7 +132,22 @@ func (ls *LocalStackService) Initialize(bs interop.Bootstrap) error { } initStart := metering.Monotime() - err = ls.sandbox.Init(initRequest, initRequest.InitTimeoutMs) + + initDone := make(chan error, 1) + go func() { + initDone <- ls.sandbox.Init(initRequest, initRequest.InvokeTimeoutMs) + }() + + select { + case err = <-initDone: + case <-time.After(initTimeout): + _, resetFailure := ls.sandbox.Reset("timeout", 2000) + if resetFailure != nil { + log.WithError(resetFailure).Error("Failed to reset after init timeout") + } + err = errTimeout + } + ls.initDuration = float64(metering.Monotime()-initStart) / float64(time.Millisecond) if err != nil {