rtpsw commented on code in PR #13232: URL: https://github.com/apache/arrow/pull/13232#discussion_r888268399
########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr<DataType> type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - "equal", - "is_not_distinct_from", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } +namespace { + +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} + + std::vector<util::string_view> Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional<TypeRecord> GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index]}; + } + return {}; + } + + util::optional<TypeRecord> GetType(Id id) const override { + if (auto index = GetIndex(id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index]}; + } + return {}; + } + + Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override { + if (id_to_index_.find(id) != id_to_index_.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); + } + return Status::OK(); + } + + Status RegisterType(Id id, std::shared_ptr<DataType> type) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + auto index = static_cast<int>(type_ids_.size()); + + auto it_success = id_to_index_.emplace(copied_id, index); + + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } + + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index_.erase(it_success.first); + return Status::Invalid("Type was already registered"); } - std::vector<util::string_view> Uris() const override { - return {uris_.begin(), uris_.end()}; + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + return Status::OK(); + } + + util::optional<FunctionRecord> GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } - util::optional<TypeRecord> GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index]}; - } - return {}; + util::optional<FunctionRecord> GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } - util::optional<TypeRecord> GetType(Id id) const override { - if (auto index = GetIndex(id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index]}; - } - return {}; + Status CanRegisterFunction(Id id, + const std::string& arrow_function_name) const override { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) != + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); } + return Status::OK(); + } - Status RegisterType(Id id, std::shared_ptr<DataType> type) override { - DCHECK_EQ(type_ids_.size(), types_.size()); + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - auto index = static_cast<int>(type_ids_.size()); + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; - auto it_success = id_to_index_.emplace(copied_id, index); + auto index = static_cast<int>(function_ids_.size()); - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + auto it_success = function_id_to_index_.emplace(copied_id, index); - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index_.erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); + } - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - return Status::OK(); + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - util::optional<FunctionRecord> GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } + + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set<std::string> uris_, names_, function_names_; + DataTypeVector types_; + + // non-owning lookup helpers + std::vector<Id> type_ids_, function_ids_; + std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_; + std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_; + + std::vector<const std::string*> function_name_ptrs_; + std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_; + std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash> + function_name_to_index_; +}; + +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} + + virtual ~NestedExtensionIdRegistryImpl() {} + + std::vector<util::string_view> Uris() const override { + std::vector<util::string_view> uris = parent_->Uris(); + std::unordered_set<util::string_view> uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector<util::string_view>(uris); + } + + util::optional<TypeRecord> GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; } + return parent_->GetType(type); + } - util::optional<FunctionRecord> GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + util::optional<TypeRecord> GetType(Id id) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id); + if (type_opt) { + return type_opt; } + return parent_->GetType(id); + } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override { + return parent_->CanRegisterType(id, type) & + ExtensionIdRegistryImpl::CanRegisterType(id, type); + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Status RegisterType(Id id, std::shared_ptr<DataType> type) override { + return parent_->CanRegisterType(id, type) & + ExtensionIdRegistryImpl::RegisterType(id, type); Review Comment: I'm not convinced there is something unexpected here. Presumablym the use of `std::move` on a `std::shared_ptr<DataType>` invalidates the caller's copy, which leads to a segmentation fault when the caller tries to access it. AFAICS, passing or returning a `shared_ptr` by value is normal in Arrow, so it should be fine here too. To your question, I think my build is using gcc, for which I have version `gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org