Skip to content

Implement local address and local port range configuration #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 254 additions & 0 deletions scylla-rust-wrapper/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ use scylla::policies::load_balancing::{
};
use scylla::policies::retry::RetryPolicy;
use scylla::policies::speculative_execution::SimpleSpeculativeExecutionPolicy;
use scylla::routing::ShardAwarePortRange;
use scylla::statement::{Consistency, SerialConsistency};
use std::collections::HashMap;
use std::convert::TryInto;
use std::future::Future;
use std::net::IpAddr;
use std::num::NonZero;
use std::os::raw::{c_char, c_int, c_uint};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -56,6 +59,11 @@ const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_millis(5000);
const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
// - keepalive timeout is 60 secs
const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(60);
// - default local ip address is arbitrary
const DEFAULT_LOCAL_IP_ADDRESS: Option<IpAddr> = None;
// - default shard aware local port range is ephemeral range
const DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE: ShardAwarePortRange =
ShardAwarePortRange::EPHEMERAL_PORT_RANGE;

const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");
Expand Down Expand Up @@ -227,6 +235,8 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
.write_coalescing_delay(DEFAULT_WRITE_COALESCING_DELAY)
.keepalive_interval(DEFAULT_KEEPALIVE_INTERVAL)
.keepalive_timeout(DEFAULT_KEEPALIVE_TIMEOUT)
.local_ip_address(DEFAULT_LOCAL_IP_ADDRESS)
.shard_aware_local_port_range(DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE)
};

