Source file src/archive/zip/writer_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zip
     6  
     7  import (
     8  	"bytes"
     9  	"compress/flate"
    10  	"encoding/binary"
    11  	"fmt"
    12  	"hash/crc32"
    13  	"io"
    14  	"io/fs"
    15  	"math/rand"
    16  	"os"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // TODO(adg): a more sophisticated test suite
    23  
    24  type WriteTest struct {
    25  	Name   string
    26  	Data   []byte
    27  	Method uint16
    28  	Mode   fs.FileMode
    29  }
    30  
    31  var writeTests = []WriteTest{
    32  	{
    33  		Name:   "foo",
    34  		Data:   []byte("Rabbits, guinea pigs, gophers, marsupial rats, and quolls."),
    35  		Method: Store,
    36  		Mode:   0666,
    37  	},
    38  	{
    39  		Name:   "bar",
    40  		Data:   nil, // large data set in the test
    41  		Method: Deflate,
    42  		Mode:   0644,
    43  	},
    44  	{
    45  		Name:   "setuid",
    46  		Data:   []byte("setuid file"),
    47  		Method: Deflate,
    48  		Mode:   0755 | fs.ModeSetuid,
    49  	},
    50  	{
    51  		Name:   "setgid",
    52  		Data:   []byte("setgid file"),
    53  		Method: Deflate,
    54  		Mode:   0755 | fs.ModeSetgid,
    55  	},
    56  	{
    57  		Name:   "symlink",
    58  		Data:   []byte("../link/target"),
    59  		Method: Deflate,
    60  		Mode:   0755 | fs.ModeSymlink,
    61  	},
    62  	{
    63  		Name:   "device",
    64  		Data:   []byte("device file"),
    65  		Method: Deflate,
    66  		Mode:   0755 | fs.ModeDevice,
    67  	},
    68  	{
    69  		Name:   "chardevice",
    70  		Data:   []byte("char device file"),
    71  		Method: Deflate,
    72  		Mode:   0755 | fs.ModeDevice | fs.ModeCharDevice,
    73  	},
    74  }
    75  
    76  func TestWriter(t *testing.T) {
    77  	largeData := make([]byte, 1<<17)
    78  	if _, err := rand.Read(largeData); err != nil {
    79  		t.Fatal("rand.Read failed:", err)
    80  	}
    81  	writeTests[1].Data = largeData
    82  	defer func() {
    83  		writeTests[1].Data = nil
    84  	}()
    85  
    86  	// write a zip file
    87  	buf := new(bytes.Buffer)
    88  	w := NewWriter(buf)
    89  
    90  	for _, wt := range writeTests {
    91  		testCreate(t, w, &wt)
    92  	}
    93  
    94  	if err := w.Close(); err != nil {
    95  		t.Fatal(err)
    96  	}
    97  
    98  	// read it back
    99  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  	for i, wt := range writeTests {
   104  		testReadFile(t, r.File[i], &wt)
   105  	}
   106  }
   107  
   108  // TestWriterComment is test for EOCD comment read/write.
   109  func TestWriterComment(t *testing.T) {
   110  	var tests = []struct {
   111  		comment string
   112  		ok      bool
   113  	}{
   114  		{"hi, hello", true},
   115  		{"hi, こんにちわ", true},
   116  		{strings.Repeat("a", uint16max), true},
   117  		{strings.Repeat("a", uint16max+1), false},
   118  	}
   119  
   120  	for _, test := range tests {
   121  		// write a zip file
   122  		buf := new(bytes.Buffer)
   123  		w := NewWriter(buf)
   124  		if err := w.SetComment(test.comment); err != nil {
   125  			if test.ok {
   126  				t.Fatalf("SetComment: unexpected error %v", err)
   127  			}
   128  			continue
   129  		} else {
   130  			if !test.ok {
   131  				t.Fatalf("SetComment: unexpected success, want error")
   132  			}
   133  		}
   134  
   135  		if err := w.Close(); test.ok == (err != nil) {
   136  			t.Fatal(err)
   137  		}
   138  
   139  		if w.closed != test.ok {
   140  			t.Fatalf("Writer.closed: got %v, want %v", w.closed, test.ok)
   141  		}
   142  
   143  		// skip read test in failure cases
   144  		if !test.ok {
   145  			continue
   146  		}
   147  
   148  		// read it back
   149  		r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   150  		if err != nil {
   151  			t.Fatal(err)
   152  		}
   153  		if r.Comment != test.comment {
   154  			t.Fatalf("Reader.Comment: got %v, want %v", r.Comment, test.comment)
   155  		}
   156  	}
   157  }
   158  
   159  func TestWriterUTF8(t *testing.T) {
   160  	var utf8Tests = []struct {
   161  		name    string
   162  		comment string
   163  		nonUTF8 bool
   164  		flags   uint16
   165  	}{
   166  		{
   167  			name:    "hi, hello",
   168  			comment: "in the world",
   169  			flags:   0x8,
   170  		},
   171  		{
   172  			name:    "hi, こんにちわ",
   173  			comment: "in the world",
   174  			flags:   0x808,
   175  		},
   176  		{
   177  			name:    "hi, こんにちわ",
   178  			comment: "in the world",
   179  			nonUTF8: true,
   180  			flags:   0x8,
   181  		},
   182  		{
   183  			name:    "hi, hello",
   184  			comment: "in the 世界",
   185  			flags:   0x808,
   186  		},
   187  		{
   188  			name:    "hi, こんにちわ",
   189  			comment: "in the 世界",
   190  			flags:   0x808,
   191  		},
   192  		{
   193  			name:    "the replacement rune is �",
   194  			comment: "the replacement rune is �",
   195  			flags:   0x808,
   196  		},
   197  		{
   198  			// Name is Japanese encoded in Shift JIS.
   199  			name:    "\x93\xfa\x96{\x8c\xea.txt",
   200  			comment: "in the 世界",
   201  			flags:   0x008, // UTF-8 must not be set
   202  		},
   203  	}
   204  
   205  	// write a zip file
   206  	buf := new(bytes.Buffer)
   207  	w := NewWriter(buf)
   208  
   209  	for _, test := range utf8Tests {
   210  		h := &FileHeader{
   211  			Name:    test.name,
   212  			Comment: test.comment,
   213  			NonUTF8: test.nonUTF8,
   214  			Method:  Deflate,
   215  		}
   216  		w, err := w.CreateHeader(h)
   217  		if err != nil {
   218  			t.Fatal(err)
   219  		}
   220  		w.Write([]byte{})
   221  	}
   222  
   223  	if err := w.Close(); err != nil {
   224  		t.Fatal(err)
   225  	}
   226  
   227  	// read it back
   228  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   229  	if err != nil {
   230  		t.Fatal(err)
   231  	}
   232  	for i, test := range utf8Tests {
   233  		flags := r.File[i].Flags
   234  		if flags != test.flags {
   235  			t.Errorf("CreateHeader(name=%q comment=%q nonUTF8=%v): flags=%#x, want %#x", test.name, test.comment, test.nonUTF8, flags, test.flags)
   236  		}
   237  	}
   238  }
   239  
   240  func TestWriterTime(t *testing.T) {
   241  	var buf bytes.Buffer
   242  	h := &FileHeader{
   243  		Name:     "test.txt",
   244  		Modified: time.Date(2017, 10, 31, 21, 11, 57, 0, timeZone(-7*time.Hour)),
   245  	}
   246  	w := NewWriter(&buf)
   247  	if _, err := w.CreateHeader(h); err != nil {
   248  		t.Fatalf("unexpected CreateHeader error: %v", err)
   249  	}
   250  	if err := w.Close(); err != nil {
   251  		t.Fatalf("unexpected Close error: %v", err)
   252  	}
   253  
   254  	want, err := os.ReadFile("testdata/time-go.zip")
   255  	if err != nil {
   256  		t.Fatalf("unexpected ReadFile error: %v", err)
   257  	}
   258  	if got := buf.Bytes(); !bytes.Equal(got, want) {
   259  		fmt.Printf("%x\n%x\n", got, want)
   260  		t.Error("contents of time-go.zip differ")
   261  	}
   262  }
   263  
   264  func TestWriterOffset(t *testing.T) {
   265  	largeData := make([]byte, 1<<17)
   266  	if _, err := rand.Read(largeData); err != nil {
   267  		t.Fatal("rand.Read failed:", err)
   268  	}
   269  	writeTests[1].Data = largeData
   270  	defer func() {
   271  		writeTests[1].Data = nil
   272  	}()
   273  
   274  	// write a zip file
   275  	buf := new(bytes.Buffer)
   276  	existingData := []byte{1, 2, 3, 1, 2, 3, 1, 2, 3}
   277  	n, _ := buf.Write(existingData)
   278  	w := NewWriter(buf)
   279  	w.SetOffset(int64(n))
   280  
   281  	for _, wt := range writeTests {
   282  		testCreate(t, w, &wt)
   283  	}
   284  
   285  	if err := w.Close(); err != nil {
   286  		t.Fatal(err)
   287  	}
   288  
   289  	// read it back
   290  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  	for i, wt := range writeTests {
   295  		testReadFile(t, r.File[i], &wt)
   296  	}
   297  }
   298  
   299  func TestWriterFlush(t *testing.T) {
   300  	var buf bytes.Buffer
   301  	w := NewWriter(struct{ io.Writer }{&buf})
   302  	_, err := w.Create("foo")
   303  	if err != nil {
   304  		t.Fatal(err)
   305  	}
   306  	if buf.Len() > 0 {
   307  		t.Fatalf("Unexpected %d bytes already in buffer", buf.Len())
   308  	}
   309  	if err := w.Flush(); err != nil {
   310  		t.Fatal(err)
   311  	}
   312  	if buf.Len() == 0 {
   313  		t.Fatal("No bytes written after Flush")
   314  	}
   315  }
   316  
   317  func TestWriterDir(t *testing.T) {
   318  	w := NewWriter(io.Discard)
   319  	dw, err := w.Create("dir/")
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  	if _, err := dw.Write(nil); err != nil {
   324  		t.Errorf("Write(nil) to directory: got %v, want nil", err)
   325  	}
   326  	if _, err := dw.Write([]byte("hello")); err == nil {
   327  		t.Error(`Write("hello") to directory: got nil error, want non-nil`)
   328  	}
   329  }
   330  
   331  func TestWriterDirAttributes(t *testing.T) {
   332  	var buf bytes.Buffer
   333  	w := NewWriter(&buf)
   334  	if _, err := w.CreateHeader(&FileHeader{
   335  		Name:               "dir/",
   336  		Method:             Deflate,
   337  		CompressedSize64:   1234,
   338  		UncompressedSize64: 5678,
   339  	}); err != nil {
   340  		t.Fatal(err)
   341  	}
   342  	if err := w.Close(); err != nil {
   343  		t.Fatal(err)
   344  	}
   345  	b := buf.Bytes()
   346  
   347  	var sig [4]byte
   348  	binary.LittleEndian.PutUint32(sig[:], uint32(fileHeaderSignature))
   349  
   350  	idx := bytes.Index(b, sig[:])
   351  	if idx == -1 {
   352  		t.Fatal("file header not found")
   353  	}
   354  	b = b[idx:]
   355  
   356  	if !bytes.Equal(b[6:10], []byte{0, 0, 0, 0}) { // FileHeader.Flags: 0, FileHeader.Method: 0
   357  		t.Errorf("unexpected method and flags: %v", b[6:10])
   358  	}
   359  
   360  	if !bytes.Equal(b[14:26], make([]byte, 12)) { // FileHeader.{CRC32,CompressSize,UncompressedSize} all zero.
   361  		t.Errorf("unexpected crc, compress and uncompressed size to be 0 was: %v", b[14:26])
   362  	}
   363  
   364  	binary.LittleEndian.PutUint32(sig[:], uint32(dataDescriptorSignature))
   365  	if bytes.Contains(b, sig[:]) {
   366  		t.Error("there should be no data descriptor")
   367  	}
   368  }
   369  
   370  func TestWriterCopy(t *testing.T) {
   371  	// make a zip file
   372  	buf := new(bytes.Buffer)
   373  	w := NewWriter(buf)
   374  	for _, wt := range writeTests {
   375  		testCreate(t, w, &wt)
   376  	}
   377  	if err := w.Close(); err != nil {
   378  		t.Fatal(err)
   379  	}
   380  
   381  	// read it back
   382  	src, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   383  	if err != nil {
   384  		t.Fatal(err)
   385  	}
   386  	for i, wt := range writeTests {
   387  		testReadFile(t, src.File[i], &wt)
   388  	}
   389  
   390  	// make a new zip file copying the old compressed data.
   391  	buf2 := new(bytes.Buffer)
   392  	dst := NewWriter(buf2)
   393  	for _, f := range src.File {
   394  		if err := dst.Copy(f); err != nil {
   395  			t.Fatal(err)
   396  		}
   397  	}
   398  	if err := dst.Close(); err != nil {
   399  		t.Fatal(err)
   400  	}
   401  
   402  	// read the new one back
   403  	r, err := NewReader(bytes.NewReader(buf2.Bytes()), int64(buf2.Len()))
   404  	if err != nil {
   405  		t.Fatal(err)
   406  	}
   407  	for i, wt := range writeTests {
   408  		testReadFile(t, r.File[i], &wt)
   409  	}
   410  }
   411  
   412  func TestWriterCreateRaw(t *testing.T) {
   413  	files := []struct {
   414  		name             string
   415  		content          []byte
   416  		method           uint16
   417  		flags            uint16
   418  		crc32            uint32
   419  		uncompressedSize uint64
   420  		compressedSize   uint64
   421  	}{
   422  		{
   423  			name:    "small store w desc",
   424  			content: []byte("gophers"),
   425  			method:  Store,
   426  			flags:   0x8,
   427  		},
   428  		{
   429  			name:    "small deflate wo desc",
   430  			content: bytes.Repeat([]byte("abcdefg"), 2048),
   431  			method:  Deflate,
   432  		},
   433  	}
   434  
   435  	// write a zip file
   436  	archive := new(bytes.Buffer)
   437  	w := NewWriter(archive)
   438  
   439  	for i := range files {
   440  		f := &files[i]
   441  		f.crc32 = crc32.ChecksumIEEE(f.content)
   442  		size := uint64(len(f.content))
   443  		f.uncompressedSize = size
   444  		f.compressedSize = size
   445  
   446  		var compressedContent []byte
   447  		if f.method == Deflate {
   448  			var buf bytes.Buffer
   449  			w, err := flate.NewWriter(&buf, flate.BestSpeed)
   450  			if err != nil {
   451  				t.Fatalf("flate.NewWriter err = %v", err)
   452  			}
   453  			_, err = w.Write(f.content)
   454  			if err != nil {
   455  				t.Fatalf("flate Write err = %v", err)
   456  			}
   457  			err = w.Close()
   458  			if err != nil {
   459  				t.Fatalf("flate Writer.Close err = %v", err)
   460  			}
   461  			compressedContent = buf.Bytes()
   462  			f.compressedSize = uint64(len(compressedContent))
   463  		}
   464  
   465  		h := &FileHeader{
   466  			Name:               f.name,
   467  			Method:             f.method,
   468  			Flags:              f.flags,
   469  			CRC32:              f.crc32,
   470  			CompressedSize64:   f.compressedSize,
   471  			UncompressedSize64: f.uncompressedSize,
   472  		}
   473  		w, err := w.CreateRaw(h)
   474  		if err != nil {
   475  			t.Fatal(err)
   476  		}
   477  		if compressedContent != nil {
   478  			_, err = w.Write(compressedContent)
   479  		} else {
   480  			_, err = w.Write(f.content)
   481  		}
   482  		if err != nil {
   483  			t.Fatalf("%s Write got %v; want nil", f.name, err)
   484  		}
   485  	}
   486  
   487  	if err := w.Close(); err != nil {
   488  		t.Fatal(err)
   489  	}
   490  
   491  	// read it back
   492  	r, err := NewReader(bytes.NewReader(archive.Bytes()), int64(archive.Len()))
   493  	if err != nil {
   494  		t.Fatal(err)
   495  	}
   496  	for i, want := range files {
   497  		got := r.File[i]
   498  		if got.Name != want.name {
   499  			t.Errorf("got Name %s; want %s", got.Name, want.name)
   500  		}
   501  		if got.Method != want.method {
   502  			t.Errorf("%s: got Method %#x; want %#x", want.name, got.Method, want.method)
   503  		}
   504  		if got.Flags != want.flags {
   505  			t.Errorf("%s: got Flags %#x; want %#x", want.name, got.Flags, want.flags)
   506  		}
   507  		if got.CRC32 != want.crc32 {
   508  			t.Errorf("%s: got CRC32 %#x; want %#x", want.name, got.CRC32, want.crc32)
   509  		}
   510  		if got.CompressedSize64 != want.compressedSize {
   511  			t.Errorf("%s: got CompressedSize64 %d; want %d", want.name, got.CompressedSize64, want.compressedSize)
   512  		}
   513  		if got.UncompressedSize64 != want.uncompressedSize {
   514  			t.Errorf("%s: got UncompressedSize64 %d; want %d", want.name, got.UncompressedSize64, want.uncompressedSize)
   515  		}
   516  
   517  		r, err := got.Open()
   518  		if err != nil {
   519  			t.Errorf("%s: Open err = %v", got.Name, err)
   520  			continue
   521  		}
   522  
   523  		buf, err := io.ReadAll(r)
   524  		if err != nil {
   525  			t.Errorf("%s: ReadAll err = %v", got.Name, err)
   526  			continue
   527  		}
   528  
   529  		if !bytes.Equal(buf, want.content) {
   530  			t.Errorf("%v: ReadAll returned unexpected bytes", got.Name)
   531  		}
   532  	}
   533  }
   534  
   535  func testCreate(t *testing.T, w *Writer, wt *WriteTest) {
   536  	header := &FileHeader{
   537  		Name:   wt.Name,
   538  		Method: wt.Method,
   539  	}
   540  	if wt.Mode != 0 {
   541  		header.SetMode(wt.Mode)
   542  	}
   543  	f, err := w.CreateHeader(header)
   544  	if err != nil {
   545  		t.Fatal(err)
   546  	}
   547  	_, err = f.Write(wt.Data)
   548  	if err != nil {
   549  		t.Fatal(err)
   550  	}
   551  }
   552  
   553  func testReadFile(t *testing.T, f *File, wt *WriteTest) {
   554  	if f.Name != wt.Name {
   555  		t.Fatalf("File name: got %q, want %q", f.Name, wt.Name)
   556  	}
   557  	testFileMode(t, f, wt.Mode)
   558  	rc, err := f.Open()
   559  	if err != nil {
   560  		t.Fatalf("opening %s: %v", f.Name, err)
   561  	}
   562  	b, err := io.ReadAll(rc)
   563  	if err != nil {
   564  		t.Fatalf("reading %s: %v", f.Name, err)
   565  	}
   566  	err = rc.Close()
   567  	if err != nil {
   568  		t.Fatalf("closing %s: %v", f.Name, err)
   569  	}
   570  	if !bytes.Equal(b, wt.Data) {
   571  		t.Errorf("File contents %q, want %q", b, wt.Data)
   572  	}
   573  }
   574  
   575  func BenchmarkCompressedZipGarbage(b *testing.B) {
   576  	bigBuf := bytes.Repeat([]byte("a"), 1<<20)
   577  
   578  	runOnce := func(buf *bytes.Buffer) {
   579  		buf.Reset()
   580  		zw := NewWriter(buf)
   581  		for j := 0; j < 3; j++ {
   582  			w, _ := zw.CreateHeader(&FileHeader{
   583  				Name:   "foo",
   584  				Method: Deflate,
   585  			})
   586  			w.Write(bigBuf)
   587  		}
   588  		zw.Close()
   589  	}
   590  
   591  	b.ReportAllocs()
   592  	// Run once and then reset the timer.
   593  	// This effectively discards the very large initial flate setup cost,
   594  	// as well as the initialization of bigBuf.
   595  	runOnce(&bytes.Buffer{})
   596  	b.ResetTimer()
   597  
   598  	b.RunParallel(func(pb *testing.PB) {
   599  		var buf bytes.Buffer
   600  		for pb.Next() {
   601  			runOnce(&buf)
   602  		}
   603  	})
   604  }
   605  

View as plain text