Migrating from warp to axum

This article is part of the Updating fasterthanli.me for 2022 series.

Falling out of love with warp

Back when I wrote this codebase, warp was the best / only alternative for something relatively high-level on top of hyper.

I was never super fond of warp's model — it's a fine crate, just not for me.

The way routing works is essentially building a type that gets larger and larger. One route might look like:

Rust code
    let bye = warp::path("bye")
        .and(warp::path::param())
        .map(|name: String| format!("Good bye, {}!", name));

And then all the routes combined might look like:

Rust code
    let routes = warp::get().and(
        hello_world
            .or(hi)
            .or(hello_from_warp)
            .or(bye)
            .or(math)
            .or(sum)
            .or(times),
    );

That's from warp's routing example.

We can look at the type of the whole routes expression with this simple trick:

Rust code
    let a: () = routes;

And as we see, it's fairly large:

sh
$ cargo check --examples
cargo check --examples        
    Checking warp v0.3.3 (/home/amos/bearcove/warp)
error[E0308]: mismatched types
   --> examples/routing.rs:103:17
    (cut)
    = note: expected unit type `()`
                  found struct `warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>, warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::or::Or<warp::filter::map::Map<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>, [closure@examples/routing.rs:13:45: 13:47]>, warp::filter::map::Map<Exact<warp::path::internal::Opaque<&str>>, [closure@examples/routing.rs:16:35: 16:37]>>, warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:21:70: 21:72]>>, warp::filter::map::Map<warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (String,), Error = Rejection>>, [closure@examples/routing.rs:51:14: 51:28]>>, warp::filter::or::Or<warp::filter::map::Map<warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:70:14: 70:16]>, warp::filter::and::And<Exact<warp::path::internal::Opaque<&str>>, warp::filter::or::Or<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:26:50: 26:56]>, warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:32:46: 32:52]>>>>>, warp::filter::map::Map<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u32,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:26:50: 26:56]>, [closure@examples/routing.rs:74:23: 74:31]>>, warp::filter::map::Map<warp::filter::map::Map<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<warp::filter::and::And<impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Infallible>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, Exact<warp::path::internal::Opaque<main::{closure#0}::__StaticPath>>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (u16,), Error = Rejection>>, impl warp::Filter + Copy + warp::filter::FilterBase<Extract = (), Error = Rejection>>, [closure@examples/routing.rs:32:46: 32:52]>, [closure@examples/routing.rs:76:19: 76:27]>>>`

For more information about this error, try `rustc --explain E0308`.
error: could not compile `warp` due to previous error

(In fact, it made the estimated reading time for this article jump from 8 to 13 minutes, all by itself).

This is an issue on several levels. First, that's a lot of work for the compiler. "My warp app takes forever to compile" is the most active issue on the issue tracker. In fact, I've written about this specific problem before.

Second, the errors are awful. In your mind, I might be pretty good at Rust, but facing those errors, even I recoiled in horror. I just had no idea what it wanted me to do. So I ended up doing some... terrible, non-warpy things. This is what my router ended up looking like:

