Hi everyone,

There's been some discussions about asynchronous C++ clients and
servers (THRIFT-1). Inspired by the Twisted server (THRIFT-148) by
Esteve Fernandez, as well as his C++ ideas (THRIFT-311), we (Erik
Bernhardsson, Mattias de Zalenski) decided to implement our own
version here at Spotify.

Based on TFramedTransport and Boost.ASIO, this version is event-driven
and runs fully asynchronously in a single thread. There is no overhead
in terms of additional threads. This code is not thread-safe yet, in
the sense that multiple threads cannot simultaneously invoke calls on
the same client, though we're definitely interested in discussing how
to implement this.

We have implemented an additional C++ stub generator (invoked by --gen
cpp:async) which adds two extra arguments to all client and server
methods. These are a callback and an errback, which are both Boost
closures. On the client side, these are passed when making a call, and
will be invoked when the call returns. On the server side, the server
is free to invoke these at any point in the future in order to respond
to the client. Calls and responses can be sent in any order. See code
example below.

The stub generator is independent of the ASIO code and can easily be
plugged into other reactor loops/frameworks. It shares most of the
code with the standard cpp generator (and does generate the
synchronous stubs as well). We have attached a diff between them.

This is still very much an early version, so feel welcome to submit
your comments. There are a few design decisions we still aren't 100%
sure of. Additionally, we haven't worked out how to handle low-level
failures such as disconnects. We are also working on a stress test and
some unit tests.

We have compiled everything by adding the stub generator to
compiler/cpp/src/ and the client/server code to lib/cpp/src/async/. In
order for this to work, the corresponding Makefiles have to be
modified and everything has to be compiled with -lboost_system.

Code example for the server follows below. We have modified the add
method so that it sleeps for num1+num2 seconds before returning, so
that there is an easy way of generating responses in a different order
than requests. We only include the most relevant parts below, but have
attached the full source code. Client code follows after the server
code.

class CalculatorAsyncHandler : public CalculatorAsyncIf {
 public:
 CalculatorAsyncHandler() {}

 virtual void ping(boost::function<void (void)> callback,
boost::function<void (Calculator_ping_result)> errback) {
   printf("ping()\n");
   callback();
 }

 virtual void add(const int32_t num1, const int32_t num2,
boost::function<void (int32_t)> callback, boost::function<void
(Calculator_add_result)> errback) {
   printf("add(%d,%d)\n", num1, num2);
   boost::shared_ptr<boost::asio::deadline_timer> timer(new
boost::asio::deadline_timer(io_service,
boost::posix_time::seconds(num1 + num2)));
   timer->async_wait(boost::bind(&CalculatorAsyncHandler::wait_done,
this, num1 + num2, callback, timer));
 }
 virtual void wait_done(const int32_t sum, boost::function<void
(int32_t)> callback, boost::shared_ptr<boost::asio::deadline_timer>) {
   callback(sum);
   // timer will fall out of scope now and will be deleted
 }

 virtual void calculate(const int32_t logid, const Work& w,
boost::function<void (int32_t)> callback, boost::function<void
(Calculator_calculate_result)> errback) {
(...)
   case DIVIDE:
     if (w.num2 == 0) {
       InvalidOperation io;
       io.what = w.op;
       io.why = "Cannot divide by 0";
       errback(calculate_ouch(io));
       return;
     }
     val = w.num1 / w.num2;
     break;
   default:
     errback(calculate_failure(std::string("Invalid Operation")));
     return;
   }
 (...)
}
(... other methods, omitted for brevity)

};

int main(int argc, char **argv) {
 boost::shared_ptr<protocol::TProtocolFactory> protocolFactory(new
protocol::TBinaryProtocolFactory());
 boost::shared_ptr<CalculatorAsyncHandler> handler(new
CalculatorAsyncHandler());
 boost::shared_ptr<TProcessor> processor(new
CalculatorAsyncProcessor(handler));

 boost::shared_ptr<apache::thrift::async::TAsioServer> server(
                                                              new
apache::thrift::async::TAsioServer(

                              io_service,

                              9090,

                              protocolFactory,

                              protocolFactory,

                              processor));

 server->start(); // Nonblocking
 io_service.run(); // Blocking

 return 0;
}

Code for the client:

void pingback() {
 printf("ping()\n");
}

void pingerr(tutorial::Calculator_ping_result result) {
 printf("Exception caught\n");
}

void addback(int32_t a, int32_t b, int32_t sum) {
 printf("%d+%d=%d\n", a, b, sum);
}

void adderr(tutorial::Calculator_add_result result) {
 printf("Exception caught\n");
}

void connected(boost::shared_ptr<tutorial::CalculatorAsyncClient> client) {
 client->ping(pingback, pingerr);

 client->add(2, 3, boost::bind(&addback, 2, 3, _1), &adderr);  //
will return after 5s
 client->add(1, 2, boost::bind(&addback, 1, 2, _1), &adderr);  //
will return after 3s
 client->add(1, 1, boost::bind(&addback, 1, 1, _1), &adderr);  //
will return after 2s
}

