Skip to main content

tokio_util/io/
simplex.rs

1//! Unidirectional byte-oriented channel.
2
3use crate::util::poll_proceed;
4
5use bytes::Buf;
6use bytes::BytesMut;
7use futures_core::ready;
8use std::io::Error as IoError;
9use std::io::ErrorKind as IoErrorKind;
10use std::io::IoSlice;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll, Waker};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16type IoResult<T> = Result<T, IoError>;
17
18const CLOSED_ERROR_MSG: &str = "simplex has been closed";
19
20#[derive(Debug)]
21struct Inner {
22    /// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached
23    backpressure_boundary: usize,
24
25    /// either [`Sender`] or [`Receiver`] is closed
26    is_closed: bool,
27
28    /// Waker used to wake the [`Receiver`]
29    receiver_waker: Option<Waker>,
30
31    /// Waker used to wake the [`Sender`]
32    sender_waker: Option<Waker>,
33
34    /// Buffer used to read and write data
35    buf: BytesMut,
36}
37
38impl Inner {
39    fn with_capacity(capacity: usize) -> Self {
40        Self {
41            backpressure_boundary: capacity,
42            is_closed: false,
43            receiver_waker: None,
44            sender_waker: None,
45            buf: BytesMut::with_capacity(capacity),
46        }
47    }
48
49    fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
50        match self.receiver_waker.as_mut() {
51            Some(old) if old.will_wake(waker) => None,
52            _ => self.receiver_waker.replace(waker.clone()),
53        }
54    }
55
56    fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
57        match self.sender_waker.as_mut() {
58            Some(old) if old.will_wake(waker) => None,
59            _ => self.sender_waker.replace(waker.clone()),
60        }
61    }
62
63    fn take_receiver_waker(&mut self) -> Option<Waker> {
64        self.receiver_waker.take()
65    }
66
67    fn take_sender_waker(&mut self) -> Option<Waker> {
68        self.sender_waker.take()
69    }
70
71    fn is_closed(&self) -> bool {
72        self.is_closed
73    }
74
75    fn close_receiver(&mut self) -> Option<Waker> {
76        self.is_closed = true;
77        self.take_sender_waker()
78    }
79
80    fn close_sender(&mut self) -> Option<Waker> {
81        self.is_closed = true;
82        self.take_receiver_waker()
83    }
84}
85
86/// Receiver of the simplex channel.
87///
88/// # Cancellation safety
89///
90/// The `Receiver` is cancel safe. If it is used as the event in a
91/// [`tokio::select!`] statement and some other branch completes
92/// first, it is guaranteed that no bytes were received on this
93/// channel.
94///
95/// You can still read the remaining data from the buffer
96/// even if the write half has been dropped.
97/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
98///
99/// [`tokio::select!`]: https://docs.rs/tokio/latest/tokio/macro.select.html
100#[derive(Debug)]
101pub struct Receiver {
102    inner: Arc<Mutex<Inner>>,
103}
104
105impl Drop for Receiver {
106    /// This also wakes up the [`Sender`].
107    fn drop(&mut self) {
108        let maybe_waker = {
109            let mut inner = self.inner.lock().unwrap();
110            inner.close_receiver()
111        };
112
113        if let Some(waker) = maybe_waker {
114            waker.wake();
115        }
116    }
117}
118
119impl AsyncRead for Receiver {
120    fn poll_read(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123        buf: &mut ReadBuf<'_>,
124    ) -> Poll<IoResult<()>> {
125        let coop = ready!(poll_proceed(cx));
126
127        let mut inner = self.inner.lock().unwrap();
128
129        let to_read = buf.remaining().min(inner.buf.remaining());
130        if to_read == 0 {
131            if inner.is_closed() || buf.remaining() == 0 {
132                return Poll::Ready(Ok(()));
133            }
134
135            let old_waker = inner.register_receiver_waker(cx.waker());
136            let maybe_waker = inner.take_sender_waker();
137
138            // unlock before waking up and dropping old waker
139            drop(inner);
140            drop(old_waker);
141            if let Some(waker) = maybe_waker {
142                waker.wake();
143            }
144            return Poll::Pending;
145        }
146
147        // this is to avoid starving other tasks
148        coop.made_progress();
149
150        buf.put_slice(&inner.buf[..to_read]);
151        inner.buf.advance(to_read);
152
153        let waker = inner.take_sender_waker();
154        drop(inner); // unlock before waking up
155        if let Some(waker) = waker {
156            waker.wake();
157        }
158
159        Poll::Ready(Ok(()))
160    }
161}
162
163/// Sender of the simplex channel.
164///
165/// # Cancellation safety
166///
167/// The `Sender` is cancel safe. If it is used as the event in a
168/// [`tokio::select!`] statement and some other branch completes
169/// first, it is guaranteed that no bytes were sent on this channel.
170///
171/// # Shutdown
172///
173/// See [`Sender::poll_shutdown`].
174///
175/// [`tokio::select!`]: https://docs.rs/tokio/latest/tokio/macro.select.html
176#[derive(Debug)]
177pub struct Sender {
178    inner: Arc<Mutex<Inner>>,
179}
180
181impl Drop for Sender {
182    /// This also wakes up the [`Receiver`].
183    fn drop(&mut self) {
184        let maybe_waker = {
185            let mut inner = self.inner.lock().unwrap();
186            inner.close_sender()
187        };
188
189        if let Some(waker) = maybe_waker {
190            waker.wake();
191        }
192    }
193}
194
195impl AsyncWrite for Sender {
196    /// # Errors
197    ///
198    /// This method will return [`IoErrorKind::BrokenPipe`]
199    /// if the channel has been closed.
200    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
201        let coop = ready!(poll_proceed(cx));
202
203        let mut inner = self.inner.lock().unwrap();
204
205        if inner.is_closed() {
206            return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
207        }
208
209        let free = inner
210            .backpressure_boundary
211            .checked_sub(inner.buf.len())
212            .expect("backpressure boundary overflow");
213        let to_write = buf.len().min(free);
214        if to_write == 0 {
215            if buf.is_empty() {
216                return Poll::Ready(Ok(0));
217            }
218
219            let old_waker = inner.register_sender_waker(cx.waker());
220            let waker = inner.take_receiver_waker();
221
222            // unlock before waking up and dropping old waker
223            drop(inner);
224            drop(old_waker);
225            if let Some(waker) = waker {
226                waker.wake();
227            }
228
229            return Poll::Pending;
230        }
231
232        // this is to avoid starving other tasks
233        coop.made_progress();
234
235        inner.buf.extend_from_slice(&buf[..to_write]);
236
237        let waker = inner.take_receiver_waker();
238        drop(inner); // unlock before waking up
239        if let Some(waker) = waker {
240            waker.wake();
241        }
242
243        Poll::Ready(Ok(to_write))
244    }
245
246    /// # Errors
247    ///
248    /// This method will return [`IoErrorKind::BrokenPipe`]
249    /// if the channel has been closed.
250    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
251        let inner = self.inner.lock().unwrap();
252        if inner.is_closed() {
253            Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
254        } else {
255            Poll::Ready(Ok(()))
256        }
257    }
258
259    /// After returns [`Poll::Ready`], all the following call to
260    /// [`Sender::poll_write`] and [`Sender::poll_flush`]
261    /// will return error.
262    ///
263    /// The [`Receiver`] can still be used to read remaining data
264    /// until all bytes have been consumed.
265    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
266        let maybe_waker = {
267            let mut inner = self.inner.lock().unwrap();
268            inner.close_sender()
269        };
270
271        if let Some(waker) = maybe_waker {
272            waker.wake();
273        }
274
275        Poll::Ready(Ok(()))
276    }
277
278    fn is_write_vectored(&self) -> bool {
279        true
280    }
281
282    fn poll_write_vectored(
283        self: Pin<&mut Self>,
284        cx: &mut Context<'_>,
285        bufs: &[IoSlice<'_>],
286    ) -> Poll<Result<usize, IoError>> {
287        let coop = ready!(poll_proceed(cx));
288
289        let mut inner = self.inner.lock().unwrap();
290        if inner.is_closed() {
291            return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
292        }
293
294        let free = inner
295            .backpressure_boundary
296            .checked_sub(inner.buf.len())
297            .expect("backpressure boundary overflow");
298        if free == 0 {
299            let old_waker = inner.register_sender_waker(cx.waker());
300            let maybe_waker = inner.take_receiver_waker();
301
302            // unlock before waking up and dropping old waker
303            drop(inner);
304            drop(old_waker);
305            if let Some(waker) = maybe_waker {
306                waker.wake();
307            }
308
309            return Poll::Pending;
310        }
311
312        // this is to avoid starving other tasks
313        coop.made_progress();
314
315        let mut rem = free;
316        for buf in bufs {
317            if rem == 0 {
318                break;
319            }
320
321            let to_write = buf.len().min(rem);
322            if to_write == 0 {
323                assert_ne!(rem, 0);
324                assert_eq!(buf.len(), 0);
325                continue;
326            }
327
328            inner.buf.extend_from_slice(&buf[..to_write]);
329            rem -= to_write;
330        }
331
332        let waker = inner.take_receiver_waker();
333        drop(inner); // unlock before waking up
334        if let Some(waker) = waker {
335            waker.wake();
336        }
337
338        Poll::Ready(Ok(free - rem))
339    }
340}
341
342/// Create a simplex channel.
343///
344/// The `capacity` parameter specifies the maximum number of bytes that can be
345/// stored in the channel without making the [`Sender::poll_write`]
346/// return [`Poll::Pending`].
347///
348/// # Panics
349///
350/// This function will panic if `capacity` is zero.
351pub fn new(capacity: usize) -> (Sender, Receiver) {
352    assert_ne!(capacity, 0, "capacity must be greater than zero");
353
354    let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
355    let tx = Sender {
356        inner: Arc::clone(&inner),
357    };
358    let rx = Receiver { inner };
359    (tx, rx)
360}