Futures Nostalgia
Up until recently, hyper was my favorite Rust HTTP framework. It's low-level, but that gives you a lot of control over what happens.
Here's what a sample hyper application would look like:
$ cargo new nostalgia Created binary (application) `nostalgia` package
$ cd nostalgia $ cargo add hyper@0.14 --features "http1 tcp server" Updating 'https://github.com/rust-lang/crates.io-index' index Adding hyper v0.14 to dependencies with features: ["http1", "tcp", "server"] $ cargo add tokio@1 --features "full" Updating 'https://github.com/rust-lang/crates.io-index' index Adding tokio v1 to dependencies with features: ["full"]
use std::{ convert::Infallible, future::{ready, Ready}, task::{Context, Poll}, }; use hyper::{server::conn::AddrStream, service::Service, Body, Request, Response, Server}; #[tokio::main] async fn main() { Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(MyServiceFactory) .await .unwrap(); } struct MyServiceFactory; impl Service<&AddrStream> for MyServiceFactory { type Response = MyService; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: &AddrStream) -> Self::Future { println!("Accepted connection from {}", req.remote_addr()); ready(Ok(MyService)) } } struct MyService; impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { println!("Handling {req:?}"); ready(Ok(Response::builder().body("Hello World!\n".into()).unwrap())) } }
$ cargo run cargo run Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia) Finished dev [unoptimized + debuginfo] target(s) in 1.20s Running `target/debug/nostalgia` (server runs peacefully in the background)
Aaaand half the readers have closed the page already.
What? Let me show how it works at least!
See, you can curl it, and it works:
$ curl 0:1025 Hello World!
Wait a minute hold on - what kind of address is that?
Well... omitted octets in IPv4 addresses are filled with zeroes, so 127.1
is
127.0.0.1
for example...
So... 0 is 0.0.0.0
? Isn't that the address we listen on when we want to accept
traffic from all network interfaces? How does that work?
I'll uh... I'll get back to you on that. But it does work.
For completeness, here's what it shows in the terminal pane where the server is running:
Accepted connection from 127.0.0.1:50408 Handling Request { method: GET, uri: /, version: HTTP/1.1, headers: {"host": "0.0.0.0:1025", "user-agent": "curl/7.79.1", "accept": "*/*"}, body: Body(Empty) }
Okay, cool. That's an ungodly amount of code though - I could do the same in, like, ten lines of G-
Ahhh but that's not the point! The point is that the design of hyper is delightfully perceptible to the naked eye! It's 90% the tower Service trait and 10%, huh... waves hands HTTP stuff.
HTTP stuff?
Yeah, you know! Easy. Text-based protocol, couple headers, some chunking if needed, the occasional trailer, cheerio good day sir.
...and HTTP/2?
Oh yeah, lil' bit of binary, adaptive windows, header compression, add multiplexing to taste. Nothing too hard there. Really it's mostly just the Service trait, look at it.
...and TLS?
Eh that's all rustls or OpenSSL or some fork thereof, but who cares, LOOK AT THAT TRAIT:
struct MyServiceFactory; impl Service<&AddrStream> for MyServiceFactory { type Response = MyService; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: &AddrStream) -> Self::Future { println!("Accepted connection from {}", req.remote_addr()); ready(Ok(MyService)) } }
Isn't it beautiful?
What... what am I supposed to see?
Backpressure, bear! The cornerstone of a nutritious breakfast stable
application server.
See, before spawning any future onto the executor, it asks us if we're ready!
With poll_ready
. And if we're not, it... well it waits until we are.
That means we can control how many concurrent connections there can be to our service!
What would that look like?
Ah jeeze bear, that's not really the point of the article, but if you insist, I suppose we could, uh... use a semaphore maybe?
But first we should probably keep track of how many connections we have concurrently...
Maybe we keep an Arc<AtomicU64>
in MyServiceFactory
, which we increment
when we accept a connection, and decrement when we drop a MyService
?
#[tokio::main] async fn main() { Server::bind(&([127, 0, 0, 1], 1025).into()) // 👇 previously was a unit struct (just `MyServiceFactory`) .serve(MyServiceFactory::default()) .await .unwrap(); } // 👇 Now holding an atomically-reference-counted atomic counter #[derive(Default)] struct MyServiceFactory { num_connected: Arc<AtomicU64>, } impl Service<&AddrStream> for MyServiceFactory { // (cut: everything except for call) fn call(&mut self, req: &AddrStream) -> Self::Future { let prev = self.num_connected.fetch_add(1, Ordering::SeqCst); println!( "⬆️ {} connections (accepted {})", prev + 1, req.remote_addr() ); ready(Ok(MyService { num_connected: self.num_connected.clone(), })) } } // 👇 Now also holding a counter struct MyService { num_connected: Arc<AtomicU64>, } impl Drop for MyService { fn drop(&mut self) { let prev = self.num_connected.fetch_sub(1, Ordering::SeqCst); println!("⬇️ {} connections (dropped)", prev - 1); } } impl Service<Request<Body>> for MyService { // (cut: everything except call) fn call(&mut self, req: Request<Body>) -> Self::Future { // 👇 made these logs a little nicer println!("{} {}", req.method(), req.uri()); // otherwise the same ready(Ok(Response::builder() .body("Hello World!\n".into()) .unwrap())) } }
And now, a single curl request results in these logs:
$ cargo run --quiet ⬆️ 1 connections (accepted 127.0.0.1:50416) GET / ⬇️ 0 connections (dropped)
But we can also make requests by hand:
$ socat - TCP4:0:1025 GET / HTTP/1.1 > HTTP/1.1 200 OK > content-length: 13 > date: Sat, 02 Apr 2022 23:59:58 GMT > > Hello World! GET /ahAH HTTP/1.1 > HTTP/1.1 200 OK > content-length: 13 > date: Sun, 03 Apr 2022 00:00:04 GMT > > Hello World!
(Note I've prefixed response lines by >
manually. The "GET" lines followed by
two newlines are typed manually)
And from the server's perspective, this looks like this:
$ cargo run --quiet ⬆️ 1 connections (accepted 127.0.0.1:50420) GET / GET /ahAH ⬇️ 0 connections (dropped)
Two requests! From the same connection! Who needs h2/h3/quic? Huh?
Me. I do. Come onnnnn.
I don't feel like doing enough typing to test our limiter though... how about we use a load testing tool? Like oha?
$ oha http://127.0.0.1:1025 (a lot of output ensues) Status code distribution: [200] 200 responses
That generated... a /lot/ of output on the server side. But I've redirected it to a file, because I'm a forward-thinking young lad.
Like so:
$ cargo run --release --quiet | tee /tmp/server-log.txt (server output is printed as normal, but is also logged to the file)
Which means I can now easily count how many connections we had at our peakest of peaks:
$ cat /tmp/server-log.txt | grep '⬆' | cut -d ' ' -f 2 | sort -n | tail -1 50
You may think that's a useless use of cat, but actually, cat is union now, so think twice about giving its work away to someone else.
50! What a suspicious value, almost as if it was oha's default...
$ oha --help (cut) -c <N_WORKERS> Number of workers to run concurrently. You may should increase limit to number of open files for larger `-c`. [default: 50]
Mh yep!
Okay but we're trying to limit it. Let's use... a semaphore!
$ cargo add tokio-util@0.7 Updating 'https://github.com/rust-lang/crates.io-index' index Adding tokio-util v0.7 to dependencies
struct MyServiceFactory { num_connected: Arc<AtomicU64>, semaphore: PollSemaphore, } impl Default for MyServiceFactory { fn default() -> Self { Self { num_connected: Default::default(), semaphore: PollSemaphore::new(Arc::new(Semaphore::new(5))), } } }
There! I've set a limit of 5 permits, and so now all we have to do is... try to
acquire a semaphore from poll_ready
I suppose!
$ cargo add futures Updating 'https://github.com/rust-lang/crates.io-index' index Adding futures v0.3.21 to dependencies
impl Service<&AddrStream> for MyServiceFactory { // (cut: all except for poll_ready) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { let permit = futures::ready!(self.semaphore.poll_acquire(cx)).unwrap(); Ok(()).into() } }
There! That handy little
ready! macro lets us
extract the successful type of a Poll<T>
, and we unwrap that
Option<OwnedSemaphorePermit>
because, well, we never close the semaphore.
But uh... where do we store that permit? We're not creating a MyService
from
poll_ready
, merely checking for readiness!
Okay we'll need one more field:
struct MyServiceFactory { num_connected: Arc<AtomicU64>, semaphore: PollSemaphore, // 👇 new! permit: Option<OwnedSemaphorePermit>, }
Adjusting the impl Default
is left as an exercise to the reader.
And now, in poll_ready
, we simply try to acquire a permit if we don't have one
already.
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { if self.permit.is_none() { self.permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap()); } Ok(()).into() }
And in call
, well, we just take it! And explode loudly in case someone hasn't
called poll_ready
before:
fn call(&mut self, req: &AddrStream) -> Self::Future { let permit = self.permit.take().expect( "you didn't drive me to readiness did you? you know that's a tower crime right?", ); let prev = self.num_connected.fetch_add(1, Ordering::SeqCst); println!( "⬆️ {} connections (accepted {})", prev + 1, req.remote_addr() ); ready(Ok(MyService { num_connected: self.num_connected.clone(), permit, })) }
Of course we need somewhere to store that permit:
struct MyService { num_connected: Arc<AtomicU64>, permit: OwnedSemaphorePermit, }
rustc
will complain that this is dead code, but it's very much not - holding
that type is proof that we're allowed to run within the concurrency limits we've
set for ourselves.
We... no, you can fix that by renaming it to _permit
. I won't.
So now, let's try our oha test again!
$ cargo run --release --quiet | tee /tmp/server-log.txt # in another pane: $ oha http://127.0.0.1:1025 # after stopping the server: $ cat /tmp/server-log.txt | grep '⬆' | cut -d ' ' -f 2 | sort -n | tail -1 5
Wonderful! How often do things work exactly as expected? Not so often, I tell you what.
Oh, and we can actually remove our atomic counter, because semaphores do their own counting.
Which makes our complete program this:
use std::{ convert::Infallible, future::{ready, Ready}, sync::Arc, task::{Context, Poll}, }; use hyper::{server::conn::AddrStream, service::Service, Body, Request, Response, Server}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::sync::PollSemaphore; #[tokio::main] async fn main() { Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(MyServiceFactory::default()) .await .unwrap(); } const MAX_CONNS: usize = 5; struct MyServiceFactory { semaphore: PollSemaphore, permit: Option<OwnedSemaphorePermit>, } impl Default for MyServiceFactory { fn default() -> Self { Self { semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))), permit: None, } } } impl Service<&AddrStream> for MyServiceFactory { type Response = MyService; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { if self.permit.is_none() { self.permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap()); } Ok(()).into() } fn call(&mut self, _req: &AddrStream) -> Self::Future { let permit = self.permit.take().expect( "you didn't drive me to readiness did you? you know that's a tower crime right?", ); println!( "⬆️ {} connections", MAX_CONNS - self.semaphore.available_permits() ); ready(Ok(MyService { _permit: permit })) } } struct MyService { _permit: OwnedSemaphorePermit, } impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { println!("{} {}", req.method(), req.uri()); ready(Ok(Response::builder() .body("Hello World!\n".into()) .unwrap())) } }
That's not an async server though
I mean... yes it is an async server, but I see what you mean - we just immediately reply with "Hello World!", we're not even pretending to think about it a little bit.
So let's pretend. Now our Future
type cannot be Ready
anymore. We could make
a custom future, like so:
struct PretendFuture { sleep: Sleep, response: Option<Response<Body>>, } impl Future for PretendFuture { type Output = Result<Response<Body>, Infallible>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { futures::ready!( unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.sleep) }.poll(cx) ); Ok(unsafe { self.get_unchecked_mut() }.response.take().unwrap()).into() } }
Which we could then use as our Future
type in MyService
:
impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = PretendFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { println!("{} {}", req.method(), req.uri()); PretendFuture { sleep: tokio::time::sleep(Duration::from_millis(250)), response: Some(Response::builder().body("Hello World!\n".into()).unwrap()), } } }
And now our latency has SHOT UP, from this:
$ oha etc. (cut) Latency distribution: 10% in 0.0001 secs 25% in 0.0001 secs 50% in 0.0002 secs 75% in 0.0015 secs 90% in 0.0091 secs 95% in 0.0095 secs 99% in 0.0103 secs
To this:
$ oha etc. (cut) Latency distribution: 10% in 0.2514 secs 25% in 0.2521 secs 50% in 0.2522 secs 75% in 0.2528 secs 90% in 9.3291 secs 95% in 9.8342 secs 99% in 10.0861 secs
Wait, why are some requests taking 10 seconds?
Concurrency limits! If we raise MAX_CONNS
to 50, the p99 falls to 256 milliseconds.
Now we have something that looks, more or less, like a real-world application. From the outside at least.
But before we move on, and more importantly, before the unsafe police rains on
me, let's use pin-project-lite to get
rid of those gnarly map_unchecked_mut
and get_unchecked_mut
:
$ cargo add pin-project-lite Updating 'https://github.com/rust-lang/crates.io-index' index Adding pin-project-lite v0.2.8 to dependencies
pin_project_lite::pin_project! { struct PretendFuture { #[pin] sleep: Sleep, response: Option<Response<Body>>, } } impl Future for PretendFuture { type Output = Result<Response<Body>, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.project(); futures::ready!(this.sleep.poll(cx)); Ok(this.response.take().unwrap()).into() } }
There! Don't you feel better already? And remember: pinning is either structural for a field, or it's not. I think that says it all.
?? No it bloody doesn't?
Ah, would you rather have the long-form explanation?
...perhaps not.
But then, another question arises, from the endless well of question that is me, late at night, in lieu of resting.
How would we limit the number of in-flight requests?
Well... couldn't the future hold a semaphore permit itself?
Ah, that might work! Let's try it.
const MAX_CONNS: usize = 50; const MAX_INFLIGHT_REQUESTS: usize = 5; struct MyServiceFactory { conn_semaphore: PollSemaphore, reqs_semaphore: PollSemaphore, permit: Option<OwnedSemaphorePermit>, } impl Default for MyServiceFactory { fn default() -> Self { Self { conn_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))), reqs_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS))), permit: None, } } }
Now to hand off the semaphore to every instance of MyService
, by cloning it -
which doesn't change the number of available permits, by the way, it just clones
the inner Arc<Semaphore>
, so it's all good.
impl Service<&AddrStream> for MyServiceFactory { // (cut: everything except call) fn call(&mut self, _req: &AddrStream) -> Self::Future { // (cut: everything except this:) ready(Ok(MyService { _conn_permit: permit, // 👇 semaphore: self.reqs_semaphore.clone(), // 👇 (for later) reqs_permit: None, })) } }
Now MyService
needs to hold a permit and the requests semaphore. And an
optional reqs_permit
, filled in from poll_ready
:
struct MyService { _conn_permit: OwnedSemaphorePermit, semaphore: PollSemaphore, reqs_permit: Option<OwnedSemaphorePermit>, } impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = PretendFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { // 👇 if self.reqs_permit.is_none() { self.reqs_permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap()); } Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { // 👇 let permit = self.reqs_permit.take().expect( "you didn't drive me to readiness did you? you know that's a tower crime right?", ); println!("{} {}", req.method(), req.uri()); PretendFuture { sleep: tokio::time::sleep(Duration::from_millis(250)), response: Some(Response::builder().body("Hello World!\n".into()).unwrap()), // 👇 permit, } } }
And the PretendFuture
simply needs to hold onto the OwnedSemaphorePermit
:
again, its presence is proof enough. We can't build a PretendFuture
without
having an OwnedSemaphorePermit
, and dropping the PretendFuture
(which happens
after it's polled to completion) also releases the permit.
pin_project_lite::pin_project! { struct PretendFuture { #[pin] sleep: Sleep, response: Option<Response<Body>>, // 👇 permit: OwnedSemaphorePermit, } }
And now, try as you might, you're not gonna get more than... 20 requests per second out of this web server. Because only 5 requests can be in-flight at any given time, and each request takes about 1/4 of a second.
$ oha http://127.0.0.1:1025 Summary: Success rate: 1.0000 Total: 10.0610 secs Slowest: 2.5176 secs Fastest: 0.2519 secs Average: 2.2320 secs Requests/sec: 19.8788
Forget everything you just learned
So the exercise we just went through is neat, because it shows you how hyper and tower actually work: turning connections into "http services" is done through a service. And turning requests into responses is also done through a service.
And services need to be driven to readiness, before we can call them, which gives us a future, which we can then await, or spawn on a runtime, etc. - something we don't really have to worry about with the way we use hyper, but that we could definitely worry about if such was our wish.
But now, tokei informs me that we have 122 lines of Rust in this project, which seems a tad excessive. In fact, everyone who's already stopped reading probably walked away thinking "dang this language is verbose"!
And who could blame em. But they're gone now, and they're not coming back. It's just us. Hey. How are you holding up? Yeah. I feel that.
So let's remove a bunch of code!
First off, this isn't 2018 anymore - we actually have async blocks and an await postfix keyword now.
We don't really need PretendFuture
, we can just have an async block and box
it!
impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { if self.reqs_permit.is_none() { self.reqs_permit = Some(futures::ready!(self.semaphore.poll_acquire(cx)).unwrap()); } Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { let permit = self.reqs_permit.take().expect( "you didn't drive me to readiness did you? you know that's a tower crime right?", ); println!("{} {}", req.method(), req.uri()); Box::pin(async move { let _permit = permit; tokio::time::sleep(Duration::from_millis(250)).await; Ok(Response::builder().body("Hello World!\n".into()).unwrap()) }) } }
Boom. That async block gets turned into a generator, which implements Future (or
so I'm told), and it captures ("closes over" if you want to get schmancy)
anything it needs (in this case really just permit
), which means it keeps
"our place in the queue" (disclaimer: not actually a queue) until it's dropped.
Hey, we can even remove a dependency:
$ cargo rm pin-project-lite Removing pin-project-lite from dependencies
There. How often have you seen me do that, huh?
...thirteen times.
Huh.
Now, we're this close to having clippy yell at us for type complexity with this line:
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
So I'm gonna take a defensive stance and go with this instead:
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
(That's futures::future::BoxFuture
). And look, the lifetime is explicit now!
Did you know Box<dyn T>
actually means Box<dyn T + 'static>
? Well it does!
Our code does the exact same thing, but we're down to 100 lines.
Let's keep going!
Because the Service
trait is so well thought-out (they even did a whole piece
about it), there's
a bunch of re-usable services out there!
For example, there's ConcurrencyLimit, which does... precisely what you think it does.
Here's one way to use it:
$ cargo add tower --features limit Updating 'https://github.com/rust-lang/crates.io-index' index Adding tower v0.4.12 to dependencies with features: ["limit"]
We still want a shared semaphore:
struct MyServiceFactory { conn_semaphore: PollSemaphore, // 👇 was: PollSemaphore reqs_semaphore: Arc<Semaphore>, permit: Option<OwnedSemaphorePermit>, } impl Default for MyServiceFactory { fn default() -> Self { Self { conn_semaphore: PollSemaphore::new(Arc::new(Semaphore::new(MAX_CONNS))), // 👇 that changed too reqs_semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)), permit: None, } } }
But now we return a ConcurrencyLimit<MyService>
:
impl Service<&AddrStream> for MyServiceFactory { // 👇 new! now 100% more nested type Response = ConcurrencyLimit<MyService>; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { // we're still limiting concurrency by hand for connections if self.permit.is_none() { self.permit = Some(futures::ready!(self.conn_semaphore.poll_acquire(cx)).unwrap()); } Ok(()).into() } fn call(&mut self, _req: &AddrStream) -> Self::Future { let permit = self.permit.take().expect( "you didn't drive me to readiness did you? you know that's a tower crime right?", ); println!( "⬆️ {} connections", MAX_CONNS - self.conn_semaphore.available_permits() ); // 👇 the nesting occurs here ready(Ok(ConcurrencyLimit::with_semaphore( MyService { _conn_permit: permit, }, self.reqs_semaphore.clone(), ))) } }
And our MyService
service is now considerably simpler - it only concerns
itself with the business logic: pretending to do work for a while, then quickly
throwing something together:
struct MyService { _conn_permit: OwnedSemaphorePermit, } impl Service<Request<Body>> for MyService { type Response = Response<Body>; type Error = Infallible; type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, req: Request<Body>) -> Self::Future { println!("{} {}", req.method(), req.uri()); Box::pin(async move { tokio::time::sleep(Duration::from_millis(250)).await; Ok(Response::builder().body("Hello World!\n".into()).unwrap()) }) } }
So, that's for limiting request concurrency. Can we limit connections the same way?
I fully expected to go there, but after giving it some thought, my professional
opinion is: I don't think so. See, ConcurrencyLimit
limits how many futures
for a service can exist at any given time, but in the case of MyServiceFactory
,
the futures are very short-lived, and yield a service (which is used to handle
requests coming through the newly-established connections, via some tasks that
hyper spawns on the tokio runtime).
So, no dice. But hey, we're down to 94 lines already!
I really want to show you some other stuff now, so I will!
Let's set aside the connections concurrency limiting for now:
const MAX_INFLIGHT_REQUESTS: usize = 5; struct MyServiceFactory { reqs_semaphore: Arc<Semaphore>, } impl Default for MyServiceFactory { fn default() -> Self { Self { reqs_semaphore: Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)), } } } impl Service<&AddrStream> for MyServiceFactory { // 👇 type Response = ConcurrencyLimit<MyService>; type Error = Infallible; type Future = Ready<Result<Self::Response, Self::Error>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { Ok(()).into() } fn call(&mut self, _req: &AddrStream) -> Self::Future { ready(Ok(ConcurrencyLimit::with_semaphore( MyService, self.reqs_semaphore.clone(), ))) } }
Did you know we need none of this? tower provides convenience functions that take closures instead. And hyper re-exports them for extra convenience:
use hyper::service::make_service_fn; const MAX_INFLIGHT_REQUESTS: usize = 5; #[tokio::main] async fn main() { let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)); let app = make_service_fn(move |_stream: &AddrStream| { let sem = sem.clone(); async move { Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore(MyService, sem)) } }); Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(app) .await .unwrap(); }
Boom! Just like that, we got rid of our entire MyServiceFactory
. And the
actual hyper hello world makes a
lot more sense now.
We're down to 51 lines.
Yeah, but we lost connection limiting...
Shhhletskeepgoing.
Did you know we don't need MyService
either?
We don't do anything interesting in its poll_ready
anymore - this all could be
just an async function.
And it can be!
use std::{convert::Infallible, sync::Arc, time::Duration}; use hyper::{ server::conn::AddrStream, service::{make_service_fn, service_fn}, Body, Request, Response, Server, }; use tokio::sync::Semaphore; use tower::limit::ConcurrencyLimit; const MAX_INFLIGHT_REQUESTS: usize = 5; #[tokio::main] async fn main() { let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)); let app = make_service_fn(move |_stream: &AddrStream| { let sem = sem.clone(); async move { Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore( service_fn(|req: Request<Body>| async move { println!("{} {}", req.method(), req.uri()); tokio::time::sleep(Duration::from_millis(250)).await; Ok::<_, Infallible>( Response::builder() .body(Body::from("Hello World!\n")) .unwrap(), ) }), sem, )) } }); Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(app) .await .unwrap(); }
Down to 38 lines, still functionally equivalent.
Now let's stop obsessing over silly metrics like lines of code, and focus on readability instead. I vote we move the service back out of this tower and into its own function:
#[tokio::main] async fn main() { let sem = Arc::new(Semaphore::new(MAX_INFLIGHT_REQUESTS)); let app = make_service_fn(move |_stream: &AddrStream| { let sem = sem.clone(); async move { Ok::<_, Infallible>(ConcurrencyLimit::with_semaphore( service_fn(hello_world), sem, )) } }); Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(app) .await .unwrap(); } async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> { println!("{} {}", req.method(), req.uri()); tokio::time::sleep(Duration::from_millis(250)).await; Ok(Response::builder() .body(Body::from("Hello World!\n")) .unwrap()) }
And, for readability, let's introduce ServiceBuilder
. On top of the Service
trait, tower provides the
Layer trait, which lets you
"decorate" a service:
Its definition is short and elegant:
pub trait Layer<S> { type Service; fn layer(&self, inner: S) -> Self::Service; }
And ServiceBuilder
helps us compose layers together, like so:
// in main let app = make_service_fn(move |_stream: &AddrStream| async move { let svc = ServiceBuilder::new() .layer(ConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS)) .service(service_fn(hello_world)); Ok::<_, Infallible>(svc) });
Wait, where did our semaphore go? What is that actually limiting?
Good catch bear - we actually changed the semantics here. The limit we're adding here is per service, not overall as we had before.
But before we fix that, I just wanted to show this even shorter version:
# in `Cargo.toml` # (cut: everything but the tower dependency) [dependencies] tower = { version = "0.4.12", features = ["limit", "util"] } # 👈 util is new
let app = make_service_fn(move |_stream: &AddrStream| async move { let svc = ServiceBuilder::new() .concurrency_limit(MAX_INFLIGHT_REQUESTS) .service_fn(hello_world); Ok::<_, Infallible>(svc) });
Isn't that pretty?
Pretty... but incorrect 🧐
...and also only 29 lines. Are you sure it isn't easier to convince product that it'll benefit users?
...
Fine fine we'll fix it. Luckily, GlobalConcurrencyLayer has us covered:
let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS); let app = make_service_fn(move |_stream: &AddrStream| { let reqs_limit = reqs_limit.clone(); async move { let svc = ServiceBuilder::new() .layer(reqs_limit) .service_fn(hello_world); Ok::<_, Infallible>(svc) } });
Theeere.
Oh by the way, we're not doing anything asynchronous here, so we don't really
need to use an async move
block, and we don't need to do that dance where we
clone before the async move
block because the resulting future must be
'static
... we've already used std::future::ready
in this article (not to be
confused with the futures::ready!
macro), so we can totally do this:
let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS); let app = make_service_fn(move |_stream: &AddrStream| { std::future::ready(Ok::<_, Infallible>( ServiceBuilder::new() .layer(reqs_limit.clone()) .service_fn(hello_world), )) });
But now... now I'm looking at this and I'm thinking... how hard could it be to bring back connections concurrency?
I have a terrible, terrible idea.
See, ServiceBuilder
has a then
method, that lets you execute a function
after a service.
Let's give it a try:
let app = make_service_fn(move |_stream: &AddrStream| { std::future::ready(Ok::<_, Infallible>( ServiceBuilder::new() .layer(reqs_limit.clone()) .then(|res: Result<Response<Body>, Infallible>| async move { println!("Just served a request!"); res }) .service_fn(hello_world), )) });
$ cargo run --release Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia) Finished release [optimized] target(s) in 9.11s Running `target/release/nostalgia` GET / Just served a request! GET / Just served a request! GET / Just served a request! # (I'm running curl in another pane, the server isn't haunted)
In that case... I'm much less interested in running something after the service, and much more interested in the state the closure can capture...
...say maybe it could capture a permit 😈
let conns_limit = Arc::new(Semaphore::new(MAX_CONNS)); let reqs_limit = GlobalConcurrencyLimitLayer::new(MAX_INFLIGHT_REQUESTS); let app = make_service_fn(move |_stream: &AddrStream| { let conns_limit = conns_limit.clone(); let reqs_limit = reqs_limit.clone(); async move { let permit = Arc::new(conns_limit.acquire_owned().await.unwrap()); Ok::<_, Infallible>( ServiceBuilder::new() .layer(reqs_limit) .then(move |res: Result<Response<Body>, Infallible>| { drop(permit); std::future::ready(res) }) .service_fn(hello_world), ) } });
This is definitely a crime, and not what the innocent authors of ThenLayer
had in mind, but hey, if it's stupid and it works then it must be 4AM:
$ oha etc. (cut) Latency distribution: 10% in 0.2517 secs 25% in 0.2521 secs 50% in 0.2524 secs 75% in 0.2532 secs 90% in 9.3343 secs 95% in 9.8390 secs 99% in 10.0909 secs
Now let's try raising that MAX_CONNS
value to 50:
$ oha etc. (cut) Latency distribution: 10% in 1.2601 secs 25% in 2.5188 secs 50% in 2.5198 secs 75% in 2.5203 secs 90% in 2.5209 secs 95% in 2.5210 secs 99% in 2.5215 secs
Yup, that checks ou- wait a minute. Those results look really different from before... Mhh. did we have a max in-flight requests limit before? Let's try removing that:
use std::{convert::Infallible, sync::Arc, time::Duration}; use hyper::{server::conn::AddrStream, service::make_service_fn, Body, Request, Response, Server}; use tokio::sync::Semaphore; use tower::ServiceBuilder; const MAX_CONNS: usize = 50; #[tokio::main] async fn main() { let conns_limit = Arc::new(Semaphore::new(MAX_CONNS)); let app = make_service_fn(move |_stream: &AddrStream| { let conns_limit = conns_limit.clone(); async move { let permit = Arc::new(conns_limit.acquire_owned().await.unwrap()); Ok::<_, Infallible>( ServiceBuilder::new() .then(move |res: Result<Response<Body>, Infallible>| { drop(permit); std::future::ready(res) }) .service_fn(hello_world), ) } }); Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(app) .await .unwrap(); } async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> { println!("{} {}", req.method(), req.uri()); tokio::time::sleep(Duration::from_millis(250)).await; Ok(Response::builder() .body(Body::from("Hello World!\n")) .unwrap()) }
$ Latency distribution: 10% in 0.2517 secs 25% in 0.2519 secs 50% in 0.2521 secs 75% in 0.2524 secs 90% in 0.2531 secs 95% in 0.2543 secs 99% in 0.2550 secs
Okay, that makes more sense. See, if we have a connections limit of 50 but a requests limit of 5, oha thinks it can issue 50 requests at a time (it's over http/1, there's no multiplexing here), but only the first 5 get serviced immediately. The others wait their turn.
~ backpressure ~
And now for the grand finale: let's remove the conns limit too. And the sleep.
use std::convert::Infallible; use hyper::{ server::conn::AddrStream, service::{make_service_fn, service_fn}, Body, Request, Response, Server, }; #[tokio::main] async fn main() { let app = make_service_fn(move |_stream: &AddrStream| async move { Ok::<_, Infallible>(service_fn(hello_world)) }); Server::bind(&([127, 0, 0, 1], 1025).into()) .serve(app) .await .unwrap(); } async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> { println!("{} {}", req.method(), req.uri()); Ok(Response::builder() .body(Body::from("Hello World!\n")) .unwrap()) }
And now we have a hyper hello world. And we know exactly what's going on.
Well... there's one piece we haven't really talked about yet. And that's the piece I originally wanted to talk about.
Accepting connections
We've been doing this all along:
Server::bind(&([127, 0, 0, 1], 1025).into())
Without really questioning it.
Well, it's not the only way! We can build a TCP listener ourselves:
use tokio::net::TcpListener; #[tokio::main] async fn main() { let app = make_service_fn(move |_stream: &AddrStream| async move { Ok::<_, Infallible>(service_fn(hello_world)) }); // 👇 let ln = TcpListener::bind("127.0.0.1:1025").await.unwrap(); Server::builder(AddrIncoming::from_listener(ln).unwrap()) .serve(app) .await .unwrap(); }
That works just as well.
But, you see, I found myself wanting do something weird...
I have this test suite with hundreds of tests, and most of them start listening
on some port. Because they all run concurrently, on a machine that's busy doing
other things, I can't exactly specify the port every time. Instead, I just pass
in port 0
, and let the operating system pick a free port for me, like so:
let ln = TcpListener::bind("127.0.0.1:0").await.unwrap(); println!("Listening on {}", ln.local_addr().unwrap());
$ cargo run --release Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia) Finished release [optimized] target(s) in 1.28s Running `target/release/nostalgia` Listening on 127.0.0.1:35137
But in some of the tests, we want to simulate a service that's... up, but not quite up. Like, it has reserved a port (so another service can't bind to it, not without some naughty socket flags we're not using there), but it hasn't started listening for connections yet, so any connection attempts will be met with a RST packet, which is TCP for: "we see what you're going for but nah": on the client side, we'd see something like "connection refused".
(As opposed to just dropping the packet, which is TCP for "miss me with that shit")
And that's a super duper basic endeavor with the BSD socket API (what
essentially everybody uses - even
winsock2), but
Rust APIs like libstd and tokio's TcpListener
helpfully group together the
bind
and listen
calls, which means we... can't do what we want.
Well, until we remember that, as always, there is a crate for that:
$ cargo add socket2 Updating 'https://github.com/rust-lang/crates.io-index' index Adding socket2 v0.4.4 to dependencies
use socket2::{Domain, Protocol, Socket, Type}; use std::net::{SocketAddr, TcpListener}; #[tokio::main] async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP)).unwrap(); socket.bind(&addr.into()).unwrap(); let addr = socket.local_addr().unwrap().as_socket().unwrap(); println!("Bound but not listening on {}", addr); assert!(TcpListener::bind(addr).is_err()); println!("As expected, nobody else can listen on the same address"); println!("Try curling it, it'll fail (press Enter when done)"); std::io::stdin().read_line(&mut String::new()).unwrap(); socket.listen(128).unwrap(); println!("Okay now we're listening (try curling it now, it should hang)"); std::io::stdin().read_line(&mut String::new()).unwrap(); }
$ cargo run Finished dev [unoptimized + debuginfo] target(s) in 0.02s Running `target/debug/nostalgia` Bound but not listening on 127.0.0.1:44277 As expected, nobody else can listen on the same address Try curling it, it'll fail (press Enter when done) # (in another pane) $ curl http://127.0.0.1:44277 curl: (7) Failed to connect to 127.0.0.1 port 44277 after 0 ms: Connection refused # (back to the server's pane, after pressing enter) Okay now we're listening (try curling it now, it should hang) # (in another pane) curl http://127.0.0.1:44277 (it does hang)
At that point, the kernel's TCP stack has accepted the connection on our behalf
(by which I mean it's completed the TCP three-way handshake), and put it in the
accept queue, waiting for our call to accept
(which would pop it from the
queue and let us play with it).
So, that's exactly the behavior we're looking for: bind, wait a bit, then listen. And eventually, start accepting connections.
Once we're done building our socket2::Socket
, we can turn it into a
std::net::TcpListener
, then a tokio::net::TcpListener
quite easily:
println!("Okay now let's accept one connection"); socket.set_nonblocking(true).unwrap(); let fd = socket.as_raw_fd(); std::mem::forget(socket); let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }; let ln = tokio::net::TcpListener::from_std(ln).unwrap(); let (_stream, _) = ln.accept().await.unwrap(); println!("Accepted one conn!");
From the client side, we now see a connection reset:
$ curl 127.0.0.1:38145 curl: (56) Recv failure: Connection reset by peer
tokio::net::TcpListener
can then be turned into an AddrIncoming
, which
we can pass to Server::builder
... simple right?
Well... not that simple. Because by the time we convert it to a
tokio::net::TcpListener
, we have to have called listen
already.
That's why from_raw_fd
is unsafe: not only does it need to be a valid, open
file descriptor, that is not going to get closed by another thread right after,
but it also, in this case, needs to be listening already.
Of course, we could wait until we create the server itself, and I don't know why this occurs to me JUST NOW, but hey, we got that article out of it so let's not question my motives too much.
So... what I'm getting at is that, if we want that behavior (bind, wait a couple
seconds, then listen) with a hyper server, and we haven't realized we could
simply wait to build the server, then we can't use AddrIncoming
: we need to
implement the
Accept
trait ourselves.
It looks like this:
pub trait Accept { type Conn; type Error; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll<Option<Result<Self::Conn, Self::Error>>>; }
So, okay, let's do it without the sleep, as a warm-up:
$ cargo add color-eyre Updating 'https://github.com/rust-lang/crates.io-index' index Adding color-eyre v0.6.1 to dependencies
use color_eyre::Report; use hyper::{ server::accept::Accept, service::{make_service_fn, service_fn}, Body, Request, Response, }; use socket2::{Domain, Protocol, Socket, Type}; use std::{ convert::Infallible, net::SocketAddr, os::unix::prelude::{AsRawFd, FromRawFd}, pin::Pin, task::Context, time::Duration, }; use tokio::net::TcpStream; #[tokio::main] async fn main() -> Result<(), Report> { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let acc = Acceptor::new(addr)?; hyper::Server::builder(acc) .serve(make_service_fn(|_: &TcpStream| async move { Ok::<_, Report>(service_fn(hello_world)) })) .await?; Ok(()) } async fn hello_world(req: Request<Body>) -> Result<Response<Body>, Infallible> { println!("{} {}", req.method(), req.uri()); tokio::time::sleep(Duration::from_millis(250)).await; Ok(Response::builder() .body(Body::from("Hello World!\n")) .unwrap()) } struct Acceptor { ln: tokio::net::TcpListener, } impl Acceptor { fn new(addr: SocketAddr) -> Result<Self, Report> { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; println!( "Listening on {}...", socket.local_addr()?.as_socket().unwrap() ); socket.listen(128)?; socket.set_nonblocking(true)?; let fd = socket.as_raw_fd(); std::mem::forget(socket); let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }; let ln = tokio::net::TcpListener::from_std(ln)?; Ok(Self { ln }) } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> { let (stream, _) = futures::ready!(self.ln.poll_accept(cx)?); Some(Ok(stream)).into() } }
$ cargo run Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia) Finished dev [unoptimized + debuginfo] target(s) in 1.35s Running `target/debug/nostalgia` Binding... Listening on 127.0.0.1:44063... GET / # (in another pane) $ curl 0:44063 Hello World!
And now... let's add the sleep! Well, just as before, we'll need a Sleep
future, and for the rest, well... we'll need to hold onto the Socket
until we
listen and turn it into a tokio::net::TcpListener
. So we'll need an enum
to
know where we're at.
As for the rest, uh... read slowly.
enum Acceptor { Waiting { sleep: Sleep, socket: Socket }, Listening { ln: tokio::net::TcpListener }, } impl Acceptor { fn new(addr: SocketAddr) -> Result<Self, Report> { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; Ok(Self::Waiting { sleep: tokio::time::sleep(Duration::from_secs(2)), socket, }) } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { // Safety: we do our own pin-projection match unsafe { self.as_mut().get_unchecked_mut() } { Acceptor::Waiting { sleep, socket } => { // Safety: `sleep` is structurally pinned let sleep = unsafe { Pin::new_unchecked(sleep) }; futures::ready!(sleep.poll(cx)); println!( "Listening on {}...", socket.local_addr()?.as_socket().unwrap() ); socket.listen(128)?; socket.set_nonblocking(true)?; let fd = socket.as_raw_fd(); // Safety: `fd` comes from a well-initialized and listening `socket2::Socket` let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }; let ln = tokio::net::TcpListener::from_std(ln)?; let mut state = Self::Listening { ln }; // Safety: we never use `sleep` anymore, and `socket` is `Unpin` std::mem::swap(unsafe { self.as_mut().get_unchecked_mut() }, &mut state); match state { Acceptor::Waiting { socket, .. } => { // necessary to avoid closing the socket on drop std::mem::forget(socket) } _ => unreachable!(), }; match unsafe { self.get_unchecked_mut() } { Acceptor::Listening { ln } => { let (stream, _) = futures::ready!(ln.poll_accept(cx)?); Some(Ok(stream)).into() } _ => unreachable!(), } } Acceptor::Listening { ln } => { let (stream, _) = futures::ready!(ln.poll_accept(cx)?); Some(Ok(stream)).into() } } } }
This works just fine.
The thing is... I just... I just don't like it. I'm so tired of writing poll-style functions in Rust. So tired. Sure, I feel smart. This is like Sudoku for extra-nerds. Yey, let's be fun at parties together.
But like... where's my async? Where's my await. I don't want to think about pinning if I don't have to.
So let's try to uhh simplify this.
Here's where my mind went first: how about we have a Listener
struct with an
async method that does everything we need it to?
enum Listener { Waiting { socket: Socket }, Listening { ln: tokio::net::TcpListener }, } impl Listener { fn new(addr: SocketAddr) -> Result<Self, Report> { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; Ok(Self::Waiting { socket }) } async fn accept(&mut self) -> Result<TcpStream, Report> { match self { Listener::Waiting { socket } => { tokio::time::sleep(Duration::from_secs(2)).await; println!( "Listening on {}...", socket.local_addr()?.as_socket().unwrap() ); socket.listen(128)?; socket.set_nonblocking(true)?; let fd = socket.as_raw_fd(); // Safety: `fd` comes from a well-initialized and listening `socket2::Socket` let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }; let ln = tokio::net::TcpListener::from_std(ln)?; let mut state = Self::Listening { ln }; std::mem::swap(self, &mut state); match state { Listener::Waiting { socket } => { // necessary to avoid closing the socket on drop std::mem::forget(socket) } _ => unreachable!(), }; match self { Listener::Listening { ln } => Ok(ln.accept().await?.0), _ => unreachable!(), } } Listener::Listening { ln } => Ok(ln.accept().await?.0), } } }
Much clearer! And then all we need to do is... adapt that into the Accept
interface, right?
Which is as easy as, uh, mhhh.
Well you want to try something like that, right?
struct Acceptor { listener: Listener, fut: BoxFuture<'static, Result<TcpStream, Report>>, } impl Acceptor { fn from_listener(listener: Listener) -> Self { let fut = listener.accept(); Self { listener, fut: Box::pin(fut), } } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { // project! let (listener, fut) = unsafe { let this = self.get_unchecked_mut(); (&mut this.listener, Pin::new_unchecked(&mut this.fut)) }; let res = futures::ready!(fut.poll(cx)); self.get_mut().fut = Box::pin(listener.accept()); Some(res).into() } }
We can't name the type of the future returned by Listener::accept
, so we have
to box it. That means we have to pick a lifetime and uhh...
$ cargo check error[E0759]: `self` has an anonymous lifetime `'_` but it needs to satisfy a `'static` lifetime requirement --> src/main.rs:114:29 | 109 | mut self: Pin<&mut Self>, | -------------- this data with an anonymous lifetime `'_`... ... 114 | let this = self.get_unchecked_mut(); | ^^^^^^^^^^^^^^^^^ ...is used here... ... 119 | self.get_mut().fut = Box::pin(listener.accept()); | --------------------------- ...and is required to live as long as `'static` here
Yeah. It ain't static, that's for sure. But I mean, what else can we do? What we
have here is a self-referential struct: fut
holds a mutable reference to
listener
- so one field borrows the other.
I even made a whole video about it!
Which is fine, afaik, as long as we "manually match up their lifetimes": if we
were to drop listener
and then use fut
, things would go very wrong.
But like... if we're so convinced it's fine, we can straight up lie to the compiler's face. WHICH YOU SHOULD NEVER DO, but let this be a lesson:
struct Acceptor { listener: Listener, fut: BoxFuture<'static, Result<TcpStream, Report>>, } impl Acceptor { fn from_listener(mut listener: Listener) -> Self { let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>; // Safety: transmuting to the static lifetime 💀 let fut = unsafe { std::mem::transmute(fut) }; Self { listener, fut } } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { // project! let (listener, fut) = unsafe { let this = self.as_mut().get_unchecked_mut(); (&mut this.listener, Pin::new_unchecked(&mut this.fut)) }; let res = futures::ready!(fut.poll(cx)); // Safety: transmuting to the static lifetime 💀 let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>; self.get_mut().fut = unsafe { std::mem::transmute(fut) }; Some(res).into() } }
Does this work?
$ cargo run Compiling nostalgia v0.1.0 (/home/amos/bearcove/nostalgia) Finished dev [unoptimized + debuginfo] target(s) in 1.35s Running `target/debug/nostalgia` Binding... Error: error accepting connection: Socket operation on non-socket (os error 88) Caused by: Socket operation on non-socket (os error 88) Location: src/main.rs:25:5
No! It doesn't! Woops, we broke an invariant. We let a future borrow mutably
from Listener
, and then we moved it.
We really stepped in it this time.
Of course I know how to fix it, because I ran into three other odd issues (memory corruption, yay!) while trying this particular crime.
You know one way to make the listener stay in place? Shove it on the heap!
struct Acceptor { // 👇 listener: Box<Listener>, fut: BoxFuture<'static, Result<TcpStream, Report>>, } impl Acceptor { fn from_listener(listener: Listener) -> Self { // 👇 let mut listener = Box::new(listener); let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>; // Safety: transmuting to the static lifetime 💀 let fut: BoxFuture<'static, _> = unsafe { std::mem::transmute(fut) }; // Safety: we're moving a _pointer_ to `Listener`, the listener itself // is not moved, it stays pinned. Self { listener, fut } } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { // project! let (listener, fut) = unsafe { let this = self.as_mut().get_unchecked_mut(); (&mut this.listener, Pin::new_unchecked(&mut this.fut)) }; let res = futures::ready!(fut.poll(cx)); // Safety: transmuting to the static lifetime 💀 let fut = Box::pin(listener.accept()) as BoxFuture<'_, _>; self.get_mut().fut = unsafe { std::mem::transmute(fut) }; Some(res).into() } }
And now it works! But at what cost? At what cost?
We can do better than this. Much better.
Just move stuff. That's it. That's the whole tweet.
It is trivial to avoid self-referential structs in this case.
We can simply... change our function signature to this:
impl Listener { fn new(addr: SocketAddr) -> Result<Self, Report> { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; Ok(Self::Waiting { socket }) } // now taking 👇 ownership async fn accept(mut self) -> Result<(Self, TcpStream), Report> { match self { Listener::Waiting { socket } => { tokio::time::sleep(Duration::from_secs(2)).await; println!( "Listening on {}...", socket.local_addr()?.as_socket().unwrap() ); socket.listen(128)?; socket.set_nonblocking(true)?; let fd = socket.as_raw_fd(); // Safety: `fd` comes from a well-initialized and listening `socket2::Socket` let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) } std::mem::forget(socket); let ln = tokio::net::TcpListener::from_std(ln)?; let conn = ln.accept().await?.0; Ok((Self::Listening { ln }, conn)) } Listener::Listening { ref mut ln } => { let conn = ln.accept().await?.0; Ok((self, conn)) } } } }
This gets rid of the match with unreachable!
arms, which is nice!
And then our acceptor becomes simply this:
struct Acceptor { fut: BoxFuture<'static, Result<(Listener, TcpStream), Report>>, } impl Acceptor { fn from_listener(listener: Listener) -> Self { Self { fut: Box::pin(listener.accept()), } } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { let res = futures::ready!(self.fut.poll_unpin(cx)); let (listener, stream) = match res { Ok(tup) => tup, Err(e) => return Some(Err(e)).into(), }; self.fut = Box::pin(listener.accept()); Some(Ok(stream)).into() } }
But... we could've done that without changing the function signature at all. If
we go back to the &mut self
version, we can simply implement Acceptor
like
this:
struct Acceptor { fut: BoxFuture<'static, (Listener, Result<TcpStream, Report>)>, } impl Acceptor { fn from_listener(mut listener: Listener) -> Self { Self { fut: Box::pin(async move { let res = listener.accept().await; (listener, res) }), } } } impl Accept for Acceptor { type Conn = TcpStream; type Error = Report; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { let (mut listener, res) = futures::ready!(self.fut.poll_unpin(cx)); self.fut = Box::pin(async move { let res = listener.accept().await; (listener, res) }); Some(res).into() } }
There! Isn't that nice? The listener is moved into the async block (the future), but the whole future is boxed. The future is a self-referential struct, we just don't need to be particularly careful with it.
This pattern is so common that there's even a name for it:
unfold. It's
also a function from the futures
crate that gives us a Stream.
struct Acceptor<S>(S); impl Listener { fn into_acceptor(self) -> Acceptor<impl Stream<Item = Result<TcpStream, Report>>> { Acceptor(unfold(self, |mut ln| async move { let stream = ln.accept().await; Some((stream, ln)) })) } } impl<S> Accept for Acceptor<S> where S: Stream<Item = Result<TcpStream, Report>>, { type Conn = TcpStream; type Error = Report; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { // project! let stream = unsafe { self.map_unchecked_mut(|this| &mut this.0) }; futures::ready!(stream.poll_next(cx)).into() } }
Not gonna lie, I was pretty happy when I found this.
But we can go even further! This is such a common pattern, hyper
ships with
an accept::from_stream
method!
# in `Cargo.toml` [dependencies] hyper = { version = "0.14", features = ["http1", "tcp", "server", "stream"] } # new: stream feature
And now we don't even need an Acceptor
struct at all:
impl Listener { fn into_acceptor(self) -> impl Accept<Conn = TcpStream, Error = Report> { hyper::server::accept::from_stream(unfold(self, |mut ln| async move { let stream = ln.accept().await; Some((stream, ln)) })) } }
And it does the exact same thing.
Heck, we can go even further. Think unfold
is confusing? Not ready to hop on
the functional programming train yet?
Then have a macro!
$ cargo add async-stream Updating 'https://github.com/rust-lang/crates.io-index' index Adding async-stream v0.3.3 to dependencies
impl Listener { fn into_acceptor(mut self) -> impl Accept<Conn = TcpStream, Error = Report> { from_stream(async_stream::stream! { loop { yield self.accept().await; } }) } }
It really doesn't get any better than this.
Or does it?
Simplifying the listener
After I posted this article, /u/usr_bin_nya
on twitter pointed out that we can
apply two more simplifications.
The first one is: instead of doing an std::mem::forget
dance, we can use
into_raw_fd,
which transfers ownership of the underlying file descriptor.
It's actually awkward to do it with the current design of Listener
, because
we're building a tokio::net::TcpListener
before letting go of the
socket2::Socket
. We could do it if we had a third enum variant, a pattern
that's actually quite common when writing state machines in Rust:
enum Listener { Waiting { socket: Socket }, Listening { ln: tokio::net::TcpListener }, // 👇 new variant Transition, } impl Listener { fn new(addr: SocketAddr) -> Result<Self, Report> { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; Ok(Self::Waiting { socket }) } async fn accept(&mut self) -> Result<TcpStream, Report> { match self { Listener::Waiting { socket } => { tokio::time::sleep(Duration::from_secs(2)).await; println!( "Listening on {}...", socket.local_addr()?.as_socket().unwrap() ); // 👇 swapping 'Transition' in so we can take ownership of socket let mut state = Self::Transition; std::mem::swap(self, &mut state); // 👇 that kind of awkward code is super common when writing // Rust state machines by hand. there's crates that make it // better. let socket = match state { Self::Waiting { socket } => socket, _ => unreachable!(), }; socket.listen(128)?; socket.set_nonblocking(true)?; // 👇 Using `into_raw_fd` instead of `as_raw_fd` let fd = socket.into_raw_fd(); // Safety: `fd` comes from a well-initialized and listening `socket2::Socket` let ln = unsafe { std::net::TcpListener::from_raw_fd(fd) }; let ln = tokio::net::TcpListener::from_std(ln)?; *self = Self::Listening { ln }; match self { Listener::Listening { ln } => Ok(ln.accept().await?.0), _ => unreachable!(), } } Listener::Listening { ln } => Ok(ln.accept().await?.0), // 👇 Using unreachable! because reaching that part would be a bug Listener::Transition => unreachable!(), } } }
But as it turns out, it gets much, much better still. Just like we discovered
unfold
before, there is a try_flatten_stream
method provided by TryFutureExt
, in the futures crate.
It lets us turn a "future that returns a stream" into "a stream".
And with it, our whole acceptor becomes this:
fn delayed_acceptor( addr: SocketAddr, delay: Duration, ) -> impl Accept<Conn = TcpStream, Error = std::io::Error> { from_stream( async move { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; tokio::time::sleep(delay).await; println!("Listening..."); socket.listen(128)?; socket.set_nonblocking(true)?; let ln = tokio::net::TcpListener::from_std(unsafe { std::net::TcpListener::from_raw_fd(socket.into_raw_fd()) })?; let stream = unfold(ln, |ln| async move { let stream = ln.accept().await.map(|(stream, _)| stream); Some((stream, ln)) }); Ok(stream) } .try_flatten_stream(), ) }
And can be used like this:
let acc = delayed_acceptor(addr, Duration::from_secs(2)); hyper::Server::builder(acc) // etc.
And if we're willing to use the async-stream
macros, we can simplify that some
more!
fn delayed_acceptor( addr: SocketAddr, delay: Duration, ) -> impl Accept<Conn = TcpStream, Error = std::io::Error> { from_stream(async_stream::try_stream! { let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; println!("Binding..."); socket.bind(&addr.into())?; tokio::time::sleep(delay).await; println!("Listening..."); socket.listen(128)?; socket.set_nonblocking(true)?; let ln = tokio::net::TcpListener::from_std(unsafe { std::net::TcpListener::from_raw_fd(socket.into_raw_fd()) })?; loop { yield ln.accept().await?.0; } }) }
Now it doesn't get any better. For now.
Update: it does get even better, thanks to /u/Shadow0133
on reddit for
suggesting this; we don't need any unsafe
code at all, since there's an
impl From<socket2::Socket> for std::net::TcpListener
.
This code:
let ln = tokio::net::TcpListener::from_std(unsafe { std::net::TcpListener::from_raw_fd(socket.into_raw_fd()) })?;
Becomes:
let ln = tokio::net::TcpListener::from_std(socket.into())?;
How neat!
One last thing...
Yes?
What's your new favorite http framework in Rust?
axum! It lets me not care about hyperisms/towerisms most of the time but still dive down into them / tap into their ecosystem whenever I want. I really recommend looking at it.
I hope you've enjoyed this journey through some of the best and most awkward bits of async Rust, and until next time, take exemplary care of yourself.
If you liked what you saw, please support my work!