BoxFFI::into_ptr(Box::new(CassCluster {
Expand Down Expand Up @@ -524,6 +534,93 @@ pub unsafe extern "C" fn cass_cluster_set_port(
CassError::CASS_OK
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_cluster_set_local_address(
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
ip: *const c_char,
) -> CassError {
// Safety: We assume that string is null-terminated.
unsafe { cass_cluster_set_local_address_n(cluster_raw, ip, strlen(ip)) }
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_cluster_set_local_address_n(
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
ip: *const c_char,
ip_length: size_t,
) -> CassError {
let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
tracing::error!("Provided null cluster pointer to cass_cluster_set_local_address_n!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

// Semantics from cpp-driver - if pointer is null or length is 0, use the
// arbitrary address (INADDR_ANY, or in6addr_any).
let local_addr: Option<IpAddr> = if ip.is_null() || ip_length == 0 {
None
} else {
// SAFETY: We assume that user provides valid pointer and length.
match unsafe { ptr_to_cstr_n(ip, ip_length) } {
Some(ip_str) => match IpAddr::from_str(ip_str) {
Ok(addr) => Some(addr),
Err(err) => {
tracing::error!("Failed to parse ip address <{}>: {}", ip_str, err);
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
}
},
None => {
tracing::error!("Provided non-utf8 ip string to cass_cluster_set_local_address_n!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
}
}
};

cluster.session_builder.config.local_ip_address = local_addr;

CassError::CASS_OK
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_cluster_set_local_port_range(
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
lo: c_int,
hi: c_int,
) -> CassError {
let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
tracing::error!("Provided null cluster pointer to cass_cluster_set_local_port_range!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

fn compute_range_from_raw(lo: i32, hi: i32) -> Result<ShardAwarePortRange, CassError> {
let start: u16 = lo
.try_into()
.map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)?;
// In cpp-driver, the `hi` is exluded from the port range.
// In rust-driver, OTOH, we include the upper bound of the range - thus -1.
let end: u16 = hi
.checked_sub(1)
.ok_or(CassError::CASS_ERROR_LIB_BAD_PARAMS)?
.try_into()
.map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)?;

// Further validation is performed by the constructor.
ShardAwarePortRange::new(start..=end).map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)
}

let range: ShardAwarePortRange = match compute_range_from_raw(lo, hi) {
Ok(range) => range,
Err(cass_error) => {
// Let's use the error message from cpp-driver.
tracing::error!("Invalid local port range. Expected: 1024 < lo <= hi < 65536.");
return cass_error;
}
};

cluster.session_builder.config.shard_aware_local_port_range = range;

CassError::CASS_OK
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_cluster_set_credentials(
cluster: CassBorrowedExclusivePtr<CassCluster, CMut>,
Expand Down Expand Up @@ -1017,12 +1114,169 @@ mod tests {
exec_profile::{cass_execution_profile_free, cass_execution_profile_new},
};
use assert_matches::assert_matches;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::{
collections::HashSet,
convert::{TryFrom, TryInto},
os::raw::c_char,
};

#[test]
fn test_local_ip_address() {
unsafe {
let mut cluster_raw = cass_cluster_new();

// Check default address
{
let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
assert!(cluster.session_builder.config.local_ip_address.is_none());
}

// null ip pointer
{
assert_cass_error_eq!(
cass_cluster_set_local_address(cluster_raw.borrow_mut(), std::ptr::null()),
CassError::CASS_OK
);

let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
assert!(cluster.session_builder.config.local_ip_address.is_none());
}

// empty string
{
assert_cass_error_eq!(
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"".as_ptr()),
CassError::CASS_OK
);

let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
assert!(cluster.session_builder.config.local_ip_address.is_none());
}

// valid ipv4 address
{
assert_cass_error_eq!(
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"1.2.3.4".as_ptr()),
CassError::CASS_OK
);

let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
assert_eq!(
cluster.session_builder.config.local_ip_address,
Some(Ipv4Addr::new(1, 2, 3, 4).into())
);
}

// valid ipv6 address
{
assert_cass_error_eq!(
cass_cluster_set_local_address(
cluster_raw.borrow_mut(),
c"2001:db8::8a2e:370:7334".as_ptr()
),
CassError::CASS_OK
);

let cluster = BoxFFI::as_ref(cluster_raw.borrow()).unwrap();
assert_eq!(
cluster.session_builder.config.local_ip_address,
Some(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0x8a2e, 0x0370, 0x7334,).into())
);
}

// non-numeric address
{
assert_cass_error_eq!(
cass_cluster_set_local_address(cluster_raw.borrow_mut(), c"foo".as_ptr()),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// non-valid-utf8 slice
{
let non_utf8_slice: &[u8] = &[0xF0, 0x28, 0x8C, 0x28, 0x00];
assert_cass_error_eq!(
cass_cluster_set_local_address(
cluster_raw.borrow_mut(),
non_utf8_slice.as_ptr() as *const c_char
),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

cass_cluster_free(cluster_raw);
}
}

#[test]
fn test_local_port_range() {
// TODO: Currently no way to compare the `ShardAwarePortRange`. Either implement `PartialEq`
// or expose a getter for underlying range on rust-driver side. We can test the validation, though.

unsafe {
let mut cluster_raw = cass_cluster_new();

// negative value
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), -1, 1025),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// start (inclusive) == end (exclusive)
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, 5555),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// start == end - 1
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5556, 5556),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// start > end
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5556, 5555),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// 0 <= start,end < 1024
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 1, 3),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// end is i32::MIN - check that does not panic due to overflow
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, i32::MIN),
CassError::CASS_ERROR_LIB_BAD_PARAMS
);
}

// some valid port range
{
assert_cass_error_eq!(
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, 5557),
CassError::CASS_OK
);
}

cass_cluster_free(cluster_raw);
}
}

#[test]
#[ntest::timeout(100)]
fn test_coalescing_delay() {
Expand Down
3 changes: 0 additions & 3 deletions src/testing_unimplemented.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ CASS_EXPORT CassError cass_cluster_set_host_listener_callback(CassCluster* clust
void* data) {
throw std::runtime_error("UNIMPLEMENTED cass_cluster_set_host_listener_callback\n");
}
CASS_EXPORT CassError cass_cluster_set_local_address(CassCluster* cluster, const char* name) {
throw std::runtime_error("UNIMPLEMENTED cass_cluster_set_local_address\n");
}
CASS_EXPORT CassError cass_cluster_set_no_compact(CassCluster* cluster, cass_bool_t enabled) {
throw std::runtime_error("UNIMPLEMENTED cass_cluster_set_no_compact\n");
}
Expand Down