Tokio 信号量 3 行代码,并发控流零阻塞

Photos provided by Unsplash OR Pexels

引言

在现代异步编程中,并发控制是确保系统稳定性和性能的关键要素。Tokio 的Semaphore(信号量)作为 Rust 异步生态系统中最重要的同步原语之一,为开发者提供了精确控制并发访问共享资源的能力。无论是数据库连接池的管理、API 请求的速率限制,还是文件 I/O 操作的并发控制,信号量都扮演着不可替代的角色。

本文将从基础概念出发,通过实际代码示例深入探讨 Tokio 信号量的核心 API、典型应用场景、高级模式以及性能优化策略,帮助开发者在实际项目中正确、高效地使用这一重要工具。

Tokio 并发信号量实战分析

1. 信号量基础概念

信号量(Semaphore)是控制并发访问共享资源的同步原语,用于限制同时访问特定资源的线程或任务数量。

use tokio::sync::Semaphore;
use std::sync::Arc;

#[tokio::main]
async fn main() {
    // 创建一个最多允许 3 个并发任务的信号量
    let semaphore = Arc::new(Semaphore::new(3));
    
    let mut handles = vec![];
    
    // 启动 10 个任务
    for i in 0..10 {
        let semaphore = semaphore.clone();
        let handle = tokio::spawn(async move {
            // 获取一个许可
            let permit = semaphore.acquire().await.unwrap();
            println!("Task {} acquired permit", i);
            
            // 模拟工作
            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
            
            println!("Task {} releasing permit", i);
            // 退出作用域时自动释放许可
        });
        handles.push(handle);
    }
    
    // 等待所有任务完成
    for handle in handles {
        handle.await.unwrap();
    }
}

2. 信号量核心 API 详解

2.1 基本构造和操作

use tokio::sync::Semaphore;
use std::sync::Arc;

async fn semaphore_basics() {
    // 创建信号量,初始许可数量为 5
    let semaphore = Arc::new(Semaphore::new(5));
    
    // 获取当前可用许可数量
    println!("Available permits: {}", semaphore.available_permits());
    
    // 获取一个许可(异步等待)
    let _permit1 = semaphore.acquire().await.unwrap();
    println!("After acquiring 1 permit: {}", semaphore.available_permits());
    
    // 尝试获取许可(非阻塞)
    match semaphore.try_acquire() {
        Ok(permit) => {
            println!("Successfully acquired permit without blocking");
            // 释放许可
            drop(permit);
        },
        Err(_) => println!("No permits available"),
    }
    
    // 添加许可(增加信号量容量)
    semaphore.add_permits(2);
    println!("After adding 2 permits: {}", semaphore.available_permits());
}

2.2 限时获取和 try_acquire

use tokio::sync::Semaphore;
use std::sync::Arc;
use std::time::Duration;

async fn acquire_with_timeout() {
    let semaphore = Arc::new(Semaphore::new(1));
    
    // 获取许可
    let _permit = semaphore.acquire().await.unwrap();
    
    // 在另一个任务中尝试获取许可
    let semaphore_clone = semaphore.clone();
    let handle = tokio::spawn(async move {
        // 限时获取许可
        match tokio::time::timeout(
            Duration::from_millis(1000),
            semaphore_clone.acquire()
        ).await {
            Ok(Ok(permit)) => {
                println!("Successfully acquired permit after timeout");
                Some(permit)
            },
            Ok(Err(_)) => {
                println!("Semaphore closed");
                None
            },
            Err(_) => {
                println!("Timeout waiting for permit");
                None
            }
        }
    });
    
    let result = handle.await.unwrap();
    
    // 尝试立即获取许可
    match semaphore.try_acquire() {
        Ok(permit) => println!("Immediate acquisition successful"),
        Err(_) => println!("Immediate acquisition failed - no permits available"),
    }
}

3. 数据库连接池控制

use tokio::sync::{Semaphore, SemaphorePermit};
use std::sync::Arc;
use std::collections::VecDeque;
use std::time::Duration;

// 模拟数据库连接
#[derive(Debug)]
struct DatabaseConnection {
    id: u32,
    busy: bool,
}

impl DatabaseConnection {
    fn new(id: u32) -> Self {
        Self { id, busy: false }
    }
    
    async fn query(&mut self, sql: &str) -> String {
        println!("Connection {} executing query: {}", self.id, sql);
        tokio::time::sleep(Duration::from_millis(100)).await;
        format!("Result from connection {}", self.id)
    }
}

struct ConnectionPool {
    connections: VecDeque<DatabaseConnection>,
    semaphore: Arc<Semaphore>,
}

