实现 Worker Pool
🔴 困难题目描述
实现一个 Worker Pool,支持以下功能:
- 限制并发 goroutine 数量
- 动态调整 worker 数量
- 优雅关闭(等待所有任务完成)
示例代码
type WorkerPool interface {
Start(workers int)
Submit(task func())
Stop()
Resize(workers int)
}使用示例
func main() {
pool := NewWorkerPool()
pool.Start(5) // 启动 5 个 worker
// 提交 100 个任务
for i := 0; i < 100; i++ {
taskID := i
pool.Submit(func() {
fmt.Printf("Processing task %d\n", taskID)
time.Sleep(time.Second)
})
}
// 动态调整
pool.Resize(10) // 增加到 10 个 worker
// 等待完成
pool.Stop()
}提示
- 使用 channel 作为任务队列
- 使用 sync.WaitGroup 等待 worker 退出
- 使用 context 控制 worker 生命周期
解法
参考答案 (3 个标签)
Worker Pool 并发控制 channel
实现代码
type Task func()
type WorkerPool struct {
tasks chan Task
wg sync.WaitGroup
quit chan struct{}
workers int
mu sync.RWMutex
}
func NewWorkerPool() *WorkerPool {
return &WorkerPool{
tasks: make(chan Task, 1000),
quit: make(chan struct{}),
}
}
func (wp *WorkerPool) Start(workers int) {
wp.mu.Lock()
defer wp.mu.Unlock()
wp.workers = workers
for i := 0; i < workers; i++ {
wp.wg.Add(1)
go wp.worker()
}
}
func (wp *WorkerPool) worker() {
defer wp.wg.Done()
for {
select {
case task, ok := <-wp.tasks:
if !ok {
return
}
task()
case <-wp.quit:
return
}
}
}
func (wp *WorkerPool) Submit(task Task) {
wp.tasks <- task
}
func (wp *WorkerPool) Stop() {
close(wp.tasks)
wp.wg.Wait()
}
func (wp *WorkerPool) Resize(workers int) {
wp.mu.Lock()
defer wp.mu.Unlock()
if workers > wp.workers {
// 增加 workers
for i := 0; i < workers-wp.workers; i++ {
wp.wg.Add(1)
go wp.worker()
}
} else if workers < wp.workers {
// 减少 workers
for i := 0; i < wp.workers-workers; i++ {
wp.quit <- struct{}{}
}
}
wp.workers = workers
}关键点
- 任务队列:缓冲 channel,解耦任务提交和执行
- Worker 生命周期:使用 quit channel 通知退出
- 并发控制:限制 worker 数量
- 优雅关闭:关闭任务队列,等待所有 worker 完成
优化版本:支持超时和错误处理
type Task func() error
type WorkerPool struct {
tasks chan Task
results chan error
wg sync.WaitGroup
quit chan struct{}
workers int
mu sync.RWMutex
timeout time.Duration
}
func NewWorkerPool(timeout time.Duration) *WorkerPool {
return &WorkerPool{
tasks: make(chan Task, 1000),
results: make(chan error, 1000),
quit: make(chan struct{}),
timeout: timeout,
}
}
func (wp *WorkerPool) worker() {
defer wp.wg.Done()
for {
select {
case task, ok := <-wp.tasks:
if !ok {
return
}
// 带超时的任务执行
done := make(chan error, 1)
go func() {
done <- task()
}()
select {
case err := <-done:
if err != nil {
wp.results <- err
}
case <-time.After(wp.timeout):
wp.results <- fmt.Errorf("task timeout")
}
case <-wp.quit:
return
}
}
}
func (wp *WorkerPool) Submit(task Task) error {
wp.tasks <- task
return nil
}
func (wp *WorkerPool) Results() <-chan error {
return wp.results
}扩展:其他并发模式
1. 限流器(Token Bucket)
type RateLimiter struct {
rate int
capacity int
tokens int
lastTime time.Time
mu sync.Mutex
}
func NewRateLimiter(rate, capacity int) *RateLimiter {
return &RateLimiter{
rate: rate,
capacity: capacity,
tokens: capacity,
lastTime: time.Now(),
}
}
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
elapsed := now.Sub(rl.lastTime).Seconds()
rl.tokens = min(rl.tokens+int(elapsed*float64(rl.rate)), rl.capacity)
rl.lastTime = now
if rl.tokens > 0 {
rl.tokens--
return true
}
return false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}2. 扇出-扇入模式
func FanOut(input <-chan int, workers int) []<-chan int {
outputs := make([]<-chan int, workers)
for i := 0; i < workers; i++ {
outputs[i] = worker(input)
}
return outputs
}
func worker(input <-chan int) <-chan int {
output := make(chan int)
go func() {
defer close(output)
for data := range input {
output <- process(data)
}
}()
return output
}
func FanIn(inputs ...<-chan int) <-chan int {
output := make(chan int)
var wg sync.WaitGroup
for _, input := range inputs {
wg.Add(1)
go func(ch <-chan int) {
defer wg.Done()
for data := range ch {
output <- data
}
}(input)
}
go func() {
wg.Wait()
close(output)
}()
return output
}3. 超时控制
func WithTimeout(timeout time.Duration) func(http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return &timeoutTransport{
next: next,
timeout: timeout,
}
}
}
type timeoutTransport struct {
next http.RoundTripper
timeout time.Duration
}
func (t *timeoutTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx, cancel := context.WithTimeout(req.Context(), t.timeout)
defer cancel()
req = req.WithContext(ctx)
return t.next.RoundTrip(req)
}最佳实践
1. 合理设置 Worker 数量
// CPU 密集型
workers := runtime.NumCPU()
// IO 密集型
workers := runtime.NumCPU() * 22. 任务队列容量
// 根据任务提交速率和处理速率调整
tasks := make(chan Task, 1000)3. 错误处理
func (wp *WorkerPool) Submit(task Task) error {
select {
case wp.tasks <- task:
return nil
case <-time.After(time.Second):
return errors.New("task queue full")
}
}4. 监控指标
type WorkerPool struct {
// ...
submitted int64
completed int64
failed int64
}
func (wp *WorkerPool) Metrics() map[string]int64 {
return map[string]int64{
"submitted": atomic.LoadInt64(&wp.submitted),
"completed": atomic.LoadInt64(&wp.completed),
"failed": atomic.LoadInt64(&wp.failed),
}
}