Futures Nostalgia

Up until recently, hyper was my favorite Rust HTTP framework. It's low-level, but that gives you a lot of control over what happens.

Here's what a sample hyper application would look like:

Shell session
$ cargo new nostalgia
     Created binary (application) `nostalgia` package
Shell session
$ cd nostalgia
$ cargo add hyper@0.14 --features "http1 tcp server"
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding hyper v0.14 to dependencies with features: ["http1", "tcp", "server"]
$ cargo add tokio@1 --features "full"
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding tokio v1 to dependencies with features: ["full"]
Rust code
use std::{
    convert::Infallible,
    future::{ready, Ready},
    task::{Context, Poll},
};

use hyper::{server::conn::AddrStream, service::Service, Body, Request, Response, Server};

#[tokio::main]
async fn main() {
    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(MyServiceFactory)
        .await
        .unwrap();
}

struct MyServiceFactory;

impl Service<&AddrStream> for MyServiceFactory {
    type Response = MyService;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: &AddrStream) -> Self::Future {
        println!("Accepted connection from {}", req.remote_addr());
        ready(Ok(MyService))
    }
}

struct MyService;

impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        println!("Handling {req:?}");
        ready(Ok(Response::builder().body("Hello World!\n".into()).unwrap()))
    }
}
Shell session
$ cargo run
cargo run
   Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia)
    Finished dev [unoptimized + debuginfo] target(s) in 1.20s
     Running `target/debug/nostalgia`
(server runs peacefully in the background)

Aaaand half the readers have closed the page already.

What? Let me show how it works at least!

See, you can curl it, and it works:

Shell session
$ curl 0:1025
Hello World!

Wait a minute hold on - what kind of address is that?

Well... omitted octets in IPv4 addresses are filled with zeroes, so 127.1 is 127.0.0.1 for example...

So... 0 is 0.0.0.0? Isn't that the address we listen on when we want to accept traffic from all network interfaces? How does that work?

I'll uh... I'll get back to you on that. But it does work.

For completeness, here's what it shows in the terminal pane where the server is running:

Shell session
Accepted connection from 127.0.0.1:50408
Handling Request { method: GET, uri: /, version: HTTP/1.1, headers: {"host": "0.0.0.0:1025", "user-agent": "curl/7.79.1", "accept": "*/*"}, body: Body(Empty) }

Okay, cool. That's an ungodly amount of code though - I could do the same in, like, ten lines of G-

Ahhh but that's not the point! The point is that the design of hyper is delightfully perceptible to the naked eye! It's 90% the tower Service trait and 10%, huh... waves hands HTTP stuff.

HTTP stuff?

Yeah, you know! Easy. Text-based protocol, couple headers, some chunking if needed, the occasional trailer, cheerio good day sir.

...and HTTP/2?

Oh yeah, lil' bit of binary, adaptive windows, header compression, add multiplexing to taste. Nothing too hard there. Really it's mostly just the Service trait, look at it.

...and TLS?

Eh that's all rustls or OpenSSL or some fork thereof, but who cares, LOOK AT THAT TRAIT:

Rust code
struct MyServiceFactory;

impl Service<&AddrStream> for MyServiceFactory {
    type Response = MyService;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: &AddrStream) -> Self::Future {
        println!("Accepted connection from {}", req.remote_addr());
        ready(Ok(MyService))
    }
}

Isn't it beautiful?

What... what am I supposed to see?

Backpressure, bear! The cornerstone of a nutritious breakfast stable application server.

See, before spawning any future onto the executor, it asks us if we're ready! With poll_ready. And if we're not, it... well it waits until we are.

That means we can control how many concurrent connections there can be to our service!

What would that look like?

Ah jeeze bear, that's not really the point of the article, but if you insist, I suppose we could, uh... use a semaphore maybe?

But first we should probably keep track of how many connections we have concurrently...

Maybe we keep an Arc<AtomicU64> in MyServiceFactory, which we increment when we accept a connection, and decrement when we drop a MyService?

Rust code
#[tokio::main]
async fn main() {
    Server::bind(&([127, 0, 0, 1], 1025).into())
        // 👇 previously was a unit struct (just `MyServiceFactory`)
        .serve(MyServiceFactory::default())
        .await
        .unwrap();
}

// 👇 Now holding an atomically-reference-counted atomic counter
#[derive(Default)]
struct MyServiceFactory {
    num_connected: Arc<AtomicU64>,
}

impl Service<&AddrStream> for MyServiceFactory {
    // (cut: everything except for call)

    fn call(&mut self, req: &AddrStream) -> Self::Future {
        let prev = self.num_connected.fetch_add(1, Ordering::SeqCst);
        println!(
            "⬆️ {} connections (accepted {})",
            prev + 1,
            req.remote_addr()
        );
        ready(Ok(MyService {
            num_connected: self.num_connected.clone(),
        }))
    }
}

// 👇 Now also holding a counter
struct MyService {
    num_connected: Arc<AtomicU64>,
}

impl Drop for MyService {
    fn drop(&mut self) {
        let prev = self.num_connected.fetch_sub(1, Ordering::SeqCst);
        println!("⬇️ {} connections (dropped)", prev - 1);
    }
}

impl Service<Request<Body>> for MyService {
    // (cut: everything except call)
    fn call(&mut self, req: Request<Body>) -> Self::Future {
        // 👇 made these logs a little nicer
        println!("{} {}", req.method(), req.uri());
        // otherwise the same
        ready(Ok(Response::builder()
            .body("Hello World!\n".into())
            .unwrap()))
    }
}

And now, a single curl request results in these logs:

Shell session
$ cargo run --quiet
⬆️ 1 connections (accepted 127.0.0.1:50416)
GET /
⬇️ 0 connections (dropped)

But we can also make requests by hand:

Shell session
$ socat - TCP4:0:1025
GET / HTTP/1.1

> HTTP/1.1 200 OK
> content-length: 13
> date: Sat, 02 Apr 2022 23:59:58 GMT
>
> Hello World!
GET /ahAH HTTP/1.1

> HTTP/1.1 200 OK
> content-length: 13
> date: Sun, 03 Apr 2022 00:00:04 GMT
> 
> Hello World!

(Note I've prefixed response lines by > manually. The "GET" lines followed by two newlines are typed manually)

And from the server's perspective, this looks like this:

Shell session
$ cargo run --quiet
⬆️ 1 connections (accepted 127.0.0.1:50420)
GET /
GET /ahAH
⬇️ 0 connections (dropped)

Two requests! From the same connection! Who needs h2/h3/quic? Huh?

Me. I do. Come onnnnn.

I don't feel like doing enough typing to test our limiter though... how about we use a load testing tool? Like oha?

Shell session
$ oha http://127.0.0.1:1025
(a lot of output ensues)

Status code distribution:
  [200] 200 responses

That generated... a /lot/ of output on the server side. But I've redirected it to a file, because I'm a forward-thinking young lad.

Like so:

Shell session
$ cargo run --release --quiet | tee /tmp/server-log.txt
(server output is printed as normal, but is also logged to the file)

Which means I can now easily count how many connections we had at our peakest of peaks:

Shell session
$ cat /tmp/server-log.txt | grep '⬆' | cut -d ' ' -f 2 | sort -n | tail -1
50
Cool bear's hot tip

You may think that's a useless use of cat, but actually, cat is union now, so think twice about giving its work away to someone else.

50! What a suspicious value, almost as if it was oha's default...

Shell session
$ oha --help
(cut)
    -c <N_WORKERS>
            Number of workers to run concurrently. You may should increase limit to number of open
            files for larger `-c`. [default: 50]

Mh yep!

Okay but we're trying to limit it. Let's use... a semaphore!

Shell session
$ cargo add tokio-util@0.7
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding tokio-util v0.7 to dependencies
Rust code
struct MyServiceFactory {
    num_connected: Arc<AtomicU64>,
    semaphore: PollSemaphore,
}

impl Default for MyServiceFactory {
    fn default() -> Self {
        Self {
            num_connected: Default::default(),
            semaphore: PollSemaphore::new(Arc::new(Semaphore::new(5))),
        }
    }
}

There! I've set a limit of 5 permits, and so now all we have to do is... try to acquire a semaphore from poll_ready I suppose!

Shell session
$ cargo add futures
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding futures v0.3.21 to dependencies
Rust code
impl Service<&AddrStream> for MyServiceFactory {
    // (cut: all except for poll_ready)

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        let permit = futures::ready!(self.semaphore.poll_acquire(cx)).unwrap();
        Ok(()).into()
    }
}

