skynet定时器

要解析一个程序代码,先了解数据结构,这是基础,再看函数。 拿skynet定时器举例子。

数据结构

//定时器事件 用于抛出定时器事件到消息队列里。理解这个数据结构需要先了解skynet的框架原理,
//不理解这个数据结构也不影响下面的论述
struct timer_event {
	int32_t handle;
	int session;
};

//定时器节点
struct timer_node {
	struct timer_node *next;
	uint32_t expire;
};

//定时器链表
struct link_list {
	struct timer_node head;
	struct timer_node *tail;
};

//层级时间轮
//存放所有定时器的地方
struct timer {
	struct link_list near[TIME_NEAR]; //最近的定时器
	struct link_list t[4][TIME_LEVEL];//更久远的定时器
	struct spinlock lock;           //全局锁
	uint32_t time;                  //当前滴答数
	uint32_t starttime;             //程序开始时间 绝对时间 时间戳,单位 s 秒
	uint64_t current;               //当前时间 相对时间 1cs 厘秒 =  10ms 毫秒
	uint64_t current_point;         //系统(pc)运行时间 相对时间 单位: cs 厘秒
};

上面定时器的基本数据结构了解了,再来看下面的函数, 基本思路是一样的。

  1. 假设我们有3层时间轮,第一层基础时间轮一刻度是1s,有60刻度。
  2. 第一层指针走完一圈,第二层指针走完一格
  3. 第二层指针走完一圈,第三层指针走完一格。

精度为什么是10ms

创建+初始化定时器的数据结构就不看了。直接看skynet_updatetime, 驱动时间轮向前进的地方。这也是为什么skynet定时器的精度是10ms也就是1cs的地方。

void
skynet_updatetime(void) {
	uint64_t cp = gettime();
	if(cp < TI->current_point) {
		skynet_error(NULL, "time diff error: change from %lld to %lld", cp, TI->current_point);
		TI->current_point = cp;
	} else if (cp != TI->current_point) {
		uint32_t diff = (uint32_t)(cp - TI->current_point);
		TI->current_point = cp;
		TI->current += diff;
		skynet_error(NULL, "%lld, %lld, %d .    ", TI->current, TI->current_point, diff);
		int i;
		for (i=0;i<diff;i++) {
			timer_update(TI);
		}
	}
}

从上面的函数可以看到,驱动时间轮向前进的两个重要的成员变量

uint64_t current;               //当前时间 相对时间 1cs 厘秒 =  10ms 毫秒
uint64_t current_point;         //系统(pc)运行时间 相对时间 单位: cs 厘秒
  1. 获取系统当前启动的时间cp(单位: cs), 如果cp < TI->current_point, 则把当前启动时间赋值给 TI->current_point
  2. 如果cp != TI->current_point
  3. 计算cpTI->current_point的差值, 这个差值diff单位是cs
  4. 遍历diff, 这就是个时间轮的精度是1cs(10ms)的原因

系统当前启动的时间cp为什么单位是cs,函数gettime能说明一切,代码中有我的注释

timer
static void
systime(uint32_t *sec, uint32_t *cs) {
#if !defined(__APPLE__) || defined(AVAILABLE_MAC_OS_X_VERSION_10_12_AND_LATER)
	struct timespec ti;
	clock_gettime(CLOCK_REALTIME, &ti); //获取真实时间, 绝对时间
	*sec = (uint32_t)ti.tv_sec; //单位s 当前时间戳
	*cs = (uint32_t)(ti.tv_nsec / 10000000);//单位10ms, 当前时间戳
#else
	struct timeval tv;
	gettimeofday(&tv, NULL);
	*sec = tv.tv_sec;
	*cs = tv.tv_usec / 10000;
#endif
}

static uint64_t
gettime() {
	uint64_t t;
#if !defined(__APPLE__) || defined(AVAILABLE_MAC_OS_X_VERSION_10_12_AND_LATER)
	struct timespec ti;
	clock_gettime(CLOCK_MONOTONIC, &ti); //获取系统启动时间, 相对时间
	t = (uint64_t)ti.tv_sec * 100; //单位cs, 系统启动到现在的时间
	t += ti.tv_nsec / 10000000; //单位cs
#else
	struct timeval tv;
	gettimeofday(&tv, NULL);
	t = (uint64_t)tv.tv_sec * 100;
	t += tv.tv_usec / 10000;
#endif
	return t;
}

定时任务的处理方式

timer_update
static void 
timer_update(struct timer *T) {
	SPIN_LOCK(T);

	// try to dispatch timeout 0 (rare condition)
	timer_execute(T);

	// shift time first, and then dispatch timer message
	timer_shift(T);

	timer_execute(T);

	SPIN_UNLOCK(T);
}

  1. skynet_updatetime每2.5毫秒执行一次,如果跨度达到10毫秒,那么就执行一次定时器timer_update。
  2. timer_update函数分别调用了 timer_execute(T)timer_shift(T)

timer_execute

先来看timer_execute(T)

