use crate::{ routes::AdminError, session_state::TypedSession, telemetry::spawn_blocking_with_tracing, }; use anyhow::Context; use argon2::{ Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version, password_hash::{SaltString, rand_core::OsRng}, }; use axum::{extract::Request, middleware::Next, response::Response}; use secrecy::{ExposeSecret, SecretString}; use sqlx::PgPool; use uuid::Uuid; pub struct Credentials { pub username: String, pub password: SecretString, } #[derive(Debug, thiserror::Error)] pub enum AuthError { #[error(transparent)] UnexpectedError(#[from] anyhow::Error), #[error("Invalid credentials.")] InvalidCredentials(#[source] anyhow::Error), #[error("Not authenticated.")] NotAuthenticated, } #[tracing::instrument(name = "Change password", skip(password, connection_pool))] pub async fn change_password( user_id: Uuid, password: SecretString, connection_pool: &PgPool, ) -> Result<(), anyhow::Error> { let password_hash = spawn_blocking_with_tracing(move || compute_pasword_hash(password)) .await? .context("Failed to hash password")?; sqlx::query!( "UPDATE users SET password_hash = $1 WHERE user_id = $2", password_hash.expose_secret(), user_id ) .execute(connection_pool) .await .context("Failed to update user password in the database.")?; Ok(()) } fn compute_pasword_hash(password: SecretString) -> Result { let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::new( Algorithm::Argon2id, Version::V0x13, Params::new(1500, 2, 1, None).unwrap(), ) .hash_password(password.expose_secret().as_bytes(), &salt)? .to_string(); Ok(SecretString::from(password_hash)) } #[tracing::instrument( name = "Validate credentials", skip(username, password, connection_pool) )] pub async fn validate_credentials( Credentials { username, password }: Credentials, connection_pool: &PgPool, ) -> Result { let mut user_id = None; let mut expected_password_hash = SecretString::from( "$argon2id$v=19$m=15000,t=2,p=1$\ gZiV/M1gPc22ElAH/Jh1Hw$\ CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno" .to_string(), ); if let Some((stored_user_id, stored_expected_password_hash)) = get_stored_credentials(&username, connection_pool) .await .map_err(AuthError::UnexpectedError)? { user_id = Some(stored_user_id); expected_password_hash = stored_expected_password_hash; } spawn_blocking_with_tracing(|| verify_password_hash(expected_password_hash, password)) .await .context("Failed to spawn blocking task.") .map_err(AuthError::UnexpectedError)??; user_id .ok_or_else(|| anyhow::anyhow!("Unknown username.")) .map_err(AuthError::InvalidCredentials) } #[tracing::instrument( name = "Verify password", skip(expected_password_hash, password_candidate) )] fn verify_password_hash( expected_password_hash: SecretString, password_candidate: SecretString, ) -> Result<(), AuthError> { let expected_password_hash = PasswordHash::new(expected_password_hash.expose_secret()) .context("Failed to parse hash in PHC string format.")?; Argon2::default() .verify_password( password_candidate.expose_secret().as_bytes(), &expected_password_hash, ) .context("Password verification failed.") .map_err(AuthError::InvalidCredentials) } #[tracing::instrument(name = "Get stored credentials", skip(username, connection_pool))] async fn get_stored_credentials( username: &str, connection_pool: &PgPool, ) -> Result, anyhow::Error> { let row = sqlx::query!( r#" SELECT user_id, password_hash FROM users WHERE username = $1 "#, username, ) .fetch_optional(connection_pool) .await .context("Failed to perform a query to retrieve stored credentials.")? .map(|row| (row.user_id, SecretString::from(row.password_hash))); Ok(row) } pub async fn require_auth( session: TypedSession, mut request: Request, next: Next, ) -> Result { let user_id = session .get_user_id() .await .map_err(|e| AdminError::UnexpectedError(e.into()))? .ok_or(AdminError::NotAuthenticated)?; let username = session .get_username() .await .map_err(|e| AdminError::UnexpectedError(e.into()))? .ok_or(AdminError::UnexpectedError(anyhow::anyhow!( "Could not find username in session." )))?; request .extensions_mut() .insert(AuthenticatedUser { user_id, username }); Ok(next.run(request).await) } #[derive(Clone)] pub struct AuthenticatedUser { pub user_id: Uuid, pub username: String, }