There! That handy little ready! macro lets us extract the successful type of a Poll<T>, and we unwrap that Option<OwnedSemaphorePermit> because, well, we never close the semaphore.

But uh... where do we store that permit? We're not creating a MyService from poll_ready, merely checking for readiness!

Okay we'll need one more field:

Rust code
struct MyServiceFactory {
    num_connected: Arc<AtomicU64>,
    semaphore: PollSemaphore,
    // 👇 new!
    permit: Option<OwnedSemaphorePermit>,
}

Adjusting the impl Default is left as an exercise to the reader.

And now, in poll_ready, we simply try to acquire a permit if we don't have one already.

Rust code
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if self.permit.is_none() {
            self.permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap());
        }
        Ok(()).into()
    }

And in call, well, we just take it! And explode loudly in case someone hasn't called poll_ready before:

Rust code
    fn call(&mut self, req: &AddrStream) -> Self::Future {
        let permit = self.permit.take().expect(
            "you didn't drive me to readiness did you? you know that's a tower crime right?",
        );

        let prev = self.num_connected.fetch_add(1, Ordering::SeqCst);
        println!(
            "⬆️ {} connections (accepted {})",
            prev + 1,
            req.remote_addr()
        );
        ready(Ok(MyService {
            num_connected: self.num_connected.clone(),
            permit,
        }))
    }

Of course we need somewhere to store that permit:

Rust code
struct MyService {
    num_connected: Arc<AtomicU64>,
    permit: OwnedSemaphorePermit,
}

rustc will complain that this is dead code, but it's very much not - holding that type is proof that we're allowed to run within the concurrency limits we've set for ourselves.

We... no, you can fix that by renaming it to _permit. I won't.

So now, let's try our oha test again!

Shell session
$ cargo run --release --quiet | tee /tmp/server-log.txt

# in another pane:
$ oha http://127.0.0.1:1025

# after stopping the server:
$ cat /tmp/server-log.txt | grep '⬆' | cut -d ' ' -f 2 | sort -n | tail -1
5

Wonderful! How often do things work exactly as expected? Not so often, I tell you what.

Oh, and we can actually remove our atomic counter, because semaphores do their own counting.

Which makes our complete program this:

Rust code
use std::{
    convert::Infallible,
    future::{ready, Ready},
    sync::Arc,
    task::{Context, Poll},
};

use hyper::{server::conn::AddrStream, service::Service, Body, Request, Response, Server};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;

#[tokio::main]
async fn main() {
    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(MyServiceFactory::default())
        .await
        .unwrap();
}

const MAX_CONNS: usize = 5;

struct MyServiceFactory {
    semaphore: PollSemaphore,
    permit: Option<OwnedSemaphorePermit>,
}

impl Default for MyServiceFactory {
    fn default() -> Self {
        Self {
            semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))),
            permit: None,
        }
    }
}

impl Service<&AddrStream> for MyServiceFactory {
    type Response = MyService;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if self.permit.is_none() {
            self.permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap());
        }
        Ok(()).into()
    }

    fn call(&mut self, _req: &AddrStream) -> Self::Future {
        let permit = self.permit.take().expect(
            "you didn't drive me to readiness did you? you know that's a tower crime right?",
        );

        println!(
            "⬆️ {} connections",
            MAX_CONNS - self.semaphore.available_permits()
        );
        ready(Ok(MyService { _permit: permit }))
    }
}

struct MyService {
    _permit: OwnedSemaphorePermit,
}

impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        println!("{} {}", req.method(), req.uri());
        ready(Ok(Response::builder()
            .body("Hello World!\n".into())
            .unwrap()))
    }
}

That's not an async server though

I mean... yes it is an async server, but I see what you mean - we just immediately reply with "Hello World!", we're not even pretending to think about it a little bit.

So let's pretend. Now our Future type cannot be Ready anymore. We could make a custom future, like so:

Rust code
struct PretendFuture {
    sleep: Sleep,
    response: Option<Response<Body>>,
}

impl Future for PretendFuture {
    type Output = Result<Response<Body>, Infallible>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        futures::ready!(
            unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.sleep) }.poll(cx)
        );
        Ok(unsafe { self.get_unchecked_mut() }.response.take().unwrap()).into()
    }
}

Which we could then use as our Future type in MyService:

Rust code
impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = PretendFuture;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        println!("{} {}", req.method(), req.uri());
        PretendFuture {
            sleep: tokio::time::sleep(Duration::from_millis(250)),
            response: Some(Response::builder().body("Hello World!\n".into()).unwrap()),
        }
    }
}

And now our latency has SHOT UP, from this:

Shell session
$ oha etc.
(cut)
Latency distribution:
  10% in 0.0001 secs
  25% in 0.0001 secs
  50% in 0.0002 secs
  75% in 0.0015 secs
  90% in 0.0091 secs
  95% in 0.0095 secs
  99% in 0.0103 secs

To this:

Shell session
$ oha etc.
(cut)
Latency distribution:
  10% in 0.2514 secs
  25% in 0.2521 secs
  50% in 0.2522 secs
  75% in 0.2528 secs
  90% in 9.3291 secs
  95% in 9.8342 secs
  99% in 10.0861 secs

Wait, why are some requests taking 10 seconds?

Concurrency limits! If we raise MAX_CONNS to 50, the p99 falls to 256 milliseconds.

Now we have something that looks, more or less, like a real-world application. From the outside at least.

But before we move on, and more importantly, before the unsafe police rains on me, let's use pin-project-lite to get rid of those gnarly map_unchecked_mut and get_unchecked_mut:

Shell session
$ cargo add pin-project-lite
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding pin-project-lite v0.2.8 to dependencies
Rust code
pin_project_lite::pin_project! {
    struct PretendFuture {
        #[pin]
        sleep: Sleep,
        response: Option<Response<Body>>,
    }
}

impl Future for PretendFuture {
    type Output = Result<Response<Body>, Infallible>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        futures::ready!(this.sleep.poll(cx));
        Ok(this.response.take().unwrap()).into()
    }
}

There! Don't you feel better already? And remember: pinning is either structural for a field, or it's not. I think that says it all.

?? No it bloody doesn't?

Ah, would you rather have the long-form explanation?

...perhaps not.

