Files
zero2prod/src/startup.rs
Alphonse Paix 54218f92a9 Admin can now write posts
Posts can be displayed on the website. Subscribers are automatically
notified by email. This gives the opportunity to track explicitly how
many people followed the link provided in the emails sent without being
intrusive (no invisible image).
2025-09-18 17:22:33 +02:00

153 lines
4.8 KiB
Rust

use crate::{
authentication::require_auth, configuration::Settings, email_client::EmailClient, routes::*,
};
use axum::{
Router,
extract::MatchedPath,
http::Request,
middleware,
routing::{get, post},
};
use axum_server::tls_rustls::RustlsConfig;
use secrecy::ExposeSecret;
use sqlx::{PgPool, postgres::PgPoolOptions};
use std::{net::TcpListener, sync::Arc};
use tower_http::{services::ServeDir, trace::TraceLayer};
use tower_sessions::SessionManagerLayer;
use tower_sessions_redis_store::{
RedisStore,
fred::prelude::{ClientLike, Config, Pool},
};
use uuid::Uuid;
#[derive(Clone)]
pub struct AppState {
pub connection_pool: PgPool,
pub email_client: Arc<EmailClient>,
pub base_url: String,
}
pub struct Application {
listener: TcpListener,
router: Router,
tls_config: Option<RustlsConfig>,
}
impl Application {
pub async fn build(configuration: Settings) -> Result<Self, anyhow::Error> {
let address = format!(
"{}:{}",
configuration.application.host, configuration.application.port
);
let connection_pool =
PgPoolOptions::new().connect_lazy_with(configuration.database.with_db());
let email_client = EmailClient::build(configuration.email_client).unwrap();
let pool = Pool::new(
Config::from_url(configuration.redis_uri.expose_secret())
.expect("Failed to parse Redis URL string"),
None,
None,
None,
6,
)
.unwrap();
pool.connect();
pool.wait_for_connect().await.unwrap();
let redis_store = RedisStore::new(pool);
let router = app(
connection_pool,
email_client,
configuration.application.base_url,
redis_store,
);
let tls_config = if configuration.require_tls {
Some(
RustlsConfig::from_pem_file(
std::env::var("APP_TLS_CERT")
.expect("Failed to read TLS certificate environment variable"),
std::env::var("APP_TLS_KEY")
.expect("Feiled to read TLS private key environment variable"),
)
.await
.expect("Could not create TLS configuration"),
)
} else {
None
};
let listener = TcpListener::bind(address).unwrap();
Ok(Self {
listener,
router,
tls_config,
})
}
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
tracing::debug!("listening on {}", self.local_addr());
if let Some(tls_config) = self.tls_config {
axum_server::from_tcp_rustls(self.listener, tls_config)
.serve(self.router.into_make_service())
.await
} else {
axum_server::from_tcp(self.listener)
.serve(self.router.into_make_service())
.await
}
}
pub fn local_addr(&self) -> String {
self.listener.local_addr().unwrap().to_string()
}
pub fn port(&self) -> u16 {
self.listener.local_addr().unwrap().port()
}
}
pub fn app(
connection_pool: PgPool,
email_client: EmailClient,
base_url: String,
redis_store: RedisStore<Pool>,
) -> Router {
let app_state = AppState {
connection_pool,
email_client: Arc::new(email_client),
base_url,
};
let admin_routes = Router::new()
.route("/dashboard", get(admin_dashboard))
.route("/password", post(change_password))
.route("/newsletters", post(publish_newsletter))
.route("/posts", post(create_post))
.route("/logout", post(logout))
.layer(middleware::from_fn(require_auth));
Router::new()
.nest_service("/assets", ServeDir::new("assets"))
.route("/", get(home))
.route("/login", get(get_login).post(post_login))
.route("/health_check", get(health_check))
.route("/subscriptions", post(subscribe))
.route("/subscriptions/confirm", get(confirm))
.nest("/admin", admin_routes)
.layer(
TraceLayer::new_for_http().make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
let request_id = Uuid::new_v4().to_string();
tracing::info_span!(
"http_request",
method = ?request.method(),
matched_path,
request_id,
some_other_field = tracing::field::Empty,
)
}),
)
.layer(SessionManagerLayer::new(redis_store).with_secure(false))
.with_state(app_state)
}