impl ConnectionPool {
    fn new(max_connections: usize) -> Self {
        let mut connections = VecDeque::new();
        for i in 0..max_connections {
            connections.push_back(DatabaseConnection::new(i as u32));
        }
        
        Self {
            connections,
            semaphore: Arc::new(Semaphore::new(max_connections)),
        }
    }
    
    async fn execute_query(&mut self, sql: &str) -> String {
        // 获取许可
        let _permit = self.semaphore.acquire().await.unwrap();
        
        // 获取可用连接
        let mut conn = self.connections.pop_front().unwrap();
        conn.busy = true;
        
        // 执行查询
        let result = conn.query(sql).await;
        
        // 释放连接
        conn.busy = false;
        self.connections.push_back(conn);
        
        result
    }
}

#[tokio::main]
async fn database_pool_example() {
    let mut pool = ConnectionPool::new(3); // 最多 3 个并发连接
    
    let mut tasks = vec![];
    for i in 0..10 {
        let mut pool = ConnectionPool::new(3);
        let task = tokio::spawn(async move {
            let sql = format!("SELECT * FROM table_{}", i);
            let result = pool.execute_query(&sql).await;
            println!("Query {} result: {}", i, result);
        });
        tasks.push(task);
    }
    
    for task in tasks {
        task.await.unwrap();
    }
}

4. API 请求速率限制

use tokio::sync::Semaphore;
use std::sync::Arc;
use std::time::Duration;

struct ApiRateLimiter {
    semaphore: Arc<Semaphore>,
    rate: u32, // 每秒请求数
}

impl ApiRateLimiter {
    fn new(rate: u32) -> Self {
        let semaphore = Arc::new(Semaphore::new(rate as usize));
        
        // 启动定时任务,每秒添加许可
        let semaphore_clone = semaphore.clone();
        tokio::spawn(async move {
            loop {
                tokio::time::sleep(Duration::from_secs(1)).await;
                semaphore_clone.add_permits(rate as usize);
            }
        });
        
        Self { semaphore, rate }
    }
    
    async fn make_request(&self, url: &str) -> String {
        // 获取许可
        let _permit = self.semaphore.acquire().await.unwrap();
        
        println!("Making request to: {}", url);
        // 模拟 API 调用
        tokio::time::sleep(Duration::from_millis(100)).await;
        
        format!("Response from {}", url)
    }
}

#[tokio::main]
async fn rate_limiting_example() {
    let limiter = ApiRateLimiter::new(2); // 每秒最多 2 个请求
    
    let mut tasks = vec![];
    for i in 0..10 {
        let limiter = limiter.clone();
        let url = format!("https://api.example.com/endpoint{}", i);
        
        let task = tokio::spawn(async move {
            let response = limiter.make_request(&url).await;
            println!("Received: {}", response);
        });
        tasks.push(task);
    }
    
    for task in tasks {
        task.await.unwrap();
    }
}

// 为 Clone 实现手动克隆
impl Clone for ApiRateLimiter {
    fn clone(&self) -> Self {
        Self {
            semaphore: self.semaphore.clone(),
            rate: self.rate,
        }
    }
}

5. 文件I/O并发控制

use tokio::sync::Semaphore;
use std::sync::Arc;
use std::path::Path;
use tokio::fs;

struct FileProcessor {
    io_semaphore: Arc<Semaphore>,
    processing_semaphore: Arc<Semaphore>,
}

impl FileProcessor {
    fn new(max_io_concurrent: usize, max_processing_concurrent: usize) -> Self {
        Self {
            io_semaphore: Arc::new(Semaphore::new(max_io_concurrent)),
            processing_semaphore: Arc::new(Semaphore::new(max_processing_concurrent)),
        }
    }
    
    async fn process_file(&self, file_path: &str) -> Result<String, Box<dyn std::error::Error>> {
        // 限制文件I/O操作
        let _io_permit = self.io_semaphore.acquire().await.unwrap();
        
        println!("Reading file: {}", file_path);
        let content = fs::read_to_string(file_path).await?;
        
        // 释放I/O许可,开始处理
        drop(_io_permit);
        
        // 限制处理操作
        let _proc_permit = self.processing_semaphore.acquire().await.unwrap();
        
        println!("Processing content of: {}", file_path);
        // 模拟处理时间
        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
        
        let processed = format!("Processed: {}", content.trim());
        
        Ok(processed)
    }
}

