Source file src/net/mockserver_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !js
     6  
     7  package net
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"os"
    13  	"path/filepath"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  // testUnixAddr uses os.MkdirTemp to get a name that is unique.
    20  func testUnixAddr(t testing.TB) string {
    21  	// Pass an empty pattern to get a directory name that is as short as possible.
    22  	// If we end up with a name longer than the sun_path field in the sockaddr_un
    23  	// struct, we won't be able to make the syscall to open the socket.
    24  	d, err := os.MkdirTemp("", "")
    25  	if err != nil {
    26  		t.Fatal(err)
    27  	}
    28  	t.Cleanup(func() {
    29  		if err := os.RemoveAll(d); err != nil {
    30  			t.Error(err)
    31  		}
    32  	})
    33  	return filepath.Join(d, "sock")
    34  }
    35  
    36  func newLocalListener(t testing.TB, network string) Listener {
    37  	listen := func(net, addr string) Listener {
    38  		ln, err := Listen(net, addr)
    39  		if err != nil {
    40  			t.Helper()
    41  			t.Fatal(err)
    42  		}
    43  		return ln
    44  	}
    45  
    46  	switch network {
    47  	case "tcp":
    48  		if supportsIPv4() {
    49  			if !supportsIPv6() {
    50  				return listen("tcp4", "127.0.0.1:0")
    51  			}
    52  			if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
    53  				return ln
    54  			}
    55  		}
    56  		if supportsIPv6() {
    57  			return listen("tcp6", "[::1]:0")
    58  		}
    59  	case "tcp4":
    60  		if supportsIPv4() {
    61  			return listen("tcp4", "127.0.0.1:0")
    62  		}
    63  	case "tcp6":
    64  		if supportsIPv6() {
    65  			return listen("tcp6", "[::1]:0")
    66  		}
    67  	case "unix", "unixpacket":
    68  		return listen(network, testUnixAddr(t))
    69  	}
    70  
    71  	t.Helper()
    72  	t.Fatalf("%s is not supported", network)
    73  	return nil
    74  }
    75  
    76  func newDualStackListener() (lns []*TCPListener, err error) {
    77  	var args = []struct {
    78  		network string
    79  		TCPAddr
    80  	}{
    81  		{"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
    82  		{"tcp6", TCPAddr{IP: IPv6loopback}},
    83  	}
    84  	for i := 0; i < 64; i++ {
    85  		var port int
    86  		var lns []*TCPListener
    87  		for _, arg := range args {
    88  			arg.TCPAddr.Port = port
    89  			ln, err := ListenTCP(arg.network, &arg.TCPAddr)
    90  			if err != nil {
    91  				continue
    92  			}
    93  			port = ln.Addr().(*TCPAddr).Port
    94  			lns = append(lns, ln)
    95  		}
    96  		if len(lns) != len(args) {
    97  			for _, ln := range lns {
    98  				ln.Close()
    99  			}
   100  			continue
   101  		}
   102  		return lns, nil
   103  	}
   104  	return nil, errors.New("no dualstack port available")
   105  }
   106  
   107  type localServer struct {
   108  	lnmu sync.RWMutex
   109  	Listener
   110  	done chan bool // signal that indicates server stopped
   111  	cl   []Conn    // accepted connection list
   112  }
   113  
   114  func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
   115  	go func() {
   116  		handler(ls, ls.Listener)
   117  		close(ls.done)
   118  	}()
   119  	return nil
   120  }
   121  
   122  func (ls *localServer) teardown() error {
   123  	ls.lnmu.Lock()
   124  	defer ls.lnmu.Unlock()
   125  	if ls.Listener != nil {
   126  		network := ls.Listener.Addr().Network()
   127  		address := ls.Listener.Addr().String()
   128  		ls.Listener.Close()
   129  		for _, c := range ls.cl {
   130  			if err := c.Close(); err != nil {
   131  				return err
   132  			}
   133  		}
   134  		<-ls.done
   135  		ls.Listener = nil
   136  		switch network {
   137  		case "unix", "unixpacket":
   138  			os.Remove(address)
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  func newLocalServer(t testing.TB, network string) *localServer {
   145  	t.Helper()
   146  	ln := newLocalListener(t, network)
   147  	return &localServer{Listener: ln, done: make(chan bool)}
   148  }
   149  
   150  type streamListener struct {
   151  	network, address string
   152  	Listener
   153  	done chan bool // signal that indicates server stopped
   154  }
   155  
   156  func (sl *streamListener) newLocalServer() *localServer {
   157  	return &localServer{Listener: sl.Listener, done: make(chan bool)}
   158  }
   159  
   160  type dualStackServer struct {
   161  	lnmu sync.RWMutex
   162  	lns  []streamListener
   163  	port string
   164  
   165  	cmu sync.RWMutex
   166  	cs  []Conn // established connections at the passive open side
   167  }
   168  
   169  func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
   170  	for i := range dss.lns {
   171  		go func(i int) {
   172  			handler(dss, dss.lns[i].Listener)
   173  			close(dss.lns[i].done)
   174  		}(i)
   175  	}
   176  	return nil
   177  }
   178  
   179  func (dss *dualStackServer) teardownNetwork(network string) error {
   180  	dss.lnmu.Lock()
   181  	for i := range dss.lns {
   182  		if network == dss.lns[i].network && dss.lns[i].Listener != nil {
   183  			dss.lns[i].Listener.Close()
   184  			<-dss.lns[i].done
   185  			dss.lns[i].Listener = nil
   186  		}
   187  	}
   188  	dss.lnmu.Unlock()
   189  	return nil
   190  }
   191  
   192  func (dss *dualStackServer) teardown() error {
   193  	dss.lnmu.Lock()
   194  	for i := range dss.lns {
   195  		if dss.lns[i].Listener != nil {
   196  			dss.lns[i].Listener.Close()
   197  			<-dss.lns[i].done
   198  		}
   199  	}
   200  	dss.lns = dss.lns[:0]
   201  	dss.lnmu.Unlock()
   202  	dss.cmu.Lock()
   203  	for _, c := range dss.cs {
   204  		c.Close()
   205  	}
   206  	dss.cs = dss.cs[:0]
   207  	dss.cmu.Unlock()
   208  	return nil
   209  }
   210  
   211  func newDualStackServer() (*dualStackServer, error) {
   212  	lns, err := newDualStackListener()
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	_, port, err := SplitHostPort(lns[0].Addr().String())
   217  	if err != nil {
   218  		lns[0].Close()
   219  		lns[1].Close()
   220  		return nil, err
   221  	}
   222  	return &dualStackServer{
   223  		lns: []streamListener{
   224  			{network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
   225  			{network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
   226  		},
   227  		port: port,
   228  	}, nil
   229  }
   230  
   231  func (ls *localServer) transponder(ln Listener, ch chan<- error) {
   232  	defer close(ch)
   233  
   234  	switch ln := ln.(type) {
   235  	case *TCPListener:
   236  		ln.SetDeadline(time.Now().Add(someTimeout))
   237  	case *UnixListener:
   238  		ln.SetDeadline(time.Now().Add(someTimeout))
   239  	}
   240  	c, err := ln.Accept()
   241  	if err != nil {
   242  		if perr := parseAcceptError(err); perr != nil {
   243  			ch <- perr
   244  		}
   245  		ch <- err
   246  		return
   247  	}
   248  	ls.cl = append(ls.cl, c)
   249  
   250  	network := ln.Addr().Network()
   251  	if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
   252  		ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
   253  		return
   254  	}
   255  	c.SetDeadline(time.Now().Add(someTimeout))
   256  	c.SetReadDeadline(time.Now().Add(someTimeout))
   257  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   258  
   259  	b := make([]byte, 256)
   260  	n, err := c.Read(b)
   261  	if err != nil {
   262  		if perr := parseReadError(err); perr != nil {
   263  			ch <- perr
   264  		}
   265  		ch <- err
   266  		return
   267  	}
   268  	if _, err := c.Write(b[:n]); err != nil {
   269  		if perr := parseWriteError(err); perr != nil {
   270  			ch <- perr
   271  		}
   272  		ch <- err
   273  		return
   274  	}
   275  }
   276  
   277  func transceiver(c Conn, wb []byte, ch chan<- error) {
   278  	defer close(ch)
   279  
   280  	c.SetDeadline(time.Now().Add(someTimeout))
   281  	c.SetReadDeadline(time.Now().Add(someTimeout))
   282  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   283  
   284  	n, err := c.Write(wb)
   285  	if err != nil {
   286  		if perr := parseWriteError(err); perr != nil {
   287  			ch <- perr
   288  		}
   289  		ch <- err
   290  		return
   291  	}
   292  	if n != len(wb) {
   293  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   294  	}
   295  	rb := make([]byte, len(wb))
   296  	n, err = c.Read(rb)
   297  	if err != nil {
   298  		if perr := parseReadError(err); perr != nil {
   299  			ch <- perr
   300  		}
   301  		ch <- err
   302  		return
   303  	}
   304  	if n != len(wb) {
   305  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   306  	}
   307  }
   308  
   309  func newLocalPacketListener(t testing.TB, network string) PacketConn {
   310  	listenPacket := func(net, addr string) PacketConn {
   311  		c, err := ListenPacket(net, addr)
   312  		if err != nil {
   313  			t.Helper()
   314  			t.Fatal(err)
   315  		}
   316  		return c
   317  	}
   318  
   319  	switch network {
   320  	case "udp":
   321  		if supportsIPv4() {
   322  			return listenPacket("udp4", "127.0.0.1:0")
   323  		}
   324  		if supportsIPv6() {
   325  			return listenPacket("udp6", "[::1]:0")
   326  		}
   327  	case "udp4":
   328  		if supportsIPv4() {
   329  			return listenPacket("udp4", "127.0.0.1:0")
   330  		}
   331  	case "udp6":
   332  		if supportsIPv6() {
   333  			return listenPacket("udp6", "[::1]:0")
   334  		}
   335  	case "unixgram":
   336  		return listenPacket(network, testUnixAddr(t))
   337  	}
   338  
   339  	t.Helper()
   340  	t.Fatalf("%s is not supported", network)
   341  	return nil
   342  }
   343  
   344  func newDualStackPacketListener() (cs []*UDPConn, err error) {
   345  	var args = []struct {
   346  		network string
   347  		UDPAddr
   348  	}{
   349  		{"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
   350  		{"udp6", UDPAddr{IP: IPv6loopback}},
   351  	}
   352  	for i := 0; i < 64; i++ {
   353  		var port int
   354  		var cs []*UDPConn
   355  		for _, arg := range args {
   356  			arg.UDPAddr.Port = port
   357  			c, err := ListenUDP(arg.network, &arg.UDPAddr)
   358  			if err != nil {
   359  				continue
   360  			}
   361  			port = c.LocalAddr().(*UDPAddr).Port
   362  			cs = append(cs, c)
   363  		}
   364  		if len(cs) != len(args) {
   365  			for _, c := range cs {
   366  				c.Close()
   367  			}
   368  			continue
   369  		}
   370  		return cs, nil
   371  	}
   372  	return nil, errors.New("no dualstack port available")
   373  }
   374  
   375  type localPacketServer struct {
   376  	pcmu sync.RWMutex
   377  	PacketConn
   378  	done chan bool // signal that indicates server stopped
   379  }
   380  
   381  func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
   382  	go func() {
   383  		handler(ls, ls.PacketConn)
   384  		close(ls.done)
   385  	}()
   386  	return nil
   387  }
   388  
   389  func (ls *localPacketServer) teardown() error {
   390  	ls.pcmu.Lock()
   391  	if ls.PacketConn != nil {
   392  		network := ls.PacketConn.LocalAddr().Network()
   393  		address := ls.PacketConn.LocalAddr().String()
   394  		ls.PacketConn.Close()
   395  		<-ls.done
   396  		ls.PacketConn = nil
   397  		switch network {
   398  		case "unixgram":
   399  			os.Remove(address)
   400  		}
   401  	}
   402  	ls.pcmu.Unlock()
   403  	return nil
   404  }
   405  
   406  func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
   407  	t.Helper()
   408  	c := newLocalPacketListener(t, network)
   409  	return &localPacketServer{PacketConn: c, done: make(chan bool)}
   410  }
   411  
   412  type packetListener struct {
   413  	PacketConn
   414  }
   415  
   416  func (pl *packetListener) newLocalServer() *localPacketServer {
   417  	return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
   418  }
   419  
   420  func packetTransponder(c PacketConn, ch chan<- error) {
   421  	defer close(ch)
   422  
   423  	c.SetDeadline(time.Now().Add(someTimeout))
   424  	c.SetReadDeadline(time.Now().Add(someTimeout))
   425  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   426  
   427  	b := make([]byte, 256)
   428  	n, peer, err := c.ReadFrom(b)
   429  	if err != nil {
   430  		if perr := parseReadError(err); perr != nil {
   431  			ch <- perr
   432  		}
   433  		ch <- err
   434  		return
   435  	}
   436  	if peer == nil { // for connected-mode sockets
   437  		switch c.LocalAddr().Network() {
   438  		case "udp":
   439  			peer, err = ResolveUDPAddr("udp", string(b[:n]))
   440  		case "unixgram":
   441  			peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
   442  		}
   443  		if err != nil {
   444  			ch <- err
   445  			return
   446  		}
   447  	}
   448  	if _, err := c.WriteTo(b[:n], peer); err != nil {
   449  		if perr := parseWriteError(err); perr != nil {
   450  			ch <- perr
   451  		}
   452  		ch <- err
   453  		return
   454  	}
   455  }
   456  
   457  func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
   458  	defer close(ch)
   459  
   460  	c.SetDeadline(time.Now().Add(someTimeout))
   461  	c.SetReadDeadline(time.Now().Add(someTimeout))
   462  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   463  
   464  	n, err := c.WriteTo(wb, dst)
   465  	if err != nil {
   466  		if perr := parseWriteError(err); perr != nil {
   467  			ch <- perr
   468  		}
   469  		ch <- err
   470  		return
   471  	}
   472  	if n != len(wb) {
   473  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   474  	}
   475  	rb := make([]byte, len(wb))
   476  	n, _, err = c.ReadFrom(rb)
   477  	if err != nil {
   478  		if perr := parseReadError(err); perr != nil {
   479  			ch <- perr
   480  		}
   481  		ch <- err
   482  		return
   483  	}
   484  	if n != len(wb) {
   485  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   486  	}
   487  }
   488  

View as plain text