diff --git a/src/socket.rs b/src/socket.rs index b258c93c..45d8ea55 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -16,6 +16,16 @@ use std::net::Ipv6Addr; use std::net::{self, Ipv4Addr, Shutdown}; #[cfg(any(unix, all(target_os = "wasi", not(target_env = "p1"))))] use std::os::fd::{FromRawFd, IntoRawFd}; +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "solaris", + target_os = "illumos", + target_os = "nto", +))] +use std::os::raw::c_uchar; #[cfg(windows)] use std::os::windows::io::{FromRawSocket, IntoRawSocket}; use std::time::Duration; @@ -27,6 +37,30 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; #[cfg(not(any(target_os = "redox", target_os = "wasi")))] use crate::{MaybeUninitSlice, MsgHdr, RecvFlags}; +// Match the system headers for these IPv4 multicast socket options. These +// targets declare `unsigned char` rather than `int`. +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "solaris", + target_os = "illumos", + target_os = "nto", +))] +type IpV4MultiCastType = c_uchar; + +#[cfg(not(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "solaris", + target_os = "illumos", + target_os = "nto", +)))] +type IpV4MultiCastType = c_int; + /// Owned wrapper around a system socket. /// /// This type simply wraps an instance of a file descriptor (`c_int`) on Unix @@ -1546,7 +1580,7 @@ impl Socket { /// [`set_multicast_loop_v4`]: Socket::set_multicast_loop_v4 pub fn multicast_loop_v4(&self) -> io::Result { unsafe { - getsockopt::(self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_LOOP) + getsockopt::(self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_LOOP) .map(|loop_v4| loop_v4 != 0) } } @@ -1561,7 +1595,7 @@ impl Socket { self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_LOOP, - loop_v4 as c_int, + loop_v4 as IpV4MultiCastType, ) } } @@ -1573,7 +1607,7 @@ impl Socket { /// [`set_multicast_ttl_v4`]: Socket::set_multicast_ttl_v4 pub fn multicast_ttl_v4(&self) -> io::Result { unsafe { - getsockopt::(self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_TTL) + getsockopt::(self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_TTL) .map(|ttl| ttl as u32) } } @@ -1591,7 +1625,7 @@ impl Socket { self.as_raw(), sys::IPPROTO_IP, sys::IP_MULTICAST_TTL, - ttl as c_int, + ttl as IpV4MultiCastType, ) } } diff --git a/tests/socket.rs b/tests/socket.rs index 6dc1aea0..d3a9a564 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -21,7 +21,7 @@ use std::io::Write; #[cfg(not(target_os = "vita"))] use std::mem::MaybeUninit; use std::mem::{self}; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream, UdpSocket}; #[cfg(not(any(target_os = "redox", target_os = "vita")))] use std::net::{Ipv6Addr, SocketAddrV6}; #[cfg(all( @@ -1970,3 +1970,52 @@ fn set_busy_poll() { assert!(socket.busy_poll().unwrap() == i); } } + +/// A helper type to allow testing socket options on both `Socket` and `UdpSocket`. +pub enum SocketRef<'a> { + Socket(&'a Socket), + UdpSocket(&'a UdpSocket), +} + +impl<'a> From<&'a Socket> for SocketRef<'a> { + fn from(socket: &'a Socket) -> Self { + SocketRef::Socket(socket) + } +} + +impl<'a> From<&'a UdpSocket> for SocketRef<'a> { + fn from(socket: &'a UdpSocket) -> Self { + SocketRef::UdpSocket(socket) + } +} + +/// Assert that `multicast_ttl_v4` is set to a given value on `socket`. +#[track_caller] +pub fn assert_multicast_ttl_v4<'a>(socket: impl Into>, want: u32) { + let socket = socket.into(); + let ttl = match socket { + SocketRef::Socket(socket) => socket.multicast_ttl_v4().unwrap(), + SocketRef::UdpSocket(socket) => socket.multicast_ttl_v4().unwrap(), + }; + assert_eq!(ttl, want, "multicast_ttl_v4 option"); +} + +#[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "solaris", + target_os = "illumos", + target_os = "nto", +))] +#[test] +fn multicast_v4_bsd_abi() { + let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + socket.set_multicast_ttl_v4(258).unwrap(); + assert_multicast_ttl_v4(&socket, 2); + + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap(); + socket.set_multicast_ttl_v4(258).unwrap(); + assert_multicast_ttl_v4(&socket, 2); +}