Migrating from warp to axum
👋 This page was last updated ~2 years ago. Just so you know.
Falling out of love with warp
Back when I wrote this codebase, warp was the best / only alternative for something relatively high-level on top of hyper.
I was never super fond of warp's model — it's a fine crate, just not for me.
The way routing works is essentially building a type that gets larger and larger. One route might look like:
let bye = warp::path("bye") .and(warp::path::param()) .map(|name: String| format!("Good bye, {}!", name));
And then all the routes combined might look like:
let routes = warp::get().and( hello_world .or(hi) .or(hello_from_warp) .or(bye) .or(math) .or(sum) .or(times), );
That's from warp's routing example.
We can look at the type of the whole routes
expression with this simple trick:
let a: () = routes;
And as we see, it's fairly large:
$ cargo check --examples cargo check --examples Checking warp v0.3.3 (/home/amos/bearcove/warp) error[E0308]: mismatched types --> examples/routing.rs:103:17 (cut) = note: expected unit type `()` found struct `warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>, warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::map::Map<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>, [closure@examples/routing.rs:13:45: 13:47]>, warp::filter::map::Map<Exact<warp::path::internal::Opaque<&str>>, [closure@examples/routing.rs:16:35: 16:37]>>, warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:21:70: 21:72]>>, warp::filter::map::Map<warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (String,), Error = Rejection>>, [closure@examples/routing.rs:51:14: 51:28]>>, warp::filter::or::Or<warp::filter::map::Map<warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:70:14: 70:16]>, warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, warp::filter::or::Or<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:26:50: 26:56]>, warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:32:46: 32:52]>>>>>, warp::filter::map::Map<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:26:50: 26:56]>, [closure@examples/routing.rs:74:23: 74:31]>>, warp::filter::map::Map<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:32:46: 32:52]>, [closure@examples/routing.rs:76:19: 76:27]>>>` For more information about this error, try `rustc --explain E0308`. error: could not compile `warp` due to previous error
(In fact, it made the estimated reading time for this article jump from 8 to 13 minutes, all by itself).
This is an issue on several levels. First, that's a lot of work for the compiler. "My warp app takes forever to compile" is the most active issue on the issue tracker. In fact, I've written about this specific problem before.
Second, the errors are awful. In your mind, I might be pretty good at Rust, but facing those errors, even I recoiled in horror. I just had no idea what it wanted me to do. So I ended up doing some... terrible, non-warpy things. This is what my router ended up looking like:
// this is a `ServerState` struct that has a cache, the config (including secrets) // stuff like that. we want that from any request handlers. let with_state = warp::any().map(move || server_state.clone()); // because most of my website is dynamic and rendered from liquid templates, // (and those are cached), I need stuff like the full path, raw query string, // cookies (for login state), etc. let with_cx = with_state .clone() .and(warp::filters::path::full()) .and( warp::filters::query::raw() .recover(|_| async { let res: Result<String, Infallible> = Ok("".to_string()); res }) .unify(), ) .and(warp::filters::header::optional("cookie")) .and( warp::filters::body::bytes() .recover(|rej| { warn!("while getting body bytes: {:?}", rej); async { let res: Result<Bytes, Infallible> = Ok(Default::default()); res } }) .unify(), ) .map( |state, full_path: FullPath, query: String, cookie_header: Option<String>, body: Bytes| { let mut cookies = HashMap::new(); if let Some(ch) = cookie_header { for cookie in ch .split(';') .map(|x| x.trim()) .filter_map(|x| Cookie::parse_encoded(x.to_string()).ok()) { cookies.insert(cookie.name().to_string(), cookie); } } let cx = Context { state, raw_query: Arc::new(Query::new(query)), path: full_path.as_str().to_string(), cookies, raw_body: body, }; cx }, ); // I forget why this is needed 🙃 but it is. let with_cx = move || with_cx.clone(); // This is just how I did live-reloading. use std::convert::Infallible; use warp::filters::sse::Event as ServerSentEvent; // This is a free-standing function because it's the only reasonable way // to annotate the error type of the TryStream. fn sse_events(cx: Context) -> impl Stream<Item = Result<ServerSentEvent, Infallible>> { let mut rx = cx.state.broadcast_rev.subscribe(); // cf. https://crates.io/crates/async-stream // could be written without it, but I like that crate. async_stream::try_stream! { yield ServerSentEvent::default().event("message").data("Live reloading enabled"); loop { tokio::select! { Ok(rev_id) = rx.recv() => { yield ServerSentEvent::default().event("new-revision").data(rev_id); }, _ = tokio::time::sleep(Duration::from_secs(1)) => { yield ServerSentEvent::default().event("ping").data("Just keeping connections alive"); }, } } } } let livereload_route = method::get() .and(warp::path!("api" / "livereload")) .and(with_cx()) .map(|cx: Context| warp::sse::reply(warp::sse::keep_alive().stream(sse_events(cx)))); // uwu what's this? a single route? weird... let catchall_get = method::get() .and(with_cx()) .and_then(|cx: Context| cx.handle(routes::serve_get)); let catchall_head = method::head() .and(with_cx()) .and_then(|cx: Context| cx.handle(routes::serve_get)); let catchall_post = method::post() .and(with_cx()) .and_then(|cx: Context| cx.handle(routes::serve_post)); let all_routes = livereload_route .boxed() .or(catchall_get.boxed()) .or(catchall_head.boxed()) .or(catchall_post.boxed()); let access_log = warp::filters::log::log("access");
It used to be much worse, because I used to use warp
the way it's /meant/ to
be used... before I gave up and just started using handlers like serve_get
:
pub async fn serve_get(cx: Context) -> HttpResult { if cx.path.ends_with('/') && cx.path != "/" { return serve_redirect(cx.path.trim_end_matches('/')).await; } if cx.path == "/robots.txt" { return Ok(Box::new( Response::builder() .status(StatusCode::OK) .body(Body::empty()), )); } if cx.path == "/tags" { return tags::serve_list(cx).await; } else if cx.path.starts_with("/tags/") { return tags::serve_single(cx).await; } if cx.path == "/settings" { return settings::serve(cx).await; } if cx.path == "/search" { return search::serve(cx).await; } if cx.path == "/login" { return login::serve_login(cx).await; } if cx.path == "/patreon/oauth" { return login::serve_patreon_oauth(cx).await; } if cx.path == "/logout" { return login::serve_logout(cx).await; } if cx.path == "/debug-credentials" { return login::serve_debug_credentials(cx).await; } if cx.path == "/comments" { return comments::serve(cx).await; } if cx.path == "/latest-video" { return latest_video::serve(cx).await; } if cx.path == "/patron-list" { return patron_list::serve(cx).await; } if cx.path == "/index.xml" { return cx .serve_template("index.xml", Default::default(), mime::atom()) .await; } revision_routes::serve(cx).await }
A single route would then look something like:
pub async fn serve_login(cx: Context) -> HttpResult { #[derive(Deserialize)] struct QueryParams { #[serde(default)] return_to: Option<String>, } let location = make_patreon_login_url(cx.config())?; let mut res = Response::builder() .status(StatusCode::TEMPORARY_REDIRECT) .header("location", location); if let Ok(params) = cx.query::<QueryParams>() { if let Some(return_to) = params.return_to { let mut cookie = Cookie::new("return_to", return_to); cookie.set_path("/"); cookie.set_expires(time::OffsetDateTime::now_utc() + time::Duration::minutes(30)); res = res.insert_cookie(cookie); } } let res = res.body("Redirecting to Patreon login...".to_string())?; Ok(Box::new(res)) }
This is absolutely not the way warp
was meant to be used, but this is what I
ended up doing, to survive the compile times + type spaghetti party.
The HttpResult
type enforced boxed replies, to avoid the "I'm returning
different types from different codepaths and they won't unify" problem:
pub use warp::Reply; pub type HttpResult = color_eyre::Result<Box<dyn Reply>>;
The opinions of axum
, also nice error handling
One nice thing about axum is that it defaults to something simple. It provides its own BoxBody body type, and an axum::body::boxed function to turn damn near anything into it.
It has an IntoResponse trait, which lets your handlers look like:
// basic handler that responds with a static string async fn root() -> &'static str { "Hello, World!" }
(Because IntoResponse
is implemented for &'static str
)
Or like:
async fn create_user( // this argument tells axum to parse the request body // as JSON into a `CreateUser` type Json(payload): Json<CreateUser>, ) -> impl IntoResponse { // insert your application logic here let user = User { id: 1337, username: payload.username, }; // this will be converted into a JSON response // with a status code of `201 Created` (StatusCode::CREATED, Json(user)) }
(Because IntoResponse
is implemented for tuples like (StatusCode, IntoResponse)
).
There's still the "type unification" footgun with that approach — if you
conditionally returned (StatusCode::FORBIDDEN, "nuh-huh")
, you'd get a compile
error, and you'd need to change the return type to Response<BoxBody>
, and call
.into_response()
on both these tuples.
It's also useful to be able to turn errors into HTTP responses, so, I came up
with my own HttpResult
type:
// in `crates/futile/src/serve/routes/response.rs` use std::borrow::Cow; use axum::response::{IntoResponse, Response}; use color_eyre::Report; use futile_backtrace_printer::make_backtrace_printer; use futile_config::is_production; use http::{header, StatusCode}; use time::format_description::well_known::Rfc3339; use tracing::error; use crate::serve::{ html_color_output::HtmlColorOutput, routes::mime, template_renderer::inject_livereload, }; pub type HttpResult = Result<Response, HttpError>; pub trait IntoHttp { fn into_http(self) -> HttpResult; } impl<T: IntoResponse> IntoHttp for T { fn into_http(self) -> HttpResult { Ok(self.into_response()) } } #[derive(Debug)] pub enum HttpError { NotFound { msg: Cow<'static, str> }, Internal { err: String }, } impl HttpError { fn from_report(err: Report) -> Self { error!("HTTP handler error: {}", err.root_cause()); let bt_printer = make_backtrace_printer(); let maybe_bt = err .context() .downcast_ref::<color_eyre::Handler>() .and_then(|h| h.backtrace()); if let Some(bt) = maybe_bt { error!("Backtrace:"); let mut stream = termcolor::StandardStream::stderr(termcolor::ColorChoice::Always); bt_printer.print_trace(bt, &mut stream).unwrap(); } else { error!("No Backtrace"); } let trace_content = if is_production() { "".into() } else { let mut err_string = String::new(); for (i, e) in err.chain().enumerate() { use std::fmt::Write; let _ = writeln!(&mut err_string, "{}. {}", i + 1, e); } let mut err_string_escaped = String::new(); html_escape::encode_safe_to_string(&err_string, &mut err_string_escaped); let backtrace: String = if let Some(bt) = maybe_bt { let mut output = HtmlColorOutput::new(); bt_printer.print_trace(bt, &mut output).unwrap(); output.into() } else { "".into() }; format!( r#" <pre class="trace">{err_string_escaped} {backtrace}</pre> "#, ) }; let date = time::OffsetDateTime::now_utc().format(&Rfc3339).unwrap(); let body = format!( r#" <html> {padding} <head> <style> @import url('https://fonts.googleapis.com/css2?family=Alegreya+Sans&family=Source+Code+Pro&display=swap'); body {{ font-family: "Alegreya Sans", sans-serif; max-width: 1200px; margin: 20px auto; line-height: 1.6; }} h2 {{ color: #b73535; }} pre.trace {{ font-family: "Source Code Pro", monospace; padding: 1em; border: 2px solid #999; font-size: 14px; background: #212121; color: #d9d9d9; white-space: pre-wrap; overflow-x: auto; }} </style> </head> <body> <h1>Something went terribly wrong ({date})</h1> {trace_content} <p> Try <a href="/">Going back to the homepage</a>, maybe? </p> </body> </html> "#, padding = "<!-- Padding to avoid browser 500 error -->\n".repeat(10), trace_content = trace_content ); HttpError::Internal { err: body } } } macro_rules! impl_from { ($from:ty) => { impl From<$from> for HttpError { fn from(err: $from) -> Self { Self::from_report(err.into()) } } }; } impl_from!(std::io::Error); impl_from!(color_eyre::Report); impl_from!(http::Error); impl_from!(http::header::InvalidHeaderValue); impl_from!(http::uri::InvalidUri); impl_from!(serde_json::Error); impl_from!(url::ParseError); impl_from!(liquid_core::Error); impl_from!(sqlx::Error); impl_from!(crate::cached::CachedError); impl_from!(futile_patreon::PatreonError); impl IntoResponse for HttpError { fn into_response(self) -> Response { match self { HttpError::NotFound { msg } => (StatusCode::NOT_FOUND, msg).into_response(), HttpError::Internal { err } => ( StatusCode::INTERNAL_SERVER_ERROR, [(header::CONTENT_TYPE, mime::html())], inject_livereload(&err).to_string(), ) .into_response(), } } }
There's a lot here, and I'm not going to comment everything, but essentially, by default, things like deserialization errors etc., get turned into 500 Internal Server Errors, with a nice (colored!) stack trace in development and an opaque error in production.
Extractors, handling shared state in axum
0.6
That means my handlers can return that type, and use ?
(unless I want to map
those errors to more precise status codes with nicer error pages):
#[axum::debug_handler(state = crate::serve::ServerState)] pub async fn serve_login( State(state): State<ServerState>, cookies: Cookies, params: Option<Form<ReturnTo>>, ) -> HttpResult { let cookies = state.private_cookies(cookies); if let Some(params) = params { if let Some(return_to) = params.return_to.as_ref() { let mut cookie = Cookie::new("return_to", return_to.to_owned()); cookie.set_path("/"); cookie.set_expires(time::OffsetDateTime::now_utc() + time::Duration::minutes(30)); cookies.add(cookie); } } let location = make_patreon_login_url(&state.config)?; Redirect::to(&location).into_http() }
The axum::debug_handler macro is invaluable to debug type errors (there's some with axum too), like for example, accidentally having a non-Send type slip in. It lets you specify which state type you're using in that application, which is actually unnecessary here since it's also specified as an extractor, in this here line:
State(state): State<ServerState>,
Cookies are extracted, too, via tower-cookies:
cookies: Cookies,
Lastly, query parameters are extracted (optionally) via the Form
extractor:
params: Option<Form<ReturnTo>>,
The ReturnTo
type is a struct that derives serde's Deserialize
:
#[derive(Deserialize)] pub struct ReturnTo { #[serde(default)] return_to: Option<String>, }
And that lets us decode a URL like
/login?return_to=articles%2Fupdating-my-website-for-2022
into something nice
and strongly typed.
Each of these extractors implement FromRequestParts
, which is something you
can just do yourself, so I ended up making up a TemplateRenderer
extractor,
since, again, a lot of my website is dynamic, and templates take a lot of state
as input.
A handler using it looks something like:
#[axum::debug_handler(state = crate::serve::ServerState)] async fn atom_feed(tr: TemplateRenderer) -> HttpResult { tr.render_ex( "index.xml", TemplateCacheBehavior::Cache, Default::default(), mime::atom(), ) .await }
Or, even simpler:
#[axum::debug_handler(state = crate::serve::ServerState)] pub async fn serve(tr: TemplateRenderer) -> HttpResult { tr.render("settings.html").await }
The live-reload stuff, using server-sent
events,
looks very similar, since warp
and axum
are using the same underlying
crates:
#[axum::debug_handler(state = crate::serve::ServerState)] async fn livereload(State(state): State<ServerState>) -> HttpResult { fn make_stream(state: &ServerState) -> impl Stream<Item = Result<sse::Event, Infallible>> { let mut rx = state.broadcast_rev.subscribe(); async_stream::try_stream! { yield sse::Event::default().event("message").data("Live reloading enabled"); while let Ok(rev_id) = rx.recv().await { yield sse::Event::default().event("new-revision").data(rev_id); } } } Sse::new(make_stream(&state)) .keep_alive(KeepAlive::default()) .into_http() }
(I also discovered I didn't need to do my own keep-alive handling — it's built right in there!)
Routing in axum
0.6
How are these routes composed then? Up until axum 0.5, I was very sad, because
routes like /fixed
and /*splat
used to conflict (you'd get a runtime error
too, yuck). But as of 0.6 (currently in release candidate) they don't anymore!
The fixed one takes priority, and it falls back to the splat.
So, my whole router looks like this:
pub fn all_routes(state: ServerState) -> Router<ServerState> { Router::with_state(state.clone()) .route("/robots.txt", get(robots_txt)) .route("/api/livereload", get(livereload)) .route("/tags", get(tags::serve_list)) .route("/tags/:tag", get(tags::serve_single)) .route("/settings", get(settings::serve).post(settings::save)) .route("/search", get(search::serve)) .route("/login", get(login::serve_login)) .route("/debug-credentials", get(login::serve_debug_credentials)) .route("/patreon/oauth", get(login::serve_patreon_oauth)) .route("/logout", get(login::serve_logout)) .route("/comments", get(comments::serve)) .route("/patron-list", get(patron_list::serve)) .route("/index.xml", get(atom_feed)) .route("/", get(revision_routes::serve)) .route("/*path", get(revision_routes::serve)) }
Different handlers can be set with method functions/methods like get
and
post
— and you get HEAD
support for free for get
handlers, as expected. Oh
and that type never "gets bigger" in a way that would cause compile-time
explosions.
You can do a bunch of interesting stuff like "nest" sub-routers, that I haven't shown here. You might be wondering how extracting path elements works and it's... through extractors, as usual!
Here's both routes for /tags
and /tags/:tag
:
#[axum::debug_handler(state = crate::serve::ServerState)] pub async fn serve_list(tr: TemplateRenderer) -> HttpResult { tr.render("tags.html").await } #[axum::debug_handler(state = crate::serve::ServerState)] pub async fn serve_single(tr: TemplateRenderer, Path(tag): Path<String>) -> HttpResult { let mut globals = Object::new(); globals.insert("tag_name".into(), to_value(&tag)?); tr.render_ex( "tag.html", TemplateCacheBehavior::Cache, globals, mime::html(), ) .await }
(Note: the name that Path
is destructured into doesn't matter, only order
matters).
I haven't shown how the whole app was started - here it is with warp
:
let addr: SocketAddr = config.address.parse()?; let conns_limit = Semaphore::new(2048); let svc = ServiceBuilder::new() .layer(GlobalConcurrencyLimitLayer::new(2048)) .layer(BanUserAgents::new()) .layer(RequestIdLayer) .layer(IncomingHttpSpanLayer::default()) .service(warp::service(all_routes.with(access_log))); let factory = ServiceFactory { inner: svc, semaphore: PollSemaphore::new(Arc::new(conns_limit)), permit: None, }; let acceptor = timeout_acceptor(addr); let server = hyper::Server::builder(acceptor).serve(factory); info!("Listening on http://{}", addr); info!("Base URL is {}", config.base_url); server.await?;
And here it is with axum
:
let addr: SocketAddr = config.address.parse()?; let acceptor = timeout_acceptor(addr); let app = routes::all_routes(server_state).layer( ServiceBuilder::new() .layer(GlobalConcurrencyLimitLayer::new(2048)) .layer(TraceLayer::new_for_http()) .layer(BanUserAgents::new()) .layer(RequestIdLayer) .layer(IncomingHttpSpanLayer::default()) .layer(CookieManagerLayer::new()), ); let server = hyper::Server::builder(acceptor).serve(app.into_make_service()); info!("Listening on http://{}", addr); info!("Base URL is {}", config.base_url); server.await?;
That's right, I was able to re-use almost all my custom tower layers with almost no changes (I had to make one accept different body types, an annoying ordeal I've become exceedingly efficient at).
Using tower-cookies (properly)
In the previous iteration, I was using the cookie crate to allow users to log in. But I don't think I ever properly understood how it was supposed to be used, so... I came up with my own signing scheme (I didn't need them to be tamper-proof).
And I had code of my own to extract/set cookies. You've seen that bit from the warp version:
let mut cookies = HashMap::new(); if let Some(ch) = cookie_header { for cookie in ch .split(';') .map(|x| x.trim()) .filter_map(|x| Cookie::parse_encoded(x.to_string()).ok()) { cookies.insert(cookie.name().to_string(), cookie); } }
(Again I'm sure there's a warpier way to do that and I was just being silly).
But you haven't seen that part:
pub trait HasCookies { fn insert_cookie(self, c: Cookie<'static>) -> Self; fn remove_cookie(self, c: Cookie<'static>) -> Self; fn apply_session_cookies(self, sc: SessionCookies) -> Self; } impl HasCookies for http::response::Builder { fn insert_cookie(self, c: Cookie<'static>) -> Self { self.header("set-cookie", c.encoded().to_string()) } fn remove_cookie(self, mut c: Cookie<'static>) -> Self { c.set_expires(Some( time::OffsetDateTime::now_utc() - time::Duration::days(1), )); self.insert_cookie(c) } fn apply_session_cookies(self, sc: SessionCookies) -> Self { sc.apply(self) } }
Or that part:
#[derive(Debug, Serialize, Deserialize)] pub struct SignedCookie { pub signature: String, pub payload: String, } impl SignedCookie { pub fn new<T>(config: &Config, payload: &T) -> Result<Self> where T: Serialize, { let payload = serde_json::to_string(payload)?; let signature = hmac_sha256::HMAC::mac(payload.as_bytes(), config.secrets.cookie_sauce.as_bytes()); let signature = format!("{}", HexSlice(&signature)); Ok(Self { signature, payload }) } pub fn get<T>(&self, config: &Config) -> Result<T> where T: DeserializeOwned, { let actual_sig = hmac_sha256::HMAC::mac( self.payload.as_bytes(), config.secrets.cookie_sauce.as_bytes(), ); let actual_sig = format!("{}", HexSlice(&actual_sig)); if self.signature != actual_sig { return Err(SignedCookieError::InvalidSignature.into()); } let res = serde_json::from_str(&self.payload)?; Ok(res) } } #[derive(Debug, thiserror::Error)] enum SignedCookieError { #[error("invalid cookie signature")] InvalidSignature, }
Well, turns out, none of that is needed. As shown earlier,
tower-cookies works well with axum
(I
had to fork it for 0.6
support, since it's not stable
yet, but it was a simple change).
With tower-cookies
, you get a Cookies
jar, and you can call get
/ add
/
remove
at any time, and by the time your HTTP handler returns, it sets the
right headers.
Also, from a Cookies
jar, you can derive a SignedCookies
jar or a
PrivateCookies
jar (I ended up using the latter), if you have a cryptographic
master key, which you can
derive from another master key, or trust the cookie crate to generate from a
secure source for you.
I don't have much code to show you here — the diff is strictly negative, I just got rid of all my ad-hoc stuff and things worked out of the box.
Here's what the "loading credentials from cookies" code looks like now, with the shiny new let-else feature and all:
impl FutileCredentials { pub async fn load_from_cookies(config: &Config, cookies: &PrivateCookies<'_>) -> Option<Self> { let cookie = cookies.get(Self::COOKIE_NAME)?; let creds: Self = match serde_json::from_str(cookie.value()) { Ok(v) => v, Err(e) => { warn!(?e, "Got undeserializable cookie, removing"); cookies.remove(cookie); return None; } }; let now = Utc::now(); if now < creds.expires_at { // credentials aren't expired yet return Some(creds); } let Some(patreon_credentials) = &creds.patreon_credentials else { warn!("Don't know how to renew non-Patreon credentials"); return None; }; info!( "Refreshing patreon credentials (expired at {:?}, is now {now:?})", creds.expires_at, ); let mut refresh_creds = patreon_credentials.clone(); if is_development() && test_patreon_renewal() { refresh_creds.access_token = "bad-token-for-testing".into() } // async because this hits the Patreon API let creds = match refresh_creds.to_futile_credentials(config).await { Ok(creds) => creds, Err(e) => { warn!(?e, "Could not renew patreon credentials, will log out"); cookies.remove(cookie); return None; } }; cookies.add(creds.as_cookie(config)); Some(creds) } }
Thanks to my sponsors:
If you liked what you saw, please support my work!
Here's another article just for you:
It feels like an eternity since I've started using Rust, and yet I remember vividly what it felt like to bang my head against the borrow checker for the first few times.
I'm definitely not alone in that, and there's been quite a few articles on the subject! But I want to take some time to present the borrow checker from the perspective of its , rather than as an opponent to fend with.