But then, another question arises, from the endless well of question that is me, late at night, in lieu of resting.

How would we limit the number of in-flight requests?

Well... couldn't the future hold a semaphore permit itself?

Ah, that might work! Let's try it.

Rust code
const MAX_CONNS: usize = 50;
const MAX_INFLIGHT_REQUESTS: usize = 5;

struct MyServiceFactory {
    conn_semaphore: PollSemaphore,
    reqs_semaphore: PollSemaphore,
    permit: Option<OwnedSemaphorePermit>,
}

impl Default for MyServiceFactory {
    fn default() -> Self {
        Self {
            conn_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))),
            reqs_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS))),
            permit: None,
        }
    }
}

Now to hand off the semaphore to every instance of MyService, by cloning it - which doesn't change the number of available permits, by the way, it just clones the inner Arc<Semaphore>, so it's all good.

Rust code
impl Service<&AddrStream> for MyServiceFactory {
    // (cut: everything except call)

    fn call(&mut self, _req: &AddrStream) -> Self::Future {
        // (cut: everything except this:)

        ready(Ok(MyService {
            _conn_permit: permit,
            // 👇
            semaphore: self.reqs_semaphore.clone(),
            // 👇 (for later)
            reqs_permit: None,
        }))
    }
}

Now MyService needs to hold a permit and the requests semaphore. And an optional reqs_permit, filled in from poll_ready:

Rust code
struct MyService {
    _conn_permit: OwnedSemaphorePermit,
    semaphore: PollSemaphore,
    reqs_permit: Option<OwnedSemaphorePermit>,
}

impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = PretendFuture;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // 👇
        if self.reqs_permit.is_none() {
            self.reqs_permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap());
        }
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        // 👇
        let permit = self.reqs_permit.take().expect(
            "you didn't drive me to readiness did you? you know that's a tower crime right?",
        );

        println!("{} {}", req.method(), req.uri());
        PretendFuture {
            sleep: tokio::time::sleep(Duration::from_millis(250)),
            response: Some(Response::builder().body("Hello World!\n".into()).unwrap()),
            // 👇
            permit,
        }
    }
}

And the PretendFuture simply needs to hold onto the OwnedSemaphorePermit: again, its presence is proof enough. We can't build a PretendFuture without having an OwnedSemaphorePermit, and dropping the PretendFuture (which happens after it's polled to completion) also releases the permit.

Rust code
pin_project_lite::pin_project! {
    struct PretendFuture {
        #[pin]
        sleep: Sleep,
        response: Option<Response<Body>>,
        // 👇
        permit: OwnedSemaphorePermit,
    }
}

And now, try as you might, you're not gonna get more than... 20 requests per second out of this web server. Because only 5 requests can be in-flight at any given time, and each request takes about 1/4 of a second.

Shell session
$ oha http://127.0.0.1:1025
Summary:
  Success rate: 1.0000
  Total:        10.0610 secs
  Slowest:      2.5176 secs
  Fastest:      0.2519 secs
  Average:      2.2320 secs
  Requests/sec: 19.8788

Forget everything you just learned

So the exercise we just went through is neat, because it shows you how hyper and tower actually work: turning connections into "http services" is done through a service. And turning requests into responses is also done through a service.

And services need to be driven to readiness, before we can call them, which gives us a future, which we can then await, or spawn on a runtime, etc. - something we don't really have to worry about with the way we use hyper, but that we could definitely worry about if such was our wish.

But now, tokei informs me that we have 122 lines of Rust in this project, which seems a tad excessive. In fact, everyone who's already stopped reading probably walked away thinking "dang this language is verbose"!

And who could blame em. But they're gone now, and they're not coming back. It's just us. Hey. How are you holding up? Yeah. I feel that.

So let's remove a bunch of code!

First off, this isn't 2018 anymore - we actually have async blocks and an await postfix keyword now.

We don't really need PretendFuture, we can just have an async block and box it!

Rust code
impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if self.reqs_permit.is_none() {
            self.reqs_permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap());
        }
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let permit = self.reqs_permit.take().expect(
            "you didn't drive me to readiness did you? you know that's a tower crime right?",
        );

        println!("{} {}", req.method(), req.uri());
        Box::pin(async move {
            let _permit = permit;
            tokio::time::sleep(Duration::from_millis(250)).await;
            Ok(Response::builder().body("Hello World!\n".into()).unwrap())
        })
    }
}

Boom. That async block gets turned into a generator, which implements Future (or so I'm told), and it captures ("closes over" if you want to get schmancy) anything it needs (in this case really just permit), which means it keeps "our place in the queue" (disclaimer: not actually a queue) until it's dropped.

Hey, we can even remove a dependency:

Shell session
$ cargo rm pin-project-lite
    Removing pin-project-lite from dependencies

There. How often have you seen me do that, huh?

...thirteen times.

Huh.

Now, we're this close to having clippy yell at us for type complexity with this line:

Rust code
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

So I'm gonna take a defensive stance and go with this instead:

Rust code
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

(That's futures::future::BoxFuture). And look, the lifetime is explicit now! Did you know Box<dyn T> actually means Box<dyn T + 'static>? Well it does!

Our code does the exact same thing, but we're down to 100 lines.

Let's keep going!

Because the Service trait is so well thought-out (they even did a whole piece about it), there's a bunch of re-usable services out there!

For example, there's ConcurrencyLimit, which does... precisely what you think it does.

Here's one way to use it:

Shell session
$ cargo add tower --features limit
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding tower v0.4.12 to dependencies with features: ["limit"]

We still want a shared semaphore:

Rust code
struct MyServiceFactory {
    conn_semaphore: PollSemaphore,
    // 👇 was: PollSemaphore
    reqs_semaphore: Arc<Semaphore>,
    permit: Option<OwnedSemaphorePermit>,
}

impl Default for MyServiceFactory {
    fn default() -> Self {
        Self {
            conn_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))),
            // 👇 that changed too
            reqs_semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)),
            permit: None,
        }
    }
}

But now we return a ConcurrencyLimit<MyService>:

Rust code
impl Service<&AddrStream> for MyServiceFactory {
    // 👇 new! now 100% more nested
    type Response = ConcurrencyLimit<MyService>;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // we're still limiting concurrency by hand for connections
        if self.permit.is_none() {
            self.permit = Some(futures::ready!(self.conn_semaphore.poll_acquire(cx)).unwrap());
        }
        Ok(()).into()
    }

    fn call(&mut self, _req: &AddrStream) -> Self::Future {
        let permit = self.permit.take().expect(
            "you didn't drive me to readiness did you? you know that's a tower crime right?",
        );

        println!(
            "⬆️ {} connections",
            MAX_CONNS - self.conn_semaphore.available_permits()
        );

        // 👇 the nesting occurs here
        ready(Ok(ConcurrencyLimit::with_semaphore(
            MyService {
                _conn_permit: permit,
            },
            self.reqs_semaphore.clone(),
        )))
    }
}

And our MyService service is now considerably simpler - it only concerns itself with the business logic: pretending to do work for a while, then quickly throwing something together:

Rust code
struct MyService {
    _conn_permit: OwnedSemaphorePermit,
}

impl Service<Request<Body>> for MyService {
    type Response = Response<Body>;
    type Error = Infallible;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        println!("{} {}", req.method(), req.uri());
        Box::pin(async move {
            tokio::time::sleep(Duration::from_millis(250)).await;
            Ok(Response::builder().body("Hello World!\n".into()).unwrap())
        })
    }
}

