Skip to content

Commit 8a59fb6

Browse files
committed
cluster: implement cass_cluster_set_local_port_range
Set the default value (ephemeral port range) and implemented unit test.
1 parent 307295c commit 8a59fb6

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

scylla-rust-wrapper/src/cluster.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use scylla::policies::load_balancing::{
2020
};
2121
use scylla::policies::retry::RetryPolicy;
2222
use scylla::policies::speculative_execution::SimpleSpeculativeExecutionPolicy;
23+
use scylla::routing::ShardAwarePortRange;
2324
use scylla::statement::{Consistency, SerialConsistency};
2425
use std::collections::HashMap;
2526
use std::convert::TryInto;
@@ -54,6 +55,9 @@ const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
5455
const DEFAULT_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(60);
5556
// - default local ip address is arbitrary
5657
const DEFAULT_LOCAL_IP_ADDRESS: Option<IpAddr> = None;
58+
// - default shard aware local port range is ephemeral range
59+
const DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE: ShardAwarePortRange =
60+
ShardAwarePortRange::EPHEMERAL_PORT_RANGE;
5761

5862
const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
5963
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -224,6 +228,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
224228
.keepalive_interval(DEFAULT_KEEPALIVE_INTERVAL)
225229
.keepalive_timeout(DEFAULT_KEEPALIVE_TIMEOUT)
226230
.local_ip_address(DEFAULT_LOCAL_IP_ADDRESS)
231+
.shard_aware_local_port_range(DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE)
227232
};
228233

229234
BoxFFI::into_ptr(Box::new(CassCluster {
@@ -533,6 +538,47 @@ pub unsafe extern "C" fn cass_cluster_set_local_address_n(
533538
CassError::CASS_OK
534539
}
535540

541+
#[unsafe(no_mangle)]
542+
pub unsafe extern "C" fn cass_cluster_set_local_port_range(
543+
cluster_raw: CassBorrowedExclusivePtr<CassCluster, CMut>,
544+
lo: c_int,
545+
hi: c_int,
546+
) -> CassError {
547+
let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
548+
tracing::error!("Provided null cluster pointer to cass_cluster_set_local_port_range!");
549+
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
550+
};
551+
552+
fn compute_range_from_raw(lo: i32, hi: i32) -> Result<ShardAwarePortRange, CassError> {
553+
let start: u16 = lo
554+
.try_into()
555+
.map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)?;
556+
// In cpp-driver, the `hi` is exluded from the port range.
557+
// In rust-driver, OTOH, we include the upper bound of the range - thus -1.
558+
let end: u16 = hi
559+
.checked_sub(1)
560+
.ok_or(CassError::CASS_ERROR_LIB_BAD_PARAMS)?
561+
.try_into()
562+
.map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)?;
563+
564+
// Further validation is performed by the constructor.
565+
ShardAwarePortRange::new(start..=end).map_err(|_| CassError::CASS_ERROR_LIB_BAD_PARAMS)
566+
}
567+
568+
let range: ShardAwarePortRange = match compute_range_from_raw(lo, hi) {
569+
Ok(range) => range,
570+
Err(cass_error) => {
571+
// Let's use the error message from cpp-driver.
572+
tracing::error!("Invalid local port range. Expected: 1024 < lo <= hi < 65536.");
573+
return cass_error;
574+
}
575+
};
576+
577+
cluster.session_builder.config.shard_aware_local_port_range = range;
578+
579+
CassError::CASS_OK
580+
}
581+
536582
#[unsafe(no_mangle)]
537583
pub unsafe extern "C" fn cass_cluster_set_credentials(
538584
cluster: CassBorrowedExclusivePtr<CassCluster, CMut>,
@@ -1121,6 +1167,74 @@ mod tests {
11211167
}
11221168
}
11231169

1170+
#[test]
1171+
fn test_local_port_range() {
1172+
// TODO: Currently no way to compare the `ShardAwarePortRange`. Either implement `PartialEq`
1173+
// or expose a getter for underlying range on rust-driver side. We can test the validation, though.
1174+
1175+
unsafe {
1176+
let mut cluster_raw = cass_cluster_new();
1177+
1178+
// negative value
1179+
{
1180+
assert_cass_error_eq!(
1181+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), -1, 1025),
1182+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1183+
);
1184+
}
1185+
1186+
// start (inclusive) == end (exclusive)
1187+
{
1188+
assert_cass_error_eq!(
1189+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, 5555),
1190+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1191+
);
1192+
}
1193+
1194+
// start == end - 1
1195+
{
1196+
assert_cass_error_eq!(
1197+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5556, 5556),
1198+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1199+
);
1200+
}
1201+
1202+
// start > end
1203+
{
1204+
assert_cass_error_eq!(
1205+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5556, 5555),
1206+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1207+
);
1208+
}
1209+
1210+
// 0 <= start,end < 1024
1211+
{
1212+
assert_cass_error_eq!(
1213+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 1, 3),
1214+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1215+
);
1216+
}
1217+
1218+
// end is i32::MIN - check that does not panic due to overflow
1219+
{
1220+
assert_cass_error_eq!(
1221+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, i32::MIN),
1222+
CassError::CASS_ERROR_LIB_BAD_PARAMS
1223+
);
1224+
}
1225+
1226+
// some valid port range
1227+
{
1228+
assert_cass_error_eq!(
1229+
cass_cluster_set_local_port_range(cluster_raw.borrow_mut(), 5555, 5557),
1230+
CassError::CASS_OK
1231+
);
1232+
}
1233+
1234+
cass_cluster_free(cluster_raw);
1235+
}
1236+
}
1237+
11241238
#[test]
11251239
#[ntest::timeout(100)]
11261240
fn test_load_balancing_config() {

0 commit comments

Comments
 (0)