From 3dce578ba0b1b41f6bf6e98a863c04525d3da704 Mon Sep 17 00:00:00 2001 From: Alphonse Paix Date: Sat, 30 Aug 2025 01:15:54 +0200 Subject: [PATCH] Flash messages using axum-messages --- Cargo.lock | 257 ++++++++++++++++++ Cargo.toml | 15 +- configuration/local.yaml | 1 + .../20250828142613_create_users_table.sql | 5 + .../20250828161700_rename_password_column.sql | 1 + .../20250828204455_add_salt_to_users.sql | 1 + .../20250828212543_remove_salt_from_users.sql | 1 + src/authentication.rs | 93 +++++++ src/configuration.rs | 1 + src/lib.rs | 1 + src/routes.rs | 4 + src/routes/home.rs | 5 + src/routes/home/home.html | 11 + src/routes/login.rs | 94 +++++++ src/routes/login/login.html | 16 ++ src/routes/newsletters.rs | 74 ++++- src/routes/subscriptions.rs | 1 + src/routes/subscriptions_confirm.rs | 3 +- src/startup.rs | 47 ++-- src/telemetry.rs | 10 + tests/api/helpers.rs | 119 ++++++-- tests/api/login.rs | 22 ++ tests/api/main.rs | 1 + tests/api/newsletters.rs | 82 ++++++ 24 files changed, 820 insertions(+), 45 deletions(-) create mode 100644 migrations/20250828142613_create_users_table.sql create mode 100644 migrations/20250828161700_rename_password_column.sql create mode 100644 migrations/20250828204455_add_salt_to_users.sql create mode 100644 migrations/20250828212543_remove_salt_from_users.sql create mode 100644 src/authentication.rs create mode 100644 src/routes/home.rs create mode 100644 src/routes/home/home.html create mode 100644 src/routes/login.rs create mode 100644 src/routes/login/login.html create mode 100644 tests/api/login.rs diff --git a/Cargo.lock b/Cargo.lock index 237c00c..798b4ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,18 @@ version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "arraydeque" version = "0.5.1" @@ -168,6 +180,48 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "serde", + "serde_html_form", + "serde_path_to_error", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-messages" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d67ce6e7bc1e1e71f2a4e86d418045a29c63c4ebb631f3d9bb2f81c4958ea391" +dependencies = [ + "axum-core", + "http", + "parking_lot", + "serde", + "serde_json", + "tower", + "tower-sessions-core", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -210,6 +264,15 @@ dependencies = [ "serde", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -340,6 +403,35 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "cookie_store" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eac901828f88a5241ee0600950ab981148a18f2f756900ffba1b125ca6a3ef9" +dependencies = [ + "cookie", + "document-features", + "idna", + "log", + "publicsuffix", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -472,6 +564,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -512,6 +605,15 @@ dependencies = [ "const-random", ] +[[package]] +name = "document-features" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d" +dependencies = [ + "litrs", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -870,6 +972,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "htmlescape" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9025058dae765dee5070ec375f591e2ba14638c63feff74f13805a72e523163" + [[package]] name = "http" version = "1.3.1" @@ -1181,6 +1289,15 @@ dependencies = [ "serde", ] +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1238,6 +1355,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +[[package]] +name = "litrs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed" + [[package]] name = "lock_api" version = "0.4.13" @@ -1246,6 +1369,7 @@ checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", + "serde", ] [[package]] @@ -1459,6 +1583,17 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "pathdiff" version = "0.2.3" @@ -1618,6 +1753,22 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psl-types" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" + +[[package]] +name = "publicsuffix" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42ea446cab60335f76979ec15e12619a2165b5ae2c12166bef27d283a9fadf" +dependencies = [ + "idna", + "psl-types", +] + [[package]] name = "quickcheck" version = "1.0.3" @@ -1830,6 +1981,8 @@ checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ "base64 0.22.1", "bytes", + "cookie", + "cookie_store", "futures-core", "http", "http-body", @@ -2045,6 +2198,19 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.143" @@ -2110,6 +2276,16 @@ dependencies = [ "digest", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2651,6 +2827,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-cookies" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36" +dependencies = [ + "axum-core", + "cookie", + "futures-util", + "http", + "parking_lot", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.6.6" @@ -2682,6 +2874,57 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "tower-sessions" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a05911f23e8fae446005fe9b7b97e66d95b6db589dc1c4d59f6a2d4d4927d3" +dependencies = [ + "async-trait", + "http", + "time", + "tokio", + "tower-cookies", + "tower-layer", + "tower-service", + "tower-sessions-core", + "tower-sessions-memory-store", + "tracing", +] + +[[package]] +name = "tower-sessions-core" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce8cce604865576b7751b7a6bc3058f754569a60d689328bb74c52b1d87e355b" +dependencies = [ + "async-trait", + "axum-core", + "base64 0.22.1", + "futures", + "http", + "parking_lot", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", +] + +[[package]] +name = "tower-sessions-memory-store" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb05909f2e1420135a831dd5df9f5596d69196d0a64c3499ca474c4bd3d33242" +dependencies = [ + "async-trait", + "time", + "tokio", + "tower-sessions-core", +] + [[package]] name = "tracing" version = "0.1.41" @@ -2853,6 +3096,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3384,11 +3633,16 @@ name = "zero2prod" version = "0.1.0" dependencies = [ "anyhow", + "argon2", "axum", + "axum-extra", + "axum-messages", + "base64 0.22.1", "chrono", "claims", "config", "fake", + "htmlescape", "linkify", "once_cell", "quickcheck", @@ -3399,14 +3653,17 @@ dependencies = [ "serde", "serde-aux", "serde_json", + "sha3", "sqlx", "thiserror", "tokio", "tower-http", + "tower-sessions", "tracing", "tracing-bunyan-formatter", "tracing-subscriber", "unicode-segmentation", + "urlencoding", "uuid", "validator", "wiremock", diff --git a/Cargo.toml b/Cargo.toml index 773f246..ce483ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "zero2prod" version = "0.1.0" edition = "2024" +resolver = "2" [lib] path = "src/lib.rs" @@ -10,24 +11,36 @@ path = "src/lib.rs" path = "src/main.rs" name = "zero2prod" +[target.x86_64-unknown-linux-gnu] +linker = "clang" +rustflags = ["-C", "link-arg=-fuse-ld=/usr/bin/mold"] + [dependencies] anyhow = "1.0.99" +argon2 = { version = "0.5.3", features = ["std"] } axum = "0.8.4" +axum-extra = { version = "0.10.1", features = ["query", "cookie"] } +axum-messages = "0.8.0" +base64 = "0.22.1" chrono = { version = "0.4.41", default-features = false, features = ["clock"] } config = "0.15.14" +htmlescape = "0.3.1" rand = { version = "0.9.2", features = ["std_rng"] } -reqwest = { version = "0.12.23", default-features = false, features = ["rustls-tls", "json"] } +reqwest = { version = "0.12.23", default-features = false, features = ["rustls-tls", "json", "cookies"] } secrecy = { version = "0.10.3", features = ["serde"] } serde = { version = "1.0.219", features = ["derive"] } serde-aux = "4.7.0" +sha3 = "0.10.8" sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "macros", "postgres", "uuid", "chrono", "migrate"] } thiserror = "2.0.16" tokio = { version = "1.47.1", features = ["macros", "rt-multi-thread"] } tower-http = { version = "0.6.6", features = ["trace"] } +tower-sessions = "0.14.0" tracing = "0.1.41" tracing-bunyan-formatter = "0.3.10" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } unicode-segmentation = "1.12.0" +urlencoding = "2.1.3" uuid = { version = "1.18.0", features = ["v4"] } validator = { version = "0.20.0", features = ["derive"] } diff --git a/configuration/local.yaml b/configuration/local.yaml index f5a4bba..b5fad82 100644 --- a/configuration/local.yaml +++ b/configuration/local.yaml @@ -1,5 +1,6 @@ application: host: "127.0.0.1" base_url: "http://127.0.0.1" + hmac_secret: vPojv$zM3Rxt#RT0D*Tp database: require_ssl: false diff --git a/migrations/20250828142613_create_users_table.sql b/migrations/20250828142613_create_users_table.sql new file mode 100644 index 0000000..0d93965 --- /dev/null +++ b/migrations/20250828142613_create_users_table.sql @@ -0,0 +1,5 @@ +CREATE TABLE users ( + user_id UUID PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT NOT NULL +); diff --git a/migrations/20250828161700_rename_password_column.sql b/migrations/20250828161700_rename_password_column.sql new file mode 100644 index 0000000..e2dbd2c --- /dev/null +++ b/migrations/20250828161700_rename_password_column.sql @@ -0,0 +1 @@ +ALTER TABLE users RENAME password TO password_hash; diff --git a/migrations/20250828204455_add_salt_to_users.sql b/migrations/20250828204455_add_salt_to_users.sql new file mode 100644 index 0000000..30f1a8f --- /dev/null +++ b/migrations/20250828204455_add_salt_to_users.sql @@ -0,0 +1 @@ +ALTER TABLE users ADD COLUMN salt TEXT NOT NULL; diff --git a/migrations/20250828212543_remove_salt_from_users.sql b/migrations/20250828212543_remove_salt_from_users.sql new file mode 100644 index 0000000..dac7f66 --- /dev/null +++ b/migrations/20250828212543_remove_salt_from_users.sql @@ -0,0 +1 @@ +ALTER TABLE users DROP COLUMN salt; diff --git a/src/authentication.rs b/src/authentication.rs new file mode 100644 index 0000000..05c3a52 --- /dev/null +++ b/src/authentication.rs @@ -0,0 +1,93 @@ +use anyhow::Context; +use argon2::{Argon2, PasswordHash, PasswordVerifier}; +use secrecy::{ExposeSecret, SecretString}; +use sqlx::PgPool; +use uuid::Uuid; + +use crate::telemetry::spawn_blocking_with_tracing; + +pub struct Credentials { + pub username: String, + pub password: SecretString, +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + #[error(transparent)] + UnexpectedError(#[from] anyhow::Error), + #[error("Invalid credentials.")] + InvalidCredentials(#[source] anyhow::Error), +} + +#[tracing::instrument( + name = "Validate credentials", + skip(username, password, connection_pool) +)] +pub async fn validate_credentials( + Credentials { username, password }: Credentials, + connection_pool: &PgPool, +) -> Result { + let mut user_id = None; + let mut expected_password_hash = SecretString::from( + "$argon2id$v=19$m=15000,t=2,p=1$\ +gZiV/M1gPc22ElAH/Jh1Hw$\ +CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno" + .to_string(), + ); + if let Some((stored_user_id, stored_expected_password_hash)) = + get_stored_credentials(&username, connection_pool) + .await + .map_err(AuthError::UnexpectedError)? + { + user_id = Some(stored_user_id); + expected_password_hash = stored_expected_password_hash; + } + + spawn_blocking_with_tracing(|| verify_password_hash(expected_password_hash, password)) + .await + .context("Failed to spawn blocking task.") + .map_err(AuthError::UnexpectedError)??; + + user_id + .ok_or_else(|| anyhow::anyhow!("Unknown username.")) + .map_err(AuthError::InvalidCredentials) +} + +#[tracing::instrument( + name = "Verify password", + skip(expected_password_hash, password_candidate) +)] +fn verify_password_hash( + expected_password_hash: SecretString, + password_candidate: SecretString, +) -> Result<(), AuthError> { + let expected_password_hash = PasswordHash::new(expected_password_hash.expose_secret()) + .context("Failed to parse hash in PHC string format.")?; + Argon2::default() + .verify_password( + password_candidate.expose_secret().as_bytes(), + &expected_password_hash, + ) + .context("Password verification failed.") + .map_err(AuthError::InvalidCredentials) +} + +#[tracing::instrument(name = "Get stored credentials", skip(username, connection_pool))] +async fn get_stored_credentials( + username: &str, + connection_pool: &PgPool, +) -> Result, anyhow::Error> { + let row = sqlx::query!( + r#" + SELECT user_id, password_hash + FROM users + WHERE username = $1 + "#, + username, + ) + .fetch_optional(connection_pool) + .await + .context("Failed to perform a query to retrieve stored credentials.")? + .map(|row| (row.user_id, SecretString::from(row.password_hash))); + Ok(row) +} diff --git a/src/configuration.rs b/src/configuration.rs index f1bbe26..37b472a 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -68,6 +68,7 @@ pub struct ApplicationSettings { pub port: u16, pub host: String, pub base_url: String, + pub hmac_secret: SecretString, } #[derive(Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index 66e386a..1bb9b5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod authentication; pub mod configuration; pub mod domain; pub mod email_client; diff --git a/src/routes.rs b/src/routes.rs index 1055760..b0d8892 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,9 +1,13 @@ mod health_check; +mod home; +mod login; mod newsletters; mod subscriptions; mod subscriptions_confirm; pub use health_check::*; +pub use home::*; +pub use login::*; pub use newsletters::*; pub use subscriptions::*; pub use subscriptions_confirm::*; diff --git a/src/routes/home.rs b/src/routes/home.rs new file mode 100644 index 0000000..4610d1e --- /dev/null +++ b/src/routes/home.rs @@ -0,0 +1,5 @@ +use axum::response::{Html, IntoResponse}; + +pub async fn home() -> impl IntoResponse { + Html(include_str!("home/home.html")) +} diff --git a/src/routes/home/home.html b/src/routes/home/home.html new file mode 100644 index 0000000..cbf2601 --- /dev/null +++ b/src/routes/home/home.html @@ -0,0 +1,11 @@ + + + + + + Home + +

Welcome to our newsletter!

+ + + diff --git a/src/routes/login.rs b/src/routes/login.rs new file mode 100644 index 0000000..e7b02e5 --- /dev/null +++ b/src/routes/login.rs @@ -0,0 +1,94 @@ +use crate::{ + authentication::{AuthError, Credentials, validate_credentials}, + routes::error_chain_fmt, + startup::AppState, +}; +use axum::{ + Form, Json, + extract::State, + response::{Html, IntoResponse, Redirect, Response}, +}; +use axum_messages::Messages; +use reqwest::StatusCode; +use secrecy::SecretString; +use std::fmt::Write; + +#[derive(thiserror::Error)] +pub enum LoginError { + #[error("Something went wrong.")] + UnexpectedError(#[from] anyhow::Error), + #[error("Authentication failed.")] + AuthError(#[source] anyhow::Error), +} + +impl std::fmt::Debug for LoginError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + error_chain_fmt(self, f) + } +} + +impl IntoResponse for LoginError { + fn into_response(self) -> Response { + #[derive(serde::Serialize)] + struct ErrorResponse<'a> { + message: &'a str, + } + + tracing::error!("{:?}", self); + + match &self { + LoginError::UnexpectedError(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + message: "An internal server error occured.", + }), + ) + .into_response(), + LoginError::AuthError(_) => Redirect::to("/login").into_response(), + } + } +} + +#[derive(serde::Deserialize)] +pub struct LoginFormData { + username: String, + password: SecretString, +} + +pub async fn get_login(messages: Messages) -> impl IntoResponse { + let mut error_html = String::new(); + for message in messages { + writeln!(error_html, "

{}

", message).unwrap(); + } + Html(format!(include_str!("login/login.html"), error_html)) +} + +#[tracing::instrument( + skip(connection_pool, form), + fields(username=tracing::field::Empty, user_id=tracing::field::Empty) +)] +pub async fn post_login( + messages: Messages, + State(AppState { + connection_pool, .. + }): State, + Form(form): Form, +) -> Result { + let credentials = Credentials { + username: form.username, + password: form.password, + }; + tracing::Span::current().record("username", tracing::field::display(&credentials.username)); + let user_id = validate_credentials(credentials, &connection_pool) + .await + .map_err(|e| match e { + AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()), + AuthError::InvalidCredentials(_) => { + let e = LoginError::AuthError(e.into()); + messages.error(e.to_string()); + e + } + })?; + tracing::Span::current().record("user_id", tracing::field::display(&user_id)); + Ok(Redirect::to("/")) +} diff --git a/src/routes/login/login.html b/src/routes/login/login.html new file mode 100644 index 0000000..817dff1 --- /dev/null +++ b/src/routes/login/login.html @@ -0,0 +1,16 @@ + + + + + + Login + +
+ + + +
+ {} + + + diff --git a/src/routes/newsletters.rs b/src/routes/newsletters.rs index 9907dc1..a1a5454 100644 --- a/src/routes/newsletters.rs +++ b/src/routes/newsletters.rs @@ -1,17 +1,27 @@ -use crate::{domain::SubscriberEmail, routes::error_chain_fmt, startup::AppState}; +use crate::{ + authentication::{AuthError, Credentials, validate_credentials}, + domain::SubscriberEmail, + routes::error_chain_fmt, + startup::AppState, +}; use anyhow::Context; use axum::{ Json, extract::State, + http::{HeaderMap, HeaderValue}, response::{IntoResponse, Response}, }; -use reqwest::StatusCode; +use base64::Engine; +use reqwest::{StatusCode, header}; +use secrecy::SecretString; use sqlx::PgPool; #[derive(thiserror::Error)] pub enum PublishError { #[error(transparent)] UnexpectedError(#[from] anyhow::Error), + #[error("Authentication failed.")] + AuthError(#[source] anyhow::Error), } impl std::fmt::Debug for PublishError { @@ -29,11 +39,25 @@ impl IntoResponse for PublishError { tracing::error!("{:?}", self); + let mut authenticate_header_value = None; let status = match self { PublishError::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR, + PublishError::AuthError(_) => { + authenticate_header_value = + Some(HeaderValue::from_str(r#"Basic realm="publish""#).unwrap()); + StatusCode::UNAUTHORIZED + } }; + let message = "An internal server error occured."; - (status, Json(ErrorResponse { message })).into_response() + let mut response = (status, Json(ErrorResponse { message })).into_response(); + if let Some(header_value) = authenticate_header_value { + response + .headers_mut() + .insert(header::WWW_AUTHENTICATE, header_value); + } + + response } } @@ -51,9 +75,11 @@ pub struct Content { #[tracing::instrument( name = "Publishing a newsletter", - skip(connection_pool, email_client, body) + skip(headers, connection_pool, email_client, body), + fields(username=tracing::field::Empty, user_id=tracing::field::Empty) )] pub async fn publish_newsletter( + headers: HeaderMap, State(AppState { connection_pool, email_client, @@ -61,6 +87,15 @@ pub async fn publish_newsletter( }): State, body: Json, ) -> Result { + let credentials = basic_authentication(&headers).map_err(PublishError::AuthError)?; + tracing::Span::current().record("username", tracing::field::display(&credentials.username)); + let user_id = validate_credentials(credentials, &connection_pool) + .await + .map_err(|e| match e { + AuthError::UnexpectedError(_) => PublishError::UnexpectedError(e.into()), + AuthError::InvalidCredentials(_) => PublishError::AuthError(e.into()), + })?; + tracing::Span::current().record("user_id", tracing::field::display(&user_id)); let subscribers = get_confirmed_subscribers(&connection_pool).await?; for subscriber in subscribers { match subscriber { @@ -83,6 +118,37 @@ pub async fn publish_newsletter( Ok(StatusCode::OK.into_response()) } +fn basic_authentication(headers: &HeaderMap) -> Result { + let header_value = headers + .get("Authorization") + .context("The 'Authorization' header was missing.")? + .to_str() + .context("The 'Authorization' header was not a valid UTF8 string.")?; + let base64encoded_segment = header_value + .strip_prefix("Basic ") + .context("The authorization scheme was not 'Basic'.")?; + let decoded_bytes = base64::engine::general_purpose::STANDARD + .decode(base64encoded_segment) + .context("Failed to base64-decode 'Basic' credentials.")?; + let decoded_credentials = String::from_utf8(decoded_bytes) + .context("The decoded credential string is not valid UTF-8.")?; + + let mut credentials = decoded_credentials.splitn(2, ':'); + let username = credentials + .next() + .context("A username must be provided in 'Basic' auth.")? + .to_string(); + let password = credentials + .next() + .context("A password must be provided in 'Basic' auth.")? + .to_string(); + + Ok(Credentials { + username, + password: SecretString::from(password), + }) +} + #[allow(dead_code)] struct ConfirmedSubscriber { name: String, diff --git a/src/routes/subscriptions.rs b/src/routes/subscriptions.rs index 7ffc222..04360e2 100644 --- a/src/routes/subscriptions.rs +++ b/src/routes/subscriptions.rs @@ -85,6 +85,7 @@ pub async fn subscribe( connection_pool, email_client, base_url, + .. }): State, Form(form): Form, ) -> Result { diff --git a/src/routes/subscriptions_confirm.rs b/src/routes/subscriptions_confirm.rs index eaec61d..4ef60b5 100644 --- a/src/routes/subscriptions_confirm.rs +++ b/src/routes/subscriptions_confirm.rs @@ -1,3 +1,4 @@ +use crate::startup::AppState; use axum::{ extract::{Query, State}, http::StatusCode, @@ -7,8 +8,6 @@ use serde::Deserialize; use sqlx::PgPool; use uuid::Uuid; -use crate::startup::AppState; - #[tracing::instrument(name = "Confirming new subscriber", skip(params))] pub async fn confirm( State(AppState { diff --git a/src/startup.rs b/src/startup.rs index d534425..d96b4e0 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -5,10 +5,13 @@ use axum::{ http::Request, routing::{get, post}, }; +use axum_messages::MessagesManagerLayer; +use secrecy::SecretString; use sqlx::{PgPool, postgres::PgPoolOptions}; use std::sync::Arc; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; +use tower_sessions::{MemoryStore, SessionManagerLayer}; use uuid::Uuid; pub struct Application { @@ -21,6 +24,7 @@ pub struct AppState { pub connection_pool: PgPool, pub email_client: Arc, pub base_url: String, + pub hmac_secret: SecretString, } impl Application { @@ -37,6 +41,7 @@ impl Application { connection_pool, email_client, configuration.application.base_url, + configuration.application.hmac_secret, ); Ok(Self { listener, router }) } @@ -51,35 +56,43 @@ impl Application { } } -pub fn app(connection_pool: PgPool, email_client: EmailClient, base_url: String) -> Router { +pub fn app( + connection_pool: PgPool, + email_client: EmailClient, + base_url: String, + hmac_secret: SecretString, +) -> Router { let app_state = AppState { connection_pool, email_client: Arc::new(email_client), base_url, + hmac_secret, }; Router::new() + .route("/", get(home)) + .route("/login", get(get_login).post(post_login)) .route("/health_check", get(health_check)) .route("/subscriptions", post(subscribe)) .route("/subscriptions/confirm", get(confirm)) .route("/newsletters", post(publish_newsletter)) .layer( - TraceLayer::new_for_http() - .make_span_with(|request: &Request<_>| { - let matched_path = request - .extensions() - .get::() - .map(MatchedPath::as_str); - let request_id = Uuid::new_v4().to_string(); + TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { + let matched_path = request + .extensions() + .get::() + .map(MatchedPath::as_str); + let request_id = Uuid::new_v4().to_string(); - tracing::info_span!( - "http_request", - method = ?request.method(), - matched_path, - request_id, - some_other_field = tracing::field::Empty, - ) - }) - .on_failure(()), + tracing::info_span!( + "http_request", + method = ?request.method(), + matched_path, + request_id, + some_other_field = tracing::field::Empty, + ) + }), ) + .layer(MessagesManagerLayer) + .layer(SessionManagerLayer::new(MemoryStore::default())) .with_state(app_state) } diff --git a/src/telemetry.rs b/src/telemetry.rs index 1343c68..45bb66f 100644 --- a/src/telemetry.rs +++ b/src/telemetry.rs @@ -1,3 +1,4 @@ +use tokio::task::JoinHandle; use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer}; use tracing_subscriber::{fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt}; @@ -20,3 +21,12 @@ where .with(formatting_layer) .init(); } + +pub fn spawn_blocking_with_tracing(f: F) -> JoinHandle +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let current_span = tracing::Span::current(); + tokio::task::spawn_blocking(move || current_span.in_scope(f)) +} diff --git a/tests/api/helpers.rs b/tests/api/helpers.rs index 14930aa..0ed7bf5 100644 --- a/tests/api/helpers.rs +++ b/tests/api/helpers.rs @@ -1,3 +1,7 @@ +use argon2::{ + Argon2, PasswordHasher, + password_hash::{SaltString, rand_core::OsRng}, +}; use linkify::LinkFinder; use once_cell::sync::Lazy; use sqlx::{Connection, Executor, PgConnection, PgPool}; @@ -22,34 +26,49 @@ pub struct ConfirmationLinks { pub text: reqwest::Url, } +pub struct TestUser { + pub user_id: Uuid, + pub username: String, + pub password: String, +} + +impl TestUser { + pub fn generate() -> Self { + Self { + user_id: Uuid::new_v4(), + username: Uuid::new_v4().to_string(), + password: Uuid::new_v4().to_string(), + } + } + + pub async fn store(&self, connection_pool: &PgPool) { + let salt = SaltString::generate(&mut OsRng); + let password_hash = Argon2::default() + .hash_password(self.password.as_bytes(), &salt) + .unwrap() + .to_string(); + sqlx::query!( + "INSERT INTO users (user_id, username, password_hash) VALUES ($1, $2, $3)", + self.user_id, + self.username, + password_hash + ) + .execute(connection_pool) + .await + .expect("Failed to create test user"); + } +} + pub struct TestApp { pub address: String, pub connection_pool: PgPool, pub email_server: wiremock::MockServer, pub port: u16, + pub test_user: TestUser, + pub api_client: reqwest::Client, } impl TestApp { - pub fn get_confirmation_links(&self, request: &wiremock::Request) -> ConfirmationLinks { - let body: serde_json::Value = serde_json::from_slice(&request.body).unwrap(); - let get_link = |s: &str| { - let links: Vec<_> = LinkFinder::new() - .links(s) - .filter(|l| *l.kind() == linkify::LinkKind::Url) - .collect(); - assert_eq!(links.len(), 1); - let raw_link = links[0].as_str(); - let mut confirmation_link = reqwest::Url::parse(raw_link).unwrap(); - assert_eq!(confirmation_link.host_str().unwrap(), "127.0.0.1"); - confirmation_link.set_port(Some(self.port)).unwrap(); - confirmation_link - }; - - let html = get_link(body["html"].as_str().unwrap()); - let text = get_link(body["text"].as_str().unwrap()); - ConfirmationLinks { html, text } - } - pub async fn spawn() -> Self { Lazy::force(&TRACING); @@ -73,11 +92,20 @@ impl TestApp { .parse::() .unwrap(); let address = format!("http://{}", application.local_addr()); + let test_user = TestUser::generate(); + test_user.store(&connection_pool).await; + let api_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .cookie_store(true) + .build() + .unwrap(); let app = TestApp { address, connection_pool, email_server, port, + test_user, + api_client, }; tokio::spawn(application.run_until_stopped()); @@ -85,8 +113,39 @@ impl TestApp { app } + pub fn get_confirmation_links(&self, request: &wiremock::Request) -> ConfirmationLinks { + let body: serde_json::Value = serde_json::from_slice(&request.body).unwrap(); + let get_link = |s: &str| { + let links: Vec<_> = LinkFinder::new() + .links(s) + .filter(|l| *l.kind() == linkify::LinkKind::Url) + .collect(); + assert_eq!(links.len(), 1); + let raw_link = links[0].as_str(); + let mut confirmation_link = reqwest::Url::parse(raw_link).unwrap(); + assert_eq!(confirmation_link.host_str().unwrap(), "127.0.0.1"); + confirmation_link.set_port(Some(self.port)).unwrap(); + confirmation_link + }; + + let html = get_link(body["html"].as_str().unwrap()); + let text = get_link(body["text"].as_str().unwrap()); + ConfirmationLinks { html, text } + } + + pub async fn get_login_html(&self) -> String { + self.api_client + .get(format!("{}/login", &self.address)) + .send() + .await + .expect("Failed to execute request") + .text() + .await + .unwrap() + } + pub async fn post_subscriptions(&self, body: String) -> reqwest::Response { - reqwest::Client::new() + self.api_client .post(format!("{}/subscriptions", self.address)) .header("Content-Type", "application/x-www-form-urlencoded") .body(body) @@ -99,6 +158,19 @@ impl TestApp { reqwest::Client::new() .post(format!("{}/newsletters", self.address)) .json(&body) + .basic_auth(&self.test_user.username, Some(&self.test_user.password)) + .send() + .await + .expect("Failed to execute request") + } + + pub async fn post_login(&self, body: &Body) -> reqwest::Response + where + Body: serde::Serialize, + { + self.api_client + .post(format!("{}/login", self.address)) + .form(body) .send() .await .expect("Failed to execute request") @@ -124,3 +196,8 @@ async fn configure_database(config: &DatabaseSettings) -> PgPool { connection_pool } + +pub fn assert_is_redirect_to(response: &reqwest::Response, location: &str) { + assert_eq!(response.status().as_u16(), 303); + assert_eq!(response.headers().get("Location").unwrap(), location); +} diff --git a/tests/api/login.rs b/tests/api/login.rs new file mode 100644 index 0000000..d4c9b2f --- /dev/null +++ b/tests/api/login.rs @@ -0,0 +1,22 @@ +use crate::helpers::{TestApp, assert_is_redirect_to}; + +#[tokio::test] +async fn an_error_flash_message_is_set_on_failure() { + let app = TestApp::spawn().await; + + let login_body = serde_json::json!({ + "username": "user", + "password": "password" + }); + + let response = app.post_login(&login_body).await; + + assert_eq!(response.status().as_u16(), 303); + assert_is_redirect_to(&response, "/login"); + + let login_page_html = app.get_login_html().await; + assert!(login_page_html.contains("Authentication failed")); + + let login_page_html = app.get_login_html().await; + assert!(!login_page_html.contains("Authentication failed")); +} diff --git a/tests/api/main.rs b/tests/api/main.rs index 129fe56..d7a12d3 100644 --- a/tests/api/main.rs +++ b/tests/api/main.rs @@ -1,5 +1,6 @@ mod health_check; mod helpers; +mod login; mod newsletters; mod subscriptions; mod subscriptions_confirm; diff --git a/tests/api/newsletters.rs b/tests/api/newsletters.rs index 468aad2..e145907 100644 --- a/tests/api/newsletters.rs +++ b/tests/api/newsletters.rs @@ -1,4 +1,5 @@ use crate::helpers::{ConfirmationLinks, TestApp}; +use uuid::Uuid; use wiremock::{ Mock, ResponseTemplate, matchers::{any, method, path}, @@ -21,6 +22,87 @@ async fn newsletters_are_not_delivered_to_unconfirmed_subscribers() { assert_eq!(response.status().as_u16(), 200); } +#[tokio::test] +async fn request_missing_authorization_are_rejected() { + let app = TestApp::spawn().await; + + let newsletter_request_body = serde_json::json!({ + "title": "Newsletter title", + "content": { + "text": "Newsletter body as plain text", + "html": "

Newsletter body as HTML

" + } + }); + let response = reqwest::Client::new() + .post(format!("{}/newsletters", &app.address)) + .json(&newsletter_request_body) + .send() + .await + .expect("Failed to execute request"); + + assert_eq!(response.status().as_u16(), 401); + assert_eq!( + response.headers()["WWW-Authenticate"], + r#"Basic realm="publish""# + ); +} + +#[tokio::test] +async fn non_existing_user_is_rejected() { + let app = TestApp::spawn().await; + + let newsletter_request_body = serde_json::json!({ + "title": "Newsletter title", + "content": { + "text": "Newsletter body as plain text", + "html": "

Newsletter body as HTML

" + } + }); + let username = Uuid::new_v4().to_string(); + let password = Uuid::new_v4().to_string(); + let response = reqwest::Client::new() + .post(format!("{}/newsletters", &app.address)) + .json(&newsletter_request_body) + .basic_auth(username, Some(password)) + .send() + .await + .expect("Failed to execute request"); + + assert_eq!(response.status().as_u16(), 401); + assert_eq!( + response.headers()["WWW-Authenticate"], + r#"Basic realm="publish""# + ); +} + +#[tokio::test] +async fn invalid_password_is_rejected() { + let app = TestApp::spawn().await; + + let newsletter_request_body = serde_json::json!({ + "title": "Newsletter title", + "content": { + "text": "Newsletter body as plain text", + "html": "

Newsletter body as HTML

" + } + }); + let username = app.test_user.username; + let password = Uuid::new_v4().to_string(); + let response = reqwest::Client::new() + .post(format!("{}/newsletters", &app.address)) + .json(&newsletter_request_body) + .basic_auth(username, Some(password)) + .send() + .await + .expect("Failed to execute request"); + + assert_eq!(response.status().as_u16(), 401); + assert_eq!( + response.headers()["WWW-Authenticate"], + r#"Basic realm="publish""# + ); +} + #[tokio::test] async fn newsletters_are_delivered_to_confirmed_subscribers() { let app = TestApp::spawn().await;