So, that's for limiting request concurrency. Can we limit connections the same way?

I fully expected to go there, but after giving it some thought, my professional opinion is: I don't think so. See, ConcurrencyLimit limits how many futures for a service can exist at any given time, but in the case of MyServiceFactory, the futures are very short-lived, and yield a service (which is used to handle requests coming through the newly-established connections, via some tasks that hyper spawns on the tokio runtime).

So, no dice. But hey, we're down to 94 lines already!

I really want to show you some other stuff now, so I will!

Let's set aside the connections concurrency limiting for now:

Rust code
const MAX_INFLIGHT_REQUESTS: usize = 5;

struct MyServiceFactory {
    reqs_semaphore: Arc<Semaphore>,
}

impl Default for MyServiceFactory {
    fn default() -> Self {
        Self {
            reqs_semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)),
        }
    }
}

impl Service<&AddrStream> for MyServiceFactory {
    // 👇
    type Response = ConcurrencyLimit<MyService>;
    type Error = Infallible;
    type Future = Ready<Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Ok(()).into()
    }

    fn call(&mut self, _req: &AddrStream) -> Self::Future {
        ready(Ok(ConcurrencyLimit::with_semaphore(
            MyService,
            self.reqs_semaphore.clone(),
        )))
    }
}

Did you know we need none of this? tower provides convenience functions that take closures instead. And hyper re-exports them for extra convenience:

Rust code
use hyper::service::make_service_fn;

const MAX_INFLIGHT_REQUESTS: usize = 5;

#[tokio::main]
async fn main() {
    let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS));
    let app = make_service_fn(move |_stream: &AddrStream| {
        let sem = sem.clone();
        async move { Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore(MyService, sem)) }
    });

    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(app)
        .await
        .unwrap();
}

Boom! Just like that, we got rid of our entire MyServiceFactory. And the actual hyper hello world makes a lot more sense now.

We're down to 51 lines.

Yeah, but we lost connection limiting...

Shhhletskeepgoing.

Did you know we don't need MyService either?

We don't do anything interesting in its poll_ready anymore - this all could be just an async function.

And it can be!

Rust code
use std::{convert::Infallible, sync::Arc, time::Duration};

use hyper::{
    server::conn::AddrStream,
    service::{make_service_fn, service_fn},
    Body, Request, Response, Server,
};
use tokio::sync::Semaphore;
use tower::limit::ConcurrencyLimit;

const MAX_INFLIGHT_REQUESTS: usize = 5;

#[tokio::main]
async fn main() {
    let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS));
    let app = make_service_fn(move |_stream: &AddrStream| {
        let sem = sem.clone();
        async move {
            Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore(
                service_fn(|req: Request<Body>| async move {
                    println!("{} {}", req.method(), req.uri());
                    tokio::time::sleep(Duration::from_millis(250)).await;
                    Ok::<_, Infallible>(
                        Response::builder()
                            .body(Body::from("Hello World!\n"))
                            .unwrap(),
                    )
                }),
                sem,
            ))
        }
    });

    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(app)
        .await
        .unwrap();
}

Down to 38 lines, still functionally equivalent.

Now let's stop obsessing over silly metrics like lines of code, and focus on readability instead. I vote we move the service back out of this tower and into its own function:

Rust code
#[tokio::main]
async fn main() {
    let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS));
    let app = make_service_fn(move |_stream: &AddrStream| {
        let sem = sem.clone();
        async move {
            Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore(
                service_fn(hello_world),
                sem,
            ))
        }
    });

    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(app)
        .await
        .unwrap();
}

async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> {
    println!("{} {}", req.method(), req.uri());
    tokio::time::sleep(Duration::from_millis(250)).await;
    Ok(Response::builder()
        .body(Body::from("Hello World!\n"))
        .unwrap())
}

And, for readability, let's introduce ServiceBuilder. On top of the Service trait, tower provides the Layer trait, which lets you "decorate" a service:

Its definition is short and elegant:

Rust code
pub trait Layer<S> {
    type Service;
    fn layer(&self, inner: S) -> Self::Service;
}

And ServiceBuilder helps us compose layers together, like so:

Rust code
// in main
    let app = make_service_fn(move |_stream: &AddrStream| async move {
        let svc = ServiceBuilder::new()
            .layer(ConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS))
            .service(service_fn(hello_world));
        Ok::<_, Infallible>(svc)
    });

Wait, where did our semaphore go? What is that actually limiting?

Good catch bear - we actually changed the semantics here. The limit we're adding here is per service, not overall as we had before.

But before we fix that, I just wanted to show this even shorter version:

TOML markup
# in `Cargo.toml`
# (cut: everything but the tower dependency)

[dependencies]
tower = { version = "0.4.12", features = ["limit", "util"] } # 👈 util is new
Rust code
    let app = make_service_fn(move |_stream: &AddrStream| async move {
        let svc = ServiceBuilder::new()
            .concurrency_limit(MAX_INFLIGHT_REQUESTS)
            .service_fn(hello_world);
        Ok::<_, Infallible>(svc)
    });

Isn't that pretty?

Pretty... but incorrect 🧐

...and also only 29 lines. Are you sure it isn't easier to convince product that it'll benefit users?

...

Fine fine we'll fix it. Luckily, GlobalConcurrencyLayer has us covered:

Rust code
    let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS);
    let app = make_service_fn(move |_stream: &AddrStream| {
        let reqs_limit = reqs_limit.clone();
        async move {
            let svc = ServiceBuilder::new()
                .layer(reqs_limit)
                .service_fn(hello_world);
            Ok::<_, Infallible>(svc)
        }
    });

Theeere.

Oh by the way, we're not doing anything asynchronous here, so we don't really need to use an async move block, and we don't need to do that dance where we clone before the async move block because the resulting future must be 'static... we've already used std::future::ready in this article (not to be confused with the futures::ready! macro), so we can totally do this:

Rust code
    let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS);
    let app = make_service_fn(move |_stream: &AddrStream| {
        std::future::ready(Ok::<_, Infallible>(
            ServiceBuilder::new()
                .layer(reqs_limit.clone())
                .service_fn(hello_world),
        ))
    });

But now... now I'm looking at this and I'm thinking... how hard could it be to bring back connections concurrency?

I have a terrible, terrible idea.

See, ServiceBuilder has a then method, that lets you execute a function after a service.

Let's give it a try:

Rust code
    let app = make_service_fn(move |_stream: &AddrStream| {
        std::future::ready(Ok::<_, Infallible>(
            ServiceBuilder::new()
                .layer(reqs_limit.clone())
                .then(|res: Result<Response<Body>, Infallible>| async move {
                    println!("Just served a request!");
                    res
                })
                .service_fn(hello_world),
        ))
    });
Shell session
$ cargo run --release
   Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia)
    Finished release [optimized] target(s) in 9.11s
     Running `target/release/nostalgia`
GET /
Just served a request!
GET /
Just served a request!
GET /
Just served a request!

# (I'm running curl in another pane, the server isn't haunted)

In that case... I'm much less interested in running something after the service, and much more interested in the state the closure can capture...

...say maybe it could capture a permit 😈

