@@ -20,6 +20,7 @@ use scylla::policies::load_balancing::{
20
20
} ;
21
21
use scylla:: policies:: retry:: RetryPolicy ;
22
22
use scylla:: policies:: speculative_execution:: SimpleSpeculativeExecutionPolicy ;
23
+ use scylla:: routing:: ShardAwarePortRange ;
23
24
use scylla:: statement:: { Consistency , SerialConsistency } ;
24
25
use std:: collections:: HashMap ;
25
26
use std:: convert:: TryInto ;
@@ -54,6 +55,9 @@ const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
54
55
const DEFAULT_KEEPALIVE_TIMEOUT : Duration = Duration :: from_secs ( 60 ) ;
55
56
// - default local ip address is arbitrary
56
57
const DEFAULT_LOCAL_IP_ADDRESS : Option < IpAddr > = None ;
58
+ // - default shard aware local port range is ephemeral range
59
+ const DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE : ShardAwarePortRange =
60
+ ShardAwarePortRange :: EPHEMERAL_PORT_RANGE ;
57
61
58
62
const DRIVER_NAME : & str = "ScyllaDB Cpp-Rust Driver" ;
59
63
const DRIVER_VERSION : & str = env ! ( "CARGO_PKG_VERSION" ) ;
@@ -224,6 +228,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
224
228
. keepalive_interval ( DEFAULT_KEEPALIVE_INTERVAL )
225
229
. keepalive_timeout ( DEFAULT_KEEPALIVE_TIMEOUT )
226
230
. local_ip_address ( DEFAULT_LOCAL_IP_ADDRESS )
231
+ . shard_aware_local_port_range ( DEFAULT_SHARD_AWARE_LOCAL_PORT_RANGE )
227
232
} ;
228
233
229
234
BoxFFI :: into_ptr ( Box :: new ( CassCluster {
@@ -533,6 +538,47 @@ pub unsafe extern "C" fn cass_cluster_set_local_address_n(
533
538
CassError :: CASS_OK
534
539
}
535
540
541
+ #[ unsafe( no_mangle) ]
542
+ pub unsafe extern "C" fn cass_cluster_set_local_port_range (
543
+ cluster_raw : CassBorrowedExclusivePtr < CassCluster , CMut > ,
544
+ lo : c_int ,
545
+ hi : c_int ,
546
+ ) -> CassError {
547
+ let Some ( cluster) = BoxFFI :: as_mut_ref ( cluster_raw) else {
548
+ tracing:: error!( "Provided null cluster pointer to cass_cluster_set_local_port_range!" ) ;
549
+ return CassError :: CASS_ERROR_LIB_BAD_PARAMS ;
550
+ } ;
551
+
552
+ fn compute_range_from_raw ( lo : i32 , hi : i32 ) -> Result < ShardAwarePortRange , CassError > {
553
+ let start: u16 = lo
554
+ . try_into ( )
555
+ . map_err ( |_| CassError :: CASS_ERROR_LIB_BAD_PARAMS ) ?;
556
+ // In cpp-driver, the `hi` is exluded from the port range.
557
+ // In rust-driver, OTOH, we include the upper bound of the range - thus -1.
558
+ let end: u16 = hi
559
+ . checked_sub ( 1 )
560
+ . ok_or ( CassError :: CASS_ERROR_LIB_BAD_PARAMS ) ?
561
+ . try_into ( )
562
+ . map_err ( |_| CassError :: CASS_ERROR_LIB_BAD_PARAMS ) ?;
563
+
564
+ // Further validation is performed by the constructor.
565
+ ShardAwarePortRange :: new ( start..=end) . map_err ( |_| CassError :: CASS_ERROR_LIB_BAD_PARAMS )
566
+ }
567
+
568
+ let range: ShardAwarePortRange = match compute_range_from_raw ( lo, hi) {
569
+ Ok ( range) => range,
570
+ Err ( cass_error) => {
571
+ // Let's use the error message from cpp-driver.
572
+ tracing:: error!( "Invalid local port range. Expected: 1024 < lo <= hi < 65536." ) ;
573
+ return cass_error;
574
+ }
575
+ } ;
576
+
577
+ cluster. session_builder . config . shard_aware_local_port_range = range;
578
+
579
+ CassError :: CASS_OK
580
+ }
581
+
536
582
#[ unsafe( no_mangle) ]
537
583
pub unsafe extern "C" fn cass_cluster_set_credentials (
538
584
cluster : CassBorrowedExclusivePtr < CassCluster , CMut > ,
@@ -1121,6 +1167,74 @@ mod tests {
1121
1167
}
1122
1168
}
1123
1169
1170
+ #[ test]
1171
+ fn test_local_port_range ( ) {
1172
+ // TODO: Currently no way to compare the `ShardAwarePortRange`. Either implement `PartialEq`
1173
+ // or expose a getter for underlying range on rust-driver side. We can test the validation, though.
1174
+
1175
+ unsafe {
1176
+ let mut cluster_raw = cass_cluster_new ( ) ;
1177
+
1178
+ // negative value
1179
+ {
1180
+ assert_cass_error_eq ! (
1181
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , -1 , 1025 ) ,
1182
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1183
+ ) ;
1184
+ }
1185
+
1186
+ // start (inclusive) == end (exclusive)
1187
+ {
1188
+ assert_cass_error_eq ! (
1189
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 5555 , 5555 ) ,
1190
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1191
+ ) ;
1192
+ }
1193
+
1194
+ // start == end - 1
1195
+ {
1196
+ assert_cass_error_eq ! (
1197
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 5556 , 5556 ) ,
1198
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1199
+ ) ;
1200
+ }
1201
+
1202
+ // start > end
1203
+ {
1204
+ assert_cass_error_eq ! (
1205
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 5556 , 5555 ) ,
1206
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1207
+ ) ;
1208
+ }
1209
+
1210
+ // 0 <= start,end < 1024
1211
+ {
1212
+ assert_cass_error_eq ! (
1213
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 1 , 3 ) ,
1214
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1215
+ ) ;
1216
+ }
1217
+
1218
+ // end is i32::MIN - check that does not panic due to overflow
1219
+ {
1220
+ assert_cass_error_eq ! (
1221
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 5555 , i32 :: MIN ) ,
1222
+ CassError :: CASS_ERROR_LIB_BAD_PARAMS
1223
+ ) ;
1224
+ }
1225
+
1226
+ // some valid port range
1227
+ {
1228
+ assert_cass_error_eq ! (
1229
+ cass_cluster_set_local_port_range( cluster_raw. borrow_mut( ) , 5555 , 5557 ) ,
1230
+ CassError :: CASS_OK
1231
+ ) ;
1232
+ }
1233
+
1234
+ cass_cluster_free ( cluster_raw) ;
1235
+ }
1236
+ }
1237
+
1124
1238
#[ test]
1125
1239
#[ ntest:: timeout( 100 ) ]
1126
1240
fn test_load_balancing_config ( ) {
0 commit comments