TokenAwareRateLimiter.java

package com.yumu.noveltranslator.adapter.out.translate;

import com.yumu.noveltranslator.properties.TranslationLimitProperties;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Token-aware 速率限制:按用户 TPM(tokens per minute)配额限流,
 * 防止单个用户耗尽全局 LLM 算力池。
 */
@Service
@Slf4j
public class TokenAwareRateLimiter {

    private static final int WINDOW_SECONDS = 60;

    private final TranslationLimitProperties limitProperties;
    private final ConcurrentHashMap<String, SlidingWindowCounter> userCounters = new ConcurrentHashMap<>();

    public TokenAwareRateLimiter(TranslationLimitProperties limitProperties) {
        this.limitProperties = limitProperties;
    }

    /**
     * 估算文本的 token 数量。
     * CJK 字符约 0.5 token/字,西方字符约 1.3 token/字。
     */
    public static int estimateTokens(String text) {
        if (text == null || text.isEmpty()) return 0;
        int cjkCount = 0;
        for (int i = 0; i < text.length(); i++) {
            char c = text.charAt(i);
            if (isCJK(c)) cjkCount++;
        }
        int westernCount = text.length() - cjkCount;
        return (int) Math.ceil(cjkCount * 0.5 + westernCount * 1.3);
    }

    private static boolean isCJK(char c) {
        return (c >= 0x4E00 && c <= 0x9FFF)   // CJK 统一表意文字
            || (c >= 0x3400 && c <= 0x4DBF)   // CJK 扩展 A
            || (c >= 0x3040 && c <= 0x309F)   // 平假名
            || (c >= 0x30A0 && c <= 0x30FF)   // 片假名
            || (c >= 0xAC00 && c <= 0xD7AF);  // 韩文谚文
    }

    public boolean tryConsume(String userId, String userLevel, int tokenCount) {
        int tpmLimit = getTpmLimit(userLevel);
        SlidingWindowCounter counter = userCounters.computeIfAbsent(userId, k -> new SlidingWindowCounter());
        return counter.tryConsume(tokenCount, tpmLimit);
    }

    public void refund(String userId, int tokenCount) {
        SlidingWindowCounter counter = userCounters.get(userId);
        if (counter != null) {
            counter.refund(tokenCount);
        }
    }

    private int getTpmLimit(String userLevel) {
        if (userLevel == null) return limitProperties.getAnonymousTpmLimit();
        return switch (userLevel.toLowerCase()) {
            case "max", "premium" -> limitProperties.getMaxTpmLimit();
            case "pro" -> limitProperties.getProTpmLimit();
            case "anonymous" -> limitProperties.getAnonymousTpmLimit();
            default -> limitProperties.getFreeTpmLimit();
        };
    }

    @Scheduled(fixedRate = 300_000)
    public void cleanupIdleCounters() {
        long now = System.currentTimeMillis();
        int removed = 0;
        for (var it = userCounters.entrySet().iterator(); it.hasNext(); ) {
            var entry = it.next();
            if (now - entry.getValue().getLastAccessTime() > 300_000) {
                it.remove();
                removed++;
            }
        }
        if (removed > 0) {
            log.info("TokenAwareRateLimiter 清理了 {} 个空闲计数器", removed);
        }
    }

    /**
     * 单用户 60 秒滑动窗口 token 计数器。
     * 使用 60 个桶(每秒一个)实现 O(1) 过期清理。
     */
    private static class SlidingWindowCounter {
        private final int[] buckets = new int[WINDOW_SECONDS];
        private final long[] timestamps = new long[WINDOW_SECONDS];
        private final AtomicInteger totalInWindow = new AtomicInteger(0);
        private volatile long lastAccessTime;

        SlidingWindowCounter() {
            this.lastAccessTime = System.currentTimeMillis();
        }

        synchronized boolean tryConsume(int cost, int limit) {
            long now = System.currentTimeMillis();
            lastAccessTime = now;
            evictExpired(now);
            int current = totalInWindow.get();
            if (current + cost > limit) {
                return false;
            }
            int bucketIdx = (int) ((now / 1000) % WINDOW_SECONDS);
            buckets[bucketIdx] += cost;
            timestamps[bucketIdx] = now / 1000;
            totalInWindow.addAndGet(cost);
            return true;
        }

        synchronized void refund(int cost) {
            totalInWindow.updateAndGet(v -> Math.max(0, v - cost));
        }

        private void evictExpired(long nowMillis) {
            long nowSec = nowMillis / 1000;
            int total = 0;
            for (int i = 0; i < WINDOW_SECONDS; i++) {
                if (nowSec - timestamps[i] < WINDOW_SECONDS) {
                    total += buckets[i];
                } else {
                    buckets[i] = 0;
                    timestamps[i] = 0;
                }
            }
            totalInWindow.set(total);
        }

        long getLastAccessTime() {
            return lastAccessTime;
        }
    }
}