This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch branch-1.8.1 in repository https://gitbox.apache.org/repos/asf/sedona.git
commit fa104d5a3fc8535d9fd166958e59481db43df160 Author: Jia Yu <[email protected]> AuthorDate: Mon Feb 9 02:40:33 2026 -0700 [GH-908] Add ST_GeoHashNeighbors and ST_GeoHashNeighbor functions (#2628) --- .gitignore | 1 + .../java/org/apache/sedona/common/Functions.java | 27 ++ .../apache/sedona/common/utils/GeoHashDecoder.java | 4 +- .../sedona/common/utils/GeoHashNeighbor.java | 274 +++++++++++++++++++++ .../apache/sedona/common/utils/GeoHashUtils.java | 75 ++++++ .../sedona/common/utils/PointGeoHashEncoder.java | 4 +- .../org/apache/sedona/common/FunctionsTest.java | 113 +++++++++ docs/api/flink/Function.md | 40 +++ docs/api/snowflake/vector-data/Function.md | 48 ++++ docs/api/sql/Function.md | 40 +++ .../main/java/org/apache/sedona/flink/Catalog.java | 2 + .../apache/sedona/flink/expressions/Functions.java | 14 ++ .../java/org/apache/sedona/flink/FunctionTest.java | 35 +++ python/pyproject.toml | 1 + python/sedona/spark/sql/st_functions.py | 30 +++ python/tests/sql/test_dataframe_api.py | 26 ++ python/tests/sql/test_function.py | 27 ++ .../sedona/snowflake/snowsql/TestFunctions.java | 13 + .../sedona/snowflake/snowsql/TestFunctionsV2.java | 13 + .../org/apache/sedona/snowflake/snowsql/UDFs.java | 10 + .../apache/sedona/snowflake/snowsql/UDFsV2.java | 14 ++ .../scala/org/apache/sedona/sql/UDF/Catalog.scala | 2 + .../sql/sedona_sql/expressions/Functions.scala | 16 ++ .../expressions/InferredExpression.scala | 28 +++ .../sql/sedona_sql/expressions/st_functions.scala | 14 ++ .../apache/sedona/sql/dataFrameAPITestScala.scala | 37 +++ .../org/apache/sedona/sql/functionTestScala.scala | 33 +++ 27 files changed, 937 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 6f5f1e8d08..cd609e393b 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,4 @@ target docs-overrides/node_modules/ uv.lock +.env diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 0be88a9b1b..f0f130a5bf 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -740,6 +740,33 @@ public class Functions { return GeometryGeoHashEncoder.calculate(geometry, precision); } + /** + * Returns the 8 neighbors of the given geohash in the order: [N, NE, E, SE, S, SW, W, NW]. + * + * @param geohash a geohash string (1-12 characters, lowercase base32) + * @return an array of 8 neighbor geohash strings, or null if the input is null or empty + * @throws IllegalArgumentException if the geohash exceeds 12 characters + */ + public static String[] geohashNeighbors(String geohash) { + return GeoHashNeighbor.getNeighbors(geohash); + } + + /** + * Returns a single neighbor of the given geohash in the specified direction. + * + * <p>Accepted direction values (case-insensitive): {@code "N"}, {@code "NE"}, {@code "E"}, {@code + * "SE"}, {@code "S"}, {@code "SW"}, {@code "W"}, {@code "NW"}. + * + * @param geohash a geohash string (1-12 characters, lowercase base32) + * @param direction the compass direction of the desired neighbor + * @return the neighbor geohash string, or null if either input is null or empty + * @throws IllegalArgumentException if the geohash exceeds 12 characters or the direction is + * invalid + */ + public static String geohashNeighbor(String geohash, String direction) { + return GeoHashNeighbor.getNeighbor(geohash, direction); + } + public static Geometry pointOnSurface(Geometry geometry) { return GeomUtils.getInteriorPoint(geometry); } diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java index b28070310c..9842038c31 100644 --- a/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java +++ b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java @@ -26,8 +26,8 @@ import org.apache.sedona.common.S2Geography.PolygonGeography; import org.locationtech.jts.geom.Geometry; public class GeoHashDecoder { - private static final int[] bits = new int[] {16, 8, 4, 2, 1}; - private static final String base32 = "0123456789bcdefghjkmnpqrstuvwxyz"; + private static final int[] bits = GeoHashUtils.BITS; + private static final String base32 = GeoHashUtils.BASE32; public static class InvalidGeoHashException extends Exception { public InvalidGeoHashException(String message) { diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeoHashNeighbor.java b/common/src/main/java/org/apache/sedona/common/utils/GeoHashNeighbor.java new file mode 100644 index 0000000000..75d1266f2d --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/utils/GeoHashNeighbor.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.utils; + +/** + * Utility class for computing geohash neighbors. + * + * <p>The neighbor algorithm is ported from the geohash-java library by Silvio Heuberger (Apache 2.0 + * License). It works by de-interleaving the geohash bits into separate latitude/longitude + * components, incrementing/decrementing the appropriate component, masking for wrap-around, and + * re-interleaving back to a geohash string. + * + * @see <a href="https://github.com/kungfoo/geohash-java">geohash-java</a> + */ +public class GeoHashNeighbor { + + private static final int MAX_BIT_PRECISION = 64; + private static final long FIRST_BIT_FLAGGED = 0x8000000000000000L; + + // Shared base32 constants from GeoHashUtils + private static final char[] base32 = GeoHashUtils.BASE32_CHARS; + private static final int[] decodeArray = GeoHashUtils.DECODE_ARRAY; + + /** + * Returns all 8 neighbors of the given geohash in the order: [N, NE, E, SE, S, SW, W, NW]. + * + * @param geohash the geohash string + * @return array of 8 neighbor geohash strings, or null if input is null + */ + public static String[] getNeighbors(String geohash) { + if (geohash == null) { + return null; + } + + long[] parsed = parseGeohash(geohash); + long bits = parsed[0]; + int significantBits = (int) parsed[1]; + + // Ported from geohash-java: GeoHash.getAdjacent() + String northern = neighborLat(bits, significantBits, 1); + String eastern = neighborLon(bits, significantBits, 1); + String southern = neighborLat(bits, significantBits, -1); + String western = neighborLon(bits, significantBits, -1); + + // Diagonal neighbors: compose two cardinal moves + long[] northBits = parseGeohash(northern); + long[] southBits = parseGeohash(southern); + + return new String[] { + northern, + neighborLon(northBits[0], significantBits, 1), // NE + eastern, + neighborLon(southBits[0], significantBits, 1), // SE + southern, + neighborLon(southBits[0], significantBits, -1), // SW + western, + neighborLon(northBits[0], significantBits, -1) // NW + }; + } + + /** + * Returns the neighbor of the given geohash in the specified direction. + * + * @param geohash the geohash string + * @param direction compass direction: "n", "ne", "e", "se", "s", "sw", "w", "nw" + * (case-insensitive) + * @return the neighbor geohash string, or null if input geohash is null + */ + public static String getNeighbor(String geohash, String direction) { + if (geohash == null) { + return null; + } + if (direction == null) { + throw new IllegalArgumentException( + "Direction cannot be null. Valid values: n, ne, e, se, s, sw, w, nw"); + } + + long[] parsed = parseGeohash(geohash); + long bits = parsed[0]; + int significantBits = (int) parsed[1]; + + switch (direction.toLowerCase()) { + case "n": + return neighborLat(bits, significantBits, 1); + case "s": + return neighborLat(bits, significantBits, -1); + case "e": + return neighborLon(bits, significantBits, 1); + case "w": + return neighborLon(bits, significantBits, -1); + case "ne": + { + long[] n = parseGeohash(neighborLat(bits, significantBits, 1)); + return neighborLon(n[0], significantBits, 1); + } + case "se": + { + long[] s = parseGeohash(neighborLat(bits, significantBits, -1)); + return neighborLon(s[0], significantBits, 1); + } + case "sw": + { + long[] s = parseGeohash(neighborLat(bits, significantBits, -1)); + return neighborLon(s[0], significantBits, -1); + } + case "nw": + { + long[] n = parseGeohash(neighborLat(bits, significantBits, 1)); + return neighborLon(n[0], significantBits, -1); + } + default: + throw new IllegalArgumentException( + "Invalid direction: '" + direction + "'. Valid values: n, ne, e, se, s, sw, w, nw"); + } + } + + /** + * Parses a geohash string into [bits, significantBits]. Based on geohash-java + * GeoHash.fromGeohashString(). + */ + private static long[] parseGeohash(String geohash) { + if (geohash.isEmpty()) { + throw new IllegalArgumentException("Geohash string cannot be empty"); + } + if (geohash.length() * GeoHashUtils.BITS_PER_CHAR > MAX_BIT_PRECISION) { + throw new IllegalArgumentException( + "Geohash '" + + geohash + + "' is too long (max " + + (MAX_BIT_PRECISION / GeoHashUtils.BITS_PER_CHAR) + + " characters)"); + } + long bits = 0; + int significantBits = 0; + for (int i = 0; i < geohash.length(); i++) { + char c = geohash.charAt(i); + int cd; + if (c >= decodeArray.length || (cd = decodeArray[c]) < 0) { + throw new IllegalArgumentException( + "Invalid character '" + c + "' in geohash '" + geohash + "'"); + } + for (int j = 0; j < GeoHashUtils.BITS_PER_CHAR; j++) { + significantBits++; + bits <<= 1; + if ((cd & (16 >> j)) != 0) { + bits |= 0x1; + } + } + } + bits <<= (MAX_BIT_PRECISION - significantBits); + return new long[] {bits, significantBits}; + } + + // Ported from geohash-java: GeoHash.getNorthernNeighbour() / getSouthernNeighbour() + private static String neighborLat(long bits, int significantBits, int delta) { + long[] latBits = getRightAlignedLatitudeBits(bits, significantBits); + long[] lonBits = getRightAlignedLongitudeBits(bits, significantBits); + latBits[0] += delta; + latBits[0] = maskLastNBits(latBits[0], latBits[1]); + return recombineLatLonBitsToBase32(latBits, lonBits); + } + + // Ported from geohash-java: GeoHash.getEasternNeighbour() / getWesternNeighbour() + private static String neighborLon(long bits, int significantBits, int delta) { + long[] latBits = getRightAlignedLatitudeBits(bits, significantBits); + long[] lonBits = getRightAlignedLongitudeBits(bits, significantBits); + lonBits[0] += delta; + lonBits[0] = maskLastNBits(lonBits[0], lonBits[1]); + return recombineLatLonBitsToBase32(latBits, lonBits); + } + + // Ported from geohash-java: GeoHash.getRightAlignedLatitudeBits() + private static long[] getRightAlignedLatitudeBits(long bits, int significantBits) { + long copyOfBits = bits << 1; + int[] numBits = getNumberOfLatLonBits(significantBits); + long value = extractEverySecondBit(copyOfBits, numBits[0]); + return new long[] {value, numBits[0]}; + } + + // Ported from geohash-java: GeoHash.getRightAlignedLongitudeBits() + private static long[] getRightAlignedLongitudeBits(long bits, int significantBits) { + int[] numBits = getNumberOfLatLonBits(significantBits); + long value = extractEverySecondBit(bits, numBits[1]); + return new long[] {value, numBits[1]}; + } + + // Copied from geohash-java: GeoHash.extractEverySecondBit() + private static long extractEverySecondBit(long copyOfBits, int numberOfBits) { + long value = 0; + for (int i = 0; i < numberOfBits; i++) { + if ((copyOfBits & FIRST_BIT_FLAGGED) == FIRST_BIT_FLAGGED) { + value |= 0x1; + } + value <<= 1; + copyOfBits <<= 2; + } + value >>>= 1; + return value; + } + + // Copied from geohash-java: GeoHash.getNumberOfLatLonBits() + private static int[] getNumberOfLatLonBits(int significantBits) { + if (significantBits % 2 == 0) { + return new int[] {significantBits / 2, significantBits / 2}; + } else { + return new int[] {significantBits / 2, significantBits / 2 + 1}; + } + } + + // Copied from geohash-java: GeoHash.maskLastNBits() + private static long maskLastNBits(long value, long n) { + long mask = 0xFFFFFFFFFFFFFFFFL; + mask >>>= (MAX_BIT_PRECISION - n); + return value & mask; + } + + /** + * Re-interleaves lat/lon bits and converts to base32 string. Simplified from geohash-java's + * GeoHash.recombineLatLonBitsToHash() — we only need the base32 output, not the full GeoHash + * object with bounding box. + */ + private static String recombineLatLonBitsToBase32(long[] latBits, long[] lonBits) { + int significantBits = (int) (latBits[1] + lonBits[1]); + long lat = latBits[0] << (MAX_BIT_PRECISION - latBits[1]); + long lon = lonBits[0] << (MAX_BIT_PRECISION - lonBits[1]); + + long bits = 0; + boolean isEvenBit = false; + for (int i = 0; i < significantBits; i++) { + bits <<= 1; + if (isEvenBit) { + if ((lat & FIRST_BIT_FLAGGED) == FIRST_BIT_FLAGGED) { + bits |= 0x1; + } + lat <<= 1; + } else { + if ((lon & FIRST_BIT_FLAGGED) == FIRST_BIT_FLAGGED) { + bits |= 0x1; + } + lon <<= 1; + } + isEvenBit = !isEvenBit; + } + bits <<= (MAX_BIT_PRECISION - significantBits); + + // Ported from geohash-java: GeoHash.toBase32() + StringBuilder buf = new StringBuilder(); + long firstFiveBitsMask = 0xF800000000000000L; + long bitsCopy = bits; + int numChars = significantBits / GeoHashUtils.BITS_PER_CHAR; + for (int i = 0; i < numChars; i++) { + int pointer = (int) ((bitsCopy & firstFiveBitsMask) >>> 59); + buf.append(base32[pointer]); + bitsCopy <<= 5; + } + return buf.toString(); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeoHashUtils.java b/common/src/main/java/org/apache/sedona/common/utils/GeoHashUtils.java new file mode 100644 index 0000000000..267d6f6ec9 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/utils/GeoHashUtils.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.utils; + +import java.util.Arrays; + +/** + * Shared constants and utilities for geohash encoding/decoding. + * + * <p>Centralizes the base32 alphabet and decode lookup table previously duplicated across {@link + * GeoHashDecoder}, {@link PointGeoHashEncoder}, and {@link GeoHashNeighbor}. + */ +public class GeoHashUtils { + + /** The base32 alphabet used by geohash encoding (Gustavo Niemeyer's specification). */ + public static final String BASE32 = "0123456789bcdefghjkmnpqrstuvwxyz"; + + /** + * Base32 character array for index-based lookup. Package-private to prevent external mutation. + */ + static final char[] BASE32_CHARS = BASE32.toCharArray(); + + /** + * Bit masks for extracting 5-bit groups: {16, 8, 4, 2, 1}. Package-private to prevent external + * mutation. + */ + static final int[] BITS = new int[] {16, 8, 4, 2, 1}; + + /** Number of bits per base32 character. */ + public static final int BITS_PER_CHAR = 5; + + /** + * Reverse lookup array: maps a character (as int) to its base32 index. Invalid characters map to + * -1. Package-private to prevent external mutation. + */ + static final int[] DECODE_ARRAY = new int['z' + 1]; + + static { + Arrays.fill(DECODE_ARRAY, -1); + for (int i = 0; i < BASE32_CHARS.length; i++) { + DECODE_ARRAY[BASE32_CHARS[i]] = i; + } + } + + /** + * Decodes a single base32 character to its integer value (0-31). + * + * @param c the character to decode + * @return the integer value, or -1 if the character is not a valid base32 character + */ + public static int decodeChar(char c) { + if (c >= DECODE_ARRAY.length) { + return -1; + } + return DECODE_ARRAY[c]; + } + + private GeoHashUtils() {} +} diff --git a/common/src/main/java/org/apache/sedona/common/utils/PointGeoHashEncoder.java b/common/src/main/java/org/apache/sedona/common/utils/PointGeoHashEncoder.java index e8be8249ed..8ad7113a5a 100644 --- a/common/src/main/java/org/apache/sedona/common/utils/PointGeoHashEncoder.java +++ b/common/src/main/java/org/apache/sedona/common/utils/PointGeoHashEncoder.java @@ -21,8 +21,8 @@ package org.apache.sedona.common.utils; import org.locationtech.jts.geom.Point; public class PointGeoHashEncoder { - private static String base32 = "0123456789bcdefghjkmnpqrstuvwxyz"; - private static int[] bits = new int[] {16, 8, 4, 2, 1}; + private static final String base32 = GeoHashUtils.BASE32; + private static final int[] bits = GeoHashUtils.BITS; public static String calculateGeoHash(Point geom, long precision) { BBox bbox = new BBox(-180, 180, -90, 90); diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index c5eac14c39..bf84e5dff2 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -5031,4 +5031,117 @@ public class FunctionsTest extends TestBase { + straightSkel.getNumGeometries(), result.getNumGeometries() <= straightSkel.getNumGeometries()); } + + @Test + public void testGeoHashNeighborsKnownValues() { + // Known values from geohash-java: GeoHashTest.testKnownNeighbouringHashes() + // Center: "u1pb" + assertEquals("u1pc", Functions.geohashNeighbor("u1pb", "n")); + assertEquals("u0zz", Functions.geohashNeighbor("u1pb", "s")); + assertEquals("u300", Functions.geohashNeighbor("u1pb", "e")); + assertEquals("u1p8", Functions.geohashNeighbor("u1pb", "w")); + + // Double east: u1pb -> e -> u300 -> e -> u302 + assertEquals("u302", Functions.geohashNeighbor(Functions.geohashNeighbor("u1pb", "e"), "e")); + } + + @Test + public void testGeoHashNeighborsAllEight() { + // Known values from geohash-java: GeoHashTest.testKnownAdjacentNeighbours() + // Center: "dqcjqc" + String[] neighbors = Functions.geohashNeighbors("dqcjqc"); + assertNotNull(neighbors); + assertEquals(8, neighbors.length); + // Verify all 8 expected values are present + java.util.Set<String> expected = + new java.util.HashSet<>( + java.util.Arrays.asList( + "dqcjqf", "dqcjr4", "dqcjr1", "dqcjr0", "dqcjq9", "dqcjq8", "dqcjqb", "dqcjqd")); + java.util.Set<String> actual = new java.util.HashSet<>(java.util.Arrays.asList(neighbors)); + assertEquals(expected, actual); + + // Center: "u1x0dfg" (7-char precision) + neighbors = Functions.geohashNeighbors("u1x0dfg"); + assertNotNull(neighbors); + assertEquals(8, neighbors.length); + expected = + new java.util.HashSet<>( + java.util.Arrays.asList( + "u1x0dg4", "u1x0dg5", "u1x0dgh", "u1x0dfu", "u1x0dfs", "u1x0dfe", "u1x0dfd", + "u1x0dff")); + actual = new java.util.HashSet<>(java.util.Arrays.asList(neighbors)); + assertEquals(expected, actual); + + // Center: "sp2j" (near prime meridian — neighbors cross into "ezr*" prefix) + neighbors = Functions.geohashNeighbors("sp2j"); + assertNotNull(neighbors); + assertEquals(8, neighbors.length); + expected = + new java.util.HashSet<>( + java.util.Arrays.asList( + "ezry", "sp2n", "sp2q", "sp2m", "sp2k", "sp2h", "ezru", "ezrv")); + actual = new java.util.HashSet<>(java.util.Arrays.asList(neighbors)); + assertEquals(expected, actual); + } + + @Test + public void testGeoHashNeighborNearMeridian() { + // From geohash-java: GeoHashTest.testNeibouringHashesNearMeridian() + // "sp2j" is near the prime meridian; western neighbors cross into "ezr*" prefix + assertEquals("ezrv", Functions.geohashNeighbor("sp2j", "w")); + assertEquals("ezrt", Functions.geohashNeighbor(Functions.geohashNeighbor("sp2j", "w"), "w")); + } + + @Test + public void testGeoHashNeighborCircleMovement() { + // Moving E -> S -> W -> N should return to the starting cell + // From geohash-java: GeoHashTest.testMovingInCircle() + String[] testHashes = {"u1pb", "sp2j", "ezrv", "dqcjqc", "u1x0dfg", "pbpbpbpbpbpb"}; + for (String start : testHashes) { + String result = Functions.geohashNeighbor(start, "e"); + result = Functions.geohashNeighbor(result, "s"); + result = Functions.geohashNeighbor(result, "w"); + result = Functions.geohashNeighbor(result, "n"); + assertEquals("Circle movement failed for " + start, start, result); + } + } + + @Test + public void testGeoHashNeighborCaseInsensitive() { + assertEquals(Functions.geohashNeighbor("u1pb", "n"), Functions.geohashNeighbor("u1pb", "N")); + assertEquals(Functions.geohashNeighbor("u1pb", "ne"), Functions.geohashNeighbor("u1pb", "NE")); + assertEquals(Functions.geohashNeighbor("u1pb", "sw"), Functions.geohashNeighbor("u1pb", "SW")); + } + + @Test + public void testGeoHashNeighborsNullInput() { + assertNull(Functions.geohashNeighbors(null)); + assertNull(Functions.geohashNeighbor(null, "n")); + } + + @Test(expected = IllegalArgumentException.class) + public void testGeoHashNeighborsEmptyInput() { + Functions.geohashNeighbors(""); + } + + @Test(expected = IllegalArgumentException.class) + public void testGeoHashNeighborInvalidDirection() { + Functions.geohashNeighbor("u1pb", "north"); + } + + @Test(expected = IllegalArgumentException.class) + public void testGeoHashNeighborNullDirection() { + Functions.geohashNeighbor("u1pb", null); + } + + @Test(expected = IllegalArgumentException.class) + public void testGeoHashNeighborsInvalidCharacter() { + Functions.geohashNeighbors("u1pb!"); + } + + @Test(expected = IllegalArgumentException.class) + public void testGeoHashNeighborsTooLongInput() { + // 13 chars * 5 bits = 65 > 64 (MAX_BIT_PRECISION) + Functions.geohashNeighbors("0123456789abc"); + } } diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md index 019f91fed2..348e794e9f 100644 --- a/docs/api/flink/Function.md +++ b/docs/api/flink/Function.md @@ -1779,6 +1779,46 @@ Output: u3r0p ``` +## ST_GeoHashNeighbors + +Introduction: Returns the 8 neighboring geohash cells of a given geohash string. The result is an array of 8 geohash strings in the order: N, NE, E, SE, S, SW, W, NW. + +Format: `ST_GeoHashNeighbors(geohash: String)` + +Since: `v1.9.0` + +Example: + +```sql +SELECT ST_GeoHashNeighbors('u1pb') +``` + +Output: + +``` +[u1pc, u301, u300, u2bp, u0zz, u0zx, u1p8, u1p9] +``` + +## ST_GeoHashNeighbor + +Introduction: Returns the neighbor geohash cell in the given direction. Valid directions are: `n`, `ne`, `e`, `se`, `s`, `sw`, `w`, `nw` (case-insensitive). + +Format: `ST_GeoHashNeighbor(geohash: String, direction: String)` + +Since: `v1.9.0` + +Example: + +```sql +SELECT ST_GeoHashNeighbor('u1pb', 'n') +``` + +Output: + +``` +u1pc +``` + ## ST_GeometricMedian Introduction: Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid. diff --git a/docs/api/snowflake/vector-data/Function.md b/docs/api/snowflake/vector-data/Function.md index dd769d4dea..b300594693 100644 --- a/docs/api/snowflake/vector-data/Function.md +++ b/docs/api/snowflake/vector-data/Function.md @@ -1396,6 +1396,54 @@ Result: +-----------------------------+ ``` +## ST_GeoHashNeighbors + +Introduction: Returns the 8 neighboring geohash cells of a given geohash string. The result is an array of 8 geohash strings in the order: N, NE, E, SE, S, SW, W, NW. + +Format: `ST_GeoHashNeighbors(geohash: String)` + +Example: + +Query: + +```sql +SELECT ST_GeoHashNeighbors('u1pb') +``` + +Result: + +``` ++-----------------------------+ +|geohash_neighbors | ++-----------------------------+ +|[u1pc, u301, u300, u2bp, ...] | ++-----------------------------+ +``` + +## ST_GeoHashNeighbor + +Introduction: Returns the neighbor geohash cell in the given direction. Valid directions are: `n`, `ne`, `e`, `se`, `s`, `sw`, `w`, `nw` (case-insensitive). + +Format: `ST_GeoHashNeighbor(geohash: String, direction: String)` + +Example: + +Query: + +```sql +SELECT ST_GeoHashNeighbor('u1pb', 'n') +``` + +Result: + +``` ++-----------------------------+ +|geohash_neighbor | ++-----------------------------+ +|u1pc | ++-----------------------------+ +``` + ## ST_GeometricMedian Introduction: Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid. diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index c6c5d7cb24..d2eb48d39f 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -1883,6 +1883,46 @@ Output: u3r0p ``` +## ST_GeoHashNeighbors + +Introduction: Returns the 8 neighboring geohash cells of a given geohash string. The result is an array of 8 geohash strings in the order: N, NE, E, SE, S, SW, W, NW. + +Format: `ST_GeoHashNeighbors(geohash: String)` + +Since: `v1.9.0` + +SQL Example + +```sql +SELECT ST_GeoHashNeighbors('u1pb') +``` + +Output: + +``` +[u1pc, u301, u300, u2bp, u0zz, u0zx, u1p8, u1p9] +``` + +## ST_GeoHashNeighbor + +Introduction: Returns the neighbor geohash cell in the given direction. Valid directions are: `n`, `ne`, `e`, `se`, `s`, `sw`, `w`, `nw` (case-insensitive). + +Format: `ST_GeoHashNeighbor(geohash: String, direction: String)` + +Since: `v1.9.0` + +SQL Example + +```sql +SELECT ST_GeoHashNeighbor('u1pb', 'n') +``` + +Output: + +``` +u1pc +``` + ## ST_GeometricMedian Introduction: Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid. diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index 0ee24de799..eefe48c821 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -106,6 +106,8 @@ public class Catalog { new FunctionsGeoTools.ST_Transform(), new Functions.ST_FlipCoordinates(), new Functions.ST_GeoHash(), + new Functions.ST_GeoHashNeighbors(), + new Functions.ST_GeoHashNeighbor(), new Functions.ST_Perimeter(), new Functions.ST_Perimeter2D(), new Functions.ST_PointOnSurface(), diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java index 6be5c0f562..c4567a6a98 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java @@ -951,6 +951,20 @@ public class Functions { } } + public static class ST_GeoHashNeighbors extends ScalarFunction { + @DataTypeHint("ARRAY<STRING>") + public String[] eval(String geohash) { + return org.apache.sedona.common.Functions.geohashNeighbors(geohash); + } + } + + public static class ST_GeoHashNeighbor extends ScalarFunction { + @DataTypeHint("String") + public String eval(String geohash, String direction) { + return org.apache.sedona.common.Functions.geohashNeighbor(geohash, direction); + } + } + public static class ST_Perimeter extends ScalarFunction { @DataTypeHint(value = "Double") public Double eval( diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index d1095c98e7..bd310c65b7 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -29,6 +29,7 @@ import java.util.stream.Collectors; import org.apache.commons.codec.binary.Hex; import org.apache.commons.lang3.tuple.Pair; import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; import org.apache.sedona.flink.expressions.Functions; import org.apache.sedona.flink.expressions.FunctionsGeoTools; import org.geotools.api.referencing.FactoryException; @@ -794,6 +795,40 @@ public class FunctionTest extends TestBase { assertEquals(first(pointTable).getField(0), "s0000"); } + @Test + public void testGeoHashNeighbors() { + Table resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbors('u1pb')"); + Row result = first(resultTable); + String[] neighbors = (String[]) result.getField(0); + assertEquals(8, neighbors.length); + assertEquals("u1pc", neighbors[0]); // N + assertEquals("u300", neighbors[2]); // E + assertEquals("u0zz", neighbors[4]); // S + assertEquals("u1p8", neighbors[6]); // W + } + + @Test + public void testGeoHashNeighborsNull() { + Table resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbors(CAST(NULL AS STRING))"); + assertNull(first(resultTable).getField(0)); + } + + @Test + public void testGeoHashNeighbor() { + Table resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbor('u1pb', 'n')"); + assertEquals("u1pc", first(resultTable).getField(0)); + resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbor('u1pb', 'e')"); + assertEquals("u300", first(resultTable).getField(0)); + resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbor('u1pb', 'NE')"); + assertEquals("u301", first(resultTable).getField(0)); + } + + @Test + public void testGeoHashNeighborNull() { + Table resultTable = tableEnv.sqlQuery("SELECT ST_GeoHashNeighbor(CAST(NULL AS STRING), 'n')"); + assertNull(first(resultTable).getField(0)); + } + @Test public void testGeometryType() { Table pointTable = diff --git a/python/pyproject.toml b/python/pyproject.toml index b988966c4f..b73a70d21a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,6 +54,7 @@ all = [ dev = [ "pytest", "pytest-cov", + "setuptools>=69,<82", "notebook==6.4.12", "jupyter", "mkdocs", diff --git a/python/sedona/spark/sql/st_functions.py b/python/sedona/spark/sql/st_functions.py index bc6b2d5e20..34afd113b0 100644 --- a/python/sedona/spark/sql/st_functions.py +++ b/python/sedona/spark/sql/st_functions.py @@ -763,6 +763,36 @@ def ST_GeoHash(geometry: ColumnOrName, precision: Union[ColumnOrName, int]) -> C return _call_st_function("ST_GeoHash", (geometry, precision)) +@validate_argument_types +def ST_GeoHashNeighbors(geohash: ColumnOrName) -> Column: + """Return the 8 neighboring geohash cells of the given geohash string. + + The neighbors are returned in the order: [N, NE, E, SE, S, SW, W, NW]. + + :param geohash: Geohash string column. + :type geohash: ColumnOrName + :return: Array of 8 neighboring geohash strings. + :rtype: Column + """ + return _call_st_function("ST_GeoHashNeighbors", (geohash,)) + + +@validate_argument_types +def ST_GeoHashNeighbor( + geohash: ColumnOrName, direction: Union[ColumnOrName, str] +) -> Column: + """Return the neighboring geohash cell in the specified direction. + + :param geohash: Geohash string column. + :type geohash: ColumnOrName + :param direction: Compass direction: 'n', 'ne', 'e', 'se', 's', 'sw', 'w', 'nw'. + :type direction: Union[ColumnOrName, str] + :return: Neighboring geohash string. + :rtype: Column + """ + return _call_st_function("ST_GeoHashNeighbor", (geohash, direction)) + + @validate_argument_types def ST_GeometricMedian( geometry: ColumnOrName, diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 9629a7ca55..ad158e69ae 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -612,6 +612,29 @@ test_configurations = [ ), (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"), (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"), + ( + stf.ST_GeoHashNeighbors, + ("geohash",), + "constructor", + "", + [ + "s00twy01mk", + "s00twy01mm", + "s00twy01mq", + "s00twy01ms", + "s00twy01mu", + "s00twy01mv", + "s00twy01mw", + "s00twy01my", + ], + ), + ( + stf.ST_GeoHashNeighbor, + ("geohash", lambda: f.lit("n")), + "constructor", + "", + "s00twy01mw", + ), ( stf.ST_HausdorffDistance, ( @@ -1357,6 +1380,9 @@ wrong_type_configurations = [ (stf.ST_GeometryN, ("", None)), (stf.ST_GeometryN, ("", 0.0)), (stf.ST_GeometryType, (None,)), + (stf.ST_GeoHashNeighbors, (None,)), + (stf.ST_GeoHashNeighbor, (None, "n")), + (stf.ST_GeoHashNeighbor, ("", None)), (stf.ST_GeneratePoints, (None, 0.0)), (stf.ST_GeneratePoints, ("", None)), (stf.ST_InteriorRingN, (None, 0)), diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index 1c35ce417f..74bd28ef07 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1859,6 +1859,33 @@ class TestPredicateJoin(TestBase): for calculated_geohash, expected_geohash in geohash: assert calculated_geohash == expected_geohash + def test_st_geohash_neighbors(self): + result = self.spark.sql("SELECT ST_GeoHashNeighbors('u1pb')").collect()[0][0] + + assert len(result) == 8 + # Order: N, NE, E, SE, S, SW, W, NW + assert result[0] == "u1pc" + assert result[1] == "u301" + assert result[2] == "u300" + assert result[3] == "u2bp" + assert result[4] == "u0zz" + assert result[5] == "u0zx" + assert result[6] == "u1p8" + assert result[7] == "u1p9" + + def test_st_geohash_neighbor(self): + # Test north neighbor + result_n = self.spark.sql("SELECT ST_GeoHashNeighbor('u1pb', 'n')").collect()[ + 0 + ][0] + assert result_n == "u1pc" + + # Test east neighbor + result_e = self.spark.sql("SELECT ST_GeoHashNeighbor('u1pb', 'e')").collect()[ + 0 + ][0] + assert result_e == "u300" + def test_geom_from_geohash(self): # Given geometry_df = self.spark.createDataFrame( diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java index 20ab767fad..58bf1dc79f 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java @@ -471,6 +471,19 @@ public class TestFunctions extends TestBase { "u3r0p"); } + @Test + public void test_ST_GeoHashNeighbors() { + registerUDF("ST_GeoHashNeighbors", String.class); + verifySqlSingleRes("select ARRAY_SIZE(sedona.ST_GeoHashNeighbors('u1pb'))", 8); + } + + @Test + public void test_ST_GeoHashNeighbor() { + registerUDF("ST_GeoHashNeighbor", String.class, String.class); + verifySqlSingleRes("select sedona.ST_GeoHashNeighbor('u1pb', 'n')", "u1pc"); + verifySqlSingleRes("select sedona.ST_GeoHashNeighbor('u1pb', 'e')", "u300"); + } + @Test public void test_ST_GeometryN() { registerUDF("ST_GeometryN", byte[].class, int.class); diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java index 425a5e288e..7aa8f15657 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java @@ -451,6 +451,19 @@ public class TestFunctionsV2 extends TestBase { "u3r0p"); } + @Test + public void test_ST_GeoHashNeighbors() { + registerUDFV2("ST_GeoHashNeighbors", String.class); + verifySqlSingleRes("select ARRAY_SIZE(sedona.ST_GeoHashNeighbors('u1pb'))", 8); + } + + @Test + public void test_ST_GeoHashNeighbor() { + registerUDFV2("ST_GeoHashNeighbor", String.class, String.class); + verifySqlSingleRes("select sedona.ST_GeoHashNeighbor('u1pb', 'n')", "u1pc"); + verifySqlSingleRes("select sedona.ST_GeoHashNeighbor('u1pb', 'e')", "u300"); + } + @Test public void test_ST_GeometryN() { registerUDFV2("ST_GeometryN", String.class, int.class); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java index 3e2d0095a6..0730d26ea3 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java @@ -439,6 +439,16 @@ public class UDFs { return Functions.geohash(GeometrySerde.deserialize(geometry), precision); } + @UDFAnnotations.ParamMeta(argNames = {"geohash"}) + public static String[] ST_GeoHashNeighbors(String geohash) { + return Functions.geohashNeighbors(geohash); + } + + @UDFAnnotations.ParamMeta(argNames = {"geohash", "direction"}) + public static String ST_GeoHashNeighbor(String geohash, String direction) { + return Functions.geohashNeighbor(geohash, direction); + } + @UDFAnnotations.ParamMeta(argNames = {"gml"}) public static byte[] ST_GeomFromGML(String gml) throws IOException, ParserConfigurationException, SAXException { diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java index 5d545f4b6c..bc3a962898 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java @@ -621,6 +621,20 @@ public class UDFsV2 { return Functions.geohash(GeometrySerde.deserGeoJson(geometry), precision); } + @UDFAnnotations.ParamMeta( + argNames = {"geohash"}, + argTypes = {"String"}) + public static String[] ST_GeoHashNeighbors(String geohash) { + return Functions.geohashNeighbors(geohash); + } + + @UDFAnnotations.ParamMeta( + argNames = {"geohash", "direction"}, + argTypes = {"String", "String"}) + public static String ST_GeoHashNeighbor(String geohash, String direction) { + return Functions.geohashNeighbor(geohash, direction); + } + @UDFAnnotations.ParamMeta( argNames = {"geometry", "n"}, argTypes = {"Geometry", "int"}, diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 8b5882a989..a2d483398a 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -187,6 +187,8 @@ object Catalog extends AbstractCatalog with Logging { function[ST_MaximumInscribedCircle](), function[ST_MaxDistance](), function[ST_GeoHash](), + function[ST_GeoHashNeighbors](), + function[ST_GeoHashNeighbor](), function[ST_GeomFromGeoHash](null), function[ST_PointFromGeoHash](null), function[ST_GeogFromGeoHash](null), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 54424510ff..329848d2fc 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1156,6 +1156,22 @@ private[apache] case class ST_GeoHash(inputExpressions: Seq[Expression]) } } +private[apache] case class ST_GeoHashNeighbors(inputExpressions: Seq[Expression]) + extends InferredExpression(Functions.geohashNeighbors _) { + + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + +private[apache] case class ST_GeoHashNeighbor(inputExpressions: Seq[Expression]) + extends InferredExpression(Functions.geohashNeighbor _) { + + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + /** * Return the difference between geometry A and B * diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala index 606dcbb931..757948216c 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala @@ -194,6 +194,8 @@ object InferrableType { new InferrableType[Array[java.lang.Long]] {} implicit val doubleArrayInstance: InferrableType[Array[Double]] = new InferrableType[Array[Double]] {} + implicit val stringArrayInstance: InferrableType[Array[String]] = + new InferrableType[Array[String]] {} implicit val javaDoubleListInstance: InferrableType[java.util.List[java.lang.Double]] = new InferrableType[java.util.List[java.lang.Double]] {} implicit val javaGeomListInstance: InferrableType[java.util.List[Geometry]] = @@ -220,6 +222,22 @@ object InferredTypes { expr.asString(input) } else if (t =:= typeOf[Array[Long]]) { expr => input => expr.eval(input).asInstanceOf[ArrayData].toLongArray() + } else if (t =:= typeOf[Array[String]]) { expr => input => + expr.eval(input).asInstanceOf[ArrayData] match { + case null => null + case arrayData: ArrayData => + val n = arrayData.numElements() + val result = new Array[String](n) + var i = 0 + while (i < n) { + if (!arrayData.isNullAt(i)) { + val utf8 = arrayData.getUTF8String(i) + if (utf8 != null) result(i) = utf8.toString + } + i += 1 + } + result + } } else if (t =:= typeOf[Array[Int]]) { expr => input => expr.eval(input).asInstanceOf[ArrayData] match { case null => null @@ -263,6 +281,14 @@ object InferredTypes { } else { null } + } else if (t =:= typeOf[Array[String]]) { output => + if (output != null) { + ArrayData.toArrayData(output.asInstanceOf[Array[String]].map { s => + if (s != null) UTF8String.fromString(s) else null + }) + } else { + null + } } else if (t =:= typeOf[java.util.List[java.lang.Double]]) { output => if (output != null) { ArrayData.toArrayData( @@ -330,6 +356,8 @@ object InferredTypes { DataTypes.createArrayType(LongType) } else if (t =:= typeOf[Array[Double]] || t =:= typeOf[java.util.List[java.lang.Double]]) { DataTypes.createArrayType(DoubleType) + } else if (t =:= typeOf[Array[String]]) { + DataTypes.createArrayType(StringType) } else if (t =:= typeOf[Option[Boolean]]) { BooleanType } else if (t =:= typeOf[Boolean]) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index f34a128307..818055b906 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -246,6 +246,20 @@ object st_functions { def ST_GeoHash(geometry: String, precision: Int): Column = wrapExpression[ST_GeoHash](geometry, precision) + def ST_GeoHashNeighbors(geohash: Column): Column = + wrapExpression[ST_GeoHashNeighbors](geohash) + def ST_GeoHashNeighbors(geohash: String): Column = + wrapExpression[ST_GeoHashNeighbors](geohash) + + def ST_GeoHashNeighbor(geohash: Column, direction: Column): Column = + wrapExpression[ST_GeoHashNeighbor](geohash, direction) + def ST_GeoHashNeighbor(geohash: String, direction: String): Column = + wrapExpression[ST_GeoHashNeighbor](geohash, direction) + def ST_GeoHashNeighbor(geohash: Column, direction: String): Column = + wrapExpression[ST_GeoHashNeighbor](geohash, direction) + def ST_GeoHashNeighbor(geohash: String, direction: Column): Column = + wrapExpression[ST_GeoHashNeighbor](geohash, direction) + def ST_GeometryN(multiGeometry: Column, n: Column): Column = wrapExpression[ST_GeometryN](multiGeometry, n) def ST_GeometryN(multiGeometry: String, n: Int): Column = diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 2fb0d5b5c5..725f4cc8a3 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -411,6 +411,43 @@ class dataFrameAPITestScala extends TestBaseScala { assert(expected.equals(actual)) } + it("Passed ST_GeoHashNeighbors") { + val df = sparkSession + .sql("SELECT 'u1pb' AS geohash") + .select(ST_GeoHashNeighbors("geohash")) + val result = df.take(1)(0).getSeq[String](0) + assert(result.size == 8) + // Order: N, NE, E, SE, S, SW, W, NW + assert(result(0) == "u1pc") + assert(result(2) == "u300") + assert(result(4) == "u0zz") + assert(result(6) == "u1p8") + } + + it("Passed ST_GeoHashNeighbor") { + val df = sparkSession + .sql("SELECT 'u1pb' AS geohash") + .select(ST_GeoHashNeighbor(col("geohash"), lit("n"))) + val result = df.take(1)(0).getString(0) + assert(result == "u1pc") + } + + it("Passed ST_GeoHashNeighbor with Column geohash and String direction") { + val df = sparkSession + .sql("SELECT 'u1pb' AS geohash, 'n' AS direction") + .select(ST_GeoHashNeighbor(col("geohash"), "direction")) + val result = df.take(1)(0).getString(0) + assert(result == "u1pc") + } + + it("Passed ST_GeoHashNeighbor with String geohash and Column direction") { + val df = sparkSession + .sql("SELECT 'u1pb' AS geohash, 'n' AS direction") + .select(ST_GeoHashNeighbor("geohash", col("direction"))) + val result = df.take(1)(0).getString(0) + assert(result == "u1pc") + } + it("passed st_geomfromgml") { val gmlString = "<gml:LineString srsName=\"EPSG:4269\"><gml:coordinates>-71.16028,42.258729 -71.160837,42.259112 -71.161143,42.25932</gml:coordinates></gml:LineString>" diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index a095fc1bfa..01b831d6c3 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -2790,6 +2790,10 @@ class functionTestScala assert(functionDf.first().get(0) == null) functionDf = sparkSession.sql("select ST_GeoHash(null, 1)") assert(functionDf.first().get(0) == null) + functionDf = sparkSession.sql("select ST_GeoHashNeighbors(null)") + assert(functionDf.first().get(0) == null) + functionDf = sparkSession.sql("select ST_GeoHashNeighbor(null, 'n')") + assert(functionDf.first().get(0) == null) functionDf = sparkSession.sql("select ST_Difference(null, null)") assert(functionDf.first().get(0) == null) functionDf = sparkSession.sql("select ST_SymDifference(null, null)") @@ -4264,4 +4268,33 @@ class functionTestScala FileUtils.deleteDirectory(new File(tmpDir)) } + + it("Should pass ST_GeoHashNeighbors") { + var result = sparkSession.sql("SELECT ST_GeoHashNeighbors('u1pb')").first().getList[String](0) + assert(result.size() == 8) + assert(result.get(0) == "u1pc") // N + assert(result.get(2) == "u300") // E + assert(result.get(4) == "u0zz") // S + assert(result.get(6) == "u1p8") // W + + result = sparkSession.sql("SELECT ST_GeoHashNeighbors('dqcjqc')").first().getList[String](0) + assert(result.size() == 8) + val expected = + Set("dqcjqf", "dqcjr4", "dqcjr1", "dqcjr0", "dqcjq9", "dqcjq8", "dqcjqb", "dqcjqd") + val actual = (0 until 8).map(result.get).toSet + assert(actual == expected) + } + + it("Should pass ST_GeoHashNeighbor") { + var result = sparkSession.sql("SELECT ST_GeoHashNeighbor('u1pb', 'n')").first().getString(0) + assert(result == "u1pc") + result = sparkSession.sql("SELECT ST_GeoHashNeighbor('u1pb', 'e')").first().getString(0) + assert(result == "u300") + result = sparkSession.sql("SELECT ST_GeoHashNeighbor('u1pb', 's')").first().getString(0) + assert(result == "u0zz") + result = sparkSession.sql("SELECT ST_GeoHashNeighbor('u1pb', 'w')").first().getString(0) + assert(result == "u1p8") + result = sparkSession.sql("SELECT ST_GeoHashNeighbor('u1pb', 'NE')").first().getString(0) + assert(result == "u301") + } }
