diff --git a/Makefile b/Makefile index a3f014f3..f6446285 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,8 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\ :SerialConsistencyTests.*\ :HeartbeatTests.*\ :PreparedTests.*\ +:StatementNoClusterTests.*\ +:StatementTests.*\ :NamedParametersTests.*\ :CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\ :ControlConnectionTests.*\ @@ -27,6 +29,7 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\ :PreparedMetadataTests.*\ :UseKeyspaceCaseSensitiveTests.*\ :ServerSideFailureTests.*\ +:ServerSideFailureThreeNodeTests.*\ :TimestampTests.*\ :MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\ :MetricsTests.Integration_Cassandra_Requests\ @@ -69,6 +72,8 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\ :SerialConsistencyTests.*\ :HeartbeatTests.*\ :PreparedTests.*\ +:StatementNoClusterTests.*\ +:StatementTests.*\ :NamedParametersTests.*\ :CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\ :ControlConnectionTests.*\ @@ -83,6 +88,7 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\ :PreparedMetadataTests.*\ :UseKeyspaceCaseSensitiveTests.*\ :ServerSideFailureTests.*\ +:ServerSideFailureThreeNodeTests.*\ :TimestampTests.*\ :MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\ :MetricsTests.Integration_Cassandra_Requests\ diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index 278b6ef9..e6620fe4 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -1112,12 +1112,11 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "scylla" -version = "1.1.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.1.0#ef5b0ada61989cedf9bcf5d715c8b36214f8b36f" +version = "1.2.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.2.0#bc6b24da1b17aa4d33df6a0f0283f08937ff1a19" dependencies = [ "arc-swap", "async-trait", - "byteorder", "bytes", "chrono", "dashmap", @@ -1125,14 +1124,11 @@ dependencies = [ "hashbrown 0.14.5", "histogram", "itertools", - "lazy_static", - "lz4_flex", "openssl", "rand 0.9.0", "rand_pcg", "scylla-cql", "smallvec", - "snap", "socket2", "thiserror 2.0.12", "tokio", @@ -1171,10 +1167,9 @@ dependencies = [ [[package]] name = "scylla-cql" -version = "1.1.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.1.0#ef5b0ada61989cedf9bcf5d715c8b36214f8b36f" +version = "1.2.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.2.0#bc6b24da1b17aa4d33df6a0f0283f08937ff1a19" dependencies = [ - "async-trait", "byteorder", "bytes", "chrono", @@ -1191,8 +1186,8 @@ dependencies = [ [[package]] name = "scylla-macros" -version = "1.1.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.1.0#ef5b0ada61989cedf9bcf5d715c8b36214f8b36f" +version = "1.2.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.2.0#bc6b24da1b17aa4d33df6a0f0283f08937ff1a19" dependencies = [ "darling", "proc-macro2", @@ -1202,8 +1197,8 @@ dependencies = [ [[package]] name = "scylla-proxy" -version = "0.0.3" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.1.0#ef5b0ada61989cedf9bcf5d715c8b36214f8b36f" +version = "0.0.4" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v1.2.0#bc6b24da1b17aa4d33df6a0f0283f08937ff1a19" dependencies = [ "bigdecimal", "byteorder", diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index 11a48b38..e5f87295 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -10,7 +10,7 @@ categories = ["database"] license = "MIT OR Apache-2.0" [dependencies] -scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0", features = [ +scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.2.0", features = [ "openssl-010", "metrics", ] } @@ -34,7 +34,7 @@ bindgen = "0.65" chrono = "0.4.20" [dev-dependencies] -scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0" } +scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.2.0" } bytes = "1.10.0" assert_matches = "1.5.0" diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 0b279a3c..b20d2cd0 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -5,17 +5,18 @@ use crate::cass_error::CassErrorMessage; use crate::cass_error::ToCassError; use crate::execution_error::CassErrorResult; use crate::prepared::CassPrepared; -use crate::query_result::CassResult; +use crate::query_result::{CassNode, CassResult}; use crate::types::*; use crate::uuid::CassUuid; use futures::future; use std::future::Future; use std::mem; use std::os::raw::c_void; -use std::sync::{Arc, Condvar, Mutex}; +use std::sync::{Arc, Condvar, Mutex, OnceLock}; use tokio::task::JoinHandle; use tokio::time::Duration; +#[derive(Debug)] pub enum CassResultValue { Empty, QueryResult(Arc), @@ -50,7 +51,6 @@ impl BoundCallback { #[derive(Default)] struct CassFutureState { - value: Option, err_string: Option, callback: Option, join_handle: Option>, @@ -58,6 +58,7 @@ struct CassFutureState { pub struct CassFuture { state: Mutex, + result: OnceLock, wait_for_value: Condvar, } @@ -87,6 +88,7 @@ impl CassFuture { ) -> Arc { let cass_fut = Arc::new(CassFuture { state: Mutex::new(Default::default()), + result: OnceLock::new(), wait_for_value: Condvar::new(), }); let cass_fut_clone = Arc::clone(&cass_fut); @@ -94,7 +96,10 @@ impl CassFuture { let r = fut.await; let maybe_cb = { let mut guard = cass_fut_clone.state.lock().unwrap(); - guard.value = Some(r); + cass_fut_clone + .result + .set(r) + .expect("Tried to resolve future result twice!"); // Take the callback and call it after releasing the lock guard.callback.take() }; @@ -115,16 +120,17 @@ impl CassFuture { pub fn new_ready(r: CassFutureResult) -> Arc { Arc::new(CassFuture { - state: Mutex::new(CassFutureState { - value: Some(r), - ..Default::default() - }), + state: Mutex::new(CassFutureState::default()), + result: OnceLock::from(r), wait_for_value: Condvar::new(), }) } - pub fn with_waited_result(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T { - self.with_waited_state(|s| f(s.value.as_mut().unwrap())) + pub fn with_waited_result<'s, T>(&'s self, f: impl FnOnce(&'s CassFutureResult) -> T) -> T + where + T: 's, + { + self.with_waited_state(|_| f(self.result.get().unwrap())) } /// Awaits the future until completion. @@ -153,7 +159,7 @@ impl CassFuture { guard = self .wait_for_value .wait_while(guard, |state| { - state.value.is_none() && state.join_handle.is_none() + self.result.get().is_none() && state.join_handle.is_none() }) // unwrap: Error appears only when mutex is poisoned. .unwrap(); @@ -171,10 +177,10 @@ impl CassFuture { fn with_waited_result_timed( &self, - f: impl FnOnce(&mut CassFutureResult) -> T, + f: impl FnOnce(&CassFutureResult) -> T, timeout_duration: Duration, ) -> Result { - self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration) + self.with_waited_state_timed(|_| f(self.result.get().unwrap()), timeout_duration) } /// Tries to await the future with a given timeout. @@ -242,7 +248,7 @@ impl CassFuture { let (guard_result, timeout_result) = self .wait_for_value .wait_timeout_while(guard, remaining_timeout, |state| { - state.value.is_none() && state.join_handle.is_none() + self.result.get().is_none() && state.join_handle.is_none() }) // unwrap: Error appears only when mutex is poisoned. .unwrap(); @@ -275,7 +281,7 @@ impl CassFuture { return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET; } let bound_cb = BoundCallback { cb, data }; - if lock.value.is_some() { + if self.result.get().is_some() { // The value is already available, we need to call the callback ourselves mem::drop(lock); bound_cb.invoke(self_ptr); @@ -345,8 +351,7 @@ pub unsafe extern "C" fn cass_future_ready( return cass_false; }; - let state_guard = future.state.lock().unwrap(); - match state_guard.value { + match future.result.get() { None => cass_false, Some(_) => cass_true, } @@ -361,7 +366,7 @@ pub unsafe extern "C" fn cass_future_error_code( return CassError::CASS_ERROR_LIB_BAD_PARAMS; }; - future.with_waited_result(|r: &mut CassFutureResult| match r { + future.with_waited_result(|r: &CassFutureResult| match r { Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), Err((err, _)) => *err, _ => CassError::CASS_OK, @@ -380,7 +385,7 @@ pub unsafe extern "C" fn cass_future_error_message( }; future.with_waited_state(|state: &mut CassFutureState| { - let value = &state.value; + let value = future.result.get(); let msg = state .err_string .get_or_insert_with(|| match value.as_ref().unwrap() { @@ -407,7 +412,7 @@ pub unsafe extern "C" fn cass_future_get_result( }; future - .with_waited_result(|r: &mut CassFutureResult| -> Option> { + .with_waited_result(|r: &CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)), _ => None, @@ -426,7 +431,7 @@ pub unsafe extern "C" fn cass_future_get_error_result( }; future - .with_waited_result(|r: &mut CassFutureResult| -> Option> { + .with_waited_result(|r: &CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryError(qr) => Some(Arc::clone(qr)), _ => None, @@ -445,7 +450,7 @@ pub unsafe extern "C" fn cass_future_get_prepared( }; future - .with_waited_result(|r: &mut CassFutureResult| -> Option> { + .with_waited_result(|r: &CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::Prepared(p) => Some(Arc::clone(p)), _ => None, @@ -464,7 +469,7 @@ pub unsafe extern "C" fn cass_future_tracing_id( return CassError::CASS_ERROR_LIB_BAD_PARAMS; }; - future.with_waited_result(|r: &mut CassFutureResult| match r { + future.with_waited_result(|r: &CassFutureResult| match r { Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { Some(id) => { unsafe { *tracing_id = CassUuid::from(id) }; @@ -476,6 +481,24 @@ pub unsafe extern "C" fn cass_future_tracing_id( }) } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_future_coordinator( + future_raw: CassBorrowedSharedPtr, +) -> CassBorrowedSharedPtr { + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future to cass_future_coordinator!"); + return RefFFI::null(); + }; + + future.with_waited_result(|r| match r { + Ok(CassResultValue::QueryResult(result)) => { + // unwrap: Coordinator is `None` only for tests. + RefFFI::as_ptr(result.coordinator.as_ref().unwrap()) + } + _ => RefFFI::null(), + }) +} + #[cfg(test)] mod tests { use crate::testing::{assert_cass_error_eq, assert_cass_future_error_message_eq}; diff --git a/scylla-rust-wrapper/src/integration_testing.rs b/scylla-rust-wrapper/src/integration_testing.rs index bf372eaa..ce97ac48 100644 --- a/scylla-rust-wrapper/src/integration_testing.rs +++ b/scylla-rust-wrapper/src/integration_testing.rs @@ -7,9 +7,12 @@ use scylla::errors::{RequestAttemptError, RequestError}; use scylla::observability::history::{AttemptId, HistoryListener, RequestId, SpeculativeId}; use scylla::policies::retry::RetryDecision; -use crate::argconv::{BoxFFI, CMut, CassBorrowedExclusivePtr}; +use crate::argconv::{ + ArcFFI, BoxFFI, CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr, +}; use crate::batch::CassBatch; use crate::cluster::CassCluster; +use crate::future::{CassFuture, CassResultValue}; use crate::statement::{BoundStatement, CassStatement}; use crate::types::{cass_int32_t, cass_uint16_t, cass_uint64_t, size_t}; @@ -60,8 +63,47 @@ pub unsafe extern "C" fn testing_cluster_get_contact_points( } #[unsafe(no_mangle)] -pub unsafe extern "C" fn testing_free_contact_points(contact_points: *mut c_char) { - let _ = unsafe { CString::from_raw(contact_points) }; +pub unsafe extern "C" fn testing_future_get_host( + future_raw: CassBorrowedSharedPtr, + host: *mut *mut c_char, + host_length: *mut size_t, +) { + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to testing_future_get_host!"); + unsafe { + *host = std::ptr::null_mut(); + *host_length = 0; + }; + return; + }; + + future.with_waited_result(|r| match r { + Ok(CassResultValue::QueryResult(result)) => { + // unwrap: Coordinator is none only for unit tests. + let coordinator = result.coordinator.as_ref().unwrap(); + + let ip_addr_str = coordinator.node().address.ip().to_string(); + let length = ip_addr_str.len(); + + let ip_addr_cstr = CString::new(ip_addr_str).expect( + "String obtained from IpAddr::to_string() should not contain any nul bytes!", + ); + + unsafe { + *host = ip_addr_cstr.into_raw(); + *host_length = length as size_t + }; + } + _ => unsafe { + *host = std::ptr::null_mut(); + *host_length = 0; + }, + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn testing_free_cstring(s: *mut c_char) { + let _ = unsafe { CString::from_raw(s) }; } #[derive(Debug)] diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index a88b86b1..36816530 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -17,8 +17,8 @@ use scylla::deserialize::row::{ use scylla::deserialize::value::DeserializeValue; use scylla::errors::{DeserializationError, IntoRowsResultError, TypeCheckError}; use scylla::frame::response::result::{ColumnSpec, DeserializedMetadataAndRawRows}; -use scylla::response::PagingStateResponse; use scylla::response::query_result::{ColumnSpecs, QueryResult}; +use scylla::response::{Coordinator, PagingStateResponse}; use scylla::value::{ Counter, CqlDate, CqlDecimalBorrowed, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, }; @@ -29,27 +29,41 @@ use std::sync::Arc; use thiserror::Error; use uuid::Uuid; +#[derive(Debug)] pub enum CassResultKind { NonRows, Rows(CassRowsResult), } +#[derive(Debug)] pub struct CassRowsResult { // Arc: shared with first_row (yoke). pub(crate) shared_data: Arc, pub(crate) first_row: Option, } +#[derive(Debug)] pub(crate) struct CassRowsResultSharedData { pub(crate) raw_rows: DeserializedMetadataAndRawRows, // Arc: shared with CassPrepared pub(crate) metadata: Arc, } +pub type CassNode = Coordinator; + +// Borrowed from CassResult in cass_future_coordinator. +impl FFI for CassNode { + type Origin = FromRef; +} + +#[derive(Debug)] pub struct CassResult { pub tracing_id: Option, pub paging_state_response: PagingStateResponse, pub kind: CassResultKind, + // None only for tests - currently no way to mock coordinator in rust-driver. + // Should be able to do so under "cpp_rust_unstable". + pub(crate) coordinator: Option, } impl CassResult { @@ -61,7 +75,7 @@ impl CassResult { result: QueryResult, paging_state_response: PagingStateResponse, maybe_result_metadata: Option>, - ) -> Result { + ) -> Result> { match result.into_rows_result() { Ok(rows_result) => { // maybe_result_metadata is: @@ -73,7 +87,7 @@ impl CassResult { )) }); - let (raw_rows, tracing_id, _) = rows_result.into_inner(); + let (raw_rows, tracing_id, _, coordinator) = rows_result.into_inner(); let shared_data = Arc::new(CassRowsResultSharedData { raw_rows, metadata }); let first_row = RowWithSelfBorrowedResultData::first_from_raw_rows_and_metadata( Arc::clone(&shared_data), @@ -86,6 +100,7 @@ impl CassResult { shared_data, first_row, }), + coordinator, }; Ok(cass_result) @@ -95,12 +110,13 @@ impl CassResult { tracing_id: result.tracing_id(), paging_state_response, kind: CassResultKind::NonRows, + coordinator: Some(result.request_coordinator().clone()), }; Ok(cass_result) } Err(IntoRowsResultError::ResultMetadataLazyDeserializationError(err)) => { - Err(err.into()) + Err(Arc::new(err.into())) } } } @@ -147,6 +163,7 @@ impl<'frame, 'metadata> DeserializeRow<'frame, 'metadata> for CassRawRow<'frame, /// The lifetime of CassRow is bound to CassResult. /// It will be freed, when CassResult is freed.(see #[cass_result_free]) +#[derive(Debug)] pub struct CassRow<'result> { pub columns: Vec>, pub result_metadata: &'result CassResultMetadata, @@ -218,7 +235,7 @@ mod row_with_self_borrowed_result_data { /// A simple wrapper over CassRow. /// Needed, so we can implement Yokeable for it, instead of implementing it for CassRow. - #[derive(Yokeable)] + #[derive(Debug, Yokeable)] struct CassRowWrapper<'result>(CassRow<'result>); /// A wrapper over struct which self-borrows the metadata allocated using Arc. @@ -231,6 +248,7 @@ mod row_with_self_borrowed_result_data { /// /// This struct is a shared owner of the row bytes and metadata, and self-borrows this data /// to the `CassRow` it contains. + #[derive(Debug)] pub struct RowWithSelfBorrowedResultData( Yoke, Arc>, ); @@ -239,7 +257,7 @@ mod row_with_self_borrowed_result_data { /// Constructs [`RowWithSelfBorrowedResultData`] based on the first row from `raw_rows_and_metadata`. pub(super) fn first_from_raw_rows_and_metadata( raw_rows_and_metadata: Arc, - ) -> Result, CassErrorResult> { + ) -> Result, Arc> { enum AttachError { CassErrorResult(CassErrorResult), NoRows, @@ -276,7 +294,7 @@ mod row_with_self_borrowed_result_data { match yoke_result { Ok(yoke) => Ok(Some(Self(yoke))), Err(AttachError::NoRows) => Ok(None), - Err(AttachError::CassErrorResult(err)) => Err(err), + Err(AttachError::CassErrorResult(err)) => Err(Arc::new(err)), } } @@ -295,6 +313,7 @@ pub(crate) mod cass_raw_value { use scylla::errors::{DeserializationError, TypeCheckError}; use thiserror::Error; + #[derive(Debug)] pub(crate) struct CassRawValue<'frame, 'metadata> { typ: &'metadata ColumnType<'metadata>, slice: Option>, @@ -416,6 +435,7 @@ pub(crate) mod cass_raw_value { } } +#[derive(Debug)] pub struct CassValue<'result> { pub(crate) value: CassRawValue<'result, 'result>, pub(crate) value_type: &'result Arc, @@ -1207,6 +1227,7 @@ mod tests { shared_data, first_row, }), + coordinator: None, } } @@ -1310,6 +1331,7 @@ mod tests { tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, kind: CassResultKind::NonRows, + coordinator: None, } } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 2612c9e2..62c53ad9 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -239,10 +239,11 @@ pub unsafe extern "C" fn cass_session_execute_batch( let query_res = session.batch(&state.batch, &state.bound_values).await; match query_res { - Ok(_result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult { + Ok(result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult { tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, kind: CassResultKind::NonRows, + coordinator: Some(result.request_coordinator().clone()), }))), Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), } @@ -395,7 +396,7 @@ pub unsafe extern "C" fn cass_session_execute( maybe_result_metadata, ) { Ok(result) => Ok(CassResultValue::QueryResult(Arc::new(result))), - Err(e) => Ok(CassResultValue::QueryError(Arc::new(e))), + Err(e) => Ok(CassResultValue::QueryError(e)), } } Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index 5aec5029..33eabf2d 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -1,13 +1,15 @@ use crate::cass_error::CassError; use crate::cass_types::CassConsistency; use crate::exec_profile::PerStatementExecProfile; +use crate::inet::CassInet; use crate::prepared::CassPrepared; -use crate::query_result::CassResult; +use crate::query_result::{CassNode, CassResult}; use crate::retry_policy::CassRetryPolicy; use crate::types::*; use crate::value::CassCqlValue; use crate::{argconv::*, value}; use scylla::frame::types::Consistency; +use scylla::policies::load_balancing::{NodeIdentifier, SingleTargetLoadBalancingPolicy}; use scylla::response::{PagingState, PagingStateResponse}; use scylla::serialize::SerializationError; use scylla::serialize::row::{RowSerializationContext, SerializeRow}; @@ -19,8 +21,10 @@ use scylla::value::MaybeUnset; use scylla::value::MaybeUnset::{Set, Unset}; use std::collections::HashMap; use std::convert::TryInto; +use std::net::{IpAddr, SocketAddr}; use std::os::raw::{c_char, c_int}; use std::slice; +use std::str::FromStr; use std::sync::Arc; use thiserror::Error; @@ -440,6 +444,132 @@ pub unsafe extern "C" fn cass_statement_set_tracing( CassError::CASS_OK } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_statement_set_host( + statement_raw: CassBorrowedExclusivePtr, + host: *const c_char, + port: c_int, +) -> CassError { + unsafe { cass_statement_set_host_n(statement_raw, host, strlen(host), port) } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_statement_set_host_n( + statement_raw: CassBorrowedExclusivePtr, + host: *const c_char, + host_length: size_t, + port: c_int, +) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_host_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let host = match unsafe { ptr_to_cstr_n(host, host_length) } { + Some(v) => v, + None => { + tracing::error!("Provided null or non-utf8 host pointer to cass_statement_set_host_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + } + }; + let Ok(port): Result = port.try_into() else { + tracing::error!("Provided invalid port value to cass_statement_set_host_n: {port}"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let address = match IpAddr::from_str(host) { + Ok(ip_addr) => SocketAddr::new(ip_addr, port), + Err(e) => { + tracing::error!("Failed to parse ip address <{}>: {}", host, e); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + } + }; + let enforce_target_lbp = + SingleTargetLoadBalancingPolicy::new(NodeIdentifier::NodeAddress(address), None); + + match &mut statement.statement { + BoundStatement::Simple(inner) => inner + .query + .set_load_balancing_policy(Some(enforce_target_lbp)), + BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) + .statement + .set_load_balancing_policy(Some(enforce_target_lbp)), + } + + CassError::CASS_OK +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_statement_set_host_inet( + statement_raw: CassBorrowedExclusivePtr, + host: *const CassInet, + port: c_int, +) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_host_inet!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + if host.is_null() { + tracing::error!("Provided null host pointer to cass_statement_set_host_inet!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + } + // SAFETY: Assuming that user provided valid pointer. + let ip_addr: IpAddr = match unsafe { *host }.try_into() { + Ok(ip_addr) => ip_addr, + Err(_) => { + tracing::error!("Provided invalid CassInet value to cass_statement_set_host_inet!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + } + }; + let Ok(port): Result = port.try_into() else { + tracing::error!("Provided invalid port value to cass_statement_set_host_n: {port}"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let address = SocketAddr::new(ip_addr, port); + let enforce_target_lbp = + SingleTargetLoadBalancingPolicy::new(NodeIdentifier::NodeAddress(address), None); + + match &mut statement.statement { + BoundStatement::Simple(inner) => inner + .query + .set_load_balancing_policy(Some(enforce_target_lbp)), + BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) + .statement + .set_load_balancing_policy(Some(enforce_target_lbp)), + } + + CassError::CASS_OK +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn cass_statement_set_node( + statement_raw: CassBorrowedExclusivePtr, + node_raw: CassBorrowedSharedPtr, +) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_node!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(node) = RefFFI::as_ref(node_raw) else { + tracing::error!("Provided null node pointer to cass_statement_set_node!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + let enforce_target_lbp = + SingleTargetLoadBalancingPolicy::new(NodeIdentifier::Node(Arc::clone(node.node())), None); + + match &mut statement.statement { + BoundStatement::Simple(inner) => inner + .query + .set_load_balancing_policy(Some(enforce_target_lbp)), + BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) + .statement + .set_load_balancing_policy(Some(enforce_target_lbp)), + } + + CassError::CASS_OK +} + #[unsafe(no_mangle)] pub unsafe extern "C" fn cass_statement_set_retry_policy( statement: CassBorrowedExclusivePtr, @@ -694,3 +824,158 @@ make_binders!( cass_statement_bind_user_type_by_name, cass_statement_bind_user_type_by_name_n ); + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + use std::ptr::addr_of; + use std::str::FromStr; + + use crate::argconv::{BoxFFI, RefFFI}; + use crate::cass_error::CassError; + use crate::inet::CassInet; + use crate::statement::{ + cass_statement_set_host, cass_statement_set_host_inet, cass_statement_set_node, + }; + use crate::testing::assert_cass_error_eq; + + use super::{cass_statement_free, cass_statement_new}; + + #[test] + fn test_statement_set_host() { + unsafe { + let mut statement_raw = cass_statement_new(c"dummy".as_ptr(), 0); + + // cass_statement_set_host + { + // Null statement + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host(BoxFFI::null_mut(), c"127.0.0.1".as_ptr(), 9042) + ); + + // Null ip address + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host(statement_raw.borrow_mut(), std::ptr::null(), 9042) + ); + + // Unparsable ip address + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host(statement_raw.borrow_mut(), c"invalid".as_ptr(), 9042) + ); + + // Negative port + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host(statement_raw.borrow_mut(), c"127.0.0.1".as_ptr(), -1) + ); + + // Port too big + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host( + statement_raw.borrow_mut(), + c"127.0.0.1".as_ptr(), + 70000 + ) + ); + + // Valid ip address and port + assert_cass_error_eq!( + CassError::CASS_OK, + cass_statement_set_host( + statement_raw.borrow_mut(), + c"127.0.0.1".as_ptr(), + 9042 + ) + ); + } + + // cass_statement_set_host_inet + { + let valid_inet: CassInet = IpAddr::from_str("127.0.0.1").unwrap().into(); + + let invalid_inet = CassInet { + address: [0; 16], + // invalid length - should be 4 or 16 + address_length: 3, + }; + + // Null statement + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host_inet(BoxFFI::null_mut(), addr_of!(valid_inet), 9042) + ); + + // Null CassInet + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host_inet( + statement_raw.borrow_mut(), + std::ptr::null(), + 9042 + ) + ); + + // Invalid CassInet + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host_inet( + statement_raw.borrow_mut(), + addr_of!(invalid_inet), + 9042 + ) + ); + + // Negative port + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host_inet( + statement_raw.borrow_mut(), + addr_of!(valid_inet), + -1 + ) + ); + + // Port too big + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_host_inet( + statement_raw.borrow_mut(), + addr_of!(valid_inet), + 70000 + ) + ); + + // Valid ip address and port + assert_cass_error_eq!( + CassError::CASS_OK, + cass_statement_set_host_inet( + statement_raw.borrow_mut(), + addr_of!(valid_inet), + 9042 + ) + ); + } + + // cass_statement_set_node + { + // Null statement + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_node(BoxFFI::null_mut(), RefFFI::null()) + ); + + // Null CassNode + assert_cass_error_eq!( + CassError::CASS_ERROR_LIB_BAD_PARAMS, + cass_statement_set_node(statement_raw.borrow_mut(), RefFFI::null()) + ); + } + + cass_statement_free(statement_raw); + } + } +} diff --git a/src/testing.cpp b/src/testing.cpp index 7e397901..2f7d9c12 100644 --- a/src/testing.cpp +++ b/src/testing.cpp @@ -32,7 +32,22 @@ namespace datastax { namespace internal { namespace testing { using namespace core; String get_host_from_future(CassFuture* future) { - throw std::runtime_error("Unimplemented 'get_host_from_future'!"); + char* host; + size_t host_length; + + testing_future_get_host(future, &host, &host_length); + + if (host == nullptr) { + throw std::runtime_error("CassFuture returned a null host string."); + } + + std::string host_str(host, host_length); + OStringStream ss; + ss << host_str; + + testing_free_cstring(host); + + return ss.str(); } StringVec get_attempted_hosts_from_future(CassFuture* future) { @@ -59,7 +74,7 @@ String get_contact_points_from_cluster(CassCluster* cluster) { OStringStream ss; ss << contact_points_str; - testing_free_contact_points(contact_points); + testing_free_cstring(contact_points); return ss.str(); } diff --git a/src/testing_rust_impls.h b/src/testing_rust_impls.h index e6f07800..a5866ff7 100644 --- a/src/testing_rust_impls.h +++ b/src/testing_rust_impls.h @@ -16,11 +16,18 @@ CASS_EXPORT cass_int32_t testing_cluster_get_port(CassCluster* cluster); // Then, the resulting pointer is set to null. // // On success, this function allocates a contact points string, which needs to be then -// freed with `testing_free_contact_points`. +// freed with `testing_free_cstring`. CASS_EXPORT void testing_cluster_get_contact_points(CassCluster* cluster, char** contact_points, size_t* contact_points_length); -CASS_EXPORT void testing_free_contact_points(char* contact_points); +// Returns an ip address of request coordinator. +// +// This method fails if the future resolved to some error. +// +// On success, it allocates a host string which needs to be then freed wih `testing_free_cstring`. +CASS_EXPORT void testing_future_get_host(const CassFuture* future, char** host, size_t* host_length); + +CASS_EXPORT void testing_free_cstring(char *s); // Sets a sleeping history listener on the statement. // This can be used to enforce a sleep time during statement execution, which increases the latency. diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp index 960ab3e2..9ba9b257 100644 --- a/src/testing_unimplemented.cpp +++ b/src/testing_unimplemented.cpp @@ -162,9 +162,6 @@ CASS_EXPORT const CassDataType* cass_function_meta_return_type(const CassFunctionMeta* function_meta) { throw std::runtime_error("UNIMPLEMENTED cass_function_meta_return_type\n"); } -CASS_EXPORT const CassNode* cass_future_coordinator(CassFuture* future) { - throw std::runtime_error("UNIMPLEMENTED cass_future_coordinator\n"); -} CASS_EXPORT const CassValue* cass_index_meta_field_by_name(const CassIndexMeta* index_meta, const char* name) { throw std::runtime_error("UNIMPLEMENTED cass_index_meta_field_by_name\n"); @@ -220,20 +217,9 @@ CASS_EXPORT CassError cass_statement_set_custom_payload(CassStatement* statement const CassCustomPayload* payload) { throw std::runtime_error("UNIMPLEMENTED cass_statement_set_custom_payload\n"); } -CASS_EXPORT CassError cass_statement_set_host(CassStatement* statement, const char* host, - int port) { - throw std::runtime_error("UNIMPLEMENTED cass_statement_set_host\n"); -} -CASS_EXPORT CassError cass_statement_set_host_inet(CassStatement* statement, const CassInet* host, - int port) { - throw std::runtime_error("UNIMPLEMENTED cass_statement_set_host_inet\n"); -} CASS_EXPORT CassError cass_statement_set_keyspace(CassStatement* statement, const char* keyspace) { throw std::runtime_error("UNIMPLEMENTED cass_statement_set_keyspace\n"); } -CASS_EXPORT CassError cass_statement_set_node(CassStatement* statement, const CassNode* node) { - throw std::runtime_error("UNIMPLEMENTED cass_statement_set_node\n"); -} CASS_EXPORT CassClusteringOrder cass_table_meta_clustering_key_order(const CassTableMeta* table_meta, size_t index) { throw std::runtime_error("UNIMPLEMENTED cass_table_meta_clustering_key_order\n");