use crate::telemetry::spawn_blocking_with_tracing; use anyhow::Context; use argon2::{ Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version, password_hash::{SaltString, rand_core::OsRng}, }; use secrecy::{ExposeSecret, SecretString}; use sqlx::PgPool; use std::fmt::Display; use uuid::Uuid; pub struct Credentials { pub username: String, pub password: SecretString, } #[derive(Debug, thiserror::Error)] pub enum AuthError { #[error(transparent)] UnexpectedError(#[from] anyhow::Error), #[error("Invalid credentials.")] InvalidCredentials(#[source] anyhow::Error), } #[tracing::instrument(name = "Change password", skip(password, connection_pool))] pub async fn change_password( user_id: Uuid, password: SecretString, connection_pool: &PgPool, ) -> Result<(), anyhow::Error> { let password_hash = spawn_blocking_with_tracing(move || compute_pasword_hash(password)) .await? .context("Failed to hash password.")?; sqlx::query!( "UPDATE users SET password_hash = $1 WHERE user_id = $2", password_hash.expose_secret(), user_id ) .execute(connection_pool) .await .context("Failed to update user password in the database.")?; Ok(()) } pub(crate) fn compute_pasword_hash(password: SecretString) -> Result { let salt = SaltString::generate(&mut OsRng); let password_hash = Argon2::new( Algorithm::Argon2id, Version::V0x13, Params::new(1500, 2, 1, None).unwrap(), ) .hash_password(password.expose_secret().as_bytes(), &salt)? .to_string(); Ok(SecretString::from(password_hash)) } #[tracing::instrument(name = "Validate credentials", skip_all)] pub async fn validate_credentials( Credentials { username, password }: Credentials, connection_pool: &PgPool, ) -> Result<(Uuid, Role), AuthError> { let mut user_id = None; let mut role = None; let mut expected_password_hash = SecretString::from( "$argon2id$v=19$m=15000,t=2,p=1$\ gZiV/M1gPc22ElAH/Jh1Hw$\ CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno" .to_string(), ); if let Some((stored_user_id, stored_expected_password_hash, stored_role)) = get_stored_credentials(&username, connection_pool) .await .context("Failed to retrieve credentials from database.") .map_err(AuthError::UnexpectedError)? { user_id = Some(stored_user_id); role = Some(stored_role); expected_password_hash = stored_expected_password_hash; } let handle = spawn_blocking_with_tracing(|| verify_password_hash(expected_password_hash, password)); let uuid = user_id .ok_or_else(|| anyhow::anyhow!("Unknown username.")) .map_err(AuthError::InvalidCredentials)?; let role = role .ok_or_else(|| anyhow::anyhow!("Unknown role.")) .map_err(AuthError::UnexpectedError)?; handle .await .context("Failed to spawn blocking task.") .map_err(AuthError::UnexpectedError)? .map_err(AuthError::InvalidCredentials) .map(|_| (uuid, role)) } #[tracing::instrument(name = "Verify password", skip_all)] fn verify_password_hash( expected_password_hash: SecretString, password_candidate: SecretString, ) -> Result<(), anyhow::Error> { let expected_password_hash = PasswordHash::new(expected_password_hash.expose_secret()) .context("Failed to parse hash in PHC string format.")?; Argon2::default() .verify_password( password_candidate.expose_secret().as_bytes(), &expected_password_hash, ) .context("Password verification failed.") } #[tracing::instrument(name = "Get stored credentials", skip(connection_pool))] async fn get_stored_credentials( username: &str, connection_pool: &PgPool, ) -> Result, sqlx::Error> { let row = sqlx::query!( r#" SELECT user_id, password_hash, role as "role: Role" FROM users WHERE username = $1 "#, username, ) .fetch_optional(connection_pool) .await? .map(|row| (row.user_id, SecretString::from(row.password_hash), row.role)); Ok(row) } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq, sqlx::Type)] #[sqlx(type_name = "user_role", rename_all = "lowercase")] pub enum Role { Admin, Writer, } impl Display for Role { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Role::Admin => write!(f, "admin"), Role::Writer => write!(f, "writer"), } } } #[derive(Clone)] pub struct AuthenticatedUser { pub user_id: Uuid, pub username: String, pub role: Role, } impl AuthenticatedUser { pub fn is_admin(&self) -> bool { matches!(self.role, Role::Admin) } }