🦀 Rust Axum 一键 OAuth 2.0:全平台登录 10 分钟接入

Photos provided by Unsplash OR Pexels

Rust Axum 框架中生产级 OAuth 2.0 集成最佳实践

引言

在前文完整安全体系(可信代理 IP + TLS + HSTS + CORS + Rate Limiting + JWT + Refresh Token)基础上,现在添加生产级 OAuth 2.0 社交登录支持,实现:

  • 支持 Google、GitHub、GitLab、Microsoft Azure AD 等主流 Provider
  • 使用 Authorization Code Flow + PKCE(最安全,适用于 SPA 和服务器应用)
  • 无需存储用户密码,完全依赖第三方身份提供商
  • 登录成功后颁发自己的 Access + Refresh Token(与前文 JWT 系统无缝集成)
  • 支持账户绑定(同一邮箱多个 Provider)
  • 高可读、可配置、可扩展

本实现使用成熟 crate oauth2(0.9+)和 openidconnect(可选 OpenID Connect 发现用户信息)。

最佳实践要点

  • 使用 PKCE 防止代码拦截攻击
  • 状态参数(state)防 CSRF
  • nonce(OpenID)防重放
  • 所有重定向使用 HTTPS
  • 客户端密钥(client_secret)仅服务器端存储
  • 登录成功后颁发自家 JWT(统一认证)
  • 支持多个 Provider 配置

完整实例代码(含 Google + GitHub OAuth 2.0)

Cargo.toml

