实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import java.util.concurrent.TimeUnit;

/**
* @author <a href="mailto:wf2311@163.com">wf2311</a>
* @since 2021/4/25 11:52.
*/
public class LeakyBucketRateLimiter {
/**
* 桶的大小
*/
private final long bucket;

/**
* 桶的已用量
*/
private long used;

/**
* 桶的流出速率
*/
private final int rate;

/**
* 流出速率单位
*/
private final TimeUnit rateUnit;

/**
*
*/
private final int perMillisRadio;

/**
* 最新刷新时间
*/
private volatile long lastRefreshTime;

public static LeakyBucketRateLimiter create(long bucket, int rate, TimeUnit rateUnit) {
return new LeakyBucketRateLimiter(bucket, rate, rateUnit);
}

private LeakyBucketRateLimiter(long bucket, int rate, TimeUnit rateUnit) {
this.bucket = bucket;
this.rate = rate;
this.rateUnit = rateUnit;
this.perMillisRadio = convertMillisRadio();
}

private int convertMillisRadio() {
switch (rateUnit) {
case MILLISECONDS:
return 1;
case SECONDS:
return 1000;
case MINUTES:
return 1000 * 60;
case HOURS:
return 1000 * 60 * 60;
default:
throw new AssertionError();
}
}

private void refreshBucketUsed() {
long now = System.currentTimeMillis();
if (lastRefreshTime > 0) {
long n = (now - lastRefreshTime) / perMillisRadio;
used = Math.max(0, used - n * rate);
}
lastRefreshTime = now;
}

public boolean tryAcquire() {
return tryAcquire(1);
}

public synchronized boolean tryAcquire(int n) {
//刷新桶的使用量
refreshBucketUsed();
//如果桶未满,则获取成功
if (used + n <= bucket) {
used += n;
return true;
}
return false;
}

}