int main(int argc, char* argv[])
{
 try
 {
   boost::asio::io_service io_service;

   boost::shared_ptr<protocol::TProtocolFactory> protocolFactory(new
protocol::TBinaryProtocolFactory());

   boost::shared_ptr<async::TAsioClient> client (
                                                 new async::TAsioClient(

 io_service,

 protocolFactory,

 protocolFactory));

   client->connect("localhost", 9090, connected); // the type of the
client (tutorial::CalculatorAsyncClient) is inferred from the
signature of connected

   io_service.run();
 }
 catch (std::exception& e)
 {
   std::cout << "Exception: " << e.what() << "\n";
 }

 return 0;
}

Regards,
Erik Bernhardsson,
Mattias de Zalenski
Spotify - http://www.spotify.com/
Index: thrift/trunk/compiler/cpp/src/generate/t_cpp_async_generator.cc
===================================================================
--- thrift/trunk/compiler/cpp/src/generate/t_cpp_generator.cc (revision 47862)
+++ thrift/trunk/compiler/cpp/src/generate/t_cpp_async_generator.cc (revision 
50426)
@@ -53,4 +53,7 @@
     use_include_prefix_ = (iter != parsed_options.end());
 
+               iter = parsed_options.find("async");
+               gen_async_ = (iter != parsed_options.end());
+
     out_dir_base_ = "gen-cpp";
   }
@@ -103,4 +106,12 @@
   void generate_process_function  (t_service* tservice, t_function* tfunction);
   void generate_function_helpers  (t_service* tservice, t_function* tfunction);
+
+  /**
+   * Service-level async generation functions
+   */
+       void generate_async_interface(t_service* tservice);
+       void generate_async_client(t_service* tservice);
+  void generate_async_processor (t_service* tservice);
+  void generate_async_process_function  (t_service* tservice, t_function* 
tfunction);
 
   /**
@@ -174,4 +185,6 @@
   std::string type_to_enum(t_type* ttype);
   std::string local_reflection_name(const char*, t_type* ttype, bool 
external=false);
+  std::string async_function_signature(t_service* tservice, t_function* 
tfunction, std::string prefix="", bool name_params=true);
+  std::string async_recv_signature(t_service* tservice, t_function* tfunction, 
std::string prefix="", bool name_params=true);
 
   // These handles checking gen_dense_ and checking for duplicates.
@@ -206,4 +219,9 @@
 
   /**
+   * True iff we should generate asynchronous callback code
+   */
+  bool gen_async_;
+
+  /**
    * True iff we should use a path prefix in our #include statements for other
    * thrift-generated header files.
@@ -270,4 +288,10 @@
     "#include <transport/TTransport.h>" << endl <<
     endl;
+       if (gen_async_) {
+               f_types_ <<
+                       "#include <boost/bind.hpp>" << endl <<
+                       "#include <boost/function.hpp>" << endl <<
+                       endl;
+       }
 
   // Include other Thrift includes
@@ -752,4 +776,6 @@
     scope_up(out);
     for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+                       if (gen_async_ && (*m_iter)->get_name() == "failure")
+                               continue;
       // Most existing Thrift code does not use isset or optional/required,
       // so we treat "default" fields as required.
@@ -1296,4 +1322,10 @@
   generate_service_multiface(tservice);
   generate_service_skeleton(tservice);
+       if (gen_async_) {
+               generate_async_interface(tservice);
+               generate_async_client(tservice);
+               generate_async_processor(tservice);
+       }
+
 
   // Close the namespace
@@ -1361,4 +1393,69 @@
       indent() << "virtual " << function_signature(*f_iter) << " = 0;" << endl;
   }
+  indent_down();
+  f_header_ <<
+    "};" << endl << endl;
+}
+
+/**
+ * Generates a service asynchronous callback interface definition.
+ *
+ * @param tservice The service to generate a header definition for
+ */
+void t_cpp_generator::generate_async_interface(t_service* tservice) {
+  string extends = "";
+  if (tservice->get_extends() != NULL) {
+    extends = " : virtual public " + type_name(tservice->get_extends()) + 
"AsyncIf";
+  }
+  f_header_ <<
+    "class " << service_name_ << "AsyncIf" << extends << " {" << endl <<
+    " public:" << endl;
+  indent_up();
+  f_header_ <<
+    indent() << "virtual ~" << service_name_ << "AsyncIf() {}" << endl;
+
+  vector<t_function*> functions = tservice->get_functions();
+  vector<t_function*>::iterator f_iter;
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    f_header_ <<
+      indent() << "virtual " << async_function_signature(tservice, *f_iter) << 
" = 0;" << endl;
+  }
+
+       // Callback result typedefs and convenience functions
+       f_header_ << endl;
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+               if (!(*f_iter)->is_oneway()) {
+                       string resultname = (*f_iter)->get_name() + "_result";
+                       indent(f_header_) << 
+                               "typedef " << tservice->get_name() << "_" << 
resultname << " " << resultname << ";" << endl;
+                       vector<t_field*>::iterator ff_iter;
+                       vector<t_field*> results = 
(*f_iter)->get_xceptions()->get_members();
+                       t_field success((*f_iter)->get_returntype(), "success", 
0);
+                       if (!(*f_iter)->get_returntype()->is_void()) {
+                               results.insert(results.begin(), &success);
+                       }
+                       t_struct exc(NULL, 
"apache::thrift::TApplicationException");
+                       t_field failure(&exc, "failure", -1337); // FIXME: 
failure field id
+                       results.insert(results.begin(), &failure);
+                       for (ff_iter = results.begin(); ff_iter != 
results.end(); ++ff_iter) {
+                               indent(f_header_) <<
+                                       "static " << resultname << " " << 
(*f_iter)->get_name() << "_" << (*ff_iter)->get_name() << "(" << 
type_name((*ff_iter)->get_type()) << " " << (*ff_iter)->get_name() << ");" << 
endl;
+                               // Implementation
+                               indent_down();
+                               f_service_ <<
+                                       indent() << service_name_ << 
"AsyncIf::" << resultname << " " << service_name_ << "AsyncIf::" << 
(*f_iter)->get_name() << "_" << (*ff_iter)->get_name() << "(" << 
type_name((*ff_iter)->get_type()) << " " << (*ff_iter)->get_name() << ")" << 
endl;
+                               scope_up(f_service_);
+                               f_service_ <<
+                                       indent() << resultname << " result;" << 
endl <<
+                                       indent() << "result." << 
(*ff_iter)->get_name() << " = " << (*ff_iter)->get_name() << ";" << endl <<
+                                       indent() << "result.__isset." << 
(*ff_iter)->get_name() << " = true;" << endl <<
+                                       indent() << "return result;" << endl;
+                               scope_down(f_service_);
+                               f_service_ << endl;
+                               indent_up();
+                       }
+               }
+       }
+
   indent_down();
   f_header_ <<
