diff --git a/library/std/src/sys/unix/ext/net/ancillary.rs b/library/std/src/sys/unix/ext/net/ancillary.rs index 4e584cb143fa..77214801e3e0 100644 --- a/library/std/src/sys/unix/ext/net/ancillary.rs +++ b/library/std/src/sys/unix/ext/net/ancillary.rs @@ -1,3 +1,4 @@ +use crate::convert::TryFrom; use crate::io::{self, IoSliceMut}; use crate::mem; use crate::os::unix::io::RawFd; @@ -145,6 +146,13 @@ impl<'a> Iterator for ScmCredentials<'a> { } } +#[non_exhaustive] +#[derive(Debug)] +#[unstable(feature = "unix_socket_ancillary_data", issue = "none")] +pub enum AncillaryError { + Unknown { cmsg_level: i32, cmsg_type: i32 }, +} + #[cfg(any( target_os = "haiku", target_os = "solaris", @@ -240,17 +248,19 @@ impl<'a> AncillaryData<'a> { target_env = "uclibc", ))] #[unstable(feature = "unix_socket_ancillary_data", issue = "none")] -impl<'a> AncillaryData<'a> { - fn from(cmsg: &'a libc::cmsghdr) -> Self { +impl<'a> TryFrom<&'a libc::cmsghdr> for AncillaryData<'a> { + type Error = AncillaryError; + + fn try_from(cmsg: &'a libc::cmsghdr) -> Result { unsafe { let cmsg_len_zero = libc::CMSG_LEN(0) as usize; let data_len = (*cmsg).cmsg_len - cmsg_len_zero; let data = libc::CMSG_DATA(cmsg).cast(); let data = from_raw_parts(data, data_len); - if (*cmsg).cmsg_level == libc::SOL_SOCKET { - match (*cmsg).cmsg_type { - libc::SCM_RIGHTS => AncillaryData::as_rights(data), + match (*cmsg).cmsg_level { + libc::SOL_SOCKET => match (*cmsg).cmsg_type { + libc::SCM_RIGHTS => Ok(AncillaryData::as_rights(data)), #[cfg(any( target_os = "linux", target_os = "android", @@ -258,7 +268,7 @@ impl<'a> AncillaryData<'a> { target_os = "fuchsia", target_env = "uclibc", ))] - libc::SCM_CREDENTIALS => AncillaryData::as_credentials(data), + libc::SCM_CREDENTIALS => Ok(AncillaryData::as_credentials(data)), #[cfg(any( target_os = "netbsd", target_os = "openbsd", @@ -267,11 +277,14 @@ impl<'a> AncillaryData<'a> { target_os = "macos", target_os = "ios", ))] - libc::SCM_CREDS => AncillaryData::as_credentials(data), - _ => panic!("Unknown cmsg type"), + libc::SCM_CREDS => Ok(AncillaryData::as_credentials(data)), + cmsg_type => { + Err(AncillaryError::Unknown { cmsg_level: libc::SOL_SOCKET, cmsg_type }) + } + }, + cmsg_level => { + Err(AncillaryError::Unknown { cmsg_level, cmsg_type: (*cmsg).cmsg_type }) } - } else { - panic!("Unknown cmsg level"); } } } @@ -317,9 +330,9 @@ pub struct Messages<'a> { ))] #[unstable(feature = "unix_socket_ancillary_data", issue = "none")] impl<'a> Iterator for Messages<'a> { - type Item = AncillaryData<'a>; + type Item = Result, AncillaryError>; - fn next(&mut self) -> Option> { + fn next(&mut self) -> Option { unsafe { let msg = libc::msghdr { msg_name: null_mut(), @@ -339,8 +352,8 @@ impl<'a> Iterator for Messages<'a> { let cmsg = cmsg.as_ref()?; self.current = Some(cmsg); - let ancillary_data = AncillaryData::from(cmsg); - Some(ancillary_data) + let ancillary_result = AncillaryData::try_from(cmsg); + Some(ancillary_result) } } } @@ -364,8 +377,8 @@ impl<'a> Iterator for Messages<'a> { /// let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..]; /// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?; /// -/// for ancillary_data in ancillary.messages() { -/// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { +/// for ancillary_result in ancillary.messages() { +/// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } @@ -585,8 +598,8 @@ impl<'a> SocketAncillary<'a> { /// let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..]; /// /// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?; - /// for ancillary_data in ancillary.messages() { - /// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { + /// for ancillary_result in ancillary.messages() { + /// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } @@ -596,8 +609,8 @@ impl<'a> SocketAncillary<'a> { /// ancillary.clear(); /// /// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?; - /// for ancillary_data in ancillary.messages() { - /// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { + /// for ancillary_result in ancillary.messages() { + /// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } diff --git a/library/std/src/sys/unix/ext/net/datagram.rs b/library/std/src/sys/unix/ext/net/datagram.rs index e5630127ccbc..e2fa65572e1e 100644 --- a/library/std/src/sys/unix/ext/net/datagram.rs +++ b/library/std/src/sys/unix/ext/net/datagram.rs @@ -343,8 +343,8 @@ impl UnixDatagram { /// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]); /// let (size, _truncated, sender) = sock.recv_vectored_with_ancillary_from(bufs, &mut ancillary)?; /// println!("received {}", size); - /// for ancillary_data in ancillary.messages() { - /// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { + /// for ancillary_result in ancillary.messages() { + /// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } @@ -391,8 +391,8 @@ impl UnixDatagram { /// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]); /// let (size, _truncated) = sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?; /// println!("received {}", size); - /// for ancillary_data in ancillary.messages() { - /// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { + /// for ancillary_result in ancillary.messages() { + /// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } diff --git a/library/std/src/sys/unix/ext/net/stream.rs b/library/std/src/sys/unix/ext/net/stream.rs index 75b919026681..d144c41de3c8 100644 --- a/library/std/src/sys/unix/ext/net/stream.rs +++ b/library/std/src/sys/unix/ext/net/stream.rs @@ -451,8 +451,8 @@ impl UnixStream { /// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]); /// let size = socket.recv_vectored_with_ancillary(bufs, &mut ancillary)?; /// println!("received {}", size); - /// for ancillary_data in ancillary.messages() { - /// if let AncillaryData::ScmRights(scm_rights) = ancillary_data { + /// for ancillary_result in ancillary.messages() { + /// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() { /// for fd in scm_rights { /// println!("receive file descriptor: {}", fd); /// } diff --git a/library/std/src/sys/unix/ext/net/test.rs b/library/std/src/sys/unix/ext/net/test.rs index 724d73c9b1ed..3be9bb48583f 100644 --- a/library/std/src/sys/unix/ext/net/test.rs +++ b/library/std/src/sys/unix/ext/net/test.rs @@ -481,7 +481,7 @@ fn test_send_vectored_fds_unix_stream() { let mut ancillary_data_vec = Vec::from_iter(ancillary2.messages()); assert_eq!(ancillary_data_vec.len(), 1); - if let AncillaryData::ScmRights(scm_rights) = ancillary_data_vec.pop().unwrap() { + if let AncillaryData::ScmRights(scm_rights) = ancillary_data_vec.pop().unwrap().unwrap() { let fd_vec = Vec::from_iter(scm_rights); assert_eq!(fd_vec.len(), 1); unsafe { @@ -551,7 +551,9 @@ fn test_send_vectored_with_ancillary_to_unix_datagram() { let mut ancillary_data_vec = Vec::from_iter(ancillary2.messages()); assert_eq!(ancillary_data_vec.len(), 1); - if let AncillaryData::ScmCredentials(scm_credentials) = ancillary_data_vec.pop().unwrap() { + if let AncillaryData::ScmCredentials(scm_credentials) = + ancillary_data_vec.pop().unwrap().unwrap() + { let cred_vec = Vec::from_iter(scm_credentials); assert_eq!(cred_vec.len(), 1); assert_eq!(cred1.pid, cred_vec[0].pid); @@ -596,7 +598,7 @@ fn test_send_vectored_with_ancillary_unix_datagram() { let mut ancillary_data_vec = Vec::from_iter(ancillary2.messages()); assert_eq!(ancillary_data_vec.len(), 1); - if let AncillaryData::ScmRights(scm_rights) = ancillary_data_vec.pop().unwrap() { + if let AncillaryData::ScmRights(scm_rights) = ancillary_data_vec.pop().unwrap().unwrap() { let fd_vec = Vec::from_iter(scm_rights); assert_eq!(fd_vec.len(), 1); unsafe {