This is an automated email from the ASF dual-hosted git repository.
joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git
The following commit(s) were added to refs/heads/trunk by this push:
new 63b6d78 CASSGO 1 Support for Native Protocol 5
63b6d78 is described below
commit 63b6d7830710a0c6c5411c08bf16d38cb88d2a8a
Author: Bohdan Siryk <[email protected]>
AuthorDate: Thu Jul 18 12:33:47 2024 +0300
CASSGO 1 Support for Native Protocol 5
Native Protocol 5 was introduced with the release of C* 4.0. This PR
provides full support
for a newer version including new format frames (segments), and new fields
for QUERY, BATCH, and EXECUTE messages.
Also, this PR brings changes to the Compressor interface to follow an
append-like design.
One more thing, it bumps Go version to the newer 1.19.
Patch by Bohdan Siryk; Reviewed by João Reis, James Hartig for CASSGO-1
CASSGO-30
---
.github/workflows/main.yml | 19 +-
CHANGELOG.md | 5 +
batch_test.go | 82 ++++++++
cassandra_test.go | 467 ++++++++++++++++++++++++++++++++++++++++++++-
common_test.go | 7 +-
compressor.go | 40 +++-
compressor_test.go | 6 +-
conn.go | 386 +++++++++++++++++++++++++++++--------
conn_test.go | 108 ++++++++++-
control.go | 4 +-
crc.go | 58 ++++++
crc_test.go | 90 +++++++++
frame.go | 418 +++++++++++++++++++++++++++++++++++-----
frame_test.go | 314 ++++++++++++++++++++++++++++++
lz4/lz4.go | 67 +++++--
lz4/lz4_test.go | 205 +++++++++++++++++++-
prepared_cache.go | 17 ++
session.go | 70 ++++++-
18 files changed, 2178 insertions(+), 185 deletions(-)
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 714a46a..0ca9d20 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -37,8 +37,12 @@ jobs:
go: [ '1.22', '1.23' ]
cassandra_version: [ '4.0.13', '4.1.6' ]
auth: [ "false" ]
- compressor: [ "snappy" ]
+ compressor: [ "no-compression", "snappy", "lz4" ]
tags: [ "cassandra", "integration", "ccm" ]
+ proto_version: [ "4", "5" ]
+ exclude:
+ - proto_version: "5"
+ compressor: "snappy"
steps:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
@@ -102,7 +106,7 @@ jobs:
ccm status
ccm node1 nodetool status
- args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=3
-autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION
-cluster=$(ccm liveset) ./..."
+ args="-gocql.timeout=60s -runssl -proto=${{ matrix.proto_version }}
-rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }}
-gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..."
echo "args=$args" >> $GITHUB_ENV
echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV
@@ -115,7 +119,7 @@ jobs:
if: 'failure()'
uses: actions/upload-artifact@v4
with:
- name: ccm-cluster-cassandra-${{ matrix.cassandra_version }}-go-${{
matrix.go }}-tag-${{ matrix.tags }}
+ name: ccm-cluster-cassandra-${{ matrix.cassandra_version }}-go-${{
matrix.go }}-tag-${{ matrix.tags }}-proto-version-${{ matrix.proto_version
}}-compressor-${{ matrix.compressor }}
path: /home/runner/.ccm/test
retention-days: 5
integration-auth-cassandra:
@@ -129,9 +133,12 @@ jobs:
matrix:
go: [ '1.22', '1.23' ]
cassandra_version: [ '4.0.13' ]
- compressor: [ "snappy" ]
+ compressor: [ "no-compression", "snappy", "lz4" ]
tags: [ "integration" ]
-
+ proto_version: [ "4", "5" ]
+ exclude:
+ - proto_version: "5"
+ compressor: "snappy"
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v4
@@ -193,7 +200,7 @@ jobs:
ccm status
ccm node1 nodetool status
- args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=1
-autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION
-cluster=$(ccm liveset) ./..."
+ args="-gocql.timeout=60s -runssl -proto=${{ matrix.proto_version }}
-rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }}
-gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..."
echo "args=$args" >> $GITHUB_ENV
echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7cf2fb4..351b731 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,9 @@ and this project adheres to [Semantic
Versioning](https://semver.org/spec/v2.0.0
- Support of sending queries to the specific node with Query.SetHostID()
(CASSGO-4)
+- Support for Native Protocol 5. Following protocol changes exposed new API
+ Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(),
Batch.WithNowInSeconds() (CASSGO-1)
+
### Changed
- Move lz4 compressor to lz4 package within the gocql module (CASSGO-32)
@@ -43,6 +46,8 @@ and this project adheres to [Semantic
Versioning](https://semver.org/spec/v2.0.0
- Refactor HostInfo creation and ConnectAddress() method (CASSGO-45)
+- gocql.Compressor interface changes to follow append-like design. Bumped Go
version to 1.19 (CASSGO-1)
+
### Fixed
- Cassandra version unmarshal fix (CASSGO-49)
diff --git a/batch_test.go b/batch_test.go
index 44b5266..3a8a7e7 100644
--- a/batch_test.go
+++ b/batch_test.go
@@ -28,6 +28,7 @@
package gocql
import (
+ "github.com/stretchr/testify/require"
"testing"
"time"
)
@@ -86,3 +87,84 @@ func TestBatch_WithTimestamp(t *testing.T) {
t.Errorf("got ts %d, expected %d", storedTs, micros)
}
}
+
+func TestBatch_WithNowInSeconds(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("Batch now in seconds are only available on protocol >=
5")
+ }
+
+ if err := createTable(session, `CREATE TABLE IF NOT EXISTS
batch_now_in_seconds (id int primary key, val text)`); err != nil {
+ t.Fatal(err)
+ }
+
+ b := session.NewBatch(LoggedBatch)
+ b.WithNowInSeconds(0)
+ b.Query("INSERT INTO batch_now_in_seconds (id, val) VALUES (?, ?) USING
TTL 20", 1, "val")
+ if err := session.ExecuteBatch(b); err != nil {
+ t.Fatal(err)
+ }
+
+ var remainingTTL int
+ err := session.Query(`SELECT TTL(val) FROM batch_now_in_seconds WHERE
id = ?`, 1).
+ WithNowInSeconds(10).
+ Scan(&remainingTTL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, remainingTTL, 10)
+}
+
+func TestBatch_SetKeyspace(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("keyspace for BATCH message is not supported in protocol
< 5")
+ }
+
+ const keyspaceStmt = `
+ CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test
+ WITH replication = {
+ 'class': 'SimpleStrategy',
+ 'replication_factor': '1'
+ };
+`
+
+ err := session.Query(keyspaceStmt).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY
(id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ids := []int{1, 2}
+ texts := []string{"val1", "val2"}
+
+ b :=
session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test")
+ b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0],
texts[0])
+ b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1],
texts[1])
+ err = session.ExecuteBatch(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ id int
+ text string
+ )
+
+ iter := session.Query("SELECT * FROM
gocql_keyspace_override_test.batch_keyspace").Iter()
+ defer iter.Close()
+
+ for i := 0; iter.Scan(&id, &text); i++ {
+ require.Equal(t, id, ids[i])
+ require.Equal(t, text, texts[i])
+ }
+}
diff --git a/cassandra_test.go b/cassandra_test.go
index 88ef56d..54a54f4 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -1671,7 +1671,7 @@ func TestQueryInfo(t *testing.T) {
defer session.Close()
conn := getRandomConn(t, session)
- info, err := conn.prepareStatement(context.Background(), "SELECT
release_version, host_id FROM system.local WHERE key = ?", nil)
+ info, err := conn.prepareStatement(context.Background(), "SELECT
release_version, host_id FROM system.local WHERE key = ?", nil,
conn.currentKeyspace)
if err != nil {
t.Fatalf("Failed to execute query for preparing statement: %v",
err)
@@ -2761,7 +2761,7 @@ func TestRoutingKey(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}
- routingKeyInfo, err := session.routingKeyInfo(context.Background(),
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+ routingKeyInfo, err := session.routingKeyInfo(context.Background(),
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
@@ -2785,7 +2785,7 @@ func TestRoutingKey(t *testing.T) {
}
// verify the cache is working
- routingKeyInfo, err = session.routingKeyInfo(context.Background(),
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+ routingKeyInfo, err = session.routingKeyInfo(context.Background(),
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
@@ -2819,7 +2819,7 @@ func TestRoutingKey(t *testing.T) {
t.Errorf("Expected routing key %v but was %v",
expectedRoutingKey, routingKey)
}
- routingKeyInfo, err = session.routingKeyInfo(context.Background(),
"SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
+ routingKeyInfo, err = session.routingKeyInfo(context.Background(),
"SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
@@ -3342,6 +3342,10 @@ func TestCreateSession_DontSwallowError(t *testing.T) {
func TestControl_DiscoverProtocol(t *testing.T) {
cluster := createCluster()
cluster.ProtoVersion = 0
+ // Forcing to run this test without any compression.
+ // If compressor is presented, then CI will fail when snappy
compression is enabled, since
+ // protocol v5 doesn't support it.
+ cluster.Compressor = nil
session, err := cluster.CreateSession()
if err != nil {
@@ -3496,3 +3500,458 @@ func TestQuery_SetHostID(t *testing.T) {
t.Fatalf("Expected error to be: %v, but got %v",
ErrNoConnections, err)
}
}
+
+func TestQuery_WithNowInSeconds(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("Query now in seconds are only available on protocol >=
5")
+ }
+
+ if err := createTable(session, `CREATE TABLE IF NOT EXISTS
query_now_in_seconds (id int primary key, val text)`); err != nil {
+ t.Fatal(err)
+ }
+
+ err := session.Query("INSERT INTO query_now_in_seconds (id, val) VALUES
(?, ?) USING TTL 20", 1, "val").
+ WithNowInSeconds(int(0)).
+ Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var remainingTTL int
+ err = session.Query(`SELECT TTL(val) FROM query_now_in_seconds WHERE id
= ?`, 1).
+ WithNowInSeconds(10).
+ Scan(&remainingTTL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, remainingTTL, 10)
+}
+
+func TestQuery_SetKeyspace(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("keyspace for QUERY message is not supported in protocol
< 5")
+ }
+
+ const keyspaceStmt = `
+ CREATE KEYSPACE IF NOT EXISTS
gocql_query_keyspace_override_test
+ WITH replication = {
+ 'class': 'SimpleStrategy',
+ 'replication_factor': '1'
+ };
+`
+
+ err := session.Query(keyspaceStmt).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_query_keyspace_override_test.query_keyspace(id int, value text, PRIMARY
KEY (id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expectedID := 1
+ expectedText := "text"
+
+ // Testing PREPARE message
+ err = session.Query("INSERT INTO
gocql_query_keyspace_override_test.query_keyspace (id, value) VALUES (?, ?)",
expectedID, expectedText).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ id int
+ text string
+ )
+
+ q := session.Query("SELECT * FROM
gocql_query_keyspace_override_test.query_keyspace").
+ SetKeyspace("gocql_query_keyspace_override_test")
+ err = q.Scan(&id, &text)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, expectedID, id)
+ require.Equal(t, expectedText, text)
+
+ // Testing QUERY message
+ id = 0
+ text = ""
+
+ q = session.Query("SELECT * FROM
gocql_query_keyspace_override_test.query_keyspace").
+ SetKeyspace("gocql_query_keyspace_override_test")
+ q.skipPrepare = true
+ err = q.Scan(&id, &text)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, expectedID, id)
+ require.Equal(t, expectedText, text)
+}
+
+// TestLargeSizeQuery runs a query bigger than the max allowed size of the
payload of a frame,
+// so it should be sent as 2 different frames where each contains a
self-contained bit set to zero.
+func TestLargeSizeQuery(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if err := createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test.large_size_query(id int, text_col text, PRIMARY KEY (id))"); err !=
nil {
+ t.Fatal(err)
+ }
+
+ longString := strings.Repeat("a", 500_000)
+
+ err := session.Query("INSERT INTO gocql_test.large_size_query (id,
text_col) VALUES (?, ?)", "1", longString).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var result string
+ err = session.Query("SELECT text_col FROM
gocql_test.large_size_query").Scan(&result)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, longString, result)
+}
+
+// TestQueryCompressionNotWorthIt runs a query that is not likely to be
compressed efficiently
+// (uncompressed payload size > compressed payload size).
+// So, it should send a Compressed Frame where:
+// 1. Compressed length is set to the length of the uncompressed payload;
+// 2. Uncompressed length is set to zero;
+// 3. Payload is the uncompressed payload.
+func TestQueryCompressionNotWorthIt(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if err := createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test.compression_now_worth_it(id int, text_col text, PRIMARY KEY (id))");
err != nil {
+ t.Fatal(err)
+ }
+
+ str :=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&*()_+"
+ err := session.Query("INSERT INTO gocql_test.large_size_query (id,
text_col) VALUES (?, ?)", "1", str).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var result string
+ err = session.Query("SELECT text_col FROM
gocql_test.large_size_query").Scan(&result)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Equal(t, str, result)
+}
+
+// This test ensures that the whole Metadata_changed flow
+// is handled properly.
+//
+// To trigger C* to return Metadata_changed we should do:
+// 1. Create a table
+// 2. Prepare stmt which uses the created table
+// 3. Change the table schema in order to affect prepared stmt (e.g. add a
column)
+// 4. Execute prepared stmt. As a result C* should return RESULT/ROWS
response with
+// Metadata_changed flag, new metadata id and updated metadata resultset.
+//
+// The driver should handle this by updating its prepared statement inside the
cache
+// when it receives RESULT/ROWS with Metadata_changed flag
+func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("Metadata_changed mechanism is only available in proto >
4")
+ }
+
+ if err := createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test.metadata_changed(id int, PRIMARY KEY (id))"); err != nil {
+ t.Fatal(err)
+ }
+
+ type record struct {
+ id int
+ newCol int
+ }
+
+ firstRecord := record{
+ id: 1,
+ }
+ err := session.Query("INSERT INTO gocql_test.metadata_changed (id)
VALUES (?)", firstRecord.id).Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We have to specify conn for all queries to ensure that
+ // all queries are running on the same node
+ conn := session.getConn()
+
+ const selectStmt = "SELECT * FROM gocql_test.metadata_changed"
+ queryBeforeTableAltering := session.Query(selectStmt)
+ queryBeforeTableAltering.conn = conn
+ row := make(map[string]interface{})
+ err = queryBeforeTableAltering.MapScan(row)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ require.Len(t, row, 1, "Expected to retrieve a single column")
+ require.Equal(t, 1, row["id"])
+
+ stmtCacheKey := session.stmtsLRU.keyFor(conn.host.HostID(),
conn.currentKeyspace, queryBeforeTableAltering.stmt)
+ inflight, _ := session.stmtsLRU.get(stmtCacheKey)
+ preparedStatementBeforeTableAltering := inflight.preparedStatment
+
+ // Changing table schema in order to cause C* to return RESULT/ROWS
Metadata_changed
+ alteringTableQuery := session.Query("ALTER TABLE
gocql_test.metadata_changed ADD new_col int")
+ alteringTableQuery.conn = conn
+ err = alteringTableQuery.Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ secondRecord := record{
+ id: 2,
+ newCol: 10,
+ }
+ err = session.Query("INSERT INTO gocql_test.metadata_changed (id,
new_col) VALUES (?, ?)", secondRecord.id, secondRecord.newCol).
+ Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Handles result from iter and ensures integrity of the result,
+ // closes iter and handles error
+ handleRows := func(iter *Iter) {
+ t.Helper()
+
+ var scannedID int
+ var scannedNewCol *int // to perform null values
+
+ // when the driver handling null values during unmarshalling
+ // it sets to dest type its zero value, which is (*int)(nil)
for this case
+ var nilIntPtr *int
+
+ // Scanning first row
+ if iter.Scan(&scannedID, &scannedNewCol) {
+ require.Equal(t, firstRecord.id, scannedID)
+ require.Equal(t, nilIntPtr, scannedNewCol)
+ }
+
+ // Scanning second row
+ if iter.Scan(&scannedID, &scannedNewCol) {
+ require.Equal(t, secondRecord.id, scannedID)
+ require.Equal(t, &secondRecord.newCol, scannedNewCol)
+ }
+
+ err := iter.Close()
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ t.Fatal("It is likely failed due deadlock")
+ }
+ t.Fatal(err)
+ }
+ }
+
+ // Expecting C* will return RESULT/ROWS Metadata_changed
+ // and it will be properly handled
+ queryAfterTableAltering := session.Query(selectStmt)
+ queryAfterTableAltering.conn = conn
+ iter := queryAfterTableAltering.Iter()
+ handleRows(iter)
+
+ // Ensuring if cache contains updated prepared statement
+ inflight, _ = session.stmtsLRU.get(stmtCacheKey)
+ preparedStatementAfterTableAltering := inflight.preparedStatment
+ require.NotEqual(t,
preparedStatementBeforeTableAltering.resultMetadataID,
preparedStatementAfterTableAltering.resultMetadataID)
+ require.NotEqual(t, preparedStatementBeforeTableAltering.response,
preparedStatementAfterTableAltering.response)
+
+ // FORCE SEND OLD RESULT METADATA ID
(https://issues.apache.org/jira/browse/CASSANDRA-20028)
+ closedCh := make(chan struct{})
+ close(closedCh)
+ session.stmtsLRU.add(stmtCacheKey, &inflightPrepare{
+ done: closedCh,
+ err: nil,
+ preparedStatment: preparedStatementBeforeTableAltering,
+ })
+
+ // Running query with timeout to ensure there is no deadlocks.
+ // However, it doesn't 100% proves that there is a deadlock...
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+ defer cancel()
+
+ queryAfterTableAltering2 := session.Query(selectStmt).WithContext(ctx)
+ queryAfterTableAltering2.conn = conn
+ iter = queryAfterTableAltering2.Iter()
+ handleRows(iter)
+ err = iter.Close()
+
+ inflight, _ = session.stmtsLRU.get(stmtCacheKey)
+ preparedStatementAfterTableAltering2 := inflight.preparedStatment
+ require.NotEqual(t,
preparedStatementBeforeTableAltering.resultMetadataID,
preparedStatementAfterTableAltering2.resultMetadataID)
+ require.NotEqual(t, preparedStatementBeforeTableAltering.response,
preparedStatementAfterTableAltering2.response)
+
+ require.Equal(t, preparedStatementAfterTableAltering.resultMetadataID,
preparedStatementAfterTableAltering2.resultMetadataID)
+ require.NotEqual(t, preparedStatementAfterTableAltering.response,
preparedStatementAfterTableAltering2.response) // METADATA_CHANGED flag
+ require.True(t,
preparedStatementAfterTableAltering2.response.flags&flagMetaDataChanged != 0)
+
+ // Executing prepared stmt and expecting that C* won't return
+ // Metadata_changed because the table is not being changed.
+ queryAfterTableAltering3 := session.Query(selectStmt).WithContext(ctx)
+ queryAfterTableAltering3.conn = conn
+ iter = queryAfterTableAltering2.Iter()
+ handleRows(iter)
+
+ // Ensuring metadata of prepared stmt is not changed
+ inflight, _ = session.stmtsLRU.get(stmtCacheKey)
+ preparedStatementAfterTableAltering3 := inflight.preparedStatment
+ require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID,
preparedStatementAfterTableAltering3.resultMetadataID)
+ require.Equal(t, preparedStatementAfterTableAltering2.response,
preparedStatementAfterTableAltering3.response)
+}
+
+func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("This tests only runs on proto > 4 due SetKeyspace
availability")
+ }
+
+ const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
+ WITH replication = {
+ 'class' : 'SimpleStrategy',
+ 'replication_factor' : 1
+ }`
+
+ err := createTable(session, fmt.Sprintf(createKeyspaceStmt,
"gocql_test_stmt_cache"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id)
VALUES (?)"
+
+ // Inserting data via Batch to ensure that batches
+ // properly accounts for keyspace overriding
+ b1 := session.NewBatch(LoggedBatch)
+ b1.Query(insertQuery, 1)
+ err = session.ExecuteBatch(b1)
+ require.NoError(t, err)
+
+ b2 := session.NewBatch(LoggedBatch)
+ b2.SetKeyspace("gocql_test_stmt_cache")
+ b2.Query(insertQuery, 2)
+ err = session.ExecuteBatch(b2)
+ require.NoError(t, err)
+
+ var scannedID int
+
+ const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks"
+
+ // By default in our test suite session uses gocql_test ks
+ err = session.Query(selectStmt).Scan(&scannedID)
+ require.NoError(t, err)
+ require.Equal(t, 1, scannedID)
+
+ scannedID = 0
+ err =
session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID)
+ require.NoError(t, err)
+ require.Equal(t, 2, scannedID)
+
+ session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec()
+}
+
+func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
+ session := createSession(t)
+ defer session.Close()
+
+ if session.cfg.ProtoVersion < protoVersion5 {
+ t.Skip("This tests only runs on proto > 4 due SetKeyspace
availability")
+ }
+
+ const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
+ WITH replication = {
+ 'class' : 'SimpleStrategy',
+ 'replication_factor' : 1
+ }`
+
+ err := createTable(session, fmt.Sprintf(createKeyspaceStmt,
"gocql_test_routing_key_cache"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = createTable(session, "CREATE TABLE IF NOT EXISTS
gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int,
PRIMARY KEY (id))")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ getRoutingKeyInfo := func(key string) *routingKeyInfo {
+ t.Helper()
+ session.routingKeyInfoCache.mu.Lock()
+ value, _ := session.routingKeyInfoCache.lru.Get(key)
+ session.routingKeyInfoCache.mu.Unlock()
+
+ inflight := value.(*inflightCachedEntry)
+ return inflight.value.(*routingKeyInfo)
+ }
+
+ const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks
(id) VALUES (?)"
+
+ // Running batch in default ks gocql_test
+ b1 := session.NewBatch(LoggedBatch)
+ b1.Query(insertQuery, 1)
+ _, err = b1.GetRoutingKey()
+ require.NoError(t, err)
+
+ // Ensuring that the cache contains the query with default ks
+ routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt)
+ require.Equal(t, "gocql_test", routingKeyInfo1.keyspace)
+
+ // Running batch in gocql_test_routing_key_cache ks
+ b2 := session.NewBatch(LoggedBatch)
+ b2.SetKeyspace("gocql_test_routing_key_cache")
+ b2.Query(insertQuery, 2)
+ _, err = b2.GetRoutingKey()
+ require.NoError(t, err)
+
+ // Ensuring that the cache contains the query with
gocql_test_routing_key_cache ks
+ routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" +
b2.Entries[0].Stmt)
+ require.Equal(t, "gocql_test_routing_key_cache",
routingKeyInfo2.keyspace)
+
+ const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks
WHERE id=?"
+
+ // Running query in default ks gocql_test
+ q1 := session.Query(selectStmt, 1)
+ _, err = q1.GetRoutingKey()
+ require.NoError(t, err)
+ require.Equal(t, "gocql_test", q1.routingInfo.keyspace)
+
+ // Running query in gocql_test_routing_key_cache ks
+ q2 := session.Query(selectStmt, 1)
+ _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey()
+ require.NoError(t, err)
+ require.Equal(t, "gocql_test_routing_key_cache",
q2.routingInfo.keyspace)
+
+ session.Query("DROP KEYSPACE IF EXISTS
gocql_test_routing_key_cache").Exec()
+}
diff --git a/common_test.go b/common_test.go
index d9c4753..e420bbf 100644
--- a/common_test.go
+++ b/common_test.go
@@ -27,7 +27,6 @@ package gocql
import (
"flag"
"fmt"
- "github.com/gocql/gocql/lz4"
"log"
"net"
"reflect"
@@ -35,6 +34,8 @@ import (
"sync"
"testing"
"time"
+
+ "github.com/gocql/gocql/lz4"
)
var (
@@ -47,7 +48,7 @@ var (
flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond,
"time to wait for autodiscovery to fill the hosts poll")
flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl
test")
flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run
authentication test")
- flagCompressTest = flag.String("compressor", "", "compressor to use")
+ flagCompressTest = flag.String("compressor", "no-compression",
"compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets
the connection `timeout` for all operations")
flagCassVersion cassVersion
@@ -114,7 +115,7 @@ func createCluster(opts ...func(*ClusterConfig))
*ClusterConfig {
cluster.Compressor = &SnappyCompressor{}
case "lz4":
cluster.Compressor = &lz4.LZ4Compressor{}
- case "":
+ case "no-compression":
default:
panic("invalid compressor: " + *flagCompressTest)
}
diff --git a/compressor.go b/compressor.go
index f3d451a..c1b7b2b 100644
--- a/compressor.go
+++ b/compressor.go
@@ -24,14 +24,28 @@
package gocql
-import (
- "github.com/golang/snappy"
-)
+import "github.com/golang/snappy"
type Compressor interface {
Name() string
- Encode(data []byte) ([]byte, error)
- Decode(data []byte) ([]byte, error)
+
+ // AppendCompressedWithLength compresses src bytes, appends the length
of the compressed bytes to dst
+ // and then appends the compressed bytes to dst.
+ // It returns a new byte slice that is the result of the append
operation.
+ AppendCompressedWithLength(dst, src []byte) ([]byte, error)
+
+ // AppendDecompressedWithLength reads the length of the decompressed
bytes from src,
+ // decompressed bytes from src and appends the decompressed bytes to
dst.
+ // It returns a new byte slice that is the result of the append
operation.
+ AppendDecompressedWithLength(dst, src []byte) ([]byte, error)
+
+ // AppendCompressed compresses src bytes and appends the compressed
bytes to dst.
+ // It returns a new byte slice that is the result of the append
operation.
+ AppendCompressed(dst, src []byte) ([]byte, error)
+
+ // AppendDecompressed decompresses bytes from src and appends the
decompressed bytes to dst.
+ // It returns a new byte slice that is the result of the append
operation.
+ AppendDecompressed(dst, src []byte, decompressedLength uint32) ([]byte,
error)
}
// SnappyCompressor implements the Compressor interface and can be used to
@@ -43,10 +57,18 @@ func (s SnappyCompressor) Name() string {
return "snappy"
}
-func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
- return snappy.Encode(nil, data), nil
+func (s SnappyCompressor) AppendCompressedWithLength(dst, src []byte) ([]byte,
error) {
+ return snappy.Encode(dst, src), nil
+}
+
+func (s SnappyCompressor) AppendDecompressedWithLength(dst, src []byte)
([]byte, error) {
+ return snappy.Decode(dst, src)
+}
+
+func (s SnappyCompressor) AppendCompressed(dst, src []byte) ([]byte, error) {
+ panic("SnappyCompressor.AppendCompressed is not supported")
}
-func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
- return snappy.Decode(nil, data)
+func (s SnappyCompressor) AppendDecompressed(dst, src []byte,
decompressedLength uint32) ([]byte, error) {
+ panic("SnappyCompressor.AppendDecompressed is not supported")
}
diff --git a/compressor_test.go b/compressor_test.go
index 20cf934..d2d2de0 100644
--- a/compressor_test.go
+++ b/compressor_test.go
@@ -40,13 +40,13 @@ func TestSnappyCompressor(t *testing.T) {
str := "My Test String"
//Test Encoding
expected := snappy.Encode(nil, []byte(str))
- if res, err := c.Encode([]byte(str)); err != nil {
+ if res, err := c.AppendCompressedWithLength(nil, []byte(str)); err !=
nil {
t.Fatalf("failed to encode '%v' with error %v", str, err)
} else if bytes.Compare(expected, res) != 0 {
t.Fatal("failed to match the expected encoded value with the
result encoded value.")
}
- val, err := c.Encode([]byte(str))
+ val, err := c.AppendCompressedWithLength(nil, []byte(str))
if err != nil {
t.Fatalf("failed to encode '%v' with error '%v'", str, err)
}
@@ -54,7 +54,7 @@ func TestSnappyCompressor(t *testing.T) {
//Test Decoding
if expected, err := snappy.Decode(nil, val); err != nil {
t.Fatalf("failed to decode '%v' with error %v", val, err)
- } else if res, err := c.Decode(val); err != nil {
+ } else if res, err := c.AppendDecompressedWithLength(nil, val); err !=
nil {
t.Fatalf("failed to decode '%v' with error %v", val, err)
} else if bytes.Compare(expected, res) != 0 {
t.Fatal("failed to match the expected decoded value with the
result decoded value.")
diff --git a/conn.go b/conn.go
index aac75e4..d2f83d7 100644
--- a/conn.go
+++ b/conn.go
@@ -26,6 +26,7 @@ package gocql
import (
"bufio"
+ "bytes"
"context"
"crypto/tls"
"errors"
@@ -170,11 +171,9 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err
error, closed bool) {
// queries, but users are usually advised to use a more reliable, higher
// level API.
type Conn struct {
- conn net.Conn
- r *bufio.Reader
- w contextWriter
+ r ConnReader
+ w contextWriter
- timeout time.Duration
writeTimeout time.Duration
cfg *ConnConfig
frameObserver FrameHeaderObserver
@@ -252,8 +251,10 @@ func (s *Session) dialWithoutObserver(ctx context.Context,
host *HostInfo, cfg *
ctx, cancel := context.WithCancel(ctx)
c := &Conn{
- conn: dialedHost.Conn,
- r: bufio.NewReader(dialedHost.Conn),
+ r: &connReader{
+ conn: dialedHost.Conn,
+ r: bufio.NewReader(dialedHost.Conn),
+ },
cfg: cfg,
calls: make(map[int]*callReq),
version: uint8(cfg.ProtoVersion),
@@ -303,16 +304,16 @@ func (c *Conn) init(ctx context.Context, dialedHost
*DialedHost) error {
conn: c,
}
- c.timeout = c.cfg.ConnectTimeout
+ c.r.SetTimeout(c.cfg.ConnectTimeout)
if err := startup.setupConn(ctx); err != nil {
return err
}
- c.timeout = c.cfg.Timeout
+ c.r.SetTimeout(c.cfg.Timeout)
// dont coalesce startup frames
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce &&
!dialedHost.DisableCoalesce {
- c.w = newWriteCoalescer(c.conn, c.writeTimeout,
c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
+ c.w = newWriteCoalescer(dialedHost.Conn, c.writeTimeout,
c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}
go c.serve(ctx)
@@ -325,29 +326,6 @@ func (c *Conn) Write(p []byte) (n int, err error) {
return c.w.writeContext(context.Background(), p)
}
-func (c *Conn) Read(p []byte) (n int, err error) {
- const maxAttempts = 5
-
- for i := 0; i < maxAttempts; i++ {
- var nn int
- if c.timeout > 0 {
- c.conn.SetReadDeadline(time.Now().Add(c.timeout))
- }
-
- nn, err = io.ReadFull(c.r, p[n:])
- n += nn
- if err == nil {
- break
- }
-
- if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
- break
- }
- }
-
- return
-}
-
type startupCoordinator struct {
conn *Conn
frameTicker chan struct{}
@@ -355,17 +333,26 @@ type startupCoordinator struct {
func (s *startupCoordinator) setupConn(ctx context.Context) error {
var cancel context.CancelFunc
- if s.conn.timeout > 0 {
- ctx, cancel = context.WithTimeout(ctx, s.conn.timeout)
+ if s.conn.r.GetTimeout() > 0 {
+ ctx, cancel = context.WithTimeout(ctx, s.conn.r.GetTimeout())
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer cancel()
+ // Only for proto v5+.
+ // Indicates if STARTUP has been completed.
+ // github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec
+ // 2.3.1 Initial Handshake
+ // In order to support both v5 and earlier formats, the v5 framing
format is not
+ // applied to message exchanges before an initial handshake is
completed.
+ startupCompleted := &atomic.Bool{}
+ startupCompleted.Store(false)
+
startupErr := make(chan error)
go func() {
for range s.frameTicker {
- err := s.conn.recv(ctx)
+ err := s.conn.recv(ctx, startupCompleted.Load())
if err != nil {
select {
case startupErr <- err:
@@ -379,7 +366,7 @@ func (s *startupCoordinator) setupConn(ctx context.Context)
error {
go func() {
defer close(s.frameTicker)
- err := s.options(ctx)
+ err := s.options(ctx, startupCompleted)
select {
case startupErr <- err:
case <-ctx.Done():
@@ -398,14 +385,14 @@ func (s *startupCoordinator) setupConn(ctx
context.Context) error {
return nil
}
-func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder)
(frame, error) {
+func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder,
startupCompleted *atomic.Bool) (frame, error) {
select {
case s.frameTicker <- struct{}{}:
case <-ctx.Done():
return nil, ctx.Err()
}
- framer, err := s.conn.exec(ctx, frame, nil)
+ framer, err := s.conn.execInternal(ctx, frame, nil,
startupCompleted.Load())
if err != nil {
return nil, err
}
@@ -413,8 +400,8 @@ func (s *startupCoordinator) write(ctx context.Context,
frame frameBuilder) (fra
return framer.parseFrame()
}
-func (s *startupCoordinator) options(ctx context.Context) error {
- frame, err := s.write(ctx, &writeOptionsFrame{})
+func (s *startupCoordinator) options(ctx context.Context, startupCompleted
*atomic.Bool) error {
+ frame, err := s.write(ctx, &writeOptionsFrame{}, startupCompleted)
if err != nil {
return err
}
@@ -424,10 +411,10 @@ func (s *startupCoordinator) options(ctx context.Context)
error {
return NewErrProtocol("Unknown type of response to startup
frame: %T", frame)
}
- return s.startup(ctx, supported.supported)
+ return s.startup(ctx, supported.supported, startupCompleted)
}
-func (s *startupCoordinator) startup(ctx context.Context, supported
map[string][]string) error {
+func (s *startupCoordinator) startup(ctx context.Context, supported
map[string][]string, startupCompleted *atomic.Bool) error {
m := map[string]string{
"CQL_VERSION": s.conn.cfg.CQLVersion,
"DRIVER_NAME": driverName,
@@ -449,7 +436,7 @@ func (s *startupCoordinator) startup(ctx context.Context,
supported map[string][
}
}
- frame, err := s.write(ctx, &writeStartupFrame{opts: m})
+ frame, err := s.write(ctx, &writeStartupFrame{opts: m},
startupCompleted)
if err != nil {
return err
}
@@ -458,15 +445,19 @@ func (s *startupCoordinator) startup(ctx context.Context,
supported map[string][
case error:
return v
case *readyFrame:
+ // Startup is successfully completed, so we could use Native
Protocol 5
+ startupCompleted.Store(true)
return nil
case *authenticateFrame:
- return s.authenticateHandshake(ctx, v)
+ // Startup is successfully completed, so we could use Native
Protocol 5
+ startupCompleted.Store(true)
+ return s.authenticateHandshake(ctx, v, startupCompleted)
default:
return NewErrProtocol("Unknown type of response to startup
frame: %s", v)
}
}
-func (s *startupCoordinator) authenticateHandshake(ctx context.Context,
authFrame *authenticateFrame) error {
+func (s *startupCoordinator) authenticateHandshake(ctx context.Context,
authFrame *authenticateFrame, startupCompleted *atomic.Bool) error {
if s.conn.auth == nil {
return fmt.Errorf("authentication required (using %q)",
authFrame.class)
}
@@ -478,7 +469,7 @@ func (s *startupCoordinator) authenticateHandshake(ctx
context.Context, authFram
req := &writeAuthResponseFrame{data: resp}
for {
- frame, err := s.write(ctx, req)
+ frame, err := s.write(ctx, req, startupCompleted)
if err != nil {
return err
}
@@ -547,7 +538,7 @@ func (c *Conn) closeWithError(err error) {
// if error was nil then unblock the quit channel
c.cancel()
- cerr := c.close()
+ cerr := c.r.Close()
if err != nil {
c.errorHandler.HandleError(c, err, true)
@@ -557,10 +548,6 @@ func (c *Conn) closeWithError(err error) {
}
}
-func (c *Conn) close() error {
- return c.conn.Close()
-}
-
func (c *Conn) Close() {
c.closeWithError(nil)
}
@@ -571,14 +558,14 @@ func (c *Conn) Close() {
func (c *Conn) serve(ctx context.Context) {
var err error
for err == nil {
- err = c.recv(ctx)
+ err = c.recv(ctx, true)
}
c.closeWithError(err)
}
-func (c *Conn) discardFrame(head frameHeader) error {
- _, err := io.CopyN(ioutil.Discard, c, int64(head.length))
+func (c *Conn) discardFrame(r io.Reader, head frameHeader) error {
+ _, err := io.CopyN(ioutil.Discard, r, int64(head.length))
if err != nil {
return err
}
@@ -643,18 +630,28 @@ func (c *Conn) heartBeat(ctx context.Context) {
}
}
-func (c *Conn) recv(ctx context.Context) error {
+func (c *Conn) recv(ctx context.Context, startupCompleted bool) error {
+ // If startup is completed and native proto 5+ is set up then we should
+ // unwrap payload from compressed/uncompressed frame
+ if startupCompleted && c.version > protoVersion4 {
+ return c.recvSegment(ctx)
+ }
+
+ return c.processFrame(ctx, c.r)
+}
+
+func (c *Conn) processFrame(ctx context.Context, r io.Reader) error {
// not safe for concurrent reads
// read a full header, ignore timeouts, as this is being ran in a loop
// TODO: TCP level deadlines? or just query level deadlines?
- if c.timeout > 0 {
- c.conn.SetReadDeadline(time.Time{})
+ if c.r.GetTimeout() > 0 {
+ c.r.SetReadDeadline(time.Time{})
}
headStartTime := time.Now()
// were just reading headers over and over and copy bodies
- head, err := readHeader(c.r, c.headerBuf[:])
+ head, err := readHeader(r, c.headerBuf[:])
headEndTime := time.Now()
if err != nil {
return err
@@ -678,7 +675,7 @@ func (c *Conn) recv(ctx context.Context) error {
} else if head.stream == -1 {
// TODO: handle cassandra event frames, we shouldnt get any
currently
framer := newFramer(c.compressor, c.version)
- if err := framer.readFrame(c, &head); err != nil {
+ if err := framer.readFrame(r, &head); err != nil {
return err
}
go c.session.handleEvent(framer)
@@ -687,7 +684,7 @@ func (c *Conn) recv(ctx context.Context) error {
// reserved stream that we dont use, probably due to a protocol
error
// or a bug in Cassandra, this should be an error, parse it and
return.
framer := newFramer(c.compressor, c.version)
- if err := framer.readFrame(c, &head); err != nil {
+ if err := framer.readFrame(r, &head); err != nil {
return err
}
@@ -711,14 +708,14 @@ func (c *Conn) recv(ctx context.Context) error {
c.mu.Unlock()
if call == nil || !ok {
c.logger.Printf("gocql: received response for stream which has
no handler: header=%v\n", head)
- return c.discardFrame(head)
+ return c.discardFrame(r, head)
} else if head.stream != call.streamID {
panic(fmt.Sprintf("call has incorrect streamID: got %d expected
%d", call.streamID, head.stream))
}
framer := newFramer(c.compressor, c.version)
- err = framer.readFrame(c, &head)
+ err = framer.readFrame(r, &head)
if err != nil {
// only net errors should cause the connection to be closed.
Though
// cassandra returning corrupt frames will be returned here as
well.
@@ -761,6 +758,172 @@ func (c *Conn) handleTimeout() {
}
}
+func (c *Conn) recvSegment(ctx context.Context) error {
+ var (
+ frame []byte
+ isSelfContained bool
+ err error
+ )
+
+ // Read frame based on compression
+ if c.compressor != nil {
+ frame, isSelfContained, err = readCompressedSegment(c.r,
c.compressor)
+ } else {
+ frame, isSelfContained, err = readUncompressedSegment(c.r)
+ }
+ if err != nil {
+ return err
+ }
+
+ if isSelfContained {
+ return c.processAllFramesInSegment(ctx, bytes.NewReader(frame))
+ }
+
+ head, err := readHeader(bytes.NewReader(frame), c.headerBuf[:])
+ if err != nil {
+ return err
+ }
+
+ const frameHeaderLength = 9
+ buf := bytes.NewBuffer(make([]byte, 0, head.length+frameHeaderLength))
+ buf.Write(frame)
+
+ // Computing how many bytes of message left to read
+ bytesToRead := head.length - len(frame) + frameHeaderLength
+
+ err = c.recvPartialFrames(buf, bytesToRead)
+ if err != nil {
+ return err
+ }
+
+ return c.processFrame(ctx, buf)
+}
+
+// recvPartialFrames reads proto v5 segments from Conn.r and writes decoded
partial frames to dst.
+// It reads data until the bytesToRead is reached.
+// If Conn.compressor is not nil, it processes Compressed Format segments.
+func (c *Conn) recvPartialFrames(dst *bytes.Buffer, bytesToRead int) error {
+ var (
+ read int
+ frame []byte
+ isSelfContained bool
+ err error
+ )
+
+ for read != bytesToRead {
+ // Read frame based on compression
+ if c.compressor != nil {
+ frame, isSelfContained, err =
readCompressedSegment(c.r, c.compressor)
+ } else {
+ frame, isSelfContained, err =
readUncompressedSegment(c.r)
+ }
+ if err != nil {
+ return fmt.Errorf("gocql: failed to read non
self-contained frame: %w", err)
+ }
+
+ if isSelfContained {
+ return fmt.Errorf("gocql: received self-contained
segment, but expected not")
+ }
+
+ if totalLength := dst.Len() + len(frame); totalLength >
dst.Cap() {
+ return fmt.Errorf("gocql: expected partial frame of
length %d, got %d", dst.Cap(), totalLength)
+ }
+
+ // Write the frame to the destination writer
+ n, _ := dst.Write(frame)
+ read += n
+ }
+
+ return nil
+}
+
+func (c *Conn) processAllFramesInSegment(ctx context.Context, r *bytes.Reader)
error {
+ var err error
+ for r.Len() > 0 && err == nil {
+ err = c.processFrame(ctx, r)
+ }
+
+ return err
+}
+
+// ConnReader is like net.Conn but also allows to set timeout duration.
+type ConnReader interface {
+ net.Conn
+
+ // SetTimeout sets timeout duration for reading data form conn
+ SetTimeout(timeout time.Duration)
+
+ // GetTimeout returns timeout duration
+ GetTimeout() time.Duration
+}
+
+// connReader implements ConnReader.
+// It retries to read data up to 5 times or returns error.
+type connReader struct {
+ conn net.Conn
+ r *bufio.Reader
+ timeout time.Duration
+}
+
+func (c *connReader) Read(p []byte) (n int, err error) {
+ const maxAttempts = 5
+
+ for i := 0; i < maxAttempts; i++ {
+ var nn int
+ if c.timeout > 0 {
+ c.conn.SetReadDeadline(time.Now().Add(c.timeout))
+ }
+
+ nn, err = io.ReadFull(c.r, p[n:])
+ n += nn
+ if err == nil {
+ break
+ }
+
+ if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
+ break
+ }
+ }
+
+ return
+}
+
+func (c *connReader) Write(b []byte) (n int, err error) {
+ return c.conn.Write(b)
+}
+
+func (c *connReader) Close() error {
+ return c.conn.Close()
+}
+
+func (c *connReader) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+func (c *connReader) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *connReader) SetDeadline(t time.Time) error {
+ return c.conn.SetDeadline(t)
+}
+
+func (c *connReader) SetReadDeadline(t time.Time) error {
+ return c.conn.SetReadDeadline(t)
+}
+
+func (c *connReader) SetWriteDeadline(t time.Time) error {
+ return c.conn.SetWriteDeadline(t)
+}
+
+func (c *connReader) SetTimeout(timeout time.Duration) {
+ c.timeout = timeout
+}
+
+func (c *connReader) GetTimeout() time.Duration {
+ return c.timeout
+}
+
type callReq struct {
// resp will receive the frame that was sent as a response to this
stream.
resp chan callResp
@@ -1011,6 +1174,10 @@ func (c *Conn) addCall(call *callReq) error {
}
func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer)
(*framer, error) {
+ return c.execInternal(ctx, req, tracer, true)
+}
+
+func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer
Tracer, startupCompleted bool) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
@@ -1070,7 +1237,14 @@ func (c *Conn) exec(ctx context.Context, req
frameBuilder, tracer Tracer) (*fram
return nil, err
}
- n, err := c.w.writeContext(ctx, framer.buf)
+ var n int
+
+ if c.version > protoVersion4 && startupCompleted {
+ err = framer.prepareModernLayout()
+ }
+ if err == nil {
+ n, err = c.w.writeContext(ctx, framer.buf)
+ }
if err != nil {
// closeWithError will block waiting for this stream to either
receive a response
// or for us to timeout, close the timeout chan here. Im not
entirely sure
@@ -1099,7 +1273,7 @@ func (c *Conn) exec(ctx context.Context, req
frameBuilder, tracer Tracer) (*fram
}
var timeoutCh <-chan time.Time
- if c.timeout > 0 {
+ if timeout := c.r.GetTimeout(); timeout > 0 {
if call.timer == nil {
call.timer = time.NewTimer(0)
<-call.timer.C
@@ -1112,7 +1286,7 @@ func (c *Conn) exec(ctx context.Context, req
frameBuilder, tracer Tracer) (*fram
}
}
- call.timer.Reset(c.timeout)
+ call.timer.Reset(timeout)
timeoutCh = call.timer.C
}
@@ -1207,9 +1381,10 @@ type StreamObserverContext interface {
}
type preparedStatment struct {
- id []byte
- request preparedMetadata
- response resultMetadata
+ id []byte
+ resultMetadataID []byte
+ request preparedMetadata
+ response resultMetadata
}
type inflightPrepare struct {
@@ -1219,8 +1394,8 @@ type inflightPrepare struct {
preparedStatment *preparedStatment
}
-func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer
Tracer) (*preparedStatment, error) {
- stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(),
c.currentKeyspace, stmt)
+func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer
Tracer, keyspace string) (*preparedStatment, error) {
+ stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace,
stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru
*lru.Cache) *inflightPrepare {
flight := &inflightPrepare{
done: make(chan struct{}),
@@ -1237,7 +1412,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt
string, tracer Tracer)
statement: stmt,
}
if c.version > protoVersion4 {
- prep.keyspace = c.currentKeyspace
+ prep.keyspace = keyspace
}
// we won the race to do the load, if our context is
canceled we shouldnt
@@ -1268,7 +1443,8 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt
string, tracer Tracer)
flight.preparedStatment = &preparedStatment{
// defensively copy as we will recycle
the underlying buffer after we
// return.
- id: copyBytes(x.preparedID),
+ id:
copyBytes(x.preparedID),
+ resultMetadataID:
copyBytes(x.resultMetadataID),
// the type info's should _not_ have a
reference to the framers read buffer,
// therefore we can just copy them
directly.
request: x.reqMeta,
@@ -1331,7 +1507,15 @@ func (c *Conn) executeQuery(ctx context.Context, qry
*Query) *Iter {
params.pageSize = qry.pageSize
}
if c.version > protoVersion4 {
- params.keyspace = c.currentKeyspace
+ params.keyspace = qry.keyspace
+ params.nowInSeconds = qry.nowInSecondsValue
+ }
+
+ // If a keyspace for the qry is overriden,
+ // then we should use it to create stmt cache key
+ usedKeyspace := c.currentKeyspace
+ if qry.keyspace != "" {
+ usedKeyspace = qry.keyspace
}
var (
@@ -1342,7 +1526,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry
*Query) *Iter {
if !qry.skipPrepare && qry.shouldPrepare() {
// Prepare all DML queries. Other queries can not be prepared.
var err error
- info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
+ info, err = c.prepareStatement(ctx, qry.stmt, qry.trace,
usedKeyspace)
if err != nil {
return &Iter{err: err}
}
@@ -1379,14 +1563,18 @@ func (c *Conn) executeQuery(ctx context.Context, qry
*Query) *Iter {
params.skipMeta = !(c.session.cfg.DisableSkipMetadata ||
qry.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData
== 0
frame = &writeExecuteFrame{
- preparedID: info.id,
- params: params,
- customPayload: qry.customPayload,
+ preparedID: info.id,
+ params: params,
+ customPayload: qry.customPayload,
+ resultMetadataID: info.resultMetadataID,
}
// Set "keyspace" and "table" property in the query if it is
present in preparedMetadata
qry.routingInfo.mu.Lock()
qry.routingInfo.keyspace = info.request.keyspace
+ if info.request.keyspace == "" {
+ qry.routingInfo.keyspace = usedKeyspace
+ }
qry.routingInfo.table = info.request.table
qry.routingInfo.mu.Unlock()
} else {
@@ -1415,13 +1603,39 @@ func (c *Conn) executeQuery(ctx context.Context, qry
*Query) *Iter {
case *resultVoidFrame:
return &Iter{framer: framer}
case *resultRowsFrame:
+ if x.meta.newMetadataID != nil {
+ // If a RESULT/Rows message reports
+ // changed resultset metadata with the
Metadata_changed flag, the reported new
+ // resultset metadata must be used in subsequent
executions
+ stmtCacheKey :=
c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt)
+ oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey)
+ if ok {
+ newInflight := &inflightPrepare{
+ done: make(chan struct{}),
+ preparedStatment: &preparedStatment{
+ id:
oldInflight.preparedStatment.id,
+ resultMetadataID:
x.meta.newMetadataID,
+ request:
oldInflight.preparedStatment.request,
+ response: x.meta,
+ },
+ }
+ // The driver should close this done to avoid
deadlocks of
+ // other subsequent requests
+ close(newInflight.done)
+ c.session.stmtsLRU.add(stmtCacheKey,
newInflight)
+ // Updating info to ensure the code is looking
at the updated
+ // version of the prepared statement
+ info = newInflight.preparedStatment
+ }
+ }
+
iter := &Iter{
meta: x.meta,
framer: framer,
numRows: x.numRows,
}
- if params.skipMeta {
+ if x.meta.noMetaData() {
if info != nil {
iter.meta = info.response
iter.meta.pagingState =
copyBytes(x.meta.pagingState)
@@ -1462,7 +1676,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry
*Query) *Iter {
// is not consistent with regards to its schema.
return iter
case *RequestErrUnprepared:
- stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(),
c.currentKeyspace, qry.stmt)
+ stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(),
usedKeyspace, qry.stmt)
c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId)
return c.executeQuery(ctx, qry)
case error:
@@ -1539,6 +1753,16 @@ func (c *Conn) executeBatch(ctx context.Context, batch
*Batch) *Iter {
customPayload: batch.CustomPayload,
}
+ if c.version > protoVersion4 {
+ req.keyspace = batch.keyspace
+ req.nowInSeconds = batch.nowInSeconds
+ }
+
+ usedKeyspace := c.currentKeyspace
+ if batch.keyspace != "" {
+ usedKeyspace = batch.keyspace
+ }
+
stmts := make(map[string]string, len(batch.Entries))
for i := 0; i < n; i++ {
@@ -1546,7 +1770,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch
*Batch) *Iter {
b := &req.statements[i]
if len(entry.Args) > 0 || entry.binding != nil {
- info, err := c.prepareStatement(batch.Context(),
entry.Stmt, batch.trace)
+ info, err := c.prepareStatement(batch.Context(),
entry.Stmt, batch.trace, usedKeyspace)
if err != nil {
return &Iter{err: err}
}
@@ -1608,7 +1832,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch
*Batch) *Iter {
case *RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
- key := c.session.stmtsLRU.keyFor(c.host.HostID(),
c.currentKeyspace, stmt)
+ key := c.session.stmtsLRU.keyFor(c.host.HostID(),
usedKeyspace, stmt)
c.session.stmtsLRU.evictPreparedID(key, x.StatementId)
}
return c.executeBatch(ctx, batch)
diff --git a/conn_test.go b/conn_test.go
index 76d67cb..e9b33fe 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -47,6 +47,8 @@ import (
"testing"
"time"
+ "github.com/stretchr/testify/require"
+
"github.com/gocql/gocql/internal/streams"
)
@@ -711,12 +713,14 @@ func TestStream0(t *testing.T) {
}
conn := &Conn{
- r: bufio.NewReader(&buf),
+ r: &connReader{
+ r: bufio.NewReader(&buf),
+ },
streams: streams.New(protoVersion4),
logger: &defaultLogger{},
}
- err := conn.recv(context.Background())
+ err := conn.recv(context.Background(), false)
if err == nil {
t.Fatal("expected to get an error on stream 0")
} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -1338,12 +1342,12 @@ func (srv *TestServer) process(conn net.Conn, reqFrame
*framer) {
id := binary.BigEndian.Uint64(b)
// <query_parameters>
reqFrame.readConsistency() // <consistency>
- var flags byte
+ var flags uint32
if srv.protocol > protoVersion4 {
ui := reqFrame.readInt()
- flags = byte(ui)
+ flags = uint32(ui)
} else {
- flags = reqFrame.readByte()
+ flags = uint32(reqFrame.readByte())
}
switch id {
case 1:
@@ -1419,3 +1423,97 @@ func (srv *TestServer) readFrame(conn net.Conn)
(*framer, error) {
return framer, nil
}
+
+func TestConnProcessAllFramesInSingleSegment(t *testing.T) {
+ server, client, err := tcpConnPair()
+ require.NoError(t, err)
+
+ c := &Conn{
+ r: &connReader{
+ conn: server,
+ r: bufio.NewReader(server),
+ },
+ calls: make(map[int]*callReq),
+ version: protoVersion5,
+ addr: server.RemoteAddr().String(),
+ streams: streams.New(protoVersion5),
+ isSchemaV2: true,
+ w: &deadlineContextWriter{
+ w: server,
+ timeout: time.Second * 10,
+ semaphore: make(chan struct{}, 1),
+ quit: make(chan struct{}),
+ },
+ writeTimeout: time.Second * 10,
+ }
+
+ call1 := &callReq{
+ timeout: make(chan struct{}),
+ streamID: 1,
+ resp: make(chan callResp),
+ }
+
+ call2 := &callReq{
+ timeout: make(chan struct{}),
+ streamID: 2,
+ resp: make(chan callResp),
+ }
+
+ c.calls[1] = call1
+ c.calls[2] = call2
+
+ req := writeQueryFrame{
+ statement: "SELECT * FROM system.local",
+ params: queryParams{
+ consistency: Quorum,
+ keyspace: "gocql_test",
+ },
+ }
+
+ framer1 := newFramer(nil, protoVersion5)
+ err = req.buildFrame(framer1, 1)
+ require.NoError(t, err)
+
+ framer2 := newFramer(nil, protoVersion5)
+ err = req.buildFrame(framer2, 2)
+ require.NoError(t, err)
+
+ go func() {
+ var buf []byte
+ buf = append(buf, framer1.buf...)
+ buf = append(buf, framer2.buf...)
+
+ uncompressedSegment, err := newUncompressedSegment(buf, true)
+ require.NoError(t, err)
+
+ _, err = client.Write(uncompressedSegment)
+ require.NoError(t, err)
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
+ defer cancel()
+
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- c.recvSegment(ctx)
+ }()
+
+ go func() {
+ resp1 := <-call1.resp
+ close(call1.timeout)
+ // Skipping here the header of the frame because resp.framer
contains already parsed header
+ // and resp.framer.buf contains frame body
+ require.Equal(t, framer1.buf[9:], resp1.framer.buf)
+
+ resp2 := <-call2.resp
+ close(call2.timeout)
+ require.Equal(t, framer2.buf[9:], resp2.framer.buf)
+ }()
+
+ select {
+ case <-ctx.Done():
+ t.Fatal("Timed out waiting for frames")
+ case err := <-errCh:
+ require.NoError(t, err)
+ }
+}
diff --git a/control.go b/control.go
index 0e2a859..95ba1c0 100644
--- a/control.go
+++ b/control.go
@@ -225,7 +225,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo)
(int, error) {
hosts = shuffleHosts(hosts)
connCfg := *c.session.connCfg
- connCfg.ProtoVersion = 4 // TODO: define maxProtocol
+ connCfg.ProtoVersion = 5 // TODO: define maxProtocol
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected
to a
@@ -303,7 +303,7 @@ type connHost struct {
func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystemLocal(context.TODO())
- host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress,
conn.conn.RemoteAddr().(*net.TCPAddr).Port)
+ host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress,
conn.r.RemoteAddr().(*net.TCPAddr).Port)
if err != nil {
return err
}
diff --git a/crc.go b/crc.go
new file mode 100644
index 0000000..64474ad
--- /dev/null
+++ b/crc.go
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package gocql
+
+import (
+ "hash/crc32"
+)
+
+var (
+ // Initial CRC32 bytes: 0xFA, 0x2D, 0x55, 0xCA
+ initialCRC32Bytes = []byte{0xfa, 0x2d, 0x55, 0xca}
+)
+
+// Crc32 calculates the CRC32 checksum of the given byte slice.
+func Crc32(b []byte) uint32 {
+ crc := crc32.NewIEEE()
+ crc.Write(initialCRC32Bytes) // Include initial CRC32 bytes
+ crc.Write(b)
+ return crc.Sum32()
+}
+
+const (
+ crc24Init = 0x875060 // Initial value for CRC24 calculation
+ crc24Poly = 0x1974F0B // Polynomial for CRC24 calculation
+)
+
+// Crc24 calculates the CRC24 checksum using the Koopman polynomial.
+func Crc24(buf []byte) uint32 {
+ crc := crc24Init
+ for _, b := range buf {
+ crc ^= int(b) << 16
+
+ for i := 0; i < 8; i++ {
+ crc <<= 1
+ if crc&0x1000000 != 0 {
+ crc ^= crc24Poly
+ }
+ }
+ }
+
+ return uint32(crc)
+}
diff --git a/crc_test.go b/crc_test.go
new file mode 100644
index 0000000..cf5e40a
--- /dev/null
+++ b/crc_test.go
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package gocql
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestChecksumIEEE(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected uint32
+ }{
+ // expected values are manually generated using crc24 impl in
Cassandra
+ {
+ name: "empty buf",
+ buf: []byte{},
+ expected: 1148681939,
+ },
+ {
+ name: "buf filled with 0",
+ buf: []byte{0, 0, 0, 0, 0},
+ expected: 1178391023,
+ },
+ {
+ name: "buf filled with some data",
+ buf: []byte{1, 2, 3, 4, 5, 6},
+ expected: 3536190002,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.expected, Crc32(tt.buf))
+ })
+ }
+}
+
+func TestKoopmanChecksum(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected uint32
+ }{
+ // expected values are manually generated using crc32 impl in
Cassandra
+ {
+ name: "buf filled with 0 (len 3)",
+ buf: []byte{0, 0, 0},
+ expected: 8251255,
+ },
+ {
+ name: "buf filled with 0 (len 5)",
+ buf: []byte{0, 0, 0, 0, 0},
+ expected: 11185162,
+ },
+ {
+ name: "buf filled with some data (len 3)",
+ buf: []byte{64, -30 & 0xff, 1},
+ expected: 5891942,
+ },
+ {
+ name: "buf filled with some data (len 5)",
+ buf: []byte{64, -30 & 0xff, 1, 0, 0},
+ expected: 8775784,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.expected, Crc24(tt.buf))
+ })
+ }
+}
diff --git a/frame.go b/frame.go
index fa6edd0..99b07e2 100644
--- a/frame.go
+++ b/frame.go
@@ -25,7 +25,9 @@
package gocql
import (
+ "bytes"
"context"
+ "encoding/binary"
"errors"
"fmt"
"io"
@@ -70,6 +72,8 @@ const (
protoVersion5 = 0x05
maxFrameSize = 256 * 1024 * 1024
+
+ maxSegmentPayloadSize = 0x1FFFF
)
type protoVersion byte
@@ -168,16 +172,18 @@ const (
flagGlobalTableSpec int = 0x01
flagHasMorePages int = 0x02
flagNoMetaData int = 0x04
+ flagMetaDataChanged int = 0x08
// query flags
- flagValues byte = 0x01
- flagSkipMetaData byte = 0x02
- flagPageSize byte = 0x04
- flagWithPagingState byte = 0x08
- flagWithSerialConsistency byte = 0x10
- flagDefaultTimestamp byte = 0x20
- flagWithNameValues byte = 0x40
- flagWithKeyspace byte = 0x80
+ flagValues uint32 = 0x01
+ flagSkipMetaData uint32 = 0x02
+ flagPageSize uint32 = 0x04
+ flagWithPagingState uint32 = 0x08
+ flagWithSerialConsistency uint32 = 0x10
+ flagDefaultTimestamp uint32 = 0x20
+ flagWithNameValues uint32 = 0x40
+ flagWithKeyspace uint32 = 0x80
+ flagWithNowInSeconds uint32 = 0x100
// prepare flags
flagWithPreparedKeyspace uint32 = 0x01
@@ -495,12 +501,12 @@ func (f *framer) readFrame(r io.Reader, head
*frameHeader) error {
return fmt.Errorf("unable to read frame body: read %d/%d bytes:
%v", n, head.length, err)
}
- if head.flags&flagCompress == flagCompress {
+ if f.proto < protoVersion5 && head.flags&flagCompress == flagCompress {
if f.compres == nil {
return NewErrProtocol("no compressor available with
compressed frame body")
}
- f.buf, err = f.compres.Decode(f.buf)
+ f.buf, err = f.compres.AppendDecompressedWithLength(nil, f.buf)
if err != nil {
return err
}
@@ -739,13 +745,13 @@ func (f *framer) finish() error {
return ErrFrameTooBig
}
- if f.buf[1]&flagCompress == flagCompress {
+ if f.proto < protoVersion5 && f.buf[1]&flagCompress == flagCompress {
if f.compres == nil {
panic("compress flag set with no compressor")
}
// TODO: only compress frames which are big enough
- compressed, err := f.compres.Encode(f.buf[f.headSize:])
+ compressed, err := f.compres.AppendCompressedWithLength(nil,
f.buf[f.headSize:])
if err != nil {
return err
}
@@ -988,14 +994,20 @@ type resultMetadata struct {
// it is at minimum len(columns) but may be larger, for instance when a
column
// is a UDT or tuple.
actualColCount int
+
+ newMetadataID []byte
}
func (r *resultMetadata) morePages() bool {
return r.flags&flagHasMorePages == flagHasMorePages
}
+func (r *resultMetadata) noMetaData() bool {
+ return r.flags&flagNoMetaData == flagNoMetaData
+}
+
func (r resultMetadata) String() string {
- return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]",
r.flags, r.pagingState, r.columns)
+ return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v
new_metadata_id=% X]", r.flags, r.pagingState, r.columns, r.newMetadataID)
}
func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec
bool, keyspace, table string) {
@@ -1031,7 +1043,11 @@ func (f *framer) parseResultMetadata() resultMetadata {
meta.pagingState = copyBytes(f.readBytes())
}
- if meta.flags&flagNoMetaData == flagNoMetaData {
+ if f.proto > protoVersion4 && meta.flags&flagMetaDataChanged ==
flagMetaDataChanged {
+ meta.newMetadataID = copyBytes(f.readShortBytes())
+ }
+
+ if meta.noMetaData() {
return meta
}
@@ -1135,18 +1151,24 @@ func (f *framer) parseResultSetKeyspace() frame {
type resultPreparedFrame struct {
frameHeader
- preparedID []byte
- reqMeta preparedMetadata
- respMeta resultMetadata
+ preparedID []byte
+ resultMetadataID []byte
+ reqMeta preparedMetadata
+ respMeta resultMetadata
}
func (f *framer) parseResultPrepared() frame {
frame := &resultPreparedFrame{
frameHeader: *f.header,
preparedID: f.readShortBytes(),
- reqMeta: f.parsePreparedMetadata(),
}
+ if f.proto > protoVersion4 {
+ frame.resultMetadataID = copyBytes(f.readShortBytes())
+ }
+
+ frame.reqMeta = f.parsePreparedMetadata()
+
if f.proto < protoVersion2 {
return frame
}
@@ -1428,12 +1450,13 @@ type queryParams struct {
defaultTimestamp bool
defaultTimestampValue int64
// v5+
- keyspace string
+ keyspace string
+ nowInSeconds *int
}
func (q queryParams) String() string {
- return fmt.Sprintf("[query_params consistency=%v skip_meta=%v
page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v
values=%v keyspace=%s]",
- q.consistency, q.skipMeta, q.pageSize, q.pagingState,
q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace)
+ return fmt.Sprintf("[query_params consistency=%v skip_meta=%v
page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v
values=%v keyspace=%s now_in_seconds=%v]",
+ q.consistency, q.skipMeta, q.pageSize, q.pagingState,
q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace, q.nowInSeconds)
}
func (f *framer) writeQueryParams(opts *queryParams) {
@@ -1443,7 +1466,9 @@ func (f *framer) writeQueryParams(opts *queryParams) {
return
}
- var flags byte
+ var flags uint32
+ names := false
+
if len(opts.values) > 0 {
flags |= flagValues
}
@@ -1460,8 +1485,6 @@ func (f *framer) writeQueryParams(opts *queryParams) {
flags |= flagWithSerialConsistency
}
- names := false
-
// protoV3 specific things
if f.proto > protoVersion2 {
if opts.defaultTimestamp {
@@ -1475,17 +1498,23 @@ func (f *framer) writeQueryParams(opts *queryParams) {
}
if opts.keyspace != "" {
- if f.proto > protoVersion4 {
- flags |= flagWithKeyspace
- } else {
+ if f.proto < protoVersion5 {
panic(fmt.Errorf("the keyspace can only be set with
protocol 5 or higher"))
}
+ flags |= flagWithKeyspace
+ }
+
+ if opts.nowInSeconds != nil {
+ if f.proto < protoVersion5 {
+ panic(fmt.Errorf("now_in_seconds can only be set with
protocol 5 or higher"))
+ }
+ flags |= flagWithNowInSeconds
}
if f.proto > protoVersion4 {
- f.writeUint(uint32(flags))
+ f.writeUint(flags)
} else {
- f.writeByte(flags)
+ f.writeByte(byte(flags))
}
if n := len(opts.values); n > 0 {
@@ -1529,6 +1558,10 @@ func (f *framer) writeQueryParams(opts *queryParams) {
if opts.keyspace != "" {
f.writeString(opts.keyspace)
}
+
+ if opts.nowInSeconds != nil {
+ f.writeInt(int32(*opts.nowInSeconds))
+ }
}
type writeQueryFrame struct {
@@ -1575,6 +1608,9 @@ type writeExecuteFrame struct {
// v4+
customPayload map[string][]byte
+
+ // v5+
+ resultMetadataID []byte
}
func (e *writeExecuteFrame) String() string {
@@ -1582,16 +1618,21 @@ func (e *writeExecuteFrame) String() string {
}
func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error {
- return fr.writeExecuteFrame(streamID, e.preparedID, &e.params,
&e.customPayload)
+ return fr.writeExecuteFrame(streamID, e.preparedID, e.resultMetadataID,
&e.params, &e.customPayload)
}
-func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params
*queryParams, customPayload *map[string][]byte) error {
+func (f *framer) writeExecuteFrame(streamID int, preparedID, resultMetadataID
[]byte, params *queryParams, customPayload *map[string][]byte) error {
if len(*customPayload) > 0 {
f.payload()
}
f.writeHeader(f.flags, opExecute, streamID)
f.writeCustomPayload(customPayload)
f.writeShortBytes(preparedID)
+
+ if f.proto > protoVersion4 {
+ f.writeShortBytes(resultMetadataID)
+ }
+
if f.proto > protoVersion1 {
f.writeQueryParams(params)
} else {
@@ -1630,6 +1671,10 @@ type writeBatchFrame struct {
//v4+
customPayload map[string][]byte
+
+ //v5+
+ keyspace string
+ nowInSeconds *int
}
func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error {
@@ -1647,7 +1692,7 @@ func (f *framer) writeBatchFrame(streamID int, w
*writeBatchFrame, customPayload
n := len(w.statements)
f.writeShort(uint16(n))
- var flags byte
+ var flags uint32
for i := 0; i < n; i++ {
b := &w.statements[i]
@@ -1688,26 +1733,48 @@ func (f *framer) writeBatchFrame(streamID int, w
*writeBatchFrame, customPayload
if w.defaultTimestamp {
flags |= flagDefaultTimestamp
}
+ }
- if f.proto > protoVersion4 {
- f.writeUint(uint32(flags))
- } else {
- f.writeByte(flags)
+ if w.keyspace != "" {
+ if f.proto < protoVersion5 {
+ panic(fmt.Errorf("the keyspace can only be set with
protocol 5 or higher"))
}
+ flags |= flagWithKeyspace
+ }
- if w.serialConsistency > 0 {
- f.writeConsistency(Consistency(w.serialConsistency))
+ if w.nowInSeconds != nil {
+ if f.proto < protoVersion5 {
+ panic(fmt.Errorf("now_in_seconds can only be set with
protocol 5 or higher"))
}
+ flags |= flagWithNowInSeconds
+ }
- if w.defaultTimestamp {
- var ts int64
- if w.defaultTimestampValue != 0 {
- ts = w.defaultTimestampValue
- } else {
- ts = time.Now().UnixNano() / 1000
- }
- f.writeLong(ts)
+ if f.proto > protoVersion4 {
+ f.writeUint(flags)
+ } else {
+ f.writeByte(byte(flags))
+ }
+
+ if w.serialConsistency > 0 {
+ f.writeConsistency(Consistency(w.serialConsistency))
+ }
+
+ if w.defaultTimestamp {
+ var ts int64
+ if w.defaultTimestampValue != 0 {
+ ts = w.defaultTimestampValue
+ } else {
+ ts = time.Now().UnixNano() / 1000
}
+ f.writeLong(ts)
+ }
+
+ if w.keyspace != "" {
+ f.writeString(w.keyspace)
+ }
+
+ if w.nowInSeconds != nil {
+ f.writeInt(int32(*w.nowInSeconds))
}
return f.finish()
@@ -2041,3 +2108,262 @@ func (f *framer) writeBytesMap(m map[string][]byte) {
f.writeBytes(v)
}
}
+
+func (f *framer) prepareModernLayout() error {
+ // Ensure protocol version is V5 or higher
+ if f.proto < protoVersion5 {
+ panic("Modern layout is not supported with version V4 or less")
+ }
+
+ selfContained := true
+
+ var (
+ adjustedBuf []byte
+ tempBuf []byte
+ err error
+ )
+
+ // Process the buffer in chunks if it exceeds the max payload size
+ for len(f.buf) > maxSegmentPayloadSize {
+ if f.compres != nil {
+ tempBuf, err =
newCompressedSegment(f.buf[:maxSegmentPayloadSize], false, f.compres)
+ } else {
+ tempBuf, err =
newUncompressedSegment(f.buf[:maxSegmentPayloadSize], false)
+ }
+ if err != nil {
+ return err
+ }
+
+ adjustedBuf = append(adjustedBuf, tempBuf...)
+ f.buf = f.buf[maxSegmentPayloadSize:]
+ selfContained = false
+ }
+
+ // Process the remaining buffer
+ if f.compres != nil {
+ tempBuf, err = newCompressedSegment(f.buf, selfContained,
f.compres)
+ } else {
+ tempBuf, err = newUncompressedSegment(f.buf, selfContained)
+ }
+ if err != nil {
+ return err
+ }
+
+ adjustedBuf = append(adjustedBuf, tempBuf...)
+ f.buf = adjustedBuf
+
+ return nil
+}
+
+const (
+ crc24Size = 3
+ crc32Size = 4
+)
+
+func readUncompressedSegment(r io.Reader) ([]byte, bool, error) {
+ const (
+ headerSize = 3
+ )
+
+ header := [headerSize + crc24Size]byte{}
+
+ // Read the frame header
+ if _, err := io.ReadFull(r, header[:]); err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read
uncompressed frame, err: %w", err)
+ }
+
+ // Compute and verify the header CRC24
+ computedHeaderCRC24 := Crc24(header[:headerSize])
+ readHeaderCRC24 := uint32(header[3]) | uint32(header[4])<<8 |
uint32(header[5])<<16
+ if computedHeaderCRC24 != readHeaderCRC24 {
+ return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame
header, computed: %d, got: %d", computedHeaderCRC24, readHeaderCRC24)
+ }
+
+ // Extract the payload length and self-contained flag
+ headerInt := uint32(header[0]) | uint32(header[1])<<8 |
uint32(header[2])<<16
+ payloadLen := int(headerInt & maxSegmentPayloadSize)
+ isSelfContained := (headerInt & (1 << 17)) != 0
+
+ // Read the payload
+ payload := make([]byte, payloadLen)
+ if _, err := io.ReadFull(r, payload); err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read
uncompressed frame payload, err: %w", err)
+ }
+
+ // Read and verify the payload CRC32
+ if _, err := io.ReadFull(r, header[:crc32Size]); err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read payload
crc32, err: %w", err)
+ }
+
+ computedPayloadCRC32 := Crc32(payload)
+ readPayloadCRC32 := binary.LittleEndian.Uint32(header[:crc32Size])
+ if computedPayloadCRC32 != readPayloadCRC32 {
+ return nil, false, fmt.Errorf("gocql: payload crc32 mismatch,
computed: %d, got: %d", computedPayloadCRC32, readPayloadCRC32)
+ }
+
+ return payload, isSelfContained, nil
+}
+
+func newUncompressedSegment(payload []byte, isSelfContained bool) ([]byte,
error) {
+ const (
+ headerSize = 6
+ selfContainedBit = 1 << 17
+ )
+
+ payloadLen := len(payload)
+ if payloadLen > maxSegmentPayloadSize {
+ return nil, fmt.Errorf("gocql: payload length (%d) exceeds
maximum size of %d", payloadLen, maxSegmentPayloadSize)
+ }
+
+ // Create the segment
+ segmentSize := headerSize + payloadLen + crc32Size
+ segment := make([]byte, segmentSize)
+
+ // First 3 bytes: payload length and self-contained flag
+ headerInt := uint32(payloadLen)
+ if isSelfContained {
+ headerInt |= selfContainedBit // Set the self-contained flag
+ }
+
+ // Encode the first 3 bytes as a single little-endian integer
+ segment[0] = byte(headerInt)
+ segment[1] = byte(headerInt >> 8)
+ segment[2] = byte(headerInt >> 16)
+
+ // Calculate CRC24 for the first 3 bytes of the header
+ crc := Crc24(segment[:3])
+
+ // Encode CRC24 into the next 3 bytes of the header
+ segment[3] = byte(crc)
+ segment[4] = byte(crc >> 8)
+ segment[5] = byte(crc >> 16)
+
+ copy(segment[headerSize:], payload) // Copy the payload to the segment
+
+ // Calculate CRC32 for the payload
+ payloadCRC32 := Crc32(payload)
+ binary.LittleEndian.PutUint32(segment[headerSize+payloadLen:],
payloadCRC32)
+
+ return segment, nil
+}
+
+func newCompressedSegment(uncompressedPayload []byte, isSelfContained bool,
compressor Compressor) ([]byte, error) {
+ const (
+ headerSize = 5
+ selfContainedBit = 1 << 34
+ )
+
+ uncompressedLen := len(uncompressedPayload)
+ if uncompressedLen > maxSegmentPayloadSize {
+ return nil, fmt.Errorf("gocql: payload length (%d) exceeds
maximum size of %d", uncompressedPayload, maxSegmentPayloadSize)
+ }
+
+ compressedPayload, err := compressor.AppendCompressed(nil,
uncompressedPayload)
+ if err != nil {
+ return nil, err
+ }
+
+ compressedLen := len(compressedPayload)
+
+ // Compression is not worth it
+ if uncompressedLen < compressedLen {
+ // native_protocol_v5.spec
+ // 2.2
+ // An uncompressed length of 0 signals that the compressed
payload
+ // should be used as-is and not decompressed.
+ compressedPayload = uncompressedPayload
+ compressedLen = uncompressedLen
+ uncompressedLen = 0
+ }
+
+ // Combine compressed and uncompressed lengths and set the
self-contained flag if needed
+ combined := uint64(compressedLen) | uint64(uncompressedLen)<<17
+ if isSelfContained {
+ combined |= selfContainedBit
+ }
+
+ var headerBuf [headerSize + crc24Size]byte
+
+ // Write the combined value into the header buffer
+ binary.LittleEndian.PutUint64(headerBuf[:], combined)
+
+ // Create a buffer with enough capacity to hold the header, compressed
payload, and checksums
+ buf := bytes.NewBuffer(make([]byte, 0,
headerSize+crc24Size+compressedLen+crc32Size))
+
+ // Write the first 5 bytes of the header (compressed and uncompressed
sizes)
+ buf.Write(headerBuf[:headerSize])
+
+ // Compute and write the CRC24 checksum of the first 5 bytes
+ headerChecksum := Crc24(headerBuf[:headerSize])
+
+ // LittleEndian 3 bytes
+ headerBuf[0] = byte(headerChecksum)
+ headerBuf[1] = byte(headerChecksum >> 8)
+ headerBuf[2] = byte(headerChecksum >> 16)
+ buf.Write(headerBuf[:3])
+
+ buf.Write(compressedPayload)
+
+ // Compute and write the CRC32 checksum of the payload
+ payloadChecksum := Crc32(compressedPayload)
+ binary.LittleEndian.PutUint32(headerBuf[:], payloadChecksum)
+ buf.Write(headerBuf[:4])
+
+ return buf.Bytes(), nil
+}
+
+func readCompressedSegment(r io.Reader, compressor Compressor) ([]byte, bool,
error) {
+ const headerSize = 5
+ var (
+ headerBuf [headerSize + crc24Size]byte
+ err error
+ )
+
+ if _, err = io.ReadFull(r, headerBuf[:]); err != nil {
+ return nil, false, err
+ }
+
+ // Reading checksum from frame header
+ readHeaderChecksum := uint32(headerBuf[5]) | uint32(headerBuf[6])<<8 |
uint32(headerBuf[7])<<16
+ if computedHeaderChecksum := Crc24(headerBuf[:headerSize]);
computedHeaderChecksum != readHeaderChecksum {
+ return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame
header, read: %d, computed: %d", readHeaderChecksum, computedHeaderChecksum)
+ }
+
+ // First 17 bits - payload size after compression
+ compressedLen := uint32(headerBuf[0]) | uint32(headerBuf[1])<<8 |
uint32(headerBuf[2]&0x1)<<16
+
+ // The next 17 bits - payload size before compression
+ uncompressedLen := (uint32(headerBuf[2]) >> 1) |
uint32(headerBuf[3])<<7 | uint32(headerBuf[4]&0b11)<<15
+
+ // Self-contained flag
+ selfContained := (headerBuf[4] & 0b100) != 0
+
+ compressedPayload := make([]byte, compressedLen)
+ if _, err = io.ReadFull(r, compressedPayload); err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read compressed
frame payload, err: %w", err)
+ }
+
+ if _, err = io.ReadFull(r, headerBuf[:crc32Size]); err != nil {
+ return nil, false, fmt.Errorf("gocql: failed to read payload
crc32, err: %w", err)
+ }
+
+ // Ensuring if payload checksum matches
+ readPayloadChecksum := binary.LittleEndian.Uint32(headerBuf[:crc32Size])
+ if computedPayloadChecksum := Crc32(compressedPayload);
readPayloadChecksum != computedPayloadChecksum {
+ return nil, false, fmt.Errorf("gocql: crc32 mismatch in
payload, read: %d, computed: %d", readPayloadChecksum, computedPayloadChecksum)
+ }
+
+ var uncompressedPayload []byte
+ if uncompressedLen > 0 {
+ if uncompressedPayload, err =
compressor.AppendDecompressed(nil, compressedPayload, uncompressedLen); err !=
nil {
+ return nil, false, err
+ }
+ if uint32(len(uncompressedPayload)) != uncompressedLen {
+ return nil, false, fmt.Errorf("gocql: length mismatch
after payload decoding, got %d, expected %d", len(uncompressedPayload),
uncompressedLen)
+ }
+ } else {
+ uncompressedPayload = compressedPayload
+ }
+
+ return uncompressedPayload, selfContained, nil
+}
diff --git a/frame_test.go b/frame_test.go
index 170cba7..8cb9024 100644
--- a/frame_test.go
+++ b/frame_test.go
@@ -26,8 +26,12 @@ package gocql
import (
"bytes"
+ "errors"
"os"
"testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestFuzzBugs(t *testing.T) {
@@ -127,3 +131,313 @@ func TestFrameReadTooLong(t *testing.T) {
t.Fatalf("expected to get header %v got %v", opReady, head.op)
}
}
+
+func Test_framer_writeExecuteFrame(t *testing.T) {
+ framer := newFramer(nil, protoVersion5)
+ nowInSeconds := 123
+ frame := writeExecuteFrame{
+ preparedID: []byte{1, 2, 3},
+ resultMetadataID: []byte{4, 5, 6},
+ customPayload: map[string][]byte{
+ "key1": []byte("value1"),
+ },
+ params: queryParams{
+ nowInSeconds: &nowInSeconds,
+ keyspace: "test_keyspace",
+ },
+ }
+
+ err := framer.writeExecuteFrame(123, frame.preparedID,
frame.resultMetadataID, &frame.params, &frame.customPayload)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // skipping header
+ framer.buf = framer.buf[9:]
+
+ assertDeepEqual(t, "customPayload", frame.customPayload,
framer.readBytesMap())
+ assertDeepEqual(t, "preparedID", frame.preparedID,
framer.readShortBytes())
+ assertDeepEqual(t, "resultMetadataID", frame.resultMetadataID,
framer.readShortBytes())
+ assertDeepEqual(t, "constistency", frame.params.consistency,
Consistency(framer.readShort()))
+
+ flags := framer.readInt()
+ if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) {
+ t.Fatal("expected flagNowInSeconds to be set, but it is not")
+ }
+
+ if flags&int(flagWithKeyspace) != int(flagWithKeyspace) {
+ t.Fatal("expected flagWithKeyspace to be set, but it is not")
+ }
+
+ assertDeepEqual(t, "keyspace", frame.params.keyspace,
framer.readString())
+ assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt())
+}
+
+func Test_framer_writeBatchFrame(t *testing.T) {
+ framer := newFramer(nil, protoVersion5)
+ nowInSeconds := 123
+ frame := writeBatchFrame{
+ customPayload: map[string][]byte{
+ "key1": []byte("value1"),
+ },
+ nowInSeconds: &nowInSeconds,
+ }
+
+ err := framer.writeBatchFrame(123, &frame, frame.customPayload)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // skipping header
+ framer.buf = framer.buf[9:]
+
+ assertDeepEqual(t, "customPayload", frame.customPayload,
framer.readBytesMap())
+ assertDeepEqual(t, "typ", frame.typ, BatchType(framer.readByte()))
+ assertDeepEqual(t, "len(statements)", len(frame.statements),
int(framer.readShort()))
+ assertDeepEqual(t, "consistency", frame.consistency,
Consistency(framer.readShort()))
+
+ flags := framer.readInt()
+ if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) {
+ t.Fatal("expected flagNowInSeconds to be set, but it is not")
+ }
+
+ assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt())
+}
+
+type testMockedCompressor struct {
+ // this is an error its methods should return
+ expectedError error
+
+ // invalidateDecodedDataLength allows to simulate data decoding
invalidation
+ invalidateDecodedDataLength bool
+}
+
+func (m testMockedCompressor) Name() string {
+ return "testMockedCompressor"
+}
+
+func (m testMockedCompressor) AppendCompressed(_, src []byte) ([]byte, error) {
+ if m.expectedError != nil {
+ return nil, m.expectedError
+ }
+ return src, nil
+}
+
+func (m testMockedCompressor) AppendDecompressed(_, src []byte,
decompressedLength uint32) ([]byte, error) {
+ if m.expectedError != nil {
+ return nil, m.expectedError
+ }
+
+ // simulating invalid size of decoded data
+ if m.invalidateDecodedDataLength {
+ return src[:decompressedLength-1], nil
+ }
+
+ return src, nil
+}
+
+func (m testMockedCompressor) AppendCompressedWithLength(dst, src []byte)
([]byte, error) {
+ panic("testMockedCompressor.AppendCompressedWithLength is not
implemented")
+}
+
+func (m testMockedCompressor) AppendDecompressedWithLength(dst, src []byte)
([]byte, error) {
+ panic("testMockedCompressor.AppendDecompressedWithLength is not
implemented")
+}
+
+func Test_readUncompressedFrame(t *testing.T) {
+ tests := []struct {
+ name string
+ modifyFrame func([]byte) []byte
+ expectedErr string
+ }{
+ {
+ name: "header crc24 mismatch",
+ modifyFrame: func(frame []byte) []byte {
+ // simulating some crc invalidation
+ frame[0] = 255
+ return frame
+ },
+ expectedErr: "gocql: crc24 mismatch in frame header",
+ },
+ {
+ name: "body crc32 mismatch",
+ modifyFrame: func(frame []byte) []byte {
+ // simulating body crc32 mismatch
+ frame[len(frame)-1] = 255
+ return frame
+ },
+ expectedErr: "gocql: payload crc32 mismatch",
+ },
+ {
+ name: "invalid frame length",
+ modifyFrame: func(frame []byte) []byte {
+ // simulating body length invalidation
+ frame = frame[:7]
+ return frame
+ },
+ expectedErr: "gocql: failed to read uncompressed frame
payload",
+ },
+ {
+ name: "cannot read body checksum",
+ modifyFrame: func(frame []byte) []byte {
+ // simulating body length invalidation
+ frame = frame[:len(frame)-4]
+ return frame
+ },
+ expectedErr: "gocql: failed to read payload crc32",
+ },
+ {
+ name: "success",
+ modifyFrame: nil,
+ expectedErr: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ framer := newFramer(nil, protoVersion5)
+ req := writeQueryFrame{
+ statement: "SELECT * FROM system.local",
+ params: queryParams{
+ consistency: Quorum,
+ keyspace: "gocql_test",
+ },
+ }
+
+ err := req.buildFrame(framer, 128)
+ require.NoError(t, err)
+
+ frame, err := newUncompressedSegment(framer.buf, true)
+ require.NoError(t, err)
+
+ if tt.modifyFrame != nil {
+ frame = tt.modifyFrame(frame)
+ }
+
+ readFrame, isSelfContained, err :=
readUncompressedSegment(bytes.NewReader(frame))
+
+ if tt.expectedErr != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ assert.True(t, isSelfContained)
+ assert.Equal(t, framer.buf, readFrame)
+ }
+ })
+ }
+}
+
+func Test_readCompressedFrame(t *testing.T) {
+ tests := []struct {
+ name string
+ // modifyFrameFn is useful for simulating frame data
invalidation
+ modifyFrameFn func([]byte) []byte
+ compressor testMockedCompressor
+
+ // expectedErrorMsg is an error message that should be returned
by Error() method.
+ // We need this to understand which of fmt.Errorf() is returned
+ expectedErrorMsg string
+ }{
+ {
+ name: "header crc24 mismatch",
+ modifyFrameFn: func(frame []byte) []byte {
+ // simulating some crc invalidation
+ frame[0] = 255
+ return frame
+ },
+ expectedErrorMsg: "gocql: crc24 mismatch in frame
header",
+ },
+ {
+ name: "body crc32 mismatch",
+ modifyFrameFn: func(frame []byte) []byte {
+ // simulating body crc32 mismatch
+ frame[len(frame)-1] = 255
+ return frame
+ },
+ expectedErrorMsg: "gocql: crc32 mismatch in payload",
+ },
+ {
+ name: "invalid frame length",
+ modifyFrameFn: func(frame []byte) []byte {
+ // simulating body length invalidation
+ return frame[:12]
+ },
+ expectedErrorMsg: "gocql: failed to read compressed
frame payload",
+ },
+ {
+ name: "cannot read body checksum",
+ modifyFrameFn: func(frame []byte) []byte {
+ // simulating body length invalidation
+ return frame[:len(frame)-4]
+ },
+ expectedErrorMsg: "gocql: failed to read payload crc32",
+ },
+ {
+ name: "failed to encode payload",
+ modifyFrameFn: nil,
+ compressor: testMockedCompressor{
+ expectedError: errors.New("failed to encode
payload"),
+ },
+ expectedErrorMsg: "failed to encode payload",
+ },
+ {
+ name: "failed to decode payload",
+ modifyFrameFn: nil,
+ compressor: testMockedCompressor{
+ expectedError: errors.New("failed to decode
payload"),
+ },
+ expectedErrorMsg: "failed to decode payload",
+ },
+ {
+ name: "length mismatch after decoding",
+ modifyFrameFn: nil,
+ compressor: testMockedCompressor{
+ invalidateDecodedDataLength: true,
+ },
+ expectedErrorMsg: "gocql: length mismatch after payload
decoding",
+ },
+ {
+ name: "success",
+ modifyFrameFn: nil,
+ expectedErrorMsg: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ framer := newFramer(nil, protoVersion5)
+ req := writeQueryFrame{
+ statement: "SELECT * FROM system.local",
+ params: queryParams{
+ consistency: Quorum,
+ keyspace: "gocql_test",
+ },
+ }
+
+ err := req.buildFrame(framer, 128)
+ require.NoError(t, err)
+
+ frame, err := newCompressedSegment(framer.buf, true,
testMockedCompressor{})
+ require.NoError(t, err)
+
+ if tt.modifyFrameFn != nil {
+ frame = tt.modifyFrameFn(frame)
+ }
+
+ readFrame, selfContained, err :=
readCompressedSegment(bytes.NewReader(frame), tt.compressor)
+
+ switch {
+ case tt.expectedErrorMsg != "":
+ require.Error(t, err)
+ require.Contains(t, err.Error(),
tt.expectedErrorMsg)
+ case tt.compressor.expectedError != nil:
+ require.ErrorIs(t, err,
tt.compressor.expectedError)
+ default:
+ require.NoError(t, err)
+ assert.True(t, selfContained)
+ assert.Equal(t, framer.buf, readFrame)
+ }
+ })
+ }
+}
diff --git a/lz4/lz4.go b/lz4/lz4.go
index 049fdc0..c836a09 100644
--- a/lz4/lz4.go
+++ b/lz4/lz4.go
@@ -27,7 +27,6 @@ package lz4
import (
"encoding/binary"
"fmt"
-
"github.com/pierrec/lz4/v4"
)
@@ -47,29 +46,71 @@ func (s LZ4Compressor) Name() string {
return "lz4"
}
-func (s LZ4Compressor) Encode(data []byte) ([]byte, error) {
- buf := make([]byte, lz4.CompressBlockBound(len(data)+4))
+const dataLengthSize = 4
+
+func (s LZ4Compressor) AppendCompressedWithLength(dst, src []byte) ([]byte,
error) {
+ maxLength := lz4.CompressBlockBound(len(src))
+ oldDstLen := len(dst)
+ dst = grow(dst, maxLength+dataLengthSize)
+
var compressor lz4.Compressor
- n, err := compressor.CompressBlock(data, buf[4:])
+ n, err := compressor.CompressBlock(src, dst[oldDstLen+dataLengthSize:])
// According to lz4.CompressBlock doc, it doesn't fail as long as the
dst
// buffer length is at least lz4.CompressBlockBound(len(data))) bytes,
but
// we check for error anyway just to be thorough.
if err != nil {
return nil, err
}
- binary.BigEndian.PutUint32(buf, uint32(len(data)))
- return buf[:n+4], nil
+ binary.BigEndian.PutUint32(dst[oldDstLen:oldDstLen+dataLengthSize],
uint32(len(src)))
+ return dst[:oldDstLen+n+dataLengthSize], nil
+}
+
+func (s LZ4Compressor) AppendDecompressedWithLength(dst, src []byte) ([]byte,
error) {
+ if len(src) < dataLengthSize {
+ return nil, fmt.Errorf("cassandra lz4 block size should be >4,
got=%d", len(src))
+ }
+ uncompressedLength := binary.BigEndian.Uint32(src[:dataLengthSize])
+ if uncompressedLength == 0 {
+ return nil, nil
+ }
+ oldDstLen := len(dst)
+ dst = grow(dst, int(uncompressedLength))
+ n, err := lz4.UncompressBlock(src[dataLengthSize:], dst[oldDstLen:])
+ return dst[:oldDstLen+n], err
+
}
-func (s LZ4Compressor) Decode(data []byte) ([]byte, error) {
- if len(data) < 4 {
- return nil, fmt.Errorf("cassandra lz4 block size should be >4,
got=%d", len(data))
+func (s LZ4Compressor) AppendCompressed(dst, src []byte) ([]byte, error) {
+ maxLength := lz4.CompressBlockBound(len(src))
+ oldDstLen := len(dst)
+ dst = grow(dst, maxLength)
+
+ var compressor lz4.Compressor
+ n, err := compressor.CompressBlock(src, dst[oldDstLen:])
+ if err != nil {
+ return nil, err
}
- uncompressedLength := binary.BigEndian.Uint32(data)
+
+ return dst[:oldDstLen+n], nil
+}
+
+func (s LZ4Compressor) AppendDecompressed(dst, src []byte, uncompressedLength
uint32) ([]byte, error) {
if uncompressedLength == 0 {
return nil, nil
}
- buf := make([]byte, uncompressedLength)
- n, err := lz4.UncompressBlock(data[4:], buf)
- return buf[:n], err
+ oldDstLen := len(dst)
+ dst = grow(dst, int(uncompressedLength))
+ n, err := lz4.UncompressBlock(src, dst[oldDstLen:])
+ return dst[:oldDstLen+n], err
+}
+
+// grow grows b to guaranty space for n elements, if needed.
+func grow(b []byte, n int) []byte {
+ oldLen := len(b)
+ if cap(b)-oldLen < n {
+ newBuf := make([]byte, oldLen+n)
+ copy(newBuf, b)
+ b = newBuf
+ }
+ return b[:oldLen+n]
}
diff --git a/lz4/lz4_test.go b/lz4/lz4_test.go
index e0834b9..379afd4 100644
--- a/lz4/lz4_test.go
+++ b/lz4/lz4_test.go
@@ -25,6 +25,7 @@
package lz4
import (
+ "github.com/pierrec/lz4/v4"
"testing"
"github.com/stretchr/testify/require"
@@ -34,21 +35,215 @@ func TestLZ4Compressor(t *testing.T) {
var c LZ4Compressor
require.Equal(t, "lz4", c.Name())
- _, err := c.Decode([]byte{0, 1, 2})
+ _, err := c.AppendDecompressedWithLength(nil, []byte{0, 1, 2})
require.EqualError(t, err, "cassandra lz4 block size should be >4,
got=3")
- _, err = c.Decode([]byte{0, 1, 2, 4, 5})
+ _, err = c.AppendDecompressedWithLength(nil, []byte{0, 1, 2, 4, 5})
require.EqualError(t, err, "lz4: invalid source or destination buffer
too short")
// If uncompressed size is zero then nothing is decoded even if present.
- decoded, err := c.Decode([]byte{0, 0, 0, 0, 5, 7, 8})
+ decoded, err := c.AppendDecompressedWithLength(nil, []byte{0, 0, 0, 0,
5, 7, 8})
require.NoError(t, err)
require.Nil(t, decoded)
original := []byte("My Test String")
- encoded, err := c.Encode(original)
+ encoded, err := c.AppendCompressedWithLength(nil, original)
require.NoError(t, err)
- decoded, err = c.Decode(encoded)
+ decoded, err = c.AppendDecompressedWithLength(nil, encoded)
require.NoError(t, err)
require.Equal(t, original, decoded)
}
+
+func TestLZ4Compressor_AppendCompressedDecompressed(t *testing.T) {
+ c := LZ4Compressor{}
+
+ invalidUncompressedLength := uint32(10)
+ _, err := c.AppendDecompressed(nil, []byte{0, 1, 2, 4, 5},
invalidUncompressedLength)
+ require.EqualError(t, err, "lz4: invalid source or destination buffer
too short")
+
+ original := []byte("My Test String")
+ encoded, err := c.AppendCompressed(nil, original)
+ require.NoError(t, err)
+ decoded, err := c.AppendDecompressed(nil, encoded,
uint32(len(original)))
+ require.NoError(t, err)
+ require.Equal(t, original, decoded)
+}
+
+func TestLZ4Compressor_AppendWithLengthGrowSliceWithData(t *testing.T) {
+ var tests = []struct {
+ name string
+ src []byte
+ dst []byte
+ shouldReuseDst bool
+ decodeDst []byte
+ shouldReuseDecodeDst bool
+ }{
+ {
+ name: "both dst are empty",
+ src: []byte("small data"),
+ dst: nil,
+ decodeDst: nil,
+ },
+ {
+ name: "dst is nil",
+ src: []byte("another piece of data"),
+ dst: nil,
+ decodeDst: []byte("something"),
+ },
+ {
+ name: "decodeDst is nil",
+ src: []byte("another piece of data"),
+ dst: []byte("some"),
+ decodeDst: nil,
+ },
+ {
+ name: "both dst are not empty",
+ src: []byte("another piece of data"),
+ dst: []byte("dst"),
+ decodeDst: []byte("decodeDst"),
+ },
+ {
+ name: "both dst slices have enough
capacity",
+ src: []byte("small"),
+ dst:
createBufWithCapAndData("cap=128", 128),
+ shouldReuseDst: true,
+ decodeDst:
createBufWithCapAndData("cap=256", 256),
+ shouldReuseDecodeDst: true,
+ },
+ {
+ name: "both dsts have some data and not enough
capacity",
+ src: []byte("small"),
+ dst: createBufWithCapAndData("data", 6),
+ decodeDst: createBufWithCapAndData("wow", 4),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ compressor := LZ4Compressor{}
+
+ // Appending compressed data to dst,
+ // expecting that dst still contains "test"
+ result, err :=
compressor.AppendCompressedWithLength(tt.dst, tt.src)
+ require.NoError(t, err)
+
+ var expectedCap int
+ if tt.shouldReuseDst {
+ expectedCap = cap(tt.dst)
+ } else {
+ expectedCap = len(tt.dst) +
lz4.CompressBlockBound(len(tt.src)) + dataLengthSize
+ }
+
+ require.Equal(t, expectedCap, cap(result))
+ if len(tt.dst) > 0 {
+ require.Equal(t, tt.dst, result[:len(tt.dst)])
+ }
+
+ result, err =
compressor.AppendDecompressedWithLength(tt.decodeDst, result[len(tt.dst):])
+ require.NoError(t, err)
+
+ var expectedDecodeCap int
+ if tt.shouldReuseDecodeDst {
+ expectedDecodeCap = cap(tt.decodeDst)
+ } else {
+ expectedDecodeCap = len(tt.decodeDst) +
len(tt.src)
+ }
+
+ require.Equal(t, expectedDecodeCap, cap(result))
+ require.Equal(t, tt.src, result[len(tt.decodeDst):])
+ })
+ }
+}
+
+func TestLZ4Compressor_AppendGrowSliceWithData(t *testing.T) {
+ var tests = []struct {
+ name string
+ src []byte
+ dst []byte
+ shouldReuseDst bool
+ decodeDst []byte
+ shouldReuseDecodeDst bool
+ }{
+ {
+ name: "both dst are empty",
+ src: []byte("small data"),
+ dst: nil,
+ decodeDst: nil,
+ },
+ {
+ name: "dst is nil",
+ src: []byte("another piece of data"),
+ dst: nil,
+ decodeDst: []byte("something"),
+ },
+ {
+ name: "decodeDst is nil",
+ src: []byte("another piece of data"),
+ dst: []byte("some"),
+ decodeDst: nil,
+ },
+ {
+ name: "both dst are not empty",
+ src: []byte("another piece of data"),
+ dst: []byte("dst"),
+ decodeDst: []byte("decodeDst"),
+ },
+ {
+ name: "both dst slices have enough
capacity",
+ src: []byte("small"),
+ dst:
createBufWithCapAndData("cap=128", 128),
+ shouldReuseDst: true,
+ decodeDst:
createBufWithCapAndData("cap=256", 256),
+ shouldReuseDecodeDst: true,
+ },
+ {
+ name: "both dst slices have some data and not
enough capacity",
+ src: []byte("small"),
+ dst: createBufWithCapAndData("data", 6),
+ decodeDst: createBufWithCapAndData("wow", 4),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ compressor := LZ4Compressor{}
+
+ // Appending compressed data to dst,
+ // expecting that dst still contains "test"
+ result, err := compressor.AppendCompressed(tt.dst,
tt.src)
+ require.NoError(t, err)
+
+ var expectedCap int
+ if tt.shouldReuseDst {
+ expectedCap = cap(tt.dst)
+ } else {
+ expectedCap = len(tt.dst) +
lz4.CompressBlockBound(len(tt.src))
+ }
+
+ require.Equal(t, expectedCap, cap(result))
+ if len(tt.dst) > 0 {
+ require.Equal(t, tt.dst, result[:len(tt.dst)])
+ }
+
+ uncompressedLen := uint32(len(tt.src))
+ result, err =
compressor.AppendDecompressed(tt.decodeDst, result[len(tt.dst):],
uncompressedLen)
+ require.NoError(t, err)
+
+ var expectedDecodeCap int
+ if tt.shouldReuseDst {
+ expectedDecodeCap = cap(tt.decodeDst)
+ } else {
+ expectedDecodeCap = len(tt.decodeDst) +
len(tt.src)
+ }
+
+ require.Equal(t, expectedDecodeCap, cap(result))
+ require.Equal(t, tt.src, result[len(tt.decodeDst):])
+ })
+ }
+}
+
+func createBufWithCapAndData(data string, cap int) []byte {
+ buf := make([]byte, cap)
+ copy(buf, data)
+ return buf[:len(data)]
+}
diff --git a/prepared_cache.go b/prepared_cache.go
index 3fd256d..7f5533a 100644
--- a/prepared_cache.go
+++ b/prepared_cache.go
@@ -100,3 +100,20 @@ func (p *preparedLRU) evictPreparedID(key string, id
[]byte) {
}
}
+
+func (p *preparedLRU) get(key string) (*inflightPrepare, bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ val, ok := p.lru.Get(key)
+ if !ok {
+ return nil, false
+ }
+
+ ifp, ok := val.(*inflightPrepare)
+ if !ok {
+ return nil, false
+ }
+
+ return ifp, true
+}
diff --git a/session.go b/session.go
index 957a742..ed1a078 100644
--- a/session.go
+++ b/session.go
@@ -596,11 +596,20 @@ func (s *Session) getConn() *Conn {
return nil
}
-// returns routing key indexes and type info
-func (s *Session) routingKeyInfo(ctx context.Context, stmt string)
(*routingKeyInfo, error) {
+// Returns routing key indexes and type info.
+// If keyspace == "" it uses the keyspace which is specified in
Cluster.Keyspace
+func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace
string) (*routingKeyInfo, error) {
+ if keyspace == "" {
+ keyspace = s.cfg.Keyspace
+ }
+
+ routingKeyInfoCacheKey := keyspace + stmt
+
s.routingKeyInfoCache.mu.Lock()
- entry, cached := s.routingKeyInfoCache.lru.Get(stmt)
+ // Using here keyspace + stmt as a cache key because
+ // the query keyspace could be overridden via SetKeyspace
+ entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey)
if cached {
// done accessing the cache
s.routingKeyInfoCache.mu.Unlock()
@@ -624,7 +633,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt
string) (*routingKeyI
inflight := new(inflightCachedEntry)
inflight.wg.Add(1)
defer inflight.wg.Done()
- s.routingKeyInfoCache.lru.Add(stmt, inflight)
+ s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight)
s.routingKeyInfoCache.mu.Unlock()
var (
@@ -640,7 +649,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt
string) (*routingKeyI
}
// get the query info for the statement
- info, inflight.err = conn.prepareStatement(ctx, stmt, nil)
+ info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace)
if inflight.err != nil {
// don't cache this error
s.routingKeyInfoCache.Remove(stmt)
@@ -656,7 +665,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt
string) (*routingKeyI
}
table := info.request.table
- keyspace := info.request.keyspace
+ if info.request.keyspace != "" {
+ keyspace = info.request.keyspace
+ }
if len(info.request.pkeyColumns) > 0 {
// proto v4 dont need to calculate primary key columns
@@ -958,6 +969,9 @@ type Query struct {
// hostID specifies the host on which the query should be executed.
// If it is empty, then the host is picked by HostSelectionPolicy
hostID string
+
+ keyspace string
+ nowInSecondsValue *int
}
type queryRoutingInfo struct {
@@ -1165,6 +1179,9 @@ func (q *Query) Keyspace() string {
if q.routingInfo.keyspace != "" {
return q.routingInfo.keyspace
}
+ if q.keyspace != "" {
+ return q.keyspace
+ }
if q.session == nil {
return ""
@@ -1196,7 +1213,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
}
// try to determine the routing key
- routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt)
+ routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt,
q.keyspace)
if err != nil {
return nil, err
}
@@ -1468,6 +1485,24 @@ func (q *Query) GetHostID() string {
return q.hostID
}
+// SetKeyspace will enable keyspace flag on the query.
+// It allows to specify the keyspace that the query should be executed in
+//
+// Only available on protocol >= 5.
+func (q *Query) SetKeyspace(keyspace string) *Query {
+ q.keyspace = keyspace
+ return q
+}
+
+// WithNowInSeconds will enable the with now_in_seconds flag on the query.
+// Also, it allows to define now_in_seconds value.
+//
+// Only available on protocol >= 5.
+func (q *Query) WithNowInSeconds(now int) *Query {
+ q.nowInSecondsValue = &now
+ return q
+}
+
// Iter represents an iterator that can be used to iterate over all rows that
// were returned by a query. The iterator might send additional queries to the
// database during the iteration if paging was enabled.
@@ -1787,6 +1822,7 @@ type Batch struct {
cancelBatch func()
keyspace string
metrics *queryMetrics
+ nowInSeconds *int
// routingInfo is a pointer because Query can be copied and copyable
struct can't hold a mutex.
routingInfo *queryRoutingInfo
@@ -2031,7 +2067,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) {
return nil, nil
}
// try to determine the routing key
- routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt)
+ routingKeyInfo, err := b.session.routingKeyInfo(b.Context(),
entry.Stmt, b.keyspace)
if err != nil {
return nil, err
}
@@ -2091,6 +2127,24 @@ func (b *Batch) GetHostID() string {
return ""
}
+// SetKeyspace will enable keyspace flag on the query.
+// It allows to specify the keyspace that the query should be executed in
+//
+// Only available on protocol >= 5.
+func (b *Batch) SetKeyspace(keyspace string) *Batch {
+ b.keyspace = keyspace
+ return b
+}
+
+// WithNowInSeconds will enable the with now_in_seconds flag on the query.
+// Also, it allows to define now_in_seconds value.
+//
+// Only available on protocol >= 5.
+func (b *Batch) WithNowInSeconds(now int) *Batch {
+ b.nowInSeconds = &now
+ return b
+}
+
type BatchType byte
const (
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]