@@ -1814,4 +1911,294 @@
 
 /**
+ * Generates a service async client definition.
+ *
+ * @param tservice The service to generate an async client for.
+ */
+void t_cpp_generator::generate_async_client(t_service* tservice) {
+  string extends = "";
+  string extends_client = "";
+       string extends_processor = "";
+  if (tservice->get_extends() != NULL) {
+    extends = type_name(tservice->get_extends());
+    extends_client = ", public " + extends + "AsyncClient";
+  }
+       else {
+               extends_processor = ", public apache::thrift::TProcessor";
+       }
+
+  // Generate the header portion
+  f_header_ <<
+    "class " << service_name_ << "AsyncClient : " <<
+    "virtual public " << service_name_ << "AsyncIf" <<
+    extends_client << extends_processor << " {" << endl <<
+    " public:" << endl;
+
+  indent_up();
+  f_header_ <<
+    indent() << service_name_ << 
"AsyncClient(boost::shared_ptr<apache::thrift::transport::TTransport> ot, 
boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> opf) :" << endl;
+  if (extends.empty()) {
+    f_header_ <<
+      indent() << "  potransport_(ot)," << endl <<
+      indent() << "  poprotfact_(opf)," << endl <<
+      indent() << "  request_counter_(0) {" << endl <<
+      indent() << "  otransport_ = ot.get();" << endl <<
+      indent() << "  oprotfact_ = opf.get();" << endl <<
+      indent() << "}" << endl;
+  } else {
+    f_header_ <<
+      indent() << "  " << extends << "AsyncClient(ot, opf) {}" << endl;
+  }
+
+  // Generate getters for the transport and protocol factory.
+  f_header_ <<
+    indent() << "boost::shared_ptr<apache::thrift::transport::TTransport> 
getOutputTransport() {" << endl <<
+    indent() << "  return potransport_;" << endl <<
+    indent() << "}" << endl;
+  f_header_ <<
+    indent() << "boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> 
getOutputProtocolFactory() {" << endl <<
+    indent() << "  return poprotfact_;" << endl <<
+    indent() << "}" << endl;
+
+  vector<t_function*> functions = tservice->get_functions();
+  vector<t_function*>::const_iterator f_iter;
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    t_function send_function(g_type_void,
+                             string("send_") + (*f_iter)->get_name(),
+                             (*f_iter)->get_arglist());
+    indent(f_header_) << async_function_signature(tservice, *f_iter) << ";" << 
endl;
+    indent(f_header_) << function_signature(&send_function) << ";" << endl;
+    if (!(*f_iter)->is_oneway()) {
+      //t_struct noargs(program_);
+      //t_function recv_function((*f_iter)->get_returntype(),
+      //                         string("recv_") + (*f_iter)->get_name(),
+      //                         &noargs);
+      //indent(f_header_) << function_signature(&recv_function) << ";" << endl;
+                       //indent(f_header_) << "void " << "recv_" << 
(*f_iter)->get_name() << "(apache::thrift::protocol::TProtocol* iprot, 
std::string fname, apache::thrift::protocol::TMessageType mtype);" << endl;
+                       indent(f_header_) << async_recv_signature(tservice, 
*f_iter) << ";" << endl;
+    }
+  }
+  indent(f_header_) << "bool 
process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot);" << endl;
+
+  indent_down();
+
+  if (extends.empty()) {
+    f_header_ <<
+      " protected:" << endl;
+    indent_up();
+    f_header_ <<
+                       indent() << 
"boost::shared_ptr<apache::thrift::transport::TTransport> potransport_;" << 
endl <<
+      indent() << 
"boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> poprotfact_;"  
<< endl <<
+                       indent() << "apache::thrift::transport::TTransport* 
otransport_;" << endl <<
+      indent() << "apache::thrift::protocol::TProtocolFactory* oprotfact_;"  
<< endl <<
+                       indent() << "typedef std::map<int32_t, 
boost::function<void (boost::shared_ptr<apache::thrift::protocol::TProtocol> 
prot, std::string fname, apache::thrift::protocol::TMessageType mtype)> > 
request_map_t;" << endl <<
+                       indent() << "request_map_t requests_;" << endl <<
+                       indent() << "int32_t request_counter_;" << endl;
+    indent_down();
+  }
+
+  f_header_ <<
+    "};" << endl <<
+    endl;
+
+  string scope = service_name_ + "AsyncClient::";
+
+  // Generate async client method implementations
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    string funname = (*f_iter)->get_name();
+
+    // Open function
+    indent(f_service_) <<
+      async_function_signature(tservice, *f_iter, scope) << endl;
+    scope_up(f_service_);
+    indent(f_service_) <<
+      "send_" << funname << "(";
+
+    // Get the struct of function call params
+    t_struct* arg_struct = (*f_iter)->get_arglist();
+
+    // Declare the function arguments
+    const vector<t_field*>& fields = arg_struct->get_members();
+    vector<t_field*>::const_iterator fld_iter;
+    bool first = true;
+    for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
+      if (first) {
+        first = false;
+      } else {
+        f_service_ << ", ";
+      }
+      f_service_ << (*fld_iter)->get_name();
+    }
+    f_service_ << ");" << endl;
+
+    if (!(*f_iter)->is_oneway()) {
+                       f_service_ << indent() << "requests_[request_counter_] 
= boost::bind(&" << scope << "recv_" << funname << ", this, _1 /* piprot */, _2 
/* fname */, _3 /* mtype */, callback, errback);" << endl;
+                       /*
+      f_service_ << indent();
+      if (!(*f_iter)->get_returntype()->is_void()) {
+        if (is_complex_type((*f_iter)->get_returntype())) {
+          f_service_ << "recv_" << funname << "(_return);" << endl;
+        } else {
+          f_service_ << "return recv_" << funname << "();" << endl;
+        }
+      } else {
+        f_service_ <<
+          "recv_" << funname << "();" << endl;
+      }
+                       */
+    }
+    scope_down(f_service_);
+    f_service_ << endl;
+
+    // Function for sending
+    t_function send_function(g_type_void,
+                             string("send_") + (*f_iter)->get_name(),
+                             (*f_iter)->get_arglist());
+
+    // Open the send function
+    indent(f_service_) <<
+      function_signature(&send_function, scope) << endl;
+    scope_up(f_service_);
+
+    // Function arguments and results
+    string argsname = tservice->get_name() + "_" + (*f_iter)->get_name() + 
"_pargs";
+    string resultname = tservice->get_name() + "_" + (*f_iter)->get_name() + 
"_result";
+
+    // Serialize the request
+    f_service_ <<
+                       indent() << 
"boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot = 
oprotfact_->getProtocol(potransport_);" << endl <<
+                       indent() << "apache::thrift::protocol::TProtocol* oprot 
= poprot.get();" << endl <<
+      indent() << "int32_t cseqid = ++request_counter_;" << endl <<
+      indent() << "oprot->writeMessageBegin(\"" << (*f_iter)->get_name() << 
"\", apache::thrift::protocol::T_CALL, cseqid);" << endl <<
+      endl <<
+      indent() << argsname << " args;" << endl;
+
+    for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
+      f_service_ <<
+        indent() << "args." << (*fld_iter)->get_name() << " = &" << 
(*fld_iter)->get_name() << ";" << endl;
+    }
+
+    f_service_ <<
+      indent() << "args.write(oprot);" << endl <<
+      endl <<
+      indent() << "oprot->writeMessageEnd();" << endl <<
+      indent() << "oprot->getTransport()->flush();" << endl <<
+      indent() << "oprot->getTransport()->writeEnd();" << endl;
+
+    scope_down(f_service_);
+    f_service_ << endl;
+
+    // Generate recv function only if not an oneway function
+    if (!(*f_iter)->is_oneway()) {
+      t_struct noargs(program_);
+      t_function recv_function((*f_iter)->get_returntype(),
+                               string("recv_") + (*f_iter)->get_name(),
+                               &noargs);
+
+      // Open function
+                       indent(f_service_) << async_recv_signature(tservice, 
*f_iter, scope) << endl;
+      scope_up(f_service_);
+
+      f_service_ <<
+        endl <<
+        indent() << "if (mtype == apache::thrift::protocol::T_EXCEPTION) {" << 
endl <<
+        indent() << "  apache::thrift::TApplicationException x;" << endl <<
+        indent() << "  x.read(iprot.get());" << endl <<
+        indent() << "  iprot->readMessageEnd();" << endl <<
+        indent() << "  iprot->getTransport()->readEnd();" << endl <<
+        indent() << "  errback(" << (*f_iter)->get_name() << "_failure(x));" 
<< endl <<
+        indent() << "  return;" << endl <<
+        indent() << "}" << endl <<
+        indent() << "if (mtype != apache::thrift::protocol::T_REPLY) {" << 
endl <<
+        indent() << "  iprot->skip(apache::thrift::protocol::T_STRUCT);" << 
endl <<
+        indent() << "  iprot->readMessageEnd();" << endl <<
+        indent() << "  iprot->getTransport()->readEnd();" << endl <<
+        indent() << "  errback(" << (*f_iter)->get_name() << 
"_failure(apache::thrift::TApplicationException(apache::thrift::TApplicationException::INVALID_MESSAGE_TYPE)));"
 << endl <<
+                               indent() << "  return;" << endl <<
+        indent() << "}" << endl <<
+        indent() << "if (fname.compare(\"" << (*f_iter)->get_name() << "\") != 
0) {" << endl <<
+        indent() << "  iprot->skip(apache::thrift::protocol::T_STRUCT);" << 
endl <<
+        indent() << "  iprot->readMessageEnd();" << endl <<
+        indent() << "  iprot->getTransport()->readEnd();" << endl <<
+        indent() << "  errback(" << (*f_iter)->get_name() << 
"_failure(apache::thrift::TApplicationException(apache::thrift::TApplicationException::WRONG_METHOD_NAME)));"
 << endl <<
+                               indent() << "  return;" << endl <<
+        indent() << "}" << endl;
+
+      if (!(*f_iter)->get_returntype()->is_void() &&
+          !is_complex_type((*f_iter)->get_returntype())) {
+        t_field returnfield((*f_iter)->get_returntype(), "_return");
+        f_service_ <<
+          indent() << declare_field(&returnfield) << endl;
+      }
+
+      f_service_ <<
+        indent() << resultname << " result;" << endl;
+
+      f_service_ <<
+        indent() << "result.read(iprot.get());" << endl <<
+        indent() << "iprot->readMessageEnd();" << endl <<
+        indent() << "iprot->getTransport()->readEnd();" << endl <<
+        endl;
+
+      // Careful, only look for _result if not a void function
+      if (!(*f_iter)->get_returntype()->is_void()) {
+                               f_service_ <<
+                                       indent() << "if 
(result.__isset.success) {" << endl <<
+                                       indent() << "  
callback(result.success);" << endl <<
+                                       indent() << "  return;" << endl <<
+                                       indent() << "}" << endl;
+      }
+
+                       // Handle failure
+                       // FIXME: deprecated! failure is sent as exception
+                       f_service_ <<
+                               indent() << "if (result.__isset.failure) {" << 
endl <<
+                               indent() << "  errback(result);" << endl <<
+                               indent() << "  return;" << endl <<
+                               indent() << "}" << endl;
+
+                       // Handle exception
+      t_struct* xs = (*f_iter)->get_xceptions();
+      const std::vector<t_field*>& xceptions = xs->get_members();
+      vector<t_field*>::const_iterator x_iter;
+      for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
+        f_service_ <<
+          indent() << "if (result.__isset." << (*x_iter)->get_name() << ") {" 
<< endl <<
+          indent() << "  errback(result);" << endl <<
+          indent() << "  return;" << endl <<
+          indent() << "}" << endl;
+      }
+
+      // We only get here if we are a void function
+      if ((*f_iter)->get_returntype()->is_void()) {
+        indent(f_service_) << "callback();" << endl;
+      } else {
+        f_service_ <<
+                                       indent() << "  errback(" << 
(*f_iter)->get_name() << 
"_failure(apache::thrift::TApplicationException(apache::thrift::TApplicationException::MISSING_RESULT,
 \"" << (*f_iter)->get_name() << " failed: unknown result\")));" <<
+                                       indent() << "  return;" << endl;
+      }
+
+      // Close function
+      scope_down(f_service_);
+      f_service_ << endl;
+    }
+  }
+  indent(f_service_) << "bool " << scope << 
"process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot)" << endl;
+  scope_up(f_service_);
+       f_service_ <<
+               indent() << "std::string fname;" << endl <<
+               indent() << "apache::thrift::protocol::TMessageType mtype;" << 
endl <<
+               indent() << "int32_t rseqid;" << endl <<
+               endl <<
+               indent() << "piprot->readMessageBegin(fname, mtype, rseqid);" 
<< endl <<
+               indent() << "request_map_t::iterator it = 
requests_.find(rseqid);" << endl <<
+               indent() << "if (it == requests_.end()) return false;" << endl 
<<
+               indent() << "it->second(piprot, fname, mtype); /* Invoke user 
method */" << endl <<
+               indent() << "requests_.erase(it);" << endl <<
+               indent() << "return true;" << endl;
+
+  scope_down(f_service_);
+}
+
+/**
  * Generates a service server definition.
  *
@@ -1973,4 +2360,174 @@
 
 /**
+ * Generates a service server definition.
+ *
+ * @param tservice The service to generate a server for.
+ */
+void t_cpp_generator::generate_async_processor(t_service* tservice) {
+  // Generate the dispatch methods
+  vector<t_function*> functions = tservice->get_functions();
+  vector<t_function*>::iterator f_iter;
+
+  string extends = "";
+  string extends_processor = "";
+  if (tservice->get_extends() != NULL) {
+    extends = type_name(tservice->get_extends());
+    extends_processor = ", public " + extends + "AsyncProcessor";
+  }
+
+  // Generate the header portion
+  f_header_ <<
+    "class " << service_name_ << "AsyncProcessor : " <<
+    "virtual public apache::thrift::TProcessor" <<
+    extends_processor << " {" << endl;
+
+  // Protected data members
+  f_header_ <<
+    " protected:" << endl;
+  indent_up();
+  f_header_ <<
+    indent() << "boost::shared_ptr<" << service_name_ << "AsyncIf> iface_;" << 
endl;
+  f_header_ <<
+    indent() << "virtual bool 
process_fn(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, std::string& 
fname, int32_t seqid);" << endl;
+  indent_down();
+
+  // Process function declarations
+  f_header_ <<
+    " private:" << endl;
+  indent_up();
+  f_header_ <<
+    indent() << "std::map<std::string, void (" << service_name_ << 
"AsyncProcessor::*)(int32_t, 
boost::shared_ptr<apache::thrift::protocol::TProtocol>, 
boost::shared_ptr<apache::thrift::protocol::TProtocol>)> processMap_;" << endl;
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    indent(f_header_) <<
+      "void process_" << (*f_iter)->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot);" << endl;
+               if (!(*f_iter)->is_oneway()) {
+                       if (!(*f_iter)->get_returntype()->is_void()) {
+                               indent(f_header_) <<
+                                       "void success_" << 
(*f_iter)->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, " << 
type_name((*f_iter)->get_returntype()) << " value);" << endl;
+                       }
+                       else {
+                               indent(f_header_) <<
+                                       "void success_" << 
(*f_iter)->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot);" << endl;
+                       }
+                       indent(f_header_) <<
+                               "void reply_" << (*f_iter)->get_name() << 
"(int32_t seqid, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, 
" << service_name_ << "_" << (*f_iter)->get_name() << "_result result);" << 
endl;
+               }
+       }
+  indent_down();
+
+  indent_up();
+  string declare_map = "";
+  indent_up();
+
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    declare_map += indent();
+    declare_map += "processMap_[\"";
+    declare_map += (*f_iter)->get_name();
+    declare_map += "\"] = &";
+    declare_map += service_name_;
+    declare_map += "AsyncProcessor::process_";
+    declare_map += (*f_iter)->get_name();
+    declare_map += ";\n";
+  }
+  indent_down();
+
+  f_header_ <<
+    " public:" << endl <<
+               // Constructor
+    indent() << service_name_ << "AsyncProcessor(boost::shared_ptr<" << 
service_name_ << "AsyncIf> iface) :" << endl;
+  if (extends.empty()) {
+    f_header_ <<
+      indent() << "  iface_(iface) {" << endl;
+  } else {
+    f_header_ <<
+      indent() << "  " << extends << "AsyncProcessor(iface)," << endl <<
+      indent() << "  iface_(iface) {" << endl;
+  }
+  f_header_ <<
+    declare_map <<
+    indent() << "}" << endl <<
+    endl <<
+    indent() << "virtual bool 
process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot);" << endl <<
+    indent() << "virtual ~" << service_name_ << "AsyncProcessor() {}" << endl;
+  indent_down();
+  f_header_ <<
+    "};" << endl << endl;
+
+  // Generate the server implementation
+  f_service_ <<
+    "bool " << service_name_ << 
"AsyncProcessor::process(boost::shared_ptr<apache::thrift::protocol::TProtocol> 
piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot) {" << 
endl;
+  indent_up();
+
+  f_service_ <<
+    endl <<
+    indent() << "std::string fname;" << endl <<
+    indent() << "apache::thrift::protocol::TMessageType mtype;" << endl <<
+    indent() << "int32_t seqid;" << endl <<
+    endl <<
+    indent() << "piprot->readMessageBegin(fname, mtype, seqid);" << endl <<
+    endl <<
+    indent() << "if (mtype != apache::thrift::protocol::T_CALL && mtype != 
apache::thrift::protocol::T_ONEWAY) {" << endl <<
+    indent() << "  piprot->skip(apache::thrift::protocol::T_STRUCT);" << endl 
<<
+    indent() << "  piprot->readMessageEnd();" << endl <<
+    indent() << "  piprot->getTransport()->readEnd();" << endl <<
+    indent() << "  apache::thrift::TApplicationException 
x(apache::thrift::TApplicationException::INVALID_MESSAGE_TYPE);" << endl <<
+    indent() << "  poprot->writeMessageBegin(fname, 
apache::thrift::protocol::T_EXCEPTION, seqid);" << endl <<
+    indent() << "  x.write(poprot.get());" << endl <<
+    indent() << "  poprot->writeMessageEnd();" << endl <<
+    indent() << "  poprot->getTransport()->flush();" << endl <<
+    indent() << "  poprot->getTransport()->writeEnd();" << endl <<
+    indent() << "  return true;" << endl <<
+    indent() << "}" << endl <<
+    endl <<
+    indent() << "return process_fn(piprot, poprot, fname, seqid);" <<
+    endl;
+
+  indent_down();
+  f_service_ <<
+    indent() << "}" << endl <<
+    endl;
+
+  f_service_ <<
+    "bool " << service_name_ << 
"AsyncProcessor::process_fn(boost::shared_ptr<apache::thrift::protocol::TProtocol>
 piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, 
