Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(storage): remove protobuf's copy of data on unmarshalling #9526

Merged
merged 13 commits into from
Mar 19, 2024
2 changes: 1 addition & 1 deletion storage/go.mod
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go v0.112.1
cloud.google.com/go/compute/metadata v0.2.3
cloud.google.com/go/iam v1.1.6
github.com/golang/protobuf v1.5.3
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/googleapis/gax-go/v2 v2.12.2
Expand All @@ -26,7 +27,6 @@ require (
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/martian/v3 v3.3.2 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
Expand Down
244 changes: 234 additions & 10 deletions storage/grpc_client.go
Expand Up @@ -27,15 +27,18 @@ import (
"cloud.google.com/go/internal/trace"
gapic "cloud.google.com/go/storage/internal/apiv2"
"cloud.google.com/go/storage/internal/apiv2/storagepb"
"github.com/golang/protobuf/proto"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protowire"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
)

Expand Down Expand Up @@ -902,12 +905,50 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec
return r, nil
}

// bytesCodec is a grpc codec which permits receiving messages as either
// protobuf messages, or as raw []bytes.
type bytesCodec struct {
encoding.Codec
}

func (bytesCodec) Marshal(v any) ([]byte, error) {
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (bytesCodec) Unmarshal(data []byte, v any) error {
switch v := v.(type) {
case *[]byte:
// If gRPC could recycle the data []byte after unmarshaling (through
// buffer pools), we would need to make a copy here.
*v = data
return nil
case proto.Message:
return proto.Unmarshal(data, v)
default:
return fmt.Errorf("can not unmarshal type %T", v)
}
}

func (bytesCodec) Name() string {
// If this isn't "", then gRPC sets the content-subtype of the call to this
// value and we get errors.
return ""
}

func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) {
ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader")
defer func() { trace.EndSpan(ctx, err) }()

s := callSettings(c.settings, opts...)

s.gax = append(s.gax, gax.WithGRPCOptions(
grpc.ForceCodec(bytesCodec{}),
))

if s.userProject != "" {
ctx = setUserProjectMetadata(ctx, s.userProject)
}
Expand All @@ -923,6 +964,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
req.Generation = params.gen
}

var databuf []byte

// Define a function that initiates a Read with offset and length, assuming
// we have already read seen bytes.
reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) {
Expand Down Expand Up @@ -957,12 +1000,23 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
return err
}

msg, err = stream.Recv()
// Receive the message into databuf as a wire-encoded message so we can
// use a custom decoder to avoid an extra copy at the protobuf layer.
err := stream.RecvMsg(&databuf)
// These types of errors show up on the Recv call, rather than the
// initialization of the stream via ReadObject above.
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return ErrObjectNotExist
}
if err != nil {
return err
}
// Use a custom decoder that uses protobuf unmarshalling for all
// fields except the checksummed data.
// Subsequent receives in Read calls will skip all protobuf
// unmarshalling and directly read the content from the gRPC []byte
// response, since only the first call will contain other fields.
msg, err = readFullObjectResponse(databuf)

return err
}, s.retry, s.idempotent)
Expand Down Expand Up @@ -1008,6 +1062,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
leftovers: msg.GetChecksummedData().GetContent(),
settings: s,
zeroRange: params.length == 0,
databuf: databuf,
},
}

Expand Down Expand Up @@ -1406,6 +1461,7 @@ type gRPCReader struct {
stream storagepb.Storage_ReadObjectClient
reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error)
leftovers []byte
databuf []byte
cancel context.CancelFunc
settings *settings
}
Expand Down Expand Up @@ -1436,7 +1492,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
}

