Use Golang to implement bidirectional copy detail of TCP connection

  • 2020-06-07 04:39:25
  • OfStack

preface

This article mainly introduces the Golang implementation of TCP two-way copy of the relevant content, to share for your reference and learning, the following words do not say much, let's have a look at the detailed introduction.

The simplest implementation

Every time an Server connection is made, a new Client connection is opened. Copy from server to client with 1 goroutine and copy from client to server with another goroutine. If any 1 side is disconnected, both sides are disconnected.


func main() {
 runtime.GOMAXPROCS(1)
 listener, err := net.Listen("tcp", "127.0.0.1:8848")
 if err != nil {
 panic(err)
 }
 for {
 conn, err := listener.Accept()
 if err != nil {
 panic(err)
 }
 go handle(conn.(*net.TCPConn))
 }
}

func handle(server *net.TCPConn) {
 defer server.Close()
 client, err := net.Dial("tcp", "127.0.0.1:8849")
 if err != nil {
 fmt.Print(err)
 return
 }
 defer client.Close()
 go func() {
 defer server.Close()
 defer client.Close()
 buf := make([]byte, 2048)
 io.CopyBuffer(server, client, buf)
 }()
 buf := make([]byte, 2048)
 io.CopyBuffer(client, server, buf)
}

One thing to note is that the default buffer for ES20en.Copy is larger, and a smaller buffer can support more concurrent connections.

The two goroutine conjunctions followed by the exit of one and the exit of the other. This is achieved by turning off server or socket of client. As socket is turned off, ES31en.CopyBuffer will exit.

The Client end implements connection pooling

The obvious problem is that every time a connection to Server comes in, a new connection to the end of Client needs to be temporarily established. This includes the handshake time of 1 tcp connection in the total agent time. If you can get the Client side to pool and reuse existing connections, you can reduce the end-to-end latency.


var pool = make(chan net.Conn, 100)

func borrow() (net.Conn, error) {
 select {
 case conn := <- pool:
 return conn, nil
 default:
 return net.Dial("tcp", "127.0.0.1:8849")
 }
}

func release(conn net.Conn) error {
 select {
 case pool <- conn:
 // returned to pool
 return nil
 default:
 // pool is overflow
 return conn.Close()
 }
}

func handle(server *net.TCPConn) {
 defer server.Close()
 client, err := borrow()
 if err != nil {
 fmt.Print(err)
 return
 }
 defer release(client)
 go func() {
 defer server.Close()
 defer release(client)
 buf := make([]byte, 2048)
 io.CopyBuffer(server, client, buf)
 }()
 buf := make([]byte, 2048)
 io.CopyBuffer(client, server, buf)
}

This version of the implementation is obviously problematic. There is no guarantee that the connection will remain connected when it is returned to the pool. Another more serious problem is that because the client connection is no longer closed, when the server end closes the connection, the goroutine making io.CopyBuffer from client to server cannot exit.

Therefore, there are several problems to be solved:

How do I exit when one goroutine exits and another goroutine exits? How do I ensure that the connection returned to pool is valid? How to keep the connection in pool still valid?

Goroutine is interrupted through SetDeadline

A common view is that Goroutine cannot be interrupted. When an Goroutine is doing ES66en.Read, the coroutine is blocked there. It is not impossible to interrupt Goroutine with ES68en.Close. In the case of connection pooling, however, Close links are not available. Another option is to interrupt the current blocking read or block write with SetDeadline for 1 past timestamp.


var pool = make(chan net.Conn, 100)

type client struct {
 conn net.Conn
 inUse *sync.WaitGroup
}

func borrow() (clt *client, err error) {
 var conn net.Conn
 select {
 case conn = <- pool:
 default:
 conn, err = net.Dial("tcp", "127.0.0.1:18849")
 }
 if err != nil {
 return nil, err
 }
 clt = &client{
 conn: conn,
 inUse: &sync.WaitGroup{},
 }
 return
}

