Skip to content

feat: unstable HttpClient Config #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ h1_client = ["async-h1", "async-std", "deadpool", "futures"]
native_client = ["curl_client", "wasm_client"]
curl_client = ["isahc", "async-std"]
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"]
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"]
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util", "tokio"]

native-tls = ["async-native-tls"]
rustls = ["async-tls"]
rustls = ["async-tls", "rustls_crate"]

unstable-config = []

[dependencies]
async-trait = "0.1.37"
Expand All @@ -48,11 +50,13 @@ futures = { version = "0.3.8", optional = true }

# h1_client_rustls
async-tls = { version = "0.10.0", optional = true }
rustls_crate = { version = "0.18", optional = true, package = "rustls" }

# hyper_client
hyper = { version = "0.13.6", features = ["tcp"], optional = true }
hyper-tls = { version = "0.4.3", optional = true }
futures-util = { version = "0.3.5", features = ["io"], optional = true }
tokio = { version = "0.2", features = ["time"], optional = true }

# curl_client
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
Expand Down
111 changes: 111 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//! Configuration for `HttpClient`s.

use std::fmt::Debug;
use std::time::Duration;

/// Configuration for `HttpClient`s.
#[non_exhaustive]
#[derive(Clone)]
pub struct Config {
/// HTTP/1.1 `keep-alive` (connection pooling).
///
/// Default: `true`.
pub http_keep_alive: bool,
/// TCP `NO_DELAY`.
///
/// Default: `false`.
pub tcp_no_delay: bool,
/// Connection timeout duration.
///
/// Default: `Some(Duration::from_secs(60))`.
pub timeout: Option<Duration>,
/// TLS Configuration (Rustls)
#[cfg(all(feature = "h1_client", feature = "rustls"))]
pub tls_config: Option<std::sync::Arc<rustls_crate::ClientConfig>>,
/// TLS Configuration (Native TLS)
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
pub tls_config: Option<std::sync::Arc<async_native_tls::TlsConnector>>,
}

impl Debug for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut dbg_struct = f.debug_struct("Config");
dbg_struct
.field("http_keep_alive", &self.http_keep_alive)
.field("tcp_no_delay", &self.tcp_no_delay)
.field("timeout", &self.timeout);

#[cfg(all(feature = "h1_client", feature = "rustls"))]
{
if self.tls_config.is_some() {
dbg_struct.field("tls_config", &"Some(rustls::ClientConfig)");
} else {
dbg_struct.field("tls_config", &"None");
}
}
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
{
dbg_struct.field("tls_config", &self.tls_config);
}

dbg_struct.finish()
}
}

impl Config {
/// Construct new empty config.
pub fn new() -> Self {
Self {
http_keep_alive: true,
tcp_no_delay: false,
timeout: Some(Duration::from_secs(60)),
#[cfg(all(feature = "h1_client", any(feature = "rustls", feature = "native-tls")))]
tls_config: None,
}
}
}

impl Default for Config {
fn default() -> Self {
Self::new()
}
}

impl Config {
/// Set HTTP/1.1 `keep-alive` (connection pooling).
pub fn set_http_keep_alive(mut self, keep_alive: bool) -> Self {
self.http_keep_alive = keep_alive;
self
}

/// Set TCP `NO_DELAY`.
pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
self.tcp_no_delay = no_delay;
self
}

/// Set connection timeout duration.
pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}

/// Set TLS Configuration (Rustls)
#[cfg(all(feature = "h1_client", feature = "rustls"))]
pub fn set_tls_config(
mut self,
tls_config: Option<std::sync::Arc<rustls_crate::ClientConfig>>,
) -> Self {
self.tls_config = tls_config;
self
}
/// Set TLS Configuration (Native TLS)
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
pub fn set_tls_config(
mut self,
tls_config: Option<std::sync::Arc<async_native_tls::TlsConnector>>,
) -> Self {
self.tls_config = tls_config;
self
}
}
101 changes: 95 additions & 6 deletions src/h1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").

#[cfg(feature = "unstable-config")]
use std::convert::{Infallible, TryFrom};

use std::fmt::Debug;
use std::net::SocketAddr;
use std::sync::Arc;

use async_h1::client;
use async_std::net::TcpStream;
Expand All @@ -17,6 +21,8 @@ cfg_if::cfg_if! {
}
}

use crate::Config;

use super::{async_trait, Error, HttpClient, Request, Response};

mod tcp;
Expand All @@ -40,6 +46,7 @@ pub struct H1Client {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
https_pools: HttpsPool,
max_concurrent_connections: usize,
config: Arc<Config>,
}

impl Debug for H1Client {
Expand Down Expand Up @@ -79,6 +86,7 @@ impl Debug for H1Client {
"max_concurrent_connections",
&self.max_concurrent_connections,
)
.field("config", &self.config)
.finish()
}
}
Expand All @@ -97,6 +105,7 @@ impl H1Client {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
https_pools: DashMap::new(),
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
config: Arc::new(Config::default()),
}
}

