Skip to content

Commit c1e40d7

Browse files
committed
session: pass Arc to async block instead of reference
cpp-driver assumes that session object can be prematurely dropped. This means, that we should increase the reference count in functions where we pass the session to an async block. This will prevent UAF. There actually is a test case for this (AsyncTests::Close). However, we cannot enable it yet, since it expects that prematurely dropped session awaits all async tasks and before closing. Also removing the outdated comment regarding &'static reference.
1 parent f9a1001 commit c1e40d7

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

scylla-rust-wrapper/src/session.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,7 @@ impl CassSessionInner {
7070
}
7171

7272
fn connect(
73-
// This reference is 'static because this is the only was of assuring the borrow checker
74-
// that holding it in our returned future is sound. Ideally, we would prefer to have
75-
// the returned future's lifetime constrained by real lifetime of the session's RwLock,
76-
// but this is impossible to be guaranteed due to C/Rust cross-language barrier.
77-
session_opt: &'static RwLock<Option<CassSessionInner>>,
73+
session_opt: Arc<RwLock<Option<CassSessionInner>>>,
7874
cluster: &CassCluster,
7975
keyspace: Option<String>,
8076
) -> *mut CassFuture {
@@ -94,7 +90,7 @@ impl CassSessionInner {
9490
}
9591

9692
async fn connect_fut(
97-
session_opt: &RwLock<Option<CassSessionInner>>,
93+
session_opt: Arc<RwLock<Option<CassSessionInner>>>,
9894
session_builder_fut: impl Future<Output = SessionBuilder>,
9995
exec_profile_builder_map: HashMap<ExecProfileName, CassExecProfile>,
10096
client_id: uuid::Uuid,
@@ -154,7 +150,7 @@ pub unsafe extern "C" fn cass_session_connect(
154150
session_raw: *mut CassSession,
155151
cluster_raw: *const CassCluster,
156152
) -> *mut CassFuture {
157-
let session_opt = ArcFFI::as_ref(session_raw);
153+
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
158154
let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw);
159155

160156
CassSessionInner::connect(session_opt, cluster, None)
@@ -176,7 +172,7 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n(
176172
keyspace: *const c_char,
177173
keyspace_length: size_t,
178174
) -> *mut CassFuture {
179-
let session_opt = ArcFFI::as_ref(session_raw);
175+
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
180176
let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw);
181177
let keyspace = ptr_to_cstr_n(keyspace, keyspace_length).map(ToOwned::to_owned);
182178

@@ -188,7 +184,7 @@ pub unsafe extern "C" fn cass_session_execute_batch(
188184
session_raw: *mut CassSession,
189185
batch_raw: *const CassBatch,
190186
) -> *mut CassFuture {
191-
let session_opt = ArcFFI::as_ref(session_raw);
187+
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
192188
let batch_from_raw = BoxFFI::as_ref(batch_raw);
193189
let mut state = batch_from_raw.state.clone();
194190
let request_timeout_ms = batch_from_raw.batch_request_timeout_ms;
@@ -254,7 +250,7 @@ pub unsafe extern "C" fn cass_session_execute(
254250
session_raw: *mut CassSession,
255251
statement_raw: *const CassStatement,
256252
) -> *mut CassFuture {
257-
let session_opt = ArcFFI::as_ref(session_raw);
253+
let session_opt = ArcFFI::cloned_from_ptr(session_raw);
258254

259255
// DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault.
260256
let statement_opt = BoxFFI::as_ref(statement_raw);
@@ -389,7 +385,7 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing(
389385
cass_session: *mut CassSession,
390386
statement: *const CassStatement,
391387
) -> *mut CassFuture {
392-
let session = ArcFFI::as_ref(cass_session);
388+
let session = ArcFFI::cloned_from_ptr(cass_session);
393389
let cass_statement = BoxFFI::as_ref(statement);
394390
let statement = cass_statement.statement.clone();
395391

@@ -441,7 +437,7 @@ pub unsafe extern "C" fn cass_session_prepare_n(
441437
// There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`.
442438
.unwrap_or_default();
443439
let query = Statement::new(query_str.to_string());
444-
let cass_session = ArcFFI::as_ref(cass_session_raw);
440+
let cass_session = ArcFFI::cloned_from_ptr(cass_session_raw);
445441

446442
CassFuture::make_raw(async move {
447443
let session_guard = cass_session.read().await;
@@ -474,7 +470,7 @@ pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) {
474470

475471
#[no_mangle]
476472
pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *mut CassFuture {
477-
let session_opt = ArcFFI::as_ref(session);
473+
let session_opt = ArcFFI::cloned_from_ptr(session);
478474

479475
CassFuture::make_raw(async move {
480476
let mut session_guard = session_opt.write().await;

0 commit comments

Comments
 (0)