func release(clt *client) error {
 clt.conn.SetDeadline(time.Now().Add(-time.Second))
 clt.inUse.Done()
 clt.inUse.Wait()
 select {
 case pool <- clt.conn:
 // returned to pool
 return nil
 default:
 // pool is overflow
 return clt.conn.Close()
 }
}

func handle(server *net.TCPConn) {
 defer server.Close()
 clt, err := borrow()
 if err != nil {
 fmt.Print(err)
 return
 }
 clt.inUse.Add(1)
 defer release(clt)
 go func() {
 clt.inUse.Add(1)
 defer server.Close()
 defer release(clt)
 buf := make([]byte, 2048)
 io.CopyBuffer(server, clt.conn, buf)
 }()
 buf := make([]byte, 2048)
 io.CopyBuffer(clt.conn, server, buf)
}

SetDeadline interrupts goroutine and then sync.WaitGroup ensures that these users exit before returning to the connection pool. Otherwise, by the time a connection is reused, the previous user may not have exited.

Connection validity

To ensure that the connection is still valid until returned to pool. If error is found while the connection is being read or written, we will mark that there is a problem with the connection and release close directly. However, SetDeadline will inevitably cause an timeout error when reading or writing, so timeout needs to be eliminated.


var pool = make(chan net.Conn, 100)

type client struct {
 conn net.Conn
 inUse *sync.WaitGroup
 isValid int32
}

const maybeValid = 0
const isValid = 1
const isInvalid = 2

func (clt *client) Read(b []byte) (n int, err error) {
 n, err = clt.conn.Read(b)
 if err != nil {
 if !isTimeoutError(err) {
 atomic.StoreInt32(&clt.isValid, isInvalid)
 }
 } else {
 atomic.StoreInt32(&clt.isValid, isValid)
 }
 return
}

func (clt *client) Write(b []byte) (n int, err error) {
 n, err = clt.conn.Write(b)
 if err != nil {
 if !isTimeoutError(err) {
 atomic.StoreInt32(&clt.isValid, isInvalid)
 }
 } else {
 atomic.StoreInt32(&clt.isValid, isValid)
 }
 return
}

type timeoutErr interface {
 Timeout() bool
}

func isTimeoutError(err error) bool {
 timeoutErr, _ := err.(timeoutErr)
 if timeoutErr == nil {
 return false
 }
 return timeoutErr.Timeout()
}

func borrow() (clt *client, err error) {
 var conn net.Conn
 select {
 case conn = <- pool:
 default:
 conn, err = net.Dial("tcp", "127.0.0.1:18849")
 }
 if err != nil {
 return nil, err
 }
 clt = &client{
 conn: conn,
 inUse: &sync.WaitGroup{},
 isValid: maybeValid,
 }
 return
}

func release(clt *client) error {
 clt.conn.SetDeadline(time.Now().Add(-time.Second))
 clt.inUse.Done()
 clt.inUse.Wait()
 if clt.isValid == isValid {
 return clt.conn.Close()
 }
 select {
 case pool <- clt.conn:
 // returned to pool
 return nil
 default:
 // pool is overflow
 return clt.conn.Close()
 }
}

func handle(server *net.TCPConn) {
 defer server.Close()
 clt, err := borrow()
 if err != nil {
 fmt.Print(err)
 return
 }
 clt.inUse.Add(1)
 defer release(clt)
 go func() {
 clt.inUse.Add(1)
 defer server.Close()
 defer release(clt)
 buf := make([]byte, 2048)
 io.CopyBuffer(server, clt, buf)
 }()
 buf := make([]byte, 2048)
 io.CopyBuffer(clt, server, buf)
}

To determine whether error is timeout, type strong inversion is required.

Whether the conn in the connection pool is still valid or not would be more expensive to implement with ping in the background. Because different protocols require different ways of ping for connection retention. The easiest way to do this is to try it next time. If the connection is no longer working, create a new connection to avoid continuous invalid connections. This is the way to weed out invalid connections.

About correctness

This paper was written at Hangzhou Airport, and the accuracy of the content is not guaranteed

conclusion


Related articles: