利用时间轮实现定时器
理解定时器
适用场景
- 定时任务(每隔1s钟打印一次数据)
- 超时控制(xx分钟没有动作就断开连接)
- 频率限制(最快只能每5s调用一次API)
定时器常用的数据结构有如下几种:
- 链表
- 双向有序链表
- 最小堆
- 时间轮
- 层级时间轮
链表
链表用于存储所有的定时器。每次都需要遍历链表,从里面找出过期的任务,简单粗暴不采用。
双向有序链表
通过过期时间(expireTime)查找合适的位置插入。由于整个双向链表是基于过期时间(expireTime)有序的,所以调度器只需要轮询第一个任务即可。
这也是使用双向有序链表而不是普通链表的原因。
最小堆
golang中的NewTick采用的就是最小堆。
时间轮
时间轮是一个环形+链表的结构,可以类比成时钟,钟面上有很多槽
,每一个槽
上可以存放多个任务。使用一个链表保存该时刻到期的所有任务。每走过一刻度,我称之为一滴答(tick)
假设一个tick是1毫秒(ms),指针转动一轮的时间为8ms
- eg: 当前指针指向0,如果需要调度一个3ms后执行的任务,应该加入到(0+3=3)的方格中。
- eg: 当前指针指向0,如果需要调度一个10ms后执行的任务,应该等指针走完(一轮+2格)的方格中。
开始代码编写,一个时间轮的struct, 和任务的struct
// Job 延时任务回调函数
type Job func(interface{})
// Task 延时任务
type Task struct {
delay time.Duration // 延迟时间
circle int // 时间轮需要转动几圈
key int64 // 定时器唯一标识, 用于删除定时器
data interface{} // 回调函数参数
job Job
}
type TimeWheel struct {
id int64 //auto incr
tick time.Duration//in millisecond 时间轮上时间单位 ms
slots []*list.List // 时间轮槽/刻度
currentPos int // 当前指针指向哪一个槽
taskMap map[int64]int //任务 key=pos
slotNum int // 槽数量
addTaskChannel chan *Task // 新增任务channel
removeTaskChannel chan int64 // 删除任务channel
stopChannel chan bool // 停止定时器channel
ticker *time.Ticker
}
开始时间轮
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.tick)
go func() {
for {
select {
case <-tw.ticker.C:
tw.tickHandler() //时间轮每转动一下指针,则调用这个tickHandler()函数
case task := <-tw.addTaskChannel:
tw.addTask(task) //添加定时任务
case key := <-tw.removeTaskChannel:
tw.removeTask(key) //删除任务
case <-tw.stopChannel:
tw.ticker.Stop() //关闭时间轮
tw.removeAllTask() //删除所有任务
return
}
}
}()
}
按照上面的时间轮的思路,可以看到一个时间轮+定时任务所需要的数据结构和处理函数。梳理一下逻辑:
- 启动时间轮
- 监听时间轮指针 tw.ticker.C, 每拨动指针则调用 tickHandler()
- 监听添加任务的channel
- 监听删除任务的channel
- 监听关闭时间轮的channel
详细看看tickHandler函数
golang
func (tw *TimeWheel) tickHandler() {
//获取指针指向的槽
l := tw.slots[tw.currentPos]
//扫描当前槽中过期的定任务,时间复杂度为O(n),并执行回调
tw.scanAndRunTask(l)
//如果当前指针已经到时间轮的末尾,则赋值=0
if tw.currentPos == tw.slotNum - 1 {
tw.currentPos = 0
}else {
//指针自增
tw.currentPos ++
}
}
func (tw *TimeWheel) scanAndRunTask(l *list.List) {
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task.circle > 0 {
task.circle --
e = e.Next()
continue
}
//执行定时任务. 每个任务都需要一个协程,需优化
go task.job(task.data)
next := e.Next()
//删除任务
l.Remove(e)
delete(tw.taskMap, task.key)
e = next
}
}
再来看看添加定时任务函数 addTask
:
func (tw *TimeWheel) getRoundCount(d time.Duration) (int,int) {
// 需要转动的圈数 = (过期时间/一滴答) / 槽数量
circle := int(d/tw.tick)/tw.slotNum
//在时间轮上的位置 = (指针当前位置 + (过期时间/一滴答)) % 槽数量
pos := (tw.currentPos + int(d/tw.tick)) % tw.slotNum
return circle, pos
}
func (tw *TimeWheel) addTask(task *Task) {
//获取该任务 时间轮需要转动圈数,和在时间轮上的位置
circle, pos := tw.getRoundCount(task.delay)
task.circle = circle
tw.slots[pos].PushBack(task)
tw.taskMap[task.key] = pos
}
完整代码
timewheel
package timewheel
import (
"container/list"
"sync/atomic"
"time"
)
// Job 延时任务回调函数
type Job func(interface{})
// Task 延时任务
type Task struct {
delay time.Duration // 延迟时间
circle int // 时间轮需要转动几圈
key int64 // 定时器唯一标识, 用于删除定时器
data interface{} // 回调函数参数
job Job
}
type TimeWheel struct {
id int64 //auto incr
tick time.Duration//in millisecond 时间轮上时间单位 ms
slots []*list.List // 时间轮槽
currentPos int // 当前指针指向哪一个槽
taskMap map[int64]int
slotNum int // 槽数量
addTaskChannel chan *Task // 新增任务channel
removeTaskChannel chan int64 // 删除任务channel
stopChannel chan bool // 停止定时器channel
ticker *time.Ticker
}
func NewTimeWheel(tick time.Duration, slotNum int) *TimeWheel{
tw := &TimeWheel{
tick: tick,
currentPos: 0,
slots: make([]*list.List, slotNum),
taskMap: make(map[int64]int),
slotNum: slotNum,
addTaskChannel: make(chan *Task),
removeTaskChannel: make(chan int64),
stopChannel: make(chan bool),
}
tw.initSlots()
return tw
}
func (tw *TimeWheel) initSlots() {
for i := 0; i < tw.slotNum; i++ {
tw.slots[i] = list.New()
}
}
// 扫描链表中过期定时器, 并执行回调函数 时间复杂度为O(n)
func (tw *TimeWheel) scanAndRunTask(l *list.List) {
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task.circle > 0 {
task.circle --
e = e.Next()
continue
}
go task.job(task.data)
next := e.Next()
l.Remove(e)
delete(tw.taskMap, task.key)
e = next
}
}
func (tw *TimeWheel) tickHandler() {
l := tw.slots[tw.currentPos]
tw.scanAndRunTask(l)
if tw.currentPos == tw.slotNum - 1 {
tw.currentPos = 0
}else {
tw.currentPos ++
}
}
func (tw *TimeWheel) getRoundCount(d time.Duration) (int,int) {
circle := int(d/tw.tick)/tw.slotNum
pos := (tw.currentPos + int(d/tw.tick)) % tw.slotNum
return circle, pos
}
func (tw *TimeWheel) addTask(task *Task) {
circle, pos := tw.getRoundCount(task.delay)
task.circle = circle
tw.slots[pos].PushBack(task)
tw.taskMap[task.key] = pos
}
func (tw *TimeWheel) removeTask(key int64) {
pos, ok := tw.taskMap[key]
if !ok {
return
}
l := tw.slots[pos]
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task.key == key {
l.Remove(e)
delete(tw.taskMap, key)
return
}
e = e.Next()
}
}
func (tw *TimeWheel) removeAllTask() {
for _, l := range tw.slots {
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task != nil {
next := e.Next()
l.Remove(e)
if task.key != 0{
delete(tw.taskMap, task.key)
}
e = next
}else {
return
}
}
}
}
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.tick)
go func() {
for {
select {
case <-tw.ticker.C:
tw.tickHandler()
case task := <-tw.addTaskChannel:
tw.addTask(task)
case key := <-tw.removeTaskChannel:
tw.removeTask(key)
case <-tw.stopChannel:
tw.ticker.Stop()
tw.removeAllTask()
return
}
}
}()
}
func (tw *TimeWheel) Stop() {
tw.stopChannel <- true
}
func (tw *TimeWheel) AddTimer(delay time.Duration, data interface{}, job Job) int64 {
id := atomic.AddInt64(&tw.id, 1)
tw.addTaskChannel <- &Task{
key: id,
delay: delay,
data: data,
job: job,
}
return id
}
func (tw *TimeWheel) RemoveTimer(key int64) {
tw.removeTaskChannel <- key
}
测试
testting
package timewheel
import (
"fmt"
"testing"
"time"
)
func TestTimeWheel_Start(t *testing.T) {
tw := NewTimeWheel(time.Millisecond, 3600)
tw.Start()
id := tw.AddTimer(10*time.Millisecond, "123", func(i interface{}) {
fmt.Println(i)
})
fmt.Println("timer id: ", id)
time.Sleep(100*time.Millisecond)
tw.Stop()
}
可以看到时间轮代码还有需要优化的点:
- 每个任务需要启动一个临时协程来处理
- 每次都需要遍历指针所指向的
槽
中的所有的定时任务,时间复杂度为O(n) - 当时间跨度较大的时候,提升单层时间轮的tick可以减少空转次数,但是会导致时间精度较低
由此引入了层级时间轮的概念
层级时间轮
如果任务的时间跨度很大,数量也多,传统的时间轮会造成单个刻度所对应的任务链表过长,对于过期时间久远的任务所需转动的圈数circle
过大。
这时可以时间轮按时间粒度分级。可以理解为一个时钟,含有时分秒指针。最基础的指针是秒针。依次类推…
按照时钟的概念,可以类推到层级时间轮的原理
- 假设我们有3层时间轮,第一层基础时间轮一刻度是1s,有60刻度。
- 第一层指针走完一圈,第二层指针走完一格
- 第二层指针走完一圈,第三层指针走完一格。 …
用golang来实现层级时间轮, 时间轮结构体
type TimeWheel struct {
ticker *time.Ticker
tickMs int64 //一滴答的时间 1ms 可以自定义 我们这里选择使用1ms
wheelSize int64
startMs int64 //开始时间 in millisecond
endMs int64
wheelTime int64 //跑完一圈所需时间
//时间刻度 列表
bucket []*bucket
currentTime int64 //当前时间 in millisecond
overflowWheel unsafe.Pointer // type: *TimingWheel 用于递归过去更高层的时间轮
exitC chan struct{}
}
刻度bucket的struct
type bucket struct {
//过期时间
expiration int64
mu sync.Mutex
//相同过期时间的任务队列
timers *list.List
}
//一个定时任务包含的 task和过期时间
type Timer struct {
expiration int64
//要被执行的任务
task func()
}
接下来的思路跟传统的时间轮是一样的,拨动时间轮的函数 + 添加定时任务函数
开始时间轮
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(time.Duration(tw.tickMs) * time.Millisecond)
go func() {
for {
select {
case t := <- tw.ticker.C:
tw.advanceClock(t.UnixMilli())
case <- tw.exitC:
return
}
}
}()
}
拨动时间轮的函数 tw.advanceClock()
//拨动时钟
func (tw *TimeWheel) advanceClock(expiration int64) {
level := atomic.LoadInt64(&tw.level)
currentTime := truncate(expiration, tw.tickMs)
atomic.StoreInt64(&tw.currentTime, currentTime)
if level == 0 {
virtualID := expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Flush(tw.addOrRun)
} else {
prevflowWheel := atomic.LoadPointer(&tw.prevflowWheel)
if prevflowWheel != nil {
virtualID := expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Flush((*TimeWheel)(prevflowWheel).addOrRun)
}
}
//如果基础的时钟指针转完了一圈,则递归拨动下一级时钟
if currentTime >= tw.endMs {
atomic.StoreInt64(&tw.startMs, currentTime)
atomic.StoreInt64(&tw.endMs, currentTime + tw.wheelTime)
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel != nil {
(*TimeWheel)(overflowWheel).advanceClock(currentTime)
}
}
}
- 传入时间轮当前时间
expiration
- 计算出指针走到当前时间经过了多少滴答
- 如果当前层级为零,则获取任务列表
- 根据滴答数与整个时间轮长度取模,得到当前指针所指向的位置,得到任务列表
tw.bucket[virtualID%tw.wheelSize]
如果当前层级不为零,把快要过期的任务列表添加到上一层级的时间轮中,
添加或者执行任务列表里的任务
如果基础的时钟指针转完了一圈,则递归拨动下一级时钟。回到步骤3.
再开看看添加定时器的函数 AfterFunc
:
AfterFunc
func (tw *TimeWheel) AfterFunc(d time.Duration, f func()) *Timer {
t := &Timer{
expiration: time.Now().UTC().Add(d).UnixMilli(),
task: f,
}
tw.addOrRun(t)
return t
}
func (tw *TimeWheel) addOrRun(t *Timer) {
if !tw.add(t) {
workerID := t.expiration % tw.WorkPool.WorkerPoolSize
//将请求消息发送给任务队列
tw.WorkPool.TaskQueue[workerID] <- t.task
}
}
func (tw *TimeWheel) add(t *Timer) bool {
currentTime := atomic.LoadInt64(&tw.currentTime)
if t.expiration < currentTime + tw.tickMs {
return false
}else if t.expiration < currentTime + tw.wheelTime {
virtualID := t.expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Add(t)
b.SetExpiration(virtualID * tw.tickMs)
}else {
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel == nil {
atomic.CompareAndSwapPointer(
&tw.overflowWheel,
nil,
unsafe.Pointer(newTimingWheel(tw.wheelTime, tw.wheelSize, currentTime)),
)
overflowWheel = atomic.LoadPointer(&tw.overflowWheel)
}
//递归添加到下一级定时器中
(*TimeWheel)(overflowWheel).add(t)
}
return true
}
- 如果
t.expiration<currentTime + tw.tickMs
,则说明定时任务已经过期,直接执行任务 - 如果
t.expiration<currentTime + tw.wheelTime
,则说明当前定时任务还未过期,则添加到当前层级中 - 如果
t.expiration >= currentTime + tw.wheelTime
,则说明当前定时任务不再当前层起中,需要添加到下层级时间轮。
贴完整的代码
timewheel.go
package timewheel
import (
"pro2d/src/common"
"pro2d/src/components/workpool"
"sync/atomic"
"time"
"unsafe"
)
type TimeWheel struct {
ticker *time.Ticker
tickMs int64 //一滴答的时间 1ms 可以自定义 我们这里选择使用1ms
wheelSize int64
startMs int64 //开始时间 in millisecond
endMs int64
wheelTime int64 //跑完一圈所需时间
level int64 //层级
//时间刻度 列表
bucket []*bucket
currentTime int64 //当前时间 in millisecond
prevflowWheel unsafe.Pointer // type: *TimingWheel
overflowWheel unsafe.Pointer // type: *TimingWheel
exitC chan struct{}
WorkPool *workpool.WorkPool
}
func NewTimeWheel(tick time.Duration, wheelSize int64) *TimeWheel {
//转化为毫秒
tickMs := int64(tick / time.Millisecond)
//如果小于零
if tickMs <=0 {
panic("tick must be greater than or equal to 1 ms")
}
startMs := time.Now().UnixMilli() //ms
workpool := workpool.NewWorkPool(common.WorkerPoolSize, common.MaxTaskPerWorker)
return newTimingWheel(tickMs, wheelSize, startMs, 0, nil, workpool)
}
func newTimingWheel(tick, wheelSize int64, start, level int64, prev *TimeWheel, pool *workpool.WorkPool) *TimeWheel {
buckets := make([]*bucket, wheelSize)
for i := range buckets {
buckets[i] = newBucket()
}
return &TimeWheel{
tickMs: tick,
wheelSize: wheelSize,
startMs: start,
endMs: wheelSize * tick + start,
wheelTime: wheelSize * tick,
bucket: buckets,
currentTime: truncate(start, tick),
exitC: make(chan struct{}),
WorkPool: pool,
prevflowWheel: unsafe.Pointer(prev),
level: level,
}
}
func truncate(dst, m int64) int64 {
return dst - dst%m
}
func (tw *TimeWheel) add(t *Timer) bool {
currentTime := atomic.LoadInt64(&tw.currentTime)
if t.expiration < currentTime + tw.tickMs {
return false
}else if t.expiration < currentTime + tw.wheelTime {
virtualID := t.expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Add(t)
b.SetExpiration(virtualID * tw.tickMs)
}else {
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel == nil {
level := atomic.LoadInt64(&tw.level) + 1
atomic.CompareAndSwapPointer(
&tw.overflowWheel,
nil,
unsafe.Pointer(newTimingWheel(tw.wheelTime, tw.wheelSize, currentTime, level, tw , tw.WorkPool)),
)
overflowWheel = atomic.LoadPointer(&tw.overflowWheel)
}
//递归添加到下一级定时器中
(*TimeWheel)(overflowWheel).add(t)
}
return true
}
func (tw *TimeWheel) addOrRun(t *Timer) {
if !tw.add(t) {
workerID := t.expiration % tw.WorkPool.WorkerPoolSize
//将请求消息发送给任务队列
tw.WorkPool.TaskQueue[workerID] <- t.task
}
}
//拨动时钟
func (tw *TimeWheel) advanceClock(expiration int64) {
level := atomic.LoadInt64(&tw.level)
currentTime := truncate(expiration, tw.tickMs)
atomic.StoreInt64(&tw.currentTime, currentTime)
if level == 0 {
virtualID := expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Flush(tw.addOrRun)
} else {
prevflowWheel := atomic.LoadPointer(&tw.prevflowWheel)
if prevflowWheel != nil {
virtualID := expiration / tw.tickMs //需要多少滴答数
b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
b.Flush((*TimeWheel)(prevflowWheel).addOrRun)
}
}
//如果基础的时钟指针转完了一圈,则递归拨动下一级时钟
if currentTime >= tw.endMs {
atomic.StoreInt64(&tw.startMs, currentTime)
atomic.StoreInt64(&tw.endMs, currentTime + tw.wheelTime)
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel != nil {
(*TimeWheel)(overflowWheel).advanceClock(currentTime)
}
}
}
func (tw *TimeWheel) AfterFunc(d time.Duration, f func()) *Timer {
t := &Timer{
expiration: time.Now().UTC().Add(d).UnixMilli(),
task: f,
}
tw.addOrRun(t)
return t
}
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(time.Duration(tw.tickMs) * time.Millisecond)
tw.WorkPool.StartWorkerPool()
go func() {
for {
select {
case t := <- tw.ticker.C:
tw.advanceClock(t.UnixMilli())
case <- tw.exitC:
return
}
}
}()
}
func (tw *TimeWheel) Stop() {
tw.exitC <- struct{}{}
}
bucket.go
package timewheel
import (
"container/list"
"sync"
"sync/atomic"
"unsafe"
)
type bucket struct {
//过期时间
expiration int64
mu sync.Mutex
//相同过期时间的任务队列
timers *list.List
}
func newBucket() *bucket {
return &bucket{
expiration: -1,
mu: sync.Mutex{},
timers: list.New(),
}
}
func (b *bucket) SetExpiration(expiration int64) {
atomic.AddInt64(&b.expiration, expiration)
}
func (b *bucket) Add(t *Timer) {
b.mu.Lock()
defer b.mu.Unlock()
e := b.timers.PushBack(t)
t.setBucket(b)
t.element = e
}
func (b *bucket) Flush(reinsert func(*Timer)) {
b.mu.Lock()
defer b.mu.Unlock()
for e := b.timers.Front(); e != nil; {
next := e.Next()
t := e.Value.(*Timer)
b.remove(t)
reinsert(t)
e = next
}
}
func (b *bucket) remove(t *Timer) bool {
if t.getBucket() != b {
return false
}
b.timers.Remove(t.element)
t.setBucket(nil)
t.element = nil
return true
}
func (b *bucket) Remove(t *Timer) bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.remove(t)
}
type Timer struct {
expiration int64
//要被执行的任务
task func()
b unsafe.Pointer
element *list.Element
}
func (t *Timer) setBucket(b *bucket) {
atomic.StorePointer(&t.b, unsafe.Pointer(b))
}
func (t *Timer) getBucket() *bucket {
return (*bucket)(atomic.LoadPointer(&t.b))
}
timewheel_test.go
package timewheel
import (
"fmt"
"testing"
"time"
)
var tw *TimeWheel
func Add() {
fmt.Println("ADD : 123456")
tw.AfterFunc(6*time.Second, Add)
}
func Add1() {
fmt.Println("GET : 78901112")
tw.AfterFunc(9*time.Second, Add1)
}
func TestTimeWheel_AfterFunc(t *testing.T) {
tw = NewTimeWheel(time.Second, 5)
tw.Start()
defer tw.Stop()
Add()
Add1()
time.Sleep(time.Second * 200)
}
workpool
package workpool
type Job func()
type WorkPool struct {
WorkerPoolSize int64
MaxTaskPerWorker int64
TaskQueue []chan Job
}
func NewWorkPool(poolSize, maxTaskSize int64) *WorkPool {
return &WorkPool{
WorkerPoolSize: poolSize,
MaxTaskPerWorker: maxTaskSize,
TaskQueue: make([]chan Job, poolSize),
}
}
//StartOneWorker 启动一个Worker工作流程
func (wp *WorkPool) StartOneWorker(workerID int, taskQueue chan Job) {
//不断的等待队列中的消息
for {
select {
case job := <-taskQueue:
_ = workerID
job()
}
}
}
func (wp *WorkPool) StartWorkerPool() {
//遍历需要启动worker的数量,依此启动
for i := 0; i < int(wp.WorkerPoolSize); i++ {
//一个worker被启动
//给当前worker对应的任务队列开辟空间
wp.TaskQueue[i] = make(chan Job, wp.MaxTaskPerWorker)
//启动当前Worker,阻塞的等待对应的任务队列是否有消息传递进来
go wp.StartOneWorker(i, wp.TaskQueue[i])
}
}
总结
传统时间轮的问题
- 当时间跨度较大的时候,提升单层时间轮的tick可以减少空转次数,但是会导致时间精度较低.
层级时间轮
- 层级时间轮既可以避免精度降低,又避免了指针空转的次数。
- 上文中可以看到,使用单协程来调度 + 协程池的方式来处理业务
- 应使用单例,全局化一个定时器
什么时候使用层级定时器
- 当时间跨度很大的时候
还有什么问题
- 当任务较少的时候,可以看到,虽然是单协程调度,但是避免不了指针空转的情况。 需求是那种可以知道任务快过期了抛出来给定时器。
参考
--完--
- 原文作者: 留白
- 原文链接: https://zfunnily.github.io/2022/03/timer/
- 更新时间:2024-04-16 01:01:05
- 本文声明:转载请标记原文作者及链接