std::string& fname, int32_t seqid) {" << endl;
+  indent_up();
+
+  // HOT: member function pointer map
+  f_service_ <<
+    indent() << "std::map<std::string, void (" << service_name_ << 
"AsyncProcessor::*)(int32_t, 
boost::shared_ptr<apache::thrift::protocol::TProtocol>, 
boost::shared_ptr<apache::thrift::protocol::TProtocol>)>::iterator pfn;" << 
endl <<
+    indent() << "pfn = processMap_.find(fname);" << endl <<
+    indent() << "if (pfn == processMap_.end()) {" << endl;
+  if (extends.empty()) {
+    f_service_ <<
+      indent() << "  piprot->skip(apache::thrift::protocol::T_STRUCT);" << 
endl <<
+      indent() << "  piprot->readMessageEnd();" << endl <<
+      indent() << "  piprot->getTransport()->readEnd();" << endl <<
+      indent() << "  apache::thrift::TApplicationException 
x(apache::thrift::TApplicationException::UNKNOWN_METHOD, \"Invalid method name: 
'\"+fname+\"'\");" << endl <<
+      indent() << "  poprot->writeMessageBegin(fname, 
apache::thrift::protocol::T_EXCEPTION, seqid);" << endl <<
+      indent() << "  x.write(poprot.get());" << endl <<
+      indent() << "  poprot->writeMessageEnd();" << endl <<
+      indent() << "  poprot->getTransport()->flush();" << endl <<
+      indent() << "  poprot->getTransport()->writeEnd();" << endl <<
+      indent() << "  return true;" << endl;
+  } else {
+    f_service_ <<
+      indent() << "  return " << extends << 
"AsyncProcessor::process_fn(piprot, poprot, fname, seqid);" << endl;
+  }
+  f_service_ <<
+    indent() << "}" << endl <<
+    indent() << "(this->*(pfn->second))(seqid, piprot, poprot);" << endl <<
+    indent() << "return true;" << endl;
+
+  indent_down();
+  f_service_ <<
+    "}" << endl <<
+    endl;
+
+  // Generate the process subfunctions
+  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+    generate_async_process_function(tservice, *f_iter);
+  }
+}
+
+/**
  * Generates a struct and helpers for a function.
  *
@@ -1988,4 +2545,10 @@
     result.append(&success);
   }
+
+       t_struct exc(NULL, "apache::thrift::TApplicationException");
+  t_field failure(&exc, "failure", -1337); // FIXME: failure field id
+       if (gen_async_) {
+               result.append(&failure);
+       }
 
   t_struct* xs = tfunction->get_xceptions();
@@ -2136,4 +2699,121 @@
   scope_down(f_service_);
   f_service_ << endl;
+}
+
+/**
+ * Generates async process function definitions.
+ *
+ * @param tfunction The function to write a dispatcher for
+ */
+void t_cpp_generator::generate_async_process_function(t_service* tservice,
+                                                                               
                                                                                
                                                                 t_function* 
tfunction) {
+  // Open function
+  f_service_ <<
+    "void " << tservice->get_name() << "AsyncProcessor::" <<
+    "process_" << tfunction->get_name() <<
+    "(int32_t seqid, boost::shared_ptr<apache::thrift::protocol::TProtocol> 
piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot)" << endl;
+  scope_up(f_service_);
+
+  string argsname = tservice->get_name() + "_" + tfunction->get_name() + 
"_args";
+  string resultname = tservice->get_name() + "_" + tfunction->get_name() + 
"_result";
+
+  f_service_ <<
+    indent() << argsname << " args;" << endl <<
+    indent() << "args.read(piprot.get());" << endl <<
+    indent() << "piprot->readMessageEnd();" << endl <<
+    indent() << "piprot->getTransport()->readEnd();" << endl <<
+    endl;
+
+  //t_struct* xs = tfunction->get_xceptions();
+  //const std::vector<t_field*>& xceptions = xs->get_members();
+  //vector<t_field*>::const_iterator x_iter;
+
+  // Callbacks
+  if (!tfunction->is_oneway()) {
+               if (!tfunction->get_returntype()->is_void()) {
+                       f_service_ <<
+                               indent() << "boost::function<void (" << 
type_name(tfunction->get_returntype()) << ")> callback = boost::bind(&" << 
tservice->get_name() << "AsyncProcessor::success_" << tfunction->get_name() << 
", this, seqid, poprot, _1);" << endl;
+               }
+               else {
+                       f_service_ <<
+                               indent() << "boost::function<void (void)> 
callback = boost::bind(&" << tservice->get_name() << "AsyncProcessor::success_" 
<< tfunction->get_name() << ", this, seqid, poprot);" << endl;
+               }
+    f_service_ <<
+      indent() << "boost::function<void (" << resultname << ")> errback = 
boost::bind(&" << tservice->get_name() << "AsyncProcessor::reply_" << 
tfunction->get_name() << ", this, seqid, poprot, _1);" << endl;
+  }
+
+  // Generate the function call
+  t_struct* arg_struct = tfunction->get_arglist();
+  const std::vector<t_field*>& fields = arg_struct->get_members();
+  vector<t_field*>::const_iterator f_iter;
+
+  bool first = true;
+  indent(f_service_) <<
+               "iface_->" << tfunction->get_name() << "(";
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if (first) {
+      first = false;
+    } else {
+      f_service_ << ", ";
+    }
+    f_service_ << "args." << (*f_iter)->get_name();
+  }
+  if (!tfunction->is_oneway()) {
+               f_service_ << (fields.empty() ? "" : ", ") << "callback, 
errback";
+       }
+       f_service_ << ");" << endl;
+       scope_down(f_service_);
+       f_service_ << endl;
+
+  if (!tfunction->is_oneway()) {
+               // Success implementation
+               if (!tfunction->get_returntype()->is_void()) {
+                       indent(f_service_) <<
+                               "void " << tservice->get_name() << 
"AsyncProcessor::success_" << tfunction->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, " << 
type_name(tfunction->get_returntype()) << " value)" << endl;
+               }
+               else {
+                       indent(f_service_) <<
+                               "void " << tservice->get_name() << 
"AsyncProcessor::success_" << tfunction->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot)" << endl;
+               }
+               scope_up(f_service_);
+               indent(f_service_) << resultname << " result;" << endl;
+               // Set isset on success field
+               if (!tfunction->get_returntype()->is_void()) {
+                       indent(f_service_) << "result.success = value;" << endl;
+                       indent(f_service_) << "result.__isset.success = true;" 
<< endl;
+               }
+               indent(f_service_) << "reply_" << tfunction->get_name() << 
"(seqid, poprot, result);" << endl;
+               scope_down(f_service_);
+               f_service_ << endl;
+
+               // Reply implementation
+               indent(f_service_) <<
+      "void " << tservice->get_name() << "AsyncProcessor::reply_" << 
tfunction->get_name() << "(int32_t seqid, 
boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot, " << resultname 
<< " result)" << endl;
+               scope_up(f_service_);
+
+               // Send failure tapplication exception
+    f_service_ <<
+      indent() << "if (result.__isset.failure) {" << endl <<
+      indent() << "  poprot->writeMessageBegin(\"" << tfunction->get_name() << 
"\", apache::thrift::protocol::T_EXCEPTION, seqid);" << endl <<
+      indent() << "  result.failure.write(poprot.get());" << endl <<
+      indent() << "  poprot->writeMessageEnd();" << endl <<
+      indent() << "  poprot->getTransport()->flush();" << endl <<
+      indent() << "  poprot->getTransport()->writeEnd();" << endl <<
+      indent() << "  return;" << endl <<
+                       indent() << "}" << endl;
+
+               // Serialize the result into a struct
+               f_service_ <<
+                       endl <<
+                       indent() << "poprot->writeMessageBegin(\"" << 
tfunction->get_name() << "\", apache::thrift::protocol::T_REPLY, seqid);" << 
endl <<
+                       indent() << "result.write(poprot.get());" << endl <<
+                       indent() << "poprot->writeMessageEnd();" << endl <<
+                       indent() << "poprot->getTransport()->flush();" << endl 
<<
+                       indent() << "poprot->getTransport()->writeEnd();" << 
endl;
+
+               // Close function
+               scope_down(f_service_);
+               f_service_ << endl;
+       }
 }
 
