CurtHagenlocher commented on code in PR #2678:
URL: https://github.com/apache/arrow-adbc/pull/2678#discussion_r2047511703
##########
csharp/src/Drivers/Databricks/DatabricksParameters.cs:
##########
@@ -42,6 +42,30 @@ public class DatabricksParameters : SparkParameters
/// Default value is 5 minutes if not specified.
/// </summary>
public const string CloudFetchTimeoutMinutes =
"adbc.databricks.cloudfetch.timeout_minutes";
+
+ /// <summary>
+ /// Maximum number of parallel downloads for CloudFetch operations.
+ /// Default value is 3 if not specified.
+ /// </summary>
+ public const string CloudFetchParallelDownloads =
"adbc.spark.cloudfetch.parallel_downloads";
Review Comment:
The existing CloudFetch properties are named `adbc.databricks.*` instead of
`adbc.spark.*`.
##########
csharp/src/Drivers/Databricks/CloudFetch/cloudfetch-pipeline-design.md:
##########
@@ -0,0 +1,60 @@
+current cloudfetch implementation download the cloud result file inline with
the reader, which generate performance problem, it slows down the reader when
need download the next result file
Review Comment:
This file needs an Apache copyright notice. See some of the other Markdown
files in this project for how to add one.
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs:
##########
@@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
public CloudFetchReader(DatabricksStatement statement, Schema schema,
bool isLz4Compressed)
{
- this.statement = statement;
- this.schema = schema;
+ this.schema = schema ?? throw new
ArgumentNullException(nameof(schema));
this.isLz4Compressed = isLz4Compressed;
- // Get configuration values from connection properties or use
defaults
+ // Check if prefetch is enabled
var connectionProps = statement.Connection.Properties;
-
- // Parse max retries
- int parsedMaxRetries = DefaultMaxRetries;
- if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
- int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
- parsedMaxRetries > 0)
+ isPrefetchEnabled = true; // Default to true
+ if
(connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled,
out string? prefetchEnabledStr) &&
+ bool.TryParse(prefetchEnabledStr, out bool
parsedPrefetchEnabled))
Review Comment:
The same pattern exists in other places e.g. CloudFetchDownloadManager.cs.
In general, it's better to return an error when parameter validation fails
instead of silently reverting to the default value.
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs:
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Adbc.Drivers.Databricks;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Manages the CloudFetch download pipeline.
+ /// </summary>
+ internal sealed class CloudFetchDownloadManager :
ICloudFetchDownloadManager
+ {
+ // Default values
+ private const int DefaultParallelDownloads = 3;
+ private const int DefaultPrefetchCount = 2;
+ private const int DefaultMemoryBufferSizeMB = 200;
+ private const bool DefaultPrefetchEnabled = true;
+
+ private readonly DatabricksStatement _statement;
+ private readonly Schema _schema;
+ private readonly bool _isLz4Compressed;
+ private readonly ICloudFetchMemoryBufferManager _memoryManager;
+ private readonly BlockingCollection<IDownloadResult> _downloadQueue;
+ private readonly BlockingCollection<IDownloadResult> _resultQueue;
+ private readonly ICloudFetchResultFetcher _resultFetcher;
+ private readonly ICloudFetchDownloader _downloader;
+ private readonly HttpClient _httpClient;
+ private bool _isDisposed;
+ private bool _isStarted;
+ private CancellationTokenSource? _cancellationTokenSource;
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="CloudFetchDownloadManager"/> class.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ public CloudFetchDownloadManager(DatabricksStatement statement, Schema
schema, bool isLz4Compressed)
+ {
+ _statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _isLz4Compressed = isLz4Compressed;
+
+ // Get configuration values from connection properties
+ var connectionProps = statement.Connection.Properties;
+
+ // Parse parallel downloads
+ int parallelDownloads = DefaultParallelDownloads;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads,
out string? parallelDownloadsStr) &&
+ int.TryParse(parallelDownloadsStr, out int
parsedParallelDownloads) &&
+ parsedParallelDownloads > 0)
+ {
+ parallelDownloads = parsedParallelDownloads;
+ }
+
+ // Parse prefetch count
+ int prefetchCount = DefaultPrefetchCount;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out
string? prefetchCountStr) &&
+ int.TryParse(prefetchCountStr, out int parsedPrefetchCount) &&
+ parsedPrefetchCount > 0)
+ {
+ prefetchCount = parsedPrefetchCount;
+ }
+
+ // Parse memory buffer size
+ int memoryBufferSizeMB = DefaultMemoryBufferSizeMB;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize,
out string? memoryBufferSizeStr) &&
+ int.TryParse(memoryBufferSizeStr, out int
parsedMemoryBufferSize) &&
+ parsedMemoryBufferSize > 0)
+ {
+ memoryBufferSizeMB = parsedMemoryBufferSize;
+ }
+
+ // Parse max retries
+ int maxRetries = 3;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
+ int.TryParse(maxRetriesStr, out int parsedMaxRetries) &&
+ parsedMaxRetries > 0)
+ {
+ maxRetries = parsedMaxRetries;
+ }
+
+ // Parse retry delay
+ int retryDelayMs = 500;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out
string? retryDelayStr) &&
+ int.TryParse(retryDelayStr, out int parsedRetryDelay) &&
+ parsedRetryDelay > 0)
+ {
+ retryDelayMs = parsedRetryDelay;
+ }
+
+ // Parse timeout minutes
+ int timeoutMinutes = 5;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out
string? timeoutStr) &&
+ int.TryParse(timeoutStr, out int parsedTimeout) &&
+ parsedTimeout > 0)
+ {
+ timeoutMinutes = parsedTimeout;
+ }
+
+ // Initialize the memory manager
+ _memoryManager = new
CloudFetchMemoryBufferManager(memoryBufferSizeMB);
+
+ // Initialize the queues with bounded capacity
+ _downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+ _resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+
+ // Initialize the HTTP client
+ _httpClient = new HttpClient
+ {
+ Timeout = TimeSpan.FromMinutes(timeoutMinutes)
+ };
+
+ // Initialize the result fetcher
+ _resultFetcher = new CloudFetchResultFetcher(
+ _statement,
+ _memoryManager,
+ _downloadQueue,
+ 2000000);
+
+ // Initialize the downloader
+ _downloader = new CloudFetchDownloader(
+ _downloadQueue,
+ _resultQueue,
+ _memoryManager,
+ _httpClient,
+ parallelDownloads,
+ _isLz4Compressed,
+ maxRetries,
+ retryDelayMs);
+ }
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="CloudFetchDownloadManager"/> class.
+ /// This constructor is intended for testing purposes only.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ /// <param name="resultFetcher">The result fetcher.</param>
+ /// <param name="downloader">The downloader.</param>
+ internal CloudFetchDownloadManager(
+ DatabricksStatement statement,
+ Schema schema,
+ bool isLz4Compressed,
+ ICloudFetchResultFetcher resultFetcher,
+ ICloudFetchDownloader downloader)
+ {
+ _statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _isLz4Compressed = isLz4Compressed;
+ _resultFetcher = resultFetcher ?? throw new
ArgumentNullException(nameof(resultFetcher));
+ _downloader = downloader ?? throw new
ArgumentNullException(nameof(downloader));
+
+ // Create empty collections for the test
+ _memoryManager = new
CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB);
+ _downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
+ _resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
+ _httpClient = new HttpClient();
+ }
+
+ /// <inheritdoc />
+ public bool HasMoreResults => !_downloader.IsCompleted ||
!_resultQueue.IsCompleted;
+
+ /// <inheritdoc />
+ public async Task<IDownloadResult?>
GetNextDownloadedFileAsync(CancellationToken cancellationToken)
+ {
+ ThrowIfDisposed();
+
+ if (!_isStarted)
+ {
+ throw new InvalidOperationException("Download manager has not
been started.");
+ }
+
+ try
+ {
+ return await
_downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ // If there's an error in the downloader, check if there's
also an error in the fetcher
+ if (_resultFetcher.HasError)
+ {
+ throw new AggregateException("Errors in download
pipeline", new[] { ex, _resultFetcher.Error! });
+ }
+ throw;
+ }
Review Comment:
```suggestion
try
{
return await
_downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (_resultFetcher.HasError)
{
throw new AggregateException("Errors in download pipeline",
new[] { ex, _resultFetcher.Error! });
}
```
##########
csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using Apache.Hive.Service.Rpc.Thrift;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Interface for accessing HiveServer2Statement properties needed by
CloudFetchResultFetcher.
+ /// </summary>
+ public interface IHiveServer2Statement
Review Comment:
This should not be public.
It looks like all of the Thrift classes were mistakenly generated as public
instead of internal. I'll make sure to fix that. We don't want internal
implementation details to be part of the public API.
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs:
##########
@@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
public CloudFetchReader(DatabricksStatement statement, Schema schema,
bool isLz4Compressed)
{
- this.statement = statement;
- this.schema = schema;
+ this.schema = schema ?? throw new
ArgumentNullException(nameof(schema));
Review Comment:
This isn't strictly needed because it's an internal class and `nullable` is
enabled and the schema argument is already declared as non-nullable.
##########
csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs:
##########
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Hive.Service.Rpc.Thrift;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Represents a downloaded result file with its associated metadata.
+ /// </summary>
+ public interface IDownloadResult : IDisposable
+ {
+ /// <summary>
+ /// Gets the link information for this result.
+ /// </summary>
+ TSparkArrowResultLink Link { get; }
+
+ /// <summary>
+ /// Gets the stream containing the downloaded data.
+ /// </summary>
+ Stream DataStream { get; }
+
+ /// <summary>
+ /// Gets the size of the downloaded data in bytes.
+ /// </summary>
+ long Size { get; }
+
+ /// <summary>
+ /// Gets a task that completes when the download is finished.
+ /// </summary>
+ Task DownloadCompletedTask { get; }
Review Comment:
Nothing seems to use this
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs:
##########
@@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
public CloudFetchReader(DatabricksStatement statement, Schema schema,
bool isLz4Compressed)
{
- this.statement = statement;
- this.schema = schema;
+ this.schema = schema ?? throw new
ArgumentNullException(nameof(schema));
this.isLz4Compressed = isLz4Compressed;
- // Get configuration values from connection properties or use
defaults
+ // Check if prefetch is enabled
var connectionProps = statement.Connection.Properties;
-
- // Parse max retries
- int parsedMaxRetries = DefaultMaxRetries;
- if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
- int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
- parsedMaxRetries > 0)
+ isPrefetchEnabled = true; // Default to true
+ if
(connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled,
out string? prefetchEnabledStr) &&
+ bool.TryParse(prefetchEnabledStr, out bool
parsedPrefetchEnabled))
{
- // Value was successfully parsed
+ isPrefetchEnabled = parsedPrefetchEnabled;
}
- else
- {
- parsedMaxRetries = DefaultMaxRetries;
- }
- this.maxRetries = parsedMaxRetries;
- // Parse retry delay
- int parsedRetryDelay = DefaultRetryDelayMs;
- if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out
string? retryDelayStr) &&
- int.TryParse(retryDelayStr, out parsedRetryDelay) &&
- parsedRetryDelay > 0)
+ // Initialize the download manager
+ if (isPrefetchEnabled)
{
- // Value was successfully parsed
+ downloadManager = new CloudFetchDownloadManager(statement,
schema, isLz4Compressed);
+ downloadManager.StartAsync().Wait();
}
else
{
- parsedRetryDelay = DefaultRetryDelayMs;
+ // If prefetch is disabled, use the legacy implementation
+ throw new NotImplementedException("Legacy implementation
without prefetch is not supported.");
Review Comment:
It seems a little odd to have an option whose only effect is to cause an
operation to fail when set and the service decides to return file-based
results. Is that the intent for the flag?
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs:
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Adbc.Drivers.Databricks;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Manages the CloudFetch download pipeline.
+ /// </summary>
+ internal sealed class CloudFetchDownloadManager :
ICloudFetchDownloadManager
+ {
+ // Default values
+ private const int DefaultParallelDownloads = 3;
+ private const int DefaultPrefetchCount = 2;
+ private const int DefaultMemoryBufferSizeMB = 200;
+ private const bool DefaultPrefetchEnabled = true;
+
+ private readonly DatabricksStatement _statement;
+ private readonly Schema _schema;
+ private readonly bool _isLz4Compressed;
+ private readonly ICloudFetchMemoryBufferManager _memoryManager;
+ private readonly BlockingCollection<IDownloadResult> _downloadQueue;
+ private readonly BlockingCollection<IDownloadResult> _resultQueue;
+ private readonly ICloudFetchResultFetcher _resultFetcher;
+ private readonly ICloudFetchDownloader _downloader;
+ private readonly HttpClient _httpClient;
+ private bool _isDisposed;
+ private bool _isStarted;
+ private CancellationTokenSource? _cancellationTokenSource;
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="CloudFetchDownloadManager"/> class.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ public CloudFetchDownloadManager(DatabricksStatement statement, Schema
schema, bool isLz4Compressed)
+ {
+ _statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _isLz4Compressed = isLz4Compressed;
+
+ // Get configuration values from connection properties
+ var connectionProps = statement.Connection.Properties;
+
+ // Parse parallel downloads
+ int parallelDownloads = DefaultParallelDownloads;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads,
out string? parallelDownloadsStr) &&
+ int.TryParse(parallelDownloadsStr, out int
parsedParallelDownloads) &&
+ parsedParallelDownloads > 0)
+ {
+ parallelDownloads = parsedParallelDownloads;
+ }
+
+ // Parse prefetch count
+ int prefetchCount = DefaultPrefetchCount;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out
string? prefetchCountStr) &&
+ int.TryParse(prefetchCountStr, out int parsedPrefetchCount) &&
+ parsedPrefetchCount > 0)
+ {
+ prefetchCount = parsedPrefetchCount;
+ }
+
+ // Parse memory buffer size
+ int memoryBufferSizeMB = DefaultMemoryBufferSizeMB;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize,
out string? memoryBufferSizeStr) &&
+ int.TryParse(memoryBufferSizeStr, out int
parsedMemoryBufferSize) &&
+ parsedMemoryBufferSize > 0)
+ {
+ memoryBufferSizeMB = parsedMemoryBufferSize;
+ }
+
+ // Parse max retries
+ int maxRetries = 3;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
+ int.TryParse(maxRetriesStr, out int parsedMaxRetries) &&
+ parsedMaxRetries > 0)
+ {
+ maxRetries = parsedMaxRetries;
+ }
+
+ // Parse retry delay
+ int retryDelayMs = 500;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out
string? retryDelayStr) &&
+ int.TryParse(retryDelayStr, out int parsedRetryDelay) &&
+ parsedRetryDelay > 0)
+ {
+ retryDelayMs = parsedRetryDelay;
+ }
+
+ // Parse timeout minutes
+ int timeoutMinutes = 5;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out
string? timeoutStr) &&
+ int.TryParse(timeoutStr, out int parsedTimeout) &&
+ parsedTimeout > 0)
+ {
+ timeoutMinutes = parsedTimeout;
+ }
+
+ // Initialize the memory manager
+ _memoryManager = new
CloudFetchMemoryBufferManager(memoryBufferSizeMB);
+
+ // Initialize the queues with bounded capacity
+ _downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+ _resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+
+ // Initialize the HTTP client
+ _httpClient = new HttpClient
+ {
+ Timeout = TimeSpan.FromMinutes(timeoutMinutes)
+ };
+
+ // Initialize the result fetcher
+ _resultFetcher = new CloudFetchResultFetcher(
+ _statement,
+ _memoryManager,
+ _downloadQueue,
+ 2000000);
Review Comment:
Should this be configurable? Can it be lifted to a named constant?
##########
csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs:
##########
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Hive.Service.Rpc.Thrift;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Represents a downloaded result file with its associated metadata.
+ /// </summary>
+ public interface IDownloadResult : IDisposable
Review Comment:
These interfaces should not be public contracts
##########
csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs:
##########
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.IO;
+using System.Threading.Tasks;
+using Apache.Hive.Service.Rpc.Thrift;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Special marker class that indicates the end of results in the download
queue.
+ /// </summary>
+ internal sealed class EndOfResultsGuard : IDownloadResult
+ {
+ private static readonly Task<bool> CompletedTask =
Task.FromResult(true);
Review Comment:
Can we use `Task.CompletedTask` instead of defining a new one?
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs:
##########
@@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
public CloudFetchReader(DatabricksStatement statement, Schema schema,
bool isLz4Compressed)
{
- this.statement = statement;
- this.schema = schema;
+ this.schema = schema ?? throw new
ArgumentNullException(nameof(schema));
this.isLz4Compressed = isLz4Compressed;
- // Get configuration values from connection properties or use
defaults
+ // Check if prefetch is enabled
var connectionProps = statement.Connection.Properties;
-
- // Parse max retries
- int parsedMaxRetries = DefaultMaxRetries;
- if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
- int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
- parsedMaxRetries > 0)
+ isPrefetchEnabled = true; // Default to true
+ if
(connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled,
out string? prefetchEnabledStr) &&
+ bool.TryParse(prefetchEnabledStr, out bool
parsedPrefetchEnabled))
Review Comment:
consider throwing an exception if the value can't be parsed as boolean
##########
csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs:
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Adbc.Drivers.Databricks;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+{
+ /// <summary>
+ /// Manages the CloudFetch download pipeline.
+ /// </summary>
+ internal sealed class CloudFetchDownloadManager :
ICloudFetchDownloadManager
+ {
+ // Default values
+ private const int DefaultParallelDownloads = 3;
+ private const int DefaultPrefetchCount = 2;
+ private const int DefaultMemoryBufferSizeMB = 200;
+ private const bool DefaultPrefetchEnabled = true;
+
+ private readonly DatabricksStatement _statement;
+ private readonly Schema _schema;
+ private readonly bool _isLz4Compressed;
+ private readonly ICloudFetchMemoryBufferManager _memoryManager;
+ private readonly BlockingCollection<IDownloadResult> _downloadQueue;
+ private readonly BlockingCollection<IDownloadResult> _resultQueue;
+ private readonly ICloudFetchResultFetcher _resultFetcher;
+ private readonly ICloudFetchDownloader _downloader;
+ private readonly HttpClient _httpClient;
+ private bool _isDisposed;
+ private bool _isStarted;
+ private CancellationTokenSource? _cancellationTokenSource;
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="CloudFetchDownloadManager"/> class.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ public CloudFetchDownloadManager(DatabricksStatement statement, Schema
schema, bool isLz4Compressed)
+ {
+ _statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _isLz4Compressed = isLz4Compressed;
+
+ // Get configuration values from connection properties
+ var connectionProps = statement.Connection.Properties;
+
+ // Parse parallel downloads
+ int parallelDownloads = DefaultParallelDownloads;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads,
out string? parallelDownloadsStr) &&
+ int.TryParse(parallelDownloadsStr, out int
parsedParallelDownloads) &&
+ parsedParallelDownloads > 0)
+ {
+ parallelDownloads = parsedParallelDownloads;
+ }
+
+ // Parse prefetch count
+ int prefetchCount = DefaultPrefetchCount;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out
string? prefetchCountStr) &&
+ int.TryParse(prefetchCountStr, out int parsedPrefetchCount) &&
+ parsedPrefetchCount > 0)
+ {
+ prefetchCount = parsedPrefetchCount;
+ }
+
+ // Parse memory buffer size
+ int memoryBufferSizeMB = DefaultMemoryBufferSizeMB;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize,
out string? memoryBufferSizeStr) &&
+ int.TryParse(memoryBufferSizeStr, out int
parsedMemoryBufferSize) &&
+ parsedMemoryBufferSize > 0)
+ {
+ memoryBufferSizeMB = parsedMemoryBufferSize;
+ }
+
+ // Parse max retries
+ int maxRetries = 3;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out
string? maxRetriesStr) &&
+ int.TryParse(maxRetriesStr, out int parsedMaxRetries) &&
+ parsedMaxRetries > 0)
+ {
+ maxRetries = parsedMaxRetries;
+ }
+
+ // Parse retry delay
+ int retryDelayMs = 500;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out
string? retryDelayStr) &&
+ int.TryParse(retryDelayStr, out int parsedRetryDelay) &&
+ parsedRetryDelay > 0)
+ {
+ retryDelayMs = parsedRetryDelay;
+ }
+
+ // Parse timeout minutes
+ int timeoutMinutes = 5;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out
string? timeoutStr) &&
+ int.TryParse(timeoutStr, out int parsedTimeout) &&
+ parsedTimeout > 0)
+ {
+ timeoutMinutes = parsedTimeout;
+ }
+
+ // Initialize the memory manager
+ _memoryManager = new
CloudFetchMemoryBufferManager(memoryBufferSizeMB);
+
+ // Initialize the queues with bounded capacity
+ _downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+ _resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
+
+ // Initialize the HTTP client
+ _httpClient = new HttpClient
+ {
+ Timeout = TimeSpan.FromMinutes(timeoutMinutes)
+ };
+
+ // Initialize the result fetcher
+ _resultFetcher = new CloudFetchResultFetcher(
+ _statement,
+ _memoryManager,
+ _downloadQueue,
+ 2000000);
+
+ // Initialize the downloader
+ _downloader = new CloudFetchDownloader(
+ _downloadQueue,
+ _resultQueue,
+ _memoryManager,
+ _httpClient,
+ parallelDownloads,
+ _isLz4Compressed,
+ maxRetries,
+ retryDelayMs);
+ }
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="CloudFetchDownloadManager"/> class.
+ /// This constructor is intended for testing purposes only.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ /// <param name="resultFetcher">The result fetcher.</param>
+ /// <param name="downloader">The downloader.</param>
+ internal CloudFetchDownloadManager(
+ DatabricksStatement statement,
+ Schema schema,
+ bool isLz4Compressed,
+ ICloudFetchResultFetcher resultFetcher,
+ ICloudFetchDownloader downloader)
+ {
+ _statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _isLz4Compressed = isLz4Compressed;
+ _resultFetcher = resultFetcher ?? throw new
ArgumentNullException(nameof(resultFetcher));
+ _downloader = downloader ?? throw new
ArgumentNullException(nameof(downloader));
+
+ // Create empty collections for the test
+ _memoryManager = new
CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB);
+ _downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
+ _resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
+ _httpClient = new HttpClient();
+ }
+
+ /// <inheritdoc />
+ public bool HasMoreResults => !_downloader.IsCompleted ||
!_resultQueue.IsCompleted;
+
+ /// <inheritdoc />
+ public async Task<IDownloadResult?>
GetNextDownloadedFileAsync(CancellationToken cancellationToken)
+ {
+ ThrowIfDisposed();
+
+ if (!_isStarted)
+ {
+ throw new InvalidOperationException("Download manager has not
been started.");
+ }
+
+ try
+ {
+ return await
_downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ // If there's an error in the downloader, check if there's
also an error in the fetcher
+ if (_resultFetcher.HasError)
+ {
+ throw new AggregateException("Errors in download
pipeline", new[] { ex, _resultFetcher.Error! });
+ }
+ throw;
+ }
+ }
+
+ /// <inheritdoc />
+ public async Task StartAsync()
+ {
+ ThrowIfDisposed();
+
+ if (_isStarted)
+ {
+ throw new InvalidOperationException("Download manager is
already started.");
+ }
+
+ // Create a new cancellation token source
+ _cancellationTokenSource = new CancellationTokenSource();
+
+ // Start the result fetcher
+ await
_resultFetcher.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false);
+
+ // Start the downloader
+ await
_downloader.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false);
+
+ _isStarted = true;
+ }
+
+ /// <inheritdoc />
+ public async Task StopAsync()
+ {
+ if (!_isStarted)
+ {
+ return;
+ }
+
+ // Cancel the token to signal all operations to stop
+ _cancellationTokenSource?.Cancel();
+
+ // Stop the downloader
+ await _downloader.StopAsync().ConfigureAwait(false);
+
+ // Stop the result fetcher
+ await _resultFetcher.StopAsync().ConfigureAwait(false);
+
+ // Dispose the cancellation token source
+ _cancellationTokenSource?.Dispose();
+ _cancellationTokenSource = null;
+
+ _isStarted = false;
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ if (_isDisposed)
+ {
+ return;
+ }
+
+ // Stop the pipeline
+ StopAsync().GetAwaiter().GetResult();
+
+ // Dispose the HTTP client
+ _httpClient.Dispose();
+
+ // Dispose the cancellation token source if it hasn't been
disposed yet
+ _cancellationTokenSource?.Dispose();
+ _cancellationTokenSource = null;
+
+ // Mark the queues as completed to release any waiting threads
+ _downloadQueue.CompleteAdding();
+ _resultQueue.CompleteAdding();
+
+ // Dispose any remaining results
+ foreach (var result in
_resultQueue.GetConsumingEnumerable(CancellationToken.None))
+ {
+ if (result != EndOfResultsGuard.Instance)
+ {
+ result.Dispose();
+ }
Review Comment:
Given that `EndOfResultsGuard.Dispose()` is a nop, we could just
`.Dispose()` without the `if` guard.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]