Flash messages using axum-messages

This commit is contained in:
Alphonse Paix
2025-08-30 01:15:54 +02:00
parent 8447d050d6
commit 3dce578ba0
24 changed files with 820 additions and 45 deletions

5
src/routes/home.rs Normal file
View File

@@ -0,0 +1,5 @@
use axum::response::{Html, IntoResponse};
pub async fn home() -> impl IntoResponse {
Html(include_str!("home/home.html"))
}

11
src/routes/home/home.html Normal file
View File

@@ -0,0 +1,11 @@
<!DOCTYPE html>
<html lang="en">
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width" />
<head>
<title>Home</title>
<body>
<p>Welcome to our newsletter!</p>
</body>
</head>
</html>

94
src/routes/login.rs Normal file
View File

@@ -0,0 +1,94 @@
use crate::{
authentication::{AuthError, Credentials, validate_credentials},
routes::error_chain_fmt,
startup::AppState,
};
use axum::{
Form, Json,
extract::State,
response::{Html, IntoResponse, Redirect, Response},
};
use axum_messages::Messages;
use reqwest::StatusCode;
use secrecy::SecretString;
use std::fmt::Write;
#[derive(thiserror::Error)]
pub enum LoginError {
#[error("Something went wrong.")]
UnexpectedError(#[from] anyhow::Error),
#[error("Authentication failed.")]
AuthError(#[source] anyhow::Error),
}
impl std::fmt::Debug for LoginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
error_chain_fmt(self, f)
}
}
impl IntoResponse for LoginError {
fn into_response(self) -> Response {
#[derive(serde::Serialize)]
struct ErrorResponse<'a> {
message: &'a str,
}
tracing::error!("{:?}", self);
match &self {
LoginError::UnexpectedError(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
message: "An internal server error occured.",
}),
)
.into_response(),
LoginError::AuthError(_) => Redirect::to("/login").into_response(),
}
}
}
#[derive(serde::Deserialize)]
pub struct LoginFormData {
username: String,
password: SecretString,
}
pub async fn get_login(messages: Messages) -> impl IntoResponse {
let mut error_html = String::new();
for message in messages {
writeln!(error_html, "<p><i>{}</i></p>", message).unwrap();
}
Html(format!(include_str!("login/login.html"), error_html))
}
#[tracing::instrument(
skip(connection_pool, form),
fields(username=tracing::field::Empty, user_id=tracing::field::Empty)
)]
pub async fn post_login(
messages: Messages,
State(AppState {
connection_pool, ..
}): State<AppState>,
Form(form): Form<LoginFormData>,
) -> Result<Redirect, LoginError> {
let credentials = Credentials {
username: form.username,
password: form.password,
};
tracing::Span::current().record("username", tracing::field::display(&credentials.username));
let user_id = validate_credentials(credentials, &connection_pool)
.await
.map_err(|e| match e {
AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()),
AuthError::InvalidCredentials(_) => {
let e = LoginError::AuthError(e.into());
messages.error(e.to_string());
e
}
})?;
tracing::Span::current().record("user_id", tracing::field::display(&user_id));
Ok(Redirect::to("/"))
}

View File

@@ -0,0 +1,16 @@
<!DOCTYPE html>
<html lang="en">
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width" />
<head>
<title>Login</title>
<body>
<form action="/login" method="post">
<input type="text" name="username" placeholder="Username" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Login</button>
</form>
{}
</body>
</head>
</html>

View File