Rust code
    // this is a `ServerState` struct that has a cache, the config (including secrets)
    // stuff like that. we want that from any request handlers.
    let with_state = warp::any().map(move || server_state.clone());

    // because most of my website is dynamic and rendered from liquid templates,
    // (and those are cached), I need stuff like the full path, raw query string,
    // cookies (for login state), etc.
    let with_cx = with_state
        .clone()
        .and(warp::filters::path::full())
        .and(
            warp::filters::query::raw()
                .recover(|_| async {
                    let res: Result<String, Infallible> = Ok("".to_string());
                    res
                })
                .unify(),
        )
        .and(warp::filters::header::optional("cookie"))
        .and(
            warp::filters::body::bytes()
                .recover(|rej| {
                    warn!("while getting body bytes: {:?}", rej);
                    async {
                        let res: Result<Bytes, Infallible> = Ok(Default::default());
                        res
                    }
                })
                .unify(),
        )
        .map(
            |state,
             full_path: FullPath,
             query: String,
             cookie_header: Option<String>,
             body: Bytes| {
                let mut cookies = HashMap::new();
                if let Some(ch) = cookie_header {
                    for cookie in ch
                        .split(';')
                        .map(|x| x.trim())
                        .filter_map(|x| Cookie::parse_encoded(x.to_string()).ok())
                    {
                        cookies.insert(cookie.name().to_string(), cookie);
                    }
                }
                let cx = Context {
                    state,
                    raw_query: Arc::new(Query::new(query)),
                    path: full_path.as_str().to_string(),
                    cookies,
                    raw_body: body,
                };
                cx
            },
        );

    // I forget why this is needed 🙃 but it is.
    let with_cx = move || with_cx.clone();

    // This is just how I did live-reloading.
    use std::convert::Infallible;
    use warp::filters::sse::Event as ServerSentEvent;
    // This is a free-standing function because it's the only reasonable way
    // to annotate the error type of the TryStream.
    fn sse_events(cx: Context) -> impl Stream<Item = Result<ServerSentEvent, Infallible>> {
        let mut rx = cx.state.broadcast_rev.subscribe();
        // cf. https://lib.rs/crates/async-stream
        // could be written without it, but I like that crate.
        async_stream::try_stream! {
            yield ServerSentEvent::default().event("message").data("Live reloading enabled");

            loop {
                tokio::select! {
                    Ok(rev_id) = rx.recv() => {
                        yield ServerSentEvent::default().event("new-revision").data(rev_id);
                    },
                    _ = tokio::time::sleep(Duration::from_secs(1)) => {
                        yield ServerSentEvent::default().event("ping").data("Just keeping connections alive");
                    },
                }
            }
        }
    }

    let livereload_route = method::get()
        .and(warp::path!("api" / "livereload"))
        .and(with_cx())
        .map(|cx: Context| warp::sse::reply(warp::sse::keep_alive().stream(sse_events(cx))));

    // uwu what's this? a single route? weird...
    let catchall_get = method::get()
        .and(with_cx())
        .and_then(|cx: Context| cx.handle(routes::serve_get));

    let catchall_head = method::head()
        .and(with_cx())
        .and_then(|cx: Context| cx.handle(routes::serve_get));

    let catchall_post = method::post()
        .and(with_cx())
        .and_then(|cx: Context| cx.handle(routes::serve_post));

    let all_routes = livereload_route
        .boxed()
        .or(catchall_get.boxed())
        .or(catchall_head.boxed())
        .or(catchall_post.boxed());
    let access_log = warp::filters::log::log("access");

It used to be much worse, because I used to use warp the way it's /meant/ to be used... before I gave up and just started using handlers like serve_get:

Rust code
pub async fn serve_get(cx: Context) -> HttpResult {
    if cx.path.ends_with('/') && cx.path != "/" {
        return serve_redirect(cx.path.trim_end_matches('/')).await;
    }

    if cx.path == "/robots.txt" {
        return Ok(Box::new(
            Response::builder()
                .status(StatusCode::OK)
                .body(Body::empty()),
        ));
    }

    if cx.path == "/tags" {
        return tags::serve_list(cx).await;
    } else if cx.path.starts_with("/tags/") {
        return tags::serve_single(cx).await;
    }

    if cx.path == "/settings" {
        return settings::serve(cx).await;
    }

    if cx.path == "/search" {
        return search::serve(cx).await;
    }

    if cx.path == "/login" {
        return login::serve_login(cx).await;
    }

    if cx.path == "/patreon/oauth" {
        return login::serve_patreon_oauth(cx).await;
    }

    if cx.path == "/logout" {
        return login::serve_logout(cx).await;
    }

    if cx.path == "/debug-credentials" {
        return login::serve_debug_credentials(cx).await;
    }

    if cx.path == "/comments" {
        return comments::serve(cx).await;
    }

    if cx.path == "/latest-video" {
        return latest_video::serve(cx).await;
    }

    if cx.path == "/patron-list" {
        return patron_list::serve(cx).await;
    }

    if cx.path == "/index.xml" {
        return cx
            .serve_template("index.xml", Default::default(), mime::atom())
            .await;
    }

    revision_routes::serve(cx).await
}

