diff --git a/Cargo.toml b/Cargo.toml index 5ba1349..9f50026 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,5 @@ url = "2.2" rustls = "0.20" webpki-roots = "0.22.3" rustls-native-certs = "0.6" +serde = { version = "1", features = ["derive"] } +enum_dispatch = "0.3.8" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index aaf1b68..8b13789 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,105 +1 @@ -use std::sync::Arc; -use crate::{ClientOptions, Error}; -use bytes::{Bytes, BytesMut}; -use rustls::version::{TLS12, TLS13}; -use ureq::{Request, Response}; -use url::Url; - -const BUFFER_SIZE: usize = 128 * 1024; - -pub struct Inserter { - request: Request, - buffer: BytesMut, -} - -impl Inserter { - pub fn new(options: ClientOptions, table: &str) -> Inserter { - let agent = match options.secure { - true => { - let mut root_store = rustls::RootCertStore::empty(); - - let certs = rustls_native_certs::load_native_certs().expect("Could not load platform certs"); - for cert in certs { - // Repackage the certificate DER bytes. - let rustls_cert = rustls::Certificate(cert.0); - root_store - .add(&rustls_cert) - .expect("Failed to add native certificate too root store"); - } - - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( - |ta| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - - let protocol_versions = &[&TLS12, &TLS13]; - - let tls_config = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(protocol_versions) - .unwrap() - .with_root_certificates(root_store) - .with_no_client_auth(); - - ureq::builder().tls_config(Arc::new(tls_config)).build() - } - false => ureq::builder().build(), - }; - - let mut url = Url::parse(&options.url).expect("TODO"); - let query = format!("INSERT INTO {} FORMAT JSONEachRow", table); - - url.query_pairs_mut() - .append_pair("database", &options.database); - - url.query_pairs_mut().append_pair("query", &query); - - let mut request = agent.post(url.as_str()); - // let mut request = ureq::post(url.as_str()); - - if let Some(user) = &options.user { - request = request.set("X-ClickHouse-User", user); - } - - if let Some(password) = &options.password { - request = request.set("X-ClickHouse-Key", password); - } - - Inserter { - request, - buffer: BytesMut::with_capacity(BUFFER_SIZE), - } - } - - pub fn len(&self) -> usize { - self.buffer.len() - } - - pub fn write_bytes(&mut self, payload: Bytes) -> Result<(), Error> { - self.buffer.extend_from_slice(&payload[..]); - Ok(()) - } - - pub fn write_slice(&mut self, payload: &[u8]) -> Result<(), Error> { - self.buffer.extend_from_slice(payload); - Ok(()) - } - - pub fn end(&mut self) -> Result { - let request = self.request.clone(); - let response = request.send_bytes(&self.buffer[..])?; - self.buffer.clear(); - Ok(response) - } - - pub fn clear(&mut self) { - self.buffer.clear(); - } -} diff --git a/src/db/clickhouse.rs b/src/db/clickhouse.rs new file mode 100644 index 0000000..c646b64 --- /dev/null +++ b/src/db/clickhouse.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use crate::{ClientOptions, Error, Inserter}; +use bytes::{Bytes, BytesMut}; +use rustls::version::{TLS12, TLS13}; +use ureq::{Request, Response}; +use url::Url; + +const BUFFER_SIZE: usize = 128 * 1024; + +pub struct Clickhouse { + request: Request, + buffer: BytesMut, +} + +impl Clickhouse { + pub(crate) fn new(options: ClientOptions, table: &str) -> Clickhouse { + let agent = match options.secure { + true => { + let mut root_store = rustls::RootCertStore::empty(); + + let certs = rustls_native_certs::load_native_certs() + .expect("Could not load platform certs"); + for cert in certs { + // Repackage the certificate DER bytes. + let rustls_cert = rustls::Certificate(cert.0); + root_store + .add(&rustls_cert) + .expect("Failed to add native certificate too root store"); + } + + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( + |ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + + let protocol_versions = &[&TLS12, &TLS13]; + + let tls_config = rustls::ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(protocol_versions) + .unwrap() + .with_root_certificates(root_store) + .with_no_client_auth(); + + ureq::builder().tls_config(Arc::new(tls_config)).build() + } + false => ureq::builder().build(), + }; + + let mut url = Url::parse(&options.url).expect("TODO"); + let query = format!("INSERT INTO {} FORMAT JSONEachRow", table); + + url.query_pairs_mut() + .append_pair("database", &options.database); + + url.query_pairs_mut().append_pair("query", &query); + + let mut request = agent.post(url.as_str()); + // let mut request = ureq::post(url.as_str()); + + if let Some(user) = &options.user { + request = request.set("X-ClickHouse-User", user); + } + + if let Some(password) = &options.password { + request = request.set("X-ClickHouse-Key", password); + } + + Clickhouse { + request, + buffer: BytesMut::with_capacity(BUFFER_SIZE), + } + } +} + +impl Inserter for Clickhouse { + fn len(&self) -> usize { + self.buffer.len() + } + + fn write_bytes(&mut self, payload: Bytes) -> Result<(), Error> { + self.buffer.extend_from_slice(&payload[..]); + Ok(()) + } + + fn write_slice(&mut self, payload: &[u8]) -> Result<(), Error> { + self.buffer.extend_from_slice(payload); + Ok(()) + } + + fn end(&mut self) -> Result { + let request = self.request.clone(); + let response = request.send_bytes(&self.buffer[..])?; + self.buffer.clear(); + Ok(response) + } + + fn clear(&mut self) { + self.buffer.clear(); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..a92c69f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,41 @@ +mod clickhouse; + +use bytes::Bytes; +use clickhouse::Clickhouse; +use enum_dispatch::enum_dispatch; +use serde::{Deserialize, Serialize}; +use ureq::Response; + +use crate::{ClientOptions, Error}; + +#[enum_dispatch(Database)] +pub trait Inserter { + fn len(&self) -> usize; + + fn write_bytes(&mut self, payload: Bytes) -> Result<(), Error>; + + fn write_slice(&mut self, payload: &[u8]) -> Result<(), Error>; + + fn clear(&mut self); + + fn end(&mut self) -> Result; +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Type { + #[serde(rename = "clickhouse")] + Clickhouse, +} + +#[enum_dispatch] +pub enum Database { + Clickhouse, +} + +impl Database { + pub fn new(db_type: &Type, options: ClientOptions, table: &str) -> Database { + match db_type { + Type::Clickhouse => Self::Clickhouse(Clickhouse::new(options, table)), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6b84959..2154137 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ +use std::io; + mod client; +mod db; mod options; -use std::io; - -pub use client::Inserter; +pub use db::{Database, Inserter, Type}; +pub use options::ClientOptions; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -45,13 +47,3 @@ impl Default for Compression { Compression::None } } - -pub use options::ClientOptions; - -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -}