This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opendal.git


The following commit(s) were added to refs/heads/main by this push:
     new e948ec826d feat(core): Add reader size check in complete reader (#4690)
e948ec826d is described below

commit e948ec826ded0c7cc98af939e574e167e9ae7086
Author: Xuanwo <[email protected]>
AuthorDate: Wed Jun 5 19:37:56 2024 +0800

    feat(core): Add reader size check in complete reader (#4690)
    
    Signed-off-by: Xuanwo <[email protected]>
---
 core/src/layers/complete.rs | 63 +++++++++++++++++++++++++++++++++++++++++----
 core/src/types/error.rs     |  4 +--
 2 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/core/src/layers/complete.rs b/core/src/layers/complete.rs
index 8e265b2171..97483f10b1 100644
--- a/core/src/layers/complete.rs
+++ b/core/src/layers/complete.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::cmp::Ordering;
 use std::fmt::Debug;
 use std::fmt::Formatter;
 use std::sync::Arc;
@@ -398,10 +399,12 @@ impl<A: Access> LayeredAccess for CompleteAccessor<A> {
         if !capability.read {
             return Err(self.new_unsupported_error(Operation::Read));
         }
+
+        let size = args.range().size();
         self.inner
             .read(path, args)
             .await
-            .map(|(rp, r)| (rp, CompleteReader(r)))
+            .map(|(rp, r)| (rp, CompleteReader::new(r, size)))
     }
 
     async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, 
Self::Writer)> {
@@ -514,9 +517,11 @@ impl<A: Access> LayeredAccess for CompleteAccessor<A> {
         if !capability.read || !capability.blocking {
             return Err(self.new_unsupported_error(Operation::Read));
         }
+
+        let size = args.range().size();
         self.inner
             .blocking_read(path, args)
-            .map(|(rp, r)| (rp, CompleteReader(r)))
+            .map(|(rp, r)| (rp, CompleteReader::new(r, size)))
     }
 
     fn blocking_write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, 
Self::BlockingWriter)> {
@@ -584,18 +589,66 @@ impl<A: Access> LayeredAccess for CompleteAccessor<A> {
 pub type CompleteLister<A, P> =
     FourWays<P, FlatLister<Arc<A>, P>, PrefixLister<P>, 
PrefixLister<FlatLister<Arc<A>, P>>>;
 
-pub struct CompleteReader<R>(R);
+pub struct CompleteReader<R> {
+    inner: R,
+    size: Option<u64>,
+    read: u64,
+}
+
+impl<R> CompleteReader<R> {
+    pub fn new(inner: R, size: Option<u64>) -> Self {
+        Self {
+            inner,
+            size,
+            read: 0,
+        }
+    }
+
+    pub fn check(&self) -> Result<()> {
+        let Some(size) = self.size else {
+            return Ok(());
+        };
+
+        match self.read.cmp(&size) {
+            Ordering::Equal => Ok(()),
+            Ordering::Less => Err(
+                Error::new(ErrorKind::Unexpected, "reader got too little data")
+                    .with_context("expect", size)
+                    .with_context("actual", self.read),
+            ),
+            Ordering::Greater => Err(
+                Error::new(ErrorKind::Unexpected, "reader got too much data")
+                    .with_context("expect", size)
+                    .with_context("actual", self.read),
+            ),
+        }
+    }
+}
 
 impl<R: oio::Read> oio::Read for CompleteReader<R> {
     async fn read(&mut self) -> Result<Buffer> {
-        let buf = self.0.read().await?;
+        let buf = self.inner.read().await?;
+
+        if buf.is_empty() {
+            self.check()?;
+        } else {
+            self.read += buf.len() as u64;
+        }
+
         Ok(buf)
     }
 }
 
 impl<R: oio::BlockingRead> oio::BlockingRead for CompleteReader<R> {
     fn read(&mut self) -> Result<Buffer> {
-        let buf = self.0.read()?;
+        let buf = self.inner.read()?;
+
+        if buf.is_empty() {
+            self.check()?;
+        } else {
+            self.read += buf.len() as u64;
+        }
+
         Ok(buf)
     }
 }
diff --git a/core/src/types/error.rs b/core/src/types/error.rs
index f2c258dc61..6576f3dbfc 100644
--- a/core/src/types/error.rs
+++ b/core/src/types/error.rs
@@ -334,8 +334,8 @@ impl Error {
     }
 
     /// Add more context in error.
-    pub fn with_context(mut self, key: &'static str, value: impl Into<String>) 
-> Self {
-        self.context.push((key, value.into()));
+    pub fn with_context(mut self, key: &'static str, value: impl ToString) -> 
Self {
+        self.context.push((key, value.to_string()));
         self
     }
 

Reply via email to