diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..f5f2226af1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4137,8 +4137,11 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai if schema_mismatches is None: return True + elapsed = self._time.time() - start log.debug("[control connection] Schemas mismatched, trying again") - self._time.sleep(0.2) + remaining = total_timeout - elapsed + if remaining > 0: + self._time.sleep(min(0.2, remaining)) elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..f1806618a9 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -642,16 +642,25 @@ def maybe_request_more(self): if space_in_queue >= max_queue_size / 2: self.update_next_pages(space_in_queue) + def _send_revise_request(self, request, callback): + with self.connection.lock: + request_id = self.connection.get_request_id() + try: + self.connection.send_msg(request, request_id, callback) + except Exception: + if request_id not in self.connection._requests and request_id not in self.connection.request_ids: + self.connection.request_ids.append(request_id) + raise + def update_next_pages(self, num_next_pages): try: self._state.num_pages_requested += num_next_pages log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host) - with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, - self.stream_id, - next_pages=num_next_pages), - self.connection.get_request_id(), - self._on_backpressure_response) + self._send_revise_request( + ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, + self.stream_id, + next_pages=num_next_pages), + self._on_backpressure_response) except ConnectionShutdown as ex: log.debug("Failed to update backpressure for session %s from %s, connection is shutdown", self.stream_id, self.connection.host) @@ -668,11 +677,10 @@ def _on_backpressure_response(self, response): def cancel(self): try: log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host) - with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, - self.stream_id), - self.connection.get_request_id(), - self._on_cancel_response) + self._send_revise_request( + ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, + self.stream_id), + self._on_cancel_response) except ConnectionShutdown: log.debug("Failed to cancel session %s from %s, connection is shutdown", self.stream_id, self.connection.host) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..27110b36b8 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -363,6 +363,18 @@ def test_wait_for_responses_shutdown_includes_last_error(self): assert "already closed" in error_message assert "Bad file descriptor" in error_message + def test_continuous_paging_cancel_releases_request_id_when_send_fails(self): + c = self.make_connection() + c.push = Mock(side_effect=ConnectionException("write failed")) + state = Mock(max_queue_size=100, num_pages_requested=0, num_pages_received=0) + session = c.new_continuous_paging_session(1, Mock(), Mock(), state) + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionException): + session.cancel() + + assert len(c.request_ids) == initial_request_ids + assert not c._requests @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..b9f59ea081 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -245,6 +245,152 @@ def test_wait_for_schema_agreement_fails(self): # the control connection should have slept until it hit the limit assert self.time.clock >= self.cluster.max_schema_agreement_wait +<<<<<<< HEAD +======= + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_closes(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionShutdown("closed") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_is_busy(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionBusy("overloaded") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_falls_back_to_session_when_connection_errors(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = ConnectionException("write failed") + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_called_once_with(wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_session_fallback_skips_failing_sessions(self): + failing_session = Mock(is_shutdown=False) + failing_session.wait_for_schema_agreement.side_effect = ConnectionException("session broken") + healthy_session = Mock(is_shutdown=False) + healthy_session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [failing_session, healthy_session] + self.connection.wait_for_responses.side_effect = ConnectionBusy("overloaded") + + assert self.control_connection.wait_for_schema_agreement() + failing_session.wait_for_schema_agreement.assert_called_once_with( + wait_time=self.cluster.max_schema_agreement_wait) + healthy_session.wait_for_schema_agreement.assert_called_once_with( + wait_time=self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_subtracts_elapsed_time_before_session_fallback(self): + session = Mock(is_shutdown=False) + + def wait_for_responses(*args, **kwargs): + self.time.sleep(3) + raise ConnectionShutdown("closed") + + def wait_for_schema_agreement(wait_time=None): + self.time.sleep(wait_time) + return False + + self.cluster.sessions = [session] + self.connection.wait_for_responses.side_effect = wait_for_responses + session.wait_for_schema_agreement.side_effect = wait_for_schema_agreement + + assert not self.control_connection.wait_for_schema_agreement() + assert self.time.clock == self.cluster.max_schema_agreement_wait + + def test_wait_for_schema_agreement_does_not_accept_session_fallback_after_known_mismatch(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.connection.peer_results[1][1][2] = 'b' + + assert not self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_not_called() + + def test_wait_for_schema_agreement_retries_control_connection_after_mismatch_then_busy(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + matching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionBusy("overloaded"), + _node_meta_results(self.connection.local_results, (peer_columns, matching_peer_rows))] + + assert self.control_connection.wait_for_schema_agreement() + session.wait_for_schema_agreement.assert_not_called() + assert self.connection.wait_for_responses.call_count == 3 + + def test_wait_for_schema_agreement_raises_connection_error_after_mismatch(self): + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionShutdown("closed")] + + with self.assertRaises(ConnectionShutdown): + self.control_connection.wait_for_schema_agreement() + + def test_schema_change_refresh_does_not_session_fallback_after_mismatch_then_connection_error(self): + session = Mock(is_shutdown=False) + session.wait_for_schema_agreement.return_value = True + self.cluster.sessions = [session] + self.cluster.metadata.refresh = Mock() + + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + self.connection.wait_for_responses.side_effect = [ + _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)), + ConnectionShutdown("closed")] + + response_future = Mock() + response_future.session = session + event = {'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + + refresh_schema_and_set_result(self.control_connection, response_future, self.connection, **event) + + session.wait_for_schema_agreement.assert_not_called() + self.cluster.metadata.refresh.assert_not_called() + assert not response_future.is_schema_agreed + response_future._set_final_result.assert_called_once_with(None) + + def test_wait_for_schema_agreement_does_not_sleep_past_deadline_after_mismatch(self): + self.cluster.max_schema_agreement_wait = 0.1 + self.connection.peer_results[1][1][2] = 'b' + + assert not self.control_connection.wait_for_schema_agreement() + assert self.time.clock == self.cluster.max_schema_agreement_wait + + def test_wait_for_schema_agreement_counts_query_time_before_mismatch_retry_sleep(self): + self.cluster.max_schema_agreement_wait = 0.1 + peer_columns = self.connection.peer_results[0] + mismatching_peer_rows = [list(row) for row in self.connection.peer_results[1]] + mismatching_peer_rows[1][2] = 'b' + + def wait_for_responses(*args, **kwargs): + self.time.sleep(0.09) + return _node_meta_results(self.connection.local_results, (peer_columns, mismatching_peer_rows)) + + self.connection.wait_for_responses.side_effect = wait_for_responses + + assert not self.control_connection.wait_for_schema_agreement() + self.assertAlmostEqual(self.time.clock, self.cluster.max_schema_agreement_wait) + def test_wait_for_schema_agreement_skipping(self): """ If rpc_address or schema_version isn't set, the host should be skipped