diff --git a/storage/go.mod b/storage/go.mod index 6066b8db33b..63bc0716c6d 100644 --- a/storage/go.mod +++ b/storage/go.mod @@ -8,6 +8,7 @@ require ( cloud.google.com/go v0.112.1 cloud.google.com/go/compute/metadata v0.2.3 cloud.google.com/go/iam v1.1.6 + github.com/golang/protobuf v1.5.3 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/googleapis/gax-go/v2 v2.12.2 @@ -26,7 +27,6 @@ require ( github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.3 // indirect github.com/google/martian/v3 v3.3.2 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect diff --git a/storage/grpc_client.go b/storage/grpc_client.go index bdbf3acfea2..b661d6c8ac2 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -27,6 +27,7 @@ import ( "cloud.google.com/go/internal/trace" gapic "cloud.google.com/go/storage/internal/apiv2" "cloud.google.com/go/storage/internal/apiv2/storagepb" + "github.com/golang/protobuf/proto" "github.com/googleapis/gax-go/v2" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" @@ -34,8 +35,10 @@ import ( "google.golang.org/api/option/internaloption" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protowire" fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" ) @@ -902,12 +905,50 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec return r, nil } +// bytesCodec is a grpc codec which permits receiving messages as either +// protobuf messages, or as raw []bytes. +type bytesCodec struct { + encoding.Codec +} + +func (bytesCodec) Marshal(v any) ([]byte, error) { + vv, ok := v.(proto.Message) + if !ok { + return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + } + return proto.Marshal(vv) +} + +func (bytesCodec) Unmarshal(data []byte, v any) error { + switch v := v.(type) { + case *[]byte: + // If gRPC could recycle the data []byte after unmarshaling (through + // buffer pools), we would need to make a copy here. + *v = data + return nil + case proto.Message: + return proto.Unmarshal(data, v) + default: + return fmt.Errorf("can not unmarshal type %T", v) + } +} + +func (bytesCodec) Name() string { + // If this isn't "", then gRPC sets the content-subtype of the call to this + // value and we get errors. + return "" +} + func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) { ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader") defer func() { trace.EndSpan(ctx, err) }() s := callSettings(c.settings, opts...) + s.gax = append(s.gax, gax.WithGRPCOptions( + grpc.ForceCodec(bytesCodec{}), + )) + if s.userProject != "" { ctx = setUserProjectMetadata(ctx, s.userProject) } @@ -923,6 +964,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange req.Generation = params.gen } + var databuf []byte + // Define a function that initiates a Read with offset and length, assuming // we have already read seen bytes. reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) { @@ -957,12 +1000,23 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange return err } - msg, err = stream.Recv() + // Receive the message into databuf as a wire-encoded message so we can + // use a custom decoder to avoid an extra copy at the protobuf layer. + err := stream.RecvMsg(&databuf) // These types of errors show up on the Recv call, rather than the // initialization of the stream via ReadObject above. if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { return ErrObjectNotExist } + if err != nil { + return err + } + // Use a custom decoder that uses protobuf unmarshalling for all + // fields except the checksummed data. + // Subsequent receives in Read calls will skip all protobuf + // unmarshalling and directly read the content from the gRPC []byte + // response, since only the first call will contain other fields. + msg, err = readFullObjectResponse(databuf) return err }, s.retry, s.idempotent) @@ -1008,6 +1062,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange leftovers: msg.GetChecksummedData().GetContent(), settings: s, zeroRange: params.length == 0, + databuf: databuf, }, } @@ -1406,6 +1461,7 @@ type gRPCReader struct { stream storagepb.Storage_ReadObjectClient reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error) leftovers []byte + databuf []byte cancel context.CancelFunc settings *settings } @@ -1436,7 +1492,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) { } // Attempt to Recv the next message on the stream. - msg, err := r.recv() + content, err := r.recv() if err != nil { return 0, err } @@ -1448,7 +1504,6 @@ func (r *gRPCReader) Read(p []byte) (int, error) { // present in the response here. // TODO: Figure out if we need to support decompressive transcoding // https://cloud.google.com/storage/docs/transcoding. - content := msg.GetChecksummedData().GetContent() n = copy(p[n:], content) leftover := len(content) - n if leftover > 0 { @@ -1471,9 +1526,10 @@ func (r *gRPCReader) Close() error { return nil } -// recv attempts to Recv the next message on the stream. In the event -// that a retryable error is encountered, the stream will be closed, reopened, -// and Recv again. This will attempt to Recv until one of the following is true: +// recv attempts to Recv the next message on the stream and extract the object +// data that it contains. In the event that a retryable error is encountered, +// the stream will be closed, reopened, and RecvMsg again. +// This will attempt to Recv until one of the following is true: // // * Recv is successful // * A non-retryable error is encountered @@ -1481,8 +1537,9 @@ func (r *gRPCReader) Close() error { // // The last error received is the one that is returned, which could be from // an attempt to reopen the stream. -func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) { - msg, err := r.stream.Recv() +func (r *gRPCReader) recv() ([]byte, error) { + err := r.stream.RecvMsg(&r.databuf) + var shouldRetry = ShouldRetry if r.settings.retry != nil && r.settings.retry.shouldRetry != nil { shouldRetry = r.settings.retry.shouldRetry @@ -1492,10 +1549,195 @@ func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) { // reopen the stream, but will backoff if further attempts are necessary. // Reopening the stream Recvs the first message, so if retrying is // successful, the next logical chunk will be returned. - msg, err = r.reopenStream() + msg, err := r.reopenStream() + return msg.GetChecksummedData().GetContent(), err + } + + if err != nil { + return nil, err + } + + return readObjectResponseContent(r.databuf) +} + +// ReadObjectResponse field and subfield numbers. +const ( + checksummedDataField = protowire.Number(1) + checksummedDataContentField = protowire.Number(1) + checksummedDataCRC32CField = protowire.Number(2) + objectChecksumsField = protowire.Number(2) + contentRangeField = protowire.Number(3) + metadataField = protowire.Number(4) +) + +// readObjectResponseContent returns the checksummed_data.content field of a +// ReadObjectResponse message, or an error if the message is invalid. +// This can be used on recvs of objects after the first recv, since only the +// first message will contain non-data fields. +func readObjectResponseContent(b []byte) ([]byte, error) { + checksummedData, err := readProtoBytes(b, checksummedDataField) + if err != nil { + return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err) + } + content, err := readProtoBytes(checksummedData, checksummedDataContentField) + if err != nil { + return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err) } - return msg, err + return content, nil +} + +// readFullObjectResponse returns the ReadObjectResponse that is encoded in the +// wire-encoded message buffer b, or an error if the message is invalid. +// This must be used on the first recv of an object as it may contain all fields +// of ReadObjectResponse, and we use or pass on those fields to the user. +// This function is essentially identical to proto.Unmarshal, except it aliases +// the data in the input []byte. If the proto library adds a feature to +// Unmarshal that does that, this function can be dropped. +func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) { + msg := &storagepb.ReadObjectResponse{} + + // Loop over the entire message, extracting fields as we go. This does not + // handle field concatenation, in which the contents of a single field + // are split across multiple protobuf tags. + off := 0 + for off < len(b) { + // Consume the next tag. This will tell us which field is next in the + // buffer, its type, and how much space it takes up. + fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:]) + if fieldLength < 0 { + return nil, protowire.ParseError(fieldLength) + } + off += fieldLength + + // Unmarshal the field according to its type. Only fields that are not + // nil will be present. + switch { + case fieldNum == checksummedDataField && fieldType == protowire.BytesType: + // The ChecksummedData field was found. Initialize the struct. + msg.ChecksummedData = &storagepb.ChecksummedData{} + + // Get the bytes corresponding to the checksummed data. + fieldContent, n := protowire.ConsumeBytes(b[off:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n)) + } + off += n + + // Get the nested fields. We need to do this manually as it contains + // the object content bytes. + contentOff := 0 + for contentOff < len(fieldContent) { + gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:]) + if n < 0 { + return nil, protowire.ParseError(n) + } + contentOff += n + + switch { + case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType: + // Get the content bytes. + bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n)) + } + msg.ChecksummedData.Content = bytes + contentOff += n + case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type: + v, n := protowire.ConsumeFixed32(fieldContent[contentOff:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n)) + } + msg.ChecksummedData.Crc32C = &v + contentOff += n + default: + n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:]) + if n < 0 { + return nil, protowire.ParseError(n) + } + contentOff += n + } + } + case fieldNum == objectChecksumsField && fieldType == protowire.BytesType: + // The field was found. Initialize the struct. + msg.ObjectChecksums = &storagepb.ObjectChecksums{} + + // Get the bytes corresponding to the checksums. + bytes, n := protowire.ConsumeBytes(b[off:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n)) + } + off += n + + // Unmarshal. + if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil { + return nil, err + } + case fieldNum == contentRangeField && fieldType == protowire.BytesType: + msg.ContentRange = &storagepb.ContentRange{} + + bytes, n := protowire.ConsumeBytes(b[off:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n)) + } + off += n + + if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil { + return nil, err + } + case fieldNum == metadataField && fieldType == protowire.BytesType: + msg.Metadata = &storagepb.Object{} + + bytes, n := protowire.ConsumeBytes(b[off:]) + if n < 0 { + return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n)) + } + off += n + + if err := proto.Unmarshal(bytes, msg.Metadata); err != nil { + return nil, err + } + default: + fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:]) + if fieldLength < 0 { + return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength)) + } + off += fieldLength + } + } + + return msg, nil +} + +// readProtoBytes returns the contents of the protobuf field with number num +// and type bytes from a wire-encoded message. If the field cannot be found, +// the returned slice will be nil and no error will be returned. +// +// It does not handle field concatenation, in which the contents of a single field +// are split across multiple protobuf tags. Encoded data containing split fields +// of this form is technically permissable, but uncommon. +func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) { + off := 0 + for off < len(b) { + gotNum, gotTyp, n := protowire.ConsumeTag(b[off:]) + if n < 0 { + return nil, protowire.ParseError(n) + } + off += n + if gotNum == num && gotTyp == protowire.BytesType { + b, n := protowire.ConsumeBytes(b[off:]) + if n < 0 { + return nil, protowire.ParseError(n) + } + return b, nil + } + n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:]) + if n < 0 { + return nil, protowire.ParseError(n) + } + off += n + } + return nil, nil } // reopenStream "closes" the existing stream and attempts to reopen a stream and diff --git a/storage/grpc_client_test.go b/storage/grpc_client_test.go new file mode 100644 index 00000000000..5c1eb0f1283 --- /dev/null +++ b/storage/grpc_client_test.go @@ -0,0 +1,147 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "crypto/md5" + "hash/crc32" + "math/rand" + "testing" + "time" + + "cloud.google.com/go/storage/internal/apiv2/storagepb" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestBytesCodec(t *testing.T) { + // Generate some random content. + content := make([]byte, 1<<10+1) // 1 kib + 1 byte + rand.New(rand.NewSource(0)).Read(content) + + // Calculate full content hashes. + crc32c := crc32.Checksum(content, crc32.MakeTable(crc32.Castagnoli)) + hasher := md5.New() + if _, err := hasher.Write(content); err != nil { + t.Errorf("hasher.Write: %v", err) + } + md5 := hasher.Sum(nil) + + trueBool := true + metadata := &storagepb.Object{ + Name: "object-name", + Bucket: "bucket-name", + Etag: "etag", + Generation: 100, + Metageneration: 907, + StorageClass: "Standard", + Size: 1025, + ContentEncoding: "none", + ContentDisposition: "inline", + CacheControl: "public, max-age=3600", + Acl: []*storagepb.ObjectAccessControl{{ + Role: "role", + Id: "id", + Entity: "allUsers", + Etag: "tag", + Email: "email@foo.com", + }}, + ContentLanguage: "mi, en", + DeleteTime: toProtoTimestamp(time.Now()), + ContentType: "application/octet-stream", + CreateTime: toProtoTimestamp(time.Now()), + ComponentCount: 1, + Checksums: &storagepb.ObjectChecksums{ + Crc32C: &crc32c, + Md5Hash: md5, + }, + TemporaryHold: true, + Metadata: map[string]string{ + "a-key": "a-value", + }, + EventBasedHold: &trueBool, + Owner: &storagepb.Owner{ + Entity: "user-1", + EntityId: "1", + }, + CustomerEncryption: &storagepb.CustomerEncryption{ + EncryptionAlgorithm: "alg", + KeySha256Bytes: []byte("bytes"), + }, + HardDeleteTime: toProtoTimestamp(time.Now()), + } + + for _, test := range []struct { + desc string + resp *storagepb.ReadObjectResponse + }{ + { + desc: "filled object response", + resp: &storagepb.ReadObjectResponse{ + ChecksummedData: &storagepb.ChecksummedData{ + Content: content, + Crc32C: &crc32c, + }, + ObjectChecksums: &storagepb.ObjectChecksums{ + Crc32C: &crc32c, + Md5Hash: md5, + }, + ContentRange: &storagepb.ContentRange{ + Start: 0, + End: 1025, + CompleteLength: 1025, + }, + Metadata: metadata, + }, + }, + { + desc: "empty object response", + resp: &storagepb.ReadObjectResponse{}, + }, + { + desc: "partially empty", + resp: &storagepb.ReadObjectResponse{ + ChecksummedData: &storagepb.ChecksummedData{}, + ObjectChecksums: &storagepb.ObjectChecksums{Md5Hash: md5}, + Metadata: &storagepb.Object{}, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + // Encode the response. + encodedResp, err := proto.Marshal(test.resp) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + + // Unmarshal and decode response using custom decoding. + encodedBytes := &[]byte{} + if err := bytesCodec.Unmarshal(bytesCodec{}, encodedResp, encodedBytes); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + got, err := readFullObjectResponse(*encodedBytes) + if err != nil { + t.Fatalf("readFullObjectResponse: %v", err) + } + + // Compare the result with the original ReadObjectResponse. + if diff := cmp.Diff(got, test.resp, protocmp.Transform()); diff != "" { + t.Errorf("cmp.Diff got(-),want(+):\n%s", diff) + } + }) + } +} diff --git a/storage/integration_test.go b/storage/integration_test.go index 18f2b5e89bd..0c6aedaf16e 100644 --- a/storage/integration_test.go +++ b/storage/integration_test.go @@ -1023,7 +1023,8 @@ func TestIntegration_ObjectReadChunksGRPC(t *testing.T) { multiTransportTest(skipHTTP("gRPC implementation specific test"), t, func(t *testing.T, ctx context.Context, bucket string, _ string, client *Client) { h := testHelper{t} // Use a larger blob to test chunking logic. This is a little over 5MB. - content := bytes.Repeat([]byte("a"), 5<<20) + content := make([]byte, 5<<20) + rand.New(rand.NewSource(0)).Read(content) // Upload test data. obj := client.Bucket(bucket).Object(uidSpaceObjects.New()) @@ -1066,6 +1067,9 @@ func TestIntegration_ObjectReadChunksGRPC(t *testing.T) { if rem := r.Remain(); rem != 0 { t.Errorf("got %v bytes remaining, want 0", rem) } + if !bytes.Equal(buf, content) { + t.Errorf("content mismatch") + } }) }