163 lines
4.8 KiB
Rust
163 lines
4.8 KiB
Rust
use crate::telemetry::spawn_blocking_with_tracing;
|
|
use anyhow::Context;
|
|
use argon2::{
|
|
Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version,
|
|
password_hash::{SaltString, rand_core::OsRng},
|
|
};
|
|
use secrecy::{ExposeSecret, SecretString};
|
|
use sqlx::PgPool;
|
|
use std::fmt::Display;
|
|
use uuid::Uuid;
|
|
|
|
pub struct Credentials {
|
|
pub username: String,
|
|
pub password: SecretString,
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum AuthError {
|
|
#[error(transparent)]
|
|
UnexpectedError(#[from] anyhow::Error),
|
|
#[error("Invalid credentials.")]
|
|
InvalidCredentials(#[source] anyhow::Error),
|
|
}
|
|
|
|
#[tracing::instrument(name = "Change password", skip(password, connection_pool))]
|
|
pub async fn change_password(
|
|
user_id: Uuid,
|
|
password: SecretString,
|
|
connection_pool: &PgPool,
|
|
) -> Result<(), anyhow::Error> {
|
|
let password_hash = spawn_blocking_with_tracing(move || compute_pasword_hash(password))
|
|
.await?
|
|
.context("Failed to hash password.")?;
|
|
sqlx::query!(
|
|
"UPDATE users SET password_hash = $1 WHERE user_id = $2",
|
|
password_hash.expose_secret(),
|
|
user_id
|
|
)
|
|
.execute(connection_pool)
|
|
.await
|
|
.context("Failed to update user password in the database.")?;
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn compute_pasword_hash(password: SecretString) -> Result<SecretString, anyhow::Error> {
|
|
let salt = SaltString::generate(&mut OsRng);
|
|
let password_hash = Argon2::new(
|
|
Algorithm::Argon2id,
|
|
Version::V0x13,
|
|
Params::new(1500, 2, 1, None).unwrap(),
|
|
)
|
|
.hash_password(password.expose_secret().as_bytes(), &salt)?
|
|
.to_string();
|
|
Ok(SecretString::from(password_hash))
|
|
}
|
|
|
|
#[tracing::instrument(name = "Validate credentials", skip_all)]
|
|
pub async fn validate_credentials(
|
|
Credentials { username, password }: Credentials,
|
|
connection_pool: &PgPool,
|
|
) -> Result<(Uuid, Role), AuthError> {
|
|
let mut user_id = None;
|
|
let mut role = None;
|
|
let mut expected_password_hash = SecretString::from(
|
|
"$argon2id$v=19$m=15000,t=2,p=1$\
|
|
gZiV/M1gPc22ElAH/Jh1Hw$\
|
|
CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"
|
|
.to_string(),
|
|
);
|
|
|
|
if let Some((stored_user_id, stored_expected_password_hash, stored_role)) =
|
|
get_stored_credentials(&username, connection_pool)
|
|
.await
|
|
.context("Failed to retrieve credentials from database.")
|
|
.map_err(AuthError::UnexpectedError)?
|
|
{
|
|
user_id = Some(stored_user_id);
|
|
role = Some(stored_role);
|
|
expected_password_hash = stored_expected_password_hash;
|
|
}
|
|
|
|
let handle =
|
|
spawn_blocking_with_tracing(|| verify_password_hash(expected_password_hash, password));
|
|
|
|
let uuid = user_id
|
|
.ok_or_else(|| anyhow::anyhow!("Unknown username."))
|
|
.map_err(AuthError::InvalidCredentials)?;
|
|
|
|
let role = role
|
|
.ok_or_else(|| anyhow::anyhow!("Unknown role."))
|
|
.map_err(AuthError::UnexpectedError)?;
|
|
|
|
handle
|
|
.await
|
|
.context("Failed to spawn blocking task.")
|
|
.map_err(AuthError::UnexpectedError)?
|
|
.map_err(AuthError::InvalidCredentials)
|
|
.map(|_| (uuid, role))
|
|
}
|
|
|
|
#[tracing::instrument(name = "Verify password", skip_all)]
|
|
fn verify_password_hash(
|
|
expected_password_hash: SecretString,
|
|
password_candidate: SecretString,
|
|
) -> Result<(), anyhow::Error> {
|
|
let expected_password_hash = PasswordHash::new(expected_password_hash.expose_secret())
|
|
.context("Failed to parse hash in PHC string format.")?;
|
|
Argon2::default()
|
|
.verify_password(
|
|
password_candidate.expose_secret().as_bytes(),
|
|
&expected_password_hash,
|
|
)
|
|
.context("Password verification failed.")
|
|
}
|
|
|
|
#[tracing::instrument(name = "Get stored credentials", skip(connection_pool))]
|
|
async fn get_stored_credentials(
|
|
username: &str,
|
|
connection_pool: &PgPool,
|
|
) -> Result<Option<(Uuid, SecretString, Role)>, sqlx::Error> {
|
|
let row = sqlx::query!(
|
|
r#"
|
|
SELECT user_id, password_hash, role as "role: Role"
|
|
FROM users
|
|
WHERE username = $1
|
|
"#,
|
|
username,
|
|
)
|
|
.fetch_optional(connection_pool)
|
|
.await?
|
|
.map(|row| (row.user_id, SecretString::from(row.password_hash), row.role));
|
|
Ok(row)
|
|
}
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq, sqlx::Type)]
|
|
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
|
|
pub enum Role {
|
|
Admin,
|
|
Writer,
|
|
}
|
|
|
|
impl Display for Role {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
Role::Admin => write!(f, "admin"),
|
|
Role::Writer => write!(f, "writer"),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AuthenticatedUser {
|
|
pub user_id: Uuid,
|
|
pub username: String,
|
|
pub role: Role,
|
|
}
|
|
|
|
impl AuthenticatedUser {
|
|
pub fn is_admin(&self) -> bool {
|
|
matches!(self.role, Role::Admin)
|
|
}
|
|
}
|