timer_execute
static inline void
timer_execute(struct timer *T) {
	int idx = T->time & TIME_NEAR_MASK;
	
	while (T->near[idx].head.next) {
		struct timer_node *current = link_clear(&T->near[idx]);
		SPIN_UNLOCK(T);
		// dispatch_list don't need lock T
		dispatch_list(current);
		SPIN_LOCK(T);
	}
}

  1. 根据当前滴答数T->time获取时间轮上near的位置idx,
  2. 遍历当前位置的任务链表,抛出定时器事件(如果只想了解定时器原理的,可以先不用看这块抛出事件), 可以理解为执行到期的任务

timer_shift

接着看timer_shift(T)

timer_shift
static void
timer_shift(struct timer *T) {
	int mask = TIME_NEAR;
	uint32_t ct = ++T->time;
	if (ct == 0) {
		move_list(T, 3, 0);
	} else {
		uint32_t time = ct >> TIME_NEAR_SHIFT;
		int i=0;

		while ((ct & (mask-1))==0) {
			int idx=time & TIME_LEVEL_MASK;
			if (idx!=0) {
				move_list(T, i, idx);
				break;				
			}
			mask <<= TIME_LEVEL_SHIFT;
			time >>= TIME_LEVEL_SHIFT;
			++i;
		}
	}
} 

  1. 滴答数T->time, 自增1
  2. 因为T->time,数据类型是uint32_t, 会出现自增溢出的时候,当T->time>2^32(2的32次方)溢出的时候,则代表已经跑完一个周期。 需要移动这里的定时器 T->t[3][0]
  3. 当定时器的expire超过一个uint32的时候,那么这些定时器默认是放在T->t[3][0]这一层的

timer_update

我们再回头来看timer_update, 这个函数所做的就是

  1. 先处理当前滴答 T->near[idx]的定时任务
  2. 过滤一遍定时任务列表,把T->t[],检测高一级别的定时任务有没有可以加入到T->near的地方
  3. 再处理理一遍当前滴答 T->near[idx]的定时任务

add_node

时间轮处理任务的方式我们已经知道来了,下面来看看添加定时任务的函数

add_node
static void
add_node(struct timer *T,struct timer_node *node) {
    // 逾期时间是相对于 T->time的
	uint32_t time=node->expire;
	uint32_t current_time=T->time;
	
    // 高24位一致,表示在一个near周期内。
	if ((time|TIME_NEAR_MASK)==(current_time|TIME_NEAR_MASK)) {
		link(&T->near[time&TIME_NEAR_MASK],node);
	} else {
		int i;
		uint32_t mask=TIME_NEAR << TIME_LEVEL_SHIFT;
		for (i=0;i<3;i++) {
			if ((time|(mask-1))==(current_time|(mask-1))) {
				break;
			}
			mask <<= TIME_LEVEL_SHIFT;
		}

		link(&T->t[i][((time>>(TIME_NEAR_SHIFT + i*TIME_LEVEL_SHIFT)) & TIME_LEVEL_MASK)],node);	
	}
}

  1. 添加定时任务,其实就是根据任务的逾期时间expire,插入对应的层级中。
  2. 逾期时间expire时间是根据T->time算的,所以只要高24位和T->time一样,就表示在一个near周期中,那么就插入near层,反之插入高层

总结

  1. 看代码先看数据结构,给自己提几个问题,再带着问题去找答案。
  2. 层级时间轮的原理其实是一样的,明白了其中的原理,再看skynet的定时器的实现就会很容易理解。

golang实现skynet定时器

timewheel
package main

import (
	"container/list"
	"sync"
	"sync/atomic"
	"time"
)

//skynet的时间轮 + 协程池
const (
	TimeNearShift  = 8
	TimeNear       = 1 << TimeNearShift
	TimeLevelShift = 6
	TimeLevel      = 1 << TimeLevelShift
	TimeNearMask   = TimeNear - 1
	TimeLevelMask  = TimeLevel - 1

	//协程池 大小
	WorkerPoolSize = 10
	MaxTaskPerWorker = 20
)

type bucket struct {
	expiration int32
	timers *list.List

	mu sync.Mutex
}

func newBucket() *bucket {
	return &bucket{
		expiration: -1,
		timers:     list.New(),
		mu: sync.Mutex{},
	}
}

func (b*bucket) Add(t *timer)  {
	b.mu.Lock()
	defer b.mu.Unlock()

	b.timers.PushBack(t)
}

func (b*bucket) Flush(reinsert func(t *timer))  {
	b.mu.Lock()
	defer b.mu.Unlock()

	for e := b.timers.Front(); e != nil; {
		next := e.Next()
		reinsert(e.Value.(*timer))

		b.timers.Remove(e)
		e = next
	}
}

type timer struct {
	expiration	 uint32
	f func()
}

var TimingWheel *TimeWheel

func init()  {
	TimingWheel = NewTimeWheel()
	TimingWheel.Start()
}

type TimeWheel struct {
	tick time.Duration
	ticker *time.Ticker
	near 			[TimeNear]*bucket
	t				[4][TimeLevel]*bucket
	time 			uint32

	WorkPool *WorkPool
	exit chan struct{}
	exitFlag uint32
}