Rust code
    let conns_limit = Arc::new(Semaphore::new(MAX_CONNS));
    let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS);
    let app = make_service_fn(move |_stream: &AddrStream| {
        let conns_limit = conns_limit.clone();
        let reqs_limit = reqs_limit.clone();
        async move {
            let permit = Arc::new(conns_limit.acquire_owned().await.unwrap());
            Ok::<_, Infallible>(
                ServiceBuilder::new()
                    .layer(reqs_limit)
                    .then(move |res: Result<Response<Body>, Infallible>| {
                        drop(permit);
                        std::future::ready(res)
                    })
                    .service_fn(hello_world),
            )
        }
    });

This is definitely a crime, and not what the innocent authors of ThenLayer had in mind, but hey, if it's stupid and it works then it must be 4AM:

Shell session
$ oha etc.
(cut)

Latency distribution:
  10% in 0.2517 secs
  25% in 0.2521 secs
  50% in 0.2524 secs
  75% in 0.2532 secs
  90% in 9.3343 secs
  95% in 9.8390 secs
  99% in 10.0909 secs

Now let's try raising that MAX_CONNS value to 50:

Shell session
$ oha etc.
(cut)

Latency distribution:
  10% in 1.2601 secs
  25% in 2.5188 secs
  50% in 2.5198 secs
  75% in 2.5203 secs
  90% in 2.5209 secs
  95% in 2.5210 secs
  99% in 2.5215 secs

Yup, that checks ou- wait a minute. Those results look really different from before... Mhh. did we have a max in-flight requests limit before? Let's try removing that:

Rust code
use std::{convert::Infallible, sync::Arc, time::Duration};

use hyper::{server::conn::AddrStream, service::make_service_fn, Body, Request, Response, Server};
use tokio::sync::Semaphore;
use tower::ServiceBuilder;

const MAX_CONNS: usize = 50;

#[tokio::main]
async fn main() {
    let conns_limit = Arc::new(Semaphore::new(MAX_CONNS));
    let app = make_service_fn(move |_stream: &AddrStream| {
        let conns_limit = conns_limit.clone();
        async move {
            let permit = Arc::new(conns_limit.acquire_owned().await.unwrap());
            Ok::<_, Infallible>(
                ServiceBuilder::new()
                    .then(move |res: Result<Response<Body>, Infallible>| {
                        drop(permit);
                        std::future::ready(res)
                    })
                    .service_fn(hello_world),
            )
        }
    });

    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(app)
        .await
        .unwrap();
}

async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> {
    println!("{} {}", req.method(), req.uri());
    tokio::time::sleep(Duration::from_millis(250)).await;
    Ok(Response::builder()
        .body(Body::from("Hello World!\n"))
        .unwrap())
}
Shell session
$ Latency distribution:
  10% in 0.2517 secs
  25% in 0.2519 secs
  50% in 0.2521 secs
  75% in 0.2524 secs
  90% in 0.2531 secs
  95% in 0.2543 secs
  99% in 0.2550 secs

Okay, that makes more sense. See, if we have a connections limit of 50 but a requests limit of 5, oha thinks it can issue 50 requests at a time (it's over http/1, there's no multiplexing here), but only the first 5 get serviced immediately. The others wait their turn.

~ backpressure ~

And now for the grand finale: let's remove the conns limit too. And the sleep.

Rust code
use std::convert::Infallible;

use hyper::{
    server::conn::AddrStream,
    service::{make_service_fn, service_fn},
    Body, Request, Response, Server,
};

#[tokio::main]
async fn main() {
    let app = make_service_fn(move |_stream: &AddrStream| async move {
        Ok::<_, Infallible>(service_fn(hello_world))
    });

    Server::bind(&([127, 0, 0, 1], 1025).into())
        .serve(app)
        .await
        .unwrap();
}

async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> {
    println!("{} {}", req.method(), req.uri());
    Ok(Response::builder()
        .body(Body::from("Hello World!\n"))
        .unwrap())
}

And now we have a hyper hello world. And we know exactly what's going on.

Well... there's one piece we haven't really talked about yet. And that's the piece I originally wanted to talk about.

Accepting connections

We've been doing this all along:

Rust code
    Server::bind(&([127, 0, 0, 1], 1025).into())

Without really questioning it.

Well, it's not the only way! We can build a TCP listener ourselves:

Rust code
use tokio::net::TcpListener;

#[tokio::main]
async fn main() {
    let app = make_service_fn(move |_stream: &AddrStream| async move {
        Ok::<_, Infallible>(service_fn(hello_world))
    });

    // 👇
    let ln = TcpListener::bind("127.0.0.1:1025").await.unwrap();
    Server::builder(AddrIncoming::from_listener(ln).unwrap())
        .serve(app)
        .await
        .unwrap();
}

That works just as well.

But, you see, I found myself wanting do something weird...

I have this test suite with hundreds of tests, and most of them start listening on some port. Because they all run concurrently, on a machine that's busy doing other things, I can't exactly specify the port every time. Instead, I just pass in port 0, and let the operating system pick a free port for me, like so:

Rust code
    let ln = TcpListener::bind("127.0.0.1:0").await.unwrap();
    println!("Listening on {}", ln.local_addr().unwrap());
Shell session
$ cargo run --release
   Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia)
    Finished release [optimized] target(s) in 1.28s
     Running `target/release/nostalgia`
Listening on 127.0.0.1:35137

But in some of the tests, we want to simulate a service that's... up, but not quite up. Like, it has reserved a port (so another service can't bind to it, not without some naughty socket flags we're not using there), but it hasn't started listening for connections yet, so any connection attempts will be met with a RST packet, with is TCP for: "we see what you're going for but nah": on the client side, we'd see something like "connection refused".

(As opposed to just dropping the packet, which is TCP for "miss me with that shit")

And that's a super duper basic endeavor with the BSD socket API (what essentially everybody uses - even winsock2), but Rust APIs like libstd and tokio's TcpListener helpfully group together the bind and listen calls, which means we... can't do what we want.

Well, until we remember that, as always, there is a crate for that:

Shell session
$ cargo add socket2
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding socket2 v0.4.4 to dependencies
Rust code
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{SocketAddr, TcpListener};

#[tokio::main]
async fn main() {
    let addr = SocketAddr::from(([127, 0, 0, 1], 0));
    let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP)).unwrap();
    socket.bind(&addr.into()).unwrap();

    let addr = socket.local_addr().unwrap().as_socket().unwrap();
    println!("Bound but not listening on {}", addr);
    assert!(TcpListener::bind(addr).is_err());
    println!("As expected, nobody else can listen on the same address");
    println!("Try curling it, it'll fail (press Enter when done)");
    std::io::stdin().read_line(&mut String::new()).unwrap();

    socket.listen(128).unwrap();
    println!("Okay now we're listening (try curling it now, it should hang)");
    std::io::stdin().read_line(&mut String::new()).unwrap();
}
Shell session
$ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.02s
     Running `target/debug/nostalgia`
Bound but not listening on 127.0.0.1:44277
As expected, nobody else can listen on the same address
Try curling it, it'll fail (press Enter when done)

# (in another pane)
$ curl http://127.0.0.1:44277
curl: (7) Failed to connect to 127.0.0.1 port 44277 after 0 ms: Connection refused

# (back to the server's pane, after pressing enter)
Okay now we're listening (try curling it now, it should hang)

# (in another pane)
curl http://127.0.0.1:44277

(it does hang)

At that point, the kernel's TCP stack has accepted the connection on our behalf (by which I mean it's completed the TCP three-way handshake), and put it in the accept queue, waiting for our call to accept (which would pop it from the queue and let us play with it).

