导航菜单

实现 Worker Pool

🔴 困难

题目描述

实现一个 Worker Pool,支持以下功能:

  1. 限制并发 goroutine 数量
  2. 动态调整 worker 数量
  3. 优雅关闭(等待所有任务完成)

示例代码

type WorkerPool interface {
    Start(workers int)
    Submit(task func())
    Stop()
    Resize(workers int)
}

使用示例

func main() {
    pool := NewWorkerPool()
    pool.Start(5)  // 启动 5 个 worker
    
    // 提交 100 个任务
    for i := 0; i < 100; i++ {
        taskID := i
        pool.Submit(func() {
            fmt.Printf("Processing task %d\n", taskID)
            time.Sleep(time.Second)
        })
    }
    
    // 动态调整
    pool.Resize(10)  // 增加到 10 个 worker
    
    // 等待完成
    pool.Stop()
}

提示

  • 使用 channel 作为任务队列
  • 使用 sync.WaitGroup 等待 worker 退出
  • 使用 context 控制 worker 生命周期

解法

参考答案 (3 个标签)
Worker Pool 并发控制 channel

实现代码

type Task func()

type WorkerPool struct {
    tasks   chan Task
    wg      sync.WaitGroup
    quit    chan struct{}
    workers int
    mu      sync.RWMutex
}

func NewWorkerPool() *WorkerPool {
    return &WorkerPool{
        tasks: make(chan Task, 1000),
        quit:  make(chan struct{}),
    }
}

func (wp *WorkerPool) Start(workers int) {
    wp.mu.Lock()
    defer wp.mu.Unlock()
    
    wp.workers = workers
    
    for i := 0; i < workers; i++ {
        wp.wg.Add(1)
        go wp.worker()
    }
}

func (wp *WorkerPool) worker() {
    defer wp.wg.Done()
    
    for {
        select {
        case task, ok := <-wp.tasks:
            if !ok {
                return
            }
            task()
        case <-wp.quit:
            return
        }
    }
}

func (wp *WorkerPool) Submit(task Task) {
    wp.tasks <- task
}

func (wp *WorkerPool) Stop() {
    close(wp.tasks)
    wp.wg.Wait()
}

func (wp *WorkerPool) Resize(workers int) {
    wp.mu.Lock()
    defer wp.mu.Unlock()
    
    if workers > wp.workers {
        // 增加 workers
        for i := 0; i < workers-wp.workers; i++ {
            wp.wg.Add(1)
            go wp.worker()
        }
    } else if workers < wp.workers {
        // 减少 workers
        for i := 0; i < wp.workers-workers; i++ {
            wp.quit <- struct{}{}
        }
    }
    
    wp.workers = workers
}

关键点

  1. 任务队列:缓冲 channel,解耦任务提交和执行
  2. Worker 生命周期:使用 quit channel 通知退出
  3. 并发控制:限制 worker 数量
  4. 优雅关闭:关闭任务队列,等待所有 worker 完成

优化版本:支持超时和错误处理

type Task func() error

type WorkerPool struct {
    tasks    chan Task
    results  chan error
    wg       sync.WaitGroup
    quit     chan struct{}
    workers  int
    mu       sync.RWMutex
    timeout  time.Duration
}

func NewWorkerPool(timeout time.Duration) *WorkerPool {
    return &WorkerPool{
        tasks:   make(chan Task, 1000),
        results: make(chan error, 1000),
        quit:    make(chan struct{}),
        timeout: timeout,
    }
}

func (wp *WorkerPool) worker() {
    defer wp.wg.Done()
    
    for {
        select {
        case task, ok := <-wp.tasks:
            if !ok {
                return
            }
            
            // 带超时的任务执行
            done := make(chan error, 1)
            go func() {
                done <- task()
            }()
            
            select {
            case err := <-done:
                if err != nil {
                    wp.results <- err
                }
            case <-time.After(wp.timeout):
                wp.results <- fmt.Errorf("task timeout")
            }
            
        case <-wp.quit:
            return
        }
    }
}

func (wp *WorkerPool) Submit(task Task) error {
    wp.tasks <- task
    return nil
}

func (wp *WorkerPool) Results() <-chan error {
    return wp.results
}

扩展:其他并发模式

1. 限流器(Token Bucket)

type RateLimiter struct {
    rate     int
    capacity int
    tokens   int
    lastTime time.Time
    mu       sync.Mutex
}

func NewRateLimiter(rate, capacity int) *RateLimiter {
    return &RateLimiter{
        rate:     rate,
        capacity: capacity,
        tokens:   capacity,
        lastTime: time.Now(),
    }
}

func (rl *RateLimiter) Allow() bool {
    rl.mu.Lock()
    defer rl.mu.Unlock()
    
    now := time.Now()
    elapsed := now.Sub(rl.lastTime).Seconds()
    rl.tokens = min(rl.tokens+int(elapsed*float64(rl.rate)), rl.capacity)
    rl.lastTime = now
    
    if rl.tokens > 0 {
        rl.tokens--
        return true
    }
    
    return false
}

func min(a, b int) int {
    if a < b {
        return a
    }
    return b
}

2. 扇出-扇入模式

func FanOut(input <-chan int, workers int) []<-chan int {
    outputs := make([]<-chan int, workers)
    
    for i := 0; i < workers; i++ {
        outputs[i] = worker(input)
    }
    
    return outputs
}

func worker(input <-chan int) <-chan int {
    output := make(chan int)
    
    go func() {
        defer close(output)
        for data := range input {
            output <- process(data)
        }
    }()
    
    return output
}

func FanIn(inputs ...<-chan int) <-chan int {
    output := make(chan int)
    
    var wg sync.WaitGroup
    for _, input := range inputs {
        wg.Add(1)
        go func(ch <-chan int) {
            defer wg.Done()
            for data := range ch {
                output <- data
            }
        }(input)
    }
    
    go func() {
        wg.Wait()
        close(output)
    }()
    
    return output
}

3. 超时控制

func WithTimeout(timeout time.Duration) func(http.RoundTripper) http.RoundTripper {
    return func(next http.RoundTripper) http.RoundTripper {
        return &timeoutTransport{
            next:    next,
            timeout: timeout,
        }
    }
}

type timeoutTransport struct {
    next    http.RoundTripper
    timeout time.Duration
}

func (t *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    ctx, cancel := context.WithTimeout(req.Context(), t.timeout)
    defer cancel()
    
    req = req.WithContext(ctx)
    return t.next.RoundTrip(req)
}

最佳实践

1. 合理设置 Worker 数量

// CPU 密集型
workers := runtime.NumCPU()

// IO 密集型
workers := runtime.NumCPU() * 2

2. 任务队列容量

// 根据任务提交速率和处理速率调整
tasks := make(chan Task, 1000)

3. 错误处理

func (wp *WorkerPool) Submit(task Task) error {
    select {
    case wp.tasks <- task:
        return nil
    case <-time.After(time.Second):
        return errors.New("task queue full")
    }
}

4. 监控指标

type WorkerPool struct {
    // ...
    submitted int64
    completed int64
    failed    int64
}

func (wp *WorkerPool) Metrics() map[string]int64 {
    return map[string]int64{
        "submitted": atomic.LoadInt64(&wp.submitted),
        "completed": atomic.LoadInt64(&wp.completed),
        "failed":    atomic.LoadInt64(&wp.failed),
    }
}

搜索