diff --git a/internal/writers/writers.go b/internal/writers/writers.go index 70ab6220..b40f4e73 100644 --- a/internal/writers/writers.go +++ b/internal/writers/writers.go @@ -1,8 +1,10 @@ package writers import ( + "bufio" "bytes" "errors" + "fmt" "io" ) @@ -85,4 +87,85 @@ func (n *NopCloser) Close() error { var ( _ io.WriteCloser = (*NopCloser)(nil) _ io.WriteCloser = (*emptySkipper)(nil) + _ io.WriteCloser = (*sameSkipper)(nil) ) + +type sameSkipper struct { + open func() (io.WriteCloser, error) + + // internal + r *bufio.Reader + w io.WriteCloser + buf *bytes.Buffer + diff bool +} + +// SameSkipper creates an io.WriteCloser that will only start writing once a +// difference with the current output has been encountered. The wrapped +// io.WriteCloser must be provided by 'open'. +func SameSkipper(r io.Reader, open func() (io.WriteCloser, error)) io.WriteCloser { + br := bufio.NewReader(r) + return &sameSkipper{ + r: br, + w: nil, + buf: &bytes.Buffer{}, + diff: false, + open: open, + } +} + +// Write - writes to the buffer, until a difference with the output is found, +// then flushes and writes to the wrapped writer. +func (f *sameSkipper) Write(p []byte) (n int, err error) { + if !f.diff { + in := make([]byte, len(p)) + _, err := f.r.Read(in) + if err != nil && err != io.EOF { + return 0, fmt.Errorf("failed to read: %w", err) + } + if bytes.Equal(in, p) { + return f.buf.Write(p) + } + + f.diff = true + err = f.flush() + if err != nil { + return 0, err + } + } + return f.w.Write(p) +} + +func (f *sameSkipper) flush() (err error) { + if f.w == nil { + f.w, err = f.open() + if err != nil { + return err + } + if f.w == nil { + return fmt.Errorf("nil writer returned by open") + } + } + // empty the buffer into the wrapped writer + _, err = f.buf.WriteTo(f.w) + return err +} + +// Close - implements io.Closer +func (f *sameSkipper) Close() error { + // Check to see if we missed anything in the reader + if !f.diff { + n, err := f.r.Peek(1) + if len(n) > 0 || err != io.EOF { + err = f.flush() + if err != nil { + return fmt.Errorf("failed to flush on close: %w", err) + } + } + } + + if f.w != nil { + return f.w.Close() + } + return nil +} diff --git a/internal/writers/writers_test.go b/internal/writers/writers_test.go index df9ccd8f..b3452fcd 100644 --- a/internal/writers/writers_test.go +++ b/internal/writers/writers_test.go @@ -2,6 +2,7 @@ package writers import ( "bytes" + "fmt" "io" "testing" @@ -47,6 +48,8 @@ func TestEmptySkipper(t *testing.T) { n, err := f.Write(d.in) assert.NoError(t, err) assert.Equal(t, len(d.in), n) + err = f.Close() + assert.NoError(t, err) if d.empty { assert.Nil(t, f.w) assert.False(t, opened) @@ -65,3 +68,46 @@ type bufferCloser struct { func (b *bufferCloser) Close() error { return nil } + +func TestSameSkipper(t *testing.T) { + testdata := []struct { + in []byte + out []byte + same bool + }{ + {[]byte(" "), []byte(" "), true}, + {[]byte("foo"), []byte("foo"), true}, + {[]byte("foo"), nil, false}, + {[]byte("foo"), []byte("bar"), false}, + {[]byte("foobar"), []byte("foo"), false}, + {[]byte("foo"), []byte("foobar"), false}, + } + + for _, d := range testdata { + t.Run(fmt.Sprintf("in:%q/out:%q/same:%v", d.in, d.out, d.same), func(t *testing.T) { + r := bytes.NewBuffer(d.out) + w := &bufferCloser{&bytes.Buffer{}} + opened := false + f, ok := SameSkipper(r, func() (io.WriteCloser, error) { + opened = true + return w, nil + }).(*sameSkipper) + assert.True(t, ok) + + n, err := f.Write(d.in) + assert.NoError(t, err) + assert.Equal(t, len(d.in), n) + err = f.Close() + assert.NoError(t, err) + if d.same { + assert.Nil(t, f.w) + assert.False(t, opened) + assert.Empty(t, w.Bytes()) + } else { + assert.NotNil(t, f.w) + assert.True(t, opened) + assert.EqualValues(t, d.in, w.Bytes()) + } + }) + } +} diff --git a/template.go b/template.go index 3559b5ca..b500ce73 100644 --- a/template.go +++ b/template.go @@ -249,10 +249,26 @@ func createOutFile(filename string, mode os.FileMode, modeOverride bool) (out io return nil, fmt.Errorf("failed to chmod output file '%s' with mode %q: %w", filename, mode.Perm(), err) } } - out, err = fs.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()) - if err != nil { + + open := func() (out io.WriteCloser, err error) { + out, err = fs.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()) + if err != nil { + return out, fmt.Errorf("failed to open output file '%s' for writing: %w", filename, err) + } + return out, err } + + // if the output file already exists, we'll use a SameSkipper + f, err := fs.OpenFile(filename, os.O_RDONLY, mode.Perm()) + if err != nil { + // likely means the file just doesn't exist - open's error will be more useful + return open() + } + out = writers.SameSkipper(f, func() (io.WriteCloser, error) { + return open() + }) + return out, err }