// Attempt to Recv the next message on the stream.
msg, err := r.recv()
content, err := r.recv()
if err != nil {
return 0, err
}
Expand All @@ -1448,7 +1504,6 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
// present in the response here.
// TODO: Figure out if we need to support decompressive transcoding
// https://cloud.google.com/storage/docs/transcoding.
content := msg.GetChecksummedData().GetContent()
n = copy(p[n:], content)
leftover := len(content) - n
if leftover > 0 {
Expand All @@ -1471,18 +1526,20 @@ func (r *gRPCReader) Close() error {
return nil
}

// recv attempts to Recv the next message on the stream. In the event
// that a retryable error is encountered, the stream will be closed, reopened,
// and Recv again. This will attempt to Recv until one of the following is true:
// recv attempts to Recv the next message on the stream and extract the object
// data that it contains. In the event that a retryable error is encountered,
// the stream will be closed, reopened, and RecvMsg again.
// This will attempt to Recv until one of the following is true:
//
// * Recv is successful
// * A non-retryable error is encountered
// * The Reader's context is canceled
//
// The last error received is the one that is returned, which could be from
// an attempt to reopen the stream.
func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
msg, err := r.stream.Recv()
func (r *gRPCReader) recv() ([]byte, error) {
err := r.stream.RecvMsg(&r.databuf)

var shouldRetry = ShouldRetry
if r.settings.retry != nil && r.settings.retry.shouldRetry != nil {
shouldRetry = r.settings.retry.shouldRetry
Expand All @@ -1492,10 +1549,177 @@ func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
// reopen the stream, but will backoff if further attempts are necessary.
// Reopening the stream Recvs the first message, so if retrying is
// successful, the next logical chunk will be returned.
msg, err = r.reopenStream()
msg, err := r.reopenStream()
tritone marked this conversation as resolved.
Show resolved Hide resolved
return msg.GetChecksummedData().GetContent(), err
}

if err != nil {
return nil, err
}

return readObjectResponseContent(r.databuf)
}

// ReadObjectResponse field and subfield numbers.
const (
checksummedDataField = protowire.Number(1)
checksummedDataContentField = protowire.Number(1)
checksummedDataCRC32CField = protowire.Number(2)
objectChecksumsField = protowire.Number(2)
contentRangeField = protowire.Number(3)
metadataField = protowire.Number(4)
)

// readObjectResponseContent returns the checksummed_data.content field of a
// ReadObjectResponse message, or an error if the message is invalid.
// This can be used on recvs of objects after the first recv, since only the
// first message will contain non-data fields.
func readObjectResponseContent(b []byte) ([]byte, error) {
checksummedData, err := readProtoBytes(b, checksummedDataField)
if err != nil {
return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
}
content, err := readProtoBytes(checksummedData, checksummedDataContentField)
if err != nil {
return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
}

return content, nil
}

// readFullObjectResponse returns the ReadObjectResponse encoded in the
// wire-encoded message buffer b, or an error if the message is invalid.
// This is used on the first recv of an object as it may contain all fields of
// ReadObjectResponse.
tritone marked this conversation as resolved.
Show resolved Hide resolved
func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to note that this function is essentially identical to proto.Unmarshal, except it aliases the data in the input []byte. If we ever add a feature to Unmarshal that does that, this function can be dropped.

var checksummedData *storagepb.ChecksummedData

// Extract object content.
fieldContent, err := readProtoBytes(b, checksummedDataField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
}
// Only fill the contents if the checksummedData field was found.
if fieldContent != nil {
content, err := readProtoBytes(fieldContent, checksummedDataContentField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
}
crc32c, err := readProtoFixed32(fieldContent, checksummedDataCRC32CField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", err)
}

checksummedData = &storagepb.ChecksummedData{
Content: content,
Crc32C: crc32c,
}
}

// Unmarshal remaining fields.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're unmarshaling all the fields in the message, so this will be a bit more efficient if you loop once over the fields rather than looking up each individually:

off := 0
for off < len(b) {
  num, typ, n := protowire.ConsumeTag(b[off:])
  if n < 0 {
    return nil, protowire.ParseError(n)
  }
  off += n
  switch {
  case num == checksummedDataField && typ == protowire.BytesType:
    // unmarshal the checksummed_data field
  case num == objectChecksumsField && typ == protowire.BytesType:
    // unmarshal the object_checksums field
  case num == contentRangeField && typ == protowire.BytesType:
    // unmarshal the content_range field
  case num == metadataField && typ == protowire.BytesType:
    // unmarshal the metadata field
  default:
    n = protowire.ConsumeFieldValue(num, typ, b[off:])
    if n < 0 {
      return nil, protowire.ParseError(n)
    }
    off += n
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

var checksums *storagepb.ObjectChecksums
fieldContent, err = readProtoBytes(b, objectChecksumsField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", err)
}
// Only unmarshal the contents if the field was found.
if fieldContent != nil {
checksums = &storagepb.ObjectChecksums{}
if err := proto.Unmarshal(fieldContent, checksums); err != nil {
return nil, err
}
}

return msg, err
var contentRange *storagepb.ContentRange
fieldContent, err = readProtoBytes(b, contentRangeField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", "err")
}
if fieldContent != nil {
contentRange = &storagepb.ContentRange{}
if err := proto.Unmarshal(fieldContent, contentRange); err != nil {
return nil, err
}
}

var metadata *storagepb.Object
fieldContent, err = readProtoBytes(b, metadataField)
if err != nil {
return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", err)
}
if fieldContent != nil {
metadata = &storagepb.Object{}
if err := proto.Unmarshal(fieldContent, metadata); err != nil {
return nil, err
}
}

msg := &storagepb.ReadObjectResponse{
ChecksummedData: checksummedData,
ObjectChecksums: checksums,
ContentRange: contentRange,
Metadata: metadata,
}

return msg, nil
}

// readProtoBytes returns the contents of the protobuf field with number num
// and type bytes from a wire-encoded message. If the field cannot be found,
// the returned slice will be nil and no error will be returned.
//
// It does not handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags. Encoded data containing split fields
// of this form is technically permissable, but uncommon.
func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) {
off := 0
for off < len(b) {
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
if gotNum == num && gotTyp == protowire.BytesType {
b, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
return b, nil
}
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
}
return nil, nil
}

// readProtoFixed32 returns the contents of the protobuf field with number num
// and type uint32 from a wire-encoded message. If the field cannot be found,
// the returned pointer will be nil and no error will be returned.
func readProtoFixed32(b []byte, num protowire.Number) (*uint32, error) {
off := 0
for off < len(b) {
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
if gotNum == num && gotTyp == protowire.Fixed32Type {
v, n := protowire.ConsumeFixed32(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
return &v, nil
}
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
}
return nil, nil
}

// reopenStream "closes" the existing stream and attempts to reopen a stream and
Expand Down