A single route would then look something like:

Rust code
pub async fn serve_login(cx: Context) -> HttpResult {
    #[derive(Deserialize)]
    struct QueryParams {
        #[serde(default)]
        return_to: Option<String>,
    }

    let location = make_patreon_login_url(cx.config())?;

    let mut res = Response::builder()
        .status(StatusCode::TEMPORARY_REDIRECT)
        .header("location", location);
    if let Ok(params) = cx.query::<QueryParams>() {
        if let Some(return_to) = params.return_to {
            let mut cookie = Cookie::new("return_to", return_to);
            cookie.set_path("/");
            cookie.set_expires(time::OffsetDateTime::now_utc() + time::Duration::minutes(30));
            res = res.insert_cookie(cookie);
        }
    }
    let res = res.body("Redirecting to Patreon login...".to_string())?;
    Ok(Box::new(res))
}

This is absolutely not the way warp was meant to be used, but this is what I ended up doing, to survive the compile times + type spaghetti party.

The HttpResult type enforced boxed replies, to avoid the "I'm returning different types from different codepaths and they won't unify" problem:

Rust code
pub use warp::Reply;
pub type HttpResult = color_eyre::Result<Box<dyn Reply>>;

The opinions of axum, also nice error handling

One nice thing about axum is that it defaults to something simple. It provides its own BoxBody body type, and an axum::body::boxed function to turn damn near anything into it.

It has an IntoResponse trait, which lets your handlers look like:

Rust code
// basic handler that responds with a static string
async fn root() -> &'static str {
    "Hello, World!"
}