@@ -1,17 +1,27 @@
use crate::{domain::SubscriberEmail, routes::error_chain_fmt, startup::AppState};
use crate::{
authentication::{AuthError, Credentials, validate_credentials},
domain::SubscriberEmail,
routes::error_chain_fmt,
startup::AppState,
};
use anyhow::Context;
use axum::{
Json,
extract::State,
http::{HeaderMap, HeaderValue},
response::{IntoResponse, Response},
};
use reqwest::StatusCode;
use base64::Engine;
use reqwest::{StatusCode, header};
use secrecy::SecretString;
use sqlx::PgPool;
#[derive(thiserror::Error)]
pub enum PublishError {
#[error(transparent)]
UnexpectedError(#[from] anyhow::Error),
#[error("Authentication failed.")]
AuthError(#[source] anyhow::Error),
}
impl std::fmt::Debug for PublishError {
@@ -29,11 +39,25 @@ impl IntoResponse for PublishError {
tracing::error!("{:?}", self);
let mut authenticate_header_value = None;
let status = match self {
PublishError::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR,
PublishError::AuthError(_) => {
authenticate_header_value =
Some(HeaderValue::from_str(r#"Basic realm="publish""#).unwrap());
StatusCode::UNAUTHORIZED
}
};
let message = "An internal server error occured.";
(status, Json(ErrorResponse { message })).into_response()
let mut response = (status, Json(ErrorResponse { message })).into_response();
if let Some(header_value) = authenticate_header_value {
response
.headers_mut()
.insert(header::WWW_AUTHENTICATE, header_value);
}
response
}
}
@@ -51,9 +75,11 @@ pub struct Content {
#[tracing::instrument(
name = "Publishing a newsletter",
skip(connection_pool, email_client, body)
skip(headers, connection_pool, email_client, body),
fields(username=tracing::field::Empty, user_id=tracing::field::Empty)
)]
pub async fn publish_newsletter(
headers: HeaderMap,
State(AppState {
connection_pool,
email_client,
@@ -61,6 +87,15 @@ pub async fn publish_newsletter(
}): State<AppState>,
body: Json<BodyData>,
) -> Result<Response, PublishError> {
let credentials = basic_authentication(&headers).map_err(PublishError::AuthError)?;
tracing::Span::current().record("username", tracing::field::display(&credentials.username));
let user_id = validate_credentials(credentials, &connection_pool)
.await
.map_err(|e| match e {
AuthError::UnexpectedError(_) => PublishError::UnexpectedError(e.into()),
AuthError::InvalidCredentials(_) => PublishError::AuthError(e.into()),
})?;
tracing::Span::current().record("user_id", tracing::field::display(&user_id));
let subscribers = get_confirmed_subscribers(&connection_pool).await?;
for subscriber in subscribers {
match subscriber {
@@ -83,6 +118,37 @@ pub async fn publish_newsletter(
Ok(StatusCode::OK.into_response())
}
fn basic_authentication(headers: &HeaderMap) -> Result<Credentials, anyhow::Error> {
let header_value = headers
.get("Authorization")
.context("The 'Authorization' header was missing.")?
.to_str()
.context("The 'Authorization' header was not a valid UTF8 string.")?;
let base64encoded_segment = header_value
.strip_prefix("Basic ")
.context("The authorization scheme was not 'Basic'.")?;
let decoded_bytes = base64::engine::general_purpose::STANDARD
.decode(base64encoded_segment)
.context("Failed to base64-decode 'Basic' credentials.")?;
let decoded_credentials = String::from_utf8(decoded_bytes)
.context("The decoded credential string is not valid UTF-8.")?;
let mut credentials = decoded_credentials.splitn(2, ':');
let username = credentials
.next()
.context("A username must be provided in 'Basic' auth.")?
.to_string();
let password = credentials
.next()
.context("A password must be provided in 'Basic' auth.")?
.to_string();
Ok(Credentials {
username,
password: SecretString::from(password),
})
}
#[allow(dead_code)]
struct ConfirmedSubscriber {
name: String,

View File

@@ -85,6 +85,7 @@ pub async fn subscribe(
connection_pool,
email_client,
base_url,
..
}): State<AppState>,
Form(form): Form<FormData>,
) -> Result<Response, SubscribeError> {

View File

@@ -1,3 +1,4 @@
use crate::startup::AppState;
use axum::{
extract::{Query, State},
http::StatusCode,
@@ -7,8 +8,6 @@ use serde::Deserialize;
use sqlx::PgPool;
use uuid::Uuid;
use crate::startup::AppState;
#[tracing::instrument(name = "Confirming new subscriber", skip(params))]
pub async fn confirm(
State(AppState {