Make SSE less dependent on tokio (#3154)

This commit is contained in:
Nano 2025-05-01 13:54:29 +05:00 committed by GitHub
parent bf7c5fc5f3
commit 6ad76dd9a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 72 deletions

View File

@ -4,7 +4,6 @@ use http::{header, HeaderValue, StatusCode};
mod redirect; mod redirect;
#[cfg(feature = "tokio")]
pub mod sse; pub mod sse;
#[doc(no_inline)] #[doc(no_inline)]
@ -27,7 +26,6 @@ pub use axum_core::response::{
pub use self::redirect::Redirect; pub use self::redirect::Redirect;
#[doc(inline)] #[doc(inline)]
#[cfg(feature = "tokio")]
pub use sse::Sse; pub use sse::Sse;
/// An HTML response. /// An HTML response.

View File

@ -38,21 +38,18 @@ use futures_util::stream::{Stream, TryStream};
use http_body::Frame; use http_body::Frame;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
fmt, fmt, mem,
future::Future,
pin::Pin, pin::Pin,
task::{ready, Context, Poll}, task::{ready, Context, Poll},
time::Duration, time::Duration,
}; };
use sync_wrapper::SyncWrapper; use sync_wrapper::SyncWrapper;
use tokio::time::Sleep;
/// An SSE response /// An SSE response
#[derive(Clone)] #[derive(Clone)]
#[must_use] #[must_use]
pub struct Sse<S> { pub struct Sse<S> {
stream: S, stream: S,
keep_alive: Option<KeepAlive>,
} }
impl<S> Sse<S> { impl<S> Sse<S> {
@ -65,18 +62,17 @@ impl<S> Sse<S> {
S: TryStream<Ok = Event> + Send + 'static, S: TryStream<Ok = Event> + Send + 'static,
S::Error: Into<BoxError>, S::Error: Into<BoxError>,
{ {
Sse { Sse { stream }
stream,
keep_alive: None,
}
} }
/// Configure the interval between keep-alive messages. /// Configure the interval between keep-alive messages.
/// ///
/// Defaults to no keep-alive messages. /// Defaults to no keep-alive messages.
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { #[cfg(feature = "tokio")]
self.keep_alive = Some(keep_alive); pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
self Sse {
stream: KeepAliveStream::new(keep_alive, self.stream),
}
} }
} }
@ -84,7 +80,6 @@ impl<S> fmt::Debug for Sse<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sse") f.debug_struct("Sse")
.field("stream", &format_args!("{}", std::any::type_name::<S>())) .field("stream", &format_args!("{}", std::any::type_name::<S>()))
.field("keep_alive", &self.keep_alive)
.finish() .finish()
} }
} }
@ -102,7 +97,6 @@ where
], ],
Body::new(SseBody { Body::new(SseBody {
event_stream: SyncWrapper::new(self.stream), event_stream: SyncWrapper::new(self.stream),
keep_alive: self.keep_alive.map(KeepAliveStream::new),
}), }),
) )
.into_response() .into_response()
@ -113,8 +107,6 @@ pin_project! {
struct SseBody<S> { struct SseBody<S> {
#[pin] #[pin]
event_stream: SyncWrapper<S>, event_stream: SyncWrapper<S>,
#[pin]
keep_alive: Option<KeepAliveStream>,
} }
} }
@ -131,35 +123,67 @@ where
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project(); let this = self.project();
match this.event_stream.get_pin_mut().poll_next(cx) { match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
Poll::Pending => { Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
if let Some(keep_alive) = this.keep_alive.as_pin_mut() { Some(Err(error)) => Poll::Ready(Some(Err(error))),
keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e)))) None => Poll::Ready(None),
} else { }
Poll::Pending }
}
/// The state of an event's buffer.
///
/// This type allows creating events in a `const` context
/// by using a finalized buffer.
///
/// While the buffer is active, more bytes can be written to it.
/// Once finalized, it's immutable and cheap to clone.
/// The buffer is active during the event building, but eventually
/// becomes finalized to send http body frames as [`Bytes`].
#[derive(Debug, Clone)]
enum Buffer {
Active(BytesMut),
Finalized(Bytes),
}
impl Buffer {
/// Returns a mutable reference to the internal buffer.
///
/// If the buffer was finalized, this method creates
/// a new active buffer with the previous contents.
fn as_mut(&mut self) -> &mut BytesMut {
match self {
Buffer::Active(bytes_mut) => bytes_mut,
Buffer::Finalized(bytes) => {
*self = Buffer::Active(BytesMut::from(mem::take(bytes)));
match self {
Buffer::Active(bytes_mut) => bytes_mut,
Buffer::Finalized(_) => unreachable!(),
} }
} }
Poll::Ready(Some(Ok(event))) => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
}
Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
} }
} }
} }
/// Server-sent event /// Server-sent event
#[derive(Debug, Default, Clone)] #[derive(Debug, Clone)]
#[must_use] #[must_use]
pub struct Event { pub struct Event {
buffer: BytesMut, buffer: Buffer,
flags: EventFlags, flags: EventFlags,
} }
impl Event { impl Event {
/// Default keep-alive event
pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
const fn finalized(bytes: Bytes) -> Self {
Self {
buffer: Buffer::Finalized(bytes),
flags: EventFlags::from_bits(0),
}
}
/// Set the event's data data field(s) (`data: <content>`) /// Set the event's data data field(s) (`data: <content>`)
/// ///
/// Newlines in `data` will automatically be broken across `data: ` fields. /// Newlines in `data` will automatically be broken across `data: ` fields.
@ -179,7 +203,7 @@ impl Event {
T: AsRef<str>, T: AsRef<str>,
{ {
if self.flags.contains(EventFlags::HAS_DATA) { if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::data` multiple times"); panic!("Called `Event::data` multiple times");
} }
for line in memchr_split(b'\n', data.as_ref().as_bytes()) { for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
@ -222,13 +246,14 @@ impl Event {
} }
} }
if self.flags.contains(EventFlags::HAS_DATA) { if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::json_data` multiple times"); panic!("Called `Event::json_data` multiple times");
} }
self.buffer.extend_from_slice(b"data: "); let buffer = self.buffer.as_mut();
serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data) buffer.extend_from_slice(b"data: ");
serde_json::to_writer(IgnoreNewLines(buffer.writer()), &data)
.map_err(axum_core::Error::new)?; .map_err(axum_core::Error::new)?;
self.buffer.put_u8(b'\n'); buffer.put_u8(b'\n');
self.flags.insert(EventFlags::HAS_DATA); self.flags.insert(EventFlags::HAS_DATA);
@ -272,7 +297,7 @@ impl Event {
T: AsRef<str>, T: AsRef<str>,
{ {
if self.flags.contains(EventFlags::HAS_EVENT) { if self.flags.contains(EventFlags::HAS_EVENT) {
panic!("Called `EventBuilder::event` multiple times"); panic!("Called `Event::event` multiple times");
} }
self.flags.insert(EventFlags::HAS_EVENT); self.flags.insert(EventFlags::HAS_EVENT);
@ -292,33 +317,32 @@ impl Event {
/// Panics if this function has already been called on this event. /// Panics if this function has already been called on this event.
pub fn retry(mut self, duration: Duration) -> Event { pub fn retry(mut self, duration: Duration) -> Event {
if self.flags.contains(EventFlags::HAS_RETRY) { if self.flags.contains(EventFlags::HAS_RETRY) {
panic!("Called `EventBuilder::retry` multiple times"); panic!("Called `Event::retry` multiple times");
} }
self.flags.insert(EventFlags::HAS_RETRY); self.flags.insert(EventFlags::HAS_RETRY);
self.buffer.extend_from_slice(b"retry:"); let buffer = self.buffer.as_mut();
buffer.extend_from_slice(b"retry:");
let secs = duration.as_secs(); let secs = duration.as_secs();
let millis = duration.subsec_millis(); let millis = duration.subsec_millis();
if secs > 0 { if secs > 0 {
// format seconds // format seconds
self.buffer buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
// pad milliseconds // pad milliseconds
if millis < 10 { if millis < 10 {
self.buffer.extend_from_slice(b"00"); buffer.extend_from_slice(b"00");
} else if millis < 100 { } else if millis < 100 {
self.buffer.extend_from_slice(b"0"); buffer.extend_from_slice(b"0");
} }
} }
// format milliseconds // format milliseconds
self.buffer buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
self.buffer.put_u8(b'\n'); buffer.put_u8(b'\n');
self self
} }
@ -340,7 +364,7 @@ impl Event {
T: AsRef<str>, T: AsRef<str>,
{ {
if self.flags.contains(EventFlags::HAS_ID) { if self.flags.contains(EventFlags::HAS_ID) {
panic!("Called `EventBuilder::id` multiple times"); panic!("Called `Event::id` multiple times");
} }
self.flags.insert(EventFlags::HAS_ID); self.flags.insert(EventFlags::HAS_ID);
@ -362,20 +386,36 @@ impl Event {
None, None,
"SSE field value cannot contain newlines or carriage returns", "SSE field value cannot contain newlines or carriage returns",
); );
self.buffer.extend_from_slice(name.as_bytes());
self.buffer.put_u8(b':'); let buffer = self.buffer.as_mut();
self.buffer.put_u8(b' '); buffer.extend_from_slice(name.as_bytes());
self.buffer.extend_from_slice(value); buffer.put_u8(b':');
self.buffer.put_u8(b'\n'); buffer.put_u8(b' ');
buffer.extend_from_slice(value);
buffer.put_u8(b'\n');
} }
fn finalize(mut self) -> Bytes { fn finalize(self) -> Bytes {
self.buffer.put_u8(b'\n'); match self.buffer {
self.buffer.freeze() Buffer::Finalized(bytes) => bytes,
Buffer::Active(mut bytes_mut) => {
bytes_mut.put_u8(b'\n');
bytes_mut.freeze()
}
}
} }
} }
#[derive(Default, Debug, Copy, Clone, PartialEq)] impl Default for Event {
fn default() -> Self {
Self {
buffer: Buffer::Active(BytesMut::new()),
flags: EventFlags::from_bits(0),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
struct EventFlags(u8); struct EventFlags(u8);
impl EventFlags { impl EventFlags {
@ -406,7 +446,7 @@ impl EventFlags {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[must_use] #[must_use]
pub struct KeepAlive { pub struct KeepAlive {
event: Bytes, event: Event,
max_interval: Duration, max_interval: Duration,
} }
@ -414,7 +454,7 @@ impl KeepAlive {
/// Create a new `KeepAlive`. /// Create a new `KeepAlive`.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
event: Bytes::from_static(b":\n\n"), event: Event::DEFAULT_KEEP_ALIVE,
max_interval: Duration::from_secs(15), max_interval: Duration::from_secs(15),
} }
} }
@ -451,7 +491,7 @@ impl KeepAlive {
/// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
/// comments. /// comments.
pub fn event(mut self, event: Event) -> Self { pub fn event(mut self, event: Event) -> Self {
self.event = event.finalize(); self.event = Event::finalized(event.finalize());
self self
} }
} }
@ -462,19 +502,25 @@ impl Default for KeepAlive {
} }
} }
#[cfg(feature = "tokio")]
pin_project! { pin_project! {
/// A wrapper around a stream that produces keep-alive events
#[derive(Debug)] #[derive(Debug)]
struct KeepAliveStream { pub struct KeepAliveStream<S> {
keep_alive: KeepAlive,
#[pin] #[pin]
alive_timer: Sleep, alive_timer: tokio::time::Sleep,
#[pin]
inner: S,
keep_alive: KeepAlive,
} }
} }
impl KeepAliveStream { #[cfg(feature = "tokio")]
fn new(keep_alive: KeepAlive) -> Self { impl<S> KeepAliveStream<S> {
fn new(keep_alive: KeepAlive, inner: S) -> Self {
Self { Self {
alive_timer: tokio::time::sleep(keep_alive.max_interval), alive_timer: tokio::time::sleep(keep_alive.max_interval),
inner,
keep_alive, keep_alive,
} }
} }
@ -484,17 +530,38 @@ impl KeepAliveStream {
this.alive_timer this.alive_timer
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval); .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
} }
}
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> { #[cfg(feature = "tokio")]
let this = self.as_mut().project(); impl<S, E> Stream for KeepAliveStream<S>
where
S: Stream<Item = Result<Event, E>>,
{
type Item = Result<Event, E>;
ready!(this.alive_timer.poll(cx)); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use std::future::Future;
let event = this.keep_alive.event.clone(); let mut this = self.as_mut().project();
self.reset(); match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(event))) => {
self.reset();
Poll::Ready(event) Poll::Ready(Some(Ok(event)))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
ready!(this.alive_timer.poll(cx));
let event = this.keep_alive.event.clone();
self.reset();
Poll::Ready(Some(Ok(event)))
}
}
} }
} }