diff --git a/.github/.editorconfig b/.github/.editorconfig new file mode 100644 index 0000000..0902c6a --- /dev/null +++ b/.github/.editorconfig @@ -0,0 +1,2 @@ +[{*.yml,*.yaml}] +indent_size = 2 diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..f80132b --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,9 @@ +version: 2 + +updates: + - package-ecosystem: github-actions + directory: / + labels: + - dependencies + schedule: + interval: daily diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000..ff093dc --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,29 @@ +changelog: + exclude: + labels: + - release-note/ignore + categories: + - title: Exciting New Features 🎉 + labels: + - release-note/new-feature + - title: Enhancements 🚀 + labels: + - enhancement + - release-note/enhancement + - title: Bug Fixes 🐛 + labels: + - bug + - release-note/bug-fix + - title: Breaking Changes 🛠 + labels: + - release-note/breaking-change + - title: Deprecations ❌ + labels: + - release-note/deprecation + - title: Dependency Updates ⬆️ + labels: + - dependencies + - release-note/dependency-update + - title: Other Changes + labels: + - "*" diff --git a/.github/workflows/.editorconfig b/.github/workflows/.editorconfig deleted file mode 100644 index 7bd3346..0000000 --- a/.github/workflows/.editorconfig +++ /dev/null @@ -1,2 +0,0 @@ -[*.yml] -indent_size = 2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yaml similarity index 60% rename from .github/workflows/ci.yml rename to .github/workflows/ci.yaml index e45455c..9c2ec25 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yaml @@ -2,8 +2,7 @@ name: CI on: push: - branches: - - master + branches: [master] pull_request: jobs: @@ -11,20 +10,18 @@ jobs: name: Test runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - go: ['1.11', '1.12', '1.13', '1.14'] - env: - VERBOSE: 1 - GOFLAGS: -mod=readonly + go: ['1.15', '1.16', '1.17', '1.18', '1.19', '1.20'] steps: - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Run tests run: go test -v -race diff --git a/README.md b/README.md index cad6083..42970da 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge-flat.svg)](https://github.com/avelino/awesome-go#utilities) -[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/jonboulle/clockwork/CI?style=flat-square)](https://github.com/jonboulle/clockwork/actions?query=workflow%3ACI) +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/jonboulle/clockwork/ci.yaml?style=flat-square)](https://github.com/jonboulle/clockwork/actions?query=workflow%3ACI) [![Go Report Card](https://goreportcard.com/badge/github.com/jonboulle/clockwork?style=flat-square)](https://goreportcard.com/report/github.com/jonboulle/clockwork) -![Go Version](https://img.shields.io/badge/go%20version-%3E=1.11-61CFDD.svg?style=flat-square) +![Go Version](https://img.shields.io/badge/go%20version-%3E=1.15-61CFDD.svg?style=flat-square) [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/mod/github.com/jonboulle/clockwork) **A simple fake clock for Go.** diff --git a/clockwork.go b/clockwork.go index 1018051..3206b36 100644 --- a/clockwork.go +++ b/clockwork.go @@ -1,30 +1,38 @@ package clockwork import ( + "context" + "sort" "sync" "time" ) -// Clock provides an interface that packages can use instead of directly -// using the time module, so that chronology-related behavior can be tested +// Clock provides an interface that packages can use instead of directly using +// the [time] module, so that chronology-related behavior can be tested. type Clock interface { After(d time.Duration) <-chan time.Time Sleep(d time.Duration) Now() time.Time Since(t time.Time) time.Duration NewTicker(d time.Duration) Ticker + NewTimer(d time.Duration) Timer + AfterFunc(d time.Duration, f func()) Timer } -// FakeClock provides an interface for a clock which can be -// manually advanced through time +// FakeClock provides an interface for a clock which can be manually advanced +// through time. +// +// FakeClock maintains a list of "waiters," which consists of all callers +// waiting on the underlying clock (i.e. Tickers and Timers including callers of +// Sleep or After). Users can call BlockUntil to block until the clock has an +// expected number of waiters. type FakeClock interface { Clock // Advance advances the FakeClock to a new point in time, ensuring any existing - // sleepers are notified appropriately before returning + // waiters are notified appropriately before returning. Advance(d time.Duration) - // BlockUntil will block until the FakeClock has the given number of - // sleepers (callers of Sleep or After) - BlockUntil(n int) + // BlockUntil blocks until the FakeClock has the given number of waiters. + BlockUntil(waiters int) } // NewRealClock returns a Clock which simply delegates calls to the actual time @@ -35,10 +43,11 @@ func NewRealClock() Clock { // NewFakeClock returns a FakeClock implementation which can be // manually advanced through time for testing. The initial time of the -// FakeClock will be an arbitrary non-zero time. +// FakeClock will be the current system time. +// +// Tests that require a deterministic time must use NewFakeClockAt. func NewFakeClock() FakeClock { - // use a fixture that does not fulfill Time.IsZero() - return NewFakeClockAt(time.Date(1984, time.April, 4, 0, 0, 0, 0, time.UTC)) + return NewFakeClockAt(time.Now()) } // NewFakeClockAt returns a FakeClock initialised at the given time.Time. @@ -67,129 +76,274 @@ func (rc *realClock) Since(t time.Time) time.Duration { } func (rc *realClock) NewTicker(d time.Duration) Ticker { - return &realTicker{time.NewTicker(d)} + return realTicker{time.NewTicker(d)} } -type fakeClock struct { - sleepers []*sleeper - blockers []*blocker - time time.Time +func (rc *realClock) NewTimer(d time.Duration) Timer { + return realTimer{time.NewTimer(d)} +} - l sync.RWMutex +func (rc *realClock) AfterFunc(d time.Duration, f func()) Timer { + return realTimer{time.AfterFunc(d, f)} } -// sleeper represents a caller of After or Sleep -type sleeper struct { - until time.Time - done chan time.Time +type fakeClock struct { + // l protects all attributes of the clock, including all attributes of all + // waiters and blockers. + l sync.RWMutex + waiters []expirer + blockers []*blocker + time time.Time } -// blocker represents a caller of BlockUntil +// blocker is a caller of BlockUntil. type blocker struct { count int - ch chan struct{} + + // ch is closed when the underlying clock has the specificed number of blockers. + ch chan struct{} } -// After mimics time.After; it waits for the given duration to elapse on the -// fakeClock, then sends the current time on the returned channel. -func (fc *fakeClock) After(d time.Duration) <-chan time.Time { - fc.l.Lock() - defer fc.l.Unlock() - now := fc.time - done := make(chan time.Time, 1) - if d.Nanoseconds() <= 0 { - // special case - trigger immediately - done <- now - } else { - // otherwise, add to the set of sleepers - s := &sleeper{ - until: now.Add(d), - done: done, - } - fc.sleepers = append(fc.sleepers, s) - // and notify any blockers - fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers)) - } - return done +// expirer is a timer or ticker that expires at some point in the future. +type expirer interface { + // expire the expirer at the given time, returning the desired duration until + // the next expiration, if any. + expire(now time.Time) (next *time.Duration) + + // Get and set the expiration time. + expiry() time.Time + setExpiry(time.Time) } -// notifyBlockers notifies all the blockers waiting until the -// given number of sleepers are waiting on the fakeClock. It -// returns an updated slice of blockers (i.e. those still waiting) -func notifyBlockers(blockers []*blocker, count int) (newBlockers []*blocker) { - for _, b := range blockers { - if b.count == count { - close(b.ch) - } else { - newBlockers = append(newBlockers, b) - } - } - return +// After mimics [time.After]; it waits for the given duration to elapse on the +// fakeClock, then sends the current time on the returned channel. +func (fc *fakeClock) After(d time.Duration) <-chan time.Time { + return fc.NewTimer(d).Chan() } -// Sleep blocks until the given duration has passed on the fakeClock +// Sleep blocks until the given duration has passed on the fakeClock. func (fc *fakeClock) Sleep(d time.Duration) { <-fc.After(d) } -// Time returns the current time of the fakeClock +// Now returns the current time of the fakeClock func (fc *fakeClock) Now() time.Time { fc.l.RLock() - t := fc.time - fc.l.RUnlock() - return t + defer fc.l.RUnlock() + return fc.time } -// Since returns the duration that has passed since the given time on the fakeClock +// Since returns the duration that has passed since the given time on the +// fakeClock. func (fc *fakeClock) Since(t time.Time) time.Duration { return fc.Now().Sub(t) } +// NewTicker returns a Ticker that will expire only after calls to +// fakeClock.Advance() have moved the clock past the given duration. func (fc *fakeClock) NewTicker(d time.Duration) Ticker { - ft := &fakeTicker{ - c: make(chan time.Time, 1), - stop: make(chan bool, 1), - clock: fc, - period: d, + var ft *fakeTicker + ft = &fakeTicker{ + firer: newFirer(), + d: d, + reset: func(d time.Duration) { fc.set(ft, d) }, + stop: func() { fc.stop(ft) }, } - ft.runTickThread() + fc.set(ft, d) return ft } -// Advance advances fakeClock to a new point in time, ensuring channels from any -// previous invocations of After are notified appropriately before returning +// NewTimer returns a Timer that will fire only after calls to +// fakeClock.Advance() have moved the clock past the given duration. +func (fc *fakeClock) NewTimer(d time.Duration) Timer { + return fc.newTimer(d, nil) +} + +// AfterFunc mimics [time.AfterFunc]; it returns a Timer that will invoke the +// given function only after calls to fakeClock.Advance() have moved the clock +// past the given duration. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) Timer { + return fc.newTimer(d, f) +} + +// newTimer returns a new timer, using an optional afterFunc. +func (fc *fakeClock) newTimer(d time.Duration, afterfunc func()) *fakeTimer { + var ft *fakeTimer + ft = &fakeTimer{ + firer: newFirer(), + reset: func(d time.Duration) bool { + fc.l.Lock() + defer fc.l.Unlock() + // fc.l must be held across the calls to stopExpirer & setExpirer. + stopped := fc.stopExpirer(ft) + fc.setExpirer(ft, d) + return stopped + }, + stop: func() bool { return fc.stop(ft) }, + + afterFunc: afterfunc, + } + fc.set(ft, d) + return ft +} + +// Advance advances fakeClock to a new point in time, ensuring waiters and +// blockers are notified appropriately before returning. func (fc *fakeClock) Advance(d time.Duration) { fc.l.Lock() defer fc.l.Unlock() end := fc.time.Add(d) - var newSleepers []*sleeper - for _, s := range fc.sleepers { - if end.Sub(s.until) >= 0 { - s.done <- end - } else { - newSleepers = append(newSleepers, s) + // Expire the earliest waiter until the earliest waiter's expiration is after + // end. + // + // We don't iterate because the callback of the waiter might register a new + // waiter, so the list of waiters might change as we execute this. + for len(fc.waiters) > 0 && !end.Before(fc.waiters[0].expiry()) { + w := fc.waiters[0] + fc.waiters = fc.waiters[1:] + + // Use the waiter's expriation as the current time for this expiration. + now := w.expiry() + fc.time = now + if d := w.expire(now); d != nil { + // Set the new exipration if needed. + fc.setExpirer(w, *d) } } - fc.sleepers = newSleepers - fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers)) fc.time = end } -// BlockUntil will block until the fakeClock has the given number of sleepers -// (callers of Sleep or After) +// BlockUntil blocks until the fakeClock has the given number of waiters. +// +// Prefer BlockUntilContext, which offers context cancellation to prevent +// deadlock. +// +// Deprecation warning: This function might be deprecated in later versions. func (fc *fakeClock) BlockUntil(n int) { - fc.l.Lock() - // Fast path: current number of sleepers is what we're looking for - if len(fc.sleepers) == n { - fc.l.Unlock() + b := fc.newBlocker(n) + if b == nil { return } - // Otherwise, set up a new blocker + <-b.ch +} + +// BlockUntilContext blocks until the fakeClock has the given number of waiters +// or the context is cancelled. +func (fc *fakeClock) BlockUntilContext(ctx context.Context, n int) error { + b := fc.newBlocker(n) + if b == nil { + return nil + } + + select { + case <-b.ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (fc *fakeClock) newBlocker(n int) *blocker { + fc.l.Lock() + defer fc.l.Unlock() + // Fast path: we already have >= n waiters. + if len(fc.waiters) >= n { + return nil + } + // Set up a new blocker to wait for more waiters. b := &blocker{ count: n, ch: make(chan struct{}), } fc.blockers = append(fc.blockers, b) - fc.l.Unlock() - <-b.ch + return b +} + +// stop stops an expirer, returning true if the expirer was stopped. +func (fc *fakeClock) stop(e expirer) bool { + fc.l.Lock() + defer fc.l.Unlock() + return fc.stopExpirer(e) +} + +// stopExpirer stops an expirer, returning true if the expirer was stopped. +// +// The caller must hold fc.l. +func (fc *fakeClock) stopExpirer(e expirer) bool { + for i, t := range fc.waiters { + if t == e { + // Remove element, maintaining order. + copy(fc.waiters[i:], fc.waiters[i+1:]) + fc.waiters[len(fc.waiters)-1] = nil + fc.waiters = fc.waiters[:len(fc.waiters)-1] + return true + } + } + return false +} + +// set sets an expirer to expire at a future point in time. +func (fc *fakeClock) set(e expirer, d time.Duration) { + fc.l.Lock() + defer fc.l.Unlock() + fc.setExpirer(e, d) +} + +// setExpirer sets an expirer to expire at a future point in time. +// +// The caller must hold fc.l. +func (fc *fakeClock) setExpirer(e expirer, d time.Duration) { + if d.Nanoseconds() <= 0 { + // special case - trigger immediately, never reset. + // + // TODO: Explain what cases this covers. + e.expire(fc.time) + return + } + // Add the expirer to the set of waiters and notify any blockers. + e.setExpiry(fc.time.Add(d)) + fc.waiters = append(fc.waiters, e) + sort.Slice(fc.waiters, func(i int, j int) bool { + return fc.waiters[i].expiry().Before(fc.waiters[j].expiry()) + }) + + // Notify blockers of our new waiter. + var blocked []*blocker + count := len(fc.waiters) + for _, b := range fc.blockers { + if b.count <= count { + close(b.ch) + continue + } + blocked = append(blocked, b) + } + fc.blockers = blocked +} + +// firer is used by fakeTimer and fakeTicker used to help implement expirer. +type firer struct { + // The channel associated with the firer, used to send expriation times. + c chan time.Time + + // The time when the firer expires. Only meaningful if the firer is currently + // one of a fakeClock's waiters. + exp time.Time +} + +func newFirer() firer { + return firer{c: make(chan time.Time, 1)} +} + +func (f *firer) Chan() <-chan time.Time { + return f.c +} + +// expiry implements expirer. +func (f *firer) expiry() time.Time { + return f.exp +} + +// setExpiry implements expirer. +func (f *firer) setExpiry(t time.Time) { + f.exp = t } diff --git a/clockwork_test.go b/clockwork_test.go index 6b8b5cf..d76b1d3 100644 --- a/clockwork_test.go +++ b/clockwork_test.go @@ -1,12 +1,20 @@ package clockwork import ( + "context" + "errors" "reflect" "testing" "time" ) +// Use a consistent timeout across tests that block on channels. Keeps test +// timeouts limited while being able to easily extend it to allow the test +// process to get killed, providing a stack trace. +const timeout = time.Minute + func TestFakeClockAfter(t *testing.T) { + t.Parallel() fc := &fakeClock{} neg := fc.After(-1) @@ -81,39 +89,8 @@ func TestFakeClockAfter(t *testing.T) { } } -func TestNotifyBlockers(t *testing.T) { - b1 := &blocker{1, make(chan struct{})} - b2 := &blocker{2, make(chan struct{})} - b3 := &blocker{5, make(chan struct{})} - b4 := &blocker{10, make(chan struct{})} - b5 := &blocker{10, make(chan struct{})} - bs := []*blocker{b1, b2, b3, b4, b5} - bs1 := notifyBlockers(bs, 2) - if n := len(bs1); n != 4 { - t.Fatalf("got %d blockers, want %d", n, 4) - } - select { - case <-b2.ch: - case <-time.After(time.Second): - t.Fatalf("timed out waiting for channel close!") - } - bs2 := notifyBlockers(bs1, 10) - if n := len(bs2); n != 2 { - t.Fatalf("got %d blockers, want %d", n, 2) - } - select { - case <-b4.ch: - case <-time.After(time.Second): - t.Fatalf("timed out waiting for channel close!") - } - select { - case <-b5.ch: - case <-time.After(time.Second): - t.Fatalf("timed out waiting for channel close!") - } -} - func TestNewFakeClock(t *testing.T) { + t.Parallel() fc := NewFakeClock() now := fc.Now() if now.IsZero() { @@ -127,6 +104,7 @@ func TestNewFakeClock(t *testing.T) { } func TestNewFakeClockAt(t *testing.T) { + t.Parallel() t1 := time.Date(1999, time.February, 3, 4, 5, 6, 7, time.UTC) fc := NewFakeClockAt(t1) now := fc.Now() @@ -136,6 +114,7 @@ func TestNewFakeClockAt(t *testing.T) { } func TestFakeClockSince(t *testing.T) { + t.Parallel() fc := NewFakeClock() now := fc.Now() elapsedTime := time.Second @@ -144,3 +123,79 @@ func TestFakeClockSince(t *testing.T) { t.Fatalf("fakeClock.Since() returned unexpected duration, got: %d, want: %d", fc.Since(now), elapsedTime) } } + +// This used to result in a deadlock. +// https://github.com/jonboulle/clockwork/issues/35 +func TestTwoBlockersOneBlock(t *testing.T) { + t.Parallel() + fc := &fakeClock{} + + ft1 := fc.NewTicker(time.Second) + ft2 := fc.NewTicker(time.Second) + + fc.BlockUntil(1) + fc.BlockUntil(2) + ft1.Stop() + ft2.Stop() +} + +func TestBlockUntilContext(t *testing.T) { + t.Parallel() + fc := &fakeClock{} + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + blockCtx, cancelBlock := context.WithCancel(ctx) + errCh := make(chan error) + + go func() { + select { + case errCh <- fc.BlockUntilContext(blockCtx, 2): + case <-ctx.Done(): // Error case, captured below. + } + }() + cancelBlock() + + select { + case err := <-errCh: + if !errors.Is(err, context.Canceled) { + t.Errorf("BlockUntilContext returned %v, want context.Canceled.", err) + } + case <-ctx.Done(): + t.Errorf("Never receved error on context cancellation.") + } +} + +func TestAfterDeliveryInOrder(t *testing.T) { + t.Parallel() + fc := &fakeClock{} + for i := 0; i < 1000; i++ { + three := fc.After(3 * time.Second) + for j := 0; j < 100; j++ { + fc.After(1 * time.Second) + } + two := fc.After(2 * time.Second) + go func() { + fc.Advance(5 * time.Second) + }() + <-three + select { + case <-two: + default: + t.Fatalf("Signals from After delivered out of order") + } + } +} + +// TestFakeClockRace detects data races in fakeClock when invoked with run using `go -race ...`. +// There are no failure conditions when invoked without the -race flag. +func TestFakeClockRace(t *testing.T) { + t.Parallel() + fc := &fakeClock{} + d := time.Second + go func() { fc.Advance(d) }() + go func() { fc.NewTicker(d) }() + go func() { fc.NewTimer(d) }() + go func() { fc.Sleep(d) }() +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..edbb368 --- /dev/null +++ b/context.go @@ -0,0 +1,25 @@ +package clockwork + +import ( + "context" +) + +// contextKey is private to this package so we can ensure uniqueness here. This +// type identifies context values provided by this package. +type contextKey string + +// keyClock provides a clock for injecting during tests. If absent, a real clock should be used. +var keyClock = contextKey("clock") // clockwork.Clock + +// AddToContext creates a derived context that references the specified clock. +func AddToContext(ctx context.Context, clock Clock) context.Context { + return context.WithValue(ctx, keyClock, clock) +} + +// FromContext extracts a clock from the context. If not present, a real clock is returned. +func FromContext(ctx context.Context) Clock { + if clock, ok := ctx.Value(keyClock).(Clock); ok { + return clock + } + return NewRealClock() +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..ee10d5b --- /dev/null +++ b/context_test.go @@ -0,0 +1,26 @@ +package clockwork + +import ( + "context" + "reflect" + "testing" +) + +func TestContextOps(t *testing.T) { + t.Parallel() + ctx := context.Background() + assertIsType(t, NewRealClock(), FromContext(ctx)) + + ctx = AddToContext(ctx, NewFakeClock()) + assertIsType(t, NewFakeClock(), FromContext(ctx)) + + ctx = AddToContext(ctx, NewRealClock()) + assertIsType(t, NewRealClock(), FromContext(ctx)) +} + +func assertIsType(t *testing.T, expectedType, object interface{}) { + t.Helper() + if reflect.TypeOf(object) != reflect.TypeOf(expectedType) { + t.Fatalf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)) + } +} diff --git a/example_test.go b/example_test.go index 3d5f291..2631c5a 100644 --- a/example_test.go +++ b/example_test.go @@ -6,21 +6,20 @@ import ( "time" ) -// myFunc is an example of a time-dependent function, using an -// injected clock +// myFunc is an example of a time-dependent function, using an injected clock. func myFunc(clock Clock, i *int) { clock.Sleep(3 * time.Second) *i += 1 } -// assertState is an example of a state assertion in a test +// assertState is an example of a state assertion in a test. func assertState(t *testing.T, i, j int) { if i != j { t.Fatalf("i %d, j %d", i, j) } } -// TestMyFunc tests myFunc's behaviour with a FakeClock +// TestMyFunc tests myFunc's behaviour with a FakeClock. func TestMyFunc(t *testing.T) { var i int c := NewFakeClock() @@ -32,18 +31,18 @@ func TestMyFunc(t *testing.T) { wg.Done() }() - // Wait until myFunc is actually sleeping on the clock + // Wait until myFunc is actually sleeping on the clock. c.BlockUntil(1) - // Assert the initial state + // Assert the initial state. assertState(t, i, 0) - // Now advance the clock forward in time + // Now advance the clock forward in time. c.Advance(1 * time.Hour) - // Wait until the function completes + // Wait until the function completes. wg.Wait() - // Assert the final state + // Assert the final state. assertState(t, i, 1) } diff --git a/go.mod b/go.mod index 4f4bb16..507295d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/jonboulle/clockwork -go 1.13 +go 1.15 diff --git a/ticker.go b/ticker.go index 32b5d01..b68e4d7 100644 --- a/ticker.go +++ b/ticker.go @@ -1,72 +1,48 @@ package clockwork -import ( - "time" -) +import "time" -// Ticker provides an interface which can be used instead of directly -// using the ticker within the time module. The real-time ticker t -// provides ticks through t.C which becomes now t.Chan() to make -// this channel requirement definable in this interface. +// Ticker provides an interface which can be used instead of directly using +// [time.Ticker]. The real-time ticker t provides ticks through t.C which +// becomes t.Chan() to make this channel requirement definable in this +// interface. type Ticker interface { Chan() <-chan time.Time + Reset(d time.Duration) Stop() } type realTicker struct{ *time.Ticker } -func (rt *realTicker) Chan() <-chan time.Time { - return rt.C +func (r realTicker) Chan() <-chan time.Time { + return r.C } type fakeTicker struct { - c chan time.Time - stop chan bool - clock FakeClock - period time.Duration + firer + + // reset and stop provide the implementation of the respective exported + // functions. + reset func(d time.Duration) + stop func() + + // The duration of the ticker. + d time.Duration } -func (ft *fakeTicker) Chan() <-chan time.Time { - return ft.c +func (f *fakeTicker) Reset(d time.Duration) { + f.reset(d) } -func (ft *fakeTicker) Stop() { - ft.stop <- true +func (f *fakeTicker) Stop() { + f.stop() } -// runTickThread initializes a background goroutine to send the tick time to the ticker channel -// after every period. Tick events are discarded if the underlying ticker channel does not have -// enough capacity. -func (ft *fakeTicker) runTickThread() { - nextTick := ft.clock.Now().Add(ft.period) - next := ft.clock.After(ft.period) - go func() { - for { - select { - case <-ft.stop: - return - case <-next: - // We send the time that the tick was supposed to occur at. - tick := nextTick - // Before sending the tick, we'll compute the next tick time and star the clock.After call. - now := ft.clock.Now() - // First, figure out how many periods there have been between "now" and the time we were - // supposed to have trigged, then advance over all of those. - skipTicks := (now.Sub(tick) + ft.period - 1) / ft.period - nextTick = nextTick.Add(skipTicks * ft.period) - // Now, keep advancing until we are past now. This should happen at most once. - for !nextTick.After(now) { - nextTick = nextTick.Add(ft.period) - } - // Figure out how long between now and the next scheduled tick, then wait that long. - remaining := nextTick.Sub(now) - next = ft.clock.After(remaining) - // Finally, we can actually send the tick. - select { - case ft.c <- tick: - default: - } - } - } - }() +func (f *fakeTicker) expire(now time.Time) *time.Duration { + // Never block on expiration. + select { + case f.c <- now: + default: + } + return &f.d } diff --git a/ticker_test.go b/ticker_test.go index 1f34036..61a7b53 100644 --- a/ticker_test.go +++ b/ticker_test.go @@ -1,11 +1,13 @@ package clockwork import ( + "context" "testing" "time" ) func TestFakeTickerStop(t *testing.T) { + t.Parallel() fc := &fakeClock{} ft := fc.NewTicker(1) @@ -18,6 +20,10 @@ func TestFakeTickerStop(t *testing.T) { } func TestFakeTickerTick(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + fc := &fakeClock{} now := fc.Now() @@ -39,7 +45,7 @@ func TestFakeTickerTick(t *testing.T) { if tick != first { t.Errorf("wrong tick time, got: %v, want: %v", tick, first) } - case <-time.After(time.Millisecond): + case <-ctx.Done(): t.Errorf("expected tick!") } @@ -52,13 +58,17 @@ func TestFakeTickerTick(t *testing.T) { if tick != second { t.Errorf("wrong tick time, got: %v, want: %v", tick, second) } - case <-time.After(time.Millisecond): + case <-ctx.Done(): t.Errorf("expected tick!") } ft.Stop() } func TestFakeTicker_Race(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + fc := NewFakeClock() tickTime := 1 * time.Millisecond @@ -67,23 +77,57 @@ func TestFakeTicker_Race(t *testing.T) { fc.Advance(tickTime) - timeout := time.NewTimer(500 * time.Millisecond) - defer timeout.Stop() - select { case <-ticker.Chan(): - // Pass - case <-timeout.C: + case <-ctx.Done(): t.Fatalf("Ticker didn't detect the clock advance!") } } func TestFakeTicker_Race2(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() fc := NewFakeClock() ft := fc.NewTicker(5 * time.Second) for i := 0; i < 100; i++ { fc.Advance(5 * time.Second) - <-ft.Chan() + select { + case <-ft.Chan(): + case <-ctx.Done(): + t.Fatalf("Ticker didn't detect the clock advance!") + } + } ft.Stop() } + +func TestFakeTicker_DeliveryOrder(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + fc := NewFakeClock() + ticker := fc.NewTicker(2 * time.Second).Chan() + timer := fc.NewTimer(5 * time.Second).Chan() + go func() { + for j := 0; j < 10; j++ { + fc.BlockUntil(1) + fc.Advance(1 * time.Second) + } + }() + <-ticker + a := <-timer + // Only perform ordering check if ticker channel is drained at first. + select { + case <-ticker: + default: + select { + case b := <-ticker: + if a.After(b) { + t.Fatalf("Expected timer before ticker, got timer %v after %v", a, b) + } + case <-ctx.Done(): + t.Fatalf("Expected ticker event didn't arrive!") + } + } +} diff --git a/timer.go b/timer.go new file mode 100644 index 0000000..6f928b3 --- /dev/null +++ b/timer.go @@ -0,0 +1,53 @@ +package clockwork + +import "time" + +// Timer provides an interface which can be used instead of directly using +// [time.Timer]. The real-time timer t provides events through t.C which becomes +// t.Chan() to make this channel requirement definable in this interface. +type Timer interface { + Chan() <-chan time.Time + Reset(d time.Duration) bool + Stop() bool +} + +type realTimer struct{ *time.Timer } + +func (r realTimer) Chan() <-chan time.Time { + return r.C +} + +type fakeTimer struct { + firer + + // reset and stop provide the implmenetation of the respective exported + // functions. + reset func(d time.Duration) bool + stop func() bool + + // If present when the timer fires, the timer calls afterFunc in its own + // goroutine rather than sending the time on Chan(). + afterFunc func() +} + +func (f *fakeTimer) Reset(d time.Duration) bool { + return f.reset(d) +} + +func (f *fakeTimer) Stop() bool { + return f.stop() +} + +func (f *fakeTimer) expire(now time.Time) *time.Duration { + if f.afterFunc != nil { + go f.afterFunc() + return nil + } + + // Never block on expiration. + select { + case f.c <- now: + default: + } + return nil +} diff --git a/timer_test.go b/timer_test.go new file mode 100644 index 0000000..90776af --- /dev/null +++ b/timer_test.go @@ -0,0 +1,224 @@ +package clockwork + +import ( + "context" + "testing" + "time" +) + +func TestFakeClockTimerStop(t *testing.T) { + t.Parallel() + fc := &fakeClock{} + + ft := fc.NewTimer(1) + ft.Stop() + select { + case <-ft.Chan(): + t.Errorf("received unexpected tick!") + default: + } +} + +func TestFakeClockTimers(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + fc := &fakeClock{} + + zero := fc.NewTimer(0) + + if zero.Stop() { + t.Errorf("zero timer could be stopped") + } + + select { + case <-zero.Chan(): + case <-ctx.Done(): + t.Errorf("zero timer didn't emit time") + } + + one := fc.NewTimer(1) + + select { + case <-one.Chan(): + t.Errorf("non-zero timer did emit time") + default: + } + if !one.Stop() { + t.Errorf("non-zero timer couldn't be stopped") + } + + fc.Advance(5) + + select { + case <-one.Chan(): + t.Errorf("stopped timer did emit time") + default: + } + + if one.Reset(1) { + t.Errorf("resetting stopped timer didn't return false") + } + if !one.Reset(1) { + t.Errorf("resetting active timer didn't return true") + } + + fc.Advance(1) + + select { + case <-time.After(500 * time.Millisecond): + } + + if one.Stop() { + t.Errorf("triggered timer could be stopped") + } + + select { + case <-one.Chan(): + case <-ctx.Done(): + t.Errorf("triggered timer didn't emit time") + } + + fc.Advance(1) + + select { + case <-one.Chan(): + t.Errorf("triggered timer emitted time more than once") + default: + } + + one.Reset(0) + + if one.Stop() { + t.Errorf("reset to zero timer could be stopped") + } + + select { + case <-one.Chan(): + case <-ctx.Done(): + t.Errorf("reset to zero timer didn't emit time") + } +} + +func TestFakeClockTimer_Race(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + fc := NewFakeClock() + timer := fc.NewTimer(1 * time.Millisecond) + defer timer.Stop() + fc.Advance(1 * time.Millisecond) + + select { + case <-timer.Chan(): + case <-ctx.Done(): + t.Fatalf("Timer didn't detect the clock advance!") + } +} + +func TestFakeClockTimer_Race2(t *testing.T) { + t.Parallel() + fc := NewFakeClock() + timer := fc.NewTimer(5 * time.Second) + for i := 0; i < 100; i++ { + fc.Advance(5 * time.Second) + <-timer.Chan() + timer.Reset(5 * time.Second) + } + timer.Stop() +} + +func TestFakeClockTimer_ResetRace(t *testing.T) { + t.Parallel() + fc := NewFakeClock() + d := 5 * time.Second + var times []time.Time + timer := fc.NewTimer(d) + timerStopped := make(chan struct{}) + doneAddingTimes := make(chan struct{}) + go func() { + defer close(doneAddingTimes) + for { + select { + case <-timerStopped: + return + case now := <-timer.Chan(): + times = append(times, now) + } + } + }() + for i := 0; i < 100; i++ { + for j := 0; j < 10; j++ { + timer.Reset(d) + } + fc.Advance(d) + } + timer.Stop() + close(timerStopped) + <-doneAddingTimes // Prevent race condition on times. + for i := 1; i < len(times); i++ { + if times[i-1].Equal(times[i]) { + t.Fatalf("Timer repeatedly reported the same time.") + } + } +} + +func TestFakeClockTimer_ZeroResetDoesNotBlock(t *testing.T) { + t.Parallel() + fc := NewFakeClock() + timer := fc.NewTimer(0) + for i := 0; i < 10; i++ { + timer.Reset(0) + } + <-timer.Chan() +} + +func TestAfterFunc_Concurrent(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + fc := NewFakeClock() + blocker := make(chan struct{}) + ch := make(chan int) + // AfterFunc should start goroutines, so each should be able to make progress + // independent of the others. + fc.AfterFunc(2*time.Second, func() { + <-blocker + ch <- 222 + }) + fc.AfterFunc(2*time.Second, func() { + ch <- 111 + }) + fc.AfterFunc(2*time.Second, func() { + <-blocker + ch <- 222 + }) + fc.Advance(2 * time.Second) + select { + case a := <-ch: + if a != 111 { + t.Fatalf("Expected 111, got %d", a) + } + case <-ctx.Done(): + t.Fatalf("Expected signal hasn't arrived") + } + close(blocker) + select { + case a := <-ch: + if a != 222 { + t.Fatalf("Expected 222, got %d", a) + } + case <-ctx.Done(): + t.Fatalf("Expected signal hasn't arrived") + } + select { + case a := <-ch: + if a != 222 { + t.Fatalf("Expected 222, got %d", a) + } + case <-ctx.Done(): + t.Fatalf("Expected signal hasn't arrived") + } +}