天天看點

【Go語言代碼】通過syscall.Select實作echo server

作者:趙帥虎

備注:本文隻包含代碼,是 Go BIO/NIO探讨(6):IO多路複用之select 的附錄

總共兩個檔案:

  1. main.go
  2. fd_set_linux_amd64.go

這裡我們隻實作了對 linux amd64 的支援,其他的先忽略了

代碼如下:

main.go

package main

import (
  "fmt"
  "log"
  "net"
  "os"
  "os/signal"
  "syscall"

  "github.com/<username>/golang/net/fdsetutil"
)

func ipToSockaddrInet4(ip net.IP, port int) (syscall.SockaddrInet4, error) {
  if len(ip) == 0 {
    ip = net.IPv4zero
  }
  ip4 := ip.To4()
  if ip4 == nil {
    return syscall.SockaddrInet4{}, &net.AddrError{Err: "non-IPv4 address", Addr: ip.String()}
  }
  sa := syscall.SockaddrInet4{Port: port}
  copy(sa.Addr[:], ip4)
  return sa, nil
}

func main() {
  var (
    family        = syscall.AF_INET
    sotype        = syscall.SOCK_STREAM
    _             = "tcp"
    listenBacklog = syscall.SOMAXCONN
    serverip      = net.IPv4(0, 0, 0, 0)
    serverport    = 8080
  )

  // 建立套接字
  sockfd, err := syscall.Socket(family, sotype, 0)
  if err != nil {
    panic(fmt.Errorf("fails to create socket: %s", err))
  }

  syscall.CloseOnExec(sockfd)

  // Nonblock 處理起來太複雜了,先注釋掉這一段
  // if err := syscall.SetNonblock(sockfd, true); err != nil {
  //   syscall.Close(sockfd)
  //   log.Printf("setnonblock error=%v\n", err)
  //   os.Exit(-1)
  // }

  // 接收到Ctrl+C信号後,關閉socket
  c := make(chan os.Signal)
  signal.Notify(c, os.Interrupt, syscall.SIGTERM)
  go func() {
    <-c
    log.Println("\r- Ctrl+C pressed in Terminal")

    if err := syscall.Close(sockfd); err != nil {
      log.Printf("Close sockfd %d fails, err=%v\n", sockfd, err)
    } else {
      log.Printf("Server stopped successfully!!!")
    }
    // 收到信号後需要處理, 否則程式會永久hang住, 需要kill -9 <pid>
    // os.Exit 會導緻所有goroutine都會立即停止執行
    os.Exit(0)
  }()

  addr, err := ipToSockaddrInet4(serverip, serverport)
  if err != nil {
    panic(fmt.Sprintf("fails to convert address %s:%d to socket addr, err=%s", serverip, serverport, err))
  }

  if err := syscall.Bind(sockfd, &addr); err != nil {
    panic(fmt.Sprintf("fails to bind socket %d to address %s:%d, err=%s",
      sockfd,
      serverip, serverport,
      err))
  }

  if err := syscall.Listen(sockfd, listenBacklog); err != nil {
    log.Printf("listen sockfd %d to addr error=%v\n", sockfd, err)
    panic(fmt.Sprintf("fails to listen socket %d", sockfd))
  } else {
    log.Printf("Started listening on %s:%d", serverip, serverport)
  }

  var nfds = sockfd
  var fdSet syscall.FdSet
  fdsetutil.SetFdBit(sockfd, &fdSet)
  clientFdMap := make(map[int]struct{}, 1024)

  for {
    // select會修改這個值,是以拷貝一份fdSet
    r := fdSet
    // timeout = nil, Select 會被阻塞直到有一個 fd 可用
    nReady, err := syscall.Select(nfds+1, &r, nil, nil, nil)
    if err != nil {
      log.Printf("select error=%v\n", err)
      panic("select error")
    }

    if fdsetutil.IsSetFdBit(sockfd, &r) {
      clientSockfd, clientSockAddr, err := syscall.Accept(sockfd)
      if err != nil {
        log.Printf("accept sockfd %d error=%v\n", sockfd, err)
        continue
      }
      // if len(clientFdMap) >= 1024 {
      //   panic("too many clients")
      // }
      clientSockAddrInet4 := clientSockAddr.(*syscall.SockaddrInet4)
      log.Printf("Connected with new client, sock addr = %v:%d\n", clientSockAddrInet4.Addr, clientSockAddrInet4.Port)
      clientFdMap[clientSockfd] = struct{}{}
      fdsetutil.SetFdBit(clientSockfd, &fdSet)
      if clientSockfd > nfds {
        nfds = clientSockfd
      }

      // 不走後續的邏輯
      nReady--
      if nReady <= 0 {
        continue
      }
    }

    for clientSockFd := range clientFdMap {
      if fdsetutil.IsSetFdBit(clientSockFd, &r) {
        var buf [32 * 1024]byte
        nRead, err := syscall.Read(clientSockFd, buf[:])
        if err != nil {
          log.Printf("fails to read data from sockfd %d, err=%v\n", clientSockFd, err)
          _ = syscall.Close(clientSockFd)
          fdsetutil.ClearFdBit(clientSockFd, &fdSet)
          delete(clientFdMap, clientSockFd)
        } else if nRead == 0 {
          // Client closed
          log.Printf("client sock %d closed\n", clientSockFd)
          _ = syscall.Close(clientSockFd)
          fdsetutil.ClearFdBit(clientSockFd, &fdSet)
          delete(clientFdMap, clientSockFd)
        } else {
          log.Printf("read %d bytes from sock %d\n", nRead, clientSockFd)
          if _, err := syscall.Write(clientSockFd, buf[:nRead]); err != nil {
            log.Printf("fails to write data %s into sockfd %d, err=%v\n", buf[:nRead], sockfd, err)
          }
        }

        nReady--
        if nReady <= 0 {
          break
        }
      }
    }
  }
}
           

fd_set_linux_amd64.go

//go:build amd64 && linux

package fdsetutil

import "syscall"

/**
// filepath: ztypes_linux_amd64.go
type FdSet struct {
  Bits [16]int64
}
*/

const (
  maskBits   = 64
  totalSlots = 16
)

func SetFdBit(sockfd int, fdSet *syscall.FdSet) {
  if sockfd >= 0 {
    fdSet.Bits[sockfd/maskBits] |= 1 << (sockfd % maskBits)
  }
}

func IsSetFdBit(sockfd int, fdSet *syscall.FdSet) bool {
  if sockfd >= 0 {
    return fdSet.Bits[sockfd/maskBits] == 1<<(sockfd%maskBits)
  }
  return false
}

func ClearFdBit(sockfd int, fdSet *syscall.FdSet) {
  if sockfd >= 0 {
    fdSet.Bits[sockfd/maskBits] &= (^(1 << (sockfd % maskBits)))
  }
}

func ResetFd(fdSet *syscall.FdSet) {
  for i := 0; i < totalSlots; i++ {
    fdSet.Bits[i] = 0
  }
}

func ToFdList(fdSet *syscall.FdSet) []int {
  var fdList []int
  for i := 0; i < totalSlots; i++ {
    for j := 0; j < maskBits; j++ {
      if (fdSet.Bits[i] & (1 << j)) > 0 {
        fdList = append(fdList, i*maskBits+j)
      }
    }
  }
  return fdList
}