Skip to content

Commit eb3d6b9

Browse files
authored
Merge pull request #322 from wprzytula/support-set_record_contacted_hosts
IT: Support `set_record_contacted_hosts()` and `get_attempted_hosts_from_future()`
2 parents db766eb + dbda206 commit eb3d6b9

File tree

11 files changed

+386
-87
lines changed

11 files changed

+386
-87
lines changed

.github/pull_request_template.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
- [ ] PR description sums up the changes and reasons why they should be introduced.
1212
- [ ] I have implemented Rust unit tests for the features/changes introduced.
1313
- [ ] I have enabled appropriate tests in `Makefile` in `{SCYLLA,CASSANDRA}_(NO_VALGRIND_)TEST_FILTER`.
14+
- [ ] I added appropriate `Fixes:` annotations to PR description.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.vscode/
2+
.zed
23
build/
34
scylla-rust-wrapper/target/
45
.idea/

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,4 @@ endif
292292
build/cassandra-integration-tests --version=${CASSANDRA_VERSION} --category=CASSANDRA --verbose=ccm --gtest_filter="${CASSANDRA_NO_VALGRIND_TEST_FILTER}"
293293

294294
run-test-unit: install-cargo-if-missing _update-rust-tooling
295-
@cd ${CURRENT_DIR}/scylla-rust-wrapper; cargo test
295+
@cd ${CURRENT_DIR}/scylla-rust-wrapper; RUSTFLAGS="--cfg cpp_rust_unstable --cfg cpp_integration_testing" cargo test
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[build]
2-
# To enable cpp-rust only features from Rust driver.
2+
# To enable cpp-rust only features from Rust driver.
33
rustflags = ["--cfg", "cpp_rust_unstable"]

scylla-rust-wrapper/src/future.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ pub struct CassFuture {
6060
state: Mutex<CassFutureState>,
6161
result: OnceLock<CassFutureResult>,
6262
wait_for_value: Condvar,
63+
#[cfg(cpp_integration_testing)]
64+
recording_listener: Option<Arc<crate::integration_testing::RecordingHistoryListener>>,
6365
}
6466