So, that's exactly the behavior we're looking for: bind, wait a bit, then listen. And eventually, start accepting connections.

Once we're done building our socket2::Socket, we can turn it into a std::net::TcpListener, then a tokio::net::TcpListener quite easily:

Rust code
    println!("Okay now let's accept one connection");
    socket.set_nonblocking(true).unwrap();
    let fd = socket.as_raw_fd();
    std::mem::forget(socket);
    let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) };
    let ln = tokio::net::TcpListener::from_std(ln).unwrap();
    let (_stream, _) = ln.accept().await.unwrap();
    println!("Accepted one conn!");

From the client side, we now see a connection reset:

Shell session
$ curl 127.0.0.1:38145
curl: (56) Recv failure: Connection reset by peer

tokio::net::TcpListener can then be turned into an AddrIncoming, which we can pass to Server::builder... simple right?

Well... not that simple. Because by the time we convert it to a tokio::net::TcpListener, we have to have called listen already.

That's why from_raw_fd is unsafe: not only does it need to be a valid, open file descriptor, that is not going to get closed by another thread right after, but it also, in this case, needs to be listening already.

Of course, we could wait until we create the server itself, and I don't know why this occurs to me JUST NOW, but hey, we got that article out of it so let's not question my motives too much.

So... what I'm getting at is that, if we want that behavior (bind, wait a couple seconds, then listen) with a hyper server, and we haven't realized we could simply wait to build the server, then we can't use AddrIncoming: we need to use implement the Accept trait ourselves.

It looks like this:

Rust code
pub trait Accept {
    type Conn;
    type Error;
    fn poll_accept(
        self: Pin<&mut Self>, 
        cx: &mut Context<'_>
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>>;
}

So, okay, let's do it without the sleep, as a warm-up:

Shell session
$ cargo add color-eyre
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding color-eyre v0.6.1 to dependencies
Rust code
use color_eyre::Report;
use hyper::{
    server::accept::Accept,
    service::{make_service_fn, service_fn},
    Body, Request, Response,
};
use socket2::{Domain, Protocol, Socket, Type};
use std::{
    convert::Infallible,
    net::SocketAddr,
    os::unix::prelude::{AsRawFd, FromRawFd},
    pin::Pin,
    task::Context,
    time::Duration,
};
use tokio::net::TcpStream;

#[tokio::main]
async fn main() -> Result<(), Report> {
    let addr = SocketAddr::from(([127, 0, 0, 1], 0));
    let acc = Acceptor::new(addr)?;

    hyper::Server::builder(acc)
        .serve(make_service_fn(|_: &TcpStream| async move {
            Ok::<_, Report>(service_fn(hello_world))
        }))
        .await?;
    Ok(())
}

async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> {
    println!("{} {}", req.method(), req.uri());
    tokio::time::sleep(Duration::from_millis(250)).await;
    Ok(Response::builder()
        .body(Body::from("Hello World!\n"))
        .unwrap())
}

struct Acceptor {
    ln: tokio::net::TcpListener,
}

impl Acceptor {
    fn new(addr: SocketAddr) -> Result<Self, Report> {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;
        println!(
            "Listening on {}...",
            socket.local_addr()?.as_socket().unwrap()
        );
        socket.listen(128)?;
        socket.set_nonblocking(true)?;
        let fd = socket.as_raw_fd();
        std::mem::forget(socket);
        let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) };
        let ln = tokio::net::TcpListener::from_std(ln)?;

        Ok(Self { ln })
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
        let (stream, _) = futures::ready!(self.ln.poll_accept(cx)?);
        Some(Ok(stream)).into()
    }
}
Shell session
$ cargo run
   Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia)
    Finished dev [unoptimized + debuginfo] target(s) in 1.35s
     Running `target/debug/nostalgia`
Binding...
Listening on 127.0.0.1:44063...
GET /

# (in another pane)
$ curl 0:44063
Hello World!

And now... let's add the sleep! Well, just as before, we'll need a Sleep future, and for the rest, well... we'll need to hold onto the Socket until we listen and turn it into a tokio::net::TcpListener. So we'll need an enum to know where we're at.

As for the rest, uh... read slowly.

Rust code
enum Acceptor {
    Waiting { sleep: Sleep, socket: Socket },
    Listening { ln: tokio::net::TcpListener },
}

impl Acceptor {
    fn new(addr: SocketAddr) -> Result<Self, Report> {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;

        Ok(Self::Waiting {
            sleep: tokio::time::sleep(Duration::from_secs(2)),
            socket,
        })
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        // Safety: we do our own pin-projection
        match unsafe { self.as_mut().get_unchecked_mut() } {
            Acceptor::Waiting { sleep, socket } => {
                // Safety: `sleep` is structurally pinned
                let sleep = unsafe { Pin::new_unchecked(sleep) };
                futures::ready!(sleep.poll(cx));

                println!(
                    "Listening on {}...",
                    socket.local_addr()?.as_socket().unwrap()
                );
                socket.listen(128)?;
                socket.set_nonblocking(true)?;
                let fd = socket.as_raw_fd();
                // Safety: `fd` comes from a well-initialized and listening `socket2::Socket`
                let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) };
                let ln = tokio::net::TcpListener::from_std(ln)?;
                let mut state = Self::Listening { ln };

                // Safety: we never use `sleep` anymore, and `socket` is `Unpin`
                std::mem::swap(unsafe { self.as_mut().get_unchecked_mut() }, &mut state);
                match state {
                    Acceptor::Waiting { socket, .. } => {
                        // necessary to avoid closing the socket on drop
                        std::mem::forget(socket)
                    }
                    _ => unreachable!(),
                };

                match unsafe { self.get_unchecked_mut() } {
                    Acceptor::Listening { ln } => {
                        let (stream, _) = futures::ready!(ln.poll_accept(cx)?);
                        Some(Ok(stream)).into()
                    }
                    _ => unreachable!(),
                }
            }
            Acceptor::Listening { ln } => {
                let (stream, _) = futures::ready!(ln.poll_accept(cx)?);
                Some(Ok(stream)).into()
            }
        }
    }
}

This works just fine.

The thing is... I just... I just don't like it. I'm so tired of writing poll-style functions in Rust. So tired. Sure, I feel smart. This is like Sudoku for extra-nerds. Yey, let's be fun at parties together.

But like... where's my async? Where's my await. I don't want to think about pinning if I don't have to.

So let's try to uhh simplify this.

Here's where my mind went first: how about we have a Listener struct with an async method that does everything we need it to?

Rust code
enum Listener {
    Waiting { socket: Socket },
    Listening { ln: tokio::net::TcpListener },
}

impl Listener {
    fn new(addr: SocketAddr) -> Result<Self, Report> {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;
        Ok(Self::Waiting { socket })
    }

    async fn accept(&mut self) -> Result<TcpStream, Report> {
        match self {
            Listener::Waiting { socket } => {
                tokio::time::sleep(Duration::from_secs(2)).await;

                println!(
                    "Listening on {}...",
                    socket.local_addr()?.as_socket().unwrap()
                );
                socket.listen(128)?;
                socket.set_nonblocking(true)?;
                let fd = socket.as_raw_fd();
                // Safety: `fd` comes from a well-initialized and listening `socket2::Socket`
                let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) };
                let ln = tokio::net::TcpListener::from_std(ln)?;
                let mut state = Self::Listening { ln };

                std::mem::swap(self, &mut state);
                match state {
                    Listener::Waiting { socket } => {
                        // necessary to avoid closing the socket on drop
                        std::mem::forget(socket)
                    }
                    _ => unreachable!(),
                };

                match self {
                    Listener::Listening { ln } => Ok(ln.accept().await?.0),
                    _ => unreachable!(),
                }
            }
            Listener::Listening { ln } => Ok(ln.accept().await?.0),
        }
    }
}