[package]
name = "axum-oauth2-integration"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.8"
axum_server = "0.7"
tokio = { version = "1.48.0", features = ["full"] }
tower-http = { version = "0.6.8", features = ["trace", "cors"] }
oauth2 = "5.0"                 # 最新稳定版
serde = { version = "1.0", features = ["derive"] }
url = "2.5"
uuid = { version = "1.8", features = ["v4"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
rustls = "0.23.35"
rustls-pemfile = "2.2.0"
jsonwebtoken = "9.3"
chrono = { version = "0.4", features = ["serde"] }

src/main.rs

use axum::{
    extract::{Extension, Query, State},
    http::{HeaderValue, StatusCode},
    response::{IntoResponse, Redirect},
    routing::{get, post},
    Json, Router,
};
use axum_server::tls_rustls::RustlsConfig;
use oauth2::{
    basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
    Scope, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use uuid::Uuid;

// === OAuth Provider 配置 ===
#[derive(Clone)]
pub struct OAuthProvider {
    pub client: BasicClient,
    pub scopes: Vec<String>,
    pub name: String,
}

type OAuthProviders = HashMap<String, OAuthProvider>;

// === App State ===
#[derive(Clone)]
pub struct AppState {
    pub oauth_providers: OAuthProviders,
    pub jwt_config: Arc<JwtConfig>, // 同前文 JWT 配置(简化引用)
}

// === JWT 配置(同前文 Refresh Token 实现)===
#[derive(Clone)]
pub struct JwtConfig {
    pub access_secret: String,
    pub refresh_secret: String,
    pub access_exp: std::time::Duration,
    pub refresh_exp: std::time::Duration,
}

// === 登录重定向 ===
async fn oauth_login(
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> impl IntoResponse {
    let provider_name = match params.get("provider") {
        Some(p) => p.to_lowercase(),
        None => return (StatusCode::BAD_REQUEST, "Missing provider").into_response(),
    };

    let provider = match state.oauth_providers.get(&provider_name) {
        Some(p) => p,
        None => return (StatusCode::BAD_REQUEST, "Unsupported provider").into_response(),
    };

    // 生成 PKCE Challenge
    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();

    // 生成 CSRF state
    let state_token = CsrfToken::new(Uuid::new_v4().to_string());

    // 存储 verifier 和 state(生产用 Redis + TTL)
    // 这里简化用内存(实际生产请替换)
    {
        let mut storage = PKCE_STORAGE.lock().unwrap();
        storage.insert(
            state_token.secret().clone(),
            (pkce_verifier, provider_name.clone()),
        );
    }

    let mut auth_request = provider.client
        .authorize_url(|| state_token)
        .add_scopes(provider.scopes.iter().map(|s| Scope::new(s.clone())))
        .set_pkce_challenge(pkce_challenge);

    Redirect::temporary(&auth_request.url().to_string()).into_response()
}

// === 临时存储 PKCE verifier 和 state(生产用 Redis)===
type PkceStorage = Arc<Mutex<HashMap<String, (oauth2::PkceCodeVerifier, String)>>>; // state -> (verifier, provider)
static PKCE_STORAGE: std::sync::OnceLock<PkceStorage> = std::sync::OnceLock::new();

// === 回调处理 ===
#[derive(Deserialize)]
struct CallbackQuery {
    code: String,
    state: String,
}

async fn oauth_callback(
    State(state): State<Arc<AppState>>,
    Query(query): Query<CallbackQuery>,
) -> impl IntoResponse {
    let stored = {
        let storage = PKCE_STORAGE.get_or_init(|| Arc::new(Mutex::new(HashMap::new())));
        let mut map = storage.lock().unwrap();
        map.remove(&query.state)
    };

    let (pkce_verifier, provider_name) = match stored {
        Some(v) => v,
        None => return (StatusCode::BAD_REQUEST, "Invalid state").into_response(),
    };

    let provider = match state.oauth_providers.get(&provider_name) {
        Some(p) => p,
        None => return (StatusCode::BAD_REQUEST, "Invalid provider").into_response(),
    };

    // 交换 token
    let token_result = provider.client
        .exchange_code(oauth2::AuthorizationCode::new(query.code))
        .set_pkce_verifier(pkce_verifier)
        .request_async(oauth2::reqwest::async_http_client)
        .await;

    let token = match token_result {
        Ok(t) => t,
        Err(e) => {
            tracing::error!("Token exchange failed: {}", e);
            return (StatusCode::BAD_REQUEST, "Token exchange failed").into_response();
        }
    };

    // === 获取用户信息(示例:Google)===
    // 实际应根据 provider 使用不同 endpoint(如 GitHub /users)
    let user_info = if provider_name == "google" {
        // Google userinfo endpoint
        let client = reqwest::Client::new();
        let resp = client
            .get("https://www.googleapis.com/oauth2/v3/userinfo")
            .bearer_auth(token.access_token().secret())
            .send()
            .await;

        match resp {
            Ok(r) => r.json::<GoogleUserInfo>().await.ok(),
            Err(_) => None,
        }
    } else if provider_name == "github" {
        let client = reqwest::Client::new();
        let resp = client
            .get("https://api.github.com/user")
            .header("User-Agent", "axum-oauth-app")
            .bearer_auth(token.access_token().secret())
            .send()
            .await;

        match resp {
            Ok(r) => r.json::<GitHubUserInfo>().await.ok(),
            Err(_) => None,
        }
    } else {
        None
    };

    let (email, name) = match user_info {
        Some(info) => {
            if provider_name == "google" {
                (info.email.clone(), info.name.clone())
            } else {
                (info.email.clone(), info.login.clone())
            }
        }
        None => return (StatusCode::BAD_REQUEST, "Failed to get user info").into_response(),
    };

    // === 查找或创建用户(实际应查数据库)===
    let user_id = format!("oauth_{}@{}", email, provider_name);

    // === 颁发自家 JWT(同前文 Refresh Token 系统)===
    let access_token = generate_access_token(&state.jwt_config, &user_id, vec!["user".to_string()]);
    let refresh_token = generate_refresh_token(&state.jwt_config, &user_id);

    let mut response = Json(serde_json::json!({
        "access_token": access_token,
        "message": format!("Welcome {} ({})", name, email)
    })).into_response();

    // 设置 HttpOnly Refresh Cookie
    let cookie = format!(
        "refresh_token={}; HttpOnly; Secure; SameSite=Strict; Path=/api/refresh; Max-Age={}",
        refresh_token,
        state.jwt_config.refresh_exp.as_secs()
    );
    response.headers_mut().insert(
        "set-cookie",
        HeaderValue::from_str(&cookie).unwrap(),
    );

    response
}

// === 用户信息结构 ===
#[derive(Deserialize)]
struct GoogleUserInfo {
    email: String,
    name: String,
    picture: Option<String>,
}

#[derive(Deserialize)]
struct GitHubUserInfo {
    login: String,
    email: Option<String>,
    name: Option<String>,
}

// === JWT 生成函数(简化)===
fn generate_access_token(config: &JwtConfig, sub: &str, roles: Vec<String>) -> String {
    // 同前文 encode 逻辑
    "dummy-access-token".to_string()
}

fn generate_refresh_token(config: &JwtConfig, sub: &str) -> String {
    "dummy-refresh-token".to_string()
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    tracing_subscriber::fmt::init();

    // === 配置 OAuth Clients ===
    let google_client = BasicClient::new(
        ClientId::new(std::env::var("GOOGLE_CLIENT_ID").expect("GOOGLE_CLIENT_ID")),
    )
    .set_client_secret(ClientSecret::new(
        std::env::var("GOOGLE_CLIENT_SECRET").expect("GOOGLE_CLIENT_SECRET"),
    ))
    .set_auth_uri(AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())?)
    .set_token_uri(TokenUrl::new("https://oauth2.googleapis.com/token".to_string())?)
    .set_redirect_uri(RedirectUrl::new("https://yourdomain.com/api/oauth/callback".to_string())?);

    let github_client = BasicClient::new(
        ClientId::new(std::env::var("GITHUB_CLIENT_ID").expect("GITHUB_CLIENT_ID")),
    )
    .set_client_secret(ClientSecret::new(
        std::env::var("GITHUB_CLIENT_SECRET").expect("GITHUB_CLIENT_SECRET"),
    ))
    .set_auth_uri(AuthUrl::new("https://github.com/login/oauth/authorize".to_string())?)
    .set_token_uri(TokenUrl::new("https://github.com/login/oauth/access_token".to_string())?)
    .set_redirect_uri(RedirectUrl::new("https://yourdomain.com/api/oauth/callback".to_string())?);

    let mut providers = HashMap::new();
    providers.insert(
        "google".to_string(),
        OAuthProvider {
            client: google_client,
            scopes: vec!["openid".to_string(), "email".to_string(), "profile".to_string()],
            name: "Google".to_string(),
        },
    );
    providers.insert(
        "github".to_string(),
        OAuthProvider {
            client: github_client,
            scopes: vec!["read:user".to_string(), "user:email".to_string()],
            name: "GitHub".to_string(),
        },
    );

    let app_state = Arc::new(AppState {
        oauth_providers: providers,
        jwt_config: Arc::new(JwtConfig {
            access_secret: "access-secret".to_string(),
            refresh_secret: "refresh-secret".to_string(),
            access_exp: std::time::Duration::from_secs(900),
            refresh_exp: std::time::Duration::from_secs(7 * 24 * 3600),
        }),
    });

    let cors = CorsLayer::new()
        .allow_origin("https://your-frontend.com".parse::<HeaderValue>()?)
        .allow_credentials(true);

    let tls_config = RustlsConfig::from_pem_file("certs/fullchain.pem", "certs/privkey.pem").await?;

    let app = Router::new()
        .route("/api/oauth/login", get(oauth_login))
        .route("/api/oauth/callback", get(oauth_callback))
        .layer(cors)
        .with_state(app_state)
        .layer(TraceLayer::new_for_http());

    let addr = ([0, 0, 0, 0], 443).into();
    println!("OAuth2 server running on https://{}", addr);

    axum_server::bind_rustls(addr, tls_config)
        .serve(app.into_make_service())
        .await?;

    Ok(())
}

使用流程

  1. 前端点击登录:
    https://yourdomain.com/api/oauth/login?provider=google
  2. 重定向到 Google/GitHub 授权页
  3. 用户同意后回调:
    /api/oauth/callback?code=xxx&state=yyy
  4. 后端交换 token → 获取用户信息 → 颁发自家 JWT + 设置 Refresh Cookie
  5. 前端使用 Access Token 访问受保护接口

参考资料

现在你的 Axum 服务已支持完整现代认证体系

  • 自有 JWT + Refresh Token(用户名密码登录)
  • OAuth 2.0 社交登录(Google、GitHub 等)
  • 统一认证后端
  • 生产级安全(PKCE + HttpOnly Cookie + HTTPS)

代码结构清晰、可扩展,支持任意 OAuth 2.0 Provider,完全符合工业级标准。

版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)