Expand All @@ -107,6 +116,7 @@ impl H1Client {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
https_pools: DashMap::new(),
max_concurrent_connections: max,
config: Arc::new(Config::default()),
}
}
}
Expand Down Expand Up @@ -147,12 +157,43 @@ impl HttpClient for H1Client {
for (idx, addr) in addrs.into_iter().enumerate() {
let has_another_addr = idx != max_addrs_idx;

#[cfg(feature = "unstable-config")]
if !self.config.http_keep_alive {
match scheme {
"http" => {
let stream = async_std::net::TcpStream::connect(addr).await?;
req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
let tcp_conn = client::connect(stream, req);
return if let Some(timeout) = self.config.timeout {
async_std::future::timeout(timeout, tcp_conn).await?
} else {
tcp_conn.await
};
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let raw_stream = async_std::net::TcpStream::connect(addr).await?;
req.set_peer_addr(raw_stream.peer_addr().ok());
req.set_local_addr(raw_stream.local_addr().ok());
let tls_stream = tls::add_tls(&host, raw_stream, &self.config).await?;
let tsl_conn = client::connect(tls_stream, req);
return if let Some(timeout) = self.config.timeout {
async_std::future::timeout(timeout, tsl_conn).await?
} else {
tsl_conn.await
};
}
_ => unreachable!(),
}
}

match scheme {
"http" => {
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let manager = TcpConnection::new(addr, self.config.clone());
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
Expand All @@ -168,19 +209,28 @@ impl HttpClient for H1Client {
let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
Err(e) => return Err(Error::from_str(400, e.to_string())),
};

req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
return client::connect(TcpConnWrapper::new(stream), req).await;

let tcp_conn = client::connect(TcpConnWrapper::new(stream), req);
#[cfg(feature = "unstable-config")]
return if let Some(timeout) = self.config.timeout {
async_std::future::timeout(timeout, tcp_conn).await?
} else {
tcp_conn.await
};
#[cfg(not(feature = "unstable-config"))]
return tcp_conn.await;
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let manager = TlsConnection::new(host.clone(), addr, self.config.clone());
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
Expand All @@ -196,13 +246,21 @@ impl HttpClient for H1Client {
let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
Err(e) => return Err(Error::from_str(400, e.to_string())),
};

req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());

return client::connect(TlsConnWrapper::new(stream), req).await;
let tls_conn = client::connect(TlsConnWrapper::new(stream), req);
#[cfg(feature = "unstable-config")]
return if let Some(timeout) = self.config.timeout {
async_std::future::timeout(timeout, tls_conn).await?
} else {
tls_conn.await
};
#[cfg(not(feature = "unstable-config"))]
return tls_conn.await;
}
_ => unreachable!(),
}
Expand All @@ -213,6 +271,37 @@ impl HttpClient for H1Client {
"missing valid address",
))
}

#[cfg(feature = "unstable-config")]
/// Override the existing configuration with new configuration.
///
/// Config options may not impact existing connections.
fn set_config(&mut self, config: Config) -> http_types::Result<()> {
self.config = Arc::new(config);

Ok(())
}

#[cfg(feature = "unstable-config")]
/// Get the current configuration.
fn config(&self) -> &Config {
&*self.config
}
}

#[cfg(feature = "unstable-config")]
impl TryFrom<Config> for H1Client {
type Error = Infallible;

fn try_from(config: Config) -> Result<Self, Self::Error> {
Ok(Self {
http_pools: DashMap::new(),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
https_pools: DashMap::new(),
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
config: Arc::new(config),
})
}
}

#[cfg(test)]
Expand Down
24 changes: 19 additions & 5 deletions src/h1/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
use std::fmt::Debug;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use async_std::net::TcpStream;
use async_trait::async_trait;
use deadpool::managed::{Manager, Object, RecycleResult};
use futures::io::{AsyncRead, AsyncWrite};
use futures::task::{Context, Poll};

#[derive(Clone, Debug)]
use crate::Config;

#[derive(Clone)]
#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))]
pub(crate) struct TcpConnection {
addr: SocketAddr,
config: Arc<Config>,
}

impl TcpConnection {
pub(crate) fn new(addr: SocketAddr) -> Self {
Self { addr }
pub(crate) fn new(addr: SocketAddr, config: Arc<Config>) -> Self {
Self { addr, config }
}
}

Expand Down Expand Up @@ -58,12 +63,21 @@ impl AsyncWrite for TcpConnWrapper {
#[async_trait]
impl Manager<TcpStream, std::io::Error> for TcpConnection {
async fn create(&self) -> Result<TcpStream, std::io::Error> {
TcpStream::connect(self.addr).await
let tcp_stream = TcpStream::connect(self.addr).await?;

#[cfg(feature = "unstable-config")]
tcp_stream.set_nodelay(self.config.tcp_no_delay)?;

Ok(tcp_stream)
}

async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
let mut buf = [0; 4];
let mut cx = Context::from_waker(futures::task::noop_waker_ref());

#[cfg(feature = "unstable-config")]
conn.set_nodelay(self.config.tcp_no_delay)?;

match Pin::new(conn).poll_read(&mut cx, &mut buf) {
Poll::Ready(Err(error)) => Err(error),
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(
Expand Down
Loading