导航菜单

Context 超时控制

🟡 中等

题目描述

实现一个带超时的 HTTP 请求函数。如果请求超过指定时间未返回,应取消请求并返回错误。

示例代码

func FetchWithTimeout(url string, timeout time.Duration) (string, error) {
    // TODO: 实现带超时的 HTTP 请求
    panic("implement me")
}

使用示例

func main() {
    // 正常情况
    result, err := FetchWithTimeout("https://example.com", 5*time.Second)
    if err != nil {
        log.Printf("error: %v", err)
        return
    }
    fmt.Println(result)
    
    // 超时情况
    result, err = FetchWithTimeout("https://slow-server.com", 100*time.Millisecond)
    if err != nil {
        log.Printf("timeout: %v", err)
        return
    }
}

提示

  • 使用 context.WithTimeout 创建超时上下文
  • 使用 http.NewRequestWithContext 绑定上下文
  • 使用 defer cancel() 确保资源释放
  • 检查错误是否是超时错误

解法

参考答案 (3 个标签)
context HTTP 超时

实现代码

func FetchWithTimeout(url string, timeout time.Duration) (string, error) {
    // 1. 创建带超时的 context
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel() // 确保资源释放
    
    // 2. 创建请求并绑定 context
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return "", fmt.Errorf("create request: %w", err)
    }
    
    // 3. 发送请求
    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        // 检查是否是超时错误
        if errors.Is(err, context.DeadlineExceeded) {
            return "", fmt.Errorf("request timeout after %v", timeout)
        }
        if errors.Is(err, context.Canceled) {
            return "", fmt.Errorf("request canceled")
        }
        return "", fmt.Errorf("send request: %w", err)
    }
    defer resp.Body.Close()
    
    // 4. 读取响应
    body, err := io.ReadAll(resp.Body)
    if err != nil {
        return "", fmt.Errorf("read response: %w", err)
    }
    
    return string(body), nil
}

关键点

  1. context.WithTimeout:创建超时上下文
  2. defer cancel():确保资源释放(防止泄露)
  3. NewRequestWithContext:绑定上下文到请求
  4. 错误检查:区分超时和其他错误

Context 传播

func FetchWithContext(ctx context.Context, url string) (string, error) {
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return "", err
    }
    
    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()
    
    body, err := io.ReadAll(resp.Body)
    if err != nil {
        return "", err
    }
    
    return string(body), nil
}

// 使用
func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    
    result, err := FetchWithContext(ctx, "https://example.com")
    if err != nil {
        log.Printf("error: %v", err)
        return
    }
    fmt.Println(result)
}

扩展:Context 其他用法

1. WithCancel

func Worker(ctx context.Context, jobs <-chan int) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("worker canceled:", ctx.Err())
            return
        case job := <-jobs:
            fmt.Println("processing:", job)
        }
    }
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    
    jobs := make(chan int, 10)
    go Worker(ctx, jobs)
    
    jobs <- 1
    jobs <- 2
    
    cancel() // 取消 worker
    time.Sleep(time.Second)
}

2. WithDeadline

func main() {
    deadline := time.Now().Add(5 * time.Second)
    ctx, cancel := context.WithDeadline(context.Background(), deadline)
    defer cancel()
    
    // 在截止时间前执行
    select {
    case <-time.After(10 * time.Second):
        fmt.Println("completed")
    case <-ctx.Done():
        fmt.Println("deadline exceeded:", ctx.Err())
    }
}

3. WithValue

type contextKey string

const (
    userIDKey contextKey = "userID"
    traceIDKey contextKey = "traceID"
)

func HandleRequest(ctx context.Context) {
    userID := ctx.Value(userIDKey).(string)
    traceID := ctx.Value(traceIDKey).(string)
    fmt.Printf("userID=%s, traceID=%s\n", userID, traceID)
}

func main() {
    ctx := context.WithValue(context.Background(), userIDKey, "12345")
    ctx = context.WithValue(ctx, traceIDKey, "abc-def")
    
    HandleRequest(ctx)
}

4. 链式调用

func main() {
    // 根 context
    ctx := context.Background()
    
    // 添加超时
    ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
    defer cancel()
    
    // 添加值
    ctx = context.WithValue(ctx, "key", "value")
    
    // 传递给下游
    Process(ctx)
}

Context 最佳实践

1. 不要传递 nil Context

// ❌ 错误
func Process() {
    doWork(nil)  // 可能 panic
}

// ✅ 正确
func Process() {
    ctx := context.Background()
    doWork(ctx)
}

2. Context 作为第一个参数

// ✅ 推荐
func Fetch(ctx context.Context, url string) (string, error) {
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    // ...
}

// ❌ 不推荐
func Fetch(url string, ctx context.Context) (string, error) {
    // ...
}

3. 及时调用 cancel

// ✅ 使用 defer
func Process() error {
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    defer cancel()  // 确保调用
    
    // ...
    return nil
}

4. 检查 ctx.Done()

func LongRunningTask(ctx context.Context) error {
    for i := 0; i < 1000; i++ {
        select {
        case <-ctx.Done():
            return ctx.Err()  // 及时退出
        default:
            // 执行任务
            time.Sleep(10 * time.Millisecond)
        }
    }
    return nil
}

搜索