Flash messages using axum-messages
This commit is contained in:
93
src/authentication.rs
Normal file
93
src/authentication.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use anyhow::Context;
|
||||
use argon2::{Argon2, PasswordHash, PasswordVerifier};
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::telemetry::spawn_blocking_with_tracing;
|
||||
|
||||
pub struct Credentials {
|
||||
pub username: String,
|
||||
pub password: SecretString,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error(transparent)]
|
||||
UnexpectedError(#[from] anyhow::Error),
|
||||
#[error("Invalid credentials.")]
|
||||
InvalidCredentials(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "Validate credentials",
|
||||
skip(username, password, connection_pool)
|
||||
)]
|
||||
pub async fn validate_credentials(
|
||||
Credentials { username, password }: Credentials,
|
||||
connection_pool: &PgPool,
|
||||
) -> Result<Uuid, AuthError> {
|
||||
let mut user_id = None;
|
||||
let mut expected_password_hash = SecretString::from(
|
||||
"$argon2id$v=19$m=15000,t=2,p=1$\
|
||||
gZiV/M1gPc22ElAH/Jh1Hw$\
|
||||
CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"
|
||||
.to_string(),
|
||||
);
|
||||
if let Some((stored_user_id, stored_expected_password_hash)) =
|
||||
get_stored_credentials(&username, connection_pool)
|
||||
.await
|
||||
.map_err(AuthError::UnexpectedError)?
|
||||
{
|
||||
user_id = Some(stored_user_id);
|
||||
expected_password_hash = stored_expected_password_hash;
|
||||
}
|
||||
|
||||
spawn_blocking_with_tracing(|| verify_password_hash(expected_password_hash, password))
|
||||
.await
|
||||
.context("Failed to spawn blocking task.")
|
||||
.map_err(AuthError::UnexpectedError)??;
|
||||
|
||||
user_id
|
||||
.ok_or_else(|| anyhow::anyhow!("Unknown username."))
|
||||
.map_err(AuthError::InvalidCredentials)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "Verify password",
|
||||
skip(expected_password_hash, password_candidate)
|
||||
)]
|
||||
fn verify_password_hash(
|
||||
expected_password_hash: SecretString,
|
||||
password_candidate: SecretString,
|
||||
) -> Result<(), AuthError> {
|
||||
let expected_password_hash = PasswordHash::new(expected_password_hash.expose_secret())
|
||||
.context("Failed to parse hash in PHC string format.")?;
|
||||
Argon2::default()
|
||||
.verify_password(
|
||||
password_candidate.expose_secret().as_bytes(),
|
||||
&expected_password_hash,
|
||||
)
|
||||
.context("Password verification failed.")
|
||||
.map_err(AuthError::InvalidCredentials)
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "Get stored credentials", skip(username, connection_pool))]
|
||||
async fn get_stored_credentials(
|
||||
username: &str,
|
||||
connection_pool: &PgPool,
|
||||
) -> Result<Option<(Uuid, SecretString)>, anyhow::Error> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT user_id, password_hash
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
"#,
|
||||
username,
|
||||
)
|
||||
.fetch_optional(connection_pool)
|
||||
.await
|
||||
.context("Failed to perform a query to retrieve stored credentials.")?
|
||||
.map(|row| (row.user_id, SecretString::from(row.password_hash)));
|
||||
Ok(row)
|
||||
}
|
||||
@@ -68,6 +68,7 @@ pub struct ApplicationSettings {
|
||||
pub port: u16,
|
||||
pub host: String,
|
||||
pub base_url: String,
|
||||
pub hmac_secret: SecretString,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod authentication;
|
||||
pub mod configuration;
|
||||
pub mod domain;
|
||||
pub mod email_client;
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
mod health_check;
|
||||
mod home;
|
||||
mod login;
|
||||
mod newsletters;
|
||||
mod subscriptions;
|
||||
mod subscriptions_confirm;
|
||||
|
||||
pub use health_check::*;
|
||||
pub use home::*;
|
||||
pub use login::*;
|
||||
pub use newsletters::*;
|
||||
pub use subscriptions::*;
|
||||
pub use subscriptions_confirm::*;
|
||||
|
||||
5
src/routes/home.rs
Normal file
5
src/routes/home.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
use axum::response::{Html, IntoResponse};
|
||||
|
||||
pub async fn home() -> impl IntoResponse {
|
||||
Html(include_str!("home/home.html"))
|
||||
}
|
||||
11
src/routes/home/home.html
Normal file
11
src/routes/home/home.html
Normal file
@@ -0,0 +1,11 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<head>
|
||||
<title>Home</title>
|
||||
<body>
|
||||
<p>Welcome to our newsletter!</p>
|
||||
</body>
|
||||
</head>
|
||||
</html>
|
||||
94
src/routes/login.rs
Normal file
94
src/routes/login.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use crate::{
|
||||
authentication::{AuthError, Credentials, validate_credentials},
|
||||
routes::error_chain_fmt,
|
||||
startup::AppState,
|
||||
};
|
||||
use axum::{
|
||||
Form, Json,
|
||||
extract::State,
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_messages::Messages;
|
||||
use reqwest::StatusCode;
|
||||
use secrecy::SecretString;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum LoginError {
|
||||
#[error("Something went wrong.")]
|
||||
UnexpectedError(#[from] anyhow::Error),
|
||||
#[error("Authentication failed.")]
|
||||
AuthError(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for LoginError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
error_chain_fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for LoginError {
|
||||
fn into_response(self) -> Response {
|
||||
#[derive(serde::Serialize)]
|
||||
struct ErrorResponse<'a> {
|
||||
message: &'a str,
|
||||
}
|
||||
|
||||
tracing::error!("{:?}", self);
|
||||
|
||||
match &self {
|
||||
LoginError::UnexpectedError(_) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
message: "An internal server error occured.",
|
||||
}),
|
||||
)
|
||||
.into_response(),
|
||||
LoginError::AuthError(_) => Redirect::to("/login").into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct LoginFormData {
|
||||
username: String,
|
||||
password: SecretString,
|
||||
}
|
||||
|
||||
pub async fn get_login(messages: Messages) -> impl IntoResponse {
|
||||
let mut error_html = String::new();
|
||||
for message in messages {
|
||||
writeln!(error_html, "<p><i>{}</i></p>", message).unwrap();
|
||||
}
|
||||
Html(format!(include_str!("login/login.html"), error_html))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(connection_pool, form),
|
||||
fields(username=tracing::field::Empty, user_id=tracing::field::Empty)
|
||||
)]
|
||||
pub async fn post_login(
|
||||
messages: Messages,
|
||||
State(AppState {
|
||||
connection_pool, ..
|
||||
}): State<AppState>,
|
||||
Form(form): Form<LoginFormData>,
|
||||
) -> Result<Redirect, LoginError> {
|
||||
let credentials = Credentials {
|
||||
username: form.username,
|
||||
password: form.password,
|
||||
};
|
||||
tracing::Span::current().record("username", tracing::field::display(&credentials.username));
|
||||
let user_id = validate_credentials(credentials, &connection_pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()),
|
||||
AuthError::InvalidCredentials(_) => {
|
||||
let e = LoginError::AuthError(e.into());
|
||||
messages.error(e.to_string());
|
||||
e
|
||||
}
|
||||
})?;
|
||||
tracing::Span::current().record("user_id", tracing::field::display(&user_id));
|
||||
Ok(Redirect::to("/"))
|
||||
}
|
||||
16
src/routes/login/login.html
Normal file
16
src/routes/login/login.html
Normal file
@@ -0,0 +1,16 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<head>
|
||||
<title>Login</title>
|
||||
<body>
|
||||
<form action="/login" method="post">
|
||||
<input type="text" name="username" placeholder="Username" />
|
||||
<input type="password" name="password" placeholder="Password" />
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
{}
|
||||
</body>
|
||||
</head>
|
||||
</html>
|
||||
@@ -1,17 +1,27 @@
|
||||
use crate::{domain::SubscriberEmail, routes::error_chain_fmt, startup::AppState};
|
||||
use crate::{
|
||||
authentication::{AuthError, Credentials, validate_credentials},
|
||||
domain::SubscriberEmail,
|
||||
routes::error_chain_fmt,
|
||||
startup::AppState,
|
||||
};
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::State,
|
||||
http::{HeaderMap, HeaderValue},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use reqwest::StatusCode;
|
||||
use base64::Engine;
|
||||
use reqwest::{StatusCode, header};
|
||||
use secrecy::SecretString;
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum PublishError {
|
||||
#[error(transparent)]
|
||||
UnexpectedError(#[from] anyhow::Error),
|
||||
#[error("Authentication failed.")]
|
||||
AuthError(#[source] anyhow::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PublishError {
|
||||
@@ -29,11 +39,25 @@ impl IntoResponse for PublishError {
|
||||
|
||||
tracing::error!("{:?}", self);
|
||||
|
||||
let mut authenticate_header_value = None;
|
||||
let status = match self {
|
||||
PublishError::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
PublishError::AuthError(_) => {
|
||||
authenticate_header_value =
|
||||
Some(HeaderValue::from_str(r#"Basic realm="publish""#).unwrap());
|
||||
StatusCode::UNAUTHORIZED
|
||||
}
|
||||
};
|
||||
|
||||
let message = "An internal server error occured.";
|
||||
(status, Json(ErrorResponse { message })).into_response()
|
||||
let mut response = (status, Json(ErrorResponse { message })).into_response();
|
||||
if let Some(header_value) = authenticate_header_value {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(header::WWW_AUTHENTICATE, header_value);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,9 +75,11 @@ pub struct Content {
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "Publishing a newsletter",
|
||||
skip(connection_pool, email_client, body)
|
||||
skip(headers, connection_pool, email_client, body),
|
||||
fields(username=tracing::field::Empty, user_id=tracing::field::Empty)
|
||||
)]
|
||||
pub async fn publish_newsletter(
|
||||
headers: HeaderMap,
|
||||
State(AppState {
|
||||
connection_pool,
|
||||
email_client,
|
||||
@@ -61,6 +87,15 @@ pub async fn publish_newsletter(
|
||||
}): State<AppState>,
|
||||
body: Json<BodyData>,
|
||||
) -> Result<Response, PublishError> {
|
||||
let credentials = basic_authentication(&headers).map_err(PublishError::AuthError)?;
|
||||
tracing::Span::current().record("username", tracing::field::display(&credentials.username));
|
||||
let user_id = validate_credentials(credentials, &connection_pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
AuthError::UnexpectedError(_) => PublishError::UnexpectedError(e.into()),
|
||||
AuthError::InvalidCredentials(_) => PublishError::AuthError(e.into()),
|
||||
})?;
|
||||
tracing::Span::current().record("user_id", tracing::field::display(&user_id));
|
||||
let subscribers = get_confirmed_subscribers(&connection_pool).await?;
|
||||
for subscriber in subscribers {
|
||||
match subscriber {
|
||||
@@ -83,6 +118,37 @@ pub async fn publish_newsletter(
|
||||
Ok(StatusCode::OK.into_response())
|
||||
}
|
||||
|
||||
fn basic_authentication(headers: &HeaderMap) -> Result<Credentials, anyhow::Error> {
|
||||
let header_value = headers
|
||||
.get("Authorization")
|
||||
.context("The 'Authorization' header was missing.")?
|
||||
.to_str()
|
||||
.context("The 'Authorization' header was not a valid UTF8 string.")?;
|
||||
let base64encoded_segment = header_value
|
||||
.strip_prefix("Basic ")
|
||||
.context("The authorization scheme was not 'Basic'.")?;
|
||||
let decoded_bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(base64encoded_segment)
|
||||
.context("Failed to base64-decode 'Basic' credentials.")?;
|
||||
let decoded_credentials = String::from_utf8(decoded_bytes)
|
||||
.context("The decoded credential string is not valid UTF-8.")?;
|
||||
|
||||
let mut credentials = decoded_credentials.splitn(2, ':');
|
||||
let username = credentials
|
||||
.next()
|
||||
.context("A username must be provided in 'Basic' auth.")?
|
||||
.to_string();
|
||||
let password = credentials
|
||||
.next()
|
||||
.context("A password must be provided in 'Basic' auth.")?
|
||||
.to_string();
|
||||
|
||||
Ok(Credentials {
|
||||
username,
|
||||
password: SecretString::from(password),
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct ConfirmedSubscriber {
|
||||
name: String,
|
||||
|
||||
@@ -85,6 +85,7 @@ pub async fn subscribe(
|
||||
connection_pool,
|
||||
email_client,
|
||||
base_url,
|
||||
..
|
||||
}): State<AppState>,
|
||||
Form(form): Form<FormData>,
|
||||
) -> Result<Response, SubscribeError> {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::startup::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
@@ -7,8 +8,6 @@ use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::startup::AppState;
|
||||
|
||||
#[tracing::instrument(name = "Confirming new subscriber", skip(params))]
|
||||
pub async fn confirm(
|
||||
State(AppState {
|
||||
|
||||
@@ -5,10 +5,13 @@ use axum::{
|
||||
http::Request,
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum_messages::MessagesManagerLayer;
|
||||
use secrecy::SecretString;
|
||||
use sqlx::{PgPool, postgres::PgPoolOptions};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_sessions::{MemoryStore, SessionManagerLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct Application {
|
||||
@@ -21,6 +24,7 @@ pub struct AppState {
|
||||
pub connection_pool: PgPool,
|
||||
pub email_client: Arc<EmailClient>,
|
||||
pub base_url: String,
|
||||
pub hmac_secret: SecretString,
|
||||
}
|
||||
|
||||
impl Application {
|
||||
@@ -37,6 +41,7 @@ impl Application {
|
||||
connection_pool,
|
||||
email_client,
|
||||
configuration.application.base_url,
|
||||
configuration.application.hmac_secret,
|
||||
);
|
||||
Ok(Self { listener, router })
|
||||
}
|
||||
@@ -51,35 +56,43 @@ impl Application {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn app(connection_pool: PgPool, email_client: EmailClient, base_url: String) -> Router {
|
||||
pub fn app(
|
||||
connection_pool: PgPool,
|
||||
email_client: EmailClient,
|
||||
base_url: String,
|
||||
hmac_secret: SecretString,
|
||||
) -> Router {
|
||||
let app_state = AppState {
|
||||
connection_pool,
|
||||
email_client: Arc::new(email_client),
|
||||
base_url,
|
||||
hmac_secret,
|
||||
};
|
||||
Router::new()
|
||||
.route("/", get(home))
|
||||
.route("/login", get(get_login).post(post_login))
|
||||
.route("/health_check", get(health_check))
|
||||
.route("/subscriptions", post(subscribe))
|
||||
.route("/subscriptions/confirm", get(confirm))
|
||||
.route("/newsletters", post(publish_newsletter))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &Request<_>| {
|
||||
let matched_path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.map(MatchedPath::as_str);
|
||||
let request_id = Uuid::new_v4().to_string();
|
||||
TraceLayer::new_for_http().make_span_with(|request: &Request<_>| {
|
||||
let matched_path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.map(MatchedPath::as_str);
|
||||
let request_id = Uuid::new_v4().to_string();
|
||||
|
||||
tracing::info_span!(
|
||||
"http_request",
|
||||
method = ?request.method(),
|
||||
matched_path,
|
||||
request_id,
|
||||
some_other_field = tracing::field::Empty,
|
||||
)
|
||||
})
|
||||
.on_failure(()),
|
||||
tracing::info_span!(
|
||||
"http_request",
|
||||
method = ?request.method(),
|
||||
matched_path,
|
||||
request_id,
|
||||
some_other_field = tracing::field::Empty,
|
||||
)
|
||||
}),
|
||||
)
|
||||
.layer(MessagesManagerLayer)
|
||||
.layer(SessionManagerLayer::new(MemoryStore::default()))
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
|
||||
use tracing_subscriber::{fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
@@ -20,3 +21,12 @@ where
|
||||
.with(formatting_layer)
|
||||
.init();
|
||||
}
|
||||
|
||||
pub fn spawn_blocking_with_tracing<F, R>(f: F) -> JoinHandle<R>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let current_span = tracing::Span::current();
|
||||
tokio::task::spawn_blocking(move || current_span.in_scope(f))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user