Files
rust/library/std/src/sys/unix/ext/net/ancillary.rs
Reyk Floeter 3d6bd87b24 unix: Fix feature(unix_socket_ancillary_data) on macos and other BSDs
This adds support for CMSG handling on macOS and fixes it on OpenBSD
and other BSDs.

When traversing the CMSG list, the previous code had an exception for
Android where the next element after the last pointer could point to
the first pointer instead of NULL.  This is actually not specific to
Android: the `libc::CMSG_NXTHDR` implementation for Linux and
emscripten have a special case to return NULL when the length of the
previous element is zero; most other implementations simply return the
previous element plus a zero offset in this case.

This MR additionally adds `SocketAncillary::is_empty` because clippy
is right that it should be added.
2021-03-26 21:12:22 +01:00

676 lines
24 KiB
Rust

use super::{sockaddr_un, SocketAddr};
use crate::convert::TryFrom;
use crate::io::{self, IoSlice, IoSliceMut};
use crate::marker::PhantomData;
use crate::mem::{size_of, zeroed};
use crate::os::unix::io::RawFd;
use crate::path::Path;
use crate::ptr::{eq, read_unaligned};
use crate::slice::from_raw_parts;
use crate::sys::net::Socket;
// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
#[cfg(all(doc, not(target_os = "linux"), not(target_os = "android")))]
#[allow(non_camel_case_types)]
mod libc {
pub use libc::c_int;
pub struct ucred;
pub struct cmsghdr;
pub type pid_t = i32;
pub type gid_t = u32;
pub type uid_t = u32;
}
pub(super) fn recv_vectored_with_ancillary_from(
socket: &Socket,
bufs: &mut [IoSliceMut<'_>],
ancillary: &mut SocketAncillary<'_>,
) -> io::Result<(usize, bool, io::Result<SocketAddr>)> {
unsafe {
let mut msg_name: libc::sockaddr_un = zeroed();
let mut msg: libc::msghdr = zeroed();
msg.msg_name = &mut msg_name as *mut _ as *mut _;
msg.msg_namelen = size_of::<libc::sockaddr_un>() as libc::socklen_t;
msg.msg_iov = bufs.as_mut_ptr().cast();
cfg_if::cfg_if! {
if #[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))] {
msg.msg_iovlen = bufs.len() as libc::size_t;
msg.msg_controllen = ancillary.buffer.len() as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
msg.msg_iovlen = bufs.len() as libc::c_int;
msg.msg_controllen = ancillary.buffer.len() as libc::socklen_t;
}
}
// macos requires that the control pointer is NULL when the len is 0.
if msg.msg_controllen > 0 {
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
}
let count = socket.recv_msg(&mut msg)?;
ancillary.length = msg.msg_controllen as usize;
ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
let truncated = msg.msg_flags & libc::MSG_TRUNC == libc::MSG_TRUNC;
let addr = SocketAddr::from_parts(msg_name, msg.msg_namelen);
Ok((count, truncated, addr))
}
}
pub(super) fn send_vectored_with_ancillary_to(
socket: &Socket,
path: Option<&Path>,
bufs: &[IoSlice<'_>],
ancillary: &mut SocketAncillary<'_>,
) -> io::Result<usize> {
unsafe {
let (mut msg_name, msg_namelen) =
if let Some(path) = path { sockaddr_un(path)? } else { (zeroed(), 0) };
let mut msg: libc::msghdr = zeroed();
msg.msg_name = &mut msg_name as *mut _ as *mut _;
msg.msg_namelen = msg_namelen;
msg.msg_iov = bufs.as_ptr() as *mut _;
cfg_if::cfg_if! {
if #[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))] {
msg.msg_iovlen = bufs.len() as libc::size_t;
msg.msg_controllen = ancillary.length as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
msg.msg_iovlen = bufs.len() as libc::c_int;
msg.msg_controllen = ancillary.length as libc::socklen_t;
}
}
// macos requires that the control pointer is NULL when the len is 0.
if msg.msg_controllen > 0 {
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
}
ancillary.truncated = false;
socket.send_msg(&mut msg)
}
}
fn add_to_ancillary_data<T>(
buffer: &mut [u8],
length: &mut usize,
source: &[T],
cmsg_level: libc::c_int,
cmsg_type: libc::c_int,
) -> bool {
let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) {
if let Ok(source_len) = u32::try_from(source_len) {
source_len
} else {
return false;
}
} else {
return false;
};
unsafe {
let additional_space = libc::CMSG_SPACE(source_len) as usize;
let new_length = if let Some(new_length) = additional_space.checked_add(*length) {
new_length
} else {
return false;
};
if new_length > buffer.len() {
return false;
}
buffer[*length..new_length].fill(0);
*length = new_length;
let mut msg: libc::msghdr = zeroed();
msg.msg_control = buffer.as_mut_ptr().cast();
cfg_if::cfg_if! {
if #[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))] {
msg.msg_controllen = *length as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
msg.msg_controllen = *length as libc::socklen_t;
}
}
let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
let mut previous_cmsg = cmsg;
while !cmsg.is_null() {
previous_cmsg = cmsg;
cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
// Most operating systems, but not Linux or emscripten, return the previous pointer
// when its length is zero. Therefore, check if the previous pointer is the same as
// the current one.
if eq(cmsg, previous_cmsg) {
break;
}
}
if previous_cmsg.is_null() {
return false;
}
(*previous_cmsg).cmsg_level = cmsg_level;
(*previous_cmsg).cmsg_type = cmsg_type;
cfg_if::cfg_if! {
if #[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))] {
(*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
(*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as libc::socklen_t;
}
}
let data = libc::CMSG_DATA(previous_cmsg).cast();
libc::memcpy(data, source.as_ptr().cast(), source_len as usize);
}
true
}
struct AncillaryDataIter<'a, T> {
data: &'a [u8],
phantom: PhantomData<T>,
}
impl<'a, T> AncillaryDataIter<'a, T> {
/// Create `AncillaryDataIter` struct to iterate through the data unit in the control message.
///
/// # Safety
///
/// `data` must contain a valid control message.
unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> {
AncillaryDataIter { data, phantom: PhantomData }
}
}
impl<'a, T> Iterator for AncillaryDataIter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if size_of::<T>() <= self.data.len() {
unsafe {
let unit = read_unaligned(self.data.as_ptr().cast());
self.data = &self.data[size_of::<T>()..];
Some(unit)
}
} else {
None
}
}
}
/// Unix credential.
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
#[derive(Clone)]
pub struct SocketCred(libc::ucred);
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
impl SocketCred {
/// Create a Unix credential struct.
///
/// PID, UID and GID is set to 0.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn new() -> SocketCred {
SocketCred(libc::ucred { pid: 0, uid: 0, gid: 0 })
}
/// Set the PID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn set_pid(&mut self, pid: libc::pid_t) {
self.0.pid = pid;
}
/// Get the current PID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn get_pid(&self) -> libc::pid_t {
self.0.pid
}
/// Set the UID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn set_uid(&mut self, uid: libc::uid_t) {
self.0.uid = uid;
}
/// Get the current UID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn get_uid(&self) -> libc::uid_t {
self.0.uid
}
/// Set the GID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn set_gid(&mut self, gid: libc::gid_t) {
self.0.gid = gid;
}
/// Get the current GID.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn get_gid(&self) -> libc::gid_t {
self.0.gid
}
}
/// This control message contains file descriptors.
///
/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub struct ScmRights<'a>(AncillaryDataIter<'a, RawFd>);
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
impl<'a> Iterator for ScmRights<'a> {
type Item = RawFd;
fn next(&mut self) -> Option<RawFd> {
self.0.next()
}
}
/// This control message contains unix credentials.
///
/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_CREDENTIALS` or `SCM_CREDS`.
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub struct ScmCredentials<'a>(AncillaryDataIter<'a, libc::ucred>);
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
impl<'a> Iterator for ScmCredentials<'a> {
type Item = SocketCred;
fn next(&mut self) -> Option<SocketCred> {
Some(SocketCred(self.0.next()?))
}
}
/// The error type which is returned from parsing the type a control message.
#[non_exhaustive]
#[derive(Debug)]
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub enum AncillaryError {
Unknown { cmsg_level: i32, cmsg_type: i32 },
}
/// This enum represent one control message of variable type.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub enum AncillaryData<'a> {
ScmRights(ScmRights<'a>),
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
ScmCredentials(ScmCredentials<'a>),
}
impl<'a> AncillaryData<'a> {
/// Create a `AncillaryData::ScmRights` variant.
///
/// # Safety
///
/// `data` must contain a valid control message and the control message must be type of
/// `SOL_SOCKET` and level of `SCM_RIGHTS`.
unsafe fn as_rights(data: &'a [u8]) -> Self {
let ancillary_data_iter = AncillaryDataIter::new(data);
let scm_rights = ScmRights(ancillary_data_iter);
AncillaryData::ScmRights(scm_rights)
}
/// Create a `AncillaryData::ScmCredentials` variant.
///
/// # Safety
///
/// `data` must contain a valid control message and the control message must be type of
/// `SOL_SOCKET` and level of `SCM_CREDENTIALS` or `SCM_CREDENTIALS`.
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
unsafe fn as_credentials(data: &'a [u8]) -> Self {
let ancillary_data_iter = AncillaryDataIter::new(data);
let scm_credentials = ScmCredentials(ancillary_data_iter);
AncillaryData::ScmCredentials(scm_credentials)
}
fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Result<Self, AncillaryError> {
unsafe {
cfg_if::cfg_if! {
if #[cfg(any(
target_os = "android",
all(target_os = "linux", target_env = "gnu"),
all(target_os = "linux", target_env = "uclibc"),
))] {
let cmsg_len_zero = libc::CMSG_LEN(0) as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
let cmsg_len_zero = libc::CMSG_LEN(0) as libc::socklen_t;
}
}
let data_len = (*cmsg).cmsg_len - cmsg_len_zero;
let data = libc::CMSG_DATA(cmsg).cast();
let data = from_raw_parts(data, data_len as usize);
match (*cmsg).cmsg_level {
libc::SOL_SOCKET => match (*cmsg).cmsg_type {
libc::SCM_RIGHTS => Ok(AncillaryData::as_rights(data)),
#[cfg(any(target_os = "android", target_os = "linux",))]
libc::SCM_CREDENTIALS => Ok(AncillaryData::as_credentials(data)),
cmsg_type => {
Err(AncillaryError::Unknown { cmsg_level: libc::SOL_SOCKET, cmsg_type })
}
},
cmsg_level => {
Err(AncillaryError::Unknown { cmsg_level, cmsg_type: (*cmsg).cmsg_type })
}
}
}
}
}
/// This struct is used to iterate through the control messages.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub struct Messages<'a> {
buffer: &'a [u8],
current: Option<&'a libc::cmsghdr>,
}
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
impl<'a> Iterator for Messages<'a> {
type Item = Result<AncillaryData<'a>, AncillaryError>;
fn next(&mut self) -> Option<Self::Item> {
unsafe {
let mut msg: libc::msghdr = zeroed();
msg.msg_control = self.buffer.as_ptr() as *mut _;
cfg_if::cfg_if! {
if #[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))] {
msg.msg_controllen = self.buffer.len() as libc::size_t;
} else if #[cfg(any(
target_os = "dragonfly",
target_os = "emscripten",
target_os = "freebsd",
all(target_os = "linux", target_env = "musl",),
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
))] {
msg.msg_controllen = self.buffer.len() as libc::socklen_t;
}
}
let cmsg = if let Some(current) = self.current {
libc::CMSG_NXTHDR(&msg, current)
} else {
libc::CMSG_FIRSTHDR(&msg)
};
let cmsg = cmsg.as_ref()?;
// Most operating systems, but not Linux or emscripten, return the previous pointer
// when its length is zero. Therefore, check if the previous pointer is the same as
// the current one.
if let Some(current) = self.current {
if eq(current, cmsg) {
return None;
}
}
self.current = Some(cmsg);
let ancillary_result = AncillaryData::try_from_cmsghdr(cmsg);
Some(ancillary_result)
}
}
}
/// A Unix socket Ancillary data struct.
///
/// # Example
/// ```no_run
/// #![feature(unix_socket_ancillary_data)]
/// use std::os::unix::net::{UnixStream, SocketAncillary, AncillaryData};
/// use std::io::IoSliceMut;
///
/// fn main() -> std::io::Result<()> {
/// let sock = UnixStream::connect("/tmp/sock")?;
///
/// let mut fds = [0; 8];
/// let mut ancillary_buffer = [0; 128];
/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
///
/// let mut buf = [1; 8];
/// let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
/// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
///
/// for ancillary_result in ancillary.messages() {
/// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
/// for fd in scm_rights {
/// println!("receive file descriptor: {}", fd);
/// }
/// }
/// }
/// Ok(())
/// }
/// ```
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
#[derive(Debug)]
pub struct SocketAncillary<'a> {
buffer: &'a mut [u8],
length: usize,
truncated: bool,
}
impl<'a> SocketAncillary<'a> {
/// Create an ancillary data with the given buffer.
///
/// # Example
///
/// ```no_run
/// # #![allow(unused_mut)]
/// #![feature(unix_socket_ancillary_data)]
/// use std::os::unix::net::SocketAncillary;
/// let mut ancillary_buffer = [0; 128];
/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
/// ```
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn new(buffer: &'a mut [u8]) -> Self {
SocketAncillary { buffer, length: 0, truncated: false }
}
/// Returns the capacity of the buffer.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn capacity(&self) -> usize {
self.buffer.len()
}
/// Returns `true` if the ancillary data is empty.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn is_empty(&self) -> bool {
self.length == 0
}
/// Returns the number of used bytes.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn len(&self) -> usize {
self.length
}
/// Returns the iterator of the control messages.
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn messages(&self) -> Messages<'_> {
Messages { buffer: &self.buffer[..self.length], current: None }
}
/// Is `true` if during a recv operation the ancillary was truncated.
///
/// # Example
///
/// ```no_run
/// #![feature(unix_socket_ancillary_data)]
/// use std::os::unix::net::{UnixStream, SocketAncillary};
/// use std::io::IoSliceMut;
///
/// fn main() -> std::io::Result<()> {
/// let sock = UnixStream::connect("/tmp/sock")?;
///
/// let mut ancillary_buffer = [0; 128];
/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
///
/// let mut buf = [1; 8];
/// let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
/// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
///
/// println!("Is truncated: {}", ancillary.truncated());
/// Ok(())
/// }
/// ```
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn truncated(&self) -> bool {
self.truncated
}
/// Add file descriptors to the ancillary data.
///
/// The function returns `true` if there was enough space in the buffer.
/// If there was not enough space then no file descriptors was appended.
/// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
/// and type `SCM_RIGHTS`.
///
/// # Example
///
/// ```no_run
/// #![feature(unix_socket_ancillary_data)]
/// use std::os::unix::net::{UnixStream, SocketAncillary};
/// use std::os::unix::io::AsRawFd;
/// use std::io::IoSlice;
///
/// fn main() -> std::io::Result<()> {
/// let sock = UnixStream::connect("/tmp/sock")?;
///
/// let mut ancillary_buffer = [0; 128];
/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
/// ancillary.add_fds(&[sock.as_raw_fd()][..]);
///
/// let mut buf = [1; 8];
/// let mut bufs = &mut [IoSlice::new(&mut buf[..])][..];
/// sock.send_vectored_with_ancillary(bufs, &mut ancillary)?;
/// Ok(())
/// }
/// ```
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn add_fds(&mut self, fds: &[RawFd]) -> bool {
self.truncated = false;
add_to_ancillary_data(
&mut self.buffer,
&mut self.length,
fds,
libc::SOL_SOCKET,
libc::SCM_RIGHTS,
)
}
/// Add credentials to the ancillary data.
///
/// The function returns `true` if there was enough space in the buffer.
/// If there was not enough space then no credentials was appended.
/// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
/// and type `SCM_CREDENTIALS` or `SCM_CREDS`.
///
#[cfg(any(doc, target_os = "android", target_os = "linux",))]
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn add_creds(&mut self, creds: &[SocketCred]) -> bool {
self.truncated = false;
add_to_ancillary_data(
&mut self.buffer,
&mut self.length,
creds,
libc::SOL_SOCKET,
libc::SCM_CREDENTIALS,
)
}
/// Clears the ancillary data, removing all values.
///
/// # Example
///
/// ```no_run
/// #![feature(unix_socket_ancillary_data)]
/// use std::os::unix::net::{UnixStream, SocketAncillary, AncillaryData};
/// use std::io::IoSliceMut;
///
/// fn main() -> std::io::Result<()> {
/// let sock = UnixStream::connect("/tmp/sock")?;
///
/// let mut fds1 = [0; 8];
/// let mut fds2 = [0; 8];
/// let mut ancillary_buffer = [0; 128];
/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
///
/// let mut buf = [1; 8];
/// let mut bufs = &mut [IoSliceMut::new(&mut buf[..])][..];
///
/// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
/// for ancillary_result in ancillary.messages() {
/// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
/// for fd in scm_rights {
/// println!("receive file descriptor: {}", fd);
/// }
/// }
/// }
///
/// ancillary.clear();
///
/// sock.recv_vectored_with_ancillary(bufs, &mut ancillary)?;
/// for ancillary_result in ancillary.messages() {
/// if let AncillaryData::ScmRights(scm_rights) = ancillary_result.unwrap() {
/// for fd in scm_rights {
/// println!("receive file descriptor: {}", fd);
/// }
/// }
/// }
/// Ok(())
/// }
/// ```
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
pub fn clear(&mut self) {
self.length = 0;
self.truncated = false;
}
}