6567
impl FFI for CassFuture {
@@ -79,17 +81,30 @@ struct JoinHandleTimeout(JoinHandle<()>);
7981
impl CassFuture {
8082
pub(crate) fn make_raw(
8183
fut: impl Future<Output = CassFutureResult> + Send + 'static,
84+
#[cfg(cpp_integration_testing)] recording_listener: Option<
85+
Arc<crate::integration_testing::RecordingHistoryListener>,
86+
>,
8287
) -> CassOwnedSharedPtr<CassFuture, CMut> {
83-
Self::new_from_future(fut).into_raw()
88+
Self::new_from_future(
89+
fut,
90+
#[cfg(cpp_integration_testing)]
91+
recording_listener,
92+
)
93+
.into_raw()
8494
}
8595

8696
pub(crate) fn new_from_future(
8797
fut: impl Future<Output = CassFutureResult> + Send + 'static,
98+
#[cfg(cpp_integration_testing)] recording_listener: Option<
99+
Arc<crate::integration_testing::RecordingHistoryListener>,
100+
>,
88101
) -> Arc<CassFuture> {
89102
let cass_fut = Arc::new(CassFuture {
90103
state: Mutex::new(Default::default()),
91104
result: OnceLock::new(),
92105
wait_for_value: Condvar::new(),
106+
#[cfg(cpp_integration_testing)]
107+
recording_listener,
93108
});
94109
let cass_fut_clone = Arc::clone(&cass_fut);
95110
let join_handle = RUNTIME.spawn(async move {
@@ -125,6 +140,8 @@ impl CassFuture {
125140
state: Mutex::new(CassFutureState::default()),
126141
result: OnceLock::from(r),
127142
wait_for_value: Condvar::new(),
143+
#[cfg(cpp_integration_testing)]
144+
recording_listener: None,
128145
})
129146
}
130147

@@ -300,6 +317,15 @@ impl CassFuture {
300317
fn into_raw(self: Arc<Self>) -> CassOwnedSharedPtr<Self, CMut> {
301318
ArcFFI::into_ptr(self)
302319
}
320+
321+
#[cfg(cpp_integration_testing)]
322+
pub(crate) fn attempted_hosts(&self) -> Vec<std::net::SocketAddr> {
323+
if let Some(listener) = &self.recording_listener {
324+
listener.get_attempted_hosts()
325+
} else {
326+
vec![]
327+
}
328+
}
303329
}
304330

305331
// Do not remove; this asserts that `CassFuture` implements Send + Sync,
@@ -527,7 +553,11 @@ mod tests {
527553
tokio::time::sleep(Duration::from_millis(10)).await;
528554
Err((CassError::CASS_OK, ERROR_MSG.into()))
529555
};
530-
let cass_fut = CassFuture::make_raw(fut);
556+
let cass_fut = CassFuture::make_raw(
557+
fut,
558+
#[cfg(cpp_integration_testing)]
559+
None,
560+
);
531561

532562
struct PtrWrapper(CassBorrowedSharedPtr<'static, CassFuture, CMut>);
533563
unsafe impl Send for PtrWrapper {}
@@ -562,7 +592,11 @@ mod tests {
562592
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS)).await;
563593
Err((CassError::CASS_OK, ERROR_MSG.into()))
564594
};
565-
let cass_fut = CassFuture::make_raw(fut);
595+
let cass_fut = CassFuture::make_raw(
596+
fut,
597+
#[cfg(cpp_integration_testing)]
598+
None,
599+
);
566600

567601
unsafe {
568602
// This should timeout on tokio::time::timeout.
@@ -609,7 +643,11 @@ mod tests {
609643
tokio::time::sleep(Duration::from_micros(HUNDRED_MILLIS_IN_MICROS)).await;
610644
Err((CassError::CASS_OK, ERROR_MSG.into()))
611645
};
612-
let cass_fut = CassFuture::make_raw(fut);
646+
let cass_fut = CassFuture::make_raw(
647+
fut,
648+
#[cfg(cpp_integration_testing)]
649+
None,
650+
);
613651
let flag = Box::new(false);
614652
let flag_ptr = Box::into_raw(flag);
615653

scylla-rust-wrapper/src/integration_testing.rs

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::ffi::{CString, c_char};
2+
use std::fmt::Write as _;
23
use std::net::SocketAddr;
34
use std::sync::Arc;
45
use std::time::Duration;
@@ -16,7 +17,7 @@ use crate::cluster::CassCluster;
1617
use crate::future::{CassFuture, CassResultValue};
1718
use crate::retry_policy::CassRetryPolicy;
1819
use crate::statement::{BoundStatement, CassStatement};
19-
use crate::types::{cass_int32_t, cass_uint16_t, cass_uint64_t, size_t};
20+
use crate::types::{cass_bool_t, cass_int32_t, cass_uint16_t, cass_uint64_t, size_t};
2021

2122
#[unsafe(no_mangle)]
2223
pub unsafe extern "C" fn testing_cluster_get_connect_timeout(
@@ -177,6 +178,156 @@ pub unsafe extern "C" fn testing_batch_set_sleeping_history_listener(
177178
.set_history_listener(history_listener)
178179
}
179180

181+
/// Used to record attempted hosts during a request.
182+
/// This is useful for testing purposes and is used in integration tests.
183+
/// This is enabled by `testing_statement_set_recording_history_listener`
184+
/// and can be queried using `testing_future_get_attempted_hosts`.
185+
#[derive(Debug)]
186+
pub(crate) struct RecordingHistoryListener {
187+
attempted_hosts: std::sync::Mutex<Vec<SocketAddr>>,
188+
}
189+
190+
impl RecordingHistoryListener {
191+
pub(crate) fn new() -> Self {
192+
RecordingHistoryListener {
193+
attempted_hosts: std::sync::Mutex::new(Vec::new()),
194+
}
195+
}
196+
197+
pub(crate) fn get_attempted_hosts(&self) -> Vec<SocketAddr> {
198+
self.attempted_hosts.lock().unwrap().clone()
199+
}
200+
}
201+
202+
impl HistoryListener for RecordingHistoryListener {
203+
fn log_request_start(&self) -> RequestId {
204+
RequestId(0)
205+
}
206+
207+
fn log_request_success(&self, _request_id: RequestId) {}
208+
209+
fn log_request_error(&self, _request_id: RequestId, _error: &RequestError) {}
210+
211+
fn log_new_speculative_fiber(&self, _request_id: RequestId) -> SpeculativeId {
212+
SpeculativeId(0)
213+
}
214+
215+
fn log_attempt_start(
216+
&self,
217+
_request_id: RequestId,
218+
_speculative_id: Option<SpeculativeId>,
219+
node_addr: SocketAddr,
220+
) -> AttemptId {
221+
// Record the host that was attempted.
222+
self.attempted_hosts.lock().unwrap().push(node_addr);
223+
224+
AttemptId(0)
225+
}
226+
227+
fn log_attempt_success(&self, _attempt_id: AttemptId) {}
228+
229+
fn log_attempt_error(
230+
&self,
231+
_attempt_id: AttemptId,
232+
_error: &RequestAttemptError,
233+
_retry_decision: &RetryDecision,
234+
) {
235+
}
236+
}
237+
238+
/// Enables recording of attempted hosts in the statement's history listener.
239+
/// This is useful for testing purposes, allowing us to verify which hosts were attempted
240+
/// during a request.
241+
/// Later, `testing_future_get_attempted_hosts` can be used to retrieve this information.
242+
#[unsafe(no_mangle)]
243+
pub unsafe extern "C" fn testing_statement_set_recording_history_listener(
244+
statement_raw: CassBorrowedExclusivePtr<CassStatement, CMut>,
245+
enable: cass_bool_t,
246+
) {
247+
let statement = &mut BoxFFI::as_mut_ref(statement_raw).unwrap();
248+
249+
statement.record_hosts = enable != 0;
250+
}
251+
252+
/// Retrieves hosts that were attempted during the execution of a future.
253+
/// In order for this to work, the statement must have been configured with
254+
/// `testing_statement_set_recording_history_listener`.
255+
#[unsafe(no_mangle)]
256+
pub unsafe extern "C" fn testing_future_get_attempted_hosts(
257+
future_raw: CassBorrowedSharedPtr<CassFuture, CMut>,
258+
) -> *mut c_char {
259+
// This function should return a list of attempted hosts.
260+
// Care should be taken to ensure that the list is properly allocated and freed.
261+
// The problem is that the return type must be understandable by C code.
262+
// See testing.cpp:53 (get_attempted_hosts_from_future()) for an example of how
263+
// this is done in C++.
264+
//
265+
// Idea: Create a concatenated string of attempted hosts.
266+
// Return a pointer to that string. Caller is responsible for freeing it.
267+
268+
let future: &CassFuture = ArcFFI::as_ref(future_raw).unwrap();
269+
270+
let attempted_hosts = future.attempted_hosts();
271+
let concatenated_hosts = attempted_hosts.iter().fold(String::new(), |mut acc, host| {
272+
// Convert the SocketAddr to a string.
273+
// Delimit address strings with '\n' to enable easy parsing in C.
274+
275+
write!(&mut acc, "{}\n", host.ip()).unwrap();
276+
acc
277+
});
278+
279+
// The caller is responsible for freeing this memory, by calling `testing_free_cstring`.
280+
unsafe { CString::from_vec_unchecked(concatenated_hosts.into_bytes()) }.into_raw()
281+
}
282+
283+
/// Ensures that the `testing_future_get_attempted_hosts` function
284+
/// behaves correctly, i.e., it returns a list of attempted hosts as a concatenated string.
285+
#[test]
286+
fn test_future_get_attempted_hosts() {
287+
use scylla::observability::history::HistoryListener as _;
288+
289+
let listener = Arc::new(RecordingHistoryListener::new());
290+
let future = CassFuture::new_from_future(std::future::pending(), Some(listener.clone()));
291+
292+
fn assert_attempted_hosts_eq(future: &Arc<CassFuture>, hosts: &[String]) {
293+
let hosts_str = unsafe { testing_future_get_attempted_hosts(ArcFFI::as_ptr(future)) };
294+
let hosts_cstr = unsafe { CString::from_raw(hosts_str) };
295+
let hosts_string = hosts_cstr.to_str().unwrap();
296+
297+
// Split the string by '\n' and collect into a Vec<&str>.
298+
let attempted_hosts: Vec<&str> =
299+
hosts_string.split('\n').filter(|s| !s.is_empty()).collect();
300+
301+
assert_eq!(attempted_hosts, hosts);
302+
}
303+
304+
// 1. Test with no attempted hosts.
305+
{
306+
assert_attempted_hosts_eq(&future, &[]);
307+
}
308+
309+
let addr1: SocketAddr = SocketAddr::from(([127, 0, 0, 1], 9042));
310+
let addr2: SocketAddr = SocketAddr::from(([127, 0, 0, 2], 9042));
311+
312+
// 2. Attempt two hosts and see if they are recorded correctly, in order.
313+
{
314+
listener.log_attempt_start(RequestId(0), None, addr1);
315+
listener.log_attempt_start(RequestId(0), None, addr2);
316+
assert_attempted_hosts_eq(&future, &[addr1, addr2].map(|addr| addr.ip().to_string()))
317+
}
318+
319+
let addr3: SocketAddr = SocketAddr::from(([127, 0, 0, 3], 9042));
320+
321+
// 3. Attempt one more host and see if all hosts are present, in order.
322+
{
323+
listener.log_attempt_start(RequestId(0), None, addr3);
324+
assert_attempted_hosts_eq(
325+
&future,
326+
&[addr1, addr2, addr3].map(|addr| addr.ip().to_string()),
327+
)
328+
}
329+
}
330+
180331
/// A retry policy that always ignores all errors.
181332
///
182333
/// Useful for testing purposes.

scylla-rust-wrapper/src/prepared.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use scylla::response::PagingState;
21
use scylla::value::MaybeUnset::Unset;
32
use std::{os::raw::c_char, sync::Arc};
43

@@ -7,7 +6,7 @@ use crate::{
76
cass_error::CassError,
87
cass_types::{CassDataType, get_column_type},
98
query_result::CassResultMetadata,
10-
statement::{BoundPreparedStatement, BoundStatement, CassStatement},
9+
statement::{BoundPreparedStatement, CassStatement},
1110
types::size_t,
1211
};
1312
use scylla::statement::prepared::PreparedStatement;
@@ -98,19 +97,12 @@ pub unsafe extern "C" fn cass_prepared_bind(
9897
// cloning prepared statement's arc, because creating CassStatement should not invalidate
9998
// the CassPrepared argument
10099

101-
let statement = BoundStatement::Prepared(BoundPreparedStatement {
100+
let statement = BoundPreparedStatement {
102101
statement: prepared,
103102
bound_values: vec![Unset; bound_values_size],
104-
});
105-
106-
BoxFFI::into_ptr(Box::new(CassStatement {
107-
statement,
108-
paging_state: PagingState::start(),
109-
// Cpp driver disables paging by default.
110-
paging_enabled: false,
111-
request_timeout_ms: None,
112-
exec_profile: None,
113-
}))
103+
};
104+
105+
BoxFFI::into_ptr(Box::new(CassStatement::new_prepared(statement)))
114106
}
115107

116108
#[unsafe(no_mangle)]

0 commit comments

Comments
 (0)