Much clearer! And then all we need to do is... adapt that into the Accept interface, right?

Which is as easy as, uh, mhhh.

Well you want to try something like that, right?

Rust code
struct Acceptor {
    listener: Listener,
    fut: BoxFuture<'static, Result<TcpStream, Report>>,
}

impl Acceptor {
    fn from_listener(listener: Listener) -> Self {
        let fut = listener.accept();
        Self {
            listener,
            fut: Box::pin(fut),
        }
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        // project!
        let (listener, fut) = unsafe {
            let this = self.get_unchecked_mut();
            (&mut this.listener, Pin::new_unchecked(&mut this.fut))
        };

        let res = futures::ready!(fut.poll(cx));
        self.get_mut().fut = Box::pin(listener.accept());
        Some(res).into()
    }
}

We can't name the type of the future returned by Listener::accept, so we have to box it. That means we have to pick a lifetime and uhh...

Shell session
$ cargo check
error[E0759]: `self` has an anonymous lifetime `'_` but it needs to satisfy a `'static` lifetime requirement
   --> src/main.rs:114:29
    |
109 |         mut self: Pin<&mut Self>,
    |                   -------------- this data with an anonymous lifetime `'_`...
...
114 |             let this = self.get_unchecked_mut();
    |                             ^^^^^^^^^^^^^^^^^ ...is used here...
...
119 |         self.get_mut().fut = Box::pin(listener.accept());
    |                              --------------------------- ...and is required to live as long as `'static` here

Yeah. It ain't static, that's for sure. But I mean, what else can we do? What we have here is a self-referential struct: fut holds a mutable reference to listener - so one field borrows the other.

I even made a whole video about it!

Which is fine, afaik, as long as we "manually match up their lifetimes": if we were to drop listener and then use fut, things would go very wrong.

But like... if we're so convinced it's fine, we can straight up lie to the compiler's face. WHICH YOU SHOULD NEVER DO, but let this be a lesson:

Rust code
struct Acceptor {
    listener: Listener,
    fut: BoxFuture<'static, Result<TcpStream, Report>>,
}

impl Acceptor {
    fn from_listener(mut listener: Listener) -> Self {
        let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>;
        // Safety: transmuting to the static lifetime 💀
        let fut = unsafe { std::mem::transmute(fut) };
        Self { listener, fut }
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        // project!
        let (listener, fut) = unsafe {
            let this = self.as_mut().get_unchecked_mut();
            (&mut this.listener, Pin::new_unchecked(&mut this.fut))
        };

        let res = futures::ready!(fut.poll(cx));
        // Safety: transmuting to the static lifetime 💀
        let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>;
        self.get_mut().fut = unsafe { std::mem::transmute(fut) };
        Some(res).into()
    }
}

Does this work?

Shell session
$ cargo run
   Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia)
    Finished dev [unoptimized + debuginfo] target(s) in 1.35s
     Running `target/debug/nostalgia`
Binding...
Error: error accepting connection: Socket operation on non-socket (os error 88)

Caused by:
    Socket operation on non-socket (os error 88)

Location:
    src/main.rs:25:5

No! It doesn't! Woops, we broke an invariant. We let a future borrow mutably from Listener, and then we moved it.

We really stepped in it this time.

Of course I know how to fix it, because I ran into three other odd issues (memory corruption, yay!) while trying this particular crime.

You know one way to make the listener stay in place? Shove it on the heap!

Rust code
struct Acceptor {
    // 👇
    listener: Box<Listener>,
    fut: BoxFuture<'static, Result<TcpStream, Report>>,
}

impl Acceptor {
    fn from_listener(listener: Listener) -> Self {
        // 👇
        let mut listener = Box::new(listener);
        let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>;
        // Safety: transmuting to the static lifetime 💀
        let fut: BoxFuture<'static, _> = unsafe { std::mem::transmute(fut) };

        // Safety: we're moving a _pointer_ to `Listener`, the listener itself
        // is not moved, it stays pinned.
        Self { listener, fut }
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        // project!
        let (listener, fut) = unsafe {
            let this = self.as_mut().get_unchecked_mut();
            (&mut this.listener, Pin::new_unchecked(&mut this.fut))
        };

        let res = futures::ready!(fut.poll(cx));
        // Safety: transmuting to the static lifetime 💀
        let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>;
        self.get_mut().fut = unsafe { std::mem::transmute(fut) };
        Some(res).into()
    }
}

And now it works! But at what cost? At what cost?

We can do better than this. Much better.

Just move stuff. That's it. That's the whole tweet.

It is trivial to avoid self-referential structs in this case.

We can simply... change our function signature to this:

Rust code
impl Listener {
    fn new(addr: SocketAddr) -> Result<Self, Report> {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;
        Ok(Self::Waiting { socket })
    }

    //    now taking 👇 ownership
    async fn accept(mut self) -> Result<(Self, TcpStream), Report> {
        match self {
            Listener::Waiting { socket } => {
                tokio::time::sleep(Duration::from_secs(2)).await;

                println!(
                    "Listening on {}...",
                    socket.local_addr()?.as_socket().unwrap()
                );
                socket.listen(128)?;
                socket.set_nonblocking(true)?;
                let fd = socket.as_raw_fd();
                // Safety: `fd` comes from a well-initialized and listening `socket2::Socket`
                let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }
                std::mem::forget(socket);
                let ln = tokio::net::TcpListener::from_std(ln)?;
                
                let conn = ln.accept().await?.0;
                Ok((Self::Listening { ln }, conn))
            }
            Listener::Listening { ref mut ln } => {
                let conn = ln.accept().await?.0;
                Ok((self, conn))
            }
        }
    }
}

This gets rid of the match with unreachable! arms, which is nice!

And then our acceptor becomes simply this:

Rust code
struct Acceptor {
    fut: BoxFuture<'static, Result<(Listener, TcpStream), Report>>,
}

impl Acceptor {
    fn from_listener(listener: Listener) -> Self {
        Self {
            fut: Box::pin(listener.accept()),
        }
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        let res = futures::ready!(self.fut.poll_unpin(cx));
        let (listener, stream) = match res {
            Ok(tup) => tup,
            Err(e) => return Some(Err(e)).into(),
        };
        self.fut = Box::pin(listener.accept());
        Some(Ok(stream)).into()
    }
}

But... we could've done that without changing the function signature at all. If we go back to the &mut self version, we can simply implement Acceptor like this:

Rust code
struct Acceptor {
    fut: BoxFuture<'static, (Listener, Result<TcpStream, Report>)>,
}

impl Acceptor {
    fn from_listener(mut listener: Listener) -> Self {
        Self {
            fut: Box::pin(async move {
                let res = listener.accept().await;
                (listener, res)
            }),
        }
    }
}

impl Accept for Acceptor {
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        let (mut listener, res) = futures::ready!(self.fut.poll_unpin(cx));
        self.fut = Box::pin(async move {
            let res = listener.accept().await;
            (listener, res)
        });
        Some(res).into()
    }
}

