diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 871eb152da..107a022b4c 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -392,7 +392,14 @@ def transaction_checkout(self): this connection yet. Return the started one otherwise. This method is a no-op if the connection is in autocommit mode and no - explicit transaction has been started + explicit transaction has been started. + + The transaction is returned without calling ``begin()``. The + underlying ``Transaction.execute_sql`` and ``execute_update`` + methods detect ``_transaction_id is None`` and use *inline begin* + — piggybacking a ``BeginTransaction`` on the first RPC via + ``TransactionSelector(begin=...)``. This eliminates a separate + ``BeginTransaction`` RPC round-trip per transaction. :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` :returns: A Cloud Spanner transaction object, ready to use. @@ -410,7 +417,6 @@ def transaction_checkout(self): self.transaction_tag = None self._snapshot = None self._spanner_transaction_started = True - self._transaction.begin() return self._transaction diff --git a/tests/mockserver_tests/test_dbapi_inline_begin.py b/tests/mockserver_tests/test_dbapi_inline_begin.py new file mode 100644 index 0000000000..b8d61c7729 --- /dev/null +++ b/tests/mockserver_tests/test_dbapi_inline_begin.py @@ -0,0 +1,295 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that the DBAPI uses inline begin for read-write transactions. + +After removing the explicit ``Transaction.begin()`` call from +``Connection.transaction_checkout()``, the DBAPI should piggyback +``BeginTransaction`` on the first ``ExecuteSql`` / ``ExecuteUpdate`` request +via ``TransactionSelector(begin=...)``, eliminating one gRPC round-trip +per transaction. + +Read-only transactions are unaffected — they still use an explicit +``BeginTransaction`` RPC via ``snapshot_checkout()``. +""" + +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_v1 import ( + BeginTransactionRequest, + CommitRequest, + ExecuteSqlRequest, + RollbackRequest, + TypeCode, +) +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer +from google.cloud.spanner_v1.database_sessions_manager import TransactionType + +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_single_result, + add_update_count, + add_error, + aborted_status, +) + + +class TestDbapiInlineBegin(MockServerTestBase): + @classmethod + def setup_class(cls): + super().setup_class() + add_single_result( + "select name from singers", "name", TypeCode.STRING, [("Some Singer",)] + ) + add_update_count( + "insert into singers (id, name) values (1, 'Some Singer')", 1 + ) + + def test_read_write_inline_begin(self): + """Comprehensive check for a single-statement read-write transaction. + + Verifies: + - No BeginTransactionRequest is sent + - The ExecuteSqlRequest uses TransactionSelector(begin=ReadWrite(...)) + - The request sequence is [ExecuteSqlRequest, CommitRequest] + - The query returns correct data + """ + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + rows = cursor.fetchall() + connection.commit() + + self.assertEqual( + [("Some Singer",)], rows, + "Query should return the mocked result set", + ) + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Read-write DBAPI transactions should not send " + "a separate BeginTransactionRequest") + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertGreaterEqual(len(sql_requests), 1) + first_sql = sql_requests[0] + self.assertTrue( + first_sql.transaction.begin.read_write == first_sql.transaction.begin.read_write, + ) + self.assertIn( + "read_write", first_sql.transaction.begin, + "First ExecuteSqlRequest should use inline begin with " + "TransactionSelector(begin=ReadWrite(...))", + ) + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + def test_read_write_dml_request_sequence(self): + """DML write via DBAPI: ExecuteSql + Commit (no BeginTransaction).""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + def test_read_then_write_full_lifecycle(self): + """Read + write in same transaction: verifies the complete inline begin lifecycle. + + Checks: + - First ExecuteSqlRequest uses TransactionSelector(begin=ReadWrite(...)) + - Second ExecuteSqlRequest uses TransactionSelector(id=) + - CommitRequest uses the same transaction_id as the second statement + - Query returns correct data + - Request sequence is [ExecuteSql, ExecuteSql, Commit] + """ + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + rows = cursor.fetchall() + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + self.assertEqual( + [("Some Singer",)], rows, + "Query should return the mocked result set", + ) + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(2, len(sql_requests)) + + first = sql_requests[0] + self.assertIn( + "read_write", first.transaction.begin, + "First statement should use inline begin", + ) + + second = sql_requests[1] + self.assertNotEqual( + b"", second.transaction.id, + "Second statement should use TransactionSelector(id=...) " + "with the transaction_id returned from inline begin", + ) + + commit_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, CommitRequest) + ] + self.assertEqual(1, len(commit_requests)) + self.assertEqual( + second.transaction.id, commit_requests[0].transaction_id, + "CommitRequest must reference the same transaction_id " + "that the second ExecuteSqlRequest used", + ) + + def test_read_only_still_uses_explicit_begin(self): + """Read-only transactions should still use explicit BeginTransaction.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + with connection.cursor() as cursor: + cursor.execute("select name from singers") + rows = cursor.fetchall() + connection.commit() + + self.assertEqual( + [("Some Singer",)], rows, + "Read-only query should return the mocked result set", + ) + + self.assert_requests_sequence( + self.spanner_service.requests, + [BeginTransactionRequest, ExecuteSqlRequest], + TransactionType.READ_ONLY, + ) + + def test_rollback_after_inline_begin(self): + """Rollback after DML sends RollbackRequest with the correct transaction_id.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.rollback() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Rollback path should not use BeginTransactionRequest") + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(1, len(sql_requests)) + + rollback_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, RollbackRequest) + ] + self.assertEqual(1, len(rollback_requests), + "A RollbackRequest should be sent after DML + rollback") + + txn_id_from_inline_begin = sql_requests[0].transaction.begin + self.assertIn( + "read_write", txn_id_from_inline_begin, + "DML should have used inline begin", + ) + + self.assertNotEqual( + b"", rollback_requests[0].transaction_id, + "RollbackRequest must carry the transaction_id obtained via inline begin", + ) + + def test_inline_begin_with_abort_retry(self): + """Transaction retry after abort should work with inline begin. + + The DBAPI replays recorded statements on abort. With inline begin, + the retried ExecuteSqlRequest should again use inline begin. + """ + add_error(SpannerServicer.Commit.__name__, aborted_status()) + + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Retried transaction should also use inline begin, " + "not explicit BeginTransactionRequest") + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(2, len(sql_requests), + "Expected 2 ExecuteSqlRequests: original + retry") + for i, req in enumerate(sql_requests): + self.assertIn( + "read_write", req.transaction.begin, + f"ExecuteSqlRequest[{i}] should use inline begin", + ) + + commit_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, CommitRequest) + ] + self.assertEqual(2, len(commit_requests), + "Expected 2 CommitRequests: the aborted original + " + "the successful retry") + for i, cr in enumerate(commit_requests): + self.assertNotEqual( + b"", cr.transaction_id, + f"CommitRequest[{i}] must carry a transaction_id " + "from inline begin", + ) diff --git a/tests/mockserver_tests/test_dbapi_isolation_level.py b/tests/mockserver_tests/test_dbapi_isolation_level.py index e912914b19..a5c37e0eef 100644 --- a/tests/mockserver_tests/test_dbapi_isolation_level.py +++ b/tests/mockserver_tests/test_dbapi_isolation_level.py @@ -15,7 +15,7 @@ from google.api_core.exceptions import Unknown from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1 import ( - BeginTransactionRequest, + ExecuteSqlRequest, TransactionOptions, ) from tests.mockserver_tests.mock_server_test_base import ( @@ -24,6 +24,13 @@ ) +def _get_first_execute_sql_request(requests): + """Return the first ExecuteSqlRequest from the captured requests.""" + return next( + req for req in requests if isinstance(req, ExecuteSqlRequest) + ) + + class TestDbapiIsolationLevel(MockServerTestBase): @classmethod def setup_class(cls): @@ -36,15 +43,9 @@ def test_isolation_level_default(self): cursor.execute("insert into singers (id, name) values (1, 'Some Singer')") self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) - ) - self.assertEqual(1, len(begin_requests)) + sql_request = _get_first_execute_sql_request(self.spanner_service.requests) self.assertEqual( - begin_requests[0].options.isolation_level, + sql_request.transaction.begin.isolation_level, TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, ) @@ -62,14 +63,12 @@ def test_custom_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_isolation_level_in_connection_kwargs(self): @@ -85,14 +84,12 @@ def test_isolation_level_in_connection_kwargs(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_transaction_isolation_level(self): @@ -109,14 +106,12 @@ def test_transaction_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_begin_isolation_level(self): @@ -133,14 +128,12 @@ def test_begin_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_begin_invalid_isolation_level(self): diff --git a/tests/mockserver_tests/test_tags.py b/tests/mockserver_tests/test_tags.py index 9e35517797..4d975c8ef7 100644 --- a/tests/mockserver_tests/test_tags.py +++ b/tests/mockserver_tests/test_tags.py @@ -115,7 +115,7 @@ def test_select_read_write_transaction_no_tags(self): requests = self.spanner_service.requests self.assert_requests_sequence( requests, - [BeginTransactionRequest, ExecuteSqlRequest, CommitRequest], + [ExecuteSqlRequest, CommitRequest], TransactionType.READ_WRITE, ) @@ -131,7 +131,7 @@ def test_select_read_write_transaction_with_request_tag(self): requests = self.spanner_service.requests self.assert_requests_sequence( requests, - [BeginTransactionRequest, ExecuteSqlRequest, CommitRequest], + [ExecuteSqlRequest, CommitRequest], TransactionType.READ_WRITE, ) @@ -148,7 +148,6 @@ def test_select_read_write_transaction_with_transaction_tag(self): self.assert_requests_sequence( requests, [ - BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest, @@ -156,7 +155,7 @@ def test_select_read_write_transaction_with_transaction_tag(self): TransactionType.READ_WRITE, ) mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + tag_idx = 2 if mux_enabled else 1 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) @@ -180,7 +179,6 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self): self.assert_requests_sequence( requests, [ - BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest, @@ -188,7 +186,7 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self): TransactionType.READ_WRITE, ) mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + tag_idx = 2 if mux_enabled else 1 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6e8159425f..83d813243c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -211,6 +211,26 @@ def test_transaction_checkout(self): connection._autocommit = True self.assertIsNone(connection.transaction_checkout()) + def test_transaction_checkout_does_not_call_begin(self): + """transaction_checkout must not call Transaction.begin(). + + The transaction should be returned with _transaction_id=None so that + execute_sql/execute_update can use inline begin via + TransactionSelector(begin=...), eliminating a separate + BeginTransaction RPC. + """ + connection = Connection(INSTANCE, DATABASE) + mock_session = mock.MagicMock() + mock_transaction = mock.MagicMock() + mock_session.transaction.return_value = mock_transaction + connection._session_checkout = mock.MagicMock(return_value=mock_session) + + txn = connection.transaction_checkout() + + self.assertEqual(txn, mock_transaction) + self.assertTrue(connection._spanner_transaction_started) + mock_transaction.begin.assert_not_called() + def test_snapshot_checkout(self): connection = build_connection(read_only=True) connection.autocommit = False