Go+Redis实现常见限流算法的示例代码
限流是项目中经常需要使用到的一种工具,一般用于限制用户的请求的频率,也可以避免瞬间流量过大导致系统崩溃,或者稳定消息处理速率。并且有时候我们还需要使用到分布式限流,常见的实现方式是使用Redis作为中心存储。
这个文章主要是使用Go+Redis实现常见的限流算法,如果需要了解每种限流算法的原理可以阅读文章 Go实现常见的限流算法
下面的代码使用到了go-redis客户端
固定窗口
使用Redis实现固定窗口比较简单,主要是由于固定窗口同时只会存在一个窗口,所以我们可以在第一次进入窗口时使用pexpire
命令设置过期时间为窗口时间大小,这样窗口会随过期时间而失效,同时我们使用incr
命令增加窗口计数。
因为我们需要在counter==1
的时候设置窗口的过期时间,为了保证原子性,我们使用简单的Lua
脚本实现。
- const fixedwindowLimiterTryAcquireRedisScript = `
- — ARGV[1]: 窗口时间大小
- — ARGV[2]: 窗口请求上限
- local window = tonumber(ARGV[1])
- local limit = tonumber(ARGV[2])
- — 获取原始值
- local counter = tonumber(redis.call(“get”, KEYS[1]))
- if counter == nil then
- counter = 0
- end
- — 若到达窗口请求上限,请求失败
- if counter >= limit then
- return 0
- end
- — 窗口值+1
- redis.call(“incr”, KEYS[1])
- if counter == 0 then
- redis.call(“pexpire”, KEYS[1], window)
- end
- return 1
- `
- package redis
- import (
- “context”
- “errors”
- “github.com/go-redis/redis/v8″
- “time”
- )
- // FixedWindowLimiter 固定窗口限流器
- type FixedWindowLimiter struct {
- limit int // 窗口请求上限
- window int // 窗口时间大小
- client *redis.Client // Redis客户端
- script *redis.Script // TryAcquire脚本
- }
- func NewFixedWindowLimiter(client *redis.Client, limit int, window time.Duration) (*FixedWindowLimiter, error) {
- // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
- if window%time.Millisecond != 0 {
- return nil, errors.New(“the window uint must not be less than millisecond”)
- }
- return &FixedWindowLimiter{
- limit: limit,
- window: int(window / time.Millisecond),
- client: client,
- script: redis.NewScript(fixedWindowLimiterTryAcquireRedisScript),
- }, nil
- }
- func (l *FixedWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
- success, err := l.script.Run(ctx, l.client, []string{resource}, l.window, l.limit).Bool()
- if err != nil {
- return err
- }
- // 若到达窗口请求上限,请求失败
- if !success {
- return ErrAcquireFailed
- }
- return nil
- }
滑动窗口
hash实现
我们使用Redis的hash
存储每个小窗口的计数,每次请求会把所有有效窗口
的计数累加到count
,使用hdel
删除失效窗口,最后判断窗口的总计数是否大于上限。
我们基本上把所有的逻辑都放到Lua脚本里面,其中大头是对hash
的遍历,时间复杂度是O(N),N是小窗口数量,所以小窗口数量最好不要太多。
- const slidingWindowLimiterTryAcquireRedisScriptHashImpl = `
- — ARGV[1]: 窗口时间大小
- — ARGV[2]: 窗口请求上限
- — ARGV[3]: 当前小窗口值
- — ARGV[4]: 起始小窗口值
- local window = tonumber(ARGV[1])
- local limit = tonumber(ARGV[2])
- local currentSmallWindow = tonumber(ARGV[3])
- local startSmallWindow = tonumber(ARGV[4])
- — 计算当前窗口的请求总数
- local counters = redis.call(“hgetall”, KEYS[1])
- local count = 0
- for i = 1, #(counters) / 2 do
- local smallWindow = tonumber(counters[i * 2 – 1])
- local counter = tonumber(counters[i * 2])
- if smallWindow < startSmallWindow then
- redis.call(“hdel”, KEYS[1], smallWindow)
- else
- count = count + counter
- end
- end
- — 若到达窗口请求上限,请求失败
- if count >= limit then
- return 0
- end
- — 若没到窗口请求上限,当前小窗口计数器+1,请求成功
- redis.call(“hincrby”, KEYS[1], currentSmallWindow, 1)
- redis.call(“pexpire”, KEYS[1], window)
- return 1
- `
- package redis
- import (
- “context”
- “errors”
- “github.com/go-redis/redis/v8”
- “time”
- )
- // SlidingWindowLimiter 滑动窗口限流器
- type SlidingWindowLimiter struct {
- limit int // 窗口请求上限
- window int64 // 窗口时间大小
- smallWindow int64 // 小窗口时间大小
- smallWindows int64 // 小窗口数量
- client *redis.Client // Redis客户端
- script *redis.Script // TryAcquire脚本
- }
- func NewSlidingWindowLimiter(client *redis.Client, limit int, window, smallWindow time.Duration) (
- *SlidingWindowLimiter, error) {
- // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
- if window%time.Millisecond != 0 || smallWindow%time.Millisecond != 0 {
- return nil, errors.New(“the window uint must not be less than millisecond”)
- }
- // 窗口时间必须能够被小窗口时间整除
- if window%smallWindow != 0 {
- return nil, errors.New(“window cannot be split by integers”)
- }
- return &SlidingWindowLimiter{
- limit: limit,
- window: int64(window / time.Millisecond),
- smallWindow: int64(smallWindow / time.Millisecond),
- smallWindows: int64(window / smallWindow),
- client: client,
- script: redis.NewScript(slidingWindowLimiterTryAcquireRedisScriptHashImpl),
- }, nil
- }
- func (l *SlidingWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
- // 获取当前小窗口值
- currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
- // 获取起始小窗口值
- startSmallWindow := currentSmallWindow – l.smallWindow*(l.smallWindows–1)
- success, err := l.script.Run(
- ctx, l.client, []string{resource}, l.window, l.limit, currentSmallWindow, startSmallWindow).Bool()
- if err != nil {
- return err
- }
- // 若到达窗口请求上限,请求失败
- if !success {
- return ErrAcquireFailed
- }
- return nil
- }
list实现
如果小窗口数量特别多,可以使用list
优化时间复杂度,list的结构是:
[counter, smallWindow1, count1, smallWindow2, count2, smallWindow3, count3...]
也就是我们使用list的第一个元素存储计数器,每个窗口用两个元素表示,第一个元素表示小窗口值,第二个元素表示这个小窗口的计数。不直接把小窗口值和计数放到一个元素里是因为Redis Lua脚本里没有分割字符串的函数。
具体操作流程:
1.获取list长度
2.如果长度是0,设置counter,长度+1
3.如果长度大于1,获取第二第三个元素
如果该值小于起始小窗口值,counter-第三个元素的值,删除第二第三个元素,长度-2
4.如果counter大于等于limit,请求失败
5.如果长度大于1,获取倒数第二第一个元素
- 如果倒数第二个元素小窗口值大于等于当前小窗口值,表示当前请求因为网络延迟的问题,到达服务器的时候,窗口已经过时了,把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
- 否则,添加新的窗口值,添加新的计数(1),更新过期时间
6.否则,添加新的窗口值,添加新的计数(1),更新过期时间
7.counter + 1
8.返回成功
- const slidingWindowLimiterTryAcquireRedisScriptListImpl = `
- — ARGV[1]: 窗口时间大小
- — ARGV[2]: 窗口请求上限
- — ARGV[3]: 当前小窗口值
- — ARGV[4]: 起始小窗口值
- local window = tonumber(ARGV[1])
- local limit = tonumber(ARGV[2])
- local currentSmallWindow = tonumber(ARGV[3])
- local startSmallWindow = tonumber(ARGV[4])
- — 获取list长度
- local len = redis.call(“llen”, KEYS[1])
- — 如果长度是0,设置counter,长度+1
- local counter = 0
- if len == 0 then
- redis.call(“rpush”, KEYS[1], 0)
- redis.call(“pexpire”, KEYS[1], window)
- len = len + 1
- else
- — 如果长度大于1,获取第二第个元素
- local smallWindow1 = tonumber(redis.call(“lindex”, KEYS[1], 1))
- counter = tonumber(redis.call(“lindex”, KEYS[1], 0))
- — 如果该值小于起始小窗口值
- if smallWindow1 < startSmallWindow then
- local count1 = redis.call(“lindex”, KEYS[1], 2)
- — counter-第三个元素的值
- counter = counter – count1
- — 长度-2
- len = len – 2
- — 删除第二第三个元素
- redis.call(“lrem”, KEYS[1], 1, smallWindow1)
- redis.call(“lrem”, KEYS[1], 1, count1)
- end
- end
- — 若到达窗口请求上限,请求失败
- if counter >= limit then
- return 0
- end
- — 如果长度大于1,获取倒数第二第一个元素
- if len > 1 then
- local smallWindown = tonumber(redis.call(“lindex”, KEYS[1], -2))
- — 如果倒数第二个元素小窗口值大于等于当前小窗口值
- if smallWindown >= currentSmallWindow then
- — 把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
- local countn = redis.call(“lindex”, KEYS[1], -1)
- redis.call(“lset”, KEYS[1], -1, countn + 1)
- else
- — 否则,添加新的窗口值,添加新的计数(1),更新过期时间
- redis.call(“rpush”, KEYS[1], currentSmallWindow, 1)
- redis.call(“pexpire”, KEYS[1], window)
- end
- else
- — 否则,添加新的窗口值,添加新的计数(1),更新过期时间
- redis.call(“rpush”, KEYS[1], currentSmallWindow, 1)
- redis.call(“pexpire”, KEYS[1], window)
- end
- — counter + 1并更新
- redis.call(“lset”, KEYS[1], 0, counter + 1)
- return 1
- `
算法都是操作list
头部或者尾部,所以时间复杂度接近O(1)
漏桶算法
漏桶需要保存当前水位和上次放水时间,因此我们使用hash
来保存这两个值。
- const leakyBucketLimiterTryAcquireRedisScript = `
- — ARGV[1]: 最高水位
- — ARGV[2]: 水流速度/秒
- — ARGV[3]: 当前时间(秒)
- local peakLevel = tonumber(ARGV[1])
- local currentVelocity = tonumber(ARGV[2])
- local now = tonumber(ARGV[3])
- local lastTime = tonumber(redis.call(“hget”, KEYS[1], “lastTime”))
- local currentLevel = tonumber(redis.call(“hget”, KEYS[1], “currentLevel”))
- — 初始化
- if lastTime == nil then
- lastTime = now
- currentLevel = 0
- redis.call(“hmset”, KEYS[1], “currentLevel”, currentLevel, “lastTime”, lastTime)
- end
- — 尝试放水
- — 距离上次放水的时间
- local interval = now – lastTime
- if interval > 0 then
- — 当前水位-距离上次放水的时间(秒)*水流速度
- local newLevel = currentLevel – interval * currentVelocity
- if newLevel < 0 then
- newLevel = 0
- end
- currentLevel = newLevel
- redis.call(“hmset”, KEYS[1], “currentLevel”, newLevel, “lastTime”, now)
- end
- — 若到达最高水位,请求失败
- if currentLevel >= peakLevel then
- return 0
- end
- — 若没有到达最高水位,当前水位+1,请求成功
- redis.call(“hincrby”, KEYS[1], “currentLevel”, 1)
- redis.call(“expire”, KEYS[1], peakLevel / currentVelocity)
- return 1
- `
- package redis
- import (
- “context”
- “github.com/go-redis/redis/v8”
- “time”
- )
- // LeakyBucketLimiter 漏桶限流器
- type LeakyBucketLimiter struct {
- peakLevel int // 最高水位
- currentVelocity int // 水流速度/秒
- client *redis.Client // Redis客户端
- script *redis.Script // TryAcquire脚本
- }
- func NewLeakyBucketLimiter(client *redis.Client, peakLevel, currentVelocity int) *LeakyBucketLimiter {
- return &LeakyBucketLimiter{
- peakLevel: peakLevel,
- currentVelocity: currentVelocity,
- client: client,
- script: redis.NewScript(leakyBucketLimiterTryAcquireRedisScript),
- }
- }
- func (l *LeakyBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
- // 当前时间
- now := time.Now().Unix()
- success, err := l.script.Run(ctx, l.client, []string{resource}, l.peakLevel, l.currentVelocity, now).Bool()
- if err != nil {
- return err
- }
- // 若到达窗口请求上限,请求失败
- if !success {
- return ErrAcquireFailed
- }
- return nil
- }
令牌桶
令牌桶可以看作是漏桶的相反算法,它们一个是把水倒进桶里,一个是从桶里获取令牌。
- const tokenBucketLimiterTryAcquireRedisScript = `
- — ARGV[1]: 容量
- — ARGV[2]: 发放令牌速率/秒
- — ARGV[3]: 当前时间(秒)
- local capacity = tonumber(ARGV[1])
- local rate = tonumber(ARGV[2])
- local now = tonumber(ARGV[3])
- local lastTime = tonumber(redis.call(“hget”, KEYS[1], “lastTime”))
- local currentTokens = tonumber(redis.call(“hget”, KEYS[1], “currentTokens”))
- — 初始化
- if lastTime == nil then
- lastTime = now
- currentTokens = capacity
- redis.call(“hmset”, KEYS[1], “currentTokens”, currentTokens, “lastTime”, lastTime)
- end
- — 尝试发放令牌
- — 距离上次发放令牌的时间
- local interval = now – lastTime
- if interval > 0 then
- — 当前令牌数量+距离上次发放令牌的时间(秒)*发放令牌速率
- local newTokens = currentTokens + interval * rate
- if newTokens > capacity then
- newTokens = capacity
- end
- currentTokens = newTokens
- redis.call(“hmset”, KEYS[1], “currentTokens”, newTokens, “lastTime”, now)
- end
- — 如果没有令牌,请求失败
- if currentTokens == 0 then
- return 0
- end
- — 果有令牌,当前令牌-1,请求成功
- redis.call(“hincrby”, KEYS[1], “currentTokens”, -1)
- redis.call(“expire”, KEYS[1], capacity / rate)
- return 1
- `
- package redis
- import (
- “context”
- “github.com/go-redis/redis/v8”
- “time”
- )
- // TokenBucketLimiter 令牌桶限流器
- type TokenBucketLimiter struct {
- capacity int // 容量
- rate int // 发放令牌速率/秒
- client *redis.Client // Redis客户端
- script *redis.Script // TryAcquire脚本
- }
- func NewTokenBucketLimiter(client *redis.Client, capacity, rate int) *TokenBucketLimiter {
- return &TokenBucketLimiter{
- capacity: capacity,
- rate: rate,
- client: client,
- script: redis.NewScript(tokenBucketLimiterTryAcquireRedisScript),
- }
- }
- func (l *TokenBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
- // 当前时间
- now := time.Now().Unix()
- success, err := l.script.Run(ctx, l.client, []string{resource}, l.capacity, l.rate, now).Bool()
- if err != nil {
- return err
- }
- // 若到达窗口请求上限,请求失败
- if !success {
- return ErrAcquireFailed
- }
- return nil
- }
滑动日志
算法流程与滑动窗口相同,只是它可以指定多个策略,同时在请求失败的时候,需要通知调用方是被哪个策略所拦截。
- const slidingLogLimiterTryAcquireRedisScriptHashImpl = `
- — ARGV[1]: 当前小窗口值
- — ARGV[2]: 第一个策略的窗口时间大小
- — ARGV[i * 2 + 1]: 每个策略的起始小窗口值
- — ARGV[i * 2 + 2]: 每个策略的窗口请求上限
- local currentSmallWindow = tonumber(ARGV[1])
- — 第一个策略的窗口时间大小
- local window = tonumber(ARGV[2])
- — 第一个策略的起始小窗口值
- local startSmallWindow = tonumber(ARGV[3])
- local strategiesLen = #(ARGV) / 2 – 1
- — 计算每个策略当前窗口的请求总数
- local counters = redis.call(“hgetall”, KEYS[1])
- local counts = {}
- — 初始化counts
- for j = 1, strategiesLen do
- counts[j] = 0
- end
- for i = 1, #(counters) / 2 do
- local smallWindow = tonumber(counters[i * 2 – 1])
- local counter = tonumber(counters[i * 2])
- if smallWindow < startSmallWindow then
- redis.call(“hdel”, KEYS[1], smallWindow)
- else
- for j = 1, strategiesLen do
- if smallWindow >= tonumber(ARGV[j * 2 + 1]) then
- counts[j] = counts[j] + counter
- end
- end
- end
- end
- — 若到达对应策略窗口请求上限,请求失败,返回违背的策略下标
- for i = 1, strategiesLen do
- if counts[i] >= tonumber(ARGV[i * 2 + 2]) then
- return i – 1
- end
- end
- — 若没到窗口请求上限,当前小窗口计数器+1,请求成功
- redis.call(“hincrby”, KEYS[1], currentSmallWindow, 1)
- redis.call(“pexpire”, KEYS[1], window)
- return -1
- `
- package redis
- import (
- “context”
- “errors”
- “fmt”
- “github.com/go-redis/redis/v8”
- “sort”
- “time”
- )
- // ViolationStrategyError 违背策略错误
- type ViolationStrategyError struct {
- Limit int // 窗口请求上限
- Window time.Duration // 窗口时间大小
- }
- func (e *ViolationStrategyError) Error() string {
- return fmt.Sprintf(“violation strategy that limit = %d and window = %d”, e.Limit, e.Window)
- }
- // SlidingLogLimiterStrategy 滑动日志限流器的策略
- type SlidingLogLimiterStrategy struct {
- limit int // 窗口请求上限
- window int64 // 窗口时间大小
- smallWindows int64 // 小窗口数量
- }
- func NewSlidingLogLimiterStrategy(limit int, window time.Duration) *SlidingLogLimiterStrategy {
- return &SlidingLogLimiterStrategy{
- limit: limit,
- window: int64(window),
- }
- }
- // SlidingLogLimiter 滑动日志限流器
- type SlidingLogLimiter struct {
- strategies []*SlidingLogLimiterStrategy // 滑动日志限流器策略列表
- smallWindow int64 // 小窗口时间大小
- client *redis.Client // Redis客户端
- script *redis.Script // TryAcquire脚本
- }
- func NewSlidingLogLimiter(client *redis.Client, smallWindow time.Duration, strategies …*SlidingLogLimiterStrategy) (
- *SlidingLogLimiter, error) {
- // 复制策略避免被修改
- strategies = append(make([]*SlidingLogLimiterStrategy, 0, len(strategies)), strategies…)
- // 不能不设置策略
- if len(strategies) == 0 {
- return nil, errors.New(“must be set strategies”)
- }
- // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
- if smallWindow%time.Millisecond != 0 {
- return nil, errors.New(“the window uint must not be less than millisecond”)
- }
- smallWindow = smallWindow / time.Millisecond
- for _, strategy := range strategies {
- if strategy.window%int64(time.Millisecond) != 0 {
- return nil, errors.New(“the window uint must not be less than millisecond”)
- }
- strategy.window = strategy.window / int64(time.Millisecond)
- }
- // 排序策略,窗口时间大的排前面,相同窗口上限大的排前面
- sort.Slice(strategies, func(i, j int) bool {
- a, b := strategies[i], strategies[j]
- if a.window == b.window {
- return a.limit > b.limit
- }
- return a.window > b.window
- })
- for i, strategy := range strategies {
- // 随着窗口时间变小,窗口上限也应该变小
- if i > 0 {
- if strategy.limit >= strategies[i–1].limit {
- return nil, errors.New(“the smaller window should be the smaller limit”)
- }
- }
- // 窗口时间必须能够被小窗口时间整除
- if strategy.window%int64(smallWindow) != 0 {
- return nil, errors.New(“window cannot be split by integers”)
- }
- strategy.smallWindows = strategy.window / int64(smallWindow)
- }
- return &SlidingLogLimiter{
- strategies: strategies,
- smallWindow: int64(smallWindow),
- client: client,
- script: redis.NewScript(slidingLogLimiterTryAcquireRedisScriptHashImpl),
- }, nil
- }
- func (l *SlidingLogLimiter) TryAcquire(ctx context.Context, resource string) error {
- // 获取当前小窗口值
- currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
- args := make([]interface{}, len(l.strategies)*2+2)
- args[0] = currentSmallWindow
- args[1] = l.strategies[0].window
- // 获取每个策略的起始小窗口值
- for i, strategy := range l.strategies {
- args[i*2+2] = currentSmallWindow – l.smallWindow*(strategy.smallWindows–1)
- args[i*2+3] = strategy.limit
- }
- index, err := l.script.Run(
- ctx, l.client, []string{resource}, args…).Int()
- if err != nil {
- return err
- }
- // 若到达窗口请求上限,请求失败
- if index != –1 {
- return &ViolationStrategyError{
- Limit: l.strategies[index].limit,
- Window: time.Duration(l.strategies[index].window),
- }
- }
- return nil
- }
总结
由于Redis拥有丰富而且高性能的数据类型,因此使用Redis实现限流算法并不困难,但是每个算法都需要编写Lua脚本,所以如果不熟悉Lua可能会踩一些坑。
需要完整代码和测试代码可以查看:github.com/jiaxwu/limiter/tree/main/redis
以上就是Go+Redis实现常见限流算法的示例代码的详细内容,更多关于Go Redis限流算法的资料请关注我们其它相关文章!
发表评论