This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git


The following commit(s) were added to refs/heads/main by this push:
     new bf85c1fd3 refactor: Polish multipart writer to allow oneshot 
optimization (#3031)
bf85c1fd3 is described below

commit bf85c1fd3dcbda1bbb565032692fb799ee08b175
Author: Xuanwo <[email protected]>
AuthorDate: Mon Sep 11 20:58:20 2023 +0800

    refactor: Polish multipart writer to allow oneshot optimization (#3031)
    
    * polish multipart writer
    
    Signed-off-by: Xuanwo <[email protected]>
    
    * Fix doctest
    
    Signed-off-by: Xuanwo <[email protected]>
    
    ---------
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/raw/oio/write/multipart_upload_write.rs | 98 ++++++++++++++++++++----
 core/src/services/cos/backend.rs                 |  6 +-
 core/src/services/cos/writer.rs                  | 20 ++---
 core/src/services/obs/backend.rs                 |  6 +-
 core/src/services/obs/writer.rs                  | 20 ++---
 core/src/services/oss/backend.rs                 |  6 +-
 core/src/services/oss/writer.rs                  | 20 ++---
 core/src/services/s3/backend.rs                  |  6 +-
 core/src/services/s3/writer.rs                   | 17 ++--
 9 files changed, 112 insertions(+), 87 deletions(-)

diff --git a/core/src/raw/oio/write/multipart_upload_write.rs 
b/core/src/raw/oio/write/multipart_upload_write.rs
index 7124554ff..67f18ecf3 100644
--- a/core/src/raw/oio/write/multipart_upload_write.rs
+++ b/core/src/raw/oio/write/multipart_upload_write.rs
@@ -21,6 +21,7 @@ use std::task::Context;
 use std::task::Poll;
 
 use async_trait::async_trait;
+use bytes::Bytes;
 use futures::future::BoxFuture;
 
 use crate::raw::*;
@@ -37,8 +38,26 @@ use crate::*;
 /// - Services impl `MultipartUploadWrite`
 /// - `MultipartUploadWriter` impl `Write`
 /// - Expose `MultipartUploadWriter` as `Accessor::Writer`
+///
+/// # Notes
+///
+/// `MultipartUploadWrite` has an oneshot optimization when `write` has been 
called only once:
+///
+/// ```no_build
+/// w.write(bs).await?;
+/// w.close().await?;
+/// ```
+///
+/// We will use `write_once` instead of starting a new multipart upload.
 #[async_trait]
 pub trait MultipartUploadWrite: Send + Sync + Unpin + 'static {
+    /// write_once is used to write the data to underlying storage at once.
+    ///
+    /// MultipartUploadWriter will call this API when:
+    ///
+    /// - All the data has been written to the buffer and we can perform the 
upload at once.
+    async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>;
+
     /// initiate_part will call start a multipart upload and return the upload 
id.
     ///
     /// MultipartUploadWriter will call this when:
@@ -90,6 +109,7 @@ pub struct MultipartUploadPart {
 pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
     state: State<W>,
 
+    cache: Option<Bytes>,
     upload_id: Option<Arc<String>>,
     parts: Vec<MultipartUploadPart>,
 }
@@ -97,7 +117,7 @@ pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
 enum State<W> {
     Idle(Option<W>),
     Init(BoxFuture<'static, (W, Result<String>)>),
-    Write(BoxFuture<'static, (W, usize, Result<MultipartUploadPart>)>),
+    Write(BoxFuture<'static, (W, Result<MultipartUploadPart>)>),
     Close(BoxFuture<'static, (W, Result<()>)>),
     Abort(BoxFuture<'static, (W, Result<()>)>),
 }
@@ -113,6 +133,7 @@ impl<W: MultipartUploadWrite> MultipartUploadWriter<W> {
         Self {
             state: State::Idle(Some(inner)),
 
+            cache: None,
             upload_id: None,
             parts: Vec::new(),
         }
@@ -128,15 +149,15 @@ where
         loop {
             match &mut self.state {
                 State::Idle(w) => {
-                    let w = w.take().expect("writer must be valid");
                     match self.upload_id.as_ref() {
                         Some(upload_id) => {
-                            let size = bs.remaining();
-                            let bs = bs.copy_to_bytes(size);
                             let upload_id = upload_id.clone();
                             let part_number = self.parts.len();
 
+                            let bs = self.cache.clone().expect("cache must be 
valid").clone();
+                            let w = w.take().expect("writer must be valid");
                             self.state = State::Write(Box::pin(async move {
+                                let size = bs.len();
                                 let part = w
                                     .write_part(
                                         &upload_id,
@@ -146,10 +167,18 @@ where
                                     )
                                     .await;
 
-                                (w, size, part)
+                                (w, part)
                             }));
                         }
                         None => {
+                            // Fill cache with the first write.
+                            if self.cache.is_none() {
+                                let size = bs.remaining();
+                                self.cache = Some(bs.copy_to_bytes(size));
+                                return Poll::Ready(Ok(size));
+                            }
+
+                            let w = w.take().expect("writer must be valid");
                             self.state = State::Init(Box::pin(async move {
                                 let upload_id = w.initiate_part().await;
                                 (w, upload_id)
@@ -163,10 +192,12 @@ where
                     self.upload_id = Some(Arc::new(upload_id?));
                 }
                 State::Write(fut) => {
-                    let (w, size, part) = ready!(fut.as_mut().poll(cx));
+                    let (w, part) = ready!(fut.as_mut().poll(cx));
                     self.state = State::Idle(Some(w));
-
                     self.parts.push(part?);
+                    // Replace the cache when last write succeeded
+                    let size = bs.remaining();
+                    self.cache = Some(bs.copy_to_bytes(size));
                     return Poll::Ready(Ok(size));
                 }
                 State::Close(_) => {
@@ -191,25 +222,57 @@ where
                     match self.upload_id.clone() {
                         Some(upload_id) => {
                             let parts = self.parts.clone();
-                            self.state = State::Close(Box::pin(async move {
-                                let res = w.complete_part(&upload_id, 
&parts).await;
-                                (w, res)
-                            }));
+                            match self.cache.clone() {
+                                Some(bs) => {
+                                    let upload_id = upload_id.clone();
+                                    self.state = State::Write(Box::pin(async 
move {
+                                        let size = bs.len();
+                                        let part = w
+                                            .write_part(
+                                                &upload_id,
+                                                parts.len(),
+                                                size as u64,
+                                                AsyncBody::Bytes(bs),
+                                            )
+                                            .await;
+                                        (w, part)
+                                    }));
+                                }
+                                None => {
+                                    self.state = State::Close(Box::pin(async 
move {
+                                        let res = w.complete_part(&upload_id, 
&parts).await;
+                                        (w, res)
+                                    }));
+                                }
+                            }
                         }
-                        None => return Poll::Ready(Ok(())),
+                        None => match self.cache.clone() {
+                            Some(bs) => {
+                                self.state = State::Close(Box::pin(async move {
+                                    let size = bs.len();
+                                    let res = w.write_once(size as u64, 
AsyncBody::Bytes(bs)).await;
+                                    (w, res)
+                                }));
+                            }
+                            None => return Poll::Ready(Ok(())),
+                        },
                     }
                 }
                 State::Close(fut) => {
                     let (w, res) = futures::ready!(fut.as_mut().poll(cx));
                     self.state = State::Idle(Some(w));
+                    self.cache = None;
                     return Poll::Ready(res);
                 }
                 State::Init(_) => unreachable!(
                     "MultipartUploadWriter must not go into State::Init during 
poll_close"
                 ),
-                State::Write(_) => unreachable!(
-                    "MultipartUploadWriter must not go into State::Write 
during poll_close"
-                ),
+                State::Write(fut) => {
+                    let (w, part) = ready!(fut.as_mut().poll(cx));
+                    self.state = State::Idle(Some(w));
+                    self.parts.push(part?);
+                    self.cache = None;
+                }
                 State::Abort(_) => unreachable!(
                     "MultipartUploadWriter must not go into State::Abort 
during poll_close"
                 ),
@@ -229,7 +292,10 @@ where
                                 (w, res)
                             }));
                         }
-                        None => return Poll::Ready(Ok(())),
+                        None => {
+                            self.cache = None;
+                            return Poll::Ready(Ok(()));
+                        }
                     }
                 }
                 State::Abort(fut) => {
diff --git a/core/src/services/cos/backend.rs b/core/src/services/cos/backend.rs
index 46e3b3064..eced1ccc9 100644
--- a/core/src/services/cos/backend.rs
+++ b/core/src/services/cos/backend.rs
@@ -337,11 +337,9 @@ impl Accessor for CosBackend {
         let writer = CosWriter::new(self.core.clone(), path, args.clone());
 
         let w = if args.append() {
-            CosWriters::Three(oio::AppendObjectWriter::new(writer))
-        } else if args.content_length().is_some() {
-            CosWriters::One(oio::OneShotWriter::new(writer))
+            CosWriters::Two(oio::AppendObjectWriter::new(writer))
         } else {
-            CosWriters::Two(oio::MultipartUploadWriter::new(writer))
+            CosWriters::One(oio::MultipartUploadWriter::new(writer))
         };
 
         let w = if let Some(buffer_size) = args.buffer_size() {
diff --git a/core/src/services/cos/writer.rs b/core/src/services/cos/writer.rs
index fba7cea09..11681b99a 100644
--- a/core/src/services/cos/writer.rs
+++ b/core/src/services/cos/writer.rs
@@ -18,7 +18,6 @@
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use bytes::Bytes;
 use http::StatusCode;
 
 use super::core::*;
@@ -26,11 +25,8 @@ use super::error::parse_error;
 use crate::raw::*;
 use crate::*;
 
-pub type CosWriters = oio::ThreeWaysWriter<
-    oio::OneShotWriter<CosWriter>,
-    oio::MultipartUploadWriter<CosWriter>,
-    oio::AppendObjectWriter<CosWriter>,
->;
+pub type CosWriters =
+    oio::TwoWaysWriter<oio::MultipartUploadWriter<CosWriter>, 
oio::AppendObjectWriter<CosWriter>>;
 
 pub struct CosWriter {
     core: Arc<CosCore>,
@@ -50,16 +46,15 @@ impl CosWriter {
 }
 
 #[async_trait]
-impl oio::OneShotWrite for CosWriter {
-    async fn write_once(&self, buf: Bytes) -> Result<()> {
-        let size = buf.len();
+impl oio::MultipartUploadWrite for CosWriter {
+    async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
         let mut req = self.core.cos_put_object_request(
             &self.path,
-            Some(size as u64),
+            Some(size),
             self.op.content_type(),
             self.op.content_disposition(),
             self.op.cache_control(),
-            AsyncBody::Bytes(buf),
+            body,
         )?;
 
         self.core.sign(&mut req).await?;
@@ -76,10 +71,7 @@ impl oio::OneShotWrite for CosWriter {
             _ => Err(parse_error(resp).await?),
         }
     }
-}
 
-#[async_trait]
-impl oio::MultipartUploadWrite for CosWriter {
     async fn initiate_part(&self) -> Result<String> {
         let resp = self
             .core
diff --git a/core/src/services/obs/backend.rs b/core/src/services/obs/backend.rs
index cc14eedd0..9cca757a8 100644
--- a/core/src/services/obs/backend.rs
+++ b/core/src/services/obs/backend.rs
@@ -375,11 +375,9 @@ impl Accessor for ObsBackend {
         let writer = ObsWriter::new(self.core.clone(), path, args.clone());
 
         let w = if args.append() {
-            ObsWriters::Three(oio::AppendObjectWriter::new(writer))
-        } else if args.content_length().is_some() {
-            ObsWriters::One(oio::OneShotWriter::new(writer))
+            ObsWriters::Two(oio::AppendObjectWriter::new(writer))
         } else {
-            ObsWriters::Two(oio::MultipartUploadWriter::new(writer))
+            ObsWriters::One(oio::MultipartUploadWriter::new(writer))
         };
 
         let w = if let Some(buffer_size) = args.buffer_size() {
diff --git a/core/src/services/obs/writer.rs b/core/src/services/obs/writer.rs
index 62882b1ca..94b8380a7 100644
--- a/core/src/services/obs/writer.rs
+++ b/core/src/services/obs/writer.rs
@@ -18,7 +18,6 @@
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use bytes::Bytes;
 use http::StatusCode;
 
 use super::core::*;
@@ -27,11 +26,8 @@ use crate::raw::oio::MultipartUploadPart;
 use crate::raw::*;
 use crate::*;
 
-pub type ObsWriters = oio::ThreeWaysWriter<
-    oio::OneShotWriter<ObsWriter>,
-    oio::MultipartUploadWriter<ObsWriter>,
-    oio::AppendObjectWriter<ObsWriter>,
->;
+pub type ObsWriters =
+    oio::TwoWaysWriter<oio::MultipartUploadWriter<ObsWriter>, 
oio::AppendObjectWriter<ObsWriter>>;
 
 pub struct ObsWriter {
     core: Arc<ObsCore>,
@@ -51,15 +47,14 @@ impl ObsWriter {
 }
 
 #[async_trait]
-impl oio::OneShotWrite for ObsWriter {
-    async fn write_once(&self, bs: Bytes) -> Result<()> {
-        let size = bs.len();
+impl oio::MultipartUploadWrite for ObsWriter {
+    async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
         let mut req = self.core.obs_put_object_request(
             &self.path,
-            Some(size as u64),
+            Some(size),
             self.op.content_type(),
             self.op.cache_control(),
-            AsyncBody::Bytes(bs),
+            body,
         )?;
 
         self.core.sign(&mut req).await?;
@@ -76,10 +71,7 @@ impl oio::OneShotWrite for ObsWriter {
             _ => Err(parse_error(resp).await?),
         }
     }
-}
 
-#[async_trait]
-impl oio::MultipartUploadWrite for ObsWriter {
     async fn initiate_part(&self) -> Result<String> {
         let resp = self
             .core
diff --git a/core/src/services/oss/backend.rs b/core/src/services/oss/backend.rs
index 212c42ce0..6d027ba5b 100644
--- a/core/src/services/oss/backend.rs
+++ b/core/src/services/oss/backend.rs
@@ -473,11 +473,9 @@ impl Accessor for OssBackend {
         let writer = OssWriter::new(self.core.clone(), path, args.clone());
 
         let w = if args.append() {
-            OssWriters::Three(oio::AppendObjectWriter::new(writer))
-        } else if args.content_length().is_some() {
-            OssWriters::One(oio::OneShotWriter::new(writer))
+            OssWriters::Two(oio::AppendObjectWriter::new(writer))
         } else {
-            OssWriters::Two(oio::MultipartUploadWriter::new(writer))
+            OssWriters::One(oio::MultipartUploadWriter::new(writer))
         };
 
         let w = if let Some(buffer_size) = args.buffer_size() {
diff --git a/core/src/services/oss/writer.rs b/core/src/services/oss/writer.rs
index 56d262f17..296055e29 100644
--- a/core/src/services/oss/writer.rs
+++ b/core/src/services/oss/writer.rs
@@ -18,7 +18,6 @@
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use bytes::Bytes;
 use http::StatusCode;
 
 use super::core::*;
@@ -26,11 +25,8 @@ use super::error::parse_error;
 use crate::raw::*;
 use crate::*;
 
-pub type OssWriters = oio::ThreeWaysWriter<
-    oio::OneShotWriter<OssWriter>,
-    oio::MultipartUploadWriter<OssWriter>,
-    oio::AppendObjectWriter<OssWriter>,
->;
+pub type OssWriters =
+    oio::TwoWaysWriter<oio::MultipartUploadWriter<OssWriter>, 
oio::AppendObjectWriter<OssWriter>>;
 
 pub struct OssWriter {
     core: Arc<OssCore>,
@@ -50,16 +46,15 @@ impl OssWriter {
 }
 
 #[async_trait]
-impl oio::OneShotWrite for OssWriter {
-    async fn write_once(&self, bs: Bytes) -> Result<()> {
-        let size = bs.len();
+impl oio::MultipartUploadWrite for OssWriter {
+    async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
         let mut req = self.core.oss_put_object_request(
             &self.path,
-            Some(size as u64),
+            Some(size),
             self.op.content_type(),
             self.op.content_disposition(),
             self.op.cache_control(),
-            AsyncBody::Bytes(bs),
+            body,
             false,
         )?;
 
@@ -77,10 +72,7 @@ impl oio::OneShotWrite for OssWriter {
             _ => Err(parse_error(resp).await?),
         }
     }
-}
 
-#[async_trait]
-impl oio::MultipartUploadWrite for OssWriter {
     async fn initiate_part(&self) -> Result<String> {
         let resp = self
             .core
diff --git a/core/src/services/s3/backend.rs b/core/src/services/s3/backend.rs
index d2d8bb039..aa2b95eb8 100644
--- a/core/src/services/s3/backend.rs
+++ b/core/src/services/s3/backend.rs
@@ -977,11 +977,7 @@ impl Accessor for S3Backend {
     async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, 
Self::Writer)> {
         let writer = S3Writer::new(self.core.clone(), path, args.clone());
 
-        let w = if args.content_length().is_some() {
-            S3Writers::One(oio::OneShotWriter::new(writer))
-        } else {
-            S3Writers::Two(oio::MultipartUploadWriter::new(writer))
-        };
+        let w = oio::MultipartUploadWriter::new(writer);
 
         let w = if let Some(buffer_size) = args.buffer_size() {
             let buffer_size = max(MINIMUM_MULTIPART_SIZE, buffer_size);
diff --git a/core/src/services/s3/writer.rs b/core/src/services/s3/writer.rs
index a3a1bd5bd..76c874ed0 100644
--- a/core/src/services/s3/writer.rs
+++ b/core/src/services/s3/writer.rs
@@ -18,7 +18,6 @@
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use bytes::Bytes;
 use http::StatusCode;
 
 use super::core::*;
@@ -26,8 +25,7 @@ use super::error::parse_error;
 use crate::raw::*;
 use crate::*;
 
-pub type S3Writers =
-    oio::TwoWaysWriter<oio::OneShotWriter<S3Writer>, 
oio::MultipartUploadWriter<S3Writer>>;
+pub type S3Writers = oio::MultipartUploadWriter<S3Writer>;
 
 pub struct S3Writer {
     core: Arc<S3Core>,
@@ -47,17 +45,15 @@ impl S3Writer {
 }
 
 #[async_trait]
-impl oio::OneShotWrite for S3Writer {
-    async fn write_once(&self, bs: Bytes) -> Result<()> {
-        let size = bs.len();
-
+impl oio::MultipartUploadWrite for S3Writer {
+    async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
         let mut req = self.core.s3_put_object_request(
             &self.path,
-            Some(size as u64),
+            Some(size),
             self.op.content_type(),
             self.op.content_disposition(),
             self.op.cache_control(),
-            AsyncBody::Bytes(bs),
+            body,
         )?;
 
         self.core.sign(&mut req).await?;
@@ -74,10 +70,7 @@ impl oio::OneShotWrite for S3Writer {
             _ => Err(parse_error(resp).await?),
         }
     }
-}
 
-#[async_trait]
-impl oio::MultipartUploadWrite for S3Writer {
     async fn initiate_part(&self) -> Result<String> {
         let resp = self
             .core

Reply via email to