This is an automated email from the ASF dual-hosted git repository.
kojiromike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git
The following commit(s) were added to refs/heads/master by this push:
new e8c61ba AVRO-3162: Use Argparse to Manage Arguments (#1270)
e8c61ba is described below
commit e8c61ba4f0051a76d4bf2e2d4391ec92a6e90f4f
Author: Michael A. Smith <[email protected]>
AuthorDate: Sun Aug 29 22:34:27 2021 -0400
AVRO-3162: Use Argparse to Manage Arguments (#1270)
Improve the python cli, consisting mainly of scripts/avro and avro.tool, by
implementing argparse, and a `__main__.py` module.
- argparse handles arity and type exceptions more fluently at invocation
time than can easily be done with sys.argv.
- argparse handles sub-commands natively, especially in python 3.7 and up.
- `__main__.py` dispatch (along with setuptools
`entry_points.console_scripts` at install time) allows the cli to be invoked
with both `python -m avro` and just `avro` without having an out-of-band
`scripts/avro` directory.
- Just in case someone is still using scripts/avro, I've kept that around
with a deprecation warning.
- I added type hints as best I can; however, the avro script has an `eval`
call that we might want to get rid of, some day.
---
lang/py/avro/__main__.py | 276 ++++++++++++++++++++++++++++++++
lang/py/avro/datafile.py | 48 +++---
lang/py/avro/schema.py | 27 ++--
lang/py/avro/test/gen_interop_data.py | 63 ++++++--
lang/py/avro/test/mock_tether_parent.py | 23 +--
lang/py/avro/test/sample_http_client.py | 97 +++++------
lang/py/avro/test/test_bench.py | 29 ++--
lang/py/avro/test/test_compatibility.py | 4 +-
lang/py/avro/test/test_schema.py | 2 +-
lang/py/avro/test/test_script.py | 115 ++++++-------
lang/py/avro/tool.py | 153 +++++++++---------
lang/py/scripts/avro | 262 +-----------------------------
lang/py/setup.cfg | 6 +-
lang/py/setup.py | 3 +-
lang/py/tox.ini | 2 +-
15 files changed, 581 insertions(+), 529 deletions(-)
diff --git a/lang/py/avro/__main__.py b/lang/py/avro/__main__.py
new file mode 100755
index 0000000..423de59
--- /dev/null
+++ b/lang/py/avro/__main__.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+
+##
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Command line utility for reading and writing Avro files."""
+import argparse
+import csv
+import functools
+import itertools
+import json
+import sys
+import warnings
+from pathlib import Path
+from typing import (
+ IO,
+ Any,
+ AnyStr,
+ Callable,
+ Collection,
+ Dict,
+ Generator,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
+
+import avro
+import avro.datafile
+import avro.errors
+import avro.io
+import avro.schema
+
+
+def print_json(row: object) -> None:
+ print(json.dumps(row))
+
+
+def print_json_pretty(row: object) -> None:
+ """Pretty print JSON"""
+ # Need to work around https://bugs.python.org/issue16333
+ # where json.dumps leaves trailing spaces.
+ result = json.dumps(row, sort_keys=True, indent=4).replace(" \n", "\n")
+ print(result)
+
+
+_write_row = csv.writer(sys.stdout).writerow
+
+
+def print_csv(row: Mapping[str, object]) -> None:
+ # We sort the keys to the fields will be in the same place
+ # FIXME: Do we want to do it in schema order?
+ _write_row(row[key] for key in sorted(row))
+
+
+def select_printer(format_: str) -> Callable[[Mapping[str, object]], None]:
+ return {"json": print_json, "json-pretty": print_json_pretty, "csv":
print_csv}[format_]
+
+
+def record_match(expr: Any, record: Any) -> Callable[[Mapping[str, object]],
bool]:
+ warnings.warn(avro.errors.AvroWarning("There is no way to safely type
check this."))
+ return cast(Callable[[Mapping[str, object]], bool], eval(expr, None, {"r":
record}))
+
+
+def _passthrough_keys_filter(obj: Mapping[str, object]) -> Dict[str, object]:
+ return {**obj}
+
+
+def field_selector(fields: Collection[str]) -> Callable[[Mapping[str,
object]], Dict[str, object]]:
+ if fields:
+
+ def keys_filter(obj: Mapping[str, object]) -> Dict[str, object]:
+ return {k: obj[k] for k in obj.keys() & set(fields)}
+
+ return keys_filter
+ return _passthrough_keys_filter
+
+
+def print_avro(
+ data_file_reader: avro.datafile.DataFileReader,
+ header: bool,
+ format_: str,
+ filter_: str,
+ skip: int = 0,
+ count: int = 0,
+ fields: Collection[str] = "",
+) -> None:
+ if header and format_ != "csv":
+ raise avro.errors.UsageError("--header applies only to CSV format")
+ predicate = functools.partial(record_match, filter_) if filter_ else None
+ avro_filtered = filter(predicate, data_file_reader)
+ avro_slice = itertools.islice(avro_filtered, skip, (skip + count) if count
else None)
+ fs = field_selector(fields)
+ avro_fields = (fs(cast(Mapping[str, object], r)) for r in avro_slice)
+ avro_enum = enumerate(avro_fields)
+ printer = select_printer(format_)
+ for i, record in avro_enum:
+ if header and i == 0:
+ _write_row(sorted(record))
+ printer(record)
+
+
+def print_schema(data_file_reader: avro.datafile.DataFileReader) -> None:
+ print(json.dumps(json.loads(data_file_reader.schema), indent=4))
+
+
+def cat(
+ files: Sequence[IO[AnyStr]], print_schema_: bool, header: bool, format_:
str, filter_: str, skip: int, count: int, fields: Collection[str]
+) -> int:
+ datum_reader = avro.io.DatumReader()
+ for file_ in files:
+ with avro.datafile.DataFileReader(file_, datum_reader) as
data_file_reader:
+ if print_schema_:
+ print_schema(data_file_reader)
+ continue
+ print_avro(data_file_reader, header, format_, filter_, skip,
count, fields)
+ return 0
+
+
+def iter_json(info: Iterable[AnyStr], _: Any) -> Generator[object, None, None]:
+ for i in info:
+ row = (i if isinstance(i, str) else cast(bytes, i).decode()).strip()
+ if row:
+ yield json.loads(row)
+
+
+def convert(value: str, field: avro.schema.Field) -> Union[int, float, str,
bytes, bool, None]:
+ type_ = field.type.type
+ if type_ in ("int", "long"):
+ return int(value)
+ if type_ in ("float", "double"):
+ return float(value)
+ if type_ == "string":
+ return value
+ if type_ == "bytes":
+ return value.encode()
+ if type_ == "boolean":
+ return value.lower() in ("1", "t", "true")
+ if type_ == "null":
+ return None
+ if type_ == "union":
+ return convert_union(value, field)
+ raise avro.errors.UsageError("No valid conversion type")
+
+
+def convert_union(value: str, field: avro.schema.Field) -> Union[int, float,
str, bytes, bool, None]:
+ for name in (s.name for s in field.type.schemas):
+ try:
+ return convert(value, name)
+ except ValueError:
+ continue
+ raise avro.errors.UsageError("Exhausted Union Schema without finding a
match")
+
+
+def iter_csv(info: IO[AnyStr], schema: avro.schema.RecordSchema) ->
Generator[Dict[str, object], None, None]:
+ header = [field.name for field in schema.fields]
+ for row in csv.reader(getattr(i, "decode", lambda: i)() for i in info):
+ values = [convert(v, f) for v, f in zip(row, schema.fields)]
+ yield dict(zip(header, values))
+
+
+def guess_input_type(io_: IO[AnyStr]) -> str:
+ ext = Path(io_.name).suffixes[0].lower()
+ extensions = {".json": "json", ".js": "json", ".csv": "csv"}
+ try:
+ return extensions[ext]
+ except KeyError as e:
+ raise avro.errors.UsageError("Can't guess input file type (not .json
or .csv)") from e
+
+
+def write(schema: avro.schema.RecordSchema, output: IO[AnyStr], input_files:
Sequence[IO[AnyStr]], input_type: Optional[str] = None) -> int:
+ input_type = input_type or guess_input_type(cast(IO[AnyStr],
next(iter(input_files), sys.stdin)))
+ iter_records = {"json": iter_json, "csv": iter_csv}[input_type]
+ with avro.datafile.DataFileWriter(output, avro.io.DatumWriter(), schema)
as writer:
+ for file_ in input_files:
+ for record in iter_records(file_, schema):
+ writer.append(record)
+ return 0
+
+
+def csv_arg(arg: Optional[str]) -> Optional[Tuple[str, ...]]:
+ return tuple(f.strip() for f in arg.split(",") if f.strip()) if (arg and
arg.strip()) else None
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Display/write for Avro
files")
+ parser.add_argument("--version", action="version",
version=avro.__version__)
+ subparsers = parser.add_subparsers(required=True, dest="command") if
sys.version_info >= (3, 7) else parser.add_subparsers(dest="command")
+ subparser_cat = subparsers.add_parser("cat")
+ subparser_cat.add_argument("--version", action="version",
version=avro.__version__)
+ subparser_cat.add_argument(
+ "--count",
+ "-n",
+ default=None,
+ type=int,
+ help="number of records to print",
+ )
+ subparser_cat.add_argument(
+ "--skip",
+ "-s",
+ type=int,
+ default=0,
+ help="number of records to skip",
+ )
+ subparser_cat.add_argument(
+ "--format",
+ "-f",
+ default="json",
+ choices=("json", "csv", "json-pretty"),
+ help="record format",
+ )
+ subparser_cat.add_argument(
+ "--header",
+ "-H",
+ default=False,
+ action="store_true",
+ help="print CSV header",
+ )
+ subparser_cat.add_argument(
+ "--filter",
+ "-F",
+ default=None,
+ help="filter records (e.g. r['age']>1)",
+ )
+ subparser_cat.add_argument(
+ "--print-schema",
+ "-p",
+ action="store_true",
+ default=False,
+ help="print schema",
+ )
+ subparser_cat.add_argument("--fields", type=csv_arg, default=None,
help="fields to show, comma separated (show all by default)")
+ subparser_cat.add_argument("files", nargs="*",
type=argparse.FileType("rb"), help="Files to read", default=(sys.stdin,))
+ subparser_write = subparsers.add_parser("write")
+ subparser_write.add_argument("--version", action="version",
version=avro.__version__)
+ subparser_write.add_argument("--schema", "-s", type=avro.schema.from_path,
required=True, help="schema JSON file (required)")
+ subparser_write.add_argument(
+ "--input-type",
+ "-f",
+ choices=("json", "csv"),
+ default=None,
+ help="input file(s) type (json or csv)",
+ )
+ subparser_write.add_argument("--output", "-o", help="output file",
type=argparse.FileType("wb"), default=sys.stdout)
+ subparser_write.add_argument("input_files", nargs="*",
type=argparse.FileType("r"), help="Files to read", default=(sys.stdin,))
+ return parser.parse_args()
+
+
+def main() -> int:
+ args = _parse_args()
+ if args.command == "cat":
+ return cat(args.files, args.print_schema, args.header, args.format,
args.filter, args.skip, args.count, args.fields)
+ if args.command == "write":
+ return write(args.schema, args.output, args.input_files,
args.input_type)
+ raise avro.errors.UsageError(f"Unknown command - {args.command}")
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/lang/py/avro/datafile.py b/lang/py/avro/datafile.py
index 87383ff..22bf2dd 100644
--- a/lang/py/avro/datafile.py
+++ b/lang/py/avro/datafile.py
@@ -24,8 +24,9 @@
https://avro.apache.org/docs/current/spec.html#Object+Container+Files
"""
import io
import json
+import warnings
from types import TracebackType
-from typing import BinaryIO, MutableMapping, Optional, Type, cast
+from typing import IO, AnyStr, BinaryIO, MutableMapping, Optional, Type, cast
import avro.codecs
import avro.errors
@@ -38,18 +39,21 @@ MAGIC = bytes(b"Obj" + bytearray([VERSION]))
MAGIC_SIZE = len(MAGIC)
SYNC_SIZE = 16
SYNC_INTERVAL = 4000 * SYNC_SIZE # TODO(hammer): make configurable
-META_SCHEMA: avro.schema.RecordSchema = avro.schema.parse(
- json.dumps(
- {
- "type": "record",
- "name": "org.apache.avro.file.Header",
- "fields": [
- {"name": "magic", "type": {"type": "fixed", "name": "magic",
"size": MAGIC_SIZE}},
- {"name": "meta", "type": {"type": "map", "values": "bytes"}},
- {"name": "sync", "type": {"type": "fixed", "name": "sync",
"size": SYNC_SIZE}},
- ],
- }
- )
+META_SCHEMA = cast(
+ avro.schema.RecordSchema,
+ avro.schema.parse(
+ json.dumps(
+ {
+ "type": "record",
+ "name": "org.apache.avro.file.Header",
+ "fields": [
+ {"name": "magic", "type": {"type": "fixed", "name":
"magic", "size": MAGIC_SIZE}},
+ {"name": "meta", "type": {"type": "map", "values":
"bytes"}},
+ {"name": "sync", "type": {"type": "fixed", "name": "sync",
"size": SYNC_SIZE}},
+ ],
+ }
+ )
+ ),
)
NULL_CODEC = "null"
@@ -161,11 +165,14 @@ class DataFileWriter(_DataFileMetadata):
sync_marker: bytes
def __init__(
- self, writer: BinaryIO, datum_writer: avro.io.DatumWriter,
writers_schema: Optional[avro.schema.Schema] = None, codec: str = NULL_CODEC
+ self, writer: IO[AnyStr], datum_writer: avro.io.DatumWriter,
writers_schema: Optional[avro.schema.Schema] = None, codec: str = NULL_CODEC
) -> None:
"""If the schema is not present, presume we're appending."""
- self._writer = writer
- self._encoder = avro.io.BinaryEncoder(writer)
+ if hasattr(writer, "mode") and "b" not in writer.mode:
+ warnings.warn(avro.errors.AvroWarning(f"Writing binary data to a
writer {writer!r} that's opened for text"))
+ bytes_writer = getattr(writer, "buffer", writer)
+ self._writer = bytes_writer
+ self._encoder = avro.io.BinaryEncoder(bytes_writer)
self._datum_writer = datum_writer
self._buffer_writer = io.BytesIO()
self._buffer_encoder = avro.io.BinaryEncoder(self._buffer_writer)
@@ -307,9 +314,12 @@ class DataFileReader(_DataFileMetadata):
# TODO(hammer): allow user to specify expected schema?
# TODO(hammer): allow user to specify the encoder
- def __init__(self, reader, datum_reader):
- self._reader = reader
- self._raw_decoder = avro.io.BinaryDecoder(reader)
+ def __init__(self, reader: IO[AnyStr], datum_reader: avro.io.DatumReader)
-> None:
+ if "b" not in reader.mode:
+ warnings.warn(avro.errors.AvroWarning(f"Reader binary data from a
reader {reader!r} that's opened for text"))
+ bytes_reader = getattr(reader, "buffer", reader)
+ self._reader = bytes_reader
+ self._raw_decoder = avro.io.BinaryDecoder(bytes_reader)
self._datum_decoder = None # Maybe reset at every block.
self._datum_reader = datum_reader
diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py
index 1f73931..1d3b369 100644
--- a/lang/py/avro/schema.py
+++ b/lang/py/avro/schema.py
@@ -44,10 +44,10 @@ import datetime
import decimal
import json
import math
-import sys
import uuid
import warnings
-from typing import Mapping, MutableMapping, Optional, Sequence, cast
+from pathlib import Path
+from typing import Mapping, MutableMapping, Optional, Sequence, Union, cast
import avro.constants
import avro.errors
@@ -1189,27 +1189,22 @@ def make_avsc_object(json_data: object, names:
Optional[avro.name.Names] = None,
raise avro.errors.SchemaParseException(fail_msg)
-# TODO(hammer): make method for reading from a file?
-
-
-def parse(json_string, validate_enum_symbols=True):
+def parse(json_string: str, validate_enum_symbols: bool = True) -> Schema:
"""Constructs the Schema from the JSON text.
@arg json_string: The json string of the schema to parse
@arg validate_enum_symbols: If False, will allow enum symbols that are not
valid Avro names.
@return Schema
"""
- # parse the JSON
try:
json_data = json.loads(json_string)
- except Exception as e:
- msg = f"Error parsing JSON: {json_string}, error = {e}"
- new_exception = avro.errors.SchemaParseException(msg)
- traceback = sys.exc_info()[2]
- raise new_exception.with_traceback(traceback)
+ except json.decoder.JSONDecodeError as e:
+ raise avro.errors.SchemaParseException(f"Error parsing JSON:
{json_string}, error = {e}") from e
+ return make_avsc_object(json_data, Names(), validate_enum_symbols)
- # Initialize the names object
- names = Names()
- # construct the Avro Schema object
- return make_avsc_object(json_data, names, validate_enum_symbols)
+def from_path(path: Union[Path, str], validate_enum_symbols: bool = True) ->
Schema:
+ """
+ Constructs the Schema from a path to an avsc (json) file.
+ """
+ return parse(Path(path).read_text(), validate_enum_symbols)
diff --git a/lang/py/avro/test/gen_interop_data.py
b/lang/py/avro/test/gen_interop_data.py
index 73b47ce..d1f3c0e 100644
--- a/lang/py/avro/test/gen_interop_data.py
+++ b/lang/py/avro/test/gen_interop_data.py
@@ -18,15 +18,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
+import base64
+import io
+import json
import os
-import sys
+from pathlib import Path
+from typing import IO, TextIO
import avro.codecs
import avro.datafile
import avro.io
import avro.schema
-NULL_CODEC = "null"
+NULL_CODEC = avro.datafile.NULL_CODEC
CODECS_TO_VALIDATE = avro.codecs.KNOWN_CODECS.keys()
DATUM = {
@@ -47,17 +52,49 @@ DATUM = {
}
-def generate(schema_path, output_path):
- with open(schema_path) as schema_file:
- interop_schema = avro.schema.parse(schema_file.read())
- for codec in CODECS_TO_VALIDATE:
- filename = output_path
- if codec != NULL_CODEC:
- base, ext = os.path.splitext(output_path)
- filename = base + "_" + codec + ext
- with avro.datafile.DataFileWriter(open(filename, "wb"),
avro.io.DatumWriter(), interop_schema, codec=codec) as dfw:
- dfw.append(DATUM)
+def gen_data(codec: str, datum_writer: avro.io.DatumWriter, interop_schema:
avro.schema.Schema) -> bytes:
+ with io.BytesIO() as file_, avro.datafile.DataFileWriter(file_,
datum_writer, interop_schema, codec=codec) as dfw:
+ dfw.append(DATUM)
+ dfw.flush()
+ return file_.getvalue()
+
+
+def generate(schema_file: TextIO, output_path: IO) -> None:
+ interop_schema = avro.schema.parse(schema_file.read())
+ datum_writer = avro.io.DatumWriter()
+ output = ((codec, gen_data(codec, datum_writer, interop_schema)) for codec
in CODECS_TO_VALIDATE)
+ if output_path.isatty():
+ json.dump({codec: base64.b64encode(data).decode() for codec, data in
output}, output_path)
+ return
+ for codec, data in output:
+ if codec == NULL_CODEC:
+ output_path.write(data)
+ continue
+ base, ext = os.path.splitext(output_path.name)
+ Path(f"{base}_{codec}{ext}").write_bytes(data)
+
+
+def _parse_args() -> argparse.Namespace:
+ """Parse the command-line arguments"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("schema_path", type=argparse.FileType("r"))
+ parser.add_argument(
+ "output_path",
+ type=argparse.FileType("wb"),
+ help=(
+ "Write the different codec variants to these files. "
+ "Will append codec extensions to multiple files. "
+ "If '-', will output base64 encoded binary"
+ ),
+ )
+ return parser.parse_args()
+
+
+def main() -> int:
+ args = _parse_args()
+ generate(args.schema_path, args.output_path)
+ return 0
if __name__ == "__main__":
- generate(sys.argv[1], sys.argv[2])
+ raise SystemExit(main())
diff --git a/lang/py/avro/test/mock_tether_parent.py
b/lang/py/avro/test/mock_tether_parent.py
index 9c7c844..26ae426 100644
--- a/lang/py/avro/test/mock_tether_parent.py
+++ b/lang/py/avro/test/mock_tether_parent.py
@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
import http.server
import sys
from typing import Mapping
@@ -64,17 +65,21 @@ class MockParentHandler(http.server.BaseHTTPRequestHandler):
resp_writer.write_framed_message(resp_body)
-def main() -> None:
- global SERVER_ADDRESS
-
- if len(sys.argv) != 3 or sys.argv[1].lower() != "start_server":
- raise avro.errors.UsageError("Usage: mock_tether_parent start_server
port")
+def _parse_args() -> argparse.Namespace:
+ """Parse the command-line arguments"""
+ parser = argparse.ArgumentParser()
+ subparsers = parser.add_subparsers(required=True, dest="command") if
sys.version_info >= (3, 7) else parser.add_subparsers(dest="command")
+ subparser_start_server = subparsers.add_parser("start_server", help="Start
the server")
+ subparser_start_server.add_argument("port", type=int)
+ return parser.parse_args()
- try:
- port = int(sys.argv[2])
- except ValueError as e:
- raise avro.errors.UsageError("Usage: mock_tether_parent start_server
port") from e
+def main() -> None:
+ global SERVER_ADDRESS
+ args = _parse_args()
+ if args.command != "start_server":
+ raise NotImplementedError(f"{args.command} is not a known command")
+ port = args.port
SERVER_ADDRESS = (SERVER_ADDRESS[0], port)
print(f"mock_tether_parent: Launching Server on Port: {SERVER_ADDRESS[1]}")
diff --git a/lang/py/avro/test/sample_http_client.py
b/lang/py/avro/test/sample_http_client.py
index f1fc254..e4ae5b4 100644
--- a/lang/py/avro/test/sample_http_client.py
+++ b/lang/py/avro/test/sample_http_client.py
@@ -17,76 +17,65 @@
#
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
+import json
-import sys
-
-import avro.errors
import avro.ipc
import avro.protocol
-MAIL_PROTOCOL_JSON = """\
-{"namespace": "example.proto",
- "protocol": "Mail",
-
- "types": [
- {"name": "Message", "type": "record",
- "fields": [
- {"name": "to", "type": "string"},
- {"name": "from", "type": "string"},
- {"name": "body", "type": "string"}
- ]
- }
- ],
-
- "messages": {
- "send": {
- "request": [{"name": "message", "type": "Message"}],
- "response": "string"
- },
- "replay": {
- "request": [],
- "response": "string"
- }
- }
-}
-"""
-MAIL_PROTOCOL = avro.protocol.parse(MAIL_PROTOCOL_JSON)
+MAIL_PROTOCOL = avro.protocol.parse(
+ json.dumps(
+ {
+ "namespace": "example.proto",
+ "protocol": "Mail",
+ "types": [
+ {
+ "name": "Message",
+ "type": "record",
+ "fields": [{"name": "to", "type": "string"}, {"name":
"from", "type": "string"}, {"name": "body", "type": "string"}],
+ }
+ ],
+ "messages": {
+ "send": {"request": [{"name": "message", "type": "Message"}],
"response": "string"},
+ "replay": {"request": [], "response": "string"},
+ },
+ }
+ )
+)
SERVER_HOST = "localhost"
SERVER_PORT = 9090
-def make_requestor(server_host, server_port, protocol):
- client = avro.ipc.HTTPTransceiver(SERVER_HOST, SERVER_PORT)
+def make_requestor(server_host: str, server_port: int, protocol:
avro.protocol.Protocol) -> avro.ipc.Requestor:
+ client = avro.ipc.HTTPTransceiver(server_host, server_port)
return avro.ipc.Requestor(protocol, client)
-if __name__ == "__main__":
- if len(sys.argv) not in [4, 5]:
- raise avro.errors.UsageError("Usage: <to> <from> <body> [<count>]")
-
- # client code - attach to the server and send a message
- # fill in the Message record
- message = dict()
- message["to"] = sys.argv[1]
- message["from"] = sys.argv[2]
- message["body"] = sys.argv[3]
-
- try:
- num_messages = int(sys.argv[4])
- except IndexError:
- num_messages = 1
+def _parse_args() -> argparse.Namespace:
+ """Parse the command-line arguments"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("to", help="Who the message is to")
+ parser.add_argument("from_", help="Who the message is from")
+ parser.add_argument("body", help="The message body")
+ parser.add_argument("num_messages", type=int, default=1, help="The number
of messages")
+ return parser.parse_args()
- # build the parameters for the request
- params = {}
- params["message"] = message
+def main() -> int:
+ # client code - attach to the server and send a message fill in the
Message record
+ args = _parse_args()
+ params = {"message": {"to": args.to, "from": args.from_, "body":
args.body}}
# send the requests and print the result
- for msg_count in range(num_messages):
+ for msg_count in range(args.num_messages):
requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
result = requestor.request("send", params)
- print("Result: " + result)
-
+ print(f"Result: {result}")
# try out a replay message
requestor = make_requestor(SERVER_HOST, SERVER_PORT, MAIL_PROTOCOL)
result = requestor.request("replay", dict())
- print("Replay Result: " + result)
+ print(f"Replay Result: {result}")
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/lang/py/avro/test/test_bench.py b/lang/py/avro/test/test_bench.py
index eec667f..a188768 100644
--- a/lang/py/avro/test/test_bench.py
+++ b/lang/py/avro/test/test_bench.py
@@ -27,7 +27,7 @@ import timeit
import unittest
import unittest.mock
from pathlib import Path
-from typing import List, Mapping, Sequence
+from typing import Mapping, Sequence, cast
import avro.datafile
import avro.io
@@ -35,18 +35,21 @@ import avro.schema
from avro.utils import randbytes
TYPES = ("A", "CNAME")
-SCHEMA: avro.schema.RecordSchema = avro.schema.parse(
- json.dumps(
- {
- "type": "record",
- "name": "Query",
- "fields": [
- {"name": "query", "type": "string"},
- {"name": "response", "type": "string"},
- {"name": "type", "type": "string", "default": "A"},
- ],
- }
- )
+SCHEMA = cast(
+ avro.schema.RecordSchema,
+ avro.schema.parse(
+ json.dumps(
+ {
+ "type": "record",
+ "name": "Query",
+ "fields": [
+ {"name": "query", "type": "string"},
+ {"name": "response", "type": "string"},
+ {"name": "type", "type": "string", "default": "A"},
+ ],
+ }
+ )
+ ),
)
READER = avro.io.DatumReader(SCHEMA)
WRITER = avro.io.DatumWriter(SCHEMA)
diff --git a/lang/py/avro/test/test_compatibility.py
b/lang/py/avro/test/test_compatibility.py
index 58c09a0..d67a29b 100644
--- a/lang/py/avro/test/test_compatibility.py
+++ b/lang/py/avro/test/test_compatibility.py
@@ -403,7 +403,7 @@ RECORD1_WITH_ENUM_AB = parse(
{
"type": SchemaType.RECORD,
"name": "Record1",
- "fields": [{"name": "field1", "type":
dict(ENUM1_AB_SCHEMA.to_json())}],
+ "fields": [{"name": "field1", "type": ENUM1_AB_SCHEMA.to_json()}],
}
)
)
@@ -412,7 +412,7 @@ RECORD1_WITH_ENUM_ABC = parse(
{
"type": SchemaType.RECORD,
"name": "Record1",
- "fields": [{"name": "field1", "type":
dict(ENUM1_ABC_SCHEMA.to_json())}],
+ "fields": [{"name": "field1", "type": ENUM1_ABC_SCHEMA.to_json()}],
}
)
)
diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py
index 4506744..2015727 100644
--- a/lang/py/avro/test/test_schema.py
+++ b/lang/py/avro/test/test_schema.py
@@ -623,7 +623,7 @@ class TestMisc(unittest.TestCase):
def test_exception_is_not_swallowed_on_parse_error(self):
"""A specific exception message should appear on a json parse error."""
- self.assertRaisesRegexp(
+ self.assertRaisesRegex(
avro.errors.SchemaParseException,
r"Error parsing JSON: /not/a/real/file",
avro.schema.parse,
diff --git a/lang/py/avro/test/test_script.py b/lang/py/avro/test/test_script.py
index 07fef9b..6870c02 100644
--- a/lang/py/avro/test/test_script.py
+++ b/lang/py/avro/test/test_script.py
@@ -27,6 +27,7 @@ import subprocess
import sys
import tempfile
import unittest
+from pathlib import Path
import avro.datafile
import avro.io
@@ -35,18 +36,14 @@ import avro.schema
NUM_RECORDS = 7
-SCHEMA = """
-{
- "namespace": "test.avro",
+SCHEMA = json.dumps(
+ {
+ "namespace": "test.avro",
"name": "LooneyTunes",
"type": "record",
- "fields": [
- {"name": "first", "type": "string"},
- {"name": "last", "type": "string"},
- {"name": "type", "type": "string"}
- ]
-}
-"""
+ "fields": [{"name": "first", "type": "string"}, {"name": "last",
"type": "string"}, {"name": "type", "type": "string"}],
+ }
+)
LOONIES = (
("daffy", "duck", "duck"),
@@ -60,11 +57,8 @@ LOONIES = (
def looney_records():
- for f, l, t in LOONIES:
- yield {"first": f, "last": l, "type": t}
-
+ return ({"first": f, "last": l, "type": t} for f, l, t in LOONIES)
-SCRIPT =
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"scripts", "avro")
_JSON_PRETTY = """{
"first": "daffy",
@@ -75,12 +69,9 @@ _JSON_PRETTY = """{
def gen_avro(filename):
schema = avro.schema.parse(SCHEMA)
- fo = open(filename, "wb")
- writer = avro.datafile.DataFileWriter(fo, avro.io.DatumWriter(), schema)
- for record in looney_records():
- writer.append(record)
- writer.close()
- fo.close()
+ with avro.datafile.DataFileWriter(Path(filename).open("wb"),
avro.io.DatumWriter(), schema) as writer:
+ for record in looney_records():
+ writer.append(record)
def _tempfile():
@@ -99,10 +90,8 @@ class TestCat(unittest.TestCase):
os.unlink(self.avro_file)
def _run(self, *args, **kw):
- out = subprocess.check_output([sys.executable, SCRIPT, "cat",
self.avro_file] + list(args)).decode()
- if kw.get("raw"):
- return out
- return out.splitlines()
+ out = subprocess.check_output([sys.executable, "-m", "avro", "cat",
self.avro_file] + list(args)).decode()
+ return out if kw.get("raw") else out.splitlines()
def test_print(self):
return len(self._run()) == NUM_RECORDS
@@ -116,18 +105,18 @@ class TestCat(unittest.TestCase):
def test_csv(self):
reader = csv.reader(io.StringIO(self._run("-f", "csv", raw=True)))
- assert len(list(reader)) == NUM_RECORDS
+ self.assertEqual(len(list(reader)), NUM_RECORDS)
def test_csv_header(self):
r = {"type": "duck", "last": "duck", "first": "daffy"}
out = self._run("-f", "csv", "--header", raw=True)
io_ = io.StringIO(out)
reader = csv.DictReader(io_)
- assert next(reader) == r
+ self.assertEqual(next(reader), r)
def test_print_schema(self):
out = self._run("--print-schema", raw=True)
- assert json.loads(out)["namespace"] == "test.avro"
+ self.assertEqual(json.loads(out)["namespace"], "test.avro")
def test_help(self):
# Just see we have these
@@ -139,51 +128,47 @@ class TestCat(unittest.TestCase):
self.assertEqual(out.strip(), _JSON_PRETTY.strip())
def test_version(self):
- subprocess.check_output([sys.executable, SCRIPT, "cat", "--version"])
+ subprocess.check_output([sys.executable, "-m", "avro", "cat",
"--version"])
def test_files(self):
out = self._run(self.avro_file)
- assert len(out) == 2 * NUM_RECORDS
+ self.assertEqual(len(out), 2 * NUM_RECORDS)
def test_fields(self):
# One field selection (no comma)
out = self._run("--fields", "last")
- assert json.loads(out[0]) == {"last": "duck"}
+ self.assertEqual(json.loads(out[0]), {"last": "duck"})
# Field selection (with comma and space)
out = self._run("--fields", "first, last")
- assert json.loads(out[0]) == {"first": "daffy", "last": "duck"}
+ self.assertEqual(json.loads(out[0]), {"first": "daffy", "last":
"duck"})
# Empty fields should get all
out = self._run("--fields", "")
- assert json.loads(out[0]) == {"first": "daffy", "last": "duck",
"type": "duck"}
+ self.assertEqual(json.loads(out[0]), {"first": "daffy", "last":
"duck", "type": "duck"})
# Non existing fields are ignored
out = self._run("--fields", "first,last,age")
- assert json.loads(out[0]) == {"first": "daffy", "last": "duck"}
+ self.assertEqual(json.loads(out[0]), {"first": "daffy", "last":
"duck"})
class TestWrite(unittest.TestCase):
def setUp(self):
self.json_file = _tempfile() + ".json"
- fo = open(self.json_file, "w")
- for record in looney_records():
- json.dump(record, fo)
- fo.write("\n")
- fo.close()
+ with Path(self.json_file).open("w") as fo:
+ for record in looney_records():
+ json.dump(record, fo)
+ fo.write("\n")
self.csv_file = _tempfile() + ".csv"
- fo = open(self.csv_file, "w")
- write = csv.writer(fo).writerow
get = operator.itemgetter("first", "last", "type")
- for record in looney_records():
- write(get(record))
- fo.close()
+ with Path(self.csv_file).open("w") as fo:
+ write = csv.writer(fo).writerow
+ for record in looney_records():
+ write(get(record))
self.schema_file = _tempfile()
- fo = open(self.schema_file, "w")
- fo.write(SCHEMA)
- fo.close()
+ Path(self.schema_file).write_text(SCHEMA)
def tearDown(self):
for filename in (self.csv_file, self.json_file, self.schema_file):
@@ -193,26 +178,29 @@ class TestWrite(unittest.TestCase):
continue
def _run(self, *args, **kw):
- args = [sys.executable, SCRIPT, "write", "--schema", self.schema_file]
+ list(args)
+ args = [sys.executable, "-m", "avro", "write", "--schema",
self.schema_file] + list(args)
subprocess.check_call(args, **kw)
def load_avro(self, filename):
- out = subprocess.check_output([sys.executable, SCRIPT, "cat",
filename]).decode()
+ out = subprocess.check_output([sys.executable, "-m", "avro", "cat",
filename]).decode()
return [json.loads(o) for o in out.splitlines()]
def test_version(self):
- subprocess.check_call([sys.executable, SCRIPT, "write", "--version"])
+ subprocess.check_call([sys.executable, "-m", "avro", "write",
"--version"])
def format_check(self, format, filename):
tmp = _tempfile()
- with open(tmp, "wb") as fo:
- self._run(filename, "-f", format, stdout=fo)
-
- records = self.load_avro(tmp)
- assert len(records) == NUM_RECORDS
- assert records[0]["first"] == "daffy"
-
- os.unlink(tmp)
+ try:
+ with Path(tmp).open("w") as fo: # standard io are always text,
never binary
+ self._run(filename, "-f", format, stdout=fo)
+ records = self.load_avro(tmp)
+ finally:
+ try:
+ os.unlink(tmp)
+ except IOError: # TODO: Move this to test teardown.
+ pass
+ self.assertEqual(len(records), NUM_RECORDS)
+ self.assertEqual(records[0]["first"], "daffy")
def test_write_json(self):
self.format_check("json", self.json_file)
@@ -225,20 +213,19 @@ class TestWrite(unittest.TestCase):
os.unlink(tmp)
self._run(self.json_file, "-o", tmp)
- assert len(self.load_avro(tmp)) == NUM_RECORDS
+ self.assertEqual(len(self.load_avro(tmp)), NUM_RECORDS)
os.unlink(tmp)
def test_multi_file(self):
tmp = _tempfile()
- with open(tmp, "wb") as o:
+ with Path(tmp).open("w") as o: # standard io are always text, never
binary
self._run(self.json_file, self.json_file, stdout=o)
- assert len(self.load_avro(tmp)) == 2 * NUM_RECORDS
+ self.assertEqual(len(self.load_avro(tmp)), 2 * NUM_RECORDS)
os.unlink(tmp)
def test_stdin(self):
tmp = _tempfile()
- info = open(self.json_file, "rb")
- self._run("--input-type", "json", "-o", tmp, stdin=info)
-
- assert len(self.load_avro(tmp)) == NUM_RECORDS
+ with open(self.json_file, "r") as info: # standard io are always
text, never binary
+ self._run("--input-type", "json", "-o", tmp, stdin=info)
+ self.assertEqual(len(self.load_avro(tmp)), NUM_RECORDS)
os.unlink(tmp)
diff --git a/lang/py/avro/tool.py b/lang/py/avro/tool.py
old mode 100644
new mode 100755
index a6e3fa8..584265f
--- a/lang/py/avro/tool.py
+++ b/lang/py/avro/tool.py
@@ -23,37 +23,42 @@ Command-line tool
NOTE: The API for the command-line tool is experimental.
"""
+import argparse
import http.server
import os.path
import sys
import threading
import urllib.parse
import warnings
+from pathlib import Path
import avro.datafile
import avro.io
import avro.ipc
import avro.protocol
+server_should_shutdown = False
+responder: "GenericResponder"
+
class GenericResponder(avro.ipc.Responder):
- def __init__(self, proto, msg, datum):
- proto_json = open(proto, "rb").read()
- avro.ipc.Responder.__init__(self, avro.protocol.parse(proto_json))
+ def __init__(self, proto, msg, datum) -> None:
+ avro.ipc.Responder.__init__(self,
avro.protocol.parse(Path(proto).read_text()))
self.msg = msg
self.datum = datum
- def invoke(self, message, request):
- if message.name == self.msg:
- print(f"Message: {message.name} Datum: {self.datum}",
file=sys.stderr)
- # server will shut down after processing a single Avro request
- global server_should_shutdown
- server_should_shutdown = True
- return self.datum
+ def invoke(self, message, request) -> object:
+ global server_should_shutdown
+ if message.name != self.msg:
+ return None
+ print(f"Message: {message.name} Datum: {self.datum}", file=sys.stderr)
+ # server will shut down after processing a single Avro request
+ server_should_shutdown = True
+ return self.datum
class GenericHandler(http.server.BaseHTTPRequestHandler):
- def do_POST(self):
+ def do_POST(self) -> None:
self.responder = responder
call_request_reader = avro.ipc.FramedReader(self.rfile)
call_request = call_request_reader.read_framed_message()
@@ -70,11 +75,15 @@ class GenericHandler(http.server.BaseHTTPRequestHandler):
quitter.start()
-def run_server(uri, proto, msg, datum):
- url_obj = urllib.parse.urlparse(uri)
- server_addr = (url_obj.hostname, url_obj.port)
+def run_server(uri: str, proto: str, msg: str, datum: object) -> None:
global responder
global server_should_shutdown
+ url_obj = urllib.parse.urlparse(uri)
+ if url_obj.hostname is None:
+ raise RuntimeError(f"uri {uri} must have a hostname.")
+ if url_obj.port is None:
+ raise RuntimeError(f"uri {uri} must have a port.")
+ server_addr = (url_obj.hostname, url_obj.port)
server_should_shutdown = False
responder = GenericResponder(proto, msg, datum)
server = http.server.HTTPServer(server_addr, GenericHandler)
@@ -85,76 +94,66 @@ def run_server(uri, proto, msg, datum):
server.serve_forever()
-def send_message(uri, proto, msg, datum):
+def send_message(uri, proto, msg, datum) -> None:
url_obj = urllib.parse.urlparse(uri)
client = avro.ipc.HTTPTransceiver(url_obj.hostname, url_obj.port)
- proto_json = open(proto, "rb").read()
- requestor = avro.ipc.Requestor(avro.protocol.parse(proto_json), client)
+ requestor =
avro.ipc.Requestor(avro.protocol.parse(Path(proto).read_text()), client)
print(requestor.request(msg, datum))
-##
-# TODO: Replace this with fileinput()
-
-
-def file_or_stdin(f):
- return sys.stdin if f == "-" else open(f, "rb")
-
-
-def main(args=sys.argv):
- if len(args) == 1:
- print(f"Usage: {args[0]} [dump|rpcreceive|rpcsend]")
- return 1
-
- if args[1] == "dump":
- if len(args) != 3:
- print(f"Usage: {args[0]} dump input_file")
- return 1
- for d in avro.datafile.DataFileReader(file_or_stdin(args[2]),
avro.io.DatumReader()):
- print(repr(d))
- elif args[1] == "rpcreceive":
- usage_str = f"Usage: {args[0]} rpcreceive uri protocol_file
message_name (-data d | -file f)"
- if len(args) not in [5, 7]:
- print(usage_str)
- return 1
- uri, proto, msg = args[2:5]
- datum = None
- if len(args) > 5:
- if args[5] == "-file":
- reader = open(args[6], "rb")
- datum_reader = avro.io.DatumReader()
- dfr = avro.datafile.DataFileReader(reader, datum_reader)
- datum = next(dfr)
- elif args[5] == "-data":
- print("JSON Decoder not yet implemented.")
- return 1
- else:
- print(usage_str)
- return 1
- run_server(uri, proto, msg, datum)
- elif args[1] == "rpcsend":
- usage_str = f"Usage: {args[0]} rpcsend uri protocol_file message_name
(-data d | -file f)"
- if len(args) not in [5, 7]:
- print(usage_str)
- return 1
- uri, proto, msg = args[2:5]
- datum = None
- if len(args) > 5:
- if args[5] == "-file":
- reader = open(args[6], "rb")
- datum_reader = avro.io.DatumReader()
- dfr = avro.datafile.DataFileReader(reader, datum_reader)
- datum = next(dfr)
- elif args[5] == "-data":
- print("JSON Decoder not yet implemented.")
- return 1
- else:
- print(usage_str)
- return 1
- send_message(uri, proto, msg, datum)
+def _parse_args() -> argparse.Namespace:
+ """Parse the command-line arguments"""
+ parser = argparse.ArgumentParser()
+ subparsers = parser.add_subparsers(required=True, dest="command") if
sys.version_info >= (3, 7) else parser.add_subparsers(dest="command")
+ subparser_dump = subparsers.add_parser("dump", help="Dump an avro file")
+ subparser_dump.add_argument("input_file", type=argparse.FileType("rb"))
+ subparser_rpcreceive = subparsers.add_parser("rpcreceive", help="receive a
message")
+ subparser_rpcreceive.add_argument("uri")
+ subparser_rpcreceive.add_argument("proto")
+ subparser_rpcreceive.add_argument("msg")
+ subparser_rpcreceive.add_argument("-file", type=argparse.FileType("rb"),
required=False)
+ subparser_rpcsend = subparsers.add_parser("rpcsend", help="send a message")
+ subparser_rpcsend.add_argument("uri")
+ subparser_rpcsend.add_argument("proto")
+ subparser_rpcsend.add_argument("msg")
+ subparser_rpcsend.add_argument("-file", type=argparse.FileType("rb"))
+ return parser.parse_args()
+
+
+def main_dump(args: argparse.Namespace) -> int:
+ print("\n".join(f"{d!r}" for d in
avro.datafile.DataFileReader(args.input_file, avro.io.DatumReader())))
+ return 0
+
+
+def main_rpcreceive(args: argparse.Namespace) -> int:
+ datum = None
+ if args.file:
+ with avro.datafile.DataFileReader(args.file, avro.io.DatumReader()) as
dfr:
+ datum = next(dfr)
+ run_server(args.uri, args.proto, args.msg, datum)
return 0
+def main_rpcsend(args: argparse.Namespace) -> int:
+ datum = None
+ if args.file:
+ with avro.datafile.DataFileReader(args.file, avro.io.DatumReader()) as
dfr:
+ datum = next(dfr)
+ send_message(args.uri, args.proto, args.msg, datum)
+ return 0
+
+
+def main() -> int:
+ args = _parse_args()
+ if args.command == "dump":
+ return main_dump(args)
+ if args.command == "rpcreceive":
+ return main_rpcreceive(args)
+ if args.command == "rpcsend":
+ return main_rpcsend(args)
+ return 1
+
+
if __name__ == "__main__":
if os.path.dirname(avro.io.__file__) in sys.path:
warnings.warn(
@@ -162,4 +161,4 @@ if __name__ == "__main__":
"with the python io module. Try doing `python -m avro.tool`
instead."
)
- sys.exit(main(sys.argv))
+ sys.exit(main())
diff --git a/lang/py/scripts/avro b/lang/py/scripts/avro
old mode 100755
new mode 100644
index ffb8a00..66a5380
--- a/lang/py/scripts/avro
+++ b/lang/py/scripts/avro
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
+##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,264 +17,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Command line utility for reading and writing Avro files."""
+import warnings
-import csv
-import functools
-import json
-import optparse
-import os.path
-import sys
-
-import avro
-import avro.datafile
-import avro.errors
-import avro.io
-import avro.schema
-
-_AVRO_DIR = os.path.abspath(os.path.dirname(avro.__file__))
-
-
-def _version():
- with open(os.path.join(_AVRO_DIR, 'VERSION.txt')) as v:
- return v.read()
-
-
-_AVRO_VERSION = _version()
-
-
-def print_json(row):
- print(json.dumps(row))
-
-
-def print_json_pretty(row):
- """Pretty print JSON"""
- # Need to work around https://bugs.python.org/issue16333
- # where json.dumps leaves trailing spaces.
- result = json.dumps(row, sort_keys=True, indent=4).replace(' \n', '\n')
- print(result)
-
-
-_write_row = csv.writer(sys.stdout).writerow
-
-
-def print_csv(row):
- # We sort the keys to the fields will be in the same place
- # FIXME: Do we want to do it in schema order?
- _write_row([row[key] for key in sorted(row)])
-
-
-def select_printer(format):
- return {
- "json": print_json,
- "json-pretty": print_json_pretty,
- "csv": print_csv
- }[format]
-
-
-def record_match(expr, record):
- return eval(expr, None, {"r": record})
-
-
-def parse_fields(fields):
- fields = fields or ''
- if not fields.strip():
- return None
-
- return [field.strip() for field in fields.split(',') if field.strip()]
-
-
-def field_selector(fields):
- fields = set(fields)
-
- def keys_filter(obj):
- return dict((k, obj[k]) for k in (set(obj) & fields))
- return keys_filter
-
-
-def print_avro(avro, opts):
- if opts.header and (opts.format != "csv"):
- raise avro.errors.UsageError("--header applies only to CSV format")
-
- # Apply filter first
- if opts.filter:
- predicate = functools.partial(record_match, opts.filter)
- avro = (r for r in avro if predicate(r))
-
- for i in range(opts.skip):
- try:
- next(avro)
- except StopIteration:
- return
-
- fields = parse_fields(opts.fields)
- if fields:
- fs = field_selector(fields)
- avro = (fs(r) for r in avro)
-
- printer = select_printer(opts.format)
- for i, record in enumerate(avro):
- if i == 0 and opts.header:
- _write_row(sorted(record.keys()))
- if i >= opts.count:
- break
- printer(record)
-
-
-def print_schema(avro):
- schema = avro.schema
- # Pretty print
- print(json.dumps(json.loads(schema), indent=4))
-
-
-def cat(opts, args):
- if not args:
- raise avro.errors.UsageError("No files to show")
- for filename in args:
- with avro.datafile.DataFileReader(open(filename, 'rb'),
avro.io.DatumReader()) as avro_:
- if opts.print_schema:
- print_schema(avro_)
- continue
- print_avro(avro_, opts)
-
-
-def _open(filename, mode):
- if filename == "-":
- return {
- "rb": sys.stdin,
- "wb": sys.stdout
- }[mode]
-
- return open(filename, mode)
-
-
-def iter_json(info, _):
- for i in info:
- try:
- s = i.decode()
- except AttributeError:
- s = i
- yield json.loads(s)
-
-
-def convert(value, field):
- type = field.type.type
- if type == "union":
- return convert_union(value, field)
-
- return {
- "int": int,
- "long": int,
- "float": float,
- "double": float,
- "string": str,
- "bytes": bytes,
- "boolean": bool,
- "null": lambda _: None,
- "union": lambda v: convert_union(v, field),
- }[type](value)
-
-
-def convert_union(value, field):
- for name in [s.name for s in field.type.schemas]:
- try:
- return convert(name)(value)
- except ValueError:
- continue
-
-
-def iter_csv(info, schema):
- header = [field.name for field in schema.fields]
- for row in csv.reader((getattr(i, "decode", lambda: i)() for i in info)):
- values = [convert(v, f) for v, f in zip(row, schema.fields)]
- yield dict(zip(header, values))
-
-
-def guess_input_type(files):
- if not files:
- return None
-
- ext = os.path.splitext(files[0])[1].lower()
- if ext in (".json", ".js"):
- return "json"
- elif ext in (".csv",):
- return "csv"
-
- return None
-
-
-def write(opts, files):
- if not opts.schema:
- raise avro.errors.UsageError("No schema specified")
-
- input_type = opts.input_type or guess_input_type(files)
- if not input_type:
- raise avro.errors.UsageError("Can't guess input file type (not .json
or .csv)")
- iter_records = {"json": iter_json, "csv": iter_csv}[input_type]
-
- try:
- with open(opts.schema) as schema_file:
- schema = avro.schema.parse(schema_file.read())
- out = _open(opts.output, "wb")
- except (IOError, OSError) as e:
- raise avro.errors.UsageError(f"Can't open file - {e}")
-
- writer = avro.datafile.DataFileWriter(getattr(out, 'buffer', out),
avro.io.DatumWriter(), schema)
-
- for filename in (files or ["-"]):
- info = _open(filename, "rb")
- for record in iter_records(info, schema):
- writer.append(record)
-
- writer.close()
-
-
-def main(argv):
- parser = optparse.OptionParser(description="Display/write for Avro files",
- version=_AVRO_VERSION,
- usage="usage: %prog cat|write [options]
FILE [FILE...]")
- # cat options
- cat_options = optparse.OptionGroup(parser, "cat options")
- cat_options.add_option("-n", "--count", default=float("Infinity"),
- help="number of records to print", type=int)
- cat_options.add_option("-s", "--skip", help="number of records to skip",
- type=int, default=0)
- cat_options.add_option("-f", "--format", help="record format",
- default="json",
- choices=["json", "csv", "json-pretty"])
- cat_options.add_option("--header", help="print CSV header", default=False,
- action="store_true")
- cat_options.add_option("--filter", help="filter records (e.g. r['age']>1)",
- default=None)
- cat_options.add_option("--print-schema", help="print schema",
- action="store_true", default=False)
- cat_options.add_option('--fields', default=None,
- help='fields to show, comma separated (show all by
default)')
- parser.add_option_group(cat_options)
-
- # write options
- write_options = optparse.OptionGroup(parser, "write options")
- write_options.add_option("--schema", help="schema JSON file (required)")
- write_options.add_option("--input-type",
- help="input file(s) type (json or csv)",
- choices=["json", "csv"], default=None)
- write_options.add_option("-o", "--output", help="output file", default="-")
- parser.add_option_group(write_options)
-
- opts, args = parser.parse_args(argv[1:])
- if len(args) < 1:
- parser.error("You must specify `cat` or `write`") # Will exit
-
- command_name = args.pop(0)
- try:
- command = {
- "cat": cat,
- "write": write,
- }[command_name]
- except KeyError:
- raise avro.errors.UsageError(f"Unknown command - {command_name!s}")
- command(opts, args)
+from avro.__main__ import main
+warnings.warn(DeprecationWarning("'scripts/avro' is deprecated; use the
installed avro cli or `python -m avro`."))
if __name__ == "__main__":
- main(sys.argv)
+ main()
diff --git a/lang/py/setup.cfg b/lang/py/setup.cfg
index 5869e7d..48eb668 100644
--- a/lang/py/setup.cfg
+++ b/lang/py/setup.cfg
@@ -55,10 +55,12 @@ include_package_data = true
install_requires =
typing-extensions;python_version<"3.8"
zip_safe = true
-scripts =
- scripts/avro
python_requires = >=3.6
+[options.entry_points]
+console_scripts =
+ avro = avro.__main__:main
+
[options.package_data]
avro =
HandshakeRequest.avsc
diff --git a/lang/py/setup.py b/lang/py/setup.py
index d8d0fa3..dee0e7f 100755
--- a/lang/py/setup.py
+++ b/lang/py/setup.py
@@ -102,7 +102,8 @@ class GenerateInteropDataCommand(setuptools.Command):
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
- avro.test.gen_interop_data.generate(self.schema_file,
os.path.join(self.output_path, "py.avro"))
+ with open(self.schema_file) as schema_file,
open(os.path.join(self.output_path, "py.avro"), "wb") as output:
+ avro.test.gen_interop_data.generate(schema_file, output)
def _get_version():
diff --git a/lang/py/tox.ini b/lang/py/tox.ini
index 569846c..bbd42d6 100644
--- a/lang/py/tox.ini
+++ b/lang/py/tox.ini
@@ -82,4 +82,4 @@ deps =
extras =
mypy
commands =
- mypy
+ mypy {posargs}