diff --git a/storage/grpc_client.go b/storage/grpc_client.go index c8c019da513..e337213f03f 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "errors" "fmt" + "hash/crc32" "io" "net/url" "os" @@ -1042,6 +1043,16 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange // This is the size of the entire object, even if only a range was requested. size := obj.GetSize() + // Only support checksums when reading an entire object, not a range. + var ( + wantCRC uint32 + checkCRC bool + ) + if checksums := msg.GetObjectChecksums(); checksums != nil && checksums.Crc32C != nil && params.offset == 0 && params.length < 0 { + wantCRC = checksums.GetCrc32C() + checkCRC = true + } + r = &Reader{ Attrs: ReaderObjectAttrs{ Size: size, @@ -1063,7 +1074,10 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange settings: s, zeroRange: params.length == 0, databuf: databuf, + wantCRC: wantCRC, + checkCRC: checkCRC, }, + checkCRC: checkCRC, } cr := msg.GetContentRange() @@ -1081,12 +1095,6 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange r.reader.Close() } - // Only support checksums when reading an entire object, not a range. - if checksums := msg.GetObjectChecksums(); checksums != nil && checksums.Crc32C != nil && params.offset == 0 && params.length < 0 { - r.wantCRC = checksums.GetCrc32C() - r.checkCRC = true - } - return r, nil } @@ -1464,12 +1472,34 @@ type gRPCReader struct { databuf []byte cancel context.CancelFunc settings *settings + checkCRC bool // should we check the CRC? + wantCRC uint32 // the CRC32c value the server sent in the header + gotCRC uint32 // running crc +} + +// Update the running CRC with the data in the slice, if CRC checking was enabled. +func (r *gRPCReader) updateCRC(b []byte) { + if r.checkCRC { + r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, b) + } +} + +// Checks whether the CRC matches at the conclusion of a read, if CRC checking was enabled. +func (r *gRPCReader) runCRCCheck() error { + if r.checkCRC && r.gotCRC != r.wantCRC { + return fmt.Errorf("storage: bad CRC on read: got %d, want %d", r.gotCRC, r.wantCRC) + } + return nil } // Read reads bytes into the user's buffer from an open gRPC stream. func (r *gRPCReader) Read(p []byte) (int, error) { - // The entire object has been read by this reader, return EOF. + // The entire object has been read by this reader, check the checksum if + // necessary and return EOF. if r.size == r.seen || r.zeroRange { + if err := r.runCRCCheck(); err != nil { + return 0, err + } return 0, io.EOF } @@ -1478,7 +1508,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) { // using the same reader. One encounters an error and the stream is closed // and then reopened while the other routine attempts to read from it. if r.stream == nil { - return 0, fmt.Errorf("reader has been closed") + return 0, fmt.Errorf("storage: reader has been closed") } var n int @@ -1487,6 +1517,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) { if len(r.leftovers) > 0 { n = copy(p, r.leftovers) r.seen += int64(n) + r.updateCRC(p[:n]) r.leftovers = r.leftovers[n:] return n, nil } @@ -1512,10 +1543,78 @@ func (r *gRPCReader) Read(p []byte) (int, error) { r.leftovers = content[n:] } r.seen += int64(n) + r.updateCRC(p[:n]) return n, nil } +// WriteTo writes all the data requested by the Reader into w, implementing +// io.WriterTo. +func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { + // The entire object has been read by this reader, check the checksum if + // necessary and return nil. + if r.size == r.seen || r.zeroRange { + if err := r.runCRCCheck(); err != nil { + return 0, err + } + return 0, nil + } + + // No stream to read from, either never initialized or Close was called. + // Note: There is a potential concurrency issue if multiple routines are + // using the same reader. One encounters an error and the stream is closed + // and then reopened while the other routine attempts to read from it. + if r.stream == nil { + return 0, fmt.Errorf("storage: reader has been closed") + } + + // Track bytes written during before call. + var alreadySeen = r.seen + + // Write any leftovers to the stream. There will be some leftovers from the + // original NewRangeReader call. + if len(r.leftovers) > 0 { + // Write() will write the entire leftovers slice unless there is an error. + written, err := w.Write(r.leftovers) + r.seen += int64(written) + r.updateCRC(r.leftovers) + r.leftovers = nil + if err != nil { + return r.seen - alreadySeen, err + } + } + + // Loop and receive additional messages until the entire data is written. + for { + // Attempt to receive the next message on the stream. + // Will terminate with io.EOF once data has all come through. + // recv() handles stream reopening and retry logic so no need for retries here. + msg, err := r.recv() + if err != nil { + if err == io.EOF { + // We are done; check the checksum if necessary and return. + err = r.runCRCCheck() + } + return r.seen - alreadySeen, err + } + + // TODO: Determine if we need to capture incremental CRC32C for this + // chunk. The Object CRC32C checksum is captured when directed to read + // the entire Object. If directed to read a range, we may need to + // calculate the range's checksum for verification if the checksum is + // present in the response here. + // TODO: Figure out if we need to support decompressive transcoding + // https://cloud.google.com/storage/docs/transcoding. + written, err := w.Write(msg) + r.seen += int64(written) + r.updateCRC(msg) + if err != nil { + return r.seen - alreadySeen, err + } + } + +} + // Close cancels the read stream's context in order for it to be closed and // collected. func (r *gRPCReader) Close() error { diff --git a/storage/http_client.go b/storage/http_client.go index e3e0d761bb0..f75d93897d9 100644 --- a/storage/http_client.go +++ b/storage/http_client.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "errors" "fmt" + "hash/crc32" "io" "io/ioutil" "net/http" @@ -1218,9 +1219,12 @@ func (c *httpStorageClient) DeleteNotification(ctx context.Context, bucket strin } type httpReader struct { - body io.ReadCloser - seen int64 - reopen func(seen int64) (*http.Response, error) + body io.ReadCloser + seen int64 + reopen func(seen int64) (*http.Response, error) + checkCRC bool // should we check the CRC? + wantCRC uint32 // the CRC32c value the server sent in the header + gotCRC uint32 // running crc } func (r *httpReader) Read(p []byte) (int, error) { @@ -1229,7 +1233,22 @@ func (r *httpReader) Read(p []byte) (int, error) { m, err := r.body.Read(p[n:]) n += m r.seen += int64(m) - if err == nil || err == io.EOF { + if r.checkCRC { + r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, p[:n]) + } + if err == nil { + return n, nil + } + if err == io.EOF { + // Check CRC here. It would be natural to check it in Close, but + // everybody defers Close on the assumption that it doesn't return + // anything worth looking at. + if r.checkCRC { + if r.gotCRC != r.wantCRC { + return n, fmt.Errorf("storage: bad CRC on read: got %d, want %d", + r.gotCRC, r.wantCRC) + } + } return n, err } // Read failed (likely due to connection issues), but we will try to reopen @@ -1435,11 +1454,12 @@ func parseReadResponse(res *http.Response, params *newRangeReaderParams, reopen Attrs: attrs, size: size, remain: remain, - wantCRC: crc, checkCRC: checkCRC, reader: &httpReader{ - reopen: reopen, - body: body, + reopen: reopen, + body: body, + wantCRC: crc, + checkCRC: checkCRC, }, }, nil } diff --git a/storage/reader.go b/storage/reader.go index 4673a68d078..0b228a6a76c 100644 --- a/storage/reader.go +++ b/storage/reader.go @@ -198,9 +198,7 @@ var emptyBody = ioutil.NopCloser(strings.NewReader("")) type Reader struct { Attrs ReaderObjectAttrs seen, remain, size int64 - checkCRC bool // should we check the CRC? - wantCRC uint32 // the CRC32c value the server sent in the header - gotCRC uint32 // running crc + checkCRC bool // Did we check the CRC? This is now only used by tests. reader io.ReadCloser ctx context.Context @@ -218,17 +216,17 @@ func (r *Reader) Read(p []byte) (int, error) { if r.remain != -1 { r.remain -= int64(n) } - if r.checkCRC { - r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, p[:n]) - // Check CRC here. It would be natural to check it in Close, but - // everybody defers Close on the assumption that it doesn't return - // anything worth looking at. - if err == io.EOF { - if r.gotCRC != r.wantCRC { - return n, fmt.Errorf("storage: bad CRC on read: got %d, want %d", - r.gotCRC, r.wantCRC) - } - } + return n, err +} + +// WriteTo writes all the data from the Reader to w. Fulfills the io.WriterTo interface. +// This is called implicitly when calling io.Copy on a Reader. +func (r *Reader) WriteTo(w io.Writer) (int64, error) { + // This implicitly calls r.reader.WriteTo for gRPC only. JSON and XML don't have an + // implementation of WriteTo. + n, err := io.Copy(w, r.reader) + if r.remain != -1 { + r.remain -= int64(n) } return n, err } diff --git a/storage/retry_conformance_test.go b/storage/retry_conformance_test.go index 3f9dd618eba..950b542e2c8 100644 --- a/storage/retry_conformance_test.go +++ b/storage/retry_conformance_test.go @@ -21,7 +21,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -211,12 +210,25 @@ var methods = map[string][]retryFunc{ if err != nil { return err } - wr, err := io.Copy(ioutil.Discard, r) + wr, err := r.WriteTo(io.Discard) if got, want := wr, len(randomBytesToWrite); got != int64(want) { return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) } return err }, + func(ctx context.Context, c *Client, fs *resources, _ bool) error { + // This tests downloads by calling Reader.Read rather than Reader.WriteTo. + r, err := c.Bucket(fs.bucket.Name).Object(fs.object.Name).NewReader(ctx) + if err != nil { + return err + } + // Use ReadAll because it calls Read implicitly, not WriteTo. + b, err := io.ReadAll(r) + if got, want := len(b), len(randomBytesToWrite); got != want { + return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + return err + }, func(ctx context.Context, c *Client, fs *resources, _ bool) error { // Test JSON reads. client, ok := c.tc.(*httpStorageClient) @@ -233,7 +245,7 @@ var methods = map[string][]retryFunc{ if err != nil { return err } - wr, err := io.Copy(ioutil.Discard, r) + wr, err := io.Copy(io.Discard, r) if got, want := wr, len(randomBytesToWrite); got != int64(want) { return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) } @@ -253,7 +265,7 @@ var methods = map[string][]retryFunc{ return err } defer r.Close() - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { return fmt.Errorf("failed to ReadAll, err: %v", err) } @@ -265,6 +277,32 @@ var methods = map[string][]retryFunc{ } return nil }, + func(ctx context.Context, c *Client, fs *resources, _ bool) error { + // Test download via Reader.WriteTo. + // Before running the test method, populate a large test object of 9 MiB. + objName := objectIDs.New() + if err := uploadTestObject(fs.bucket.Name, objName, randomBytes3MiB); err != nil { + return fmt.Errorf("failed to create 9 MiB large object pre test, err: %v", err) + } + // Download the large test object for the S8 download method group. + r, err := c.Bucket(fs.bucket.Name).Object(objName).NewReader(ctx) + if err != nil { + return err + } + defer r.Close() + var data bytes.Buffer + _, err = r.WriteTo(&data) + if err != nil { + return fmt.Errorf("failed to ReadAll, err: %v", err) + } + if got, want := data.Len(), 3*MiB; got != want { + return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + if got, want := data.Bytes(), randomBytes3MiB; !bytes.Equal(got, want) { + return fmt.Errorf("body mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + return nil + }, func(ctx context.Context, c *Client, fs *resources, _ bool) error { // Test JSON reads. // Before running the test method, populate a large test object. @@ -289,7 +327,7 @@ var methods = map[string][]retryFunc{ return err } defer r.Close() - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { return fmt.Errorf("failed to ReadAll, err: %v", err) }