http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/protocol/mod.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs index b230d63..4f13914 100644 --- a/lib/rs/src/protocol/mod.rs +++ b/lib/rs/src/protocol/mod.rs @@ -19,59 +19,77 @@ //! //! # Examples //! -//! Create and use a `TOutputProtocol`. +//! Create and use a `TInputProtocol`. //! //! ```no_run -//! use std::cell::RefCell; -//! use std::rc::Rc; -//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; -//! use thrift::transport::{TTcpTransport, TTransport}; +//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; +//! use thrift::transport::TTcpChannel; //! //! // create the I/O channel -//! let mut transport = TTcpTransport::new(); -//! transport.open("127.0.0.1:9090").unwrap(); -//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); //! -//! // create the protocol to encode types into bytes -//! let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +//! // create the protocol to decode bytes into types +//! let mut protocol = TBinaryInputProtocol::new(channel, true); //! -//! // write types -//! o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); -//! o_prot.write_string("foo").unwrap(); -//! o_prot.write_field_end().unwrap(); +//! // read types from the wire +//! let field_identifier = protocol.read_field_begin().unwrap(); +//! let field_contents = protocol.read_string().unwrap(); +//! let field_end = protocol.read_field_end().unwrap(); //! ``` //! -//! Create and use a `TInputProtocol`. +//! Create and use a `TOutputProtocol`. //! //! ```no_run -//! use std::cell::RefCell; -//! use std::rc::Rc; -//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -//! use thrift::transport::{TTcpTransport, TTransport}; +//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; +//! use thrift::transport::TTcpChannel; //! //! // create the I/O channel -//! let mut transport = TTcpTransport::new(); -//! transport.open("127.0.0.1:9090").unwrap(); -//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); //! -//! // create the protocol to decode bytes into types -//! let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); +//! // create the protocol to encode types into bytes +//! let mut protocol = TBinaryOutputProtocol::new(channel, true); //! -//! // read types from the wire -//! let field_identifier = i_prot.read_field_begin().unwrap(); -//! let field_contents = i_prot.read_string().unwrap(); -//! let field_end = i_prot.read_field_end().unwrap(); +//! // write types +//! protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +//! protocol.write_string("foo").unwrap(); +//! protocol.write_field_end().unwrap(); //! ``` -use std::cell::RefCell; use std::fmt; use std::fmt::{Display, Formatter}; use std::convert::From; -use std::rc::Rc; use try_from::TryFrom; -use ::{ProtocolError, ProtocolErrorKind}; -use ::transport::TTransport; +use {ProtocolError, ProtocolErrorKind}; +use transport::{TReadTransport, TWriteTransport}; + +#[cfg(test)] +macro_rules! assert_eq_written_bytes { + ($o_prot:ident, $expected_bytes:ident) => { + { + assert_eq!($o_prot.transport.write_bytes(), &$expected_bytes); + } + }; +} + +// FIXME: should take both read and write +#[cfg(test)] +macro_rules! copy_write_buffer_to_read_buffer { + ($o_prot:ident) => { + { + $o_prot.transport.copy_write_buffer_to_read_buffer(); + } + }; +} + +#[cfg(test)] +macro_rules! set_readable_bytes { + ($i_prot:ident, $bytes:expr) => { + $i_prot.transport.set_readable_bytes($bytes); + } +} mod binary; mod compact; @@ -107,20 +125,17 @@ const MAXIMUM_SKIP_DEPTH: i8 = 64; /// Create and use a `TInputProtocol` /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); +/// let mut protocol = TBinaryInputProtocol::new(channel, true); /// -/// let field_identifier = i_prot.read_field_begin().unwrap(); -/// let field_contents = i_prot.read_string().unwrap(); -/// let field_end = i_prot.read_field_end().unwrap(); +/// let field_identifier = protocol.read_field_begin().unwrap(); +/// let field_contents = protocol.read_string().unwrap(); +/// let field_end = protocol.read_field_end().unwrap(); /// ``` pub trait TInputProtocol { /// Read the beginning of a Thrift message. @@ -171,10 +186,14 @@ pub trait TInputProtocol { /// Skip a field with type `field_type` recursively up to `depth` levels. fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> ::Result<()> { if depth == 0 { - return Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::DepthLimit, - message: format!("cannot parse past {:?}", field_type), - })); + return Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::DepthLimit, + message: format!("cannot parse past {:?}", field_type), + }, + ), + ); } match field_type { @@ -213,9 +232,11 @@ pub trait TInputProtocol { TType::Map => { let map_ident = self.read_map_begin()?; for _ in 0..map_ident.size { - let key_type = map_ident.key_type + let key_type = map_ident + .key_type .expect("non-zero sized map should contain key type"); - let val_type = map_ident.value_type + let val_type = map_ident + .value_type .expect("non-zero sized map should contain value type"); self.skip_till_depth(key_type, depth - 1)?; self.skip_till_depth(val_type, depth - 1)?; @@ -223,10 +244,14 @@ pub trait TInputProtocol { self.read_map_end() } u => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot skip field type {:?}", &u), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot skip field type {:?}", &u), + }, + ), + ) } } } @@ -259,20 +284,17 @@ pub trait TInputProtocol { /// Create and use a `TOutputProtocol` /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); /// -/// o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); -/// o_prot.write_string("foo").unwrap(); -/// o_prot.write_field_end().unwrap(); +/// protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +/// protocol.write_string("foo").unwrap(); +/// protocol.write_field_end().unwrap(); /// ``` pub trait TOutputProtocol { /// Write the beginning of a Thrift message. @@ -330,6 +352,192 @@ pub trait TOutputProtocol { fn write_byte(&mut self, b: u8) -> ::Result<()>; // FIXME: REMOVE } +impl<P> TInputProtocol for Box<P> +where + P: TInputProtocol + ?Sized, +{ + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + (**self).read_message_begin() + } + + fn read_message_end(&mut self) -> ::Result<()> { + (**self).read_message_end() + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + (**self).read_struct_begin() + } + + fn read_struct_end(&mut self) -> ::Result<()> { + (**self).read_struct_end() + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + (**self).read_field_begin() + } + + fn read_field_end(&mut self) -> ::Result<()> { + (**self).read_field_end() + } + + fn read_bool(&mut self) -> ::Result<bool> { + (**self).read_bool() + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + (**self).read_bytes() + } + + fn read_i8(&mut self) -> ::Result<i8> { + (**self).read_i8() + } + + fn read_i16(&mut self) -> ::Result<i16> { + (**self).read_i16() + } + + fn read_i32(&mut self) -> ::Result<i32> { + (**self).read_i32() + } + + fn read_i64(&mut self) -> ::Result<i64> { + (**self).read_i64() + } + + fn read_double(&mut self) -> ::Result<f64> { + (**self).read_double() + } + + fn read_string(&mut self) -> ::Result<String> { + (**self).read_string() + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + (**self).read_list_begin() + } + + fn read_list_end(&mut self) -> ::Result<()> { + (**self).read_list_end() + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + (**self).read_set_begin() + } + + fn read_set_end(&mut self) -> ::Result<()> { + (**self).read_set_end() + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + (**self).read_map_begin() + } + + fn read_map_end(&mut self) -> ::Result<()> { + (**self).read_map_end() + } + + fn read_byte(&mut self) -> ::Result<u8> { + (**self).read_byte() + } +} + +impl<P> TOutputProtocol for Box<P> +where + P: TOutputProtocol + ?Sized, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + (**self).write_message_begin(identifier) + } + + fn write_message_end(&mut self) -> ::Result<()> { + (**self).write_message_end() + } + + fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> { + (**self).write_struct_begin(identifier) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + (**self).write_struct_end() + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + (**self).write_field_begin(identifier) + } + + fn write_field_end(&mut self) -> ::Result<()> { + (**self).write_field_end() + } + + fn write_field_stop(&mut self) -> ::Result<()> { + (**self).write_field_stop() + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + (**self).write_bool(b) + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + (**self).write_bytes(b) + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + (**self).write_i8(i) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + (**self).write_i16(i) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + (**self).write_i32(i) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + (**self).write_i64(i) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + (**self).write_double(d) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + (**self).write_string(s) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + (**self).write_list_begin(identifier) + } + + fn write_list_end(&mut self) -> ::Result<()> { + (**self).write_list_end() + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + (**self).write_set_begin(identifier) + } + + fn write_set_end(&mut self) -> ::Result<()> { + (**self).write_set_end() + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + (**self).write_map_begin(identifier) + } + + fn write_map_end(&mut self) -> ::Result<()> { + (**self).write_map_end() + } + + fn flush(&mut self) -> ::Result<()> { + (**self).flush() + } + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + (**self).write_byte(b) + } +} + /// Helper type used by servers to create `TInputProtocol` instances for /// accepted client connections. /// @@ -338,21 +546,27 @@ pub trait TOutputProtocol { /// Create a `TInputProtocolFactory` and use it to create a `TInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocolFactory, TInputProtocolFactory}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut i_proto_factory = TBinaryInputProtocolFactory::new(); -/// let i_prot = i_proto_factory.create(transport); +/// let factory = TBinaryInputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); /// ``` pub trait TInputProtocolFactory { - /// Create a `TInputProtocol` that reads bytes from `transport`. - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol>; + // Create a `TInputProtocol` that reads bytes from `transport`. + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send>; +} + +impl<T> TInputProtocolFactory for Box<T> +where + T: TInputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> { + (**self).create(transport) + } } /// Helper type used by servers to create `TOutputProtocol` instances for @@ -363,21 +577,27 @@ pub trait TInputProtocolFactory { /// Create a `TOutputProtocolFactory` and use it to create a `TOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocolFactory, TOutputProtocolFactory}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("127.0.0.1:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); /// -/// let mut o_proto_factory = TBinaryOutputProtocolFactory::new(); -/// let o_prot = o_proto_factory.create(transport); +/// let factory = TBinaryOutputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); /// ``` pub trait TOutputProtocolFactory { /// Create a `TOutputProtocol` that writes bytes to `transport`. - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol>; + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send>; +} + +impl<T> TOutputProtocolFactory for Box<T> +where + T: TOutputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> { + (**self).create(transport) + } } /// Thrift message identifier. @@ -394,10 +614,11 @@ pub struct TMessageIdentifier { impl TMessageIdentifier { /// Create a `TMessageIdentifier` for a Thrift service-call named `name` /// with message type `message_type` and sequence number `sequence_number`. - pub fn new<S: Into<String>>(name: S, - message_type: TMessageType, - sequence_number: i32) - -> TMessageIdentifier { + pub fn new<S: Into<String>>( + name: S, + message_type: TMessageType, + sequence_number: i32, + ) -> TMessageIdentifier { TMessageIdentifier { name: name.into(), message_type: message_type, @@ -443,9 +664,10 @@ impl TFieldIdentifier { /// /// `id` should be `None` if `field_type` is `TType::Stop`. pub fn new<N, S, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier - where N: Into<Option<S>>, - S: Into<String>, - I: Into<Option<i16>> + where + N: Into<Option<S>>, + S: Into<String>, + I: Into<Option<i16>>, { TFieldIdentifier { name: name.into().map(|n| n.into()), @@ -510,8 +732,9 @@ impl TMapIdentifier { /// Create a `TMapIdentifier` for a map with `size` entries of type /// `key_type -> value_type`. pub fn new<K, V>(key_type: K, value_type: V, size: i32) -> TMapIdentifier - where K: Into<Option<TType>>, - V: Into<Option<TType>> + where + K: Into<Option<TType>>, + V: Into<Option<TType>>, { TMapIdentifier { key_type: key_type.into(), @@ -565,10 +788,14 @@ impl TryFrom<u8> for TMessageType { 0x03 => Ok(TMessageType::Exception), 0x04 => Ok(TMessageType::OneWay), unkn => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} to TMessageType", unkn), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TMessageType", unkn), + }, + ), + ) } } } @@ -642,10 +869,14 @@ pub fn verify_expected_sequence_number(expected: i32, actual: i32) -> ::Result<( if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::BadSequenceId, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::BadSequenceId, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -657,10 +888,14 @@ pub fn verify_expected_service_call(expected: &str, actual: &str) -> ::Result<() if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::WrongMethodName, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::WrongMethodName, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -672,10 +907,14 @@ pub fn verify_expected_message_type(expected: TMessageType, actual: TMessageType if expected == actual { Ok(()) } else { - Err(::Error::Application(::ApplicationError { - kind: ::ApplicationErrorKind::InvalidMessageType, - message: format!("expected {} got {}", expected, actual), - })) + Err( + ::Error::Application( + ::ApplicationError { + kind: ::ApplicationErrorKind::InvalidMessageType, + message: format!("expected {} got {}", expected, actual), + }, + ), + ) } } @@ -686,10 +925,14 @@ pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> : match *field { Some(_) => Ok(()), None => { - Err(::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::Unknown, - message: format!("missing required field {}", field_name), - })) + Err( + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing required field {}", field_name), + }, + ), + ) } } } @@ -700,10 +943,67 @@ pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> : /// /// Return `TFieldIdentifier.id` if an id exists, `Err` otherwise. pub fn field_id(field_ident: &TFieldIdentifier) -> ::Result<i16> { - field_ident.id.ok_or_else(|| { - ::Error::Protocol(::ProtocolError { - kind: ::ProtocolErrorKind::Unknown, - message: format!("missing field in in {:?}", field_ident), - }) - }) + field_ident + .id + .ok_or_else( + || { + ::Error::Protocol( + ::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing field in in {:?}", field_ident), + }, + ) + }, + ) +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use super::*; + use transport::{TReadTransport, TWriteTransport}; + + #[test] + fn must_create_usable_input_protocol_from_concrete_input_protocol() { + let r: Box<TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t = TCompactInputProtocol::new(r); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_input_protocol_from_boxed_input() { + let r: Box<TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t: Box<TInputProtocol> = Box::new(TCompactInputProtocol::new(r)); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_concrete_output_protocol() { + let w: Box<TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t = TCompactOutputProtocol::new(w); + takes_output_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_boxed_output() { + let w: Box<TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t: Box<TOutputProtocol> = Box::new(TCompactOutputProtocol::new(w)); + takes_output_protocol(&mut t) + } + + fn takes_input_protocol<R>(t: &mut R) + where + R: TInputProtocol, + { + t.read_byte().unwrap(); + } + + fn takes_output_protocol<W>(t: &mut W) + where + W: TOutputProtocol, + { + t.flush().unwrap(); + } }
http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/protocol/multiplexed.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/protocol/multiplexed.rs b/lib/rs/src/protocol/multiplexed.rs index a30aca8..db08027 100644 --- a/lib/rs/src/protocol/multiplexed.rs +++ b/lib/rs/src/protocol/multiplexed.rs @@ -37,33 +37,37 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// Create and use a `TMultiplexedOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TMessageIdentifier, TMessageType, TOutputProtocol}; /// use thrift::protocol::{TBinaryOutputProtocol, TMultiplexedOutputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let o_prot = TBinaryOutputProtocol::new(transport, true); -/// let mut o_prot = TMultiplexedOutputProtocol::new("service_name", Box::new(o_prot)); +/// let protocol = TBinaryOutputProtocol::new(channel, true); +/// let mut protocol = TMultiplexedOutputProtocol::new("service_name", protocol); /// /// let ident = TMessageIdentifier::new("svc_call", TMessageType::Call, 1); -/// o_prot.write_message_begin(&ident).unwrap(); +/// protocol.write_message_begin(&ident).unwrap(); /// ``` -pub struct TMultiplexedOutputProtocol<'a> { +#[derive(Debug)] +pub struct TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ service_name: String, - inner: Box<TOutputProtocol + 'a>, + inner: P, } -impl<'a> TMultiplexedOutputProtocol<'a> { +impl<P> TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ /// Create a `TMultiplexedOutputProtocol` that identifies outgoing messages /// as originating from a service named `service_name` and sends them over /// the `wrapped` `TOutputProtocol`. Outgoing messages are encoded and sent /// by `wrapped`, not by this instance. - pub fn new(service_name: &str, wrapped: Box<TOutputProtocol + 'a>) -> TMultiplexedOutputProtocol<'a> { + pub fn new(service_name: &str, wrapped: P) -> TMultiplexedOutputProtocol<P> { TMultiplexedOutputProtocol { service_name: service_name.to_owned(), inner: wrapped, @@ -72,7 +76,10 @@ impl<'a> TMultiplexedOutputProtocol<'a> { } // FIXME: avoid passthrough methods -impl<'a> TOutputProtocol for TMultiplexedOutputProtocol<'a> { +impl<P> TOutputProtocol for TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { match identifier.message_type { // FIXME: is there a better way to override identifier here? TMessageType::Call | TMessageType::OneWay => { @@ -181,39 +188,50 @@ impl<'a> TOutputProtocol for TMultiplexedOutputProtocol<'a> { #[cfg(test)] mod tests { - use std::cell::RefCell; - use std::rc::Rc; - - use ::protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; + use protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; + use transport::{TBufferChannel, TIoChannel, WriteHalf}; use super::*; #[test] fn must_write_message_begin_with_prefixed_service_name() { - let (trans, mut o_prot) = test_objects(); + let mut o_prot = test_objects(); let ident = TMessageIdentifier::new("bar", TMessageType::Call, 2); assert_success!(o_prot.write_message_begin(&ident)); - let expected: [u8; 19] = - [0x80, 0x01 /* protocol identifier */, 0x00, 0x01 /* message type */, 0x00, - 0x00, 0x00, 0x07, 0x66, 0x6F, 0x6F /* "foo" */, 0x3A /* ":" */, 0x62, 0x61, - 0x72 /* "bar" */, 0x00, 0x00, 0x00, 0x02 /* sequence number */]; - - assert_eq!(&trans.borrow().write_buffer_to_vec(), &expected); - } - - fn test_objects<'a>() -> (Rc<RefCell<Box<TBufferTransport>>>, TMultiplexedOutputProtocol<'a>) { - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40)))); - - let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let inner = Rc::new(RefCell::new(inner)); - - let o_prot = TBinaryOutputProtocol::new(inner.clone(), true); - let o_prot = TMultiplexedOutputProtocol::new("foo", Box::new(o_prot)); - - (mem, o_prot) + let expected: [u8; 19] = [ + 0x80, + 0x01, /* protocol identifier */ + 0x00, + 0x01, /* message type */ + 0x00, + 0x00, + 0x00, + 0x07, + 0x66, + 0x6F, + 0x6F, /* "foo" */ + 0x3A, /* ":" */ + 0x62, + 0x61, + 0x72, /* "bar" */ + 0x00, + 0x00, + 0x00, + 0x02 /* sequence number */, + ]; + + assert_eq!(o_prot.inner.transport.write_bytes(), expected); + } + + fn test_objects + () + -> TMultiplexedOutputProtocol<TBinaryOutputProtocol<WriteHalf<TBufferChannel>>> + { + let c = TBufferChannel::with_capacity(40, 40); + let (_, w_chan) = c.split().unwrap(); + let prot = TBinaryOutputProtocol::new(w_chan, true); + TMultiplexedOutputProtocol::new("foo", prot) } } http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/protocol/stored.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/protocol/stored.rs b/lib/rs/src/protocol/stored.rs index 6826c00..b3f305f 100644 --- a/lib/rs/src/protocol/stored.rs +++ b/lib/rs/src/protocol/stored.rs @@ -17,8 +17,8 @@ use std::convert::Into; -use ::ProtocolErrorKind; -use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TInputProtocol, +use ProtocolErrorKind; +use super::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, TSetIdentifier, TStructIdentifier}; /// `TInputProtocol` required to use a `TMultiplexedProcessor`. @@ -40,35 +40,34 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// Create and use a `TStoredInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift; /// use thrift::protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; /// use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TStoredInputProtocol}; /// use thrift::server::TProcessor; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::{TIoChannel, TTcpChannel}; /// /// // sample processor /// struct ActualProcessor; /// impl TProcessor for ActualProcessor { /// fn process( -/// &mut self, +/// &self, /// _: &mut TInputProtocol, /// _: &mut TOutputProtocol /// ) -> thrift::Result<()> { /// unimplemented!() /// } /// } -/// let mut processor = ActualProcessor {}; +/// let processor = ActualProcessor {}; /// /// // construct the shared transport -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let (i_chan, o_chan) = channel.split().unwrap(); /// /// // construct the actual input and output protocols -/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true); -/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true); +/// let mut i_prot = TBinaryInputProtocol::new(i_chan, true); +/// let mut o_prot = TBinaryOutputProtocol::new(o_chan, true); /// /// // message identifier received from remote and modified to remove the service name /// let new_msg_ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1); @@ -77,6 +76,7 @@ use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifie /// let mut proxy_i_prot = TStoredInputProtocol::new(&mut i_prot, new_msg_ident); /// let res = processor.process(&mut proxy_i_prot, &mut o_prot); /// ``` +// FIXME: implement Debug pub struct TStoredInputProtocol<'a> { inner: &'a mut TInputProtocol, message_ident: Option<TMessageIdentifier>, @@ -88,9 +88,10 @@ impl<'a> TStoredInputProtocol<'a> { /// `TInputProtocol`. `message_ident` is the modified message identifier - /// with service name stripped - that will be passed to /// `wrapped.read_message_begin(...)`. - pub fn new(wrapped: &mut TInputProtocol, - message_ident: TMessageIdentifier) - -> TStoredInputProtocol { + pub fn new( + wrapped: &mut TInputProtocol, + message_ident: TMessageIdentifier, + ) -> TStoredInputProtocol { TStoredInputProtocol { inner: wrapped, message_ident: message_ident.into(), @@ -100,10 +101,16 @@ impl<'a> TStoredInputProtocol<'a> { impl<'a> TInputProtocol for TStoredInputProtocol<'a> { fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { - self.message_ident.take().ok_or_else(|| { - ::errors::new_protocol_error(ProtocolErrorKind::Unknown, - "message identifier already read") - }) + self.message_ident + .take() + .ok_or_else( + || { + ::errors::new_protocol_error( + ProtocolErrorKind::Unknown, + "message identifier already read", + ) + }, + ) } fn read_message_end(&mut self) -> ::Result<()> { http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/server/mod.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/server/mod.rs b/lib/rs/src/server/mod.rs index ceac18a..21c392c 100644 --- a/lib/rs/src/server/mod.rs +++ b/lib/rs/src/server/mod.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! Types required to implement a Thrift server. +//! Types used to implement a Thrift server. -use ::protocol::{TInputProtocol, TOutputProtocol}; +use protocol::{TInputProtocol, TOutputProtocol}; -mod simple; mod multiplexed; +mod threaded; -pub use self::simple::TSimpleServer; pub use self::multiplexed::TMultiplexedProcessor; +pub use self::threaded::TServer; /// Handles incoming Thrift messages and dispatches them to the user-defined /// handler functions. @@ -56,14 +56,14 @@ pub use self::multiplexed::TMultiplexedProcessor; /// /// // `TProcessor` implementation for `SimpleService` /// impl TProcessor for SimpleServiceSyncProcessor { -/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { /// unimplemented!(); /// } /// } /// /// // service functions for SimpleService /// trait SimpleServiceSyncHandler { -/// fn service_call(&mut self) -> thrift::Result<()>; +/// fn service_call(&self) -> thrift::Result<()>; /// } /// /// // @@ -73,7 +73,7 @@ pub use self::multiplexed::TMultiplexedProcessor; /// // define a handler that will be invoked when `service_call` is received /// struct SimpleServiceHandlerImpl; /// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { -/// fn service_call(&mut self) -> thrift::Result<()> { +/// fn service_call(&self) -> thrift::Result<()> { /// unimplemented!(); /// } /// } @@ -82,7 +82,7 @@ pub use self::multiplexed::TMultiplexedProcessor; /// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); /// /// // at this point you can pass the processor to the server -/// // let server = TSimpleServer::new(..., processor); +/// // let server = TServer::new(..., processor); /// ``` pub trait TProcessor { /// Process a Thrift service call. @@ -91,5 +91,5 @@ pub trait TProcessor { /// the response to `o`. /// /// Returns `()` if the handler was executed; `Err` otherwise. - fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>; + fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>; } http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/server/multiplexed.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/server/multiplexed.rs b/lib/rs/src/server/multiplexed.rs index d2314a1..b1243a8 100644 --- a/lib/rs/src/server/multiplexed.rs +++ b/lib/rs/src/server/multiplexed.rs @@ -17,9 +17,10 @@ use std::collections::HashMap; use std::convert::Into; +use std::sync::{Arc, Mutex}; -use ::{new_application_error, ApplicationErrorKind}; -use ::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol}; +use {ApplicationErrorKind, new_application_error}; +use protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol}; use super::TProcessor; @@ -33,8 +34,9 @@ use super::TProcessor; /// /// A `TMultiplexedProcessor` can only handle messages sent by a /// `TMultiplexedOutputProtocol`. +// FIXME: implement Debug pub struct TMultiplexedProcessor { - processors: HashMap<String, Box<TProcessor>>, + processors: Mutex<HashMap<String, Arc<Box<TProcessor>>>>, } impl TMultiplexedProcessor { @@ -46,46 +48,62 @@ impl TMultiplexedProcessor { /// Return `false` if a mapping previously existed (the previous mapping is /// *not* overwritten). #[cfg_attr(feature = "cargo-clippy", allow(map_entry))] - pub fn register_processor<S: Into<String>>(&mut self, - service_name: S, - processor: Box<TProcessor>) - -> bool { + pub fn register_processor<S: Into<String>>( + &mut self, + service_name: S, + processor: Box<TProcessor>, + ) -> bool { + let mut processors = self.processors.lock().unwrap(); + let name = service_name.into(); - if self.processors.contains_key(&name) { + if processors.contains_key(&name) { false } else { - self.processors.insert(name, processor); + processors.insert(name, Arc::new(processor)); true } } } impl TProcessor for TMultiplexedProcessor { - fn process(&mut self, - i_prot: &mut TInputProtocol, - o_prot: &mut TOutputProtocol) - -> ::Result<()> { + fn process(&self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> ::Result<()> { let msg_ident = i_prot.read_message_begin()?; - let sep_index = msg_ident.name + let sep_index = msg_ident + .name .find(':') - .ok_or_else(|| { - new_application_error(ApplicationErrorKind::Unknown, - "no service separator found in incoming message") - })?; + .ok_or_else( + || { + new_application_error( + ApplicationErrorKind::Unknown, + "no service separator found in incoming message", + ) + }, + )?; let (svc_name, svc_call) = msg_ident.name.split_at(sep_index); - match self.processors.get_mut(svc_name) { - Some(ref mut processor) => { - let new_msg_ident = TMessageIdentifier::new(svc_call, - msg_ident.message_type, - msg_ident.sequence_number); + let processor: Option<Arc<Box<TProcessor>>> = { + let processors = self.processors.lock().unwrap(); + processors.get(svc_name).cloned() + }; + + match processor { + Some(arc) => { + let new_msg_ident = TMessageIdentifier::new( + svc_call, + msg_ident.message_type, + msg_ident.sequence_number, + ); let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident); - processor.process(&mut proxy_i_prot, o_prot) + (*arc).process(&mut proxy_i_prot, o_prot) } None => { - Err(new_application_error(ApplicationErrorKind::Unknown, - format!("no processor found for service {}", svc_name))) + Err( + new_application_error( + ApplicationErrorKind::Unknown, + format!("no processor found for service {}", svc_name), + ), + ) } } } http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/server/simple.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/server/simple.rs b/lib/rs/src/server/simple.rs deleted file mode 100644 index 89ed977..0000000 --- a/lib/rs/src/server/simple.rs +++ /dev/null @@ -1,189 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::cell::RefCell; -use std::net::{TcpListener, TcpStream}; -use std::rc::Rc; - -use ::{ApplicationError, ApplicationErrorKind}; -use ::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; -use ::transport::{TTcpTransport, TTransport, TTransportFactory}; - -use super::TProcessor; - -/// Single-threaded blocking Thrift socket server. -/// -/// A `TSimpleServer` listens on a given address and services accepted -/// connections *synchronously* and *sequentially* - i.e. in a blocking manner, -/// one at a time - on the main thread. Each accepted connection has an input -/// half and an output half, each of which uses a `TTransport` and `TProtocol` -/// to translate messages to and from byes. Any combination of `TProtocol` and -/// `TTransport` may be used. -/// -/// # Examples -/// -/// Creating and running a `TSimpleServer` using Thrift-compiler-generated -/// service code. -/// -/// ```no_run -/// use thrift; -/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; -/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory}; -/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; -/// use thrift::transport::{TBufferedTransportFactory, TTransportFactory}; -/// use thrift::server::{TProcessor, TSimpleServer}; -/// -/// // -/// // auto-generated -/// // -/// -/// // processor for `SimpleService` -/// struct SimpleServiceSyncProcessor; -/// impl SimpleServiceSyncProcessor { -/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { -/// unimplemented!(); -/// } -/// } -/// -/// // `TProcessor` implementation for `SimpleService` -/// impl TProcessor for SimpleServiceSyncProcessor { -/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { -/// unimplemented!(); -/// } -/// } -/// -/// // service functions for SimpleService -/// trait SimpleServiceSyncHandler { -/// fn service_call(&mut self) -> thrift::Result<()>; -/// } -/// -/// // -/// // user-code follows -/// // -/// -/// // define a handler that will be invoked when `service_call` is received -/// struct SimpleServiceHandlerImpl; -/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { -/// fn service_call(&mut self) -> thrift::Result<()> { -/// unimplemented!(); -/// } -/// } -/// -/// // instantiate the processor -/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); -/// -/// // instantiate the server -/// let i_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new()); -/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new()); -/// let o_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new()); -/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new()); -/// -/// let mut server = TSimpleServer::new( -/// i_tr_fact, -/// i_pr_fact, -/// o_tr_fact, -/// o_pr_fact, -/// processor -/// ); -/// -/// // start listening for incoming connections -/// match server.listen("127.0.0.1:8080") { -/// Ok(_) => println!("listen completed"), -/// Err(e) => println!("listen failed with error {:?}", e), -/// } -/// ``` -pub struct TSimpleServer<PR: TProcessor> { - i_trans_factory: Box<TTransportFactory>, - i_proto_factory: Box<TInputProtocolFactory>, - o_trans_factory: Box<TTransportFactory>, - o_proto_factory: Box<TOutputProtocolFactory>, - processor: PR, -} - -impl<PR: TProcessor> TSimpleServer<PR> { - /// Create a `TSimpleServer`. - /// - /// Each accepted connection has an input and output half, each of which - /// requires a `TTransport` and `TProtocol`. `TSimpleServer` uses - /// `input_transport_factory` and `input_protocol_factory` to create - /// implementations for the input, and `output_transport_factory` and - /// `output_protocol_factory` to create implementations for the output. - pub fn new(input_transport_factory: Box<TTransportFactory>, - input_protocol_factory: Box<TInputProtocolFactory>, - output_transport_factory: Box<TTransportFactory>, - output_protocol_factory: Box<TOutputProtocolFactory>, - processor: PR) - -> TSimpleServer<PR> { - TSimpleServer { - i_trans_factory: input_transport_factory, - i_proto_factory: input_protocol_factory, - o_trans_factory: output_transport_factory, - o_proto_factory: output_protocol_factory, - processor: processor, - } - } - - /// Listen for incoming connections on `listen_address`. - /// - /// `listen_address` should be in the form `host:port`, - /// for example: `127.0.0.1:8080`. - /// - /// Return `()` if successful. - /// - /// Return `Err` when the server cannot bind to `listen_address` or there - /// is an unrecoverable error. - pub fn listen(&mut self, listen_address: &str) -> ::Result<()> { - let listener = TcpListener::bind(listen_address)?; - for stream in listener.incoming() { - match stream { - Ok(s) => self.handle_incoming_connection(s), - Err(e) => warn!("failed to accept remote connection with error {:?}", e), - } - } - - Err(::Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: "aborted listen loop".into(), - })) - } - - fn handle_incoming_connection(&mut self, stream: TcpStream) { - // create the shared tcp stream - let stream = TTcpTransport::with_stream(stream); - let stream: Box<TTransport> = Box::new(stream); - let stream = Rc::new(RefCell::new(stream)); - - // input protocol and transport - let i_tran = self.i_trans_factory.create(stream.clone()); - let i_tran = Rc::new(RefCell::new(i_tran)); - let mut i_prot = self.i_proto_factory.create(i_tran); - - // output protocol and transport - let o_tran = self.o_trans_factory.create(stream.clone()); - let o_tran = Rc::new(RefCell::new(o_tran)); - let mut o_prot = self.o_proto_factory.create(o_tran); - - // process loop - loop { - let r = self.processor.process(&mut *i_prot, &mut *o_prot); - if let Err(e) = r { - warn!("processor failed with error: {:?}", e); - break; // FIXME: close here - } - } - } -} http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/server/threaded.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/server/threaded.rs b/lib/rs/src/server/threaded.rs new file mode 100644 index 0000000..a486c5a --- /dev/null +++ b/lib/rs/src/server/threaded.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; +use threadpool::ThreadPool; + +use {ApplicationError, ApplicationErrorKind}; +use protocol::{TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory}; +use transport::{TIoChannel, TReadTransportFactory, TTcpChannel, TWriteTransportFactory}; + +use super::TProcessor; + +/// Fixed-size thread-pool blocking Thrift server. +/// +/// A `TServer` listens on a given address and submits accepted connections +/// to an **unbounded** queue. Connections from this queue are serviced by +/// the first available worker thread from a **fixed-size** thread pool. Each +/// accepted connection is handled by that worker thread, and communication +/// over this thread occurs sequentially and synchronously (i.e. calls block). +/// Accepted connections have an input half and an output half, each of which +/// uses a `TTransport` and `TInputProtocol`/`TOutputProtocol` to translate +/// messages to and from byes. Any combination of `TInputProtocol`, `TOutputProtocol` +/// and `TTransport` may be used. +/// +/// # Examples +/// +/// Creating and running a `TServer` using Thrift-compiler-generated +/// service code. +/// +/// ```no_run +/// use thrift; +/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; +/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory}; +/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; +/// use thrift::transport::{TBufferedReadTransportFactory, TBufferedWriteTransportFactory, TReadTransportFactory, TWriteTransportFactory}; +/// use thrift::server::{TProcessor, TServer}; +/// +/// // +/// // auto-generated +/// // +/// +/// // processor for `SimpleService` +/// struct SimpleServiceSyncProcessor; +/// impl SimpleServiceSyncProcessor { +/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { +/// unimplemented!(); +/// } +/// } +/// +/// // `TProcessor` implementation for `SimpleService` +/// impl TProcessor for SimpleServiceSyncProcessor { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // service functions for SimpleService +/// trait SimpleServiceSyncHandler { +/// fn service_call(&self) -> thrift::Result<()>; +/// } +/// +/// // +/// // user-code follows +/// // +/// +/// // define a handler that will be invoked when `service_call` is received +/// struct SimpleServiceHandlerImpl; +/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { +/// fn service_call(&self) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // instantiate the processor +/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); +/// +/// // instantiate the server +/// let i_tr_fact: Box<TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new()); +/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new()); +/// let o_tr_fact: Box<TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new()); +/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new()); +/// +/// let mut server = TServer::new( +/// i_tr_fact, +/// i_pr_fact, +/// o_tr_fact, +/// o_pr_fact, +/// processor, +/// 10 +/// ); +/// +/// // start listening for incoming connections +/// match server.listen("127.0.0.1:8080") { +/// Ok(_) => println!("listen completed"), +/// Err(e) => println!("listen failed with error {:?}", e), +/// } +/// ``` +#[derive(Debug)] +pub struct TServer<PRC, RTF, IPF, WTF, OPF> +where + PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + r_trans_factory: RTF, + i_proto_factory: IPF, + w_trans_factory: WTF, + o_proto_factory: OPF, + processor: Arc<PRC>, + worker_pool: ThreadPool, +} + +impl<PRC, RTF, IPF, WTF, OPF> TServer<PRC, RTF, IPF, WTF, OPF> + where PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static { + /// Create a `TServer`. + /// + /// Each accepted connection has an input and output half, each of which + /// requires a `TTransport` and `TProtocol`. `TServer` uses + /// `read_transport_factory` and `input_protocol_factory` to create + /// implementations for the input, and `write_transport_factory` and + /// `output_protocol_factory` to create implementations for the output. + pub fn new( + read_transport_factory: RTF, + input_protocol_factory: IPF, + write_transport_factory: WTF, + output_protocol_factory: OPF, + processor: PRC, + num_workers: usize, + ) -> TServer<PRC, RTF, IPF, WTF, OPF> { + TServer { + r_trans_factory: read_transport_factory, + i_proto_factory: input_protocol_factory, + w_trans_factory: write_transport_factory, + o_proto_factory: output_protocol_factory, + processor: Arc::new(processor), + worker_pool: ThreadPool::new_with_name( + "Thrift service processor".to_owned(), + num_workers, + ), + } + } + + /// Listen for incoming connections on `listen_address`. + /// + /// `listen_address` should be in the form `host:port`, + /// for example: `127.0.0.1:8080`. + /// + /// Return `()` if successful. + /// + /// Return `Err` when the server cannot bind to `listen_address` or there + /// is an unrecoverable error. + pub fn listen(&mut self, listen_address: &str) -> ::Result<()> { + let listener = TcpListener::bind(listen_address)?; + for stream in listener.incoming() { + match stream { + Ok(s) => { + let (i_prot, o_prot) = self.new_protocols_for_connection(s)?; + let processor = self.processor.clone(); + self.worker_pool + .execute(move || handle_incoming_connection(processor, i_prot, o_prot),); + } + Err(e) => { + warn!("failed to accept remote connection with error {:?}", e); + } + } + } + + Err( + ::Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: "aborted listen loop".into(), + }, + ), + ) + } + + + fn new_protocols_for_connection( + &mut self, + stream: TcpStream, + ) -> ::Result<(Box<TInputProtocol + Send>, Box<TOutputProtocol + Send>)> { + // create the shared tcp stream + let channel = TTcpChannel::with_stream(stream); + + // split it into two - one to be owned by the + // input tran/proto and the other by the output + let (r_chan, w_chan) = channel.split()?; + + // input protocol and transport + let r_tran = self.r_trans_factory.create(Box::new(r_chan)); + let i_prot = self.i_proto_factory.create(r_tran); + + // output protocol and transport + let w_tran = self.w_trans_factory.create(Box::new(w_chan)); + let o_prot = self.o_proto_factory.create(w_tran); + + Ok((i_prot, o_prot)) + } +} + +fn handle_incoming_connection<PRC>( + processor: Arc<PRC>, + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) where + PRC: TProcessor, +{ + let mut i_prot = i_prot; + let mut o_prot = o_prot; + loop { + let r = processor.process(&mut *i_prot, &mut *o_prot); + if let Err(e) = r { + warn!("processor completed with error: {:?}", e); + break; + } + } +} http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/transport/buffered.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/transport/buffered.rs b/lib/rs/src/transport/buffered.rs index 3f240d8..b588ec1 100644 --- a/lib/rs/src/transport/buffered.rs +++ b/lib/rs/src/transport/buffered.rs @@ -15,104 +15,94 @@ // specific language governing permissions and limitations // under the License. -use std::cell::RefCell; use std::cmp; use std::io; use std::io::{Read, Write}; -use std::rc::Rc; -use super::{TTransport, TTransportFactory}; +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; /// Default capacity of the read buffer in bytes. -const DEFAULT_RBUFFER_CAPACITY: usize = 4096; +const READ_CAPACITY: usize = 4096; /// Default capacity of the write buffer in bytes.. -const DEFAULT_WBUFFER_CAPACITY: usize = 4096; +const WRITE_CAPACITY: usize = 4096; -/// Transport that communicates with endpoints using a byte stream. +/// Transport that reads messages via an internal buffer. /// -/// A `TBufferedTransport` maintains a fixed-size internal write buffer. All -/// writes are made to this buffer and are sent to the wrapped transport only -/// when `TTransport::flush()` is called. On a flush a fixed-length header with a -/// count of the buffered bytes is written, followed by the bytes themselves. -/// -/// A `TBufferedTransport` also maintains a fixed-size internal read buffer. -/// On a call to `TTransport::read(...)` one full message - both fixed-length -/// header and bytes - is read from the wrapped transport and buffered. +/// A `TBufferedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TBufferedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and buffered. /// Subsequent read calls are serviced from the internal buffer until it is /// exhausted, at which point the next full message is read from the wrapped -/// transport. +/// channel. /// /// # Examples /// -/// Create and use a `TBufferedTransport`. +/// Create and use a `TBufferedReadTransport`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; -/// use std::io::{Read, Write}; -/// use thrift::transport::{TBufferedTransport, TTcpTransport, TTransport}; +/// use std::io::Read; +/// use thrift::transport::{TBufferedReadTransport, TTcpChannel}; /// -/// let mut t = TTcpTransport::new(); -/// t.open("localhost:9090").unwrap(); +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); /// -/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>)); -/// let mut t = TBufferedTransport::new(t); +/// let mut t = TBufferedReadTransport::new(c); /// -/// // read /// t.read(&mut vec![0u8; 1]).unwrap(); -/// -/// // write -/// t.write(&[0x00]).unwrap(); -/// t.flush().unwrap(); /// ``` -pub struct TBufferedTransport { - rbuf: Box<[u8]>, - rpos: usize, - rcap: usize, - wbuf: Vec<u8>, - inner: Rc<RefCell<Box<TTransport>>>, +#[derive(Debug)] +pub struct TBufferedReadTransport<C> +where + C: Read, +{ + buf: Box<[u8]>, + pos: usize, + cap: usize, + chan: C, } -impl TBufferedTransport { +impl<C> TBufferedReadTransport<C> +where + C: Read, +{ /// Create a `TBufferedTransport` with default-sized internal read and - /// write buffers that wraps an `inner` `TTransport`. - pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TBufferedTransport { - TBufferedTransport::with_capacity(DEFAULT_RBUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner) + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport::with_capacity(READ_CAPACITY, channel) } /// Create a `TBufferedTransport` with an internal read buffer of size - /// `read_buffer_capacity` and an internal write buffer of size - /// `write_buffer_capacity` that wraps an `inner` `TTransport`. - pub fn with_capacity(read_buffer_capacity: usize, - write_buffer_capacity: usize, - inner: Rc<RefCell<Box<TTransport>>>) - -> TBufferedTransport { - TBufferedTransport { - rbuf: vec![0; read_buffer_capacity].into_boxed_slice(), - rpos: 0, - rcap: 0, - wbuf: Vec::with_capacity(write_buffer_capacity), - inner: inner, + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport { + buf: vec![0; read_capacity].into_boxed_slice(), + pos: 0, + cap: 0, + chan: channel, } } fn get_bytes(&mut self) -> io::Result<&[u8]> { - if self.rcap - self.rpos == 0 { - self.rpos = 0; - self.rcap = self.inner.borrow_mut().read(&mut self.rbuf)?; + if self.cap - self.pos == 0 { + self.pos = 0; + self.cap = self.chan.read(&mut self.buf)?; } - Ok(&self.rbuf[self.rpos..self.rcap]) + Ok(&self.buf[self.pos..self.cap]) } fn consume(&mut self, consumed: usize) { // TODO: was a bug here += <-- test somehow - self.rpos = cmp::min(self.rcap, self.rpos + consumed); + self.pos = cmp::min(self.cap, self.pos + consumed); } } -impl Read for TBufferedTransport { +impl<C> Read for TBufferedReadTransport<C> +where + C: Read, +{ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { let mut bytes_read = 0; @@ -137,65 +127,127 @@ impl Read for TBufferedTransport { } } -impl Write for TBufferedTransport { +/// Factory for creating instances of `TBufferedReadTransport`. +#[derive(Default)] +pub struct TBufferedReadTransportFactory; + +impl TBufferedReadTransportFactory { + pub fn new() -> TBufferedReadTransportFactory { + TBufferedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TBufferedReadTransportFactory { + /// Create a `TBufferedReadTransport`. + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> { + Box::new(TBufferedReadTransport::new(channel)) + } +} + +/// Transport that writes messages via an internal buffer. +/// +/// A `TBufferedWriteTransport` maintains a fixed-size internal write buffer. +/// All writes are made to this buffer and are sent to the wrapped channel only +/// when `TBufferedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TBufferedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TBufferedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TBufferedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBufferedWriteTransport<C> +where + C: Write, +{ + buf: Vec<u8>, + channel: C, +} + +impl<C> TBufferedWriteTransport<C> +where + C: Write, +{ + /// Create a `TBufferedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedWriteTransport<C> { + TBufferedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TBufferedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C> { + TBufferedWriteTransport { + buf: Vec::with_capacity(write_capacity), + channel: channel, + } + } +} + +impl<C> Write for TBufferedWriteTransport<C> +where + C: Write, +{ fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let avail_bytes = cmp::min(buf.len(), self.wbuf.capacity() - self.wbuf.len()); - self.wbuf.extend_from_slice(&buf[..avail_bytes]); - assert!(self.wbuf.len() <= self.wbuf.capacity(), - "copy overflowed buffer"); + let avail_bytes = cmp::min(buf.len(), self.buf.capacity() - self.buf.len()); + self.buf.extend_from_slice(&buf[..avail_bytes]); + assert!( + self.buf.len() <= self.buf.capacity(), + "copy overflowed buffer" + ); Ok(avail_bytes) } fn flush(&mut self) -> io::Result<()> { - self.inner.borrow_mut().write_all(&self.wbuf)?; - self.inner.borrow_mut().flush()?; - self.wbuf.clear(); + self.channel.write_all(&self.buf)?; + self.channel.flush()?; + self.buf.clear(); Ok(()) } } -/// Factory for creating instances of `TBufferedTransport` +/// Factory for creating instances of `TBufferedWriteTransport`. #[derive(Default)] -pub struct TBufferedTransportFactory; +pub struct TBufferedWriteTransportFactory; -impl TBufferedTransportFactory { - /// Create a `TBufferedTransportFactory`. - pub fn new() -> TBufferedTransportFactory { - TBufferedTransportFactory {} +impl TBufferedWriteTransportFactory { + pub fn new() -> TBufferedWriteTransportFactory { + TBufferedWriteTransportFactory {} } } -impl TTransportFactory for TBufferedTransportFactory { - fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> { - Box::new(TBufferedTransport::new(inner)) as Box<TTransport> +impl TWriteTransportFactory for TBufferedWriteTransportFactory { + /// Create a `TBufferedWriteTransport`. + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> { + Box::new(TBufferedWriteTransport::new(channel)) } } #[cfg(test)] mod tests { - use std::cell::RefCell; use std::io::{Read, Write}; - use std::rc::Rc; use super::*; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; - - macro_rules! new_transports { - ($wbc:expr, $rbc:expr) => ( - { - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity($wbc, $rbc)))); - let thru: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let thru = Rc::new(RefCell::new(thru)); - (mem, thru) - } - ); - } + use transport::TBufferChannel; #[test] fn must_return_zero_if_read_buffer_is_empty() { - let (_, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(10, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); let mut b = vec![0; 10]; let read_result = t.read(&mut b); @@ -205,8 +257,8 @@ mod tests { #[test] fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() { - let (_, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(10, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); let read_result = t.read(&mut []); @@ -215,10 +267,10 @@ mod tests { #[test] fn must_return_zero_if_nothing_more_can_be_read() { - let (mem, thru) = new_transports!(4, 0); - let mut t = TBufferedTransport::with_capacity(4, 0, thru); + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); - mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]); + t.chan.set_readable_bytes(&[0, 1, 2, 3]); // read buffer is exactly the same size as bytes available let mut buf = vec![0u8; 4]; @@ -239,10 +291,10 @@ mod tests { #[test] fn must_fill_user_buffer_with_only_as_many_bytes_as_available() { - let (mem, thru) = new_transports!(4, 0); - let mut t = TBufferedTransport::with_capacity(4, 0, thru); + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); - mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]); + t.chan.set_readable_bytes(&[0, 1, 2, 3]); // read buffer is much larger than the bytes available let mut buf = vec![0u8; 8]; @@ -268,15 +320,16 @@ mod tests { // we have a much smaller buffer than the // underlying transport has bytes available - let (mem, thru) = new_transports!(10, 0); - let mut t = TBufferedTransport::with_capacity(2, 0, thru); + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(2, mem); // fill the underlying transport's byte buffer let mut readable_bytes = [0u8; 10]; for i in 0..10 { readable_bytes[i] = i as u8; } - mem.borrow_mut().set_readable_bytes(&readable_bytes); + + t.chan.set_readable_bytes(&readable_bytes); // we ask to read into a buffer that's much larger // than the one the buffered transport has; as a result @@ -312,8 +365,8 @@ mod tests { #[test] fn must_return_zero_if_nothing_can_be_written() { - let (_, thru) = new_transports!(0, 0); - let mut t = TBufferedTransport::with_capacity(0, 0, thru); + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TBufferedWriteTransport::with_capacity(0, mem); let b = vec![0; 10]; let r = t.write(&b); @@ -323,19 +376,20 @@ mod tests { #[test] fn must_return_zero_if_caller_calls_write_with_empty_buffer() { - let (mem, thru) = new_transports!(0, 10); - let mut t = TBufferedTransport::with_capacity(0, 10, thru); + let mem = TBufferChannel::with_capacity(0, 10); + let mut t = TBufferedWriteTransport::with_capacity(10, mem); let r = t.write(&[]); + let expected: [u8; 0] = []; assert_eq!(r.unwrap(), 0); - assert_eq!(mem.borrow_mut().write_buffer_as_ref(), &[]); + assert_eq_transport_written_bytes!(t, expected); } #[test] fn must_return_zero_if_write_buffer_full() { - let (_, thru) = new_transports!(0, 0); - let mut t = TBufferedTransport::with_capacity(0, 4, thru); + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TBufferedWriteTransport::with_capacity(4, mem); let b = [0x00, 0x01, 0x02, 0x03]; @@ -350,26 +404,22 @@ mod tests { #[test] fn must_only_write_to_inner_transport_on_flush() { - let (mem, thru) = new_transports!(10, 10); - let mut t = TBufferedTransport::new(thru); + let mem = TBufferChannel::with_capacity(10, 10); + let mut t = TBufferedWriteTransport::new(mem); let b: [u8; 5] = [0, 1, 2, 3, 4]; assert_eq!(t.write(&b).unwrap(), 5); - assert_eq!(mem.borrow_mut().write_buffer_as_ref().len(), 0); + assert_eq_transport_num_written_bytes!(t, 0); assert!(t.flush().is_ok()); - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); } #[test] fn must_write_successfully_after_flush() { - let (mem, thru) = new_transports!(0, 5); - let mut t = TBufferedTransport::with_capacity(0, 5, thru); + let mem = TBufferChannel::with_capacity(0, 5); + let mut t = TBufferedWriteTransport::with_capacity(5, mem); // write and flush let b: [u8; 5] = [0, 1, 2, 3, 4]; @@ -377,24 +427,16 @@ mod tests { assert!(t.flush().is_ok()); // check the flushed bytes - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); // reset our underlying transport - mem.borrow_mut().empty_write_buffer(); + t.channel.empty_write_buffer(); // write and flush again assert_eq!(t.write(&b).unwrap(), 5); assert!(t.flush().is_ok()); // check the flushed bytes - { - let inner = mem.borrow_mut(); - let underlying_buffer = inner.write_buffer_as_ref(); - assert_eq!(b, underlying_buffer); - } + assert_eq_transport_written_bytes!(t, b); } } http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/transport/framed.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/transport/framed.rs b/lib/rs/src/transport/framed.rs index 75c12f4..d78d2f7 100644 --- a/lib/rs/src/transport/framed.rs +++ b/lib/rs/src/transport/framed.rs @@ -16,165 +16,242 @@ // under the License. use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use std::cell::RefCell; use std::cmp; use std::io; use std::io::{ErrorKind, Read, Write}; -use std::rc::Rc; -use super::{TTransport, TTransportFactory}; +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; /// Default capacity of the read buffer in bytes. -const WRITE_BUFFER_CAPACITY: usize = 4096; +const READ_CAPACITY: usize = 4096; -/// Default capacity of the write buffer in bytes.. -const DEFAULT_WBUFFER_CAPACITY: usize = 4096; +/// Default capacity of the write buffer in bytes. +const WRITE_CAPACITY: usize = 4096; -/// Transport that communicates with endpoints using framed messages. +/// Transport that reads framed messages. /// -/// A `TFramedTransport` maintains a fixed-size internal write buffer. All -/// writes are made to this buffer and are sent to the wrapped transport only -/// when `TTransport::flush()` is called. On a flush a fixed-length header with a -/// count of the buffered bytes is written, followed by the bytes themselves. -/// -/// A `TFramedTransport` also maintains a fixed-size internal read buffer. -/// On a call to `TTransport::read(...)` one full message - both fixed-length -/// header and bytes - is read from the wrapped transport and buffered. -/// Subsequent read calls are serviced from the internal buffer until it is -/// exhausted, at which point the next full message is read from the wrapped -/// transport. +/// A `TFramedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TFramedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and +/// buffered. Subsequent read calls are serviced from the internal buffer +/// until it is exhausted, at which point the next full message is read +/// from the wrapped channel. /// /// # Examples /// -/// Create and use a `TFramedTransport`. +/// Create and use a `TFramedReadTransport`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; -/// use std::io::{Read, Write}; -/// use thrift::transport::{TFramedTransport, TTcpTransport, TTransport}; +/// use std::io::Read; +/// use thrift::transport::{TFramedReadTransport, TTcpChannel}; /// -/// let mut t = TTcpTransport::new(); -/// t.open("localhost:9090").unwrap(); +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); /// -/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>)); -/// let mut t = TFramedTransport::new(t); +/// let mut t = TFramedReadTransport::new(c); /// -/// // read /// t.read(&mut vec![0u8; 1]).unwrap(); -/// -/// // write -/// t.write(&[0x00]).unwrap(); -/// t.flush().unwrap(); /// ``` -pub struct TFramedTransport { - rbuf: Box<[u8]>, - rpos: usize, - rcap: usize, - wbuf: Box<[u8]>, - wpos: usize, - inner: Rc<RefCell<Box<TTransport>>>, +#[derive(Debug)] +pub struct TFramedReadTransport<C> +where + C: Read, +{ + buf: Box<[u8]>, + pos: usize, + cap: usize, + chan: C, } -impl TFramedTransport { +impl<C> TFramedReadTransport<C> +where + C: Read, +{ /// Create a `TFramedTransport` with default-sized internal read and - /// write buffers that wraps an `inner` `TTransport`. - pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TFramedTransport { - TFramedTransport::with_capacity(WRITE_BUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner) + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport::with_capacity(READ_CAPACITY, channel) } /// Create a `TFramedTransport` with an internal read buffer of size - /// `read_buffer_capacity` and an internal write buffer of size - /// `write_buffer_capacity` that wraps an `inner` `TTransport`. - pub fn with_capacity(read_buffer_capacity: usize, - write_buffer_capacity: usize, - inner: Rc<RefCell<Box<TTransport>>>) - -> TFramedTransport { - TFramedTransport { - rbuf: vec![0; read_buffer_capacity].into_boxed_slice(), - rpos: 0, - rcap: 0, - wbuf: vec![0; write_buffer_capacity].into_boxed_slice(), - wpos: 0, - inner: inner, + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport { + buf: vec![0; read_capacity].into_boxed_slice(), + pos: 0, + cap: 0, + chan: channel, } } } -impl Read for TFramedTransport { +impl<C> Read for TFramedReadTransport<C> +where + C: Read, +{ fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { - if self.rcap - self.rpos == 0 { - let message_size = self.inner.borrow_mut().read_i32::<BigEndian>()? as usize; - if message_size > self.rbuf.len() { - return Err(io::Error::new(ErrorKind::Other, - format!("bytes to be read ({}) exceeds buffer \ + if self.cap - self.pos == 0 { + let message_size = self.chan.read_i32::<BigEndian>()? as usize; + if message_size > self.buf.len() { + return Err( + io::Error::new( + ErrorKind::Other, + format!( + "bytes to be read ({}) exceeds buffer \ capacity ({})", - message_size, - self.rbuf.len()))); + message_size, + self.buf.len() + ), + ), + ); } - self.inner.borrow_mut().read_exact(&mut self.rbuf[..message_size])?; - self.rpos = 0; - self.rcap = message_size as usize; + self.chan.read_exact(&mut self.buf[..message_size])?; + self.pos = 0; + self.cap = message_size as usize; } - let nread = cmp::min(b.len(), self.rcap - self.rpos); - b[..nread].clone_from_slice(&self.rbuf[self.rpos..self.rpos + nread]); - self.rpos += nread; + let nread = cmp::min(b.len(), self.cap - self.pos); + b[..nread].clone_from_slice(&self.buf[self.pos..self.pos + nread]); + self.pos += nread; Ok(nread) } } -impl Write for TFramedTransport { +/// Factory for creating instances of `TFramedReadTransport`. +#[derive(Default)] +pub struct TFramedReadTransportFactory; + +impl TFramedReadTransportFactory { + pub fn new() -> TFramedReadTransportFactory { + TFramedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TFramedReadTransportFactory { + /// Create a `TFramedReadTransport`. + fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> { + Box::new(TFramedReadTransport::new(channel)) + } +} + +/// Transport that writes framed messages. +/// +/// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All +/// writes are made to this buffer and are sent to the wrapped channel only +/// when `TFramedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TFramedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TFramedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TFramedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TFramedWriteTransport<C> +where + C: Write, +{ + buf: Box<[u8]>, + pos: usize, + channel: C, +} + +impl<C> TFramedWriteTransport<C> +where + C: Write, +{ + /// Create a `TFramedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TFramedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport { + buf: vec![0; write_capacity].into_boxed_slice(), + pos: 0, + channel: channel, + } + } +} + +impl<C> Write for TFramedWriteTransport<C> +where + C: Write, +{ fn write(&mut self, b: &[u8]) -> io::Result<usize> { - if b.len() > (self.wbuf.len() - self.wpos) { - return Err(io::Error::new(ErrorKind::Other, - format!("bytes to be written ({}) exceeds buffer \ + if b.len() > (self.buf.len() - self.pos) { + return Err( + io::Error::new( + ErrorKind::Other, + format!( + "bytes to be written ({}) exceeds buffer \ capacity ({})", - b.len(), - self.wbuf.len() - self.wpos))); + b.len(), + self.buf.len() - self.pos + ), + ), + ); } let nwrite = b.len(); // always less than available write buffer capacity - self.wbuf[self.wpos..(self.wpos + nwrite)].clone_from_slice(b); - self.wpos += nwrite; + self.buf[self.pos..(self.pos + nwrite)].clone_from_slice(b); + self.pos += nwrite; Ok(nwrite) } fn flush(&mut self) -> io::Result<()> { - let message_size = self.wpos; + let message_size = self.pos; if let 0 = message_size { return Ok(()); } else { - self.inner.borrow_mut().write_i32::<BigEndian>(message_size as i32)?; + self.channel + .write_i32::<BigEndian>(message_size as i32)?; } let mut byte_index = 0; - while byte_index < self.wpos { - let nwrite = self.inner.borrow_mut().write(&self.wbuf[byte_index..self.wpos])?; - byte_index = cmp::min(byte_index + nwrite, self.wpos); + while byte_index < self.pos { + let nwrite = self.channel.write(&self.buf[byte_index..self.pos])?; + byte_index = cmp::min(byte_index + nwrite, self.pos); } - self.wpos = 0; - self.inner.borrow_mut().flush() + self.pos = 0; + self.channel.flush() } } -/// Factory for creating instances of `TFramedTransport`. +/// Factory for creating instances of `TFramedWriteTransport`. #[derive(Default)] -pub struct TFramedTransportFactory; +pub struct TFramedWriteTransportFactory; -impl TFramedTransportFactory { - // Create a `TFramedTransportFactory`. - pub fn new() -> TFramedTransportFactory { - TFramedTransportFactory {} +impl TFramedWriteTransportFactory { + pub fn new() -> TFramedWriteTransportFactory { + TFramedWriteTransportFactory {} } } -impl TTransportFactory for TFramedTransportFactory { - fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> { - Box::new(TFramedTransport::new(inner)) as Box<TTransport> +impl TWriteTransportFactory for TFramedWriteTransportFactory { + /// Create a `TFramedWriteTransport`. + fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> { + Box::new(TFramedWriteTransport::new(channel)) } } @@ -183,5 +260,5 @@ mod tests { // use std::io::{Read, Write}; // // use super::*; - // use ::transport::mem::TBufferTransport; + // use ::transport::mem::TBufferChannel; }
