Source file src/archive/zip/writer_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zip
     6  
     7  import (
     8  	"bytes"
     9  	"compress/flate"
    10  	"encoding/binary"
    11  	"fmt"
    12  	"hash/crc32"
    13  	"io"
    14  	"io/fs"
    15  	"math/rand"
    16  	"os"
    17  	"strings"
    18  	"testing"
    19  	"testing/fstest"
    20  	"time"
    21  )
    22  
    23  // TODO(adg): a more sophisticated test suite
    24  
    25  type WriteTest struct {
    26  	Name   string
    27  	Data   []byte
    28  	Method uint16
    29  	Mode   fs.FileMode
    30  }
    31  
    32  var writeTests = []WriteTest{
    33  	{
    34  		Name:   "foo",
    35  		Data:   []byte("Rabbits, guinea pigs, gophers, marsupial rats, and quolls."),
    36  		Method: Store,
    37  		Mode:   0666,
    38  	},
    39  	{
    40  		Name:   "bar",
    41  		Data:   nil, // large data set in the test
    42  		Method: Deflate,
    43  		Mode:   0644,
    44  	},
    45  	{
    46  		Name:   "setuid",
    47  		Data:   []byte("setuid file"),
    48  		Method: Deflate,
    49  		Mode:   0755 | fs.ModeSetuid,
    50  	},
    51  	{
    52  		Name:   "setgid",
    53  		Data:   []byte("setgid file"),
    54  		Method: Deflate,
    55  		Mode:   0755 | fs.ModeSetgid,
    56  	},
    57  	{
    58  		Name:   "symlink",
    59  		Data:   []byte("../link/target"),
    60  		Method: Deflate,
    61  		Mode:   0755 | fs.ModeSymlink,
    62  	},
    63  	{
    64  		Name:   "device",
    65  		Data:   []byte("device file"),
    66  		Method: Deflate,
    67  		Mode:   0755 | fs.ModeDevice,
    68  	},
    69  	{
    70  		Name:   "chardevice",
    71  		Data:   []byte("char device file"),
    72  		Method: Deflate,
    73  		Mode:   0755 | fs.ModeDevice | fs.ModeCharDevice,
    74  	},
    75  }
    76  
    77  func TestWriter(t *testing.T) {
    78  	largeData := make([]byte, 1<<17)
    79  	if _, err := rand.Read(largeData); err != nil {
    80  		t.Fatal("rand.Read failed:", err)
    81  	}
    82  	writeTests[1].Data = largeData
    83  	defer func() {
    84  		writeTests[1].Data = nil
    85  	}()
    86  
    87  	// write a zip file
    88  	buf := new(bytes.Buffer)
    89  	w := NewWriter(buf)
    90  
    91  	for _, wt := range writeTests {
    92  		testCreate(t, w, &wt)
    93  	}
    94  
    95  	if err := w.Close(); err != nil {
    96  		t.Fatal(err)
    97  	}
    98  
    99  	// read it back
   100  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	for i, wt := range writeTests {
   105  		testReadFile(t, r.File[i], &wt)
   106  	}
   107  }
   108  
   109  // TestWriterComment is test for EOCD comment read/write.
   110  func TestWriterComment(t *testing.T) {
   111  	var tests = []struct {
   112  		comment string
   113  		ok      bool
   114  	}{
   115  		{"hi, hello", true},
   116  		{"hi, こんにちわ", true},
   117  		{strings.Repeat("a", uint16max), true},
   118  		{strings.Repeat("a", uint16max+1), false},
   119  	}
   120  
   121  	for _, test := range tests {
   122  		// write a zip file
   123  		buf := new(bytes.Buffer)
   124  		w := NewWriter(buf)
   125  		if err := w.SetComment(test.comment); err != nil {
   126  			if test.ok {
   127  				t.Fatalf("SetComment: unexpected error %v", err)
   128  			}
   129  			continue
   130  		} else {
   131  			if !test.ok {
   132  				t.Fatalf("SetComment: unexpected success, want error")
   133  			}
   134  		}
   135  
   136  		if err := w.Close(); test.ok == (err != nil) {
   137  			t.Fatal(err)
   138  		}
   139  
   140  		if w.closed != test.ok {
   141  			t.Fatalf("Writer.closed: got %v, want %v", w.closed, test.ok)
   142  		}
   143  
   144  		// skip read test in failure cases
   145  		if !test.ok {
   146  			continue
   147  		}
   148  
   149  		// read it back
   150  		r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   151  		if err != nil {
   152  			t.Fatal(err)
   153  		}
   154  		if r.Comment != test.comment {
   155  			t.Fatalf("Reader.Comment: got %v, want %v", r.Comment, test.comment)
   156  		}
   157  	}
   158  }
   159  
   160  func TestWriterUTF8(t *testing.T) {
   161  	var utf8Tests = []struct {
   162  		name    string
   163  		comment string
   164  		nonUTF8 bool
   165  		flags   uint16
   166  	}{
   167  		{
   168  			name:    "hi, hello",
   169  			comment: "in the world",
   170  			flags:   0x8,
   171  		},
   172  		{
   173  			name:    "hi, こんにちわ",
   174  			comment: "in the world",
   175  			flags:   0x808,
   176  		},
   177  		{
   178  			name:    "hi, こんにちわ",
   179  			comment: "in the world",
   180  			nonUTF8: true,
   181  			flags:   0x8,
   182  		},
   183  		{
   184  			name:    "hi, hello",
   185  			comment: "in the 世界",
   186  			flags:   0x808,
   187  		},
   188  		{
   189  			name:    "hi, こんにちわ",
   190  			comment: "in the 世界",
   191  			flags:   0x808,
   192  		},
   193  		{
   194  			name:    "the replacement rune is �",
   195  			comment: "the replacement rune is �",
   196  			flags:   0x808,
   197  		},
   198  		{
   199  			// Name is Japanese encoded in Shift JIS.
   200  			name:    "\x93\xfa\x96{\x8c\xea.txt",
   201  			comment: "in the 世界",
   202  			flags:   0x008, // UTF-8 must not be set
   203  		},
   204  	}
   205  
   206  	// write a zip file
   207  	buf := new(bytes.Buffer)
   208  	w := NewWriter(buf)
   209  
   210  	for _, test := range utf8Tests {
   211  		h := &FileHeader{
   212  			Name:    test.name,
   213  			Comment: test.comment,
   214  			NonUTF8: test.nonUTF8,
   215  			Method:  Deflate,
   216  		}
   217  		w, err := w.CreateHeader(h)
   218  		if err != nil {
   219  			t.Fatal(err)
   220  		}
   221  		w.Write([]byte{})
   222  	}
   223  
   224  	if err := w.Close(); err != nil {
   225  		t.Fatal(err)
   226  	}
   227  
   228  	// read it back
   229  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   230  	if err != nil {
   231  		t.Fatal(err)
   232  	}
   233  	for i, test := range utf8Tests {
   234  		flags := r.File[i].Flags
   235  		if flags != test.flags {
   236  			t.Errorf("CreateHeader(name=%q comment=%q nonUTF8=%v): flags=%#x, want %#x", test.name, test.comment, test.nonUTF8, flags, test.flags)
   237  		}
   238  	}
   239  }
   240  
   241  func TestWriterTime(t *testing.T) {
   242  	var buf bytes.Buffer
   243  	h := &FileHeader{
   244  		Name:     "test.txt",
   245  		Modified: time.Date(2017, 10, 31, 21, 11, 57, 0, timeZone(-7*time.Hour)),
   246  	}
   247  	w := NewWriter(&buf)
   248  	if _, err := w.CreateHeader(h); err != nil {
   249  		t.Fatalf("unexpected CreateHeader error: %v", err)
   250  	}
   251  	if err := w.Close(); err != nil {
   252  		t.Fatalf("unexpected Close error: %v", err)
   253  	}
   254  
   255  	want, err := os.ReadFile("testdata/time-go.zip")
   256  	if err != nil {
   257  		t.Fatalf("unexpected ReadFile error: %v", err)
   258  	}
   259  	if got := buf.Bytes(); !bytes.Equal(got, want) {
   260  		fmt.Printf("%x\n%x\n", got, want)
   261  		t.Error("contents of time-go.zip differ")
   262  	}
   263  }
   264  
   265  func TestWriterOffset(t *testing.T) {
   266  	largeData := make([]byte, 1<<17)
   267  	if _, err := rand.Read(largeData); err != nil {
   268  		t.Fatal("rand.Read failed:", err)
   269  	}
   270  	writeTests[1].Data = largeData
   271  	defer func() {
   272  		writeTests[1].Data = nil
   273  	}()
   274  
   275  	// write a zip file
   276  	buf := new(bytes.Buffer)
   277  	existingData := []byte{1, 2, 3, 1, 2, 3, 1, 2, 3}
   278  	n, _ := buf.Write(existingData)
   279  	w := NewWriter(buf)
   280  	w.SetOffset(int64(n))
   281  
   282  	for _, wt := range writeTests {
   283  		testCreate(t, w, &wt)
   284  	}
   285  
   286  	if err := w.Close(); err != nil {
   287  		t.Fatal(err)
   288  	}
   289  
   290  	// read it back
   291  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   292  	if err != nil {
   293  		t.Fatal(err)
   294  	}
   295  	for i, wt := range writeTests {
   296  		testReadFile(t, r.File[i], &wt)
   297  	}
   298  }
   299  
   300  func TestWriterFlush(t *testing.T) {
   301  	var buf bytes.Buffer
   302  	w := NewWriter(struct{ io.Writer }{&buf})
   303  	_, err := w.Create("foo")
   304  	if err != nil {
   305  		t.Fatal(err)
   306  	}
   307  	if buf.Len() > 0 {
   308  		t.Fatalf("Unexpected %d bytes already in buffer", buf.Len())
   309  	}
   310  	if err := w.Flush(); err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	if buf.Len() == 0 {
   314  		t.Fatal("No bytes written after Flush")
   315  	}
   316  }
   317  
   318  func TestWriterDir(t *testing.T) {
   319  	w := NewWriter(io.Discard)
   320  	dw, err := w.Create("dir/")
   321  	if err != nil {
   322  		t.Fatal(err)
   323  	}
   324  	if _, err := dw.Write(nil); err != nil {
   325  		t.Errorf("Write(nil) to directory: got %v, want nil", err)
   326  	}
   327  	if _, err := dw.Write([]byte("hello")); err == nil {
   328  		t.Error(`Write("hello") to directory: got nil error, want non-nil`)
   329  	}
   330  }
   331  
   332  func TestWriterDirAttributes(t *testing.T) {
   333  	var buf bytes.Buffer
   334  	w := NewWriter(&buf)
   335  	if _, err := w.CreateHeader(&FileHeader{
   336  		Name:               "dir/",
   337  		Method:             Deflate,
   338  		CompressedSize64:   1234,
   339  		UncompressedSize64: 5678,
   340  	}); err != nil {
   341  		t.Fatal(err)
   342  	}
   343  	if err := w.Close(); err != nil {
   344  		t.Fatal(err)
   345  	}
   346  	b := buf.Bytes()
   347  
   348  	var sig [4]byte
   349  	binary.LittleEndian.PutUint32(sig[:], uint32(fileHeaderSignature))
   350  
   351  	idx := bytes.Index(b, sig[:])
   352  	if idx == -1 {
   353  		t.Fatal("file header not found")
   354  	}
   355  	b = b[idx:]
   356  
   357  	if !bytes.Equal(b[6:10], []byte{0, 0, 0, 0}) { // FileHeader.Flags: 0, FileHeader.Method: 0
   358  		t.Errorf("unexpected method and flags: %v", b[6:10])
   359  	}
   360  
   361  	if !bytes.Equal(b[14:26], make([]byte, 12)) { // FileHeader.{CRC32,CompressSize,UncompressedSize} all zero.
   362  		t.Errorf("unexpected crc, compress and uncompressed size to be 0 was: %v", b[14:26])
   363  	}
   364  
   365  	binary.LittleEndian.PutUint32(sig[:], uint32(dataDescriptorSignature))
   366  	if bytes.Contains(b, sig[:]) {
   367  		t.Error("there should be no data descriptor")
   368  	}
   369  }
   370  
   371  func TestWriterCopy(t *testing.T) {
   372  	// make a zip file
   373  	buf := new(bytes.Buffer)
   374  	w := NewWriter(buf)
   375  	for _, wt := range writeTests {
   376  		testCreate(t, w, &wt)
   377  	}
   378  	if err := w.Close(); err != nil {
   379  		t.Fatal(err)
   380  	}
   381  
   382  	// read it back
   383  	src, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	for i, wt := range writeTests {
   388  		testReadFile(t, src.File[i], &wt)
   389  	}
   390  
   391  	// make a new zip file copying the old compressed data.
   392  	buf2 := new(bytes.Buffer)
   393  	dst := NewWriter(buf2)
   394  	for _, f := range src.File {
   395  		if err := dst.Copy(f); err != nil {
   396  			t.Fatal(err)
   397  		}
   398  	}
   399  	if err := dst.Close(); err != nil {
   400  		t.Fatal(err)
   401  	}
   402  
   403  	// read the new one back
   404  	r, err := NewReader(bytes.NewReader(buf2.Bytes()), int64(buf2.Len()))
   405  	if err != nil {
   406  		t.Fatal(err)
   407  	}
   408  	for i, wt := range writeTests {
   409  		testReadFile(t, r.File[i], &wt)
   410  	}
   411  }
   412  
   413  func TestWriterCreateRaw(t *testing.T) {
   414  	files := []struct {
   415  		name             string
   416  		content          []byte
   417  		method           uint16
   418  		flags            uint16
   419  		crc32            uint32
   420  		uncompressedSize uint64
   421  		compressedSize   uint64
   422  	}{
   423  		{
   424  			name:    "small store w desc",
   425  			content: []byte("gophers"),
   426  			method:  Store,
   427  			flags:   0x8,
   428  		},
   429  		{
   430  			name:    "small deflate wo desc",
   431  			content: bytes.Repeat([]byte("abcdefg"), 2048),
   432  			method:  Deflate,
   433  		},
   434  	}
   435  
   436  	// write a zip file
   437  	archive := new(bytes.Buffer)
   438  	w := NewWriter(archive)
   439  
   440  	for i := range files {
   441  		f := &files[i]
   442  		f.crc32 = crc32.ChecksumIEEE(f.content)
   443  		size := uint64(len(f.content))
   444  		f.uncompressedSize = size
   445  		f.compressedSize = size
   446  
   447  		var compressedContent []byte
   448  		if f.method == Deflate {
   449  			var buf bytes.Buffer
   450  			w, err := flate.NewWriter(&buf, flate.BestSpeed)
   451  			if err != nil {
   452  				t.Fatalf("flate.NewWriter err = %v", err)
   453  			}
   454  			_, err = w.Write(f.content)
   455  			if err != nil {
   456  				t.Fatalf("flate Write err = %v", err)
   457  			}
   458  			err = w.Close()
   459  			if err != nil {
   460  				t.Fatalf("flate Writer.Close err = %v", err)
   461  			}
   462  			compressedContent = buf.Bytes()
   463  			f.compressedSize = uint64(len(compressedContent))
   464  		}
   465  
   466  		h := &FileHeader{
   467  			Name:               f.name,
   468  			Method:             f.method,
   469  			Flags:              f.flags,
   470  			CRC32:              f.crc32,
   471  			CompressedSize64:   f.compressedSize,
   472  			UncompressedSize64: f.uncompressedSize,
   473  		}
   474  		w, err := w.CreateRaw(h)
   475  		if err != nil {
   476  			t.Fatal(err)
   477  		}
   478  		if compressedContent != nil {
   479  			_, err = w.Write(compressedContent)
   480  		} else {
   481  			_, err = w.Write(f.content)
   482  		}
   483  		if err != nil {
   484  			t.Fatalf("%s Write got %v; want nil", f.name, err)
   485  		}
   486  	}
   487  
   488  	if err := w.Close(); err != nil {
   489  		t.Fatal(err)
   490  	}
   491  
   492  	// read it back
   493  	r, err := NewReader(bytes.NewReader(archive.Bytes()), int64(archive.Len()))
   494  	if err != nil {
   495  		t.Fatal(err)
   496  	}
   497  	for i, want := range files {
   498  		got := r.File[i]
   499  		if got.Name != want.name {
   500  			t.Errorf("got Name %s; want %s", got.Name, want.name)
   501  		}
   502  		if got.Method != want.method {
   503  			t.Errorf("%s: got Method %#x; want %#x", want.name, got.Method, want.method)
   504  		}
   505  		if got.Flags != want.flags {
   506  			t.Errorf("%s: got Flags %#x; want %#x", want.name, got.Flags, want.flags)
   507  		}
   508  		if got.CRC32 != want.crc32 {
   509  			t.Errorf("%s: got CRC32 %#x; want %#x", want.name, got.CRC32, want.crc32)
   510  		}
   511  		if got.CompressedSize64 != want.compressedSize {
   512  			t.Errorf("%s: got CompressedSize64 %d; want %d", want.name, got.CompressedSize64, want.compressedSize)
   513  		}
   514  		if got.UncompressedSize64 != want.uncompressedSize {
   515  			t.Errorf("%s: got UncompressedSize64 %d; want %d", want.name, got.UncompressedSize64, want.uncompressedSize)
   516  		}
   517  
   518  		r, err := got.Open()
   519  		if err != nil {
   520  			t.Errorf("%s: Open err = %v", got.Name, err)
   521  			continue
   522  		}
   523  
   524  		buf, err := io.ReadAll(r)
   525  		if err != nil {
   526  			t.Errorf("%s: ReadAll err = %v", got.Name, err)
   527  			continue
   528  		}
   529  
   530  		if !bytes.Equal(buf, want.content) {
   531  			t.Errorf("%v: ReadAll returned unexpected bytes", got.Name)
   532  		}
   533  	}
   534  }
   535  
   536  func testCreate(t *testing.T, w *Writer, wt *WriteTest) {
   537  	header := &FileHeader{
   538  		Name:   wt.Name,
   539  		Method: wt.Method,
   540  	}
   541  	if wt.Mode != 0 {
   542  		header.SetMode(wt.Mode)
   543  	}
   544  	f, err := w.CreateHeader(header)
   545  	if err != nil {
   546  		t.Fatal(err)
   547  	}
   548  	_, err = f.Write(wt.Data)
   549  	if err != nil {
   550  		t.Fatal(err)
   551  	}
   552  }
   553  
   554  func testReadFile(t *testing.T, f *File, wt *WriteTest) {
   555  	if f.Name != wt.Name {
   556  		t.Fatalf("File name: got %q, want %q", f.Name, wt.Name)
   557  	}
   558  	testFileMode(t, f, wt.Mode)
   559  	rc, err := f.Open()
   560  	if err != nil {
   561  		t.Fatalf("opening %s: %v", f.Name, err)
   562  	}
   563  	b, err := io.ReadAll(rc)
   564  	if err != nil {
   565  		t.Fatalf("reading %s: %v", f.Name, err)
   566  	}
   567  	err = rc.Close()
   568  	if err != nil {
   569  		t.Fatalf("closing %s: %v", f.Name, err)
   570  	}
   571  	if !bytes.Equal(b, wt.Data) {
   572  		t.Errorf("File contents %q, want %q", b, wt.Data)
   573  	}
   574  }
   575  
   576  func BenchmarkCompressedZipGarbage(b *testing.B) {
   577  	bigBuf := bytes.Repeat([]byte("a"), 1<<20)
   578  
   579  	runOnce := func(buf *bytes.Buffer) {
   580  		buf.Reset()
   581  		zw := NewWriter(buf)
   582  		for j := 0; j < 3; j++ {
   583  			w, _ := zw.CreateHeader(&FileHeader{
   584  				Name:   "foo",
   585  				Method: Deflate,
   586  			})
   587  			w.Write(bigBuf)
   588  		}
   589  		zw.Close()
   590  	}
   591  
   592  	b.ReportAllocs()
   593  	// Run once and then reset the timer.
   594  	// This effectively discards the very large initial flate setup cost,
   595  	// as well as the initialization of bigBuf.
   596  	runOnce(&bytes.Buffer{})
   597  	b.ResetTimer()
   598  
   599  	b.RunParallel(func(pb *testing.PB) {
   600  		var buf bytes.Buffer
   601  		for pb.Next() {
   602  			runOnce(&buf)
   603  		}
   604  	})
   605  }
   606  
   607  func writeTestsToFS(tests []WriteTest) fs.FS {
   608  	fsys := fstest.MapFS{}
   609  	for _, wt := range tests {
   610  		fsys[wt.Name] = &fstest.MapFile{
   611  			Data: wt.Data,
   612  			Mode: wt.Mode,
   613  		}
   614  	}
   615  	return fsys
   616  }
   617  
   618  func TestWriterAddFS(t *testing.T) {
   619  	buf := new(bytes.Buffer)
   620  	w := NewWriter(buf)
   621  	tests := []WriteTest{
   622  		{
   623  			Name: "file.go",
   624  			Data: []byte("hello"),
   625  			Mode: 0644,
   626  		},
   627  		{
   628  			Name: "subfolder/another.go",
   629  			Data: []byte("world"),
   630  			Mode: 0644,
   631  		},
   632  	}
   633  	err := w.AddFS(writeTestsToFS(tests))
   634  	if err != nil {
   635  		t.Fatal(err)
   636  	}
   637  
   638  	if err := w.Close(); err != nil {
   639  		t.Fatal(err)
   640  	}
   641  
   642  	// read it back
   643  	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
   644  	if err != nil {
   645  		t.Fatal(err)
   646  	}
   647  	for i, wt := range tests {
   648  		testReadFile(t, r.File[i], &wt)
   649  	}
   650  }
   651  
   652  func TestIssue61875(t *testing.T) {
   653  	buf := new(bytes.Buffer)
   654  	w := NewWriter(buf)
   655  	tests := []WriteTest{
   656  		{
   657  			Name:   "symlink",
   658  			Data:   []byte("../link/target"),
   659  			Method: Deflate,
   660  			Mode:   0755 | fs.ModeSymlink,
   661  		},
   662  		{
   663  			Name:   "device",
   664  			Data:   []byte(""),
   665  			Method: Deflate,
   666  			Mode:   0755 | fs.ModeDevice,
   667  		},
   668  	}
   669  	err := w.AddFS(writeTestsToFS(tests))
   670  	if err == nil {
   671  		t.Errorf("expected error, got nil")
   672  	}
   673  }
   674  

View as plain text