#[tokio::main]
async fn file_processing_example() {
    let processor = Arc::new(FileProcessor::new(2, 3)); // 最多2个并发I/O,3个并发处理
    
    let files = vec![
        "file1.txt",
        "file2.txt", 
        "file3.txt",
        "file4.txt",
        "file5.txt",
    ];
    
    let mut tasks = vec![];
    for file in files {
        let processor = processor.clone();
        let task = tokio::spawn(async move {
            match processor.process_file(file).await {
                Ok(result) => println!("File {} processed successfully", file),
                Err(e) => eprintln!("Error processing {}: {}", file, e),
            }
        });
        tasks.push(task);
    }
    
    for task in tasks {
        task.await.unwrap();
    }
}

6. 任务队列和工作池

use tokio::sync::{Semaphore, Notify};
use std::sync::Arc;
use std::collections::VecDeque;
use std::time::Duration;

#[derive(Debug)]
struct Task {
    id: u32,
    payload: String,
}

struct TaskQueue {
    queue: tokio::sync::Mutex<VecDeque<Task>>,
    semaphore: Arc<Semaphore>,
    notify: Arc<Notify>,
}

impl TaskQueue {
    fn new(max_concurrent: usize) -> Self {
        Self {
            queue: tokio::sync::Mutex::new(VecDeque::new()),
            semaphore: Arc::new(Semaphore::new(max_concurrent)),
            notify: Arc::new(Notify::new()),
        }
    }
    
    async fn add_task(&self, task: Task) {
        self.queue.lock().await.push_back(task);
        self.notify.notify_one();
    }
    
    async fn worker(&self, worker_id: u32) {
        loop {
            // 等待任务
            self.notify.notified().await;
            
            // 尝试获取任务
            let task = {
                let mut queue = self.queue.lock().await;
                queue.pop_front()
            };
            
            if let Some(task) = task {
                // 获取执行许可
                let _permit = self.semaphore.acquire().await.unwrap();
                
                println!("Worker {} processing task {}: {}", 
                        worker_id, task.id, task.payload);
                
                // 模拟处理时间
                tokio::time::sleep(Duration::from_millis(300)).await;
                
                println!("Worker {} completed task {}", worker_id, task.id);
            } else {
                // 如果没有任务,短暂等待后继续
                tokio::time::sleep(Duration::from_millis(10)).await;
            }
        }
    }
    
    async fn start_workers(&self, num_workers: usize) {
        for i in 0..num_workers {
            let queue = self.clone();
            tokio::spawn(async move {
                queue.worker(i as u32).await;
            });
        }
    }
}

// 为 Clone 实现
impl Clone for TaskQueue {
    fn clone(&self) -> Self {
        Self {
            queue: self.queue.clone(),
            semaphore: self.semaphore.clone(),
            notify: self.notify.clone(),
        }
    }
}

#[tokio::main]
async fn task_queue_example() {
    let queue = Arc::new(TaskQueue::new(2)); // 最多 2 个并发任务
    
    // 启动工作者
    queue.start_workers(3).await;
    
    // 添加任务
    for i in 0..10 {
        let task = Task {
            id: i,
            payload: format!("Task payload {}", i),
        };
        queue.add_task(task).await;
        tokio::time::sleep(Duration::from_millis(50)).await;
    }
    
    // 保持程序运行以便观察结果
    tokio::time::sleep(Duration::from_secs(5)).await;
}

7. 高级模式和最佳实践

7.1 条件信号量

use tokio::sync::{Semaphore, Notify};
use std::sync::Arc;
use std::collections::VecDeque;

struct ConditionalSemaphore {
    semaphore: Arc<Semaphore>,
    notify: Arc<Notify>,
    condition: tokio::sync::Mutex<VecDeque<bool>>,
}

impl ConditionalSemaphore {
    fn new(initial_permits: usize) -> Self {
        Self {
            semaphore: Arc::new(Semaphore::new(initial_permits)),
            notify: Arc::new(Notify::new()),
            condition: tokio::sync::Mutex::new(VecDeque::new()),
        }
    }
    
    async fn acquire_with_condition(&self, condition_met: bool) -> Option<tokio::sync::SemaphorePermit> {
        // 将条件添加到队列
        {
            let mut conditions = self.condition.lock().await;
            conditions.push_back(condition_met);
        }
        
        loop {
            // 检查队列中的第一个条件
            let should_acquire = {
                let conditions = self.condition.lock().await;
                conditions.front().copied()
            };
            
            if should_acquire == Some(true) {
                // 条件满足,尝试获取许可
                if let Ok(permit) = self.semaphore.try_acquire() {
                    // 移除已处理的条件
                    self.condition.lock().await.pop_front();
                    return Some(permit);
                }
            } else {
                // 条件不满足,等待通知
                self.notify.notified().await;
            }
        }
    }
    
    fn signal_condition_change(&self) {
        self.notify.notify_waiters();
    }
}

7.2 动态信号量调整