@@ -2869,4 +3549,39 @@
 
 /**
+ * Renders a function signature of the form 'type name(args, callback)'
+ *
+ * @param tfunction Function definition
+ * @return String of rendered function definition
+ */
+string t_cpp_generator::async_function_signature(t_service* tservice,
+                                                                               
                                                                                
                                                t_function* tfunction,
+                                                                               
                                                                                
                                                string prefix,
+                                                                               
                                                                                
                                                bool name_params) {
+  t_type* ttype = tfunction->get_returntype();
+  t_struct* arglist = tfunction->get_arglist();
+
+       bool empty = arglist->get_members().size() == 0;
+
+       return
+               "void " + prefix + tfunction->get_name() +
+               "(" + argument_list(arglist, name_params) +
+               (!tfunction->is_oneway() ?
+                (empty ? "" : ", ") +
+                ("boost::function<void (" + type_name(ttype) + ")> callback, 
") +
+                ("boost::function<void (" + tservice->get_name() + "_" + 
tfunction->get_name() + "_result)> errback)")
+                : ")");
+}
+
+string t_cpp_generator::async_recv_signature(t_service* tservice,
+                                                                               
                                                                                
                                t_function* tfunction,
+                                                                               
                                                                                
                                string prefix,
+                                                                               
                                                                                
                                bool name_params) {
+  t_type* ttype = tfunction->get_returntype();
+
+       return "void " + prefix + "recv_" + tfunction->get_name() + 
"(boost::shared_ptr<apache::thrift::protocol::TProtocol> iprot, std::string 
fname, apache::thrift::protocol::TMessageType mtype, boost::function<void (" + 
type_name(ttype) + ")> callback, boost::function<void (" + tservice->get_name() 
+ "_" + tfunction->get_name() + "_result)> errback)";
+}
+
+
+/**
  * Renders a field list
  *
@@ -3001,3 +3716,4 @@
 "    dense:           Generate type specifications for the dense protocol.\n"
 "    include_prefix:  Use full include paths in generated files.\n"
+"    async:           Generate code for asynchronous callbacks.\n"
 );

Reply via email to