mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-30 06:21:07 +00:00
Make SSE less dependent on tokio (#3154)
This commit is contained in:
parent
bf7c5fc5f3
commit
6ad76dd9a4
@ -4,7 +4,6 @@ use http::{header, HeaderValue, StatusCode};
|
|||||||
|
|
||||||
mod redirect;
|
mod redirect;
|
||||||
|
|
||||||
#[cfg(feature = "tokio")]
|
|
||||||
pub mod sse;
|
pub mod sse;
|
||||||
|
|
||||||
#[doc(no_inline)]
|
#[doc(no_inline)]
|
||||||
@ -27,7 +26,6 @@ pub use axum_core::response::{
|
|||||||
pub use self::redirect::Redirect;
|
pub use self::redirect::Redirect;
|
||||||
|
|
||||||
#[doc(inline)]
|
#[doc(inline)]
|
||||||
#[cfg(feature = "tokio")]
|
|
||||||
pub use sse::Sse;
|
pub use sse::Sse;
|
||||||
|
|
||||||
/// An HTML response.
|
/// An HTML response.
|
||||||
|
@ -38,21 +38,18 @@ use futures_util::stream::{Stream, TryStream};
|
|||||||
use http_body::Frame;
|
use http_body::Frame;
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
fmt,
|
fmt, mem,
|
||||||
future::Future,
|
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{ready, Context, Poll},
|
task::{ready, Context, Poll},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use sync_wrapper::SyncWrapper;
|
use sync_wrapper::SyncWrapper;
|
||||||
use tokio::time::Sleep;
|
|
||||||
|
|
||||||
/// An SSE response
|
/// An SSE response
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub struct Sse<S> {
|
pub struct Sse<S> {
|
||||||
stream: S,
|
stream: S,
|
||||||
keep_alive: Option<KeepAlive>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Sse<S> {
|
impl<S> Sse<S> {
|
||||||
@ -65,18 +62,17 @@ impl<S> Sse<S> {
|
|||||||
S: TryStream<Ok = Event> + Send + 'static,
|
S: TryStream<Ok = Event> + Send + 'static,
|
||||||
S::Error: Into<BoxError>,
|
S::Error: Into<BoxError>,
|
||||||
{
|
{
|
||||||
Sse {
|
Sse { stream }
|
||||||
stream,
|
|
||||||
keep_alive: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configure the interval between keep-alive messages.
|
/// Configure the interval between keep-alive messages.
|
||||||
///
|
///
|
||||||
/// Defaults to no keep-alive messages.
|
/// Defaults to no keep-alive messages.
|
||||||
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
|
#[cfg(feature = "tokio")]
|
||||||
self.keep_alive = Some(keep_alive);
|
pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
|
||||||
self
|
Sse {
|
||||||
|
stream: KeepAliveStream::new(keep_alive, self.stream),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +80,6 @@ impl<S> fmt::Debug for Sse<S> {
|
|||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Sse")
|
f.debug_struct("Sse")
|
||||||
.field("stream", &format_args!("{}", std::any::type_name::<S>()))
|
.field("stream", &format_args!("{}", std::any::type_name::<S>()))
|
||||||
.field("keep_alive", &self.keep_alive)
|
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,7 +97,6 @@ where
|
|||||||
],
|
],
|
||||||
Body::new(SseBody {
|
Body::new(SseBody {
|
||||||
event_stream: SyncWrapper::new(self.stream),
|
event_stream: SyncWrapper::new(self.stream),
|
||||||
keep_alive: self.keep_alive.map(KeepAliveStream::new),
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
@ -113,8 +107,6 @@ pin_project! {
|
|||||||
struct SseBody<S> {
|
struct SseBody<S> {
|
||||||
#[pin]
|
#[pin]
|
||||||
event_stream: SyncWrapper<S>,
|
event_stream: SyncWrapper<S>,
|
||||||
#[pin]
|
|
||||||
keep_alive: Option<KeepAliveStream>,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,35 +123,67 @@ where
|
|||||||
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
|
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
|
|
||||||
match this.event_stream.get_pin_mut().poll_next(cx) {
|
match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
|
||||||
Poll::Pending => {
|
Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
|
||||||
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
|
Some(Err(error)) => Poll::Ready(Some(Err(error))),
|
||||||
keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
|
None => Poll::Ready(None),
|
||||||
} else {
|
}
|
||||||
Poll::Pending
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The state of an event's buffer.
|
||||||
|
///
|
||||||
|
/// This type allows creating events in a `const` context
|
||||||
|
/// by using a finalized buffer.
|
||||||
|
///
|
||||||
|
/// While the buffer is active, more bytes can be written to it.
|
||||||
|
/// Once finalized, it's immutable and cheap to clone.
|
||||||
|
/// The buffer is active during the event building, but eventually
|
||||||
|
/// becomes finalized to send http body frames as [`Bytes`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Buffer {
|
||||||
|
Active(BytesMut),
|
||||||
|
Finalized(Bytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Buffer {
|
||||||
|
/// Returns a mutable reference to the internal buffer.
|
||||||
|
///
|
||||||
|
/// If the buffer was finalized, this method creates
|
||||||
|
/// a new active buffer with the previous contents.
|
||||||
|
fn as_mut(&mut self) -> &mut BytesMut {
|
||||||
|
match self {
|
||||||
|
Buffer::Active(bytes_mut) => bytes_mut,
|
||||||
|
Buffer::Finalized(bytes) => {
|
||||||
|
*self = Buffer::Active(BytesMut::from(mem::take(bytes)));
|
||||||
|
match self {
|
||||||
|
Buffer::Active(bytes_mut) => bytes_mut,
|
||||||
|
Buffer::Finalized(_) => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Poll::Ready(Some(Ok(event))) => {
|
|
||||||
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
|
|
||||||
keep_alive.reset();
|
|
||||||
}
|
|
||||||
Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
|
|
||||||
}
|
|
||||||
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
|
|
||||||
Poll::Ready(None) => Poll::Ready(None),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server-sent event
|
/// Server-sent event
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub struct Event {
|
pub struct Event {
|
||||||
buffer: BytesMut,
|
buffer: Buffer,
|
||||||
flags: EventFlags,
|
flags: EventFlags,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Event {
|
impl Event {
|
||||||
|
/// Default keep-alive event
|
||||||
|
pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
|
||||||
|
|
||||||
|
const fn finalized(bytes: Bytes) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: Buffer::Finalized(bytes),
|
||||||
|
flags: EventFlags::from_bits(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the event's data data field(s) (`data: <content>`)
|
/// Set the event's data data field(s) (`data: <content>`)
|
||||||
///
|
///
|
||||||
/// Newlines in `data` will automatically be broken across `data: ` fields.
|
/// Newlines in `data` will automatically be broken across `data: ` fields.
|
||||||
@ -179,7 +203,7 @@ impl Event {
|
|||||||
T: AsRef<str>,
|
T: AsRef<str>,
|
||||||
{
|
{
|
||||||
if self.flags.contains(EventFlags::HAS_DATA) {
|
if self.flags.contains(EventFlags::HAS_DATA) {
|
||||||
panic!("Called `EventBuilder::data` multiple times");
|
panic!("Called `Event::data` multiple times");
|
||||||
}
|
}
|
||||||
|
|
||||||
for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
|
for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
|
||||||
@ -222,13 +246,14 @@ impl Event {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if self.flags.contains(EventFlags::HAS_DATA) {
|
if self.flags.contains(EventFlags::HAS_DATA) {
|
||||||
panic!("Called `EventBuilder::json_data` multiple times");
|
panic!("Called `Event::json_data` multiple times");
|
||||||
}
|
}
|
||||||
|
|
||||||
self.buffer.extend_from_slice(b"data: ");
|
let buffer = self.buffer.as_mut();
|
||||||
serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
|
buffer.extend_from_slice(b"data: ");
|
||||||
|
serde_json::to_writer(IgnoreNewLines(buffer.writer()), &data)
|
||||||
.map_err(axum_core::Error::new)?;
|
.map_err(axum_core::Error::new)?;
|
||||||
self.buffer.put_u8(b'\n');
|
buffer.put_u8(b'\n');
|
||||||
|
|
||||||
self.flags.insert(EventFlags::HAS_DATA);
|
self.flags.insert(EventFlags::HAS_DATA);
|
||||||
|
|
||||||
@ -272,7 +297,7 @@ impl Event {
|
|||||||
T: AsRef<str>,
|
T: AsRef<str>,
|
||||||
{
|
{
|
||||||
if self.flags.contains(EventFlags::HAS_EVENT) {
|
if self.flags.contains(EventFlags::HAS_EVENT) {
|
||||||
panic!("Called `EventBuilder::event` multiple times");
|
panic!("Called `Event::event` multiple times");
|
||||||
}
|
}
|
||||||
self.flags.insert(EventFlags::HAS_EVENT);
|
self.flags.insert(EventFlags::HAS_EVENT);
|
||||||
|
|
||||||
@ -292,33 +317,32 @@ impl Event {
|
|||||||
/// Panics if this function has already been called on this event.
|
/// Panics if this function has already been called on this event.
|
||||||
pub fn retry(mut self, duration: Duration) -> Event {
|
pub fn retry(mut self, duration: Duration) -> Event {
|
||||||
if self.flags.contains(EventFlags::HAS_RETRY) {
|
if self.flags.contains(EventFlags::HAS_RETRY) {
|
||||||
panic!("Called `EventBuilder::retry` multiple times");
|
panic!("Called `Event::retry` multiple times");
|
||||||
}
|
}
|
||||||
self.flags.insert(EventFlags::HAS_RETRY);
|
self.flags.insert(EventFlags::HAS_RETRY);
|
||||||
|
|
||||||
self.buffer.extend_from_slice(b"retry:");
|
let buffer = self.buffer.as_mut();
|
||||||
|
buffer.extend_from_slice(b"retry:");
|
||||||
|
|
||||||
let secs = duration.as_secs();
|
let secs = duration.as_secs();
|
||||||
let millis = duration.subsec_millis();
|
let millis = duration.subsec_millis();
|
||||||
|
|
||||||
if secs > 0 {
|
if secs > 0 {
|
||||||
// format seconds
|
// format seconds
|
||||||
self.buffer
|
buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
|
||||||
.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
|
|
||||||
|
|
||||||
// pad milliseconds
|
// pad milliseconds
|
||||||
if millis < 10 {
|
if millis < 10 {
|
||||||
self.buffer.extend_from_slice(b"00");
|
buffer.extend_from_slice(b"00");
|
||||||
} else if millis < 100 {
|
} else if millis < 100 {
|
||||||
self.buffer.extend_from_slice(b"0");
|
buffer.extend_from_slice(b"0");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// format milliseconds
|
// format milliseconds
|
||||||
self.buffer
|
buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
|
||||||
.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
|
|
||||||
|
|
||||||
self.buffer.put_u8(b'\n');
|
buffer.put_u8(b'\n');
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -340,7 +364,7 @@ impl Event {
|
|||||||
T: AsRef<str>,
|
T: AsRef<str>,
|
||||||
{
|
{
|
||||||
if self.flags.contains(EventFlags::HAS_ID) {
|
if self.flags.contains(EventFlags::HAS_ID) {
|
||||||
panic!("Called `EventBuilder::id` multiple times");
|
panic!("Called `Event::id` multiple times");
|
||||||
}
|
}
|
||||||
self.flags.insert(EventFlags::HAS_ID);
|
self.flags.insert(EventFlags::HAS_ID);
|
||||||
|
|
||||||
@ -362,20 +386,36 @@ impl Event {
|
|||||||
None,
|
None,
|
||||||
"SSE field value cannot contain newlines or carriage returns",
|
"SSE field value cannot contain newlines or carriage returns",
|
||||||
);
|
);
|
||||||
self.buffer.extend_from_slice(name.as_bytes());
|
|
||||||
self.buffer.put_u8(b':');
|
let buffer = self.buffer.as_mut();
|
||||||
self.buffer.put_u8(b' ');
|
buffer.extend_from_slice(name.as_bytes());
|
||||||
self.buffer.extend_from_slice(value);
|
buffer.put_u8(b':');
|
||||||
self.buffer.put_u8(b'\n');
|
buffer.put_u8(b' ');
|
||||||
|
buffer.extend_from_slice(value);
|
||||||
|
buffer.put_u8(b'\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finalize(mut self) -> Bytes {
|
fn finalize(self) -> Bytes {
|
||||||
self.buffer.put_u8(b'\n');
|
match self.buffer {
|
||||||
self.buffer.freeze()
|
Buffer::Finalized(bytes) => bytes,
|
||||||
|
Buffer::Active(mut bytes_mut) => {
|
||||||
|
bytes_mut.put_u8(b'\n');
|
||||||
|
bytes_mut.freeze()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Copy, Clone, PartialEq)]
|
impl Default for Event {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: Buffer::Active(BytesMut::new()),
|
||||||
|
flags: EventFlags::from_bits(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||||
struct EventFlags(u8);
|
struct EventFlags(u8);
|
||||||
|
|
||||||
impl EventFlags {
|
impl EventFlags {
|
||||||
@ -406,7 +446,7 @@ impl EventFlags {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub struct KeepAlive {
|
pub struct KeepAlive {
|
||||||
event: Bytes,
|
event: Event,
|
||||||
max_interval: Duration,
|
max_interval: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -414,7 +454,7 @@ impl KeepAlive {
|
|||||||
/// Create a new `KeepAlive`.
|
/// Create a new `KeepAlive`.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
event: Bytes::from_static(b":\n\n"),
|
event: Event::DEFAULT_KEEP_ALIVE,
|
||||||
max_interval: Duration::from_secs(15),
|
max_interval: Duration::from_secs(15),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -451,7 +491,7 @@ impl KeepAlive {
|
|||||||
/// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
|
/// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
|
||||||
/// comments.
|
/// comments.
|
||||||
pub fn event(mut self, event: Event) -> Self {
|
pub fn event(mut self, event: Event) -> Self {
|
||||||
self.event = event.finalize();
|
self.event = Event::finalized(event.finalize());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -462,19 +502,25 @@ impl Default for KeepAlive {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio")]
|
||||||
pin_project! {
|
pin_project! {
|
||||||
|
/// A wrapper around a stream that produces keep-alive events
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct KeepAliveStream {
|
pub struct KeepAliveStream<S> {
|
||||||
keep_alive: KeepAlive,
|
|
||||||
#[pin]
|
#[pin]
|
||||||
alive_timer: Sleep,
|
alive_timer: tokio::time::Sleep,
|
||||||
|
#[pin]
|
||||||
|
inner: S,
|
||||||
|
keep_alive: KeepAlive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KeepAliveStream {
|
#[cfg(feature = "tokio")]
|
||||||
fn new(keep_alive: KeepAlive) -> Self {
|
impl<S> KeepAliveStream<S> {
|
||||||
|
fn new(keep_alive: KeepAlive, inner: S) -> Self {
|
||||||
Self {
|
Self {
|
||||||
alive_timer: tokio::time::sleep(keep_alive.max_interval),
|
alive_timer: tokio::time::sleep(keep_alive.max_interval),
|
||||||
|
inner,
|
||||||
keep_alive,
|
keep_alive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -484,17 +530,38 @@ impl KeepAliveStream {
|
|||||||
this.alive_timer
|
this.alive_timer
|
||||||
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
|
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
|
#[cfg(feature = "tokio")]
|
||||||
let this = self.as_mut().project();
|
impl<S, E> Stream for KeepAliveStream<S>
|
||||||
|
where
|
||||||
|
S: Stream<Item = Result<Event, E>>,
|
||||||
|
{
|
||||||
|
type Item = Result<Event, E>;
|
||||||
|
|
||||||
ready!(this.alive_timer.poll(cx));
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
use std::future::Future;
|
||||||
|
|
||||||
let event = this.keep_alive.event.clone();
|
let mut this = self.as_mut().project();
|
||||||
|
|
||||||
self.reset();
|
match this.inner.as_mut().poll_next(cx) {
|
||||||
|
Poll::Ready(Some(Ok(event))) => {
|
||||||
|
self.reset();
|
||||||
|
|
||||||
Poll::Ready(event)
|
Poll::Ready(Some(Ok(event)))
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Pending => {
|
||||||
|
ready!(this.alive_timer.poll(cx));
|
||||||
|
|
||||||
|
let event = this.keep_alive.event.clone();
|
||||||
|
|
||||||
|
self.reset();
|
||||||
|
|
||||||
|
Poll::Ready(Some(Ok(event)))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user