This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e43555f739 [REFACTOR][DataType] Phase out target custom datatype 
support (#19760)
e43555f739 is described below

commit e43555f739474e318f1cfc15bd7a1951180d304b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 14 09:11:13 2026 -0400

    [REFACTOR][DataType] Phase out target custom datatype support (#19760)
    
    ## Summary
    
    The in-tree target custom datatype path adds maintenance surface while
    current development focuses on core datatypes. This PR phases out the
    built-in registry/lowering implementation and keeps the core dtype
    behavior intact.
    
    - Remove the target/datatype implementation, BYODT posit build option,
    and related Python helpers
    - Remove the custom datatype lowering pass from TIRX and S-TIR
    finalization pipelines
    - Simplify remaining TIRX dtype handling back to built-in/core datatypes
---
 CMakeLists.txt                               |   6 -
 cmake/modules/contrib/Posit.cmake            |  26 --
 docker/Dockerfile.ci_cpu                     |   4 -
 docker/Dockerfile.ci_gpu                     |   4 -
 include/tvm/tirx/op.h                        |   7 -
 include/tvm/tirx/transform.h                 |   9 -
 python/tvm/s_tir/pipeline.py                 |   2 -
 python/tvm/target/__init__.py                |   1 -
 python/tvm/target/datatype.py                | 379 ---------------------------
 python/tvm/tirx/compilation_pipeline.py      |   2 -
 python/tvm/tirx/transform/transform.py       |  13 -
 src/arith/rewrite_simplify.cc                |   1 -
 src/target/datatype/myfloat/myfloat.cc       | 144 ----------
 src/target/datatype/posit/posit-wrapper.cc   | 242 -----------------
 src/target/datatype/registry.cc              | 138 ----------
 src/target/datatype/registry.h               | 182 -------------
 src/target/llvm/codegen_llvm.cc              |   7 +-
 src/tirx/op/op.cc                            |  25 +-
 src/tirx/transform/lower_custom_datatypes.cc | 266 -------------------
 19 files changed, 8 insertions(+), 1450 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ad99c4c6ac..0eb8c90184 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -87,7 +87,6 @@ tvm_option(USE_CCACHE "Use ccache if found when invoking 
compiler" AUTO)
 # 3rdparty libraries
 tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
 # Contrib library options
-tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom 
datatype" OFF)
 tvm_option(USE_BLAS "The blas library to be linked" none)
 tvm_option(USE_AMX "Enable Intel AMX" OFF)
 tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
@@ -356,10 +355,6 @@ tvm_file_glob(GLOB CODEGEN_SRCS
 
 list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
 
-tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
-list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
-list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
-
 tvm_file_glob(GLOB RUNTIME_SRCS
   src/runtime/*.cc
   src/runtime/vm/*.cc
@@ -464,7 +459,6 @@ include(cmake/modules/contrib/DNNL.cmake)
 include(cmake/modules/contrib/AMX.cmake)
 include(cmake/modules/contrib/CUTLASS.cmake)
 include(cmake/modules/contrib/Random.cmake)
-include(cmake/modules/contrib/Posit.cmake)
 include(cmake/modules/contrib/Sort.cmake)
 include(cmake/modules/contrib/CoreML.cmake)
 include(cmake/modules/contrib/TensorRT.cmake)
diff --git a/cmake/modules/contrib/Posit.cmake 
b/cmake/modules/contrib/Posit.cmake
deleted file mode 100644
index b8d180ee44..0000000000
--- a/cmake/modules/contrib/Posit.cmake
+++ /dev/null
@@ -1,26 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-if(USE_BYODT_POSIT)
-  message(STATUS "Build with contrib.posit")
-  if (NOT UNIVERSAL_PATH)
-    message(FATAL_ERROR "Fail to get Universal path")
-  endif(NOT UNIVERSAL_PATH)
-
-  include_directories(${UNIVERSAL_PATH}/include)
-  list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
-endif(USE_BYODT_POSIT)
diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu
index 8e31b310fe..e823db54b2 100644
--- a/docker/Dockerfile.ci_cpu
+++ b/docker/Dockerfile.ci_cpu
@@ -63,10 +63,6 @@ RUN bash /install/ubuntu_install_dnnl.sh
 COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh
 RUN bash /install/ubuntu_install_xgboost.sh
 
-# BYODT deps
-COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh
-RUN bash /install/ubuntu_install_universal.sh
-
 # TensorFlow deps
 COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh
 RUN bash /install/ubuntu_install_tensorflow.sh
diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu
index df15215b94..2f9139f842 100644
--- a/docker/Dockerfile.ci_gpu
+++ b/docker/Dockerfile.ci_gpu
@@ -115,10 +115,6 @@ RUN bash /install/ubuntu_install_vulkan.sh
 COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh
 RUN bash /install/ubuntu_install_xgboost.sh
 
-# BYODT deps
-COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh
-RUN bash /install/ubuntu_install_universal.sh
-
 # sccache
 COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh
 RUN bash /install/ubuntu_install_sccache.sh
diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 2027665712..7a7584aff2 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -998,13 +998,6 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType 
value, Span span = Span())
   }
   if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || 
t.is_float4())
     return FloatImm(t, static_cast<double>(value), span);
-  // For now, we store const scalar values of custom datatypes within doubles; 
later, during the
-  // datatypes lowering pass, we will lower the value to its true 
representation in the format
-  // specified by the datatype.
-  // TODO(gus) when do we need to start worrying about doubles not being 
precise enough?
-  if (static_cast<uint8_t>(t.code()) >= 
static_cast<uint8_t>(DataType::kCustomBegin)) {
-    return FloatImm(t, static_cast<double>(value), span);
-  }
   TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
   throw;
 }
diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h
index 32a3ea8b29..e5a754f6c5 100644
--- a/include/tvm/tirx/transform.h
+++ b/include/tvm/tirx/transform.h
@@ -153,15 +153,6 @@ TVM_DLL Pass MakePackedAPI();
  */
 TVM_DLL Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> axis_map);
 
-/*!
- * \brief Lower custom datatypes.
- *
- * See tvm::datatypes::Registry for more information on adding custom 
datatypes.
- *
- * \return The pass.
- */
-TVM_DLL Pass LowerCustomDatatypes();
-
 /*!
  * \brief Annotate, split, and lower host/device functions.
  *
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index fb8310dc26..df1cd74a21 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -125,7 +125,6 @@ def finalize_host_passes():  # pylint: 
disable=unused-argument
     """The default finalization passes for TIR backend."""
     host_pass_list = [
         tirx.transform.LowerTVMBuiltin(),
-        tirx.transform.LowerCustomDatatypes(),
         tirx.transform.LowerIntrin(),
     ]
     return tvm.ir.transform.Sequential(host_pass_list)
@@ -136,7 +135,6 @@ def finalize_device_passes():  # pylint: 
disable=unused-argument
     device_pass_list = [
         tirx.transform.LowerWarpMemory(),
         tirx.transform.StmtSimplify(),
-        tirx.transform.LowerCustomDatatypes(),
         tirx.transform.LowerIntrin(),
     ]
     return tvm.ir.transform.Sequential(device_pass_list)
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 5c6733cf8c..7303a0d097 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -34,6 +34,5 @@ and :py:func:`tvm.target.register_tag` to register new tags.
 from .target import Target, TargetKind
 from .virtual_device import VirtualDevice
 from .tag import list_tags, register_tag
-from . import datatype
 from . import codegen
 from . import tag_registry  # registers tags on import
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
deleted file mode 100644
index d7a47836b2..0000000000
--- a/python/tvm/target/datatype.py
+++ /dev/null
@@ -1,379 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: F821
-"""Bring Your Own Datatypes custom datatype framework
-
-TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist"""
-
-from tvm_ffi import get_global_func
-from tvm_ffi import register_global_func as _register_global_func
-
-import tvm
-from tvm.runtime import DataType, convert
-from tvm.tirx import call_intrin
-from tvm.tirx.expr import (
-    BinaryOpExpr as _BinaryOpExpr,
-)
-from tvm.tirx.expr import (
-    Call as _Call,
-)
-from tvm.tirx.expr import (
-    Cast as _Cast,
-)
-from tvm.tirx.expr import (
-    FloatImm as _FloatImm,
-)
-from tvm.tirx.op import call_pure_extern
-
-
-def register(type_name, type_code):
-    """Register a custom datatype with the given type name and type code
-
-    Currently, the type code is manually allocated by the user, and the user
-    must ensure that no two custom types share the same code. Generally, this
-    should be straightforward, as the user will be manually registering all of
-    their custom types.
-
-    Example:
-
-    .. code-block:: python
-
-        # Register a dtype named 'posites2' under type code 130.
-        tvm.target.datatype.register('posites2', 130)
-
-
-    Parameters
-    ----------
-    type_name : str
-        The name of the custom datatype.
-
-    type_code : int
-        The type's code, which should be >= kCustomBegin. See
-        include/tvm/runtime/data_type.h.
-    """
-    get_global_func("dtype.register_custom_type")(type_name, type_code)
-
-
-def get_type_name(type_code):
-    """Get the type name of a custom datatype from the type code.
-
-    Note that this only works for custom datatypes registered with
-    tvm.target.datatype.register(). It does not work for TVM-native types.
-
-    Example:
-
-    .. code-block:: python
-
-        tvm.target.datatype.register('posites2', 130)
-        assert tvm.target.datatype.get_type_name(130) == 'posites2'
-
-    Parameters
-    ----------
-    type_code : int
-        The type code of the custom datatype.
-
-    Returns
-    -------
-    type_name : String
-        The name of the custom datatype.
-
-    """
-    return get_global_func("dtype.get_custom_type_name")(type_code)
-
-
-def get_type_code(type_name):
-    """Get the type code of a custom datatype from its type name
-
-    Note that this only works for custom datatypes registered with
-    tvm.target.datatype.register(). It does not work for TVM-native types.
-
-    Example:
-
-    .. code-block:: python
-
-        tvm.target.datatype.register('posites2', 130)
-        assert tvm.target.datatype.get_type_code('posites2') == 130
-
-    Parameters
-    ----------
-    type_name : str
-        The type name
-
-    Returns
-    -------
-    type_code : int
-        The type code of the custom datatype.
-    """
-    return get_global_func("dtype.get_custom_type_code")(type_name)
-
-
-def get_type_registered(type_code):
-    """Returns true if a custom datatype is registered under the given type 
code
-
-    Example:
-
-    .. code-block:: python
-
-        tvm.target.datatype.register('posites2', 130)
-        assert tvm.target.datatype.get_type_registered(130)
-
-    Parameters
-    ----------
-    type_code: int
-        The type code
-
-    Returns
-    -------
-    type_registered : bool
-        True if a custom datatype is registered under this type code, and false
-        otherwise.
-    """
-    return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
-
-
-def register_op(
-    lower_func, op_name, target, src_type_name, dest_type_name=None, 
intrinsic_name=None
-):
-    """Register a lowering function for a specific operator of a custom 
datatype
-
-    At build time, Relay must lower operators over custom datatypes into
-    operators it understands how to compile. For each custom datatype operator
-    which Relay finds while lowering custom datatypes, Relay expects to find a
-    user-defined lowering function. Users register their user-defined lowering
-    functions using this function.
-
-    Users should use create_lower_func to create their lowering function. It
-    should serve most use-cases.
-
-    Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and
-    binary expressions (e.g. Add, Sub, Mul, Div).
-
-    See the LowerCustomDatatypes pass to see how registered functions are used.
-
-    Lowering Functions
-    ------------------
-    TODO(@gussmith23) Get the terminology right here.
-    Lowering functions take in a Relay node, and should return a semantically
-    equivalent Relay node which Relay can build. This means that the returned
-    node should not contain any custom datatypes. Users should likely not need
-    to define lowering functions by hand -- see the helper function
-    create_lower_func.
-
-    Parameters
-    ----------
-    lower_func : function
-        The lowering function to call. See create_lower_func.
-
-    op_name : str
-        The name of the operation which the function computes, given by its
-        class name (e.g. Add, LE, Cast, Call).
-
-    target : str
-        The name of codegen target.
-
-    src_type_name : str
-        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
-        If op_name is not "Cast", then target type is guaranteed to be the 
same as src_type_name.
-
-    dest_type_name : str
-        If op_name is "Cast", then this is required and should be set to the 
dest datatype of
-        the argument to the Cast. If op_name is not "Cast", this is unused.
-
-    intrinsic_name : str
-        If op_name is "Call" and intrinsic_name is not None, then we assume the
-        op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
-        name.
-    """
-
-    if op_name == "Cast":
-        assert dest_type_name is not None
-        lower_func_name = (
-            "tvm.datatype.lower."
-            + target
-            + "."
-            + op_name
-            + "."
-            + dest_type_name
-            + "."
-            + src_type_name
-        )
-    elif op_name == "Call" and intrinsic_name is not None:
-        lower_func_name = (
-            "tvm.datatype.lower."
-            + target
-            + "."
-            + op_name
-            + ".intrin."
-            + intrinsic_name
-            + "."
-            + src_type_name
-        )
-    else:
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." 
+ src_type_name
-    tvm_ffi.register_global_func(lower_func_name, lower_func)
-
-
-def register_min_func(func, type_name):
-    """Register the function that returns the minimum representable value of 
type_name.
-
-    Operators such as max pooling and argmax require the minimum
-    finite value representable by the datatype the op operating on.
-    Users can use this function to register a function that returns a TIR 
expression node
-    outputting the minimum representable value of their custom data type.
-
-    Users should use create_min_lower_func to create their lowering function. 
It
-    should serve most use-cases.
-
-    Note: for special cases when it is known that the custom datatype is 
representable
-    by a float, the user can create their own lowering func that returns a 
FloatImm.
-    The benefits are allowing optimizations such as rewrites to work as 
expected on custom
-    datatypes.
-
-    Parameters
-    ----------
-    func : function
-        Input is an integer num_bits, should return a TIR expression node that
-        represents a scalar tensor of type custom[type_name]num_bits with the 
minimum
-        representable value.
-
-    type_name : str
-        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
-    """
-    _register_global_func("tvm.datatype.min." + type_name, func)
-
-
-def create_min_lower_func(extern_func_map, type_name):
-    """Returns a lowering function for getting the minimum value of a custom 
datatype.
-
-    Parameters
-    ----------
-    extern_func_map : map
-        A map from bit lengths to the name of the extern "C" function to lower 
to.
-
-    type_name : string
-        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
-    """
-
-    def lower(num_bits):
-        dtype = f"custom[{type_name}]{num_bits}"
-
-        if num_bits not in extern_func_map:
-            raise RuntimeError("missing minimum function for {dtype}")
-
-        return call_pure_extern(dtype, extern_func_map[num_bits])
-
-    return lower
-
-
-def create_lower_func(extern_func_map):
-    """Returns a function which lowers an operation to a function call.
-
-    Parameters
-    ----------
-    extern_func_map : map
-        If lowering a Cast, extern_func_map should be a map from tuples of
-        (src_bit_length, dest_bit_length) to the name of the extern "C" 
function to lower to.
-
-        Otherwise, for unary and binary ops, it should simply be a map
-        from bit_length to the name of the extern "C" function to lower to.
-    """
-
-    def lower(op):
-        """
-        Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and 
returns a
-        call to the specified external function, passing the op's argument
-        or arguments. The return type of the call depends
-        on the type of the op: if it is a custom type, then a uint of the same
-        width as the custom type is returned. Otherwise, the type is
-        unchanged."""
-        dtype = op.dtype
-        t = DataType(dtype)
-        if get_type_registered(t.type_code):
-            dtype = "uint" + str(t.bits)
-            if t.lanes > 1:
-                dtype += "x" + str(t.lanes)
-
-        key = t.bits
-        if isinstance(op, _Cast):
-            src_bits = DataType(op.value.dtype).bits
-            key = (src_bits, t.bits)
-
-        if key not in extern_func_map:
-            raise RuntimeError(f"missing key {key} in extern_func_map for 
{op}")
-
-        if isinstance(op, _Cast):
-            return call_pure_extern(dtype, extern_func_map[key], op.value)
-        if isinstance(op, _FloatImm):
-            return call_pure_extern(dtype, extern_func_map[key], op.value)
-        if isinstance(op, _Call):
-            return call_pure_extern(dtype, extern_func_map[key], *op.args)
-        if isinstance(op, _BinaryOpExpr):
-            return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)
-
-        raise RuntimeError(f"lowering unsupported op: {op}")
-
-    return lower
-
-
-def lower_ite(ite_op):
-    """Lowered if then else function that calls intrinsic if_then_else.
-    Unlike a function lowered by create_lower_func, this function
-    calls the tvm intrinsic if_then_else.
-
-    Parameters
-    ----------
-    ite_op : Op
-        Takes an if then else op and returns a
-        call to tirx.if_then_else function, passing the op's
-        arguments. The return type of the call if a uint of the same
-        width as the custom type is returned.
-    """
-    dtype = ite_op.dtype
-    t = tvm.DataType(dtype)
-    assert get_type_registered(t.type_code)
-    dtype = "uint" + str(t.bits)
-    if t.lanes > 1:
-        dtype += "x" + str(t.lanes)
-    return call_intrin(
-        dtype,
-        "tirx.if_then_else",
-        convert(ite_op.args[0]),
-        convert(ite_op.args[1]),
-        convert(ite_op.args[2]),
-    )
-
-
-def lower_call_pure_extern(op):
-    """Lowered call pure extern function that calls intrinsic call_pure_extern.
-    Unlike a function lowered by create_lower_func, this function
-    calls the tvm intrinsic call_pure_extern.
-
-    Parameters
-    ----------
-    ite_op : Op
-        Takes a call_pure_extern op and returns a
-        call to tirx.call_pure_extern function, passing the op's
-        arguments. The return type of the call if a uint of the same
-        width as the custom type is returned.
-    """
-    dtype = op.dtype
-    t = tvm.DataType(dtype)
-    assert get_type_registered(t.type_code)
-    dtype = "uint" + str(t.bits)
-    if t.lanes > 1:
-        dtype += "x" + str(t.lanes)
-    return call_intrin(dtype, "tirx.call_pure_extern", *op.args)
diff --git a/python/tvm/tirx/compilation_pipeline.py 
b/python/tvm/tirx/compilation_pipeline.py
index d2847332b4..23dee416bb 100644
--- a/python/tvm/tirx/compilation_pipeline.py
+++ b/python/tvm/tirx/compilation_pipeline.py
@@ -103,7 +103,6 @@ def finalize_host_passes():  # pylint: 
disable=unused-argument
     """The default finalization passes for TIR backend."""
     host_pass_list = [
         tirx.transform.LowerTVMBuiltin(),
-        tirx.transform.LowerCustomDatatypes(),
         tirx.transform.LowerIntrin(),
     ]
     return tvm.ir.transform.Sequential(host_pass_list)
@@ -114,7 +113,6 @@ def finalize_device_passes():  # pylint: 
disable=unused-argument
     device_pass_list = [
         tirx.transform.LowerWarpMemory(),
         tirx.transform.StmtSimplify(),
-        tirx.transform.LowerCustomDatatypes(),
         tirx.transform.LowerIntrin(),
     ]
     return tvm.ir.transform.Sequential(device_pass_list)
diff --git a/python/tvm/tirx/transform/transform.py 
b/python/tvm/tirx/transform/transform.py
index 72a5b96202..ae6b942b66 100644
--- a/python/tvm/tirx/transform/transform.py
+++ b/python/tvm/tirx/transform/transform.py
@@ -245,19 +245,6 @@ def ConvertSSA():
     return _ffi_api.ConvertSSA()  # type: ignore
 
 
-def LowerCustomDatatypes():
-    """Lower custom datatypes.
-
-    See tvm::datatypes::Registry for more information on adding custom 
datatypes.
-
-    Returns
-    -------
-    fpass : tvm.transform.Pass
-        The result pass
-    """
-    return _ffi_api.LowerCustomDatatypes()  # type: ignore
-
-
 def MakePackedAPI():
     """Transform the PrimFuncs in the module to a packed func API.
 
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index bec5091883..5a86cdd15a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -34,7 +34,6 @@
 #include <tuple>
 #include <utility>
 
-#include "../target/datatype/registry.h"
 #include "../tirx/analysis/check_contains.h"
 #include "conjunctive_normal_form.h"
 #include "const_fold.h"
diff --git a/src/target/datatype/myfloat/myfloat.cc 
b/src/target/datatype/myfloat/myfloat.cc
deleted file mode 100644
index afee8a7c4b..0000000000
--- a/src/target/datatype/myfloat/myfloat.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file 3rdparty/byodt/my-custom-datatype.cc
- * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) 
framework.
- * This is a toy example that under the hood simulates floats.
- *
- * Users interested in using the BYODT framework can use this file as a 
template.
- *
- * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
- */
-#include <tvm/runtime/base.h>
-
-#include <cmath>
-#include <cstdint>
-#include <limits>
-
-// Custom datatypes are stored as bits in a uint of the appropriate bit length.
-// Thus, when TVM calls these C functions,
-// the arguments of are uints that need to reinterpreted as your custom 
datatype.
-//
-// When returning, your custom datatype needs to be re-wrapped into a uint,
-// which can be thought of as just a wrapper for the raw bits that represent 
your custom datatype.
-template <class T>
-TVM_DLL T Uint32ToCustom32(uint32_t in) {
-  // This is a helper function to interpret the uint as your custom dataype.
-  // The following line should be replaced with the appropriate function
-  // that interprets the bits in `in` and returns your custom datatype
-  T* custom = reinterpret_cast<T*>(&in);
-  return *custom;
-}
-
-template <class T>
-TVM_DLL uint32_t Custom32ToUint32(T in) {
-  // This is a helper function to wrap your custom datatype in a uint.
-  // the following line should be replaced with the appropriate function
-  // that converts your custom datatype into a uint
-  uint32_t* bits = reinterpret_cast<uint32_t*>(&in);
-  return *bits;
-}
-
-extern "C" {
-TVM_DLL uint32_t MinCustom32() {
-  // return minimum representable value
-  float min = std::numeric_limits<float>::lowest();
-  return Custom32ToUint32<float>(min);
-}
-
-TVM_DLL float Custom32ToFloat(uint32_t in) {
-  // cast from custom datatype to float
-  float custom_datatype = Uint32ToCustom32<float>(in);
-  // our custom datatype is float, so the following redundant cast to float
-  // is to remind users to cast their own custom datatype to float
-  return static_cast<float>(custom_datatype);
-}
-
-TVM_DLL uint32_t FloatToCustom32(float in) {
-  // cast from float to custom datatype
-  return Custom32ToUint32<float>(in);
-}
-
-TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) {
-  // add operation
-  float acustom = Uint32ToCustom32<float>(a);
-  float bcustom = Uint32ToCustom32<float>(b);
-  return Custom32ToUint32<float>(acustom + bcustom);
-}
-
-TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) {
-  // subtract
-  float acustom = Uint32ToCustom32<float>(a);
-  float bcustom = Uint32ToCustom32<float>(b);
-  return Custom32ToUint32<float>(acustom - bcustom);
-}
-
-TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) {
-  // multiply
-  float acustom = Uint32ToCustom32<float>(a);
-  float bcustom = Uint32ToCustom32<float>(b);
-  return Custom32ToUint32<float>(acustom * bcustom);
-}
-
-TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) {
-  // divide
-  float acustom = Uint32ToCustom32<float>(a);
-  float bcustom = Uint32ToCustom32<float>(b);
-  return Custom32ToUint32<float>(acustom / bcustom);
-}
-
-TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) {
-  // max
-  float acustom = Uint32ToCustom32<float>(a);
-  float bcustom = Uint32ToCustom32<float>(b);
-  return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom);
-}
-
-TVM_DLL uint32_t Custom32Sqrt(uint32_t a) {
-  // sqrt
-  float acustom = Uint32ToCustom32<float>(a);
-  return Custom32ToUint32<float>(sqrt(acustom));
-}
-
-TVM_DLL uint32_t Custom32Exp(uint32_t a) {
-  // exponential
-  float acustom = Uint32ToCustom32<float>(a);
-  return Custom32ToUint32<float>(exp(acustom));
-}
-
-TVM_DLL uint32_t Custom32Log(uint32_t a) {
-  // log
-  float acustom = Uint32ToCustom32<float>(a);
-  return Custom32ToUint32<float>(log(acustom));
-}
-
-TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) {
-  // sigmoid
-  float acustom = Uint32ToCustom32<float>(a);
-  float one = 1.0f;
-  return Custom32ToUint32<float>(one / (one + exp(-acustom)));
-}
-
-TVM_DLL uint32_t Custom32Tanh(uint32_t a) {
-  // tanh
-  float acustom = Uint32ToCustom32<float>(a);
-  return Custom32ToUint32<float>(tanh(acustom));
-}
-}
diff --git a/src/target/datatype/posit/posit-wrapper.cc 
b/src/target/datatype/posit/posit-wrapper.cc
deleted file mode 100644
index e05695e603..0000000000
--- a/src/target/datatype/posit/posit-wrapper.cc
+++ /dev/null
@@ -1,242 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file 3rdparty/posit/posit-wrapper.cc
- * \brief Wrapper over the Stillwater Universal library for Bring Your Own 
Datatypes tests
- *
- * To compile TVM with this file,
- * 1. clone the Stillwater Universal repo from here 
`https://github.com/stillwater-sc/universal`.
- * 2. set `SET_BYODT_POSIT` ON and `UNIVERSAL_PATH` as the path to the folder 
containing Stillwater
- * Universal in your CMake file
- *
- * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
- */
-#include <tvm/runtime/base.h>
-
-#include <cstdint>
-
-#include "universal/posit/posit.hpp"
-// must go after posit.hpp
-#include "universal/posit/math/exponent.hpp"
-#include "universal/posit/math/hyperbolic.hpp"
-#include "universal/posit/math/logarithm.hpp"
-#include "universal/posit/math/sqrt.hpp"
-#include "universal/posit/numeric_limits.hpp"
-
-TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
-  sw::unum::bitblock<8> bb;
-  bb = static_cast<uint64_t>(in);
-  return sw::unum::posit<8, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
-  return static_cast<uint8_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit8es2() {
-  auto min = std::numeric_limits<sw::unum::posit<8, 2>>::lowest();
-  return Posit8es2toUint8(min);
-}
-
-TVM_DLL float Posit8es2ToFloat(uint8_t in) { return 
Uint8ToPosit8es2(in).operator float(); }
-
-TVM_DLL uint8_t FloatToPosit8es2(float in) {
-  auto posit = sw::unum::posit<8, 2>(in);
-  return Posit8es2toUint8(posit);
-}
-
-TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
-  return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
-  return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
-  return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
-  return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
-  auto a_p = Uint8ToPosit8es2(a);
-  auto b_p = Uint8ToPosit8es2(b);
-  return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
-  return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
-  return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
-  return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
-  auto posit_one = sw::unum::posit<8, 2>(1);
-  return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) + 
posit_one));
-}
-
-TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
-  return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
-}
-}
-
-TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
-  sw::unum::bitblock<16> bb;
-  bb = static_cast<uint64_t>(in);
-  return sw::unum::posit<16, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
-  return static_cast<uint16_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit16es2() {
-  auto min = std::numeric_limits<sw::unum::posit<16, 2>>::lowest();
-  return Posit16es2toUint16(min);
-}
-
-TVM_DLL float Posit16es2ToFloat(uint16_t in) { return 
Uint16ToPosit16es2(in).operator float(); }
-
-TVM_DLL uint16_t FloatToPosit16es2(float in) {
-  auto posit = sw::unum::posit<16, 2>(in);
-  return Posit16es2toUint16(posit);
-}
-
-TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
-  return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
-  return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
-  return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
-  return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
-  auto a_p = Uint16ToPosit16es2(a);
-  auto b_p = Uint16ToPosit16es2(b);
-  return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
-  return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
-  return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
-  return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
-  auto posit_one = sw::unum::posit<16, 2>(1);
-  return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a)) 
+ posit_one));
-}
-
-TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
-  return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
-}
-}
-
-TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
-  sw::unum::bitblock<32> bb;
-  bb = static_cast<uint64_t>(in);
-  return sw::unum::posit<32, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) {
-  return static_cast<uint32_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit32es2() {
-  auto min = std::numeric_limits<sw::unum::posit<32, 2>>::lowest();
-  return Posit32es2ToUint32(min);
-}
-
-TVM_DLL float Posit32es2ToFloat(uint32_t in) { return 
Uint32ToPosit32es2(in).operator float(); }
-
-TVM_DLL uint32_t FloatToPosit32es2(float in) {
-  auto posit = sw::unum::posit<32, 2>(in);
-  return Posit32es2ToUint32(posit);
-}
-
-TVM_DLL uint32_t Posit32es2Add(uint32_t a, uint32_t b) {
-  return Posit32es2ToUint32(Uint32ToPosit32es2(a) + Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Sub(uint32_t a, uint32_t b) {
-  return Posit32es2ToUint32(Uint32ToPosit32es2(a) - Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Mul(uint32_t a, uint32_t b) {
-  return Posit32es2ToUint32(Uint32ToPosit32es2(a) * Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Div(uint32_t a, uint32_t b) {
-  return Posit32es2ToUint32(Uint32ToPosit32es2(a) / Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Max(uint32_t a, uint32_t b) {
-  auto a_p = Uint32ToPosit32es2(a);
-  auto b_p = Uint32ToPosit32es2(b);
-  return Posit32es2ToUint32(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint32_t Posit32es2Sqrt(uint32_t a) {
-  return Posit32es2ToUint32(sw::unum::sqrt(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Exp(uint32_t a) {
-  return Posit32es2ToUint32(sw::unum::exp(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Log(uint32_t a) {
-  return Posit32es2ToUint32(sw::unum::log(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Sigmoid(uint32_t a) {
-  auto posit_one = sw::unum::posit<32, 2>(1);
-  return Posit32es2ToUint32(posit_one / (posit_one + 
sw::unum::exp(-Uint32ToPosit32es2(a))));
-}
-
-TVM_DLL uint32_t Posit32es2Tanh(uint32_t a) {
-  return Posit32es2ToUint32(sw::unum::tanh(Uint32ToPosit32es2(a)));
-}
-}
diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc
deleted file mode 100644
index 9d6459df6c..0000000000
--- a/src/target/datatype/registry.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include "registry.h"
-
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/data_type.h>
-
-namespace tvm {
-namespace datatype {
-
-using ffi::Any;
-using ffi::PackedArgs;
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef()
-      .def_packed("dtype.register_custom_type",
-                  [](ffi::PackedArgs args, ffi::Any* ret) {
-                    datatype::Registry::Global()->Register(
-                        args[0].cast<std::string>(), 
static_cast<uint8_t>(args[1].cast<int>()));
-                  })
-      .def_packed("dtype.get_custom_type_code",
-                  [](ffi::PackedArgs args, ffi::Any* ret) {
-                    *ret = 
datatype::Registry::Global()->GetTypeCode(args[0].cast<std::string>());
-                  })
-      .def_packed("dtype.get_custom_type_name",
-                  [](ffi::PackedArgs args, ffi::Any* ret) {
-                    *ret = 
Registry::Global()->GetTypeName(args[0].cast<int>());
-                  })
-      .def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs 
args, ffi::Any* ret) {
-        *ret = Registry::Global()->GetTypeRegistered(args[0].cast<int>());
-      });
-}
-
-Registry* Registry::Global() {
-  static Registry inst;
-  return &inst;
-}
-
-void Registry::Register(const std::string& type_name, uint8_t type_code) {
-  TVM_FFI_ICHECK(type_code >= DataType::kCustomBegin)
-      << "Please choose a type code >= DataType::kCustomBegin for custom 
types";
-  code_to_name_[type_code] = type_name;
-  name_to_code_[type_name] = type_code;
-}
-
-uint8_t Registry::GetTypeCode(const std::string& type_name) {
-  TVM_FFI_ICHECK(name_to_code_.find(type_name) != name_to_code_.end())
-      << "Type name " << type_name << " not registered";
-  return name_to_code_[type_name];
-}
-
-std::string Registry::GetTypeName(uint8_t type_code) {
-  TVM_FFI_ICHECK(code_to_name_.find(type_code) != code_to_name_.end())
-      << "Type code " << static_cast<unsigned>(type_code) << " not registered";
-  return code_to_name_[type_code];
-}
-
-std::optional<tvm::ffi::Function> GetCastLowerFunc(const std::string& target, 
uint8_t type_code,
-                                                   uint8_t src_type_code) {
-  std::ostringstream ss;
-  ss << "tvm.datatype.lower.";
-  ss << target << ".";
-  ss << "Cast"
-     << ".";
-
-  if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
-    ss << datatype::Registry::Global()->GetTypeName(type_code);
-  } else {
-    ss << 
ffi::details::DLDataTypeCodeAsCStr(static_cast<DLDataTypeCode>(type_code));
-  }
-
-  ss << ".";
-
-  if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
-    ss << datatype::Registry::Global()->GetTypeName(src_type_code);
-  } else {
-    ss << 
ffi::details::DLDataTypeCodeAsCStr(static_cast<DLDataTypeCode>(src_type_code));
-  }
-  return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetMinFunc(uint8_t type_code) {
-  std::ostringstream ss;
-  ss << "tvm.datatype.min.";
-  ss << datatype::Registry::Global()->GetTypeName(type_code);
-  return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetFloatImmLowerFunc(const std::string& 
target,
-                                                       uint8_t type_code) {
-  std::ostringstream ss;
-  ss << "tvm.datatype.lower.";
-  ss << target;
-  ss << ".FloatImm.";
-  ss << datatype::Registry::Global()->GetTypeName(type_code);
-  return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetIntrinLowerFunc(const std::string& target,
-                                                     const std::string& name, 
uint8_t type_code) {
-  std::ostringstream ss;
-  ss << "tvm.datatype.lower.";
-  ss << target;
-  ss << ".Call.intrin.";
-  ss << name;
-  ss << ".";
-  ss << datatype::Registry::Global()->GetTypeName(type_code);
-  return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-uint64_t ConvertConstScalar(uint8_t type_code, double value) {
-  std::ostringstream ss;
-  ss << "tvm.datatype.convertconstscalar.float.";
-  ss << datatype::Registry::Global()->GetTypeName(type_code);
-  auto make_const_scalar_func = tvm::ffi::Function::GetGlobal(ss.str());
-  return (*make_const_scalar_func)(value).cast<uint64_t>();
-}
-
-}  // namespace datatype
-}  // namespace tvm
diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h
deleted file mode 100644
index 363494e0fd..0000000000
--- a/src/target/datatype/registry.h
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_TARGET_DATATYPE_REGISTRY_H_
-#define TVM_TARGET_DATATYPE_REGISTRY_H_
-
-#include <tvm/ffi/function.h>
-
-#include <string>
-#include <unordered_map>
-
-namespace tvm {
-namespace datatype {
-
-/*!
- * \brief Registry for custom datatypes.
- *
- * Adding custom datatypes currently requires two steps:
- * 1. Register the datatype with the registry via a call to
- *    datatype::Registry::Register. This can also be done in Python
- *    directly---see the TVM globals registered in the corresponding .cc file.
- *    Currently, user should manually choose a type name and a type code,
- *    ensuring that neither conflict with existing types.
- * 2. Register the lowering functions needed to
- *    lower the custom datatype. In general, these will look like:
- *      For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
- *        Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
- *                 float to myfloat.
- * For intrinsic Calls: tvm.datatype.lower.<target>.Call.intrin.<name>.<type>
- *             Example: tvm.datatype.lower.llvm.Call.intrin.sqrt.myfloat
- *  For other ops: tvm.datatype.lower.<target>.<op>.<type>
- *       Examples: tvm.datatype.lower.llvm.Add.myfloat
- *                 tvm.datatype.lower.llvm.FloatImm.posit
- */
-class Registry {
- public:
-  /*!
-   * \brief Get the global custom datatype registry singleton
-   */
-  static Registry* Global();
-
-  /*!
-   * \brief Register custom datatype
-   * Register a custom datatype with the given type name and type code. 
Currently, the type code is
-   * manually allocated by the user, and the user must ensure that no two 
custom types share the
-   * same code. Generally, this should be straightforward, as the user will be 
manually registering
-   * all of their custom types.
-   * \param type_name The name of the type, e.g. "posites2"
-   * \param type_code The type code, which should be greater than 
TVMArgTypeCode::kTVMExtEnd
-   */
-  void Register(const std::string& type_name, uint8_t type_code);
-
-  /*!
-   * \brief Get type code from type name
-   * \param type_name The type name
-   * \return The type code
-   */
-  uint8_t GetTypeCode(const std::string& type_name);
-
-  /*!
-   * \brief Get type name from type code
-   * \param type_code The type code
-   * \return The type name
-   */
-  std::string GetTypeName(uint8_t type_code);
-
-  /*!
-   * \brief Get bool representing whether type is registered, given the type 
code
-   * \param type_code The type code
-   * \return bool representing whether the type is registered
-   */
-  inline bool GetTypeRegistered(uint8_t type_code) {
-    return code_to_name_.find(type_code) != code_to_name_.end();
-  }
-
-  /*!
-   * \brief Get bool representing whether type is registered, given the type 
name
-   * \param type_name The type name
-   * \return bool representing whether the type is registered
-   */
-  inline bool GetTypeRegistered(std::string type_name) {
-    return name_to_code_.find(type_name) != name_to_code_.end();
-  }
-
- private:
-  // TODO(gus) is there a typedef for the code?
-  std::unordered_map<uint8_t, std::string> code_to_name_;
-  std::unordered_map<std::string, uint8_t> name_to_code_;
-};
-
-/*!
- * \brief Convert scalar value to a custom datatype format
- * \param type_code The custom datatype to convert to, specified by type code
- * \param value The floating point value to convert
- * \return The value, encoded in the bits of a uint64_t
- */
-uint64_t ConvertConstScalar(uint8_t type_code, double value);
-
-/*!
- * \brief Get a function returning the minimum value for a datatype.
- * \param type_code The datatype
- * \return Function which takes the width of the datatype and returns the min 
value
- */
-std::optional<tvm::ffi::Function> GetMinFunc(uint8_t type_code);
-
-/*!
- * \brief Get lowering function for Cast ops
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype being cast to
- * \param src_type_code The datatype being cast from
- * \return Lowering function for Cast ops for the provided target, type, and 
source type
- */
-std::optional<tvm::ffi::Function> GetCastLowerFunc(const std::string& target, 
uint8_t type_code,
-                                                   uint8_t src_type_code);
-
-/*!
- * \brief Get lowering function for FloatImms
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the FloatImm
- * \return Lowering function for FloatImms for the provided target and type
- */
-std::optional<tvm::ffi::Function> GetFloatImmLowerFunc(const std::string& 
target,
-                                                       uint8_t type_code);
-
-/*!
- * \brief Get lowering function for intrinsic Calls/pure intrinsic Calls
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the Call
- * \param name The intrinsic name
- * \return Lowering function for intrinsic Calls for the provided target and 
type
- */
-std::optional<tvm::ffi::Function> GetIntrinLowerFunc(const std::string& target,
-                                                     const std::string& name, 
uint8_t type_code);
-
-/*!
- * \brief Get lowering function for other ops
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the op
- * \return Lowering function for other ops for the provided target and type
- */
-#define DEFINE_GET_LOWER_FUNC_(OP)                                             
                 \
-  inline std::optional<tvm::ffi::Function> Get##OP##LowerFunc(const 
std::string& target,        \
-                                                              uint8_t 
type_code) {              \
-    return tvm::ffi::Function::GetGlobal("tvm.datatype.lower." + target + "." 
#OP "." +         \
-                                         
datatype::Registry::Global()->GetTypeName(type_code)); \
-  }
-
-DEFINE_GET_LOWER_FUNC_(Add)
-DEFINE_GET_LOWER_FUNC_(Sub)
-DEFINE_GET_LOWER_FUNC_(Mul)
-DEFINE_GET_LOWER_FUNC_(Div)
-DEFINE_GET_LOWER_FUNC_(Mod)
-DEFINE_GET_LOWER_FUNC_(Min)
-DEFINE_GET_LOWER_FUNC_(Max)
-DEFINE_GET_LOWER_FUNC_(EQ)
-DEFINE_GET_LOWER_FUNC_(NE)
-DEFINE_GET_LOWER_FUNC_(LT)
-DEFINE_GET_LOWER_FUNC_(LE)
-DEFINE_GET_LOWER_FUNC_(GT)
-DEFINE_GET_LOWER_FUNC_(GE)
-// Later changes may need to add more lowering functions as we support 
workloads with more ops.
-
-}  // namespace datatype
-}  // namespace tvm
-
-#endif  // TVM_TARGET_DATATYPE_REGISTRY_H_
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 97422bf9ed..88a28ebccb 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -597,11 +597,10 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) 
const {
   if (auto* ptr = type.as<PrimTypeNode>()) {
     return DTypeToLLVMType(ptr->dtype);
   } else if (auto* ptr = type.as<PointerTypeNode>()) {
-    // LLVM IR doesn't allow void*, nor do we require custom datatypes
-    // to have LLVM equivalents, so we need to recognize these
-    // patterns explicitly.
+    // LLVM IR doesn't allow void*, so pointer element types that do not
+    // have an LLVM scalar equivalent need explicit handling.
     if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
-      if (primtype->dtype.is_void() || primtype->dtype.code() >= 
DataType::kCustomBegin) {
+      if (primtype->dtype.is_void()) {
         return t_void_p_;
       }
     } else if (ptr->element_type->IsInstance<TensorMapTypeNode>()) {
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index 64f7f575d6..5cf896e4fd 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -35,7 +35,6 @@
 #include <cmath>
 // Centralized header for constant folders.
 #include "../../arith/const_fold.h"
-#include "../../target/datatype/registry.h"
 #include "../analysis/check_contains.h"
 
 namespace tvm {
@@ -211,22 +210,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, 
Span span) {  // NOLINT(*)
     } else {
       rhs = cast(ltype, rhs);
     }
-  } else if (!ltype.is_float() &&
-             (rtype.is_float() || 
datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
+  } else if (!ltype.is_float() && rtype.is_float()) {
     // Cast int->float when the other operand is a float
     lhs = cast(rtype, lhs);
-  } else if ((ltype.is_float() || 
datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
-             !rtype.is_float()) {
+  } else if (ltype.is_float() && !rtype.is_float()) {
     // Cast int->float when the other operand is a float
     rhs = cast(ltype, rhs);
-  } else if (!ltype.is_bfloat16() &&
-             (rtype.is_bfloat16() ||
-              datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
+  } else if (!ltype.is_bfloat16() && rtype.is_bfloat16()) {
     // Cast int->bfloat16 when the other operand is a bfloat16
     lhs = cast(rtype, lhs);
-  } else if ((ltype.is_bfloat16() ||
-              datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
-             !rtype.is_bfloat16()) {
+  } else if (ltype.is_bfloat16() && !rtype.is_bfloat16()) {
     // Cast int->bfloat16 when the other operand is a bfloat16
     rhs = cast(ltype, rhs);
   } else if (!ltype.is_float8() && rtype.is_float8()) {
@@ -369,15 +362,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
 PrimExpr min_value(const DataType& dtype, Span span) {
   using namespace tirx;
   TVM_FFI_ICHECK_EQ(dtype.lanes(), 1);
-  if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
-    // TODO(tkonolige): need to convert all registered min functions to use 
the span.
-    auto f = datatype::GetMinFunc(dtype.code());
-    TVM_FFI_ICHECK(f) << "No minimum function registered for custom dtype "
-                      << (unsigned int)dtype.code();
-    // TODO(@hypercubestart) Document this change (and others associated with 
the overflowing
-    // floatimm min bug)
-    return (*f)(dtype.bits()).cast<PrimExpr>();
-  } else if (dtype.is_int()) {
+  if (dtype.is_int()) {
     if (dtype.bits() == 64) {
       return IntImm(dtype, std::numeric_limits<int64_t>::lowest(), span);
     } else if (dtype.bits() < 64) {
diff --git a/src/tirx/transform/lower_custom_datatypes.cc 
b/src/tirx/transform/lower_custom_datatypes.cc
deleted file mode 100644
index d23cfef4fb..0000000000
--- a/src/tirx/transform/lower_custom_datatypes.cc
+++ /dev/null
@@ -1,266 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file tvm/src/pass/lower_custom_datatypes.cc
- * \brief Pass for lowering custom datatypes
- */
-
-#include <tvm/ffi/cast.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/target/target.h>
-#include <tvm/tirx/op.h>
-#include <tvm/tirx/stmt_functor.h>
-#include <tvm/tirx/transform.h>
-
-#include "../../target/datatype/registry.h"
-
-namespace tvm {
-namespace tirx {
-
-/*!
- * \brief Helper mutator to implement lowering of custom datatypes.
- *
- * Lowering datatypes works as follows: for every expression containing a 
custom
- * datatype, we search for a global (registered by the implementer of the 
custom
- * datatype) for lowering this type of expression, and uses it to lower the
- * expression.
- */
-class CustomDatatypesLowerer : public StmtExprMutator {
- public:
-  explicit CustomDatatypesLowerer(const std::string& target) : target_(target) 
{}
-
-  PrimExpr VisitExpr_(const CastNode* op) final {
-    auto type_code = op->dtype.code();
-    auto src_type_code = op->value.dtype().code();
-    // If either datatype is a registered custom datatype, we must lower.
-    bool to_be_lowered = 
datatype::Registry::Global()->GetTypeRegistered(type_code) ||
-                         
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    if (to_be_lowered) {
-      auto lower = datatype::GetCastLowerFunc(target_, type_code, 
src_type_code);
-      TVM_FFI_ICHECK(lower) << "Cast lowering function for target " << target_
-                            << " destination type " << 
static_cast<unsigned>(type_code)
-                            << " source type " << 
static_cast<unsigned>(src_type_code)
-                            << " not found";
-      return (*lower)(expr).cast<PrimExpr>();
-    }
-    return expr;
-  }
-
-  PrimExpr VisitExpr_(const FloatImmNode* imm) final {
-    auto type_code = imm->dtype.code();
-    auto e = ffi::GetRef<PrimExpr>(imm);
-    if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
-      auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
-      TVM_FFI_ICHECK(lower) << "FloatImm lowering function for target " << 
target_ << " type "
-                            << static_cast<unsigned>(type_code) << " not 
found";
-      return (*lower)(e).cast<PrimExpr>();
-    }
-    return e;
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    Var var = ffi::GetRef<Var>(op);
-
-    auto itr = var_remap_.find(var);
-    if (itr != var_remap_.end()) {
-      return itr->second;
-    } else {
-      return var;
-    }
-  }
-
-  Stmt VisitStmt_(const AllocBufferNode* op) final {
-    bool to_be_lowered = 
datatype::Registry::Global()->GetTypeRegistered(op->buffer->dtype.code());
-
-    if (to_be_lowered) {
-      auto new_allocate_type = DataType::UInt(op->buffer->dtype.bits(), 
op->buffer->dtype.lanes());
-      auto new_buffer_var =
-          Var(op->buffer->data->name_hint, 
PointerType(PrimType(new_allocate_type)));
-      var_remap_[op->buffer->data] = new_buffer_var;
-    }
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<AllocBufferNode>();
-
-    Buffer new_buf = GetRemappedBuffer(op->buffer);
-    if (!new_buf.same_as(op->buffer)) {
-      auto node = Downcast<AllocBuffer>(stmt);
-      node.CopyOnWrite()->buffer = new_buf;
-      return node;
-    }
-    return stmt;
-  }
-
-  Stmt VisitStmt_(const DeclBufferNode* op) final {
-    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    auto modified = VisitBufferAccess(node);
-
-    // Not needed for BufferStoreNode, so we can't just call
-    // LegalizeDtype() in VisitBufferAccess.
-    if (node.same_as(modified)) {
-      return node;
-
-    } else {
-      auto writer = modified.CopyOnWrite();
-      writer->LegalizeDType();
-      return modified;
-    }
-  }
-
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  template <typename Node>
-  Node VisitBufferAccess(Node node) {
-    Buffer new_buf = GetRemappedBuffer(node->buffer);
-    if (!new_buf.same_as(node->buffer)) {
-      auto writer = node.CopyOnWrite();
-      writer->buffer = new_buf;
-    }
-
-    return node;
-  }
-
-  Buffer GetRemappedBuffer(Buffer buf) {
-    auto key = buf;
-    auto cache_it = buf_remap_.find(key);
-    if (cache_it != buf_remap_.end()) {
-      return cache_it->second;
-    }
-
-    bool to_be_lowered = 
datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code());
-
-    if (to_be_lowered) {
-      auto new_load_type = DataType::UInt(buf->dtype.bits());
-      auto writer = buf.CopyOnWrite();
-      writer->dtype = new_load_type;
-
-      auto var_it = var_remap_.find(buf->data);
-      if (var_it != var_remap_.end()) {
-        writer->data = var_it->second;
-      }
-    }
-
-    buf_remap_[key] = buf;
-    return buf;
-  }
-
-  Stmt VisitStmt_(const AttrStmtNode* op) final {
-    Stmt ret = StmtExprMutator::VisitStmt_(op);
-    op = ret.as<AttrStmtNode>();
-    // Due to legacy reasons, some attr node can contain
-    // information(e.g. alignment) of buffer variables.
-    // remap these vars when needed
-    // TODO(tvm-team): remove the rewriting once the buffer var
-    // attrs are being refactored into the corresponding definition node
-    if (auto var_node = op->node.as<Var>()) {
-      auto it = var_remap_.find(var_node.value());
-      if (it != var_remap_.end()) {
-        return AttrStmt(it->second, op->attr_key, op->value, op->body);
-      }
-    }
-    return ret;
-  }
-
-  PrimExpr VisitExpr_(const CallNode* call) final {
-    bool to_be_lowered = 
datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
-    PrimExpr expr = StmtExprMutator::VisitExpr_(call);
-    call = expr.as<CallNode>();
-    if (to_be_lowered) {
-      auto op = call->op.as<OpNode>();
-      TVM_FFI_ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not 
implemented";
-      auto lower = datatype::GetIntrinLowerFunc(target_, op->name, 
call->dtype.code());
-      TVM_FFI_ICHECK(lower) << "Intrinsic lowering function for target " << 
target_
-                            << ", intrinsic name " << op->name << ", type "
-                            << static_cast<unsigned>(call->dtype.code()) << " 
not found";
-      return (*lower)(expr).cast<PrimExpr>();
-    }
-    return expr;
-  }
-
-#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName)                           
            \
-  PrimExpr VisitExpr_(const NodeName* op) final {                              
            \
-    auto type_code = op->dtype.code();                                         
            \
-    bool to_be_lowered = 
datatype::Registry::Global()->GetTypeRegistered(type_code);       \
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                           
            \
-    op = expr.as<NodeName>();                                                  
            \
-    if (to_be_lowered) {                                                       
            \
-      auto lower = datatype::Get##OP##LowerFunc(target_, type_code);           
            \
-      TVM_FFI_ICHECK(lower) << #OP " lowering function for target " << target_ 
<< " type " \
-                            << static_cast<unsigned>(type_code) << " not 
found";           \
-      return (*lower)(expr).cast<PrimExpr>();                                  
            \
-    }                                                                          
            \
-    return expr;                                                               
            \
-  }
-
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode);
-  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode);
-  // Later changes may need to add more mutate functions as we support 
workloads with more ops.
-
-#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE
-
- private:
-  std::string target_;
-  // remap buffer vars
-  std::unordered_map<Var, Var> var_remap_;
-  std::unordered_map<Buffer, Buffer, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> 
buf_remap_;
-};
-
-namespace transform {
-
-Pass LowerCustomDatatypes() {
-  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto* n = f.CopyOnWrite();
-    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
-    TVM_FFI_ICHECK(target.defined()) << "LowerCustomDatatypes: Require the 
target attribute";
-
-    n->body = 
CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body));
-    return f;
-  };
-  return CreatePrimFuncPass(pass_func, 0, "tirx.LowerCustomDatatypes", {});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tirx.transform.LowerCustomDatatypes", 
LowerCustomDatatypes);
-}
-
-}  // namespace transform
-
-}  // namespace tirx
-}  // namespace tvm

Reply via email to