diff --git a/pyproject.toml b/pyproject.toml index 79d5de2..e86728e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dev = [ "caproto!=1.2.0", ] kafka = [ - "kafka-python-ng", + "kafka-python", "msgpack", "msgpack-numpy" ] diff --git a/src/sophys/common/utils/kafka/consumer.py b/src/sophys/common/utils/kafka/consumer.py new file mode 100644 index 0000000..37bb864 --- /dev/null +++ b/src/sophys/common/utils/kafka/consumer.py @@ -0,0 +1,66 @@ +from datetime import datetime, timedelta, timezone, tzinfo + +from kafka import KafkaConsumer +from kafka.consumer.fetcher import ConsumerRecord +from kafka.structs import TopicPartition + + +def seek_start_document(consumer: KafkaConsumer, record: ConsumerRecord): + """ + Attempt to seek into the start document of the current run, based on data from the last received document. + """ + topic_partition = TopicPartition(record.topic, record.partition) + + offset = record.offset + event_name, event_data = record.value + + beginning_offset = consumer.beginning_offsets([topic_partition])[topic_partition] + while event_name != "start" and offset != beginning_offset: + if "seq_num" in event_data: + offset = offset - event_data["seq_num"] - 1 + else: + offset -= 1 + consumer.seek(topic_partition, offset) + + for attempt_number in range(1, 4): + timeout = 1_000 * attempt_number + records = consumer.poll( + timeout_ms=timeout, max_records=1, update_offsets=False + ) + + if topic_partition in records: + break + + if topic_partition not in records: + raise RuntimeError( + f"Failed to retrieve records for the current partition ('{record.partition}' for '{record.topic}')" + ) + + event_name, event_data = records[topic_partition][0].value + + +def seek_back_in_time( + consumer: KafkaConsumer, + rewind_time: timedelta, + server_timezone: tzinfo = timezone.utc, +): + """ + Rewind the consumer by 'rewind_time', up to the beginning offset. + """ + now = datetime.now(server_timezone) + + all_partitions = [ + TopicPartition(topic_name, p) + for topic_name in consumer.subscription() + for p in consumer.partitions_for_topic(topic_name) + ] + + # NOTE: offsets_for_times expects timestamps in ms, so we multiply by 1000. + rewind_timestamp = int((now - rewind_time).timestamp() * 1000) + timestamp_offsets = consumer.offsets_for_times( + {p: rewind_timestamp for p in all_partitions} + ) + + for partition, offset_ts in timestamp_offsets.items(): + if offset_ts is not None: + consumer.seek(partition, offset_ts.offset) diff --git a/src/sophys/common/utils/kafka/monitor.py b/src/sophys/common/utils/kafka/monitor.py index 53c5f16..7441d79 100644 --- a/src/sophys/common/utils/kafka/monitor.py +++ b/src/sophys/common/utils/kafka/monitor.py @@ -1,6 +1,8 @@ +import enum import logging import json from collections import defaultdict +from datetime import timedelta from functools import wraps, partial from typing import Optional @@ -10,10 +12,11 @@ import msgpack_numpy as _m from kafka import KafkaConsumer -from kafka.structs import TopicPartition from event_model import EventPage, unpack_event_page +from .consumer import seek_start_document, seek_back_in_time + def _get_uid_from_event_data(event_data: dict): return event_data.get("uid", None) @@ -259,6 +262,12 @@ def __getitem__(self, data: tuple): return super().__getitem__(self.find_with_resource(resource_uid)) +class SeekStartResult(enum.Enum): + NO_SEEK = enum.auto() + SEEK_SUCCEEDED = enum.auto() + SEEK_FAILED = enum.auto() + + class MonitorBase(KafkaConsumer): def __init__( self, @@ -266,6 +275,7 @@ def __init__( incomplete_documents: list, topic_name: str, logger_name: str, + hour_offset: Optional[float] = None, **configs, ): """ @@ -282,6 +292,8 @@ def __init__( The Kafka topic to monitor. logger_name : str, optional Name of the logger to use for info / debug during the monitor processing. + hour_offset : float, optional + Time in hours to look back in Kafka right after the start of monitoring. **configs : dict or keyword arguments Extra arguments to pass to the KafkaConsumer's constructor. """ @@ -290,6 +302,10 @@ def __init__( self.name = repr(self) self.running = Event() + self._runs_to_ignore = set() + + self.__hour_offset = hour_offset + self.__documents = MultipleDocumentDictionary() self.__save_queue = save_queue @@ -332,19 +348,30 @@ def topic(self): """Get the name of the Kafka topic monitored by this object.""" return "".join(self.subscription()) - def seek_start( - self, topic: str, partition_id: int, offset: int, event_data: dict - ) -> None: - """Attempt to seek into the start document of the current run. May not seek if the current event does not have a sequence number.""" - if "seq_num" not in event_data: - self._logger.debug( - "Sequence numbers are not available! o.O\n {}".format(str(event_data)) - ) - # Hopefully a future event will have it! - return - self.seek( - TopicPartition(topic, partition_id), offset - event_data["seq_num"] - 1 - ) + def _seek_start_if_needed(self, event) -> SeekStartResult: + should_seek_start = False + + try: + if event.value[0] != "start" and len(self.__documents[event.value]) == 0: + # In the middle of a run, try to go back to the beginning + should_seek_start = True + except KeyError: + # In the middle of a run, try to go back to the beginning + should_seek_start = True + + seek_start_succeeded = True + if should_seek_start: + try: + seek_start_document(self, event) + except RuntimeError: + self._logger.error("Run '{}': {}", event) + seek_start_succeeded = False + + if not should_seek_start: + return SeekStartResult.NO_SEEK + if not seek_start_succeeded: + return SeekStartResult.SEEK_FAILED + return SeekStartResult.SEEK_SUCCEEDED def _commit_pending_documents(self): """Commit pending documents to the save queue, when possible.""" @@ -359,11 +386,16 @@ def _commit_pending_documents(self): self.__save_queue.put(doc, block=True, timeout=1.0) self.__saved_document_uids.add(id) except Exception as e: - self._logger.error( - "Unhandled exception while trying to save documents. Will try to continue regardless." - ) - self._logger.error("Exception if you're into that:") - self._logger.exception(e) + if isinstance(e, QueueFullException): + self._logger.warning( + "Save queue is full. Failed to add run '%s'.", id + ) + else: + self._logger.error( + "Unhandled exception while trying to save documents. Will try to continue regardless." + ) + self._logger.error("Exception if you're into that:") + self._logger.exception(e) self.__to_save_documents_save_attempts[id] += 1 @@ -386,61 +418,65 @@ def _commit_pending_documents(self): if id in self.__to_save_documents_save_attempts: del self.__to_save_documents_save_attempts[id] - def handle_event(self, event): - self._logger.debug("Event received.") - - seek_start = False + def _handle_kafka_event(self, event): + data = event.value + if len(data) != 2: + self._logger.warning( + "Event data does not have two elements.\n {}".format(str(data)) + ) + return try: - data = event.value - - if len(data) != 2: - self._logger.warning( - "Event data does not have two elements.\n {}".format(str(data)) - ) + match self._seek_start_if_needed(event): + case SeekStartResult.NO_SEEK: + pass + case SeekStartResult.SEEK_SUCCEEDED: + return + case SeekStartResult.SEEK_FAILED: + start_uid = _get_start_uid_from_event_data(data[1]) + if start_uid is not None: + self._runs_to_ignore.add(start_uid) + return + + start_uid = _get_start_uid_from_event_data(data[1]) + if start_uid in self._runs_to_ignore: return - if data[0] == "start": - self._logger.info("Received a 'start' document.") + match data: + case ("start", _): + self.__documents.append(data) - self.__documents.append(data) - new_run_uid = self.__documents[data].identifier - self.__incomplete_documents.append(new_run_uid) - - self.__documents[data].subscribe( - partial(self._run_subscriptions, new_run_uid) - ) + new_run_uid = self.__documents[data].identifier + self._logger.info( + "Run '{}': Received a 'start' document.".format(new_run_uid) + ) - return + self.__incomplete_documents.append(new_run_uid) - try: - if len(self.__documents[data]) == 0: - # In the middle of a run, try to go back to the beginning - seek_start = True - except KeyError: - # In the middle of a run, try to go back to the beginning - seek_start = True - - if seek_start: - self.seek_start(event.topic, event.partition, event.offset, data[1]) - return + self.__documents[data].subscribe( + partial(self._run_subscriptions, new_run_uid) + ) + case ("stop", _): + self._logger.info( + "Run '{}': Received a 'stop' document.".format( + self.__documents[data].identifier + ) + ) - self.__documents[data].append(*data) + self.__documents[data].append(*data) - if data[0] == "stop": - self._logger.info( - "Run '{}': Received a 'stop' document.".format( - self.__documents[data].identifier - ) - ) + self.__documents[data].clear_subscriptions() + self.__to_save_documents.append(self.__documents[data].identifier) - self.__documents[data].clear_subscriptions() - self.__to_save_documents.append(self.__documents[data].identifier) + start_uid = self.__documents[data].identifier + self._runs_to_ignore.discard(start_uid) - # TODO: Validate number of saved entries via the stop document's num_events - # TODO: Validate successful run via the stop document's exit_status + # TODO: Validate number of saved entries via the stop document's num_events + # TODO: Validate successful run via the stop document's exit_status - self._commit_pending_documents() + self._commit_pending_documents() + case (_, _): + self.__documents[data].append(*data) except Exception as e: self._logger.error("Unhandled exception. Will try to continue regardless.") @@ -452,15 +488,20 @@ def handle_event(self, event): def run(self): """Start monitoring the Kafka topic.""" - partition_number = list(self.partitions_for_topic(self.topic()))[0] - self._update_fetch_positions([TopicPartition(self.topic(), partition_number)]) + # NOTE: Configure the current offset before setting the 'running' flag. + # The timeout time here doesn't matter too much, as we don't care + # whether we received new data or not. + self.poll(timeout_ms=100, max_records=1, update_offsets=False) + + if self.__hour_offset is not None: + seek_back_in_time(self, timedelta(hours=self.__hour_offset)) self.running.set() while not self._closed: try: for event in self: - self.handle_event(event) - except StopIteration: + self._handle_kafka_event(event) + except (StopIteration, AssertionError): pass self.running.clear() diff --git a/tests/conftest.py b/tests/conftest.py index 4e4a8a2..77f7ed3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,11 @@ import pytest +import msgpack_numpy as msgpack + +from kafka import KafkaConsumer, KafkaProducer + +from bluesky import RunEngine + from .soft_ioc import start_soft_ioc @@ -8,3 +14,44 @@ def soft_ioc(): soft_ioc_prefix, stop_soft_ioc = start_soft_ioc() yield soft_ioc_prefix stop_soft_ioc() + + +@pytest.fixture(scope="session") +def kafka_bootstrap_ip(): + return "localhost:9092" + + +@pytest.fixture(scope="session") +def kafka_topic(): + return "test_bluesky_raw_docs" + + +@pytest.fixture(scope="function") +def kafka_producer(kafka_bootstrap_ip): + producer = KafkaProducer( + bootstrap_servers=[kafka_bootstrap_ip], value_serializer=msgpack.dumps + ) + yield producer + producer.flush() + producer.close() + + +@pytest.fixture(scope="function") +def kafka_consumer(kafka_bootstrap_ip, kafka_topic): + consumer = KafkaConsumer( + kafka_topic, + bootstrap_servers=[kafka_bootstrap_ip], + value_deserializer=msgpack.unpackb, + ) + + # Connect the consumer properly to the topic + consumer.poll(timeout_ms=100, max_records=1, update_offsets=False) + + return consumer + + +@pytest.fixture(scope="function") +def run_engine_without_md(kafka_producer, kafka_topic): + RE = RunEngine() + RE.subscribe(lambda name, doc: kafka_producer.send(kafka_topic, (name, doc))) + return RE diff --git a/tests/utils/test_kafka_consumer.py b/tests/utils/test_kafka_consumer.py new file mode 100644 index 0000000..3685d55 --- /dev/null +++ b/tests/utils/test_kafka_consumer.py @@ -0,0 +1,110 @@ +from datetime import datetime, timezone + +from kafka.producer import KafkaProducer +from kafka.consumer import KafkaConsumer +from kafka.structs import TopicPartition + +import numpy as np + +from ophyd.sim import hw +from bluesky import plans as bp + +from sophys.common.utils.kafka.monitor import seek_start_document, seek_back_in_time + + +def test_seek_start( + kafka_producer: KafkaProducer, + kafka_consumer: KafkaConsumer, + kafka_topic, + run_engine_without_md, +): + partition_number = list(kafka_consumer.partitions_for_topic(kafka_topic))[0] + topic_partition = TopicPartition(kafka_topic, partition_number) + original_offset = kafka_consumer.position(topic_partition) + + uid, *_ = run_engine_without_md(bp.count([hw().det], num=10)) + + def seek_and_assert_positions(offset: int, seeked_event_name: str): + kafka_consumer.seek(topic_partition, offset) + + record = kafka_consumer.poll( + timeout_ms=1_000, max_records=1, update_offsets=False + )[topic_partition][0] + + event_name, _ = record.value + assert event_name == seeked_event_name + + seek_start_document(kafka_consumer, record) + + record = kafka_consumer.poll( + timeout_ms=1_000, max_records=1, update_offsets=False + )[topic_partition][0] + + event_name, _ = record.value + assert event_name == "start" + + kafka_producer.flush(timeout=1.0) + while kafka_consumer.poll(timeout_ms=100) != {}: + pass + + new_offset = kafka_consumer.position(topic_partition) + # start (1) + descriptor (1) + events (10) + stop (1) + assert new_offset - original_offset == 13 + + # From stop to start + seek_and_assert_positions(new_offset - 1, "stop") + + # From start to start (do nothing) + seek_and_assert_positions(original_offset, "start") + + # From event in the middle to start + seek_and_assert_positions(original_offset + 5, "event") + + # From descriptor to start + seek_and_assert_positions(original_offset + 1, "descriptor") + + +def test_seek_back_in_time( + kafka_producer: KafkaProducer, + kafka_consumer: KafkaConsumer, + kafka_topic, + run_engine_without_md, +): + partition_number = list(kafka_consumer.partitions_for_topic(kafka_topic))[0] + topic_partition = TopicPartition(kafka_topic, partition_number) + + kafka_consumer.seek_to_beginning() + oldest_offset = kafka_consumer.position(topic_partition) + + kafka_consumer.seek_to_end() + newest_offset = kafka_consumer.position(topic_partition) + + if newest_offset - oldest_offset < 5: + # Add some new time-spaced data. + for _ in range(5): + run_engine_without_md(bp.count([hw().det], num=2, delay=1)) + while kafka_consumer.poll(timeout_ms=100) != {}: + pass + newest_offset = kafka_consumer.position(topic_partition) + + offsets = [round(x) for x in np.linspace(oldest_offset, newest_offset - 1, num=5)] + timestamps = list() + for offset in offsets: + kafka_consumer.seek(topic_partition, offset) + + record = kafka_consumer.poll( + timeout_ms=1_000, max_records=1, update_offsets=False + )[topic_partition][0] + timestamps.append(record.timestamp // 1000) + + for expected_timestamp in timestamps: + time_delta = datetime.now(timezone.utc) - datetime.fromtimestamp( + expected_timestamp, tz=timezone.utc + ) + seek_back_in_time(kafka_consumer, time_delta) + + record = kafka_consumer.poll( + timeout_ms=1_000, max_records=1, update_offsets=False + )[topic_partition][0] + + assert np.isclose(record.timestamp // 1000, expected_timestamp, atol=1) diff --git a/tests/utils/test_kafka_monitor.py b/tests/utils/test_kafka_monitor.py index ac8337d..0d35c84 100644 --- a/tests/utils/test_kafka_monitor.py +++ b/tests/utils/test_kafka_monitor.py @@ -1,14 +1,6 @@ import pytest import queue -import time - -import msgpack -import msgpack_numpy as _m - -from kafka.producer import KafkaProducer -from kafka.consumer import KafkaConsumer -from kafka.structs import TopicPartition from ophyd.sim import hw from bluesky import RunEngine, plans as bp, plan_stubs as bps, preprocessors as bpp @@ -18,19 +10,6 @@ from . import _wait -_m.patch() - - -@pytest.fixture(scope="session") -def kafka_bootstrap_ip(): - return "localhost:9092" - - -@pytest.fixture(scope="session") -def kafka_topic(): - return "test_bluesky_raw_docs" - - @pytest.fixture(scope="function") def save_queue_size() -> int: return 4 @@ -80,34 +59,6 @@ def incomplete_documents(_incomplete_documents): return _incomplete_documents -@pytest.fixture(scope="function") -def kafka_producer(kafka_bootstrap_ip): - producer = KafkaProducer( - bootstrap_servers=[kafka_bootstrap_ip], value_serializer=msgpack.dumps - ) - yield producer - producer.flush() - producer.close() - - -@pytest.fixture(scope="function") -def kafka_consumer(kafka_bootstrap_ip, kafka_topic): - consumer = KafkaConsumer( - bootstrap_servers=[kafka_bootstrap_ip], value_deserializer=msgpack.unpackb - ) - - # Connect the consumer properly to the topic - partition = TopicPartition(kafka_topic, 0) - consumer.assign([partition]) - print("Starting offset:") - # Fun fact: this is actually required for the tests to work properly, - # because otherwise it doesn't update the current offset before the - # producer starts throwing events at the topic. :))))) - print(consumer.position(partition)) - - return consumer - - @pytest.fixture(scope="function") def base_md(tmp_path_factory): return { @@ -123,15 +74,7 @@ def run_engine_with_md(base_md, kafka_producer, kafka_topic): return RE -@pytest.fixture(scope="function") -def run_engine_without_md(kafka_producer, kafka_topic): - RE = RunEngine() - RE.subscribe(lambda name, doc: kafka_producer.send(kafka_topic, (name, doc))) - return RE - - -@pytest.fixture(scope="function") -def good_monitor( +def _create_good_monitor( save_queue, incomplete_documents, kafka_topic, kafka_bootstrap_ip ) -> ThreadedMonitor: mon = ThreadedMonitor( @@ -148,6 +91,15 @@ def good_monitor( return mon +@pytest.fixture(scope="function") +def good_monitor( + save_queue, incomplete_documents, kafka_topic, kafka_bootstrap_ip +) -> ThreadedMonitor: + return _create_good_monitor( + save_queue, incomplete_documents, kafka_topic, kafka_bootstrap_ip + ) + + # # Tests # @@ -368,3 +320,47 @@ def custom_plan(): # One start doc, one descriptor doc, one event doc, one stop doc assert len(docs) == 4, docs.get_raw_data() + + +def test_seek_start_in_monitor( + run_engine_without_md, + incomplete_documents, + save_queue: queue.Queue, + kafka_topic, + kafka_bootstrap_ip, +): + det = hw().det + + # NOTE: Add another run before this one just to check it doesn't skip into the previous run. + run_engine_without_md(bp.count([det], num=1)) + + def custom_plan(): + yield from bps.open_run({}) + yield from bps.declare_stream(det, name="primary") + for _ in range(5): + yield from bps.create() + yield from bps.read(det) + yield from bps.save() + + monitor = _create_good_monitor( + save_queue, incomplete_documents, kafka_topic, kafka_bootstrap_ip + ) + assert monitor.is_alive() + + for _ in range(5): + yield from bps.create() + yield from bps.read(det) + yield from bps.save() + yield from bps.close_run("success") + + # Only populated if 'monitor' is working properly, and rewinded to the start document. + assert save_queue.get(True, timeout=2.0) is not None + + monitor.close() + _wait( + lambda: not monitor.running.is_set(), + timeout=2.0, + timeout_msg="Monitor took too long to close.", + ) + + run_engine_without_md(custom_plan())