(Because IntoResponse is implemented for &'static str)

Or like:

Rust code
async fn create_user(
    // this argument tells axum to parse the request body
    // as JSON into a `CreateUser` type
    Json(payload): Json<CreateUser>,
) -> impl IntoResponse {
    // insert your application logic here
    let user = User {
        id: 1337,
        username: payload.username,
    };

    // this will be converted into a JSON response
    // with a status code of `201 Created`
    (StatusCode::CREATED, Json(user))
}

(Because IntoResponse is implemented for tuples like (StatusCode, IntoResponse)).

There's still the "type unification" footgun with that approach — if you conditionally returned (StatusCode::FORBIDDEN, "nuh-huh"), you'd get a compile error, and you'd need to change the return type to Response<BoxBody>, and call .into_response() on both these tuples.

It's also useful to be able to turn errors into HTTP responses, so, I came up with my own HttpResult type:

Rust code
// in `crates/futile/src/serve/routes/response.rs`

use std::borrow::Cow;

use axum::response::{IntoResponse, Response};
use color_eyre::Report;
use futile_backtrace_printer::make_backtrace_printer;
use futile_config::is_production;
use http::{header, StatusCode};
use time::format_description::well_known::Rfc3339;
use tracing::error;

use crate::serve::{
    html_color_output::HtmlColorOutput, routes::mime, template_renderer::inject_livereload,
};

pub type HttpResult = Result<Response, HttpError>;

pub trait IntoHttp {
    fn into_http(self) -> HttpResult;
}

impl<T: IntoResponse> IntoHttp for T {
    fn into_http(self) -> HttpResult {
        Ok(self.into_response())
    }
}

#[derive(Debug)]
pub enum HttpError {
    NotFound { msg: Cow<'static, str> },
    Internal { err: String },
}

impl HttpError {
    fn from_report(err: Report) -> Self {
        error!("HTTP handler error: {}", err.root_cause());

        let bt_printer = make_backtrace_printer();

        let maybe_bt = err
            .context()
            .downcast_ref::<color_eyre::Handler>()
            .and_then(|h| h.backtrace());
        if let Some(bt) = maybe_bt {
            error!("Backtrace:");
            let mut stream = termcolor::StandardStream::stderr(termcolor::ColorChoice::Always);
            bt_printer.print_trace(bt, &mut stream).unwrap();
        } else {
            error!("No Backtrace");
        }

        let trace_content = if is_production() {
            "".into()
        } else {
            let mut err_string = String::new();
            for (i, e) in err.chain().enumerate() {
                use std::fmt::Write;
                let _ = writeln!(&mut err_string, "{}. {}", i + 1, e);
            }

            let mut err_string_escaped = String::new();
            html_escape::encode_safe_to_string(&err_string, &mut err_string_escaped);

            let backtrace: String = if let Some(bt) = maybe_bt {
                let mut output = HtmlColorOutput::new();
                bt_printer.print_trace(bt, &mut output).unwrap();
                output.into()
            } else {
                "".into()
            };

            format!(
                r#"
<pre class="trace">{err_string_escaped}
{backtrace}</pre>
                "#,
            )
        };

        let date = time::OffsetDateTime::now_utc().format(&Rfc3339).unwrap();

        let body = format!(
            r#"
            <html>
            {padding}
            <head>
            <style>
                @import url('https://fonts.googleapis.com/css2?family=Alegreya+Sans&family=Source+Code+Pro&display=swap');

                body {{
                    font-family: "Alegreya Sans", sans-serif;
                    max-width: 1200px;
                    margin: 20px auto;
                    line-height: 1.6;
                }}

                h2 {{
                    color: #b73535;
                }}

                pre.trace {{
                    font-family: "Source Code Pro", monospace;
                    padding: 1em;
                    border: 2px solid #999;
                    font-size: 14px;
                    background: #212121;
                    color: #d9d9d9;
                    white-space: pre-wrap;
                    overflow-x: auto;
                }}
            </style>
            </head>
            <body>
                <h1>Something went terribly wrong ({date})</h1>

                {trace_content}

                <p>
                    Try <a href="/">Going back to the homepage</a>, maybe?
                </p>
            </body>
            </html>
            "#,
            padding = "<!-- Padding to avoid browser 500 error -->\n".repeat(10),
            trace_content = trace_content
        );

        HttpError::Internal { err: body }
    }
}

macro_rules! impl_from {
    ($from:ty) => {
        impl From<$from> for HttpError {
            fn from(err: $from) -> Self {
                Self::from_report(err.into())
            }
        }
    };
}

impl_from!(std::io::Error);
impl_from!(color_eyre::Report);
impl_from!(http::Error);
impl_from!(http::header::InvalidHeaderValue);
impl_from!(http::uri::InvalidUri);
impl_from!(serde_json::Error);
impl_from!(url::ParseError);
impl_from!(liquid_core::Error);
impl_from!(sqlx::Error);
impl_from!(crate::cached::CachedError);
impl_from!(futile_patreon::PatreonError);

impl IntoResponse for HttpError {
    fn into_response(self) -> Response {
        match self {
            HttpError::NotFound { msg } => (StatusCode::NOT_FOUND, msg).into_response(),
            HttpError::Internal { err } => (
                StatusCode::INTERNAL_SERVER_ERROR,
                [(header::CONTENT_TYPE, mime::html())],
                inject_livereload(&err).to_string(),
            )
                .into_response(),
        }
    }
}

There's a lot here, and I'm not going to comment everything, but essentially, by default, things like deserialization errors etc., get turned into 500 Internal Server Errors, with a nice (colored!) stack trace in development and an opaque error in production.

Extractors, handling shared state in axum 0.6

That means my handlers can return that type, and use ? (unless I want to map those errors to more precise status codes with nicer error pages):

Rust code
#[axum::debug_handler(state = crate::serve::ServerState)]
pub async fn serve_login(
    State(state): State<ServerState>,
    cookies: Cookies,
    params: Option<Form<ReturnTo>>,
) -> HttpResult {
    let cookies = state.private_cookies(cookies);

    if let Some(params) = params {
        if let Some(return_to) = params.return_to.as_ref() {
            let mut cookie = Cookie::new("return_to", return_to.to_owned());
            cookie.set_path("/");
            cookie.set_expires(time::OffsetDateTime::now_utc() + time::Duration::minutes(30));
            cookies.add(cookie);
        }
    }

    let location = make_patreon_login_url(&state.config)?;
    Redirect::to(&location).into_http()
}

The axum::debug_handler macro is invaluable to debug type errors (there's some with axum too), like for example, accidentally having a non-Send type slip in. It let you specify which state type you're using in that application, which is actually unnecessary here since it's also specified as an extractor, in this here line:

Rust code
    State(state): State<ServerState>,

Cookies are extracted, too, via tower-cookies:

Rust code
    cookies: Cookies,

Lastly, query parameters are extracted (optionally) via the Form extractor:

Rust code
    params: Option<Form<ReturnTo>>,

The ReturnTo type is a struct that derives serde's Deserialize:

Rust code
#[derive(Deserialize)]
pub struct ReturnTo {
    #[serde(default)]
    return_to: Option<String>,
}

And that lets us decode a URL like /login?return_to=articles%2Fupdating-my-website-for-2022 into something nice and strongly typed.

Each of these extractors implement FromRequestParts, which is something you can just do yourself, so I ended up making up a TemplateRenderer extractor, since, again, a lot of my website is dynamic, and templates take a lot of state as input.

A handler using it looks something like:

Rust code
#[axum::debug_handler(state = crate::serve::ServerState)]
async fn atom_feed(tr: TemplateRenderer) -> HttpResult {
    tr.render_ex(
        "index.xml",
        TemplateCacheBehavior::Cache,
        Default::default(),
        mime::atom(),
    )
    .await
}

Or, even simpler:

Rust code
#[axum::debug_handler(state = crate::serve::ServerState)]
pub async fn serve(tr: TemplateRenderer) -> HttpResult {
    tr.render("settings.html").await
}

The live-reload stuff, using server-sent events, looks very similar, since warp and axum are using the same underlying crates:

Rust code
#[axum::debug_handler(state = crate::serve::ServerState)]
async fn livereload(State(state): State<ServerState>) -> HttpResult {
    fn make_stream(state: &ServerState) -> impl Stream<Item = Result<sse::Event, Infallible>> {
        let mut rx = state.broadcast_rev.subscribe();
        async_stream::try_stream! {
            yield sse::Event::default().event("message").data("Live reloading enabled");

            while let Ok(rev_id) = rx.recv().await {
                yield sse::Event::default().event("new-revision").data(rev_id);
            }
        }
    }

    Sse::new(make_stream(&state))
        .keep_alive(KeepAlive::default())
        .into_http()
}

(I also discovered I didn't need to do my own keep-alive handling — it's built right in there!)

Routing in axum 0.6

How are these routes composed then? Up until axum 0.5, I was very sad, because routes like /fixed and /*splat used to conflict (you'd get a runtime error too, yuck). But as of 0.6 (currently in release candidate) they don't anymore! The fixed one takes priority, and it falls back to the splat.

So, my whole router looks like this:

Rust code
pub fn all_routes(state: ServerState) -> Router<ServerState> {
    Router::with_state(state.clone())
        .route("/robots.txt", get(robots_txt))
        .route("/api/livereload", get(livereload))
        .route("/tags", get(tags::serve_list))
        .route("/tags/:tag", get(tags::serve_single))
        .route("/settings", get(settings::serve).post(settings::save))
        .route("/search", get(search::serve))
        .route("/login", get(login::serve_login))
        .route("/debug-credentials", get(login::serve_debug_credentials))
        .route("/patreon/oauth", get(login::serve_patreon_oauth))
        .route("/logout", get(login::serve_logout))
        .route("/comments", get(comments::serve))
        .route("/patron-list", get(patron_list::serve))
        .route("/index.xml", get(atom_feed))
        .route("/", get(revision_routes::serve))
        .route("/*path", get(revision_routes::serve))
}

Different handlers can be set with method functions/methods like get and post — and you get HEAD support for free for get handlers, as expected. Oh and that type never "gets bigger" in a way that would cause compile-time explosions.

You can do a bunch of interesting stuff like "nest" sub-routers, that I haven't shown here. You might be wondering how extracting path elements works and it's... through extractors, as usual!

Here's both routes for /tags and /tags/:tag:

Rust code
#[axum::debug_handler(state = crate::serve::ServerState)]
pub async fn serve_list(tr: TemplateRenderer) -> HttpResult {
    tr.render("tags.html").await
}

#[axum::debug_handler(state = crate::serve::ServerState)]
pub async fn serve_single(tr: TemplateRenderer, Path(tag): Path<String>) -> HttpResult {
    let mut globals = Object::new();
    globals.insert("tag_name".into(), to_value(&tag)?);

    tr.render_ex(
        "tag.html",
        TemplateCacheBehavior::Cache,
        globals,
        mime::html(),
    )
    .await
}

(Note: the name that Path is destructured into doesn't matter, only order matters).

I haven't shown how the whole app was started - here it is with warp:

Rust code
 let addr: SocketAddr = config.address.parse()?;

    let conns_limit = Semaphore::new(2048);
    let svc = ServiceBuilder::new()
        .layer(GlobalConcurrencyLimitLayer::new(2048))
        .layer(BanUserAgents::new())
        .layer(RequestIdLayer)
        .layer(IncomingHttpSpanLayer::default())
        .service(warp::service(all_routes.with(access_log)));

    let factory = ServiceFactory {
        inner: svc,
        semaphore: PollSemaphore::new(Arc::new(conns_limit)),
        permit: None,
    };

    let acceptor = timeout_acceptor(addr);
    let server = hyper::Server::builder(acceptor).serve(factory);
    info!("Listening on http://{}", addr);
    info!("Base URL is {}", config.base_url);
    server.await?;

And here it is with axum:

Rust code
    let addr: SocketAddr = config.address.parse()?;

    let acceptor = timeout_acceptor(addr);
    let app = routes::all_routes(server_state).layer(
        ServiceBuilder::new()
            .layer(GlobalConcurrencyLimitLayer::new(2048))
            .layer(TraceLayer::new_for_http())
            .layer(BanUserAgents::new())
            .layer(RequestIdLayer)
            .layer(IncomingHttpSpanLayer::default())
            .layer(CookieManagerLayer::new()),
    );
    let server = hyper::Server::builder(acceptor).serve(app.into_make_service());
    info!("Listening on http://{}", addr);
    info!("Base URL is {}", config.base_url);
    server.await?;

That's right, I was able to re-use almost all my custom tower layers with almost no changes (I had to make one accept different body types, an annoying ordeal I've become exceedingly efficient at).

Using tower-cookies (properly)

In the previous iteration, I was using the cookie crate to allow users to log in. But I don't think I ever properly understood how it was supposed to be used, so... I came up with my own signing scheme (I didn't need them to be tamper-proof).

And I had code of my own to extract/set cookies. You've seen that bit from the warp version:

Rust code
                let mut cookies = HashMap::new();
                if let Some(ch) = cookie_header {
                    for cookie in ch
                        .split(';')
                        .map(|x| x.trim())
                        .filter_map(|x| Cookie::parse_encoded(x.to_string()).ok())
                    {
                        cookies.insert(cookie.name().to_string(), cookie);
                    }
                }

(Again I'm sure there's a warpier way to do that and I was just being silly).

But you haven't seen that part:

Rust code
pub trait HasCookies {
    fn insert_cookie(self, c: Cookie<'static>) -> Self;
    fn remove_cookie(self, c: Cookie<'static>) -> Self;
    fn apply_session_cookies(self, sc: SessionCookies) -> Self;
}

impl HasCookies for http::response::Builder {
    fn insert_cookie(self, c: Cookie<'static>) -> Self {
        self.header("set-cookie", c.encoded().to_string())
    }
    fn remove_cookie(self, mut c: Cookie<'static>) -> Self {
        c.set_expires(Some(
            time::OffsetDateTime::now_utc() - time::Duration::days(1),
        ));
        self.insert_cookie(c)
    }
    fn apply_session_cookies(self, sc: SessionCookies) -> Self {
        sc.apply(self)
    }
}

Or that part:

Rust code
#[derive(Debug, Serialize, Deserialize)]
pub struct SignedCookie {
    pub signature: String,
    pub payload: String,
}

impl SignedCookie {
    pub fn new<T>(config: &Config, payload: &T) -> Result<Self>
    where
        T: Serialize,
    {
        let payload = serde_json::to_string(payload)?;
        let signature =
            hmac_sha256::HMAC::mac(payload.as_bytes(), config.secrets.cookie_sauce.as_bytes());
        let signature = format!("{}", HexSlice(&signature));
        Ok(Self { signature, payload })
    }

    pub fn get<T>(&self, config: &Config) -> Result<T>
    where
        T: DeserializeOwned,
    {
        let actual_sig = hmac_sha256::HMAC::mac(
            self.payload.as_bytes(),
            config.secrets.cookie_sauce.as_bytes(),
        );
        let actual_sig = format!("{}", HexSlice(&actual_sig));
        if self.signature != actual_sig {
            return Err(SignedCookieError::InvalidSignature.into());
        }

        let res = serde_json::from_str(&self.payload)?;
        Ok(res)
    }
}

#[derive(Debug, thiserror::Error)]
enum SignedCookieError {
    #[error("invalid cookie signature")]
    InvalidSignature,
}

Well, turns out, none of that is needed. As shown earlier, tower-cookies works well with axum (I had to fork it for 0.6 support, since it's not stable yet, but it was a simple change).

With tower-cookies, you get a Cookies jar, and you can call get / add / remove at any time, and by the time your HTTP handler returns, it sets the right headers.

Also, from a Cookies jar, you can derive a SignedCookies jar or a PrivateCookies jar (I ended up using the latter), if you have a cryptographic master key, which you can derive from another master key, or trust the cookie crate to generate from a secure source for you.

I don't have much code to show you here — the diff is strictly negative, I just got rid of all my ad-hoc stuff and things worked out of the box.

Here's what the "loading credentials from cookies" code looks like now, with the shiny new let-else feature and all:

Rust code
impl FutileCredentials {
    pub async fn load_from_cookies(config: &Config, cookies: &PrivateCookies<'_>) -> Option<Self> {
        let cookie = cookies.get(Self::COOKIE_NAME)?;

        let creds: Self = match serde_json::from_str(cookie.value()) {
            Ok(v) => v,
            Err(e) => {
                warn!(?e, "Got undeserializable cookie, removing");
                cookies.remove(cookie);
                return None;
            }
        };

        let now = Utc::now();
        if now < creds.expires_at {
            // credentials aren't expired yet
            return Some(creds);
        }

        let Some(patreon_credentials) = &creds.patreon_credentials else {
            warn!("Don't know how to renew non-Patreon credentials");
            return None;
        };

        info!(
            "Refreshing patreon credentials (expired at {:?}, is now {now:?})",
            creds.expires_at,
        );

        let mut refresh_creds = patreon_credentials.clone();
        if is_development() && test_patreon_renewal() {
            refresh_creds.access_token = "bad-token-for-testing".into()
        }

        // async because this hits the Patreon API
        let creds = match refresh_creds.to_futile_credentials(config).await {
            Ok(creds) => creds,
            Err(e) => {
                warn!(?e, "Could not renew patreon credentials, will log out");
                cookies.remove(cookie);
                return None;
            }
        };

        cookies.add(creds.as_cookie(config));
        Some(creds)
    }
}

This article is part 2 of the Updating fasterthanli.me for 2022 series.

Read the next part

If you liked what you saw, please support my work!

Github logo Donate on GitHub Patreon logo Donate on Patreon

Latest video

video cover image
How does the detour crate work?

We want to hook a function, so that our code gets called... but we also want the original code to execute. How the heck does that work?

Watch now

You can watch more videos over there

Looking for the homepage?
Another article: Rust modules vs files