天天看点

读Tokyo Tyrant源代码

仅供娱乐:Tokyo Tyrant的协议是先请求再返回数据,想给它加上主动推数据。c语言不熟,原来计划少改几行c代码,再加上Tokyo Tyrant的lua语言扩展就可以搞定。后来发现功能可以实现,要解决通信和线程的一些问题,可就麻烦了。我用的版本是tokyotyrant-1.1.36,主要看ttserver.c和ttutil.c,代码写的很好。我说好,意思是java程序员能看懂。

从ttserver.c里边main函数为线索,虽然没什么注释,但是从代码里的log字符串也能看出一段代码的功能。main主要处理命令行参数,然后调用proc函数,其它的先不管,主要关心proc调用的ttservstart函数,来看这个函数,主要看我加的中文注释:

/* Start the service of a server object. */
bool ttservstart(TTSERV *serv){
  assert(serv);
  int lfd;
  if(serv->port < 1){
    lfd = ttopenservsockunix(serv->host);//创建server socket
    if(lfd == -1){
      ttservlog(serv, TTLOGERROR, "ttopenservsockunix failed");
      return false;
    }
  } else {//创建server socket
    lfd = ttopenservsock(serv->addr[0] != '\0' ? serv->addr : NULL, serv->port);
    if(lfd == -1){
      ttservlog(serv, TTLOGERROR, "ttopenservsock failed");
      return false;
    }
  }
  int epfd = epoll_create(TTEVENTMAX);//epoll创建
  if(epfd == -1){
    close(lfd);
    ttservlog(serv, TTLOGERROR, "epoll_create failed");
    return false;
  }
  ttservlog(serv, TTLOGSYSTEM, "service started: %d", getpid());
  bool err = false;
  for(int i = 0; i < serv->timernum; i++){
    TTTIMER *timer = serv->timers + i;
    timer->alive = false;
    timer->serv = serv;
    if(pthread_create(&(timer->thid), NULL, ttservtimer, timer) == 0){
      ttservlog(serv, TTLOGINFO, "timer thread %d started", i + 1);
      timer->alive = true;
    } else {
      ttservlog(serv, TTLOGERROR, "pthread_create (ttservtimer) failed");
      err = true;
    }
  }
  int thnum = serv->thnum;
  TTREQ reqs[thnum];
  for(int i = 0; i < thnum; i++){
    reqs[i].alive = true;
    reqs[i].serv = serv;
    reqs[i].epfd = epfd;
    reqs[i].mtime = tctime();
    reqs[i].keep = false;
    reqs[i].idx = i;
    //创建处理客户端请求的线程
    if(pthread_create(&reqs[i].thid, NULL, ttservdeqtasks, reqs + i) == 0){
      ttservlog(serv, TTLOGINFO, "worker thread %d started", i + 1);
    } else {
      reqs[i].alive = false;
      err = true;
      ttservlog(serv, TTLOGERROR, "pthread_create (ttservdeqtasks) failed");
    }
  }
  struct epoll_event ev;
  memset(&ev, 0, sizeof(ev));
  ev.events = EPOLLIN;
  ev.data.fd = lfd;
  if(epoll_ctl(epfd, EPOLL_CTL_ADD, lfd, &ev) != 0){//开始监听服务器事件
    err = true;
    ttservlog(serv, TTLOGERROR, "epoll_ctl failed");
  }
  ttservlog(serv, TTLOGSYSTEM, "listening started");
  while(!serv->term){
    struct epoll_event events[TTEVENTMAX];
    int fdnum = epoll_wait(epfd, events, TTEVENTMAX, TTWAITREQUEST * 1000);//等待epoll事件
    if(fdnum != -1){
      for(int i = 0; i < fdnum; i++){
        if(events[i].data.fd == lfd){//是服务器事件
          char addr[TTADDRBUFSIZ];
          int port;
          int cfd;
          if(serv->port < 1){
            cfd = ttacceptsockunix(lfd);//接受连接
            sprintf(addr, "(unix)");
            port = 0;
          } else {
            cfd = ttacceptsock(lfd, addr, &port);//接受连接
          }
          if(epoll_reassoc(epfd, lfd) != 0){
            if(cfd != -1) close(cfd);
            cfd = -1;
          }
          if(cfd != -1){
            ttservlog(serv, TTLOGINFO, "connected: %s:%d", addr, port);
            struct epoll_event ev;
            memset(&ev, 0, sizeof(ev));
            ev.events = EPOLLIN | EPOLLONESHOT;//读  一次
            ev.data.fd = cfd;
            if(epoll_ctl(epfd, EPOLL_CTL_ADD, cfd, &ev) != 0){//开始监听客户端事件
              close(cfd);
              err = true;
              ttservlog(serv, TTLOGERROR, "epoll_ctl failed");
            }
          } else {
            err = true;
            ttservlog(serv, TTLOGERROR, "ttacceptsock failed");
          }
        } else {//客户端事件
          int cfd = events[i].data.fd;
          if(pthread_mutex_lock(&serv->qmtx) == 0){
            tclistpush(serv->queue, &cfd, sizeof(cfd));//加入处理队列
            if(pthread_mutex_unlock(&serv->qmtx) != 0){
              err = true;
              ttservlog(serv, TTLOGERROR, "pthread_mutex_unlock failed");
            }
            if(pthread_cond_signal(&serv->qcnd) != 0){
              err = true;
              ttservlog(serv, TTLOGERROR, "pthread_cond_signal failed");
            }
          } else {
            err = true;
            ttservlog(serv, TTLOGERROR, "pthread_mutex_lock failed");
          }
        }
      }
。。。。。。。。。。。。。。。。。
}
           

来看看如果处理客户端的请求,从ttservdeqtasks开始,我还是只关心用中文注释出来的关键点,其它细节暂时用不到:

/* Dequeue tasks of a server object and dispatch them.
   `argp' specifies the argument structure of the server object.
   The return value is `NULL' on success and other on failure. */
static void *ttservdeqtasks(void *argp){
  TTREQ *req = argp;
  TTSERV *serv = req->serv;
  bool err = false;
  if(pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, NULL) != 0){
    err = true;
    ttservlog(serv, TTLOGERROR, "pthread_setcancelstate failed");
  }
  sigset_t sigset;
  sigemptyset(&sigset);
  sigaddset(&sigset, SIGPIPE);
  sigset_t oldsigset;
  sigemptyset(&sigset);
  if(pthread_sigmask(SIG_BLOCK, &sigset, &oldsigset) != 0){
    err = true;
    ttservlog(serv, TTLOGERROR, "pthread_sigmask failed");
  }
  bool empty = false;
  while(!serv->term){
    if(pthread_mutex_lock(&serv->qmtx) == 0){//得到锁的才能操作队列
      struct timeval tv;
      struct timespec ts;
      if(gettimeofday(&tv, NULL) == 0){
        ts.tv_sec = tv.tv_sec;
        ts.tv_nsec = tv.tv_usec * 1000.0 + TTWAITREQUEST * 1000000000.0;
        if(ts.tv_nsec >= 1000000000){
          ts.tv_nsec -= 1000000000;
          ts.tv_sec++;
        }
      } else {
        ts.tv_sec = (1ULL << (sizeof(time_t) * 8 - 1)) - 1;
        ts.tv_nsec = 0;
      }
      int code = empty ? pthread_cond_timedwait(&serv->qcnd, &serv->qmtx, &ts) : 0;
      if(code == 0 || code == ETIMEDOUT || code == EINTR){
        void *val = tclistshift2(serv->queue);//待处理客户端出队
        if(pthread_mutex_unlock(&serv->qmtx) != 0){
          err = true;
          ttservlog(serv, TTLOGERROR, "pthread_mutex_unlock failed");
        }
        if(val){
          empty = false;
          int cfd = *(int *)val;
          tcfree(val);
          pthread_cleanup_push((void (*)(void *))close, (void *)(intptr_t)cfd);
          TTSOCK *sock = ttsocknew(cfd);//客户端socket
          pthread_cleanup_push((void (*)(void *))ttsockdel, sock);
          bool reuse;
          do {
            if(serv->timeout > 0) ttsocksetlife(sock, serv->timeout);
            req->mtime = tctime();
            req->keep = false;
            ttservtask(sock, req);//调用ttserver.c里的do_task来处理具体请求
            reuse = false;
            if(sock->end){
              req->keep = false;
            } else if(sock->ep > sock->rp){
              reuse = true;
            }
          } while(reuse);
          pthread_cleanup_pop(1);
          pthread_cleanup_pop(0);
          if(req->keep){//再把客户端加入epoll事件监听,对通信来说这里很重要
            struct epoll_event ev;
            memset(&ev, 0, sizeof(ev));
            [b]ev.events = EPOLLIN | EPOLLONESHOT;[/b]
            ev.data.fd = cfd;
            if([b]epoll_ctl(req->epfd, EPOLL_CTL_MOD, cfd, &ev)[/b] != 0){
              close(cfd);
              err = true;
              ttservlog(serv, TTLOGERROR, "epoll_ctl failed");
            }
          } else {
            if(epoll_ctl(req->epfd, EPOLL_CTL_DEL, cfd, NULL) != 0){
              err = true;
              ttservlog(serv, TTLOGERROR, "epoll_ctl failed");
            }
            if(!ttclosesock(cfd)){
              err = true;
              ttservlog(serv, TTLOGERROR, "close failed");
            }
            ttservlog(serv, TTLOGINFO, "connection finished");
          }
。。。。。。。。。

}
           

上面函数调用的do_task函数就简单了,先是处理二进制协议的各种命令,然后是对memcached协议和HTTP的支持

/* handle a task and dispatch it */
static void do_task(TTSOCK *sock, void *opq, TTREQ *req){
  TASKARG *arg = (TASKARG *)opq;
  int c = ttsockgetc(sock);
  if(c == TTMAGICNUM){
    switch(ttsockgetc(sock)){//处理二进制协议的各种命令
    case TTCMDPUT:
      do_put(sock, arg, req);
      break;
    case TTCMDPUTKEEP:
      do_putkeep(sock, arg, req);
      break;
。。。。。。。。。
    case TTCMDEXT:
      do_ext(sock, arg, req);//对这个比较感兴趣,还要往下看
      break;
。。。。。。。。。
    case TTCMDREPL:
      do_repl(sock, arg, req);
      break;
    default:
      ttservlog(g_serv, TTLOGINFO, "unknown command");
      break;
    }
  } else {
    ttsockungetc(sock, c);
    char *line = ttsockgets2(sock);
    if(line){
      pthread_cleanup_push(tcfree, line);
      int tnum;
      char **tokens = tokenize(line, &tnum);
      pthread_cleanup_push(tcfree, tokens);
      if(tnum > 0){
        const char *cmd = tokens[0];
        if(!strcmp(cmd, "set")){//[b]memcached协议[/b]
          do_mc_set(sock, arg, req, tokens, tnum);
        } else if(!strcmp(cmd, "add")){
          do_mc_add(sock, arg, req, tokens, tnum);
。。。。。。。。。。。。。。。。
        } else if(!strcmp(cmd, "quit")){
          do_mc_quit(sock, arg, req, tokens, tnum);
        } else if(tnum > 2 && tcstrfwm(tokens[2], "HTTP/1.")){
          int ver = tcatoi(tokens[2] + 7);
          const char *uri = tokens[1];
          if(tcstrifwm(uri, "http://")){
            const char *pv = strchr(uri + 7, '/');
            if(pv) uri = pv;
          }
          if(!strcmp(cmd, "GET")){//[b]HTTP的支持[/b]
            do_http_get(sock, arg, req, ver, uri);
          } else if(!strcmp(cmd, "HEAD")){
            do_http_head(sock, arg, req, ver, uri);
          } else if(!strcmp(cmd, "PUT")){
            do_http_put(sock, arg, req, ver, uri);
          } else if(!strcmp(cmd, "POST")){
            do_http_post(sock, arg, req, ver, uri);
          } else if(!strcmp(cmd, "DELETE")){
            do_http_delete(sock, arg, req, ver, uri);
          } else if(!strcmp(cmd, "OPTIONS")){
            do_http_options(sock, arg, req, ver, uri);
          }
        }
      }
      pthread_cleanup_pop(1);
      pthread_cleanup_pop(1);
    }
  }
}
           

要给客户端推送它事先订阅的内容,要有个地方记录,不能用lua的全局变量,可以用stash,全局的,内存的。想扩展又想尽量少改C代码,ext command 可以调LUA脚本,脚本语言会简单一些吧?再说ext command的本意就是用来扩展的。lua扩展里没有fd,所以要改C。要保证多线程下,推数据和拉的数据不发生混乱,那个以我的c水平搞不定。

/* handle the ext command */
static void do_ext(TTSOCK *sock, TASKARG *arg, TTREQ *req){
  ttservlog(g_serv, TTLOGDEBUG, "doing ext command");
  arg->counts[TTSEQNUM*req->idx+TTSEQEXT]++;
  uint64_t mask = arg->mask;
  pthread_mutex_t *rmtxs = arg->rmtxs;
  void *scr = arg->screxts[req->idx];
  int nsiz = ttsockgetint32(sock);
  int opts = ttsockgetint32(sock);//opts是客户端指定的,1左移0位和1位分别代表全局锁和记录锁
  int ksiz = ttsockgetint32(sock);
  int vsiz = ttsockgetint32(sock);
  if(ttsockcheckend(sock) || nsiz < 0 || nsiz >= TTADDRBUFSIZ ||
     ksiz < 0 || ksiz > MAXARGSIZ || vsiz < 0 || vsiz > MAXARGSIZ){
    ttservlog(g_serv, TTLOGINFO, "do_ext: invalid parameters");
    //要注意字节序啊,要不然也会出这个无效参数问题
    return;
  }
  int rsiz = nsiz + ksiz + vsiz;
  char stack[TTIOBUFSIZ];
  char *buf = (rsiz < TTIOBUFSIZ) ? stack : tcmalloc(rsiz + 1);
  pthread_cleanup_push(free, (buf == stack) ? NULL : buf);
  if(ttsockrecv(sock, buf, rsiz) && !ttsockcheckend(sock)){
    char name[TTADDRBUFSIZ];
    memcpy(name, buf, nsiz);
    name[nsiz] = '\0';
    const char *kbuf = buf + nsiz;
    const char *vbuf = kbuf + ksiz;
    int xsiz = 0;
    char *xbuf = NULL;
    if(mask & ((1ULL << TTSEQEXT) | (1ULL << TTSEQALLORG))){
      ttservlog(g_serv, TTLOGINFO, "do_ext: forbidden");
    } else if(scr){
      if(opts & RDBXOLCKGLB){//全局锁
        bool err = false;
        for(int i = 0; i < RECMTXNUM; i++){
          if(pthread_mutex_lock(rmtxs + i) != 0){
            ttservlog(g_serv, TTLOGERROR, "do_ext: pthread_mutex_lock failed");
            while(--i >= 0){
              pthread_mutex_unlock(rmtxs + i);
            }
            err = true;
            break;
          }
        }
        if(!err){//调用lua脚本去处理key和value
			  xbuf = scrextcallmethod(scr, name, kbuf, ksiz, vbuf, vsiz, &xsiz);
          for(int i = RECMTXNUM - 1; i >= 0; i--){
            if(pthread_mutex_unlock(rmtxs + i) != 0)
              ttservlog(g_serv, TTLOGERROR, "do_ext: pthread_mutex_unlock failed");
          }
        }
      } else if(opts & RDBXOLCKREC){//记录锁
        int mtxidx = recmtxidx(kbuf, ksiz);
        if(pthread_mutex_lock(rmtxs + mtxidx) == 0){//调用lua脚本去处理key和value
			  xbuf = scrextcallmethod(scr, name, kbuf, ksiz, vbuf, vsiz, &xsiz);
          if(pthread_mutex_unlock(rmtxs + mtxidx) != 0)
            ttservlog(g_serv, TTLOGERROR, "do_ext: pthread_mutex_unlock failed");
        } else {
          ttservlog(g_serv, TTLOGERROR, "do_ext: pthread_mutex_lock failed");
        }
      } else {//没锁  调用lua脚本去处理key和value
			  xbuf = scrextcallmethod(scr, name, kbuf, ksiz, vbuf, vsiz, &xsiz);
      }
    }
    if(xbuf){//脚本处理的返回值,要发给客户端
      int rsiz = xsiz + sizeof(uint8_t) + sizeof(uint32_t);
      char *rbuf = (rsiz < TTIOBUFSIZ) ? stack : tcmalloc(rsiz);
      pthread_cleanup_push(free, (rbuf == stack) ? NULL : rbuf);
      *rbuf = 0;
      uint32_t num;
      num = TTHTONL((uint32_t)xsiz);
      memcpy(rbuf + sizeof(uint8_t), &num, sizeof(uint32_t));
      memcpy(rbuf + sizeof(uint8_t) + sizeof(uint32_t), xbuf, xsiz);
      tcfree(xbuf);
      if(ttsocksend(sock, rbuf, rsiz)){//发给客户端
        req->keep = true;
      } else {
        ttservlog(g_serv, TTLOGINFO, "do_ext: response failed");
      }
      pthread_cleanup_pop(1);
    } else {
      uint8_t code = 1;
      if(ttsocksend(sock, &code, sizeof(code))){
        req->keep = true;
      } else {
        ttservlog(g_serv, TTLOGINFO, "do_ext: response failed");
      }
    }
  } else {
    ttservlog(g_serv, TTLOGINFO, "do_ext: invalid entity");
  }
  pthread_cleanup_pop(1);
}
           

opts是客户端指定的,1左移0位和1位分别代表全局锁和记录锁,下边用左移2位和3位标识不同操作。

要想保存fd,在判断全局锁之前加入如下代码:

if(opts & (1<<2) ){
				char *tbuf = buf + nsiz + ksiz;
				sprintf( tbuf, "%d", sock->fd);
				vsiz = strlen(tbuf);
				//以上代码把客户端的fd做为value传到脚本里,脚本可以把它保存到stash
			}
           

脚本处理的返回值,要发给特定的客户端,比如从stash里拿出的,可以这样试一下,在 if(xbuf){ 之后加上:

if(opts & (1<<3) ){
			int fd2 = atoi(xbuf);
			sock->fd = fd2;//把从stash里取出的值当作sock的fd
			xsiz = vsiz;
		}
           

以上只是用最少的改动验证一些想法是否可行,仅供娱乐,记着编译要./configure --enable-lua。

现在还少相应lua脚本,这个也是简单改了一下[url=http://github.com/igrigorik/tokyo-recipes]开源的代码[/url],lua脚本咱也不熟。(什么也不熟也出来混?只想看下学一门新技术要多长时间,下次我用java做,也有个比较。)

local SEP = '\n'

function _set_len(stream)
  local count = 0
  if stream then
    count = table.getn(_split(stream, SEP))
  end
  return count
end

function set_length(key, value)
  return _set_len(_stashget(key))
end

function set_get(key)
	return _stashget(key)
end

function set_append(key, value)
  local stream = _stashget(key)

  if not stream then
    _stashput(key, value)
  else
    local set_len = _set_len(stream)

    if set_len == 1 then
      if stream == value then return nil end
    elseif set_len > 1 then
      for _, element in ipairs(_split(stream, SEP)) do
        if element == value then return nil end
      end
    end
    if not _stashput(key, stream .. SEP .. value) then
      return nil
    end
  end
  return value
end

function set_delete(key, value)
  local stream = _stashget(key)

  if stream then
    local set_len = _set_len(stream)

    if set_len == 1 and stream == value then
      if _stashout(key) then return value end
    elseif set_len > 1 then
      local found = -1
      local set_list = _split(stream, SEP)

      for i, element in ipairs(set_list) do
        if element == value then
          found = i
          break
        end
      end

      if found > -1 then
        table.remove(set_list, found)
        if _stashput(key, table.concat(set_list, SEP)) then return value end
      end
    end
  end
  return nil
end

function set_delete_all()
	_stashvanish()
end
           

还有客户端的代码,我用actionscript,代码太长,只贴片段:

function testext(){
			var fn:String = "set_append";//调用lua脚本set_append函数,向stash里
			var key:String = "600000";   //以600000为key的"set"里放
			var val:String = "flash";    //字符串值
			var ha:ByteArray=new ByteArray();
	    	ha.endian = Endian.BIG_ENDIAN;
	    	ha.writeByte(0xc8);
	    	ha.writeByte(0x68);
	    	ha.writeUnsignedInt(fn.length);
	    	ha.writeUnsignedInt(5);      //101 C语言里读取到的opts值
	    	ha.writeUnsignedInt(key.length);
	    	ha.writeUnsignedInt(val.length);
	    	ha.writeUTFBytes(fn);
	    	ha.writeUTFBytes(key);
	    	ha.writeUTFBytes(val);
	    	socket.writeBytes(ha,0,ha.length);
	    	socket.flush();
		}		
		function testpush(){
			var fn:String = "set_get";//调用lua脚本set_get函数,向stash里取
			var key:String = "600000";//以600000为key的"set"里所有值
			var val:String = "flash";
			var ha:ByteArray=new ByteArray();
	    	ha.endian = Endian.BIG_ENDIAN;
	    	ha.writeByte(0xc8);
	    	ha.writeByte(0x68);
	    	ha.writeUnsignedInt(fn.length);
	    	ha.writeUnsignedInt(9);		//1001 C语言里读取到的opts值
	    	ha.writeUnsignedInt(key.length);
	    	ha.writeUnsignedInt(val.length);
	    	ha.writeUTFBytes(fn);
	    	ha.writeUTFBytes(key);
	    	ha.writeUTFBytes(val);
	    	socket.writeBytes(ha,0,ha.length);
	    	socket.flush();
		}
           

再想解决一些深入的问题就不是这么简单了,我宁愿用java来做,所以就到这了。