diff --git a/Makefile b/Makefile
index fe8127ac..8436b0b1 100644
--- a/Makefile
+++ b/Makefile
@@ -27,6 +27,7 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:UseKeyspaceCaseSensitiveTests.*\
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
:MetricsTests.Integration_Cassandra_Requests\
+:MetricsTests.Integration_Cassandra_StatsShardConnections\
:-PreparedTests.Integration_Cassandra_PreparedIDUnchangedDuringReprepare\
:HeartbeatTests.Integration_Cassandra_HeartbeatFailed\
:ControlConnectionTests.Integration_Cassandra_TopologyChange\
@@ -68,6 +69,7 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:UseKeyspaceCaseSensitiveTests.*\
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
:MetricsTests.Integration_Cassandra_Requests\
+:MetricsTests.Integration_Cassandra_StatsShardConnections\
:-PreparedTests.Integration_Cassandra_PreparedIDUnchangedDuringReprepare\
:PreparedTests.Integration_Cassandra_FailFastWhenPreparedIDChangesDuringReprepare\
:HeartbeatTests.Integration_Cassandra_HeartbeatFailed\
diff --git a/include/cassandra.h b/include/cassandra.h
index f5882125..5de60fa3 100644
--- a/include/cassandra.h
+++ b/include/cassandra.h
@@ -1780,21 +1780,46 @@ cass_cluster_set_queue_size_event(CassCluster* cluster,
unsigned queue_size));
/**
- * Sets the number of connections made to each server in each
- * IO thread.
+ * Sets the number of connections opened by the driver to each host.
*
- * Default: 1
+ * Notice that this overrides the number of connections per shard
+ * set by `cass_cluster_set_core_connections_per_shard()`.
+ *
+ * Default: 1 per shard (i.e. `cass_cluster_set_core_connections_per_shard(cluster, 1)`)
*
* @public @memberof CassCluster
*
* @param[in] cluster
* @param[in] num_connections
* @return CASS_OK if successful, otherwise an error occurred.
+ *
+ * @see cass_cluster_set_core_connections_per_shard()
*/
CASS_EXPORT CassError
cass_cluster_set_core_connections_per_host(CassCluster* cluster,
unsigned num_connections);
+/**
+ * Sets the number of connections opened by the driver to each shard.
+ *
+ * Cassandra nodes are treated as if they have one shard.
+ *
+ * This will override the `cass_cluster_set_core_connections_per_host`, if set.
+ *
+ * Default: 1
+ *
+ * @public @memberof CassCluster
+ *
+ * @param[in] cluster
+ * @param[in] num_connections
+ * @return CASS_OK if successful, otherwise an error occurred.
+ *
+ * @see cass_cluster_set_core_connections_per_host()
+ */
+CASS_EXPORT CassError
+cass_cluster_set_core_connections_per_shard(CassCluster* cluster,
+ unsigned num_connections);
+
/**
* Sets the maximum number of connections made to each server in each
* IO thread.
diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs
index ddc42ca3..eeba1744 100644
--- a/scylla-rust-wrapper/src/cluster.rs
+++ b/scylla-rust-wrapper/src/cluster.rs
@@ -12,7 +12,7 @@ use openssl::ssl::SslContextBuilder;
use openssl_sys::SSL_CTX_up_ref;
use scylla::client::execution_profile::ExecutionProfileBuilder;
use scylla::client::session_builder::SessionBuilder;
-use scylla::client::{SelfIdentity, WriteCoalescingDelay};
+use scylla::client::{PoolSize, SelfIdentity, WriteCoalescingDelay};
use scylla::frame::Compression;
use scylla::policies::load_balancing::{
DefaultPolicyBuilder, LatencyAwarenessBuilder, LoadBalancingPolicy,
@@ -25,7 +25,7 @@ use std::collections::HashMap;
use std::convert::TryInto;
use std::future::Future;
use std::net::IpAddr;
-use std::num::NonZero;
+use std::num::{NonZero, NonZeroUsize};
use std::os::raw::{c_char, c_int, c_uint};
use std::str::FromStr;
use std::sync::Arc;
@@ -47,6 +47,8 @@ const DEFAULT_MAX_SCHEMA_WAIT_TIME: Duration = Duration::from_millis(10000);
const DEFAULT_SCHEMA_AGREEMENT_INTERVAL: Duration = Duration::from_millis(200);
// - setting TCP_NODELAY is true
const DEFAULT_SET_TCP_NO_DELAY: bool = true;
+// - connection pool size is 1 per shard
+const DEFAULT_CONNECTION_POOL_SIZE: PoolSize = PoolSize::PerShard(NonZeroUsize::new(1).unwrap());
// - enabling write coalescing
const DEFAULT_ENABLE_WRITE_COALESCING: bool = true;
// - write coalescing delay
@@ -234,6 +236,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr,
+ num_connections: c_uint,
+) -> CassError {
+ let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
+ tracing::error!(
+ "Provided null cluster pointer to cass_cluster_set_core_connections_per_host!"
+ );
+ return CassError::CASS_ERROR_LIB_BAD_PARAMS;
+ };
+
+ match NonZeroUsize::new(num_connections as usize) {
+ Some(non_zero_conns) => {
+ cluster.session_builder.config.connection_pool_size = PoolSize::PerHost(non_zero_conns);
+ CassError::CASS_OK
+ }
+ None => {
+ tracing::error!(
+ "Provided zero connections to cass_cluster_set_core_connections_per_host!"
+ );
+ CassError::CASS_ERROR_LIB_BAD_PARAMS
+ }
+ }
+}
+
+#[unsafe(no_mangle)]
+pub unsafe extern "C" fn cass_cluster_set_core_connections_per_shard(
+ cluster_raw: CassBorrowedExclusivePtr,
+ num_connections: c_uint,
+) -> CassError {
+ let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else {
+ tracing::error!(
+ "Provided null cluster pointer to cass_cluster_set_core_connections_per_shard!"
+ );
+ return CassError::CASS_ERROR_LIB_BAD_PARAMS;
+ };
+
+ match NonZeroUsize::new(num_connections as usize) {
+ Some(non_zero_conns) => {
+ cluster.session_builder.config.connection_pool_size =
+ PoolSize::PerShard(non_zero_conns);
+ CassError::CASS_OK
+ }
+ None => {
+ tracing::error!(
+ "Provided zero connections to cass_cluster_set_core_connections_per_shard!"
+ );
+ CassError::CASS_ERROR_LIB_BAD_PARAMS
+ }
+ }
+}
+
#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_cluster_set_coalesce_delay(
cluster_raw: CassBorrowedExclusivePtr,
diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp
index 7be97ead..3843b8ac 100644
--- a/src/testing_unimplemented.cpp
+++ b/src/testing_unimplemented.cpp
@@ -65,10 +65,6 @@ CASS_EXPORT CassError cass_cluster_set_cloud_secure_connection_bundle_no_ssl_lib
CASS_EXPORT void cass_cluster_set_constant_reconnect(CassCluster* cluster, cass_uint64_t delay_ms) {
throw std::runtime_error("UNIMPLEMENTED cass_cluster_set_constant_reconnect\n");
}
-CASS_EXPORT CassError cass_cluster_set_core_connections_per_host(CassCluster* cluster,
- unsigned num_connections) {
- throw std::runtime_error("UNIMPLEMENTED cass_cluster_set_core_connections_per_host\n");
-}
CASS_EXPORT CassError cass_cluster_set_host_listener_callback(CassCluster* cluster,
CassHostListenerCallback callback,
void* data) {
diff --git a/tests/src/integration/objects/cluster.hpp b/tests/src/integration/objects/cluster.hpp
index a03f1af9..7b8f32cc 100644
--- a/tests/src/integration/objects/cluster.hpp
+++ b/tests/src/integration/objects/cluster.hpp
@@ -170,6 +170,19 @@ class Cluster : public Object {
return *this;
}
+ /**
+ * Assign the number of connections made to each shard
+ *
+ * NOTE: One extra connection is established (the control connection)
+ *
+ * @param connections Number of connection per shard (default: 1)
+ * @return Cluster object
+ */
+ Cluster& with_core_connections_per_shard(unsigned int connections = 1u) {
+ EXPECT_EQ(CASS_OK, cass_cluster_set_core_connections_per_shard(get(), connections));
+ return *this;
+ }
+
/**
* Sets credentials for plain text authentication
*
diff --git a/tests/src/integration/tests/test_metrics.cpp b/tests/src/integration/tests/test_metrics.cpp
index be517261..006eb356 100644
--- a/tests/src/integration/tests/test_metrics.cpp
+++ b/tests/src/integration/tests/test_metrics.cpp
@@ -76,6 +76,31 @@ CASSANDRA_INTEGRATION_TEST_F(MetricsTests, ErrorsConnectionTimeouts) {
EXPECT_GE(2u, metrics.errors.connection_timeouts);
}
+/**
+ * This test ensures that the driver is reporting the number of connections
+ * when connection pool size is configured per shard.
+ */
+CASSANDRA_INTEGRATION_TEST_F(MetricsTests, StatsShardConnections) {
+ CHECK_FAILURE;
+
+ const unsigned int CONNS_PER_SHARD = 2;
+
+ Session session =
+ default_cluster().with_core_connections_per_shard(CONNS_PER_SHARD).connect();
+
+ size_t nr_hosts = explode(contact_points_, ',').size();
+ size_t nr_shards = Options::is_scylla() ? Options::smp() : 1;
+ size_t expected_connection_count = nr_hosts * nr_shards * CONNS_PER_SHARD;
+
+ CassMetrics metrics = session.metrics();
+ for (int i = 0; i < 100 && metrics.stats.total_connections < expected_connection_count; ++i) {
+ metrics = session.metrics();
+ msleep(100);
+ }
+
+ EXPECT_GE(metrics.stats.total_connections, expected_connection_count);
+}
+
/**
* This test ensures that the driver is reporting the proper timeouts for requests
*