There! Isn't that nice? The listener is moved into the async block (the future), but the whole future is boxed. The future is a self-referential struct, we just don't need to be particularly careful with it.

This pattern is so common that there's even a name for it: unfold. It's also a function from the futures crate that gives us a Stream.

Rust code
struct Acceptor<S>(S);

impl Listener {
    fn into_acceptor(self) -> Acceptor<impl Stream<Item = Result<TcpStream, Report>>> {
        Acceptor(unfold(self, |mut ln| async move {
            let stream = ln.accept().await;
            Some((stream, ln))
        }))
    }
}

impl<S> Accept for Acceptor<S>
where
    S: Stream<Item = Result<TcpStream, Report>>,
{
    type Conn = TcpStream;
    type Error = Report;

    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        // project!
        let stream = unsafe { self.map_unchecked_mut(|this| &mut this.0) };
        futures::ready!(stream.poll_next(cx)).into()
    }
}

Not gonna lie, I was pretty happy when I found this.

But we can go even further! This is such a common pattern, hyper ships with an accept::from_stream method!

TOML markup
# in `Cargo.toml`

[dependencies]
hyper = { version = "0.14", features = ["http1", "tcp", "server", "stream"] }
# new: stream feature

And now we don't even need an Acceptor struct at all:

Rust code
impl Listener {
    fn into_acceptor(self) -> impl Accept<Conn = TcpStream, Error = Report> {
        hyper::server::accept::from_stream(unfold(self, |mut ln| async move {
            let stream = ln.accept().await;
            Some((stream, ln))
        }))
    }
}

And it does the exact same thing.

Heck, we can go even further. Think unfold is confusing? Not ready to hop on the functional programming train yet?

Then have a macro!

Shell session
$ cargo add async-stream
    Updating 'https://github.com/rust-lang/crates.io-index' index
      Adding async-stream v0.3.3 to dependencies
Rust code
impl Listener {
    fn into_acceptor(mut self) -> impl Accept<Conn = TcpStream, Error = Report> {
        from_stream(async_stream::stream! {
            loop {
                yield self.accept().await;
            }
        })
    }
}

It really doesn't get any better than this.

Or does it?

Simplifying the listener

After I posted this article, /u/usr_bin_nya on twitter pointed out that we can apply two more simplifications.

The first one is: instead of doing an std::mem::forget dance, we can use into_raw_fd, which transfers ownership of the underlying file descriptor.

It's actually awkward to do it with the current design of Listener, because we're building a tokio::net::TcpListener before letting go of the socket2::Socket. We could do it if we had a third enum variant, a pattern that's actually quite common when writing state machines in Rust:

Rust code
enum Listener {
    Waiting { socket: Socket },
    Listening { ln: tokio::net::TcpListener },
    // 👇 new variant
    Transition,
}

impl Listener {
    fn new(addr: SocketAddr) -> Result<Self, Report> {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;
        Ok(Self::Waiting { socket })
    }

    async fn accept(&mut self) -> Result<TcpStream, Report> {
        match self {
            Listener::Waiting { socket } => {
                tokio::time::sleep(Duration::from_secs(2)).await;

                println!(
                    "Listening on {}...",
                    socket.local_addr()?.as_socket().unwrap()
                );

                // 👇 swapping 'Transition' in so we can take ownership of socket
                let mut state = Self::Transition;
                std::mem::swap(self, &mut state);

                // 👇 that kind of awkward code is super common when writing
                // Rust state machines by hand. there's crates that make it
                // better.
                let socket = match state {
                    Self::Waiting { socket } => socket,
                    _ => unreachable!(),
                };

                socket.listen(128)?;
                socket.set_nonblocking(true)?;
                // 👇 Using `into_raw_fd` instead of `as_raw_fd`
                let fd = socket.into_raw_fd();
                // Safety: `fd` comes from a well-initialized and listening `socket2::Socket`
                let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) };
                let ln = tokio::net::TcpListener::from_std(ln)?;

                *self = Self::Listening { ln };

                match self {
                    Listener::Listening { ln } => Ok(ln.accept().await?.0),
                    _ => unreachable!(),
                }
            }
            Listener::Listening { ln } => Ok(ln.accept().await?.0),
            // 👇 Using unreachable! because reaching that part would be a bug
            Listener::Transition => unreachable!(),
        }
    }
}

But as it turns out, it gets much, much better still. Just like we discovered unfold before, there is a try_flatten_stream method provided by TryFutureExt, in the futures crate.

It lets us turn a "future that returns a stream" into "a stream".

And with it, our whole acceptor becomes this:

Rust code
fn delayed_acceptor(
    addr: SocketAddr,
    delay: Duration,
) -> impl Accept<Conn = TcpStream, Error = std::io::Error> {
    from_stream(
        async move {
            let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
            println!("Binding...");
            socket.bind(&addr.into())?;
            tokio::time::sleep(delay).await;
            println!("Listening...");
            socket.listen(128)?;
            socket.set_nonblocking(true)?;

            let ln = tokio::net::TcpListener::from_std(unsafe {
                std::net::TcpListener::from_raw_fd(socket.into_raw_fd())
            })?;
            let stream = unfold(ln, |ln| async move {
                let stream = ln.accept().await.map(|(stream, _)| stream);
                Some((stream, ln))
            });
            Ok(stream)
        }
        .try_flatten_stream(),
    )
}

And can be used like this:

Rust code
    let acc = delayed_acceptor(addr, Duration::from_secs(2));

    hyper::Server::builder(acc) // etc.

And if we're willing to use the async-stream macros, we can simplify that some more!

Rust code
fn delayed_acceptor(
    addr: SocketAddr,
    delay: Duration,
) -> impl Accept<Conn = TcpStream, Error = std::io::Error> {
    from_stream(async_stream::try_stream! {
        let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
        println!("Binding...");
        socket.bind(&addr.into())?;
        tokio::time::sleep(delay).await;
        println!("Listening...");
        socket.listen(128)?;
        socket.set_nonblocking(true)?;

        let ln = tokio::net::TcpListener::from_std(unsafe {
            std::net::TcpListener::from_raw_fd(socket.into_raw_fd())
        })?;

        loop {
            yield ln.accept().await?.0;
        }
    })
}

Now it doesn't get any better. For now.

Update: it does get even better, thanks to /u/Shadow0133 on reddit for suggesting this; we don't need any unsafe code at all, since there's an impl From<socket2::Socket> for std::net::TcpListener.

This code:

Rust code
        let ln = tokio::net::TcpListener::from_std(unsafe {
            std::net::TcpListener::from_raw_fd(socket.into_raw_fd())
        })?;

Becomes:

Rust code
        let ln = tokio::net::TcpListener::from_std(socket.into())?;

How neat!

One last thing...

Yes?

What's your new favorite http framework in Rust?

axum! It lets me not care about hyperisms/towerisms most of the time but still dive down into them / tap into their ecosystem whenever I want. I really recommend looking at it.

I hope you've enjoyed this journey through some of the best and most awkward bits of async Rust, and until next time, take exemplary care of yourself.

If you liked what you saw, please support my work!

Patreon logo Become a Patron

Latest video

video cover image
Getting good at SNES games through DLL injection

Are you ever confronted with a problem and then think to yourself "wait a minute, I know how to code?" — that's exactly what happened there.

Watch now

You can watch more videos over there