This is an automated email from the ASF dual-hosted git repository. uwe pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 4481b07 ARROW-2608: [Java/Python] Add pyarrow.{Array,Field}.from_jvm / jvm_buffer 4481b07 is described below commit 4481b070c9eca4140aaa3a2470ede920411598a0 Author: Korn, Uwe <uwe.k...@blue-yonder.com> AuthorDate: Tue Jun 26 10:20:14 2018 +0200 ARROW-2608: [Java/Python] Add pyarrow.{Array,Field}.from_jvm / jvm_buffer `jpype` is used for the tests but the actual implementation does not reference any specific Python<->Java bridge. It may work with others but this is not tested. Author: Korn, Uwe <uwe.k...@blue-yonder.com> Closes #2062 from xhochy/ARROW-2608 and squashes the following commits: 701e1b8e <Korn, Uwe> Add a module docstring 1d35d234 <Korn, Uwe> ARROW-2608: Add pyarrow.{Array,Field}.from_jvm / jvm_buffer --- .travis.yml | 4 + ci/travis_script_python.sh | 5 + python/pyarrow/jvm.py | 255 +++++++++++++++++++++++++++++++++++++++ python/pyarrow/tests/test_jvm.py | 216 +++++++++++++++++++++++++++++++++ 4 files changed, 480 insertions(+) diff --git a/.travis.yml b/.travis.yml index 195fe00..b172dff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,6 +45,7 @@ matrix: - compiler: gcc language: cpp os: linux + jdk: openjdk8 env: - ARROW_TRAVIS_USE_TOOLCHAIN=1 - ARROW_TRAVIS_VALGRIND=1 @@ -55,6 +56,7 @@ matrix: - ARROW_TRAVIS_PYTHON_BENCHMARKS=1 - ARROW_TRAVIS_PYTHON_DOCS=1 - ARROW_BUILD_WARNING_LEVEL=CHECKIN + - ARROW_TRAVIS_PYTHON_JVM=1 - CC="clang-6.0" - CXX="clang++-6.0" before_script: @@ -71,6 +73,8 @@ matrix: # All test steps are required for accurate C++ coverage info - $TRAVIS_BUILD_DIR/ci/travis_script_cpp.sh - $TRAVIS_BUILD_DIR/ci/travis_build_parquet_cpp.sh + # Build Arrow Java to test the pyarrow<->JVM in-process bridge + - $TRAVIS_BUILD_DIR/ci/travis_script_java.sh - $TRAVIS_BUILD_DIR/ci/travis_script_python.sh 2.7 - $TRAVIS_BUILD_DIR/ci/travis_script_python.sh 3.6 - $TRAVIS_BUILD_DIR/ci/travis_upload_cpp_coverage.sh diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh index 3c0ea6d..c1dbbac 100755 --- a/ci/travis_script_python.sh +++ b/ci/travis_script_python.sh @@ -36,10 +36,15 @@ source activate $CONDA_ENV_DIR python --version which python +if [ $ARROW_TRAVIS_PYTHON_JVM == "1" ]; then + CONDA_JVM_DEPS="jpype1" +fi + conda install -y -q pip \ nomkl \ cloudpickle \ numpy=1.13.1 \ + ${CONDA_JVM_DEPS} \ pandas \ cython diff --git a/python/pyarrow/jvm.py b/python/pyarrow/jvm.py new file mode 100644 index 0000000..fe2efad --- /dev/null +++ b/python/pyarrow/jvm.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Functions to interact with Arrow memory allocated by Arrow Java. + +These functions convert the objects holding the metadata, the actual +data is not copied at all. + +This will only work with a JVM running in the same process such as provided +through jpype. Modules that talk to a remote JVM like py4j will not work as the +memory addresses reported by them are not reachable in the python process. +""" + + +import pyarrow as pa + + +def jvm_buffer(arrowbuf): + """ + Construct an Arrow buffer from io.netty.buffer.ArrowBuf + + Parameters + ---------- + + arrowbuf: io.netty.buffer.ArrowBuf + Arrow Buffer representation on the JVM + + Returns + ------- + pyarrow.Buffer + Python Buffer that references the JVM memory + """ + address = arrowbuf.memoryAddress() + size = arrowbuf.capacity() + return pa.foreign_buffer(address, size, arrowbuf.unwrap()) + + +def _from_jvm_int_type(jvm_type): + """ + Convert a JVM int type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Int + + Returns + ------- + typ: pyarrow.DataType + """ + if jvm_type.isSigned: + if jvm_type.bitWidth == 8: + return pa.int8() + elif jvm_type.bitWidth == 16: + return pa.int16() + elif jvm_type.bitWidth == 32: + return pa.int32() + elif jvm_type.bitWidth == 64: + return pa.int64() + else: + if jvm_type.bitWidth == 8: + return pa.uint8() + elif jvm_type.bitWidth == 16: + return pa.uint16() + elif jvm_type.bitWidth == 32: + return pa.uint32() + elif jvm_type.bitWidth == 64: + return pa.uint64() + + +def _from_jvm_float_type(jvm_type): + """ + Convert a JVM float type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$FloatingPoint + + Returns + ------- + typ: pyarrow.DataType + """ + precision = jvm_type.getPrecision().toString() + if precision == 'HALF': + return pa.float16() + elif precision == 'SINGLE': + return pa.float32() + elif precision == 'DOUBLE': + return pa.float64() + + +def _from_jvm_time_type(jvm_type): + """ + Convert a JVM time type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Time + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + if time_unit == 'SECOND': + assert jvm_type.bitWidth == 32 + return pa.time32('s') + elif time_unit == 'MILLISECOND': + assert jvm_type.bitWidth == 32 + return pa.time32('ms') + elif time_unit == 'MICROSECOND': + assert jvm_type.bitWidth == 64 + return pa.time64('us') + elif time_unit == 'NANOSECOND': + assert jvm_type.bitWidth == 64 + return pa.time64('ns') + + +def _from_jvm_timestamp_type(jvm_type): + """ + Convert a JVM timestamp type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Timestamp + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + timezone = jvm_type.getTimezone() + if time_unit == 'SECOND': + return pa.timestamp('s', tz=timezone) + elif time_unit == 'MILLISECOND': + return pa.timestamp('ms', tz=timezone) + elif time_unit == 'MICROSECOND': + return pa.timestamp('us', tz=timezone) + elif time_unit == 'NANOSECOND': + return pa.timestamp('ns', tz=timezone) + + +def _from_jvm_date_type(jvm_type): + """ + Convert a JVM date type to its Python equivalent + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Date + + Returns + ------- + typ: pyarrow.DataType + """ + day_unit = jvm_type.getUnit().toString() + if day_unit == 'DAY': + return pa.date32() + elif day_unit == 'MILLISECOND': + return pa.date64() + + +def field(jvm_field): + """ + Construct a Field from a org.apache.arrow.vector.types.pojo.Field + instance. + + Parameters + ---------- + jvm_field: org.apache.arrow.vector.types.pojo.Field + + Returns + ------- + pyarrow.Field + """ + name = jvm_field.getName() + jvm_type = jvm_field.getType() + + typ = None + if not jvm_type.isComplex(): + type_str = jvm_type.getTypeID().toString() + if type_str == 'Null': + typ = pa.null() + elif type_str == 'Int': + typ = _from_jvm_int_type(jvm_type) + elif type_str == 'FloatingPoint': + typ = _from_jvm_float_type(jvm_type) + elif type_str == 'Utf8': + typ = pa.string() + elif type_str == 'Binary': + typ = pa.binary() + elif type_str == 'FixedSizeBinary': + typ = pa.binary(jvm_type.getByteWidth()) + elif type_str == 'Bool': + typ = pa.bool_() + elif type_str == 'Time': + typ = _from_jvm_time_type(jvm_type) + elif type_str == 'Timestamp': + typ = _from_jvm_timestamp_type(jvm_type) + elif type_str == 'Date': + typ = _from_jvm_date_type(jvm_type) + elif type_str == 'Decimal': + typ = pa.decimal128(jvm_type.getPrecision(), jvm_type.getScale()) + else: + raise NotImplementedError( + "Unsupported JVM type: {}".format(type_str)) + else: + # TODO: The following JVM types are not implemented: + # Struct, List, FixedSizeList, Union, Dictionary + raise NotImplementedError( + "JVM field conversion only implemented for primitive types.") + + nullable = jvm_field.isNullable() + if jvm_field.getMetadata().isEmpty(): + metadata = None + else: + metadata = dict(jvm_field.getMetadata()) + return pa.field(name, typ, nullable, metadata) + + +def array(jvm_array): + """ + Construct an (Python) Array from its JVM equivalent. + + Parameters + ---------- + jvm_array : org.apache.arrow.vector.ValueVector + + Returns + ------- + array : Array + """ + if jvm_array.getField().getType().isComplex(): + minor_type_str = jvm_array.getMinorType().toString() + raise NotImplementedError( + "Cannot convert JVM Arrow array of type {}," + " complex types not yet implemented.".format(minor_type_str)) + dtype = field(jvm_array.getField()).type + length = jvm_array.getValueCount() + buffers = [jvm_buffer(buf) + for buf in list(jvm_array.getBuffers(False))] + null_count = jvm_array.getNullCount() + return pa.Array.from_buffers(dtype, length, buffers, null_count) diff --git a/python/pyarrow/tests/test_jvm.py b/python/pyarrow/tests/test_jvm.py new file mode 100644 index 0000000..8f47708 --- /dev/null +++ b/python/pyarrow/tests/test_jvm.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +import pyarrow as pa +import pyarrow.jvm as pa_jvm +import pytest +import sys +import xml.etree.ElementTree as ET + + +jpype = pytest.importorskip("jpype") + + +@pytest.fixture(scope="session") +def root_allocator(): + # This test requires Arrow Java to be built in the same source tree + pom_path = os.path.join( + os.path.dirname(__file__), '..', '..', '..', + 'java', 'pom.xml') + tree = ET.parse(pom_path) + version = tree.getroot().find( + 'POM:version', + namespaces={ + 'POM': 'http://maven.apache.org/POM/4.0.0' + }).text + jar_path = os.path.join( + os.path.dirname(__file__), '..', '..', '..', + 'java', 'tools', 'target', + 'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) + jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) + jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path) + return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) + + +def test_jvm_buffer(root_allocator): + # Create a buffer + jvm_buffer = root_allocator.buffer(8) + for i in range(8): + jvm_buffer.setByte(i, 8 - i) + + # Convert to Python + buf = pa_jvm.jvm_buffer(jvm_buffer) + + # Check its content + assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' + + +def _jvm_field(jvm_spec): + om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() + pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') + return om.readValue(jvm_spec, pojo_Field) + + +# In the following, we use the JSON serialization of the Field objects in Java. +# This ensures that we neither rely on the exact mechanics on how to construct +# them using Java code as well as enables us to define them as parameters +# without to invoke the JVM. +# +# The specifications were created using: +# +# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() +# field = … # Code to instantiate the field +# jvm_spec = om.writeValueAsString(field) +@pytest.mark.parametrize('typ,jvm_spec', [ + (pa.null(), '{"name":"null"}'), + (pa.bool_(), '{"name":"bool"}'), + (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), + (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), + (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), + (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), + (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), + (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), + (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), + (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), + (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), + (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), + (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), + (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), + (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), + (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), + (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'), + (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",' + '"timezone":null}'), + (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",' + '"timezone":null}'), + (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",' + '"timezone":null}'), + (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",' + '"timezone":null}'), + (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"' + ',"timezone":"UTC"}'), + (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",' + '"unit":"NANOSECOND","timezone":"Europe/Paris"}'), + (pa.date32(), '{"name":"date","unit":"DAY"}'), + (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'), + (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'), + (pa.string(), '{"name":"utf8"}'), + (pa.binary(), '{"name":"binary"}'), + (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'), + # TODO(ARROW-2609): complex types that have children + # pa.list_(pa.int32()), + # pa.struct([pa.field('a', pa.int32()), + # pa.field('b', pa.int8()), + # pa.field('c', pa.string())]), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + # TODO: DictionaryType requires a vector in the type + # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), +]) +@pytest.mark.parametrize('nullable', [True, False]) +def test_jvm_types(root_allocator, typ, jvm_spec, nullable): + spec = { + 'name': 'field_name', + 'nullable': nullable, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + result = pa_jvm.field(jvm_field) + assert result == pa.field('field_name', typ, nullable=nullable) + + +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('typ,data,jvm_type', [ + (pa.bool_(), [True, False, True, True], 'BitVector'), + (pa.uint8(), list(range(128)), 'UInt1Vector'), + (pa.uint16(), list(range(128)), 'UInt2Vector'), + (pa.int32(), list(range(128)), 'IntVector'), + (pa.int64(), list(range(128)), 'BigIntVector'), + (pa.float32(), list(range(128)), 'Float4Vector'), + (pa.float64(), list(range(128)), 'Float8Vector'), + (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'), + (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'), + (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'), + (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + (pa.date32(), list(range(128)), 'DateDayVector'), + (pa.date64(), list(range(128)), 'DateMilliVector'), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_array(root_allocator, typ, data, jvm_type): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(data)) + for i, val in enumerate(data): + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(data)) + + py_array = pa.array(data, type=typ) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array) + + +def _string_to_varchar_holder(ra, string): + nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" + holder = jpype.JClass(nvch_cls)() + if string is None: + holder.isSet = 0 + else: + holder.isSet = 1 + value = jpype.JClass("java.lang.String")("string") + std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") + bytes_ = value.getBytes(std_charsets.UTF_8) + holder.buffer = ra.buffer(len(bytes_)) + holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) + holder.start = 0 + holder.end = len(bytes_) + return holder + + +# TODO(ARROW-2607) +@pytest.mark.xfail(reason="from_buffers is only supported for " + "primitive arrays yet") +def test_jvm_string_array(root_allocator): + data = [u"string", None, u"töst"] + cls = "org.apache.arrow.vector.VarCharVector" + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew() + + for i, string in enumerate(data): + holder = _string_to_varchar_holder(root_allocator, "string") + jvm_vector.setSafe(i, holder) + jvm_vector.setValueCount(i + 1) + + py_array = pa.array(data, type=pa.string()) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array)