CurtHagenlocher commented on code in PR #2678: URL: https://github.com/apache/arrow-adbc/pull/2678#discussion_r2047511703
########## csharp/src/Drivers/Databricks/DatabricksParameters.cs: ########## @@ -42,6 +42,30 @@ public class DatabricksParameters : SparkParameters /// Default value is 5 minutes if not specified. /// </summary> public const string CloudFetchTimeoutMinutes = "adbc.databricks.cloudfetch.timeout_minutes"; + + /// <summary> + /// Maximum number of parallel downloads for CloudFetch operations. + /// Default value is 3 if not specified. + /// </summary> + public const string CloudFetchParallelDownloads = "adbc.spark.cloudfetch.parallel_downloads"; Review Comment: The existing CloudFetch properties are named `adbc.databricks.*` instead of `adbc.spark.*`. ########## csharp/src/Drivers/Databricks/CloudFetch/cloudfetch-pipeline-design.md: ########## @@ -0,0 +1,60 @@ +current cloudfetch implementation download the cloud result file inline with the reader, which generate performance problem, it slows down the reader when need download the next result file Review Comment: This file needs an Apache copyright notice. See some of the other Markdown files in this project for how to add one. ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs: ########## @@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> public CloudFetchReader(DatabricksStatement statement, Schema schema, bool isLz4Compressed) { - this.statement = statement; - this.schema = schema; + this.schema = schema ?? throw new ArgumentNullException(nameof(schema)); this.isLz4Compressed = isLz4Compressed; - // Get configuration values from connection properties or use defaults + // Check if prefetch is enabled var connectionProps = statement.Connection.Properties; - - // Parse max retries - int parsedMaxRetries = DefaultMaxRetries; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && - int.TryParse(maxRetriesStr, out parsedMaxRetries) && - parsedMaxRetries > 0) + isPrefetchEnabled = true; // Default to true + if (connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled, out string? prefetchEnabledStr) && + bool.TryParse(prefetchEnabledStr, out bool parsedPrefetchEnabled)) Review Comment: The same pattern exists in other places e.g. CloudFetchDownloadManager.cs. In general, it's better to return an error when parameter validation fails instead of silently reverting to the default value. ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs: ########## @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Databricks; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Manages the CloudFetch download pipeline. + /// </summary> + internal sealed class CloudFetchDownloadManager : ICloudFetchDownloadManager + { + // Default values + private const int DefaultParallelDownloads = 3; + private const int DefaultPrefetchCount = 2; + private const int DefaultMemoryBufferSizeMB = 200; + private const bool DefaultPrefetchEnabled = true; + + private readonly DatabricksStatement _statement; + private readonly Schema _schema; + private readonly bool _isLz4Compressed; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly ICloudFetchResultFetcher _resultFetcher; + private readonly ICloudFetchDownloader _downloader; + private readonly HttpClient _httpClient; + private bool _isDisposed; + private bool _isStarted; + private CancellationTokenSource? _cancellationTokenSource; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + public CloudFetchDownloadManager(DatabricksStatement statement, Schema schema, bool isLz4Compressed) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + + // Get configuration values from connection properties + var connectionProps = statement.Connection.Properties; + + // Parse parallel downloads + int parallelDownloads = DefaultParallelDownloads; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads, out string? parallelDownloadsStr) && + int.TryParse(parallelDownloadsStr, out int parsedParallelDownloads) && + parsedParallelDownloads > 0) + { + parallelDownloads = parsedParallelDownloads; + } + + // Parse prefetch count + int prefetchCount = DefaultPrefetchCount; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out string? prefetchCountStr) && + int.TryParse(prefetchCountStr, out int parsedPrefetchCount) && + parsedPrefetchCount > 0) + { + prefetchCount = parsedPrefetchCount; + } + + // Parse memory buffer size + int memoryBufferSizeMB = DefaultMemoryBufferSizeMB; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize, out string? memoryBufferSizeStr) && + int.TryParse(memoryBufferSizeStr, out int parsedMemoryBufferSize) && + parsedMemoryBufferSize > 0) + { + memoryBufferSizeMB = parsedMemoryBufferSize; + } + + // Parse max retries + int maxRetries = 3; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && + int.TryParse(maxRetriesStr, out int parsedMaxRetries) && + parsedMaxRetries > 0) + { + maxRetries = parsedMaxRetries; + } + + // Parse retry delay + int retryDelayMs = 500; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) && + int.TryParse(retryDelayStr, out int parsedRetryDelay) && + parsedRetryDelay > 0) + { + retryDelayMs = parsedRetryDelay; + } + + // Parse timeout minutes + int timeoutMinutes = 5; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr) && + int.TryParse(timeoutStr, out int parsedTimeout) && + parsedTimeout > 0) + { + timeoutMinutes = parsedTimeout; + } + + // Initialize the memory manager + _memoryManager = new CloudFetchMemoryBufferManager(memoryBufferSizeMB); + + // Initialize the queues with bounded capacity + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + + // Initialize the HTTP client + _httpClient = new HttpClient + { + Timeout = TimeSpan.FromMinutes(timeoutMinutes) + }; + + // Initialize the result fetcher + _resultFetcher = new CloudFetchResultFetcher( + _statement, + _memoryManager, + _downloadQueue, + 2000000); + + // Initialize the downloader + _downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _memoryManager, + _httpClient, + parallelDownloads, + _isLz4Compressed, + maxRetries, + retryDelayMs); + } + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// This constructor is intended for testing purposes only. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + /// <param name="resultFetcher">The result fetcher.</param> + /// <param name="downloader">The downloader.</param> + internal CloudFetchDownloadManager( + DatabricksStatement statement, + Schema schema, + bool isLz4Compressed, + ICloudFetchResultFetcher resultFetcher, + ICloudFetchDownloader downloader) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + _resultFetcher = resultFetcher ?? throw new ArgumentNullException(nameof(resultFetcher)); + _downloader = downloader ?? throw new ArgumentNullException(nameof(downloader)); + + // Create empty collections for the test + _memoryManager = new CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB); + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _httpClient = new HttpClient(); + } + + /// <inheritdoc /> + public bool HasMoreResults => !_downloader.IsCompleted || !_resultQueue.IsCompleted; + + /// <inheritdoc /> + public async Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); + + if (!_isStarted) + { + throw new InvalidOperationException("Download manager has not been started."); + } + + try + { + return await _downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // If there's an error in the downloader, check if there's also an error in the fetcher + if (_resultFetcher.HasError) + { + throw new AggregateException("Errors in download pipeline", new[] { ex, _resultFetcher.Error! }); + } + throw; + } Review Comment: ```suggestion try { return await _downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false); } catch (Exception ex) when (_resultFetcher.HasError) { throw new AggregateException("Errors in download pipeline", new[] { ex, _resultFetcher.Error! }); } ``` ########## csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs: ########## @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Interface for accessing HiveServer2Statement properties needed by CloudFetchResultFetcher. + /// </summary> + public interface IHiveServer2Statement Review Comment: This should not be public. It looks like all of the Thrift classes were mistakenly generated as public instead of internal. I'll make sure to fix that. We don't want internal implementation details to be part of the public API. ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs: ########## @@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> public CloudFetchReader(DatabricksStatement statement, Schema schema, bool isLz4Compressed) { - this.statement = statement; - this.schema = schema; + this.schema = schema ?? throw new ArgumentNullException(nameof(schema)); Review Comment: This isn't strictly needed because it's an internal class and `nullable` is enabled and the schema argument is already declared as non-nullable. ########## csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs: ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Represents a downloaded result file with its associated metadata. + /// </summary> + public interface IDownloadResult : IDisposable + { + /// <summary> + /// Gets the link information for this result. + /// </summary> + TSparkArrowResultLink Link { get; } + + /// <summary> + /// Gets the stream containing the downloaded data. + /// </summary> + Stream DataStream { get; } + + /// <summary> + /// Gets the size of the downloaded data in bytes. + /// </summary> + long Size { get; } + + /// <summary> + /// Gets a task that completes when the download is finished. + /// </summary> + Task DownloadCompletedTask { get; } Review Comment: Nothing seems to use this ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs: ########## @@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> public CloudFetchReader(DatabricksStatement statement, Schema schema, bool isLz4Compressed) { - this.statement = statement; - this.schema = schema; + this.schema = schema ?? throw new ArgumentNullException(nameof(schema)); this.isLz4Compressed = isLz4Compressed; - // Get configuration values from connection properties or use defaults + // Check if prefetch is enabled var connectionProps = statement.Connection.Properties; - - // Parse max retries - int parsedMaxRetries = DefaultMaxRetries; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && - int.TryParse(maxRetriesStr, out parsedMaxRetries) && - parsedMaxRetries > 0) + isPrefetchEnabled = true; // Default to true + if (connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled, out string? prefetchEnabledStr) && + bool.TryParse(prefetchEnabledStr, out bool parsedPrefetchEnabled)) { - // Value was successfully parsed + isPrefetchEnabled = parsedPrefetchEnabled; } - else - { - parsedMaxRetries = DefaultMaxRetries; - } - this.maxRetries = parsedMaxRetries; - // Parse retry delay - int parsedRetryDelay = DefaultRetryDelayMs; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) && - int.TryParse(retryDelayStr, out parsedRetryDelay) && - parsedRetryDelay > 0) + // Initialize the download manager + if (isPrefetchEnabled) { - // Value was successfully parsed + downloadManager = new CloudFetchDownloadManager(statement, schema, isLz4Compressed); + downloadManager.StartAsync().Wait(); } else { - parsedRetryDelay = DefaultRetryDelayMs; + // If prefetch is disabled, use the legacy implementation + throw new NotImplementedException("Legacy implementation without prefetch is not supported."); Review Comment: It seems a little odd to have an option whose only effect is to cause an operation to fail when set and the service decides to return file-based results. Is that the intent for the flag? ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs: ########## @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Databricks; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Manages the CloudFetch download pipeline. + /// </summary> + internal sealed class CloudFetchDownloadManager : ICloudFetchDownloadManager + { + // Default values + private const int DefaultParallelDownloads = 3; + private const int DefaultPrefetchCount = 2; + private const int DefaultMemoryBufferSizeMB = 200; + private const bool DefaultPrefetchEnabled = true; + + private readonly DatabricksStatement _statement; + private readonly Schema _schema; + private readonly bool _isLz4Compressed; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly ICloudFetchResultFetcher _resultFetcher; + private readonly ICloudFetchDownloader _downloader; + private readonly HttpClient _httpClient; + private bool _isDisposed; + private bool _isStarted; + private CancellationTokenSource? _cancellationTokenSource; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + public CloudFetchDownloadManager(DatabricksStatement statement, Schema schema, bool isLz4Compressed) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + + // Get configuration values from connection properties + var connectionProps = statement.Connection.Properties; + + // Parse parallel downloads + int parallelDownloads = DefaultParallelDownloads; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads, out string? parallelDownloadsStr) && + int.TryParse(parallelDownloadsStr, out int parsedParallelDownloads) && + parsedParallelDownloads > 0) + { + parallelDownloads = parsedParallelDownloads; + } + + // Parse prefetch count + int prefetchCount = DefaultPrefetchCount; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out string? prefetchCountStr) && + int.TryParse(prefetchCountStr, out int parsedPrefetchCount) && + parsedPrefetchCount > 0) + { + prefetchCount = parsedPrefetchCount; + } + + // Parse memory buffer size + int memoryBufferSizeMB = DefaultMemoryBufferSizeMB; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize, out string? memoryBufferSizeStr) && + int.TryParse(memoryBufferSizeStr, out int parsedMemoryBufferSize) && + parsedMemoryBufferSize > 0) + { + memoryBufferSizeMB = parsedMemoryBufferSize; + } + + // Parse max retries + int maxRetries = 3; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && + int.TryParse(maxRetriesStr, out int parsedMaxRetries) && + parsedMaxRetries > 0) + { + maxRetries = parsedMaxRetries; + } + + // Parse retry delay + int retryDelayMs = 500; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) && + int.TryParse(retryDelayStr, out int parsedRetryDelay) && + parsedRetryDelay > 0) + { + retryDelayMs = parsedRetryDelay; + } + + // Parse timeout minutes + int timeoutMinutes = 5; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr) && + int.TryParse(timeoutStr, out int parsedTimeout) && + parsedTimeout > 0) + { + timeoutMinutes = parsedTimeout; + } + + // Initialize the memory manager + _memoryManager = new CloudFetchMemoryBufferManager(memoryBufferSizeMB); + + // Initialize the queues with bounded capacity + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + + // Initialize the HTTP client + _httpClient = new HttpClient + { + Timeout = TimeSpan.FromMinutes(timeoutMinutes) + }; + + // Initialize the result fetcher + _resultFetcher = new CloudFetchResultFetcher( + _statement, + _memoryManager, + _downloadQueue, + 2000000); Review Comment: Should this be configurable? Can it be lifted to a named constant? ########## csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs: ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Represents a downloaded result file with its associated metadata. + /// </summary> + public interface IDownloadResult : IDisposable Review Comment: These interfaces should not be public contracts ########## csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs: ########## @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.IO; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Special marker class that indicates the end of results in the download queue. + /// </summary> + internal sealed class EndOfResultsGuard : IDownloadResult + { + private static readonly Task<bool> CompletedTask = Task.FromResult(true); Review Comment: Can we use `Task.CompletedTask` instead of defining a new one? ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs: ########## @@ -61,81 +50,45 @@ internal sealed class CloudFetchReader : IArrowArrayStream /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> public CloudFetchReader(DatabricksStatement statement, Schema schema, bool isLz4Compressed) { - this.statement = statement; - this.schema = schema; + this.schema = schema ?? throw new ArgumentNullException(nameof(schema)); this.isLz4Compressed = isLz4Compressed; - // Get configuration values from connection properties or use defaults + // Check if prefetch is enabled var connectionProps = statement.Connection.Properties; - - // Parse max retries - int parsedMaxRetries = DefaultMaxRetries; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && - int.TryParse(maxRetriesStr, out parsedMaxRetries) && - parsedMaxRetries > 0) + isPrefetchEnabled = true; // Default to true + if (connectionProps.TryGetValue(Databricks.DatabricksParameters.CloudFetchPrefetchEnabled, out string? prefetchEnabledStr) && + bool.TryParse(prefetchEnabledStr, out bool parsedPrefetchEnabled)) Review Comment: consider throwing an exception if the value can't be parsed as boolean ########## csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs: ########## @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Databricks; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Manages the CloudFetch download pipeline. + /// </summary> + internal sealed class CloudFetchDownloadManager : ICloudFetchDownloadManager + { + // Default values + private const int DefaultParallelDownloads = 3; + private const int DefaultPrefetchCount = 2; + private const int DefaultMemoryBufferSizeMB = 200; + private const bool DefaultPrefetchEnabled = true; + + private readonly DatabricksStatement _statement; + private readonly Schema _schema; + private readonly bool _isLz4Compressed; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly ICloudFetchResultFetcher _resultFetcher; + private readonly ICloudFetchDownloader _downloader; + private readonly HttpClient _httpClient; + private bool _isDisposed; + private bool _isStarted; + private CancellationTokenSource? _cancellationTokenSource; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + public CloudFetchDownloadManager(DatabricksStatement statement, Schema schema, bool isLz4Compressed) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + + // Get configuration values from connection properties + var connectionProps = statement.Connection.Properties; + + // Parse parallel downloads + int parallelDownloads = DefaultParallelDownloads; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads, out string? parallelDownloadsStr) && + int.TryParse(parallelDownloadsStr, out int parsedParallelDownloads) && + parsedParallelDownloads > 0) + { + parallelDownloads = parsedParallelDownloads; + } + + // Parse prefetch count + int prefetchCount = DefaultPrefetchCount; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out string? prefetchCountStr) && + int.TryParse(prefetchCountStr, out int parsedPrefetchCount) && + parsedPrefetchCount > 0) + { + prefetchCount = parsedPrefetchCount; + } + + // Parse memory buffer size + int memoryBufferSizeMB = DefaultMemoryBufferSizeMB; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize, out string? memoryBufferSizeStr) && + int.TryParse(memoryBufferSizeStr, out int parsedMemoryBufferSize) && + parsedMemoryBufferSize > 0) + { + memoryBufferSizeMB = parsedMemoryBufferSize; + } + + // Parse max retries + int maxRetries = 3; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && + int.TryParse(maxRetriesStr, out int parsedMaxRetries) && + parsedMaxRetries > 0) + { + maxRetries = parsedMaxRetries; + } + + // Parse retry delay + int retryDelayMs = 500; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) && + int.TryParse(retryDelayStr, out int parsedRetryDelay) && + parsedRetryDelay > 0) + { + retryDelayMs = parsedRetryDelay; + } + + // Parse timeout minutes + int timeoutMinutes = 5; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr) && + int.TryParse(timeoutStr, out int parsedTimeout) && + parsedTimeout > 0) + { + timeoutMinutes = parsedTimeout; + } + + // Initialize the memory manager + _memoryManager = new CloudFetchMemoryBufferManager(memoryBufferSizeMB); + + // Initialize the queues with bounded capacity + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + + // Initialize the HTTP client + _httpClient = new HttpClient + { + Timeout = TimeSpan.FromMinutes(timeoutMinutes) + }; + + // Initialize the result fetcher + _resultFetcher = new CloudFetchResultFetcher( + _statement, + _memoryManager, + _downloadQueue, + 2000000); + + // Initialize the downloader + _downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _memoryManager, + _httpClient, + parallelDownloads, + _isLz4Compressed, + maxRetries, + retryDelayMs); + } + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// This constructor is intended for testing purposes only. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + /// <param name="resultFetcher">The result fetcher.</param> + /// <param name="downloader">The downloader.</param> + internal CloudFetchDownloadManager( + DatabricksStatement statement, + Schema schema, + bool isLz4Compressed, + ICloudFetchResultFetcher resultFetcher, + ICloudFetchDownloader downloader) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + _resultFetcher = resultFetcher ?? throw new ArgumentNullException(nameof(resultFetcher)); + _downloader = downloader ?? throw new ArgumentNullException(nameof(downloader)); + + // Create empty collections for the test + _memoryManager = new CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB); + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _httpClient = new HttpClient(); + } + + /// <inheritdoc /> + public bool HasMoreResults => !_downloader.IsCompleted || !_resultQueue.IsCompleted; + + /// <inheritdoc /> + public async Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); + + if (!_isStarted) + { + throw new InvalidOperationException("Download manager has not been started."); + } + + try + { + return await _downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // If there's an error in the downloader, check if there's also an error in the fetcher + if (_resultFetcher.HasError) + { + throw new AggregateException("Errors in download pipeline", new[] { ex, _resultFetcher.Error! }); + } + throw; + } + } + + /// <inheritdoc /> + public async Task StartAsync() + { + ThrowIfDisposed(); + + if (_isStarted) + { + throw new InvalidOperationException("Download manager is already started."); + } + + // Create a new cancellation token source + _cancellationTokenSource = new CancellationTokenSource(); + + // Start the result fetcher + await _resultFetcher.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false); + + // Start the downloader + await _downloader.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false); + + _isStarted = true; + } + + /// <inheritdoc /> + public async Task StopAsync() + { + if (!_isStarted) + { + return; + } + + // Cancel the token to signal all operations to stop + _cancellationTokenSource?.Cancel(); + + // Stop the downloader + await _downloader.StopAsync().ConfigureAwait(false); + + // Stop the result fetcher + await _resultFetcher.StopAsync().ConfigureAwait(false); + + // Dispose the cancellation token source + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + + _isStarted = false; + } + + /// <inheritdoc /> + public void Dispose() + { + if (_isDisposed) + { + return; + } + + // Stop the pipeline + StopAsync().GetAwaiter().GetResult(); + + // Dispose the HTTP client + _httpClient.Dispose(); + + // Dispose the cancellation token source if it hasn't been disposed yet + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + + // Mark the queues as completed to release any waiting threads + _downloadQueue.CompleteAdding(); + _resultQueue.CompleteAdding(); + + // Dispose any remaining results + foreach (var result in _resultQueue.GetConsumingEnumerable(CancellationToken.None)) + { + if (result != EndOfResultsGuard.Instance) + { + result.Dispose(); + } Review Comment: Given that `EndOfResultsGuard.Dispose()` is a nop, we could just `.Dispose()` without the `if` guard. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org