use crate::{ authentication::{self, Credentials, validate_credentials}, routes::error_chain_fmt, session_state::TypedSession, startup::AppState, }; use axum::{ Extension, Form, Json, extract::{Request, State}, middleware::Next, response::{Html, IntoResponse, Redirect, Response}, }; use axum_messages::Messages; use reqwest::StatusCode; use secrecy::{ExposeSecret, SecretString}; use std::fmt::Write; use uuid::Uuid; #[derive(thiserror::Error)] pub enum AdminError { #[error("Something went wrong.")] UnexpectedError(#[from] anyhow::Error), #[error("You must be logged in to access the admin dashboard.")] NotAuthenticated, #[error("Updating password failed.")] ChangePassword, } impl std::fmt::Debug for AdminError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { error_chain_fmt(self, f) } } impl IntoResponse for AdminError { fn into_response(self) -> Response { #[derive(serde::Serialize)] struct ErrorResponse<'a> { message: &'a str, } tracing::error!("{:?}", self); match &self { AdminError::UnexpectedError(_) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { message: "An internal server error occured.", }), ) .into_response(), AdminError::NotAuthenticated => Redirect::to("/login").into_response(), AdminError::ChangePassword => Redirect::to("/admin/password").into_response(), } } } pub async fn require_auth( session: TypedSession, mut request: Request, next: Next, ) -> Result { let user_id = session .get_user_id() .await .map_err(|e| AdminError::UnexpectedError(e.into()))? .ok_or(AdminError::NotAuthenticated)?; let username = session .get_username() .await .map_err(|e| AdminError::UnexpectedError(e.into()))? .ok_or(AdminError::UnexpectedError(anyhow::anyhow!( "Could not find username in session." )))?; request .extensions_mut() .insert(AuthenticatedUser { user_id, username }); Ok(next.run(request).await) } #[derive(Clone)] pub struct AuthenticatedUser { user_id: Uuid, username: String, } pub async fn admin_dashboard( Extension(AuthenticatedUser { username, .. }): Extension, ) -> Result { Ok(Html(format!(include_str!("admin/dashboard.html"), username)).into_response()) } #[derive(serde::Deserialize)] pub struct PasswordFormData { pub current_password: SecretString, pub new_password: SecretString, pub new_password_check: SecretString, } pub async fn change_password_form(messages: Messages) -> Result { let mut error_html = String::new(); for message in messages { writeln!(error_html, "

{}

", message).unwrap(); } Ok(Html(format!( include_str!("admin/change_password_form.html"), error_html )) .into_response()) } pub async fn change_password( Extension(AuthenticatedUser { user_id, username }): Extension, State(AppState { connection_pool, .. }): State, messages: Messages, Form(form): Form, ) -> Result { let credentials = Credentials { username, password: form.current_password, }; if form.new_password.expose_secret() != form.new_password_check.expose_secret() { messages.error("You entered two different passwords - the field values must match."); Err(AdminError::ChangePassword) } else if validate_credentials(credentials, &connection_pool) .await .is_err() { messages.error("The current password is incorrect."); Err(AdminError::ChangePassword) } else if let Err(e) = verify_password(form.new_password.expose_secret()) { messages.error(e); Err(AdminError::ChangePassword) } else { authentication::change_password(user_id, form.new_password, &connection_pool) .await .map_err(|_| AdminError::ChangePassword)?; messages.success("Your password has been changed."); Ok(Redirect::to("/admin/password").into_response()) } } #[tracing::instrument(name = "Logging out", skip(messages, session))] pub async fn logout(messages: Messages, session: TypedSession) -> Result { session.clear().await; messages.success("You have successfully logged out."); Ok(Redirect::to("/login").into_response()) } fn verify_password(password: &str) -> Result<(), String> { if password.len() < 12 || password.len() > 128 { return Err("The password must contain between 12 and 128 characters.".into()); } Ok(()) }