func NewTimeWheel() *TimeWheel {
	tw := &TimeWheel{
		tick:     10*time.Millisecond,
		time:     0,
		WorkPool: NewWorkPool(WorkerPoolSize, MaxTaskPerWorker),
		exit:     make(chan struct{}),
		exitFlag: 0,
	}
	for i :=0; i < TimeNear; i++ {
		tw.near[i] = newBucket()
	}

	for i :=0; i < 4; i++ {
		for j :=0; j < TimeLevel; j++ {
			tw.t[i][j] = newBucket()
		}
	}
	return tw
}

func (tw *TimeWheel) add(t *timer) bool {
	time := t.expiration
	currentTime := atomic.LoadUint32(&tw.time)
	if time <= currentTime {
		return false
	}

	if (time | TimeNearMask) == (currentTime | TimeNearMask) {
		tw.near[time&TimeNearMask].Add(t)
	}else {
		i := 0
		mask := TimeNear << TimeNearShift
		for i=0; i < 3; i ++ {
			if (time | uint32(mask - 1)) == (currentTime | uint32(mask - 1)) {
				break
			}
			mask <<= TimeLevelShift
		}

		tw.t[i][((time>>(TimeNearShift + i*TimeLevelShift)) & TimeLevelMask)].Add(t)
	}
	return true
}

func (tw *TimeWheel) addOrRun(t *timer)  {
	if !tw.add(t) {
		workerID := int64(t.expiration) % tw.WorkPool.WorkerPoolSize
		//将请求消息发送给任务队列
		tw.WorkPool.TaskQueue[workerID] <- t.f
	}
}

func (tw *TimeWheel) moveList(level, idx int)  {
	current := tw.t[level][idx]
	current.Flush(tw.addOrRun)
}

func (tw *TimeWheel) shift()  {
	mask := TimeNear
	ct := atomic.AddUint32(&tw.time, 1)
	if ct == 0 {
		tw.moveList(3, 0)
	}else {
		time := ct >> TimeNearShift

		i := 0
		for (ct & uint32(mask-1)) == 0{
			idx := time & TimeLevelMask
			if idx != 0 {
				tw.moveList(i, int(idx))
				break
			}

			mask <<= TimeLevelShift
			time >>= TimeLevelShift
			i++
		}
	}
}

func (tw *TimeWheel) execute()  {
	idx := tw.time & TimeNearMask
	tw.near[idx].Flush(tw.addOrRun)
}

func (tw *TimeWheel) update()  {
	tw.execute()
	tw.shift()
	tw.execute()
}

func (tw *TimeWheel) Start()  {
	tw.ticker = time.NewTicker(tw.tick)
	tw.WorkPool.StartWorkerPool()

	go func() {
		for  {
			select {
			case <- tw.ticker.C:
				tw.update()
			case <- tw.exit:
				return
			}
		}
	}()
}

func (tw *TimeWheel) Stop()  {
	flag := atomic.LoadUint32(&tw.exitFlag)
	if flag != 0 {
		return
	}

	atomic.StoreUint32(&tw.exitFlag, 1)
	close(tw.exit)
}

func (tw *TimeWheel) afterFunc(expiration time.Duration, f func()) {
	time := atomic.LoadUint32(&tw.time)
	tw.addOrRun(&timer{
		expiration: uint32(expiration / tw.tick) + time,
		f:          f,
	})
}

func TimeOut(expire time.Duration, f func()) {
	TimingWheel.afterFunc(expire, f)
}

func StopTimer()  {
	TimingWheel.Stop()
}

workpool
package main

type Job func()

type WorkPool struct {
	WorkerPoolSize int64
	MaxTaskPerWorker int64
	TaskQueue []chan Job
}

func NewWorkPool(poolSize, maxTaskSize int64) *WorkPool {
	return &WorkPool{
		WorkerPoolSize: poolSize,
		MaxTaskPerWorker: maxTaskSize,
		TaskQueue: make([]chan Job, poolSize),
	}
}

//StartOneWorker 启动一个Worker工作流程
func (wp *WorkPool) StartOneWorker(workerID int, taskQueue chan Job) {
	//不断的等待队列中的消息
	for {
		select {
		//有消息则取出队列的Request,并执行绑定的业务方法
		case job := <-taskQueue:
			_ = workerID
			job()
		}
	}
}

func (wp *WorkPool) StartWorkerPool() {
	//遍历需要启动worker的数量,依此启动
	for i := 0; i < int(wp.WorkerPoolSize); i++ {
		//一个worker被启动
		//给当前worker对应的任务队列开辟空间
		wp.TaskQueue[i] = make(chan Job,  wp.MaxTaskPerWorker)
		//启动当前Worker,阻塞的等待对应的任务队列是否有消息传递进来
		go wp.StartOneWorker(i, wp.TaskQueue[i])
	}
}

timewheel_test
package main

import (
	"fmt"
	"testing"
	"time"
)

func PRINT()  {
	fmt.Println("12312312312")
}

func TestTimeWheel_Start(t *testing.T) {
	TimeOut(1 * time.Second, func() {
		fmt.Println("12312313123")
	})
	select{}
}

--完--