diff --git a/Cargo.lock b/Cargo.lock index 616c9b0..52036d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -915,6 +915,7 @@ dependencies = [ "serde", "sqlx", "tokio", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", diff --git a/crates/nodata/Cargo.toml b/crates/nodata/Cargo.toml index 4b321b9..9e03fcb 100644 --- a/crates/nodata/Cargo.toml +++ b/crates/nodata/Cargo.toml @@ -23,3 +23,4 @@ sqlx = { version = "0.7.3", features = [ uuid = { version = "1.7.0", features = ["v4"] } tower-http = { version = "0.5.2", features = ["cors", "trace"] } mad = { git = "https://github.com/kjuulh/mad", branch = "main" } +tokio-util = "0.7.11" diff --git a/crates/nodata/src/http.rs b/crates/nodata/src/http.rs new file mode 100644 index 0000000..9d8d987 --- /dev/null +++ b/crates/nodata/src/http.rs @@ -0,0 +1,65 @@ +use std::net::SocketAddr; + +use anyhow::Context; +use axum::async_trait; +use axum::extract::MatchedPath; +use axum::http::Request; +use axum::routing::get; +use axum::Router; +use mad::Component; +use mad::MadError; +use tokio_util::sync::CancellationToken; +use tower_http::trace::TraceLayer; + +use crate::state::SharedState; + +pub struct HttpServer { + state: SharedState, + host: SocketAddr, +} + +impl HttpServer { + pub fn new(state: &SharedState, host: SocketAddr) -> Self { + Self { + state: state.clone(), + host, + } + } +} + +#[async_trait] +impl Component for HttpServer { + async fn run(&self, cancellation_token: CancellationToken) -> Result<(), mad::MadError> { + let app = Router::new() + .route("/", get(root)) + .with_state(self.state.clone()) + .layer( + TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { + let matched_path = request + .extensions() + .get::() + .map(MatchedPath::as_str); + + tracing::info_span!( + "http_request", + method = ?request.method(), + matched_path, + some_other_field = tracing::field::Empty, + ) + }), + ); + + tracing::info!("http: listening on {}", self.host); + let listener = tokio::net::TcpListener::bind(self.host).await.unwrap(); + axum::serve(listener, app.into_make_service()) + .await + .context("axum server stopped") + .map_err(MadError::Inner)?; + + Ok(()) + } +} + +async fn root() -> &'static str { + "Hello, nodata!" +} diff --git a/crates/nodata/src/main.rs b/crates/nodata/src/main.rs index 839b709..0809784 100644 --- a/crates/nodata/src/main.rs +++ b/crates/nodata/src/main.rs @@ -1,14 +1,12 @@ -use std::{net::SocketAddr, ops::Deref, sync::Arc}; +mod http; +mod state; + +use std::net::SocketAddr; -use anyhow::Context; -use axum::extract::MatchedPath; -use axum::http::Request; -use axum::routing::get; -use axum::Router; use clap::{Parser, Subcommand}; -use mad::{Mad, MadError}; -use sqlx::{Pool, Postgres}; -use tower_http::trace::TraceLayer; +use http::HttpServer; +use mad::Mad; +use state::SharedState; #[derive(Parser)] #[command(author, version, about, long_about = None, subcommand_required = true)] @@ -35,82 +33,13 @@ async fn main() -> anyhow::Result<()> { if let Some(Commands::Serve { host }) = cli.command { tracing::info!("Starting service"); - let state = SharedState(Arc::new(State::new().await?)); + let state = SharedState::new().await?; - let state = state.clone(); Mad::builder() - .add_fn(move |_cancel| { - let state = state.clone(); - async move { - let app = Router::new() - .route("/", get(root)) - .with_state(state.clone()) - .layer(TraceLayer::new_for_http().make_span_with( - |request: &Request<_>| { - let matched_path = request - .extensions() - .get::() - .map(MatchedPath::as_str); - - tracing::info_span!( - "http_request", - method = ?request.method(), - matched_path, - some_other_field = tracing::field::Empty, - ) - }, - )); - - tracing::info!("listening on {}", host); - let listener = tokio::net::TcpListener::bind(host).await.unwrap(); - axum::serve(listener, app.into_make_service()) - .await - .context("axum server stopped") - .map_err(MadError::Inner)?; - - Ok(()) - } - }) + .add(HttpServer::new(&state, host)) .run() .await?; } Ok(()) } - -async fn root() -> &'static str { - "Hello, nodata!" -} - -#[derive(Clone)] -pub struct SharedState(Arc); - -impl Deref for SharedState { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -pub struct State { - pub db: Pool, -} - -impl State { - pub async fn new() -> anyhow::Result { - let db = sqlx::PgPool::connect( - &std::env::var("DATABASE_URL").context("DATABASE_URL is not set")?, - ) - .await?; - - sqlx::migrate!("migrations/crdb") - .set_locking(false) - .run(&db) - .await?; - - let _ = sqlx::query("SELECT 1;").fetch_one(&db).await?; - - Ok(Self { db }) - } -} diff --git a/crates/nodata/src/state.rs b/crates/nodata/src/state.rs new file mode 100644 index 0000000..a02d80a --- /dev/null +++ b/crates/nodata/src/state.rs @@ -0,0 +1,43 @@ +use std::{ops::Deref, sync::Arc}; + +use anyhow::Context; +use sqlx::{Pool, Postgres}; + +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + pub async fn new() -> anyhow::Result { + Ok(Self(Arc::new(State::new().await?))) + } +} + +impl Deref for SharedState { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct State { + pub db: Pool, +} + +impl State { + pub async fn new() -> anyhow::Result { + let db = sqlx::PgPool::connect( + &std::env::var("DATABASE_URL").context("DATABASE_URL is not set")?, + ) + .await?; + + sqlx::migrate!("migrations/crdb") + .set_locking(false) + .run(&db) + .await?; + + let _ = sqlx::query("SELECT 1;").fetch_one(&db).await?; + + Ok(Self { db }) + } +}