use tokio::sync::{Semaphore, RwLock};
use std::sync::Arc;
use std::time::Duration;

struct DynamicSemaphore {
    semaphore: Arc<Semaphore>,
    config: Arc<RwLock<SemaphoreConfig>>,
}

#[derive(Debug, Clone)]
struct SemaphoreConfig {
    min_permits: usize,
    max_permits: usize,
    current_target: usize,
}

impl DynamicSemaphore {
    fn new(initial_permits: usize, min: usize, max: usize) -> Self {
        let semaphore = Arc::new(Semaphore::new(initial_permits));
        let config = Arc::new(RwLock::new(SemaphoreConfig {
            min_permits: min,
            max_permits: max,
            current_target: initial_permits,
        }));
        
        // 启动自动调节任务
        let sem_clone = semaphore.clone();
        let config_clone = config.clone();
        tokio::spawn(Self::auto_adjust(sem_clone, config_clone));
        
        Self { semaphore, config }
    }
    
    async fn auto_adjust(
        semaphore: Arc<Semaphore>,
        config: Arc<RwLock<SemaphoreConfig>>,
    ) {
        loop {
            tokio::time::sleep(Duration::from_secs(5)).await;
            
            let config_read = config.read().await;
            let available = semaphore.available_permits();
            let target = config_read.current_target;
            
            if available < target / 2 {
                // 增加许可
                let to_add = target.saturating_sub(available);
                if to_add > 0 {
                    semaphore.add_permits(to_add.min(config_read.max_permits - target));
                }
            } else if available > target * 2 {
                // 减少许可(通过获取来减少可用数量)
                // 注意:这里不能直接减少许可,需要通过其他机制
            }
            
            drop(config_read);
        }
    }
    
    async fn update_target(&self, new_target: usize) {
        let mut config = self.config.write().await;
        config.current_target = new_target.clamp(config.min_permits, config.max_permits);
    }
    
    async fn acquire(&self) -> Result<tokio::sync::SemaphorePermit, tokio::sync::AcquireError> {
        self.semaphore.acquire().await
    }
}

8. 性能优化和监控

use tokio::sync::Semaphore;
use std::sync::Arc;
use std::time::Instant;
use tokio::time::Duration;

struct MonitoredSemaphore {
    semaphore: Arc<Semaphore>,
    stats: Arc<tokio::sync::Mutex<SemaphoreStats>>,
}

#[derive(Debug, Default)]
struct SemaphoreStats {
    total_acquisitions: u64,
    total_wait_time: Duration,
    max_wait_time: Duration,
    current_waiters: u32,
}

impl MonitoredSemaphore {
    fn new(permits: usize) -> Self {
        Self {
            semaphore: Arc::new(Semaphore::new(permits)),
            stats: Arc::new(tokio::sync::Mutex::new(SemaphoreStats::default())),
        }
    }
    
    async fn acquire_with_monitoring(&self) -> Result<tokio::sync::SemaphorePermit, tokio::sync::AcquireError> {
        let start = Instant::now();
        
        let permit = self.semaphore.acquire().await?;
        
        let duration = start.elapsed();
        {
            let mut stats = self.stats.lock().await;
            stats.total_acquisitions += 1;
            stats.total_wait_time += duration;
            if duration > stats.max_wait_time {
                stats.max_wait_time = duration;
            }
        }
        
        Ok(permit)
    }
    
    async fn get_stats(&self) -> SemaphoreStats {
        self.stats.lock().await.clone()
    }
}

#[tokio::main]
async fn monitoring_example() {
    let semaphore = Arc::new(MonitoredSemaphore::new(2));
    
    let mut tasks = vec![];
    for i in 0..10 {
        let sem = semaphore.clone();
        let task = tokio::spawn(async move {
            let _permit = sem.acquire_with_monitoring().await.unwrap();
            println!("Task {} acquired permit", i);
            tokio::time::sleep(Duration::from_millis(100)).await;
        });
        tasks.push(task);
    }
    
    for task in tasks {
        task.await.unwrap();
    }
    
    let stats = semaphore.get_stats().await;
    println!("Statistics: {:?}", stats);
}

总结

Tokio 信号量的关键要点:

  1. 基本用途:控制并发访问资源的数量
  2. 核心操作:获取许可、释放许可、添加许可
  3. 应用场景:连接池、API 限流、文件 I/O 控制、任务队列
  4. 性能考虑:避免过度竞争,合理设置许可数量
  5. 监控重要性:跟踪等待时间和资源利用率
  6. 错误处理:处理信号量关闭等异常情况

选择合适的信号量大小需要根据具体应用场景的负载特性和资源限制来决定。

版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)