This is an automated email from the ASF dual-hosted git repository.
mtaha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/age.git
The following commit(s) were added to refs/heads/master by this push:
new 5f5b744a Update python-driver security and formatting (#2330)
5f5b744a is described below
commit 5f5b744a08641225652de83332d73bc7acfc889d
Author: John Gemignani <[email protected]>
AuthorDate: Fri Feb 13 09:30:05 2026 -0800
Update python-driver security and formatting (#2330)
Note: This PR was created with AI tools and a human.
- Add parameterized query construction using psycopg.sql to prevent
SQL injection in all Cypher execution paths (age.py, networkx/lib.py)
- Replace all %-format and f-string SQL in networkx/lib.py with
sql.Identifier() for schema/table names and sql.Literal() for values
- Add validate_graph_name() with AGE-aligned VALID_GRAPH_NAME regex:
start with letter/underscore, allow dots and hyphens in middle positions,
end with letter/digit/underscore, min 3 chars, max 63 chars
- Add validate_identifier() with strict VALID_IDENTIFIER regex for labels,
column names, and SQL types (no dots or hyphens)
- Add validation calls to all networkx/lib.py entry points:
graph names validated on entry, labels validated before SQL construction
- Add _validate_column() to sanitize column specifications in buildCypher()
- Fix exception constructors (AgeNotSet, GraphNotFound, GraphAlreadyExists)
to always call super().__init__() with a meaningful default message so
that str(exception) never returns an empty string
- Add InvalidGraphName and InvalidIdentifier exception classes with
structured name/reason/context fields
- Fix builder.py: change erroneous 'return Exception(...)' to
'raise ValueError(...)' for unknown float expressions
- Fix copy-paste docstring in create_elabel() ('create_vlabels' ->
'create_elabels')
- Remove unused 'from psycopg.adapt import Loader' import in age.py
- Add design documentation in source explaining:
- VALID_GRAPH_NAME regex uses '*' (not '+') intentionally so that the
min-length check fires first with a clear error message
- buildCypher uses string concatenation (not sql.Identifier) because
column specs are pre-validated 'name type' pairs that don't map to
sql.Identifier(); graphName and cypherStmt are NOT embedded
- Update test_networkx.py GraphNotFound assertion to use assertIn()
instead of assertEqual() to match the improved exception messages
- Strip Windows carriage returns (^M) from 7 source files
- Fix requirements.txt: convert from UTF-16LE+BOM+CRLF to clean UTF-8+LF,
move --no-binary flag from requirements.txt to CI workflow pip command
- Upgrade actions/setup-python from v4 (deprecated) to v5 in CI workflow
- Add 46 security unit tests in test_security.py covering:
- Graph name validation (AGE naming rules, injection, edge cases)
- SQL identifier validation (labels, columns, types)
- Column spec sanitization
- buildCypher injection prevention
- Exception constructor correctness (str() never empty)
- Add test_security.py to CI pipeline (python-driver.yaml)
- pip-audit: 0 known vulnerabilities in all dependencies
modified: .github/workflows/python-driver.yaml
modified: drivers/python/age/VERSION.py
modified: drivers/python/age/__init__.py
modified: drivers/python/age/age.py
modified: drivers/python/age/builder.py
modified: drivers/python/age/exceptions.py
modified: drivers/python/age/models.py
modified: drivers/python/age/networkx/lib.py
modified: drivers/python/requirements.txt
modified: drivers/python/setup.py
modified: drivers/python/test_agtypes.py
modified: drivers/python/test_networkx.py
new file: drivers/python/test_security.py
---
.github/workflows/python-driver.yaml | 5 +-
drivers/python/age/VERSION.py | 44 +--
drivers/python/age/__init__.py | 80 ++---
drivers/python/age/age.py | 611 +++++++++++++++++++++--------------
drivers/python/age/builder.py | 420 ++++++++++++------------
drivers/python/age/exceptions.py | 59 +++-
drivers/python/age/models.py | 586 ++++++++++++++++-----------------
drivers/python/age/networkx/lib.py | 97 ++++--
drivers/python/requirements.txt | Bin 176 -> 59 bytes
drivers/python/setup.py | 44 +--
drivers/python/test_agtypes.py | 264 +++++++--------
drivers/python/test_networkx.py | 2 +-
drivers/python/test_security.py | 274 ++++++++++++++++
13 files changed, 1481 insertions(+), 1005 deletions(-)
diff --git a/.github/workflows/python-driver.yaml
b/.github/workflows/python-driver.yaml
index 4dad1463..16ccface 100644
--- a/.github/workflows/python-driver.yaml
+++ b/.github/workflows/python-driver.yaml
@@ -22,14 +22,14 @@ jobs:
run: docker compose up -d
- name: Set up python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install pre-requisites
run: |
sudo apt-get install python3-dev libpq-dev
- pip install -r requirements.txt
+ pip install --no-binary psycopg -r requirements.txt
- name: Build
run: |
@@ -40,3 +40,4 @@ jobs:
python test_age_py.py -db "postgres" -u "postgres" -pass "agens"
python test_networkx.py -db "postgres" -u "postgres" -pass "agens"
python -m unittest -v test_agtypes.py
+ python -m unittest -v test_security.py
diff --git a/drivers/python/age/VERSION.py b/drivers/python/age/VERSION.py
index 3b014ea5..5136181a 100644
--- a/drivers/python/age/VERSION.py
+++ b/drivers/python/age/VERSION.py
@@ -1,22 +1,22 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-
-VER_MAJOR = 1
-VER_MINOR = 0
-VER_MICRO = 0
-
-VERSION = '.'.join([str(VER_MAJOR),str(VER_MINOR),str(VER_MICRO)])
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+
+VER_MAJOR = 1
+VER_MINOR = 0
+VER_MICRO = 0
+
+VERSION = '.'.join([str(VER_MAJOR),str(VER_MINOR),str(VER_MICRO)])
diff --git a/drivers/python/age/__init__.py b/drivers/python/age/__init__.py
index fd50135a..685f0fe7 100644
--- a/drivers/python/age/__init__.py
+++ b/drivers/python/age/__init__.py
@@ -1,40 +1,40 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import psycopg.conninfo as conninfo
-from . import age
-from .age import *
-from .models import *
-from .builder import ResultHandler, DummyResultHandler, parseAgeValue,
newResultHandler
-from . import VERSION
-
-def version():
- return VERSION.VERSION
-
-
-def connect(dsn=None, graph=None, connection_factory=None,
cursor_factory=ClientCursor, load_from_plugins=False,
- **kwargs):
-
- dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs)
-
- ag = Age()
- ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory,
cursor_factory=cursor_factory,
- load_from_plugins=load_from_plugins, **kwargs)
- return ag
-
-# Dummy ResultHandler
-rawPrinter = DummyResultHandler()
-
-__name__="age"
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import psycopg.conninfo as conninfo
+from . import age
+from .age import *
+from .models import *
+from .builder import ResultHandler, DummyResultHandler, parseAgeValue,
newResultHandler
+from . import VERSION
+
+def version():
+ return VERSION.VERSION
+
+
+def connect(dsn=None, graph=None, connection_factory=None,
cursor_factory=ClientCursor, load_from_plugins=False,
+ **kwargs):
+
+ dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs)
+
+ ag = Age()
+ ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory,
cursor_factory=cursor_factory,
+ load_from_plugins=load_from_plugins, **kwargs)
+ return ag
+
+# Dummy ResultHandler
+rawPrinter = DummyResultHandler()
+
+__name__="age"
diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py
index b1aa8215..fad1f27b 100644
--- a/drivers/python/age/age.py
+++ b/drivers/python/age/age.py
@@ -1,236 +1,375 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import re
-import psycopg
-from psycopg.types import TypeInfo
-from psycopg.adapt import Loader
-from psycopg import sql
-from psycopg.client_cursor import ClientCursor
-from .exceptions import *
-from .builder import parseAgeValue
-
-
-_EXCEPTION_NoConnection = NoConnection()
-_EXCEPTION_GraphNotSet = GraphNotSet()
-
-WHITESPACE = re.compile(r'\s')
-
-
-class AgeDumper(psycopg.adapt.Dumper):
- def dump(self, obj: Any) -> bytes | bytearray | memoryview:
- pass
-
-
-class AgeLoader(psycopg.adapt.Loader):
- def load(self, data: bytes | bytearray | memoryview) -> Any | None:
- if isinstance(data, memoryview):
- data_bytes = data.tobytes()
- else:
- data_bytes = data
-
- return parseAgeValue(data_bytes.decode('utf-8'))
-
-
-def setUpAge(conn:psycopg.connection, graphName:str,
load_from_plugins:bool=False):
- with conn.cursor() as cursor:
- if load_from_plugins:
- cursor.execute("LOAD '$libdir/plugins/age';")
- else:
- cursor.execute("LOAD 'age';")
-
- cursor.execute("SET search_path = ag_catalog, '$user', public;")
-
- ag_info = TypeInfo.fetch(conn, 'agtype')
-
- if not ag_info:
- raise AgeNotSet()
-
- conn.adapters.register_loader(ag_info.oid, AgeLoader)
- conn.adapters.register_loader(ag_info.array_oid, AgeLoader)
-
- # Check graph exists
- if graphName != None:
- checkGraphCreated(conn, graphName)
-
-# Create the graph, if it does not exist
-def checkGraphCreated(conn:psycopg.connection, graphName:str):
- with conn.cursor() as cursor:
- cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE
name={graphName}").format(graphName=sql.Literal(graphName)))
- if cursor.fetchone()[0] == 0:
- cursor.execute(sql.SQL("SELECT
create_graph({graphName});").format(graphName=sql.Literal(graphName)))
- conn.commit()
-
-
-def deleteGraph(conn:psycopg.connection, graphName:str):
- with conn.cursor() as cursor:
- cursor.execute(sql.SQL("SELECT drop_graph({graphName},
true);").format(graphName=sql.Literal(graphName)))
- conn.commit()
-
-
-def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
- if graphName == None:
- raise _EXCEPTION_GraphNotSet
-
- columnExp=[]
- if columns != None and len(columns) > 0:
- for col in columns:
- if col.strip() == '':
- continue
- elif WHITESPACE.search(col) != None:
- columnExp.append(col)
- else:
- columnExp.append(col + " agtype")
- else:
- columnExp.append('v agtype')
-
- stmtArr = []
- stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
- stmtArr.append(','.join(columnExp))
- stmtArr.append(");")
- return "".join(stmtArr)
-
-def execSql(conn:psycopg.connection, stmt:str, commit:bool=False,
params:tuple=None) -> psycopg.cursor :
- if conn == None or conn.closed:
- raise _EXCEPTION_NoConnection
-
- cursor = conn.cursor()
- try:
- cursor.execute(stmt, params)
- if commit:
- conn.commit()
-
- return cursor
- except SyntaxError as cause:
- conn.rollback()
- raise cause
- except Exception as cause:
- conn.rollback()
- raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt
+")", cause)
-
-
-def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) ->
psycopg.cursor :
- return execSql(conn, stmt, False, params)
-
-# Execute cypher statement and return cursor.
-# If cypher statement changes data (create, set, remove),
-# You must commit session(ag.commit())
-# (Otherwise the execution cannot make any effect.)
-def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str,
cols:list=None, params:tuple=None) -> psycopg.cursor :
- if conn == None or conn.closed:
- raise _EXCEPTION_NoConnection
-
- cursor = conn.cursor()
- #clean up the string for mogrification
- cypherStmt = cypherStmt.replace("\n", "")
- cypherStmt = cypherStmt.replace("\t", "")
- cypher = str(cursor.mogrify(cypherStmt, params))
- cypher = cypher.strip()
-
- preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
-
- cursor = conn.cursor()
- try:
-
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
- except SyntaxError as cause:
- conn.rollback()
- raise cause
- except Exception as cause:
- conn.rollback()
- raise SqlExecutionError("Execution ERR[" + str(cause) +"](" +
preparedStmt +")", cause)
-
- stmt = buildCypher(graphName, cypher, cols)
-
- cursor = conn.cursor()
- try:
- cursor.execute(stmt)
- return cursor
- except SyntaxError as cause:
- conn.rollback()
- raise cause
- except Exception as cause:
- conn.rollback()
- raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt
+")", cause)
-
-
-def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str,
cols:list=None, params:tuple=None) -> psycopg.cursor :
- #clean up the string for mogrification
- cypherStmt = cypherStmt.replace("\n", "")
- cypherStmt = cypherStmt.replace("\t", "")
- cypher = str(cursor.mogrify(cypherStmt, params))
- cypher = cypher.strip()
-
- preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
-
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
-
- stmt = buildCypher(graphName, cypher, cols)
- cursor.execute(stmt)
-
-
-# def execCypherWithReturn(conn:psycopg.connection, graphName:str,
cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
-# stmt = buildCypher(graphName, cypherStmt, columns)
-# return execSql(conn, stmt, False, params)
-
-# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str,
columns:list=None , params:tuple=None) -> psycopg.cursor :
-# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)
-
-
-class Age:
- def __init__(self):
- self.connection = None # psycopg connection]
- self.graphName = None
-
- # Connect to PostgreSQL Server and establish session and type extension
environment.
- def connect(self, graph:str=None, dsn:str=None, connection_factory=None,
cursor_factory=ClientCursor,
- load_from_plugins:bool=False, **kwargs):
- conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
- setUpAge(conn, graph, load_from_plugins)
- self.connection = conn
- self.graphName = graph
- return self
-
- def close(self):
- self.connection.close()
-
- def setGraph(self, graph:str):
- checkGraphCreated(self.connection, graph)
- self.graphName = graph
- return self
-
- def commit(self):
- self.connection.commit()
-
- def rollback(self):
- self.connection.rollback()
-
- def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) ->
psycopg.cursor :
- return execCypher(self.connection, self.graphName, cypherStmt,
cols=cols, params=params)
-
- def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None,
params:tuple=None) -> psycopg.cursor :
- return cypher(cursor, self.graphName, cypherStmt, cols=cols,
params=params)
-
- # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) ->
psycopg.cursor :
- # return execSql(self.connection, stmt, commit, params)
-
-
- # def execCypher(self, cypherStmt:str, commit:bool=False,
params:tuple=None) -> psycopg.cursor :
- # return execCypher(self.connection, self.graphName, cypherStmt,
commit, params)
-
- # def execCypherWithReturn(self, cypherStmt:str, columns:list=None ,
params:tuple=None) -> psycopg.cursor :
- # return execCypherWithReturn(self.connection, self.graphName,
cypherStmt, columns, params)
-
- # def queryCypher(self, cypherStmt:str, columns:list=None ,
params:tuple=None) -> psycopg.cursor :
- # return queryCypher(self.connection, self.graphName, cypherStmt,
columns, params)
-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+import psycopg
+from psycopg.types import TypeInfo
+from psycopg import sql
+from psycopg.client_cursor import ClientCursor
+from .exceptions import *
+from .builder import parseAgeValue
+
+
+_EXCEPTION_NoConnection = NoConnection()
+_EXCEPTION_GraphNotSet = GraphNotSet()
+
+WHITESPACE = re.compile(r'\s')
+
+# Valid AGE graph name pattern aligned with Apache AGE's internal validation
+# and Neo4j/openCypher naming conventions.
+# Start: letter or underscore
+# Middle: letter, digit, underscore, dot, or hyphen
+# End: letter, digit, or underscore
+#
+# Design note: The middle segment uses `*` (not `+`) intentionally.
+# This makes the regex match names as short as 2 characters at the
+# regex level. However, validate_graph_name() checks MIN_GRAPH_NAME_LENGTH
+# *before* applying this regex, so 2-character names are rejected with a
+# clear "must be at least 3 characters" error rather than a confusing
+# regex-mismatch error. This ordering gives users actionable feedback.
+VALID_GRAPH_NAME = re.compile(r'^[A-Za-z_][A-Za-z0-9_.\-]*[A-Za-z0-9_]$')
+MIN_GRAPH_NAME_LENGTH = 3
+
+# Valid SQL identifier for labels, column names, and types.
+# Stricter than graph names — no dots or hyphens.
+VALID_IDENTIFIER = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
+MAX_IDENTIFIER_LENGTH = 63
+
+
+def validate_graph_name(graph_name: str) -> None:
+ """Validate that a graph name conforms to Apache AGE's naming rules.
+
+ Graph names must:
+ - Be at least 3 characters and at most 63 characters
+ - Start with a letter or underscore
+ - Contain only letters, digits, underscores, dots, and hyphens
+ - End with a letter, digit, or underscore
+
+ This aligns with AGE's internal validation and Neo4j/openCypher
+ naming conventions.
+
+ Args:
+ graph_name: The graph name to validate.
+
+ Raises:
+ InvalidGraphName: If the graph name is invalid.
+ """
+ if not graph_name or not isinstance(graph_name, str):
+ raise InvalidGraphName(
+ str(graph_name),
+ "Graph name must be a non-empty string."
+ )
+ if len(graph_name) < MIN_GRAPH_NAME_LENGTH:
+ raise InvalidGraphName(
+ graph_name,
+ f"Graph names must be at least {MIN_GRAPH_NAME_LENGTH} characters."
+ )
+ if len(graph_name) > MAX_IDENTIFIER_LENGTH:
+ raise InvalidGraphName(
+ graph_name,
+ f"Must not exceed {MAX_IDENTIFIER_LENGTH} characters "
+ "(PostgreSQL name limit)."
+ )
+ if not VALID_GRAPH_NAME.match(graph_name):
+ raise InvalidGraphName(
+ graph_name,
+ "Graph names must start with a letter or underscore, "
+ "may contain letters, digits, underscores, dots, and hyphens, "
+ "and must end with a letter, digit, or underscore."
+ )
+
+
+def validate_identifier(name: str, context: str = "identifier") -> None:
+ """Validate that a name is a safe SQL identifier for labels, columns, or
types.
+
+ This follows stricter rules than graph names — only letters, digits,
+ and underscores are permitted (no dots or hyphens).
+
+ Args:
+ name: The identifier to validate.
+ context: What the identifier represents (for error messages).
+
+ Raises:
+ InvalidIdentifier: If the identifier is invalid.
+ """
+ if not name or not isinstance(name, str):
+ raise InvalidIdentifier(
+ str(name),
+ f"{context} must be a non-empty string."
+ )
+ if len(name) > MAX_IDENTIFIER_LENGTH:
+ raise InvalidIdentifier(
+ name,
+ f"{context} must not exceed {MAX_IDENTIFIER_LENGTH} characters."
+ )
+ if not VALID_IDENTIFIER.match(name):
+ raise InvalidIdentifier(
+ name,
+ f"{context} must start with a letter or underscore "
+ "and contain only letters, digits, and underscores."
+ )
+
+
+class AgeDumper(psycopg.adapt.Dumper):
+ def dump(self, obj: Any) -> bytes | bytearray | memoryview:
+ pass
+
+
+class AgeLoader(psycopg.adapt.Loader):
+ def load(self, data: bytes | bytearray | memoryview) -> Any | None:
+ if isinstance(data, memoryview):
+ data_bytes = data.tobytes()
+ else:
+ data_bytes = data
+
+ return parseAgeValue(data_bytes.decode('utf-8'))
+
+
+def setUpAge(conn:psycopg.connection, graphName:str,
load_from_plugins:bool=False):
+ with conn.cursor() as cursor:
+ if load_from_plugins:
+ cursor.execute("LOAD '$libdir/plugins/age';")
+ else:
+ cursor.execute("LOAD 'age';")
+
+ cursor.execute("SET search_path = ag_catalog, '$user', public;")
+
+ ag_info = TypeInfo.fetch(conn, 'agtype')
+
+ if not ag_info:
+ raise AgeNotSet(
+ "AGE agtype type not found. Ensure the AGE extension is "
+ "installed and loaded in the current database. "
+ "Run CREATE EXTENSION age; first."
+ )
+
+ conn.adapters.register_loader(ag_info.oid, AgeLoader)
+ conn.adapters.register_loader(ag_info.array_oid, AgeLoader)
+
+ # Check graph exists
+ if graphName != None:
+ checkGraphCreated(conn, graphName)
+
+# Create the graph, if it does not exist
+def checkGraphCreated(conn:psycopg.connection, graphName:str):
+ validate_graph_name(graphName)
+ with conn.cursor() as cursor:
+ cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE
name={graphName}").format(graphName=sql.Literal(graphName)))
+ if cursor.fetchone()[0] == 0:
+ cursor.execute(sql.SQL("SELECT
create_graph({graphName});").format(graphName=sql.Literal(graphName)))
+ conn.commit()
+
+
+def deleteGraph(conn:psycopg.connection, graphName:str):
+ validate_graph_name(graphName)
+ with conn.cursor() as cursor:
+ cursor.execute(sql.SQL("SELECT drop_graph({graphName},
true);").format(graphName=sql.Literal(graphName)))
+ conn.commit()
+
+
+def _validate_column(col: str) -> str:
+ """Validate and normalize a column specification for use in SQL.
+
+ Accepts either a plain column name (e.g. 'v') or a name with type
+ (e.g. 'v agtype'). Validates each component to prevent SQL injection.
+
+ Args:
+ col: Column specification string.
+
+ Returns:
+ Normalized column specification, or empty string if blank.
+
+ Raises:
+ InvalidIdentifier: If any component is invalid.
+ """
+ col = col.strip()
+ if not col:
+ return ''
+
+ if WHITESPACE.search(col):
+ parts = col.split()
+ if len(parts) != 2:
+ raise InvalidIdentifier(
+ col,
+ "Column specification must be 'name' or 'name type'."
+ )
+ name, type_name = parts
+ validate_identifier(name, "Column name")
+ validate_identifier(type_name, "Column type")
+ return f"{name} {type_name}"
+ else:
+ validate_identifier(col, "Column name")
+ return f"{col} agtype"
+
+
+def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
+ if graphName == None:
+ raise _EXCEPTION_GraphNotSet
+
+ columnExp=[]
+ if columns != None and len(columns) > 0:
+ for col in columns:
+ validated = _validate_column(col)
+ if validated:
+ columnExp.append(validated)
+ else:
+ columnExp.append('v agtype')
+
+ # Design note: String concatenation is used here instead of
+ # psycopg.sql.Identifier() because column specifications are
+ # "name type" pairs (e.g. "v agtype") that don't map directly to
+ # sql.Identifier(). Each component has already been validated by
+ # _validate_column() → validate_identifier(), which restricts
+ # names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. The
+ # graphName and cypherStmt are NOT embedded here — this template
+ # only contains the validated column list and static SQL keywords.
+ stmtArr = []
+ stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
+ stmtArr.append(','.join(columnExp))
+ stmtArr.append(");")
+ return "".join(stmtArr)
+
+def execSql(conn:psycopg.connection, stmt:str, commit:bool=False,
params:tuple=None) -> psycopg.cursor :
+ if conn == None or conn.closed:
+ raise _EXCEPTION_NoConnection
+
+ cursor = conn.cursor()
+ try:
+ cursor.execute(stmt, params)
+ if commit:
+ conn.commit()
+
+ return cursor
+ except SyntaxError as cause:
+ conn.rollback()
+ raise cause
+ except Exception as cause:
+ conn.rollback()
+ raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt
+")", cause)
+
+
+def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) ->
psycopg.cursor :
+ return execSql(conn, stmt, False, params)
+
+# Execute cypher statement and return cursor.
+# If cypher statement changes data (create, set, remove),
+# You must commit session(ag.commit())
+# (Otherwise the execution cannot make any effect.)
+def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str,
cols:list=None, params:tuple=None) -> psycopg.cursor :
+ if conn == None or conn.closed:
+ raise _EXCEPTION_NoConnection
+
+ cursor = conn.cursor()
+ #clean up the string for mogrification
+ cypherStmt = cypherStmt.replace("\n", "")
+ cypherStmt = cypherStmt.replace("\t", "")
+ cypher = str(cursor.mogrify(cypherStmt, params))
+ cypher = cypher.strip()
+
+ preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
+
+ cursor = conn.cursor()
+ try:
+
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+ except SyntaxError as cause:
+ conn.rollback()
+ raise cause
+ except Exception as cause:
+ conn.rollback()
+ raise SqlExecutionError("Execution ERR[" + str(cause) +"](" +
preparedStmt +")", cause)
+
+ stmt = buildCypher(graphName, cypher, cols)
+
+ cursor = conn.cursor()
+ try:
+ cursor.execute(stmt)
+ return cursor
+ except SyntaxError as cause:
+ conn.rollback()
+ raise cause
+ except Exception as cause:
+ conn.rollback()
+ raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt
+")", cause)
+
+
+def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str,
cols:list=None, params:tuple=None) -> psycopg.cursor :
+ #clean up the string for mogrification
+ cypherStmt = cypherStmt.replace("\n", "")
+ cypherStmt = cypherStmt.replace("\t", "")
+ cypher = str(cursor.mogrify(cypherStmt, params))
+ cypher = cypher.strip()
+
+ preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
+
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+
+ stmt = buildCypher(graphName, cypher, cols)
+ cursor.execute(stmt)
+
+
+# def execCypherWithReturn(conn:psycopg.connection, graphName:str,
cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
+# stmt = buildCypher(graphName, cypherStmt, columns)
+# return execSql(conn, stmt, False, params)
+
+# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str,
columns:list=None , params:tuple=None) -> psycopg.cursor :
+# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)
+
+
+class Age:
+ def __init__(self):
+ self.connection = None # psycopg connection]
+ self.graphName = None
+
+ # Connect to PostgreSQL Server and establish session and type extension
environment.
+ def connect(self, graph:str=None, dsn:str=None, connection_factory=None,
cursor_factory=ClientCursor,
+ load_from_plugins:bool=False, **kwargs):
+ conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
+ setUpAge(conn, graph, load_from_plugins)
+ self.connection = conn
+ self.graphName = graph
+ return self
+
+ def close(self):
+ self.connection.close()
+
+ def setGraph(self, graph:str):
+ checkGraphCreated(self.connection, graph)
+ self.graphName = graph
+ return self
+
+ def commit(self):
+ self.connection.commit()
+
+ def rollback(self):
+ self.connection.rollback()
+
+ def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) ->
psycopg.cursor :
+ return execCypher(self.connection, self.graphName, cypherStmt,
cols=cols, params=params)
+
+ def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None,
params:tuple=None) -> psycopg.cursor :
+ return cypher(cursor, self.graphName, cypherStmt, cols=cols,
params=params)
+
+ # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) ->
psycopg.cursor :
+ # return execSql(self.connection, stmt, commit, params)
+
+
+ # def execCypher(self, cypherStmt:str, commit:bool=False,
params:tuple=None) -> psycopg.cursor :
+ # return execCypher(self.connection, self.graphName, cypherStmt,
commit, params)
+
+ # def execCypherWithReturn(self, cypherStmt:str, columns:list=None ,
params:tuple=None) -> psycopg.cursor :
+ # return execCypherWithReturn(self.connection, self.graphName,
cypherStmt, columns, params)
+
+ # def queryCypher(self, cypherStmt:str, columns:list=None ,
params:tuple=None) -> psycopg.cursor :
+ # return queryCypher(self.connection, self.graphName, cypherStmt,
columns, params)
+
diff --git a/drivers/python/age/builder.py b/drivers/python/age/builder.py
index a3815b82..f1e7a2ce 100644
--- a/drivers/python/age/builder.py
+++ b/drivers/python/age/builder.py
@@ -1,210 +1,210 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from . import gen
-from .gen.AgtypeLexer import AgtypeLexer
-from .gen.AgtypeParser import AgtypeParser
-from .gen.AgtypeVisitor import AgtypeVisitor
-from .models import *
-from .exceptions import *
-from antlr4 import InputStream, CommonTokenStream, ParserRuleContext
-from antlr4.tree.Tree import TerminalNode
-from decimal import Decimal
-
-resultHandler = None
-
-class ResultHandler:
- def parse(ageData):
- pass
-
-def newResultHandler(query=""):
- resultHandler = Antlr4ResultHandler(None, query)
- return resultHandler
-
-def parseAgeValue(value, cursor=None):
- if value is None:
- return None
-
- global resultHandler
- if (resultHandler == None):
- resultHandler = Antlr4ResultHandler(None)
- try:
- return resultHandler.parse(value)
- except Exception as ex:
- raise AGTypeError(value, ex)
-
-
-class Antlr4ResultHandler(ResultHandler):
- def __init__(self, vertexCache, query=None):
- self.lexer = AgtypeLexer()
- self.parser = AgtypeParser(None)
- self.visitor = ResultVisitor(vertexCache)
-
- def parse(self, ageData):
- if not ageData:
- return None
- # print("Parse::", ageData)
-
- self.lexer.inputStream = InputStream(ageData)
- self.parser.setTokenStream(CommonTokenStream(self.lexer))
- self.parser.reset()
- tree = self.parser.agType()
- parsed = tree.accept(self.visitor)
- return parsed
-
-
-# print raw result String
-class DummyResultHandler(ResultHandler):
- def parse(self, ageData):
- print(ageData)
-
-# default agType visitor
-class ResultVisitor(AgtypeVisitor):
- vertexCache = None
-
- def __init__(self, cache) -> None:
- super().__init__()
- self.vertexCache = cache
-
-
- def visitAgType(self, ctx:AgtypeParser.AgTypeContext):
- agVal = ctx.agValue()
- if agVal != None:
- obj = ctx.agValue().accept(self)
- return obj
-
- return None
-
- def visitAgValue(self, ctx:AgtypeParser.AgValueContext):
- annoCtx = ctx.typeAnnotation()
- valueCtx = ctx.value()
-
- if annoCtx is not None:
- annoCtx.accept(self)
- anno = annoCtx.IDENT().getText()
- return self.handleAnnotatedValue(anno, valueCtx)
- else:
- return valueCtx.accept(self)
-
-
- # Visit a parse tree produced by AgtypeParser#StringValue.
- def visitStringValue(self, ctx:AgtypeParser.StringValueContext):
- return ctx.STRING().getText().strip('"')
-
-
- # Visit a parse tree produced by AgtypeParser#IntegerValue.
- def visitIntegerValue(self, ctx:AgtypeParser.IntegerValueContext):
- return int(ctx.INTEGER().getText())
-
- # Visit a parse tree produced by AgtypeParser#floatLiteral.
- def visitFloatLiteral(self, ctx:AgtypeParser.FloatLiteralContext):
- c = ctx.getChild(0)
- tp = c.symbol.type
- text = ctx.getText()
- if tp == AgtypeParser.RegularFloat:
- return float(text)
- elif tp == AgtypeParser.ExponentFloat:
- return float(text)
- else:
- if text == 'NaN':
- return float('nan')
- elif text == '-Infinity':
- return float('-inf')
- elif text == 'Infinity':
- return float('inf')
- else:
- return Exception("Unknown float expression:"+text)
-
-
- # Visit a parse tree produced by AgtypeParser#TrueBoolean.
- def visitTrueBoolean(self, ctx:AgtypeParser.TrueBooleanContext):
- return True
-
-
- # Visit a parse tree produced by AgtypeParser#FalseBoolean.
- def visitFalseBoolean(self, ctx:AgtypeParser.FalseBooleanContext):
- return False
-
-
- # Visit a parse tree produced by AgtypeParser#NullValue.
- def visitNullValue(self, ctx:AgtypeParser.NullValueContext):
- return None
-
-
- # Visit a parse tree produced by AgtypeParser#obj.
- def visitObj(self, ctx:AgtypeParser.ObjContext):
- obj = dict()
- for c in ctx.getChildren():
- if isinstance(c, AgtypeParser.PairContext):
- namVal = self.visitPair(c)
- name = namVal[0]
- valCtx = namVal[1]
- val = valCtx.accept(self)
- obj[name] = val
- return obj
-
-
- # Visit a parse tree produced by AgtypeParser#pair.
- def visitPair(self, ctx:AgtypeParser.PairContext):
- self.visitChildren(ctx)
- return (ctx.STRING().getText().strip('"') , ctx.agValue())
-
-
- # Visit a parse tree produced by AgtypeParser#array.
- def visitArray(self, ctx:AgtypeParser.ArrayContext):
- li = list()
- for c in ctx.getChildren():
- if not isinstance(c, TerminalNode):
- val = c.accept(self)
- li.append(val)
- return li
-
- def handleAnnotatedValue(self, anno:str, ctx:ParserRuleContext):
- if anno == "numeric":
- return Decimal(ctx.getText())
- elif anno == "vertex":
- dict = ctx.accept(self)
- vid = dict["id"]
- vertex = None
- if self.vertexCache != None and vid in self.vertexCache :
- vertex = self.vertexCache[vid]
- else:
- vertex = Vertex()
- vertex.id = dict["id"]
- vertex.label = dict["label"]
- vertex.properties = dict["properties"]
-
- if self.vertexCache != None:
- self.vertexCache[vid] = vertex
-
- return vertex
-
- elif anno == "edge":
- edge = Edge()
- dict = ctx.accept(self)
- edge.id = dict["id"]
- edge.label = dict["label"]
- edge.end_id = dict["end_id"]
- edge.start_id = dict["start_id"]
- edge.properties = dict["properties"]
-
- return edge
-
- elif anno == "path":
- arr = ctx.accept(self)
- path = Path(arr)
-
- return path
-
- return ctx.accept(self)
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from . import gen
+from .gen.AgtypeLexer import AgtypeLexer
+from .gen.AgtypeParser import AgtypeParser
+from .gen.AgtypeVisitor import AgtypeVisitor
+from .models import *
+from .exceptions import *
+from antlr4 import InputStream, CommonTokenStream, ParserRuleContext
+from antlr4.tree.Tree import TerminalNode
+from decimal import Decimal
+
+resultHandler = None
+
+class ResultHandler:
+ def parse(ageData):
+ pass
+
+def newResultHandler(query=""):
+ resultHandler = Antlr4ResultHandler(None, query)
+ return resultHandler
+
+def parseAgeValue(value, cursor=None):
+ if value is None:
+ return None
+
+ global resultHandler
+ if (resultHandler == None):
+ resultHandler = Antlr4ResultHandler(None)
+ try:
+ return resultHandler.parse(value)
+ except Exception as ex:
+ raise AGTypeError(value, ex)
+
+
+class Antlr4ResultHandler(ResultHandler):
+ def __init__(self, vertexCache, query=None):
+ self.lexer = AgtypeLexer()
+ self.parser = AgtypeParser(None)
+ self.visitor = ResultVisitor(vertexCache)
+
+ def parse(self, ageData):
+ if not ageData:
+ return None
+ # print("Parse::", ageData)
+
+ self.lexer.inputStream = InputStream(ageData)
+ self.parser.setTokenStream(CommonTokenStream(self.lexer))
+ self.parser.reset()
+ tree = self.parser.agType()
+ parsed = tree.accept(self.visitor)
+ return parsed
+
+
+# print raw result String
+class DummyResultHandler(ResultHandler):
+ def parse(self, ageData):
+ print(ageData)
+
+# default agType visitor
+class ResultVisitor(AgtypeVisitor):
+ vertexCache = None
+
+ def __init__(self, cache) -> None:
+ super().__init__()
+ self.vertexCache = cache
+
+
+ def visitAgType(self, ctx:AgtypeParser.AgTypeContext):
+ agVal = ctx.agValue()
+ if agVal != None:
+ obj = ctx.agValue().accept(self)
+ return obj
+
+ return None
+
+ def visitAgValue(self, ctx:AgtypeParser.AgValueContext):
+ annoCtx = ctx.typeAnnotation()
+ valueCtx = ctx.value()
+
+ if annoCtx is not None:
+ annoCtx.accept(self)
+ anno = annoCtx.IDENT().getText()
+ return self.handleAnnotatedValue(anno, valueCtx)
+ else:
+ return valueCtx.accept(self)
+
+
+ # Visit a parse tree produced by AgtypeParser#StringValue.
+ def visitStringValue(self, ctx:AgtypeParser.StringValueContext):
+ return ctx.STRING().getText().strip('"')
+
+
+ # Visit a parse tree produced by AgtypeParser#IntegerValue.
+ def visitIntegerValue(self, ctx:AgtypeParser.IntegerValueContext):
+ return int(ctx.INTEGER().getText())
+
+ # Visit a parse tree produced by AgtypeParser#floatLiteral.
+ def visitFloatLiteral(self, ctx:AgtypeParser.FloatLiteralContext):
+ c = ctx.getChild(0)
+ tp = c.symbol.type
+ text = ctx.getText()
+ if tp == AgtypeParser.RegularFloat:
+ return float(text)
+ elif tp == AgtypeParser.ExponentFloat:
+ return float(text)
+ else:
+ if text == 'NaN':
+ return float('nan')
+ elif text == '-Infinity':
+ return float('-inf')
+ elif text == 'Infinity':
+ return float('inf')
+ else:
+ raise ValueError("Unknown float expression: " + text)
+
+
+ # Visit a parse tree produced by AgtypeParser#TrueBoolean.
+ def visitTrueBoolean(self, ctx:AgtypeParser.TrueBooleanContext):
+ return True
+
+
+ # Visit a parse tree produced by AgtypeParser#FalseBoolean.
+ def visitFalseBoolean(self, ctx:AgtypeParser.FalseBooleanContext):
+ return False
+
+
+ # Visit a parse tree produced by AgtypeParser#NullValue.
+ def visitNullValue(self, ctx:AgtypeParser.NullValueContext):
+ return None
+
+
+ # Visit a parse tree produced by AgtypeParser#obj.
+ def visitObj(self, ctx:AgtypeParser.ObjContext):
+ obj = dict()
+ for c in ctx.getChildren():
+ if isinstance(c, AgtypeParser.PairContext):
+ namVal = self.visitPair(c)
+ name = namVal[0]
+ valCtx = namVal[1]
+ val = valCtx.accept(self)
+ obj[name] = val
+ return obj
+
+
+ # Visit a parse tree produced by AgtypeParser#pair.
+ def visitPair(self, ctx:AgtypeParser.PairContext):
+ self.visitChildren(ctx)
+ return (ctx.STRING().getText().strip('"') , ctx.agValue())
+
+
+ # Visit a parse tree produced by AgtypeParser#array.
+ def visitArray(self, ctx:AgtypeParser.ArrayContext):
+ li = list()
+ for c in ctx.getChildren():
+ if not isinstance(c, TerminalNode):
+ val = c.accept(self)
+ li.append(val)
+ return li
+
+ def handleAnnotatedValue(self, anno:str, ctx:ParserRuleContext):
+ if anno == "numeric":
+ return Decimal(ctx.getText())
+ elif anno == "vertex":
+ dict = ctx.accept(self)
+ vid = dict["id"]
+ vertex = None
+ if self.vertexCache != None and vid in self.vertexCache :
+ vertex = self.vertexCache[vid]
+ else:
+ vertex = Vertex()
+ vertex.id = dict["id"]
+ vertex.label = dict["label"]
+ vertex.properties = dict["properties"]
+
+ if self.vertexCache != None:
+ self.vertexCache[vid] = vertex
+
+ return vertex
+
+ elif anno == "edge":
+ edge = Edge()
+ dict = ctx.accept(self)
+ edge.id = dict["id"]
+ edge.label = dict["label"]
+ edge.end_id = dict["end_id"]
+ edge.start_id = dict["start_id"]
+ edge.properties = dict["properties"]
+
+ return edge
+
+ elif anno == "path":
+ arr = ctx.accept(self)
+ path = Path(arr)
+
+ return path
+
+ return ctx.accept(self)
diff --git a/drivers/python/age/exceptions.py b/drivers/python/age/exceptions.py
index 3aa94f4b..18292cc0 100644
--- a/drivers/python/age/exceptions.py
+++ b/drivers/python/age/exceptions.py
@@ -16,39 +16,74 @@
from psycopg.errors import *
class AgeNotSet(Exception):
- def __init__(self, name):
+ def __init__(self, name=None):
self.name = name
+ super().__init__(name or 'AGE extension is not set.')
- def __repr__(self) :
+ def __repr__(self):
return 'AGE extension is not set.'
class GraphNotFound(Exception):
- def __init__(self, name):
+ def __init__(self, name=None):
self.name = name
+ super().__init__(f'Graph[{name}] does not exist.' if name else 'Graph
does not exist.')
- def __repr__(self) :
- return 'Graph[' + self.name + '] does not exist.'
+ def __repr__(self):
+ if self.name:
+ return 'Graph[' + self.name + '] does not exist.'
+ return 'Graph does not exist.'
class GraphAlreadyExists(Exception):
- def __init__(self, name):
+ def __init__(self, name=None):
self.name = name
+ super().__init__(f'Graph[{name}] already exists.' if name else 'Graph
already exists.')
- def __repr__(self) :
- return 'Graph[' + self.name + '] already exists.'
+ def __repr__(self):
+ if self.name:
+ return 'Graph[' + self.name + '] already exists.'
+ return 'Graph already exists.'
+
+
+class InvalidGraphName(Exception):
+ """Raised when a graph name contains invalid characters."""
+ def __init__(self, name, reason=None):
+ self.name = name
+ self.reason = reason
+ msg = f"Invalid graph name: '{name}'."
+ if reason:
+ msg += f" {reason}"
+ super().__init__(msg)
+
+ def __repr__(self):
+ return f"InvalidGraphName('{self.name}')"
+
+
+class InvalidIdentifier(Exception):
+ """Raised when an identifier (column, label, etc.) is invalid."""
+ def __init__(self, name, context=None):
+ self.name = name
+ self.context = context
+ msg = f"Invalid identifier: '{name}'."
+ if context:
+ msg += f" {context}"
+ super().__init__(msg)
+
+ def __repr__(self):
+ return f"InvalidIdentifier('{self.name}')"
class GraphNotSet(Exception):
- def __repr__(self) :
+ def __repr__(self):
return 'Graph name is not set.'
class NoConnection(Exception):
- def __repr__(self) :
+ def __repr__(self):
return 'No Connection'
class NoCursor(Exception):
- def __repr__(self) :
+ def __repr__(self):
return 'No Cursor'
class SqlExecutionError(Exception):
@@ -57,7 +92,7 @@ class SqlExecutionError(Exception):
self.cause = cause
super().__init__(msg, cause)
- def __repr__(self) :
+ def __repr__(self):
return 'SqlExecution [' + self.msg + ']'
class AGTypeError(Exception):
diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py
index aee1b759..6d909548 100644
--- a/drivers/python/age/models.py
+++ b/drivers/python/age/models.py
@@ -1,294 +1,294 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-from io import StringIO
-
-
-TP_NONE = 0
-TP_VERTEX = 1
-TP_EDGE = 2
-TP_PATH = 3
-
-
-class Graph():
- def __init__(self, stmt=None) -> None:
- self.statement = stmt
- self.rows = list()
- self.vertices = dict()
-
- def __iter__(self):
- return self.rows.__iter__()
-
- def __len__(self):
- return self.rows.__len__()
-
- def __getitem__(self,index):
- return self.rows[index]
-
- def size(self):
- return self.rows.__len__()
-
- def append(self, agObj):
- self.rows.append(agObj)
-
- def getVertices(self):
- return self.vertices
-
- def getVertex(self, id):
- if id in self.vertices:
- return self.vertices[id]
- else:
- return None
-
-class AGObj:
- @property
- def gtype(self):
- return TP_NONE
-
-
-class Path(AGObj):
- entities = []
- def __init__(self, entities=None) -> None:
- self.entities = entities
-
- @property
- def gtype(self):
- return TP_PATH
-
- def __iter__(self):
- return self.entities.__iter__()
-
- def __len__(self):
- return self.entities.__len__()
-
- def __getitem__(self,index):
- return self.entities[index]
-
- def size(self):
- return self.entities.__len__()
-
- def append(self, agObj:AGObj ):
- self.entities.append(agObj)
-
- def __str__(self) -> str:
- return self.toString()
-
- def __repr__(self) -> str:
- return self.toString()
-
- def toString(self) -> str:
- buf = StringIO()
- buf.write("[")
- max = len(self.entities)
- idx = 0
- while idx < max:
- if idx > 0:
- buf.write(",")
- self.entities[idx]._toString(buf)
- idx += 1
- buf.write("]::PATH")
-
- return buf.getvalue()
-
- def toJson(self) -> str:
- buf = StringIO()
- buf.write("{\"gtype\": \"path\", \"elements\": [")
-
- max = len(self.entities)
- idx = 0
- while idx < max:
- if idx > 0:
- buf.write(",")
- self.entities[idx]._toJson(buf)
- idx += 1
- buf.write("]}")
-
- return buf.getvalue()
-
-
-
-
-class Vertex(AGObj):
- def __init__(self, id=None, label=None, properties=None) -> None:
- self.id = id
- self.label = label
- self.properties = properties
-
- @property
- def gtype(self):
- return TP_VERTEX
-
- def __setitem__(self,name, value):
- self.properties[name]=value
-
- def __getitem__(self,name):
- if name in self.properties:
- return self.properties[name]
- else:
- return None
-
- def __str__(self) -> str:
- return self.toString()
-
- def __repr__(self) -> str:
- return self.toString()
-
- def toString(self) -> str:
- return nodeToString(self)
-
- def _toString(self, buf):
- _nodeToString(self, buf)
-
- def toJson(self) -> str:
- return nodeToJson(self)
-
- def _toJson(self, buf):
- _nodeToJson(self, buf)
-
-
-class Edge(AGObj):
- def __init__(self, id=None, label=None, properties=None) -> None:
- self.id = id
- self.label = label
- self.start_id = None
- self.end_id = None
- self.properties = properties
-
- @property
- def gtype(self):
- return TP_EDGE
-
- def __setitem__(self,name, value):
- self.properties[name]=value
-
- def __getitem__(self,name):
- if name in self.properties:
- return self.properties[name]
- else:
- return None
-
- def __str__(self) -> str:
- return self.toString()
-
- def __repr__(self) -> str:
- return self.toString()
-
- def extraStrFormat(node, buf):
- if node.start_id != None:
- buf.write(", start_id:")
- buf.write(str(node.start_id))
-
- if node.end_id != None:
- buf.write(", end_id:")
- buf.write(str(node.end_id))
-
-
- def toString(self) -> str:
- return nodeToString(self, Edge.extraStrFormat)
-
- def _toString(self, buf):
- _nodeToString(self, buf, Edge.extraStrFormat)
-
- def extraJsonFormat(node, buf):
- if node.start_id != None:
- buf.write(", \"start_id\": \"")
- buf.write(str(node.start_id))
- buf.write("\"")
-
- if node.end_id != None:
- buf.write(", \"end_id\": \"")
- buf.write(str(node.end_id))
- buf.write("\"")
-
- def toJson(self) -> str:
- return nodeToJson(self, Edge.extraJsonFormat)
-
- def _toJson(self, buf):
- _nodeToJson(self, buf, Edge.extraJsonFormat)
-
-
-def nodeToString(node, extraFormatter=None):
- buf = StringIO()
- _nodeToString(node,buf,extraFormatter=extraFormatter)
- return buf.getvalue()
-
-
-def _nodeToString(node, buf, extraFormatter=None):
- buf.write("{")
- if node.label != None:
- buf.write("label:")
- buf.write(node.label)
-
- if node.id != None:
- buf.write(", id:")
- buf.write(str(node.id))
-
- if node.properties != None:
- buf.write(", properties:{")
- prop_list = []
- for k, v in node.properties.items():
- prop_list.append(f"{k}: {str(v)}")
-
- # Join properties with comma and write to buffer
- buf.write(", ".join(prop_list))
- buf.write("}")
-
- if extraFormatter != None:
- extraFormatter(node, buf)
-
- if node.gtype == TP_VERTEX:
- buf.write("}::VERTEX")
- if node.gtype == TP_EDGE:
- buf.write("}::EDGE")
-
-
-def nodeToJson(node, extraFormatter=None):
- buf = StringIO()
- _nodeToJson(node, buf, extraFormatter=extraFormatter)
- return buf.getvalue()
-
-
-def _nodeToJson(node, buf, extraFormatter=None):
- buf.write("{\"gtype\": ")
- if node.gtype == TP_VERTEX:
- buf.write("\"vertex\", ")
- if node.gtype == TP_EDGE:
- buf.write("\"edge\", ")
-
- if node.label != None:
- buf.write("\"label\":\"")
- buf.write(node.label)
- buf.write("\"")
-
- if node.id != None:
- buf.write(", \"id\":")
- buf.write(str(node.id))
-
- if extraFormatter != None:
- extraFormatter(node, buf)
-
- if node.properties != None:
- buf.write(", \"properties\":{")
-
- prop_list = []
- for k, v in node.properties.items():
- prop_list.append(f"\"{k}\": \"{str(v)}\"")
-
- # Join properties with comma and write to buffer
- buf.write(", ".join(prop_list))
- buf.write("}")
- buf.write("}")
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from io import StringIO
+
+
+TP_NONE = 0
+TP_VERTEX = 1
+TP_EDGE = 2
+TP_PATH = 3
+
+
+class Graph():
+ def __init__(self, stmt=None) -> None:
+ self.statement = stmt
+ self.rows = list()
+ self.vertices = dict()
+
+ def __iter__(self):
+ return self.rows.__iter__()
+
+ def __len__(self):
+ return self.rows.__len__()
+
+ def __getitem__(self,index):
+ return self.rows[index]
+
+ def size(self):
+ return self.rows.__len__()
+
+ def append(self, agObj):
+ self.rows.append(agObj)
+
+ def getVertices(self):
+ return self.vertices
+
+ def getVertex(self, id):
+ if id in self.vertices:
+ return self.vertices[id]
+ else:
+ return None
+
+class AGObj:
+ @property
+ def gtype(self):
+ return TP_NONE
+
+
+class Path(AGObj):
+ entities = []
+ def __init__(self, entities=None) -> None:
+ self.entities = entities
+
+ @property
+ def gtype(self):
+ return TP_PATH
+
+ def __iter__(self):
+ return self.entities.__iter__()
+
+ def __len__(self):
+ return self.entities.__len__()
+
+ def __getitem__(self,index):
+ return self.entities[index]
+
+ def size(self):
+ return self.entities.__len__()
+
+ def append(self, agObj:AGObj ):
+ self.entities.append(agObj)
+
+ def __str__(self) -> str:
+ return self.toString()
+
+ def __repr__(self) -> str:
+ return self.toString()
+
+ def toString(self) -> str:
+ buf = StringIO()
+ buf.write("[")
+ max = len(self.entities)
+ idx = 0
+ while idx < max:
+ if idx > 0:
+ buf.write(",")
+ self.entities[idx]._toString(buf)
+ idx += 1
+ buf.write("]::PATH")
+
+ return buf.getvalue()
+
+ def toJson(self) -> str:
+ buf = StringIO()
+ buf.write("{\"gtype\": \"path\", \"elements\": [")
+
+ max = len(self.entities)
+ idx = 0
+ while idx < max:
+ if idx > 0:
+ buf.write(",")
+ self.entities[idx]._toJson(buf)
+ idx += 1
+ buf.write("]}")
+
+ return buf.getvalue()
+
+
+
+
+class Vertex(AGObj):
+ def __init__(self, id=None, label=None, properties=None) -> None:
+ self.id = id
+ self.label = label
+ self.properties = properties
+
+ @property
+ def gtype(self):
+ return TP_VERTEX
+
+ def __setitem__(self,name, value):
+ self.properties[name]=value
+
+ def __getitem__(self,name):
+ if name in self.properties:
+ return self.properties[name]
+ else:
+ return None
+
+ def __str__(self) -> str:
+ return self.toString()
+
+ def __repr__(self) -> str:
+ return self.toString()
+
+ def toString(self) -> str:
+ return nodeToString(self)
+
+ def _toString(self, buf):
+ _nodeToString(self, buf)
+
+ def toJson(self) -> str:
+ return nodeToJson(self)
+
+ def _toJson(self, buf):
+ _nodeToJson(self, buf)
+
+
+class Edge(AGObj):
+ def __init__(self, id=None, label=None, properties=None) -> None:
+ self.id = id
+ self.label = label
+ self.start_id = None
+ self.end_id = None
+ self.properties = properties
+
+ @property
+ def gtype(self):
+ return TP_EDGE
+
+ def __setitem__(self,name, value):
+ self.properties[name]=value
+
+ def __getitem__(self,name):
+ if name in self.properties:
+ return self.properties[name]
+ else:
+ return None
+
+ def __str__(self) -> str:
+ return self.toString()
+
+ def __repr__(self) -> str:
+ return self.toString()
+
+ def extraStrFormat(node, buf):
+ if node.start_id != None:
+ buf.write(", start_id:")
+ buf.write(str(node.start_id))
+
+ if node.end_id != None:
+ buf.write(", end_id:")
+ buf.write(str(node.end_id))
+
+
+ def toString(self) -> str:
+ return nodeToString(self, Edge.extraStrFormat)
+
+ def _toString(self, buf):
+ _nodeToString(self, buf, Edge.extraStrFormat)
+
+ def extraJsonFormat(node, buf):
+ if node.start_id != None:
+ buf.write(", \"start_id\": \"")
+ buf.write(str(node.start_id))
+ buf.write("\"")
+
+ if node.end_id != None:
+ buf.write(", \"end_id\": \"")
+ buf.write(str(node.end_id))
+ buf.write("\"")
+
+ def toJson(self) -> str:
+ return nodeToJson(self, Edge.extraJsonFormat)
+
+ def _toJson(self, buf):
+ _nodeToJson(self, buf, Edge.extraJsonFormat)
+
+
+def nodeToString(node, extraFormatter=None):
+ buf = StringIO()
+ _nodeToString(node,buf,extraFormatter=extraFormatter)
+ return buf.getvalue()
+
+
+def _nodeToString(node, buf, extraFormatter=None):
+ buf.write("{")
+ if node.label != None:
+ buf.write("label:")
+ buf.write(node.label)
+
+ if node.id != None:
+ buf.write(", id:")
+ buf.write(str(node.id))
+
+ if node.properties != None:
+ buf.write(", properties:{")
+ prop_list = []
+ for k, v in node.properties.items():
+ prop_list.append(f"{k}: {str(v)}")
+
+ # Join properties with comma and write to buffer
+ buf.write(", ".join(prop_list))
+ buf.write("}")
+
+ if extraFormatter != None:
+ extraFormatter(node, buf)
+
+ if node.gtype == TP_VERTEX:
+ buf.write("}::VERTEX")
+ if node.gtype == TP_EDGE:
+ buf.write("}::EDGE")
+
+
+def nodeToJson(node, extraFormatter=None):
+ buf = StringIO()
+ _nodeToJson(node, buf, extraFormatter=extraFormatter)
+ return buf.getvalue()
+
+
+def _nodeToJson(node, buf, extraFormatter=None):
+ buf.write("{\"gtype\": ")
+ if node.gtype == TP_VERTEX:
+ buf.write("\"vertex\", ")
+ if node.gtype == TP_EDGE:
+ buf.write("\"edge\", ")
+
+ if node.label != None:
+ buf.write("\"label\":\"")
+ buf.write(node.label)
+ buf.write("\"")
+
+ if node.id != None:
+ buf.write(", \"id\":")
+ buf.write(str(node.id))
+
+ if extraFormatter != None:
+ extraFormatter(node, buf)
+
+ if node.properties != None:
+ buf.write(", \"properties\":{")
+
+ prop_list = []
+ for k, v in node.properties.items():
+ prop_list.append(f"\"{k}\": \"{str(v)}\"")
+
+ # Join properties with comma and write to buffer
+ buf.write(", ".join(prop_list))
+ buf.write("}")
+ buf.write("}")
\ No newline at end of file
diff --git a/drivers/python/age/networkx/lib.py
b/drivers/python/age/networkx/lib.py
index 30865862..5df761ea 100644
--- a/drivers/python/age/networkx/lib.py
+++ b/drivers/python/age/networkx/lib.py
@@ -20,17 +20,18 @@ import networkx as nx
from psycopg import sql
from typing import Dict, Any, List, Set
from age.models import Vertex, Edge, Path
+from age.age import validate_graph_name, validate_identifier
def checkIfGraphNameExistInAGE(connection: psycopg.connect,
graphName: str):
"""Check if the age graph exists"""
+ validate_graph_name(graphName)
with connection.cursor() as cursor:
- cursor.execute(sql.SQL("""
- SELECT count(*)
- FROM ag_catalog.ag_graph
- WHERE name='%s'
- """ % (graphName)))
+ cursor.execute(
+ sql.SQL("SELECT count(*) FROM ag_catalog.ag_graph WHERE name={gn}")
+ .format(gn=sql.Literal(graphName))
+ )
if cursor.fetchone()[0] == 0:
raise GraphNotFound(graphName)
@@ -38,11 +39,13 @@ def checkIfGraphNameExistInAGE(connection: psycopg.connect,
def getOidOfGraph(connection: psycopg.connect,
graphName: str) -> int:
"""Returns oid of a graph"""
+ validate_graph_name(graphName)
try:
with connection.cursor() as cursor:
- cursor.execute(sql.SQL("""
- SELECT graphid FROM ag_catalog.ag_graph WHERE
name='%s' ;
- """ % (graphName)))
+ cursor.execute(
+ sql.SQL("SELECT graphid FROM ag_catalog.ag_graph WHERE
name={gn}")
+ .format(gn=sql.Literal(graphName))
+ )
oid = cursor.fetchone()[0]
return oid
except Exception as e:
@@ -56,7 +59,9 @@ def get_vlabel(connection: psycopg.connect,
try:
with connection.cursor() as cursor:
cursor.execute(
- """SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND
graph=%s;""" % oid)
+ sql.SQL("SELECT name FROM ag_catalog.ag_label WHERE kind='v'
AND graph={oid}")
+ .format(oid=sql.Literal(oid))
+ )
for row in cursor:
node_label_list.append(row[0])
@@ -69,18 +74,19 @@ def create_vlabel(connection: psycopg.connect,
graphName: str,
node_label_list: List):
"""create_vlabels from list if not exist"""
+ validate_graph_name(graphName)
try:
node_label_set = set(get_vlabel(connection, graphName))
- crete_label_statement = ''
for label in node_label_list:
if label in node_label_set:
continue
- crete_label_statement += """SELECT create_vlabel('%s','%s');\n"""
% (
- graphName, label)
- if crete_label_statement != '':
+ validate_identifier(label, "Vertex label")
with connection.cursor() as cursor:
- cursor.execute(crete_label_statement)
- connection.commit()
+ cursor.execute(
+ sql.SQL("SELECT create_vlabel({gn},{lbl})")
+ .format(gn=sql.Literal(graphName), lbl=sql.Literal(label))
+ )
+ connection.commit()
except Exception as e:
raise Exception(e)
@@ -92,7 +98,9 @@ def get_elabel(connection: psycopg.connect,
try:
with connection.cursor() as cursor:
cursor.execute(
- """SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND
graph=%s;""" % oid)
+ sql.SQL("SELECT name FROM ag_catalog.ag_label WHERE kind='e'
AND graph={oid}")
+ .format(oid=sql.Literal(oid))
+ )
for row in cursor:
edge_label_list.append(row[0])
except Exception as ex:
@@ -103,19 +111,20 @@ def get_elabel(connection: psycopg.connect,
def create_elabel(connection: psycopg.connect,
graphName: str,
edge_label_list: List):
- """create_vlabels from list if not exist"""
+ """create_elabels from list if not exist"""
+ validate_graph_name(graphName)
try:
edge_label_set = set(get_elabel(connection, graphName))
- crete_label_statement = ''
for label in edge_label_list:
if label in edge_label_set:
continue
- crete_label_statement += """SELECT create_elabel('%s','%s');\n"""
% (
- graphName, label)
- if crete_label_statement != '':
+ validate_identifier(label, "Edge label")
with connection.cursor() as cursor:
- cursor.execute(crete_label_statement)
- connection.commit()
+ cursor.execute(
+ sql.SQL("SELECT create_elabel({gn},{lbl})")
+ .format(gn=sql.Literal(graphName), lbl=sql.Literal(label))
+ )
+ connection.commit()
except Exception as e:
raise Exception(e)
@@ -171,6 +180,7 @@ def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph):
def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G:
nx.DiGraph, node_label_list: Set):
"""Add all node to AGE"""
+ validate_graph_name(graphName)
try:
queue_data = {label: [] for label in node_label_list}
id_data = {}
@@ -180,8 +190,11 @@ def addAllNodesIntoAGE(connection: psycopg.connect,
graphName: str, G: nx.DiGrap
queue_data[data['label']].append((json_string,))
for label, rows in queue_data.items():
- table_name = """%s."%s" """ % (graphName, label)
- insert_query = f"INSERT INTO {table_name} (properties) VALUES (%s)
RETURNING id"
+ validate_identifier(label, "Node label")
+ insert_query = sql.SQL("INSERT INTO {schema}.{table} (properties)
VALUES (%s) RETURNING id").format(
+ schema=sql.Identifier(graphName),
+ table=sql.Identifier(label)
+ )
cursor = connection.cursor()
cursor.executemany(insert_query, rows, returning=True)
ids = []
@@ -205,6 +218,7 @@ def addAllNodesIntoAGE(connection: psycopg.connect,
graphName: str, G: nx.DiGrap
def addAllEdgesIntoAGE(connection: psycopg.connect, graphName: str, G:
nx.DiGraph, edge_label_list: Set):
"""Add all edge to AGE"""
+ validate_graph_name(graphName)
try:
queue_data = {label: [] for label in edge_label_list}
for u, v, data in G.edges(data=True):
@@ -213,8 +227,11 @@ def addAllEdgesIntoAGE(connection: psycopg.connect,
graphName: str, G: nx.DiGrap
(G.nodes[u]['properties']['__gid__'],
G.nodes[v]['properties']['__gid__'], json_string,))
for label, rows in queue_data.items():
- table_name = """%s."%s" """ % (graphName, label)
- insert_query = f"INSERT INTO {table_name}
(start_id,end_id,properties) VALUES (%s, %s, %s)"
+ validate_identifier(label, "Edge label")
+ insert_query = sql.SQL("INSERT INTO {schema}.{table}
(start_id,end_id,properties) VALUES (%s, %s, %s)").format(
+ schema=sql.Identifier(graphName),
+ table=sql.Identifier(label)
+ )
cursor = connection.cursor()
cursor.executemany(insert_query, rows)
connection.commit()
@@ -225,14 +242,19 @@ def addAllEdgesIntoAGE(connection: psycopg.connect,
graphName: str, G: nx.DiGrap
def addAllNodesIntoNetworkx(connection: psycopg.connect, graphName: str, G:
nx.DiGraph):
"""Add all nodes to Networkx"""
+ validate_graph_name(graphName)
node_label_list = get_vlabel(connection, graphName)
try:
for label in node_label_list:
+ validate_identifier(label, "Node label")
with connection.cursor() as cursor:
- cursor.execute("""
- SELECT id, CAST(properties AS VARCHAR)
- FROM %s."%s";
- """ % (graphName, label))
+ cursor.execute(
+ sql.SQL("SELECT id, CAST(properties AS VARCHAR) FROM
{schema}.{table}")
+ .format(
+ schema=sql.Identifier(graphName),
+ table=sql.Identifier(label)
+ )
+ )
rows = cursor.fetchall()
for row in rows:
G.add_node(int(row[0]), label=label,
@@ -243,14 +265,19 @@ def addAllNodesIntoNetworkx(connection: psycopg.connect,
graphName: str, G: nx.D
def addAllEdgesIntoNetworkx(connection: psycopg.connect, graphName: str, G:
nx.DiGraph):
"""Add All edges to Networkx"""
+ validate_graph_name(graphName)
try:
edge_label_list = get_elabel(connection, graphName)
for label in edge_label_list:
+ validate_identifier(label, "Edge label")
with connection.cursor() as cursor:
- cursor.execute("""
- SELECT start_id, end_id, CAST(properties AS
VARCHAR)
- FROM %s."%s";
- """ % (graphName, label))
+ cursor.execute(
+ sql.SQL("SELECT start_id, end_id, CAST(properties AS
VARCHAR) FROM {schema}.{table}")
+ .format(
+ schema=sql.Identifier(graphName),
+ table=sql.Identifier(label)
+ )
+ )
rows = cursor.fetchall()
for row in rows:
G.add_edge(int(row[0]), int(
diff --git a/drivers/python/requirements.txt b/drivers/python/requirements.txt
index b0593b79..449d38c6 100644
Binary files a/drivers/python/requirements.txt and
b/drivers/python/requirements.txt differ
diff --git a/drivers/python/setup.py b/drivers/python/setup.py
index d0eed26b..853f1006 100644
--- a/drivers/python/setup.py
+++ b/drivers/python/setup.py
@@ -1,22 +1,22 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# This setup.py is maintained for backward compatibility.
-# All package configuration is in pyproject.toml. For installation,
-# use: pip install .
-
-from setuptools import setup
-
-setup()
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This setup.py is maintained for backward compatibility.
+# All package configuration is in pyproject.toml. For installation,
+# use: pip install .
+
+from setuptools import setup
+
+setup()
diff --git a/drivers/python/test_agtypes.py b/drivers/python/test_agtypes.py
index 69bbbc29..4e9752e6 100644
--- a/drivers/python/test_agtypes.py
+++ b/drivers/python/test_agtypes.py
@@ -1,132 +1,132 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import unittest
-from decimal import Decimal
-import math
-import age
-
-
-class TestAgtype(unittest.TestCase):
- resultHandler = None
-
- def __init__(self, methodName: str) -> None:
- super().__init__(methodName=methodName)
- self.resultHandler = age.newResultHandler()
-
- def parse(self, exp):
- return self.resultHandler.parse(exp)
-
- def test_scalar(self):
- print("\nTesting Scalar Value Parsing. Result : ", end='')
-
- mapStr = '{"name": "Smith", "num":123, "yn":true,
"bigInt":123456789123456789123456789123456789::numeric}'
- arrStr = '["name", "Smith", "num", 123, "yn", true,
123456789123456789123456789123456789.8888::numeric]'
- strStr = '"abcd"'
- intStr = '1234'
- floatStr = '1234.56789'
- floatStr2 = '6.45161290322581e+46'
- numericStr1 = '12345678901234567890123456789123456789.789::numeric'
- numericStr2 = '12345678901234567890123456789123456789::numeric'
- boolStr = 'true'
- nullStr = ''
- nanStr = "NaN"
- infpStr = "Infinity"
- infnStr = "-Infinity"
-
- mapVal = self.parse(mapStr)
- arrVal = self.parse(arrStr)
- str = self.parse(strStr)
- intVal = self.parse(intStr)
- floatVal = self.parse(floatStr)
- floatVal2 = self.parse(floatStr2)
- bigFloat = self.parse(numericStr1)
- bigInt = self.parse(numericStr2)
- boolVal = self.parse(boolStr)
- nullVal = self.parse(nullStr)
- nanVal = self.parse(nanStr)
- infpVal = self.parse(infpStr)
- infnVal = self.parse(infnStr)
-
- self.assertEqual(mapVal, {'name': 'Smith', 'num': 123, 'yn': True,
'bigInt': Decimal(
- '123456789123456789123456789123456789')})
- self.assertEqual(arrVal, ["name", "Smith", "num", 123, "yn", True,
Decimal(
- "123456789123456789123456789123456789.8888")])
- self.assertEqual(str, "abcd")
- self.assertEqual(intVal, 1234)
- self.assertEqual(floatVal, 1234.56789)
- self.assertEqual(floatVal2, 6.45161290322581e+46)
- self.assertEqual(bigFloat, Decimal(
- "12345678901234567890123456789123456789.789"))
- self.assertEqual(bigInt, Decimal(
- "12345678901234567890123456789123456789"))
- self.assertEqual(boolVal, True)
- self.assertTrue(math.isnan(nanVal))
- self.assertTrue(math.isinf(infpVal))
- self.assertTrue(math.isinf(infnVal))
-
- def test_vertex(self):
-
- print("\nTesting vertex Parsing. Result : ", end='')
-
- vertexExp = '''{"id": 2251799813685425, "label": "Person",
- "properties": {"name": "Smith", "numInt":123, "numFloat":
384.23424,
-
"bigInt":123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789::numeric,
- "bigFloat":123456789123456789123456789123456789.12345::numeric,
- "yn":true, "nullVal": null}}::vertex'''
-
- vertex = self.parse(vertexExp)
- self.assertEqual(vertex.id, 2251799813685425)
- self.assertEqual(vertex.label, "Person")
- self.assertEqual(vertex["name"], "Smith")
- self.assertEqual(vertex["numInt"], 123)
- self.assertEqual(vertex["numFloat"], 384.23424)
- self.assertEqual(vertex["bigInt"], Decimal(
-
"123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789"))
- self.assertEqual(vertex["bigFloat"], Decimal(
- "123456789123456789123456789123456789.12345"))
- self.assertEqual(vertex["yn"], True)
- self.assertEqual(vertex["nullVal"], None)
-
- def test_path(self):
-
- print("\nTesting Path Parsing. Result : ", end='')
-
- pathExp = '''[{"id": 2251799813685425, "label": "Person",
"properties": {"name": "Smith"}}::vertex,
- {"id": 2533274790396576, "label": "workWith", "end_id":
2251799813685425, "start_id": 2251799813685424,
- "properties": {"weight": 3,
"bigFloat":123456789123456789123456789.12345::numeric}}::edge,
- {"id": 2251799813685424, "label": "Person", "properties": {"name":
"Joe"}}::vertex]::path'''
-
- path = self.parse(pathExp)
- vertexStart = path[0]
- edge = path[1]
- vertexEnd = path[2]
- self.assertEqual(vertexStart.id, 2251799813685425)
- self.assertEqual(vertexStart.label, "Person")
- self.assertEqual(vertexStart["name"], "Smith")
-
- self.assertEqual(edge.id, 2533274790396576)
- self.assertEqual(edge.label, "workWith")
- self.assertEqual(edge["weight"], 3)
- self.assertEqual(edge["bigFloat"], Decimal(
- "123456789123456789123456789.12345"))
-
- self.assertEqual(vertexEnd.id, 2251799813685424)
- self.assertEqual(vertexEnd.label, "Person")
- self.assertEqual(vertexEnd["name"], "Joe")
-
-
-if __name__ == '__main__':
- unittest.main()
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from decimal import Decimal
+import math
+import age
+
+
+class TestAgtype(unittest.TestCase):
+ resultHandler = None
+
+ def __init__(self, methodName: str) -> None:
+ super().__init__(methodName=methodName)
+ self.resultHandler = age.newResultHandler()
+
+ def parse(self, exp):
+ return self.resultHandler.parse(exp)
+
+ def test_scalar(self):
+ print("\nTesting Scalar Value Parsing. Result : ", end='')
+
+ mapStr = '{"name": "Smith", "num":123, "yn":true,
"bigInt":123456789123456789123456789123456789::numeric}'
+ arrStr = '["name", "Smith", "num", 123, "yn", true,
123456789123456789123456789123456789.8888::numeric]'
+ strStr = '"abcd"'
+ intStr = '1234'
+ floatStr = '1234.56789'
+ floatStr2 = '6.45161290322581e+46'
+ numericStr1 = '12345678901234567890123456789123456789.789::numeric'
+ numericStr2 = '12345678901234567890123456789123456789::numeric'
+ boolStr = 'true'
+ nullStr = ''
+ nanStr = "NaN"
+ infpStr = "Infinity"
+ infnStr = "-Infinity"
+
+ mapVal = self.parse(mapStr)
+ arrVal = self.parse(arrStr)
+ str = self.parse(strStr)
+ intVal = self.parse(intStr)
+ floatVal = self.parse(floatStr)
+ floatVal2 = self.parse(floatStr2)
+ bigFloat = self.parse(numericStr1)
+ bigInt = self.parse(numericStr2)
+ boolVal = self.parse(boolStr)
+ nullVal = self.parse(nullStr)
+ nanVal = self.parse(nanStr)
+ infpVal = self.parse(infpStr)
+ infnVal = self.parse(infnStr)
+
+ self.assertEqual(mapVal, {'name': 'Smith', 'num': 123, 'yn': True,
'bigInt': Decimal(
+ '123456789123456789123456789123456789')})
+ self.assertEqual(arrVal, ["name", "Smith", "num", 123, "yn", True,
Decimal(
+ "123456789123456789123456789123456789.8888")])
+ self.assertEqual(str, "abcd")
+ self.assertEqual(intVal, 1234)
+ self.assertEqual(floatVal, 1234.56789)
+ self.assertEqual(floatVal2, 6.45161290322581e+46)
+ self.assertEqual(bigFloat, Decimal(
+ "12345678901234567890123456789123456789.789"))
+ self.assertEqual(bigInt, Decimal(
+ "12345678901234567890123456789123456789"))
+ self.assertEqual(boolVal, True)
+ self.assertTrue(math.isnan(nanVal))
+ self.assertTrue(math.isinf(infpVal))
+ self.assertTrue(math.isinf(infnVal))
+
+ def test_vertex(self):
+
+ print("\nTesting vertex Parsing. Result : ", end='')
+
+ vertexExp = '''{"id": 2251799813685425, "label": "Person",
+ "properties": {"name": "Smith", "numInt":123, "numFloat":
384.23424,
+
"bigInt":123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789::numeric,
+ "bigFloat":123456789123456789123456789123456789.12345::numeric,
+ "yn":true, "nullVal": null}}::vertex'''
+
+ vertex = self.parse(vertexExp)
+ self.assertEqual(vertex.id, 2251799813685425)
+ self.assertEqual(vertex.label, "Person")
+ self.assertEqual(vertex["name"], "Smith")
+ self.assertEqual(vertex["numInt"], 123)
+ self.assertEqual(vertex["numFloat"], 384.23424)
+ self.assertEqual(vertex["bigInt"], Decimal(
+
"123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789"))
+ self.assertEqual(vertex["bigFloat"], Decimal(
+ "123456789123456789123456789123456789.12345"))
+ self.assertEqual(vertex["yn"], True)
+ self.assertEqual(vertex["nullVal"], None)
+
+ def test_path(self):
+
+ print("\nTesting Path Parsing. Result : ", end='')
+
+ pathExp = '''[{"id": 2251799813685425, "label": "Person",
"properties": {"name": "Smith"}}::vertex,
+ {"id": 2533274790396576, "label": "workWith", "end_id":
2251799813685425, "start_id": 2251799813685424,
+ "properties": {"weight": 3,
"bigFloat":123456789123456789123456789.12345::numeric}}::edge,
+ {"id": 2251799813685424, "label": "Person", "properties": {"name":
"Joe"}}::vertex]::path'''
+
+ path = self.parse(pathExp)
+ vertexStart = path[0]
+ edge = path[1]
+ vertexEnd = path[2]
+ self.assertEqual(vertexStart.id, 2251799813685425)
+ self.assertEqual(vertexStart.label, "Person")
+ self.assertEqual(vertexStart["name"], "Smith")
+
+ self.assertEqual(edge.id, 2533274790396576)
+ self.assertEqual(edge.label, "workWith")
+ self.assertEqual(edge["weight"], 3)
+ self.assertEqual(edge["bigFloat"], Decimal(
+ "123456789123456789123456789.12345"))
+
+ self.assertEqual(vertexEnd.id, 2251799813685424)
+ self.assertEqual(vertexEnd.label, "Person")
+ self.assertEqual(vertexEnd["name"], "Joe")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/drivers/python/test_networkx.py b/drivers/python/test_networkx.py
index 310d2cf5..dbaaf866 100644
--- a/drivers/python/test_networkx.py
+++ b/drivers/python/test_networkx.py
@@ -224,7 +224,7 @@ class TestAgeToNetworkx(unittest.TestCase):
with self.assertRaises(GraphNotFound) as context:
age_to_networkx(ag.connection, graphName=non_existing_graph)
# Check the raised exception has the expected error message
- self.assertEqual(str(context.exception), non_existing_graph)
+ self.assertIn(non_existing_graph, str(context.exception))
class TestNetworkxToAGE(unittest.TestCase):
diff --git a/drivers/python/test_security.py b/drivers/python/test_security.py
new file mode 100644
index 00000000..55347868
--- /dev/null
+++ b/drivers/python/test_security.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Security tests for the Apache AGE Python driver.
+
+Tests input validation, SQL injection prevention, and exception handling.
+"""
+
+import unittest
+from age.age import (
+ validate_graph_name,
+ validate_identifier,
+ buildCypher,
+ _validate_column,
+)
+from age.exceptions import (
+ AgeNotSet,
+ GraphNotFound,
+ GraphAlreadyExists,
+ GraphNotSet,
+ InvalidGraphName,
+ InvalidIdentifier,
+)
+
+
+class TestGraphNameValidation(unittest.TestCase):
+ """Test validate_graph_name rejects dangerous inputs."""
+
+ def test_rejects_empty_string(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('')
+
+ def test_rejects_none(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name(None)
+
+ def test_rejects_non_string(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name(123)
+
+ def test_rejects_digit_start(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('123graph')
+
+ def test_rejects_sql_injection_drop_table(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name("'; DROP TABLE ag_graph; --")
+
+ def test_rejects_sql_injection_semicolon(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name("test'); DROP TABLE users; --")
+
+ def test_rejects_sql_injection_select(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name("graph; SELECT * FROM pg_shadow")
+
+ def test_accepts_hyphenated_graph_name(self):
+ # AGE allows hyphens in middle positions of graph names.
+ validate_graph_name('my-graph')
+
+ def test_rejects_space(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('my graph')
+
+ def test_accepts_dotted_graph_name(self):
+ # AGE allows dots in middle positions of graph names.
+ validate_graph_name('my.graph')
+
+ def test_rejects_dollar(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('my$graph')
+
+ def test_rejects_exceeding_63_chars(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('a' * 64)
+
+ def test_accepts_valid_names(self):
+ # These should NOT raise
+ validate_graph_name('my_graph')
+ validate_graph_name('MyGraph')
+ validate_graph_name('_pr_ivate')
+ validate_graph_name('graph123')
+ validate_graph_name('my-graph')
+ validate_graph_name('my.graph')
+ validate_graph_name('a-b.c_d')
+ validate_graph_name('abc')
+ validate_graph_name('a' * 63)
+
+ def test_rejects_shorter_than_3_chars(self):
+ # AGE requires minimum 3 character graph names.
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('a')
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('ab')
+
+ def test_rejects_name_ending_with_hyphen(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('graph-')
+
+ def test_rejects_name_ending_with_dot(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('graph.')
+
+ def test_rejects_name_starting_with_hyphen(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('-graph')
+
+ def test_rejects_name_starting_with_dot(self):
+ with self.assertRaises(InvalidGraphName):
+ validate_graph_name('.graph')
+
+ def test_error_message_contains_name(self):
+ try:
+ validate_graph_name("bad;name")
+ self.fail("Expected InvalidGraphName")
+ except InvalidGraphName as e:
+ self.assertIn("bad;name", str(e))
+ self.assertIn("Invalid graph name", str(e))
+
+
+class TestIdentifierValidation(unittest.TestCase):
+ """Test validate_identifier rejects dangerous inputs."""
+
+ def test_rejects_empty_string(self):
+ with self.assertRaises(InvalidIdentifier):
+ validate_identifier('')
+
+ def test_rejects_none(self):
+ with self.assertRaises(InvalidIdentifier):
+ validate_identifier(None)
+
+ def test_rejects_sql_injection(self):
+ with self.assertRaises(InvalidIdentifier):
+ validate_identifier("Person'; DROP TABLE--")
+
+ def test_rejects_special_chars(self):
+ with self.assertRaises(InvalidIdentifier):
+ validate_identifier("col; DROP TABLE")
+
+ def test_accepts_valid_identifiers(self):
+ validate_identifier('Person')
+ validate_identifier('KNOWS')
+ validate_identifier('_internal')
+ validate_identifier('col1')
+
+ def test_error_includes_context(self):
+ try:
+ validate_identifier("bad;name", "Column name")
+ self.fail("Expected InvalidIdentifier")
+ except InvalidIdentifier as e:
+ self.assertIn("Column name", str(e))
+
+
+class TestColumnValidation(unittest.TestCase):
+ """Test _validate_column prevents injection through column specs."""
+
+ def test_plain_column_name(self):
+ self.assertEqual(_validate_column('v'), 'v agtype')
+
+ def test_column_with_type(self):
+ self.assertEqual(_validate_column('n agtype'), 'n agtype')
+
+ def test_empty_column(self):
+ self.assertEqual(_validate_column(''), '')
+ self.assertEqual(_validate_column(' '), '')
+
+ def test_rejects_injection_in_column_name(self):
+ with self.assertRaises(InvalidIdentifier):
+ _validate_column("v); DROP TABLE ag_graph; --")
+
+ def test_rejects_injection_in_column_type(self):
+ with self.assertRaises(InvalidIdentifier):
+ _validate_column("v agtype); DROP TABLE")
+
+ def test_rejects_three_part_column(self):
+ with self.assertRaises(InvalidIdentifier):
+ _validate_column("a b c")
+
+ def test_rejects_semicolon_in_name(self):
+ with self.assertRaises(InvalidIdentifier):
+ _validate_column("col;")
+
+
+class TestBuildCypher(unittest.TestCase):
+ """Test buildCypher validates columns and rejects injection."""
+
+ def test_default_column(self):
+ result = buildCypher('test_graph', 'MATCH (n) RETURN n', None)
+ self.assertIn('v agtype', result)
+
+ def test_single_column(self):
+ result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n'])
+ self.assertIn('n agtype', result)
+
+ def test_typed_column(self):
+ result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n agtype'])
+ self.assertIn('n agtype', result)
+
+ def test_multiple_columns(self):
+ result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['a', 'b'])
+ self.assertIn('a agtype', result)
+ self.assertIn('b agtype', result)
+
+ def test_rejects_injection_in_column(self):
+ with self.assertRaises(InvalidIdentifier):
+ buildCypher('test_graph', 'MATCH (n) RETURN n',
+ ["v); DROP TABLE ag_graph;--"])
+
+ def test_rejects_none_graph_name(self):
+ with self.assertRaises(GraphNotSet):
+ buildCypher(None, 'MATCH (n) RETURN n', None)
+
+
+class TestExceptionConstructors(unittest.TestCase):
+ """Test that exception constructors work correctly."""
+
+ def test_age_not_set_no_args(self):
+ """AgeNotSet() must work without arguments (previously crashed)."""
+ e = AgeNotSet()
+ self.assertIsNone(e.name)
+ self.assertIn('not set', repr(e))
+
+ def test_age_not_set_with_message(self):
+ e = AgeNotSet("custom message")
+ self.assertEqual(e.name, "custom message")
+
+ def test_graph_not_found_no_args(self):
+ e = GraphNotFound()
+ self.assertIsNone(e.name)
+ self.assertIn('does not exist', repr(e))
+
+ def test_graph_not_found_with_name(self):
+ e = GraphNotFound("test_graph")
+ self.assertEqual(e.name, "test_graph")
+ self.assertIn('test_graph', repr(e))
+
+ def test_graph_already_exists_no_args(self):
+ e = GraphAlreadyExists()
+ self.assertIsNone(e.name)
+ self.assertIn('already exists', repr(e))
+
+ def test_graph_already_exists_with_name(self):
+ e = GraphAlreadyExists("test_graph")
+ self.assertEqual(e.name, "test_graph")
+ self.assertIn('test_graph', repr(e))
+
+ def test_invalid_graph_name_fields(self):
+ e = InvalidGraphName("bad;name", "must be valid")
+ self.assertEqual(e.name, "bad;name")
+ self.assertEqual(e.reason, "must be valid")
+ self.assertIn("bad;name", str(e))
+ self.assertIn("must be valid", str(e))
+
+ def test_invalid_identifier_fields(self):
+ e = InvalidIdentifier("col;drop", "Column name")
+ self.assertEqual(e.name, "col;drop")
+ self.assertEqual(e.context, "Column name")
+ self.assertIn("col;drop", str(e))
+
+
+if __name__ == '__main__':
+ unittest.main()