1use crate::util::poll_proceed;
4
5use bytes::Buf;
6use bytes::BytesMut;
7use futures_core::ready;
8use std::io::Error as IoError;
9use std::io::ErrorKind as IoErrorKind;
10use std::io::IoSlice;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll, Waker};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16type IoResult<T> = Result<T, IoError>;
17
18const CLOSED_ERROR_MSG: &str = "simplex has been closed";
19
20#[derive(Debug)]
21struct Inner {
22 backpressure_boundary: usize,
24
25 is_closed: bool,
27
28 receiver_waker: Option<Waker>,
30
31 sender_waker: Option<Waker>,
33
34 buf: BytesMut,
36}
37
38impl Inner {
39 fn with_capacity(capacity: usize) -> Self {
40 Self {
41 backpressure_boundary: capacity,
42 is_closed: false,
43 receiver_waker: None,
44 sender_waker: None,
45 buf: BytesMut::with_capacity(capacity),
46 }
47 }
48
49 fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
50 match self.receiver_waker.as_mut() {
51 Some(old) if old.will_wake(waker) => None,
52 _ => self.receiver_waker.replace(waker.clone()),
53 }
54 }
55
56 fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
57 match self.sender_waker.as_mut() {
58 Some(old) if old.will_wake(waker) => None,
59 _ => self.sender_waker.replace(waker.clone()),
60 }
61 }
62
63 fn take_receiver_waker(&mut self) -> Option<Waker> {
64 self.receiver_waker.take()
65 }
66
67 fn take_sender_waker(&mut self) -> Option<Waker> {
68 self.sender_waker.take()
69 }
70
71 fn is_closed(&self) -> bool {
72 self.is_closed
73 }
74
75 fn close_receiver(&mut self) -> Option<Waker> {
76 self.is_closed = true;
77 self.take_sender_waker()
78 }
79
80 fn close_sender(&mut self) -> Option<Waker> {
81 self.is_closed = true;
82 self.take_receiver_waker()
83 }
84}
85
86#[derive(Debug)]
101pub struct Receiver {
102 inner: Arc<Mutex<Inner>>,
103}
104
105impl Drop for Receiver {
106 fn drop(&mut self) {
108 let maybe_waker = {
109 let mut inner = self.inner.lock().unwrap();
110 inner.close_receiver()
111 };
112
113 if let Some(waker) = maybe_waker {
114 waker.wake();
115 }
116 }
117}
118
119impl AsyncRead for Receiver {
120 fn poll_read(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 buf: &mut ReadBuf<'_>,
124 ) -> Poll<IoResult<()>> {
125 let coop = ready!(poll_proceed(cx));
126
127 let mut inner = self.inner.lock().unwrap();
128
129 let to_read = buf.remaining().min(inner.buf.remaining());
130 if to_read == 0 {
131 if inner.is_closed() || buf.remaining() == 0 {
132 return Poll::Ready(Ok(()));
133 }
134
135 let old_waker = inner.register_receiver_waker(cx.waker());
136 let maybe_waker = inner.take_sender_waker();
137
138 drop(inner);
140 drop(old_waker);
141 if let Some(waker) = maybe_waker {
142 waker.wake();
143 }
144 return Poll::Pending;
145 }
146
147 coop.made_progress();
149
150 buf.put_slice(&inner.buf[..to_read]);
151 inner.buf.advance(to_read);
152
153 let waker = inner.take_sender_waker();
154 drop(inner); if let Some(waker) = waker {
156 waker.wake();
157 }
158
159 Poll::Ready(Ok(()))
160 }
161}
162
163#[derive(Debug)]
177pub struct Sender {
178 inner: Arc<Mutex<Inner>>,
179}
180
181impl Drop for Sender {
182 fn drop(&mut self) {
184 let maybe_waker = {
185 let mut inner = self.inner.lock().unwrap();
186 inner.close_sender()
187 };
188
189 if let Some(waker) = maybe_waker {
190 waker.wake();
191 }
192 }
193}
194
195impl AsyncWrite for Sender {
196 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
201 let coop = ready!(poll_proceed(cx));
202
203 let mut inner = self.inner.lock().unwrap();
204
205 if inner.is_closed() {
206 return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
207 }
208
209 let free = inner
210 .backpressure_boundary
211 .checked_sub(inner.buf.len())
212 .expect("backpressure boundary overflow");
213 let to_write = buf.len().min(free);
214 if to_write == 0 {
215 if buf.is_empty() {
216 return Poll::Ready(Ok(0));
217 }
218
219 let old_waker = inner.register_sender_waker(cx.waker());
220 let waker = inner.take_receiver_waker();
221
222 drop(inner);
224 drop(old_waker);
225 if let Some(waker) = waker {
226 waker.wake();
227 }
228
229 return Poll::Pending;
230 }
231
232 coop.made_progress();
234
235 inner.buf.extend_from_slice(&buf[..to_write]);
236
237 let waker = inner.take_receiver_waker();
238 drop(inner); if let Some(waker) = waker {
240 waker.wake();
241 }
242
243 Poll::Ready(Ok(to_write))
244 }
245
246 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
251 let inner = self.inner.lock().unwrap();
252 if inner.is_closed() {
253 Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
254 } else {
255 Poll::Ready(Ok(()))
256 }
257 }
258
259 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
266 let maybe_waker = {
267 let mut inner = self.inner.lock().unwrap();
268 inner.close_sender()
269 };
270
271 if let Some(waker) = maybe_waker {
272 waker.wake();
273 }
274
275 Poll::Ready(Ok(()))
276 }
277
278 fn is_write_vectored(&self) -> bool {
279 true
280 }
281
282 fn poll_write_vectored(
283 self: Pin<&mut Self>,
284 cx: &mut Context<'_>,
285 bufs: &[IoSlice<'_>],
286 ) -> Poll<Result<usize, IoError>> {
287 let coop = ready!(poll_proceed(cx));
288
289 let mut inner = self.inner.lock().unwrap();
290 if inner.is_closed() {
291 return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
292 }
293
294 let free = inner
295 .backpressure_boundary
296 .checked_sub(inner.buf.len())
297 .expect("backpressure boundary overflow");
298 if free == 0 {
299 let old_waker = inner.register_sender_waker(cx.waker());
300 let maybe_waker = inner.take_receiver_waker();
301
302 drop(inner);
304 drop(old_waker);
305 if let Some(waker) = maybe_waker {
306 waker.wake();
307 }
308
309 return Poll::Pending;
310 }
311
312 coop.made_progress();
314
315 let mut rem = free;
316 for buf in bufs {
317 if rem == 0 {
318 break;
319 }
320
321 let to_write = buf.len().min(rem);
322 if to_write == 0 {
323 assert_ne!(rem, 0);
324 assert_eq!(buf.len(), 0);
325 continue;
326 }
327
328 inner.buf.extend_from_slice(&buf[..to_write]);
329 rem -= to_write;
330 }
331
332 let waker = inner.take_receiver_waker();
333 drop(inner); if let Some(waker) = waker {
335 waker.wake();
336 }
337
338 Poll::Ready(Ok(free - rem))
339 }
340}
341
342pub fn new(capacity: usize) -> (Sender, Receiver) {
352 assert_ne!(capacity, 0, "capacity must be greater than zero");
353
354 let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
355 let tx = Sender {
356 inner: Arc::clone(&inner),
357 };
358 let rx = Receiver { inner };
359 (tx, rx)
360}