Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dev = [
"caproto!=1.2.0",
]
kafka = [
"kafka-python-ng",
"kafka-python",
"msgpack",
"msgpack-numpy"
]
Expand Down
66 changes: 66 additions & 0 deletions src/sophys/common/utils/kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from datetime import datetime, timedelta, timezone, tzinfo

from kafka import KafkaConsumer
from kafka.consumer.fetcher import ConsumerRecord
from kafka.structs import TopicPartition


def seek_start_document(consumer: KafkaConsumer, record: ConsumerRecord):
"""
Attempt to seek into the start document of the current run, based on data from the last received document.
"""
topic_partition = TopicPartition(record.topic, record.partition)

offset = record.offset
event_name, event_data = record.value

beginning_offset = consumer.beginning_offsets([topic_partition])[topic_partition]
while event_name != "start" and offset != beginning_offset:
if "seq_num" in event_data:
offset = offset - event_data["seq_num"] - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea

else:
offset -= 1
consumer.seek(topic_partition, offset)

for attempt_number in range(1, 4):
timeout = 1_000 * attempt_number
records = consumer.poll(
timeout_ms=timeout, max_records=1, update_offsets=False
)

if topic_partition in records:
break

if topic_partition not in records:
raise RuntimeError(
f"Failed to retrieve records for the current partition ('{record.partition}' for '{record.topic}')"
)

event_name, event_data = records[topic_partition][0].value


def seek_back_in_time(
consumer: KafkaConsumer,
rewind_time: timedelta,
server_timezone: tzinfo = timezone.utc,
):
"""
Rewind the consumer by 'rewind_time', up to the beginning offset.
"""
now = datetime.now(server_timezone)

all_partitions = [
TopicPartition(topic_name, p)
for topic_name in consumer.subscription()
for p in consumer.partitions_for_topic(topic_name)
]

# NOTE: offsets_for_times expects timestamps in ms, so we multiply by 1000.
rewind_timestamp = int((now - rewind_time).timestamp() * 1000)
timestamp_offsets = consumer.offsets_for_times(
{p: rewind_timestamp for p in all_partitions}
)

for partition, offset_ts in timestamp_offsets.items():
if offset_ts is not None:
consumer.seek(partition, offset_ts.offset)
173 changes: 107 additions & 66 deletions src/sophys/common/utils/kafka/monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
import logging
import json
from collections import defaultdict
from datetime import timedelta
from functools import wraps, partial
from typing import Optional

Expand All @@ -10,10 +12,11 @@
import msgpack_numpy as _m

from kafka import KafkaConsumer
from kafka.structs import TopicPartition

from event_model import EventPage, unpack_event_page

from .consumer import seek_start_document, seek_back_in_time


def _get_uid_from_event_data(event_data: dict):
return event_data.get("uid", None)
Expand Down Expand Up @@ -259,13 +262,20 @@ def __getitem__(self, data: tuple):
return super().__getitem__(self.find_with_resource(resource_uid))


class SeekStartResult(enum.Enum):
NO_SEEK = enum.auto()
SEEK_SUCCEEDED = enum.auto()
SEEK_FAILED = enum.auto()


class MonitorBase(KafkaConsumer):
def __init__(
self,
save_queue: Optional[Queue],
incomplete_documents: list,
topic_name: str,
logger_name: str,
hour_offset: Optional[float] = None,
**configs,
):
"""
Expand All @@ -282,6 +292,8 @@ def __init__(
The Kafka topic to monitor.
logger_name : str, optional
Name of the logger to use for info / debug during the monitor processing.
hour_offset : float, optional
Time in hours to look back in Kafka right after the start of monitoring.
**configs : dict or keyword arguments
Extra arguments to pass to the KafkaConsumer's constructor.
"""
Expand All @@ -290,6 +302,10 @@ def __init__(
self.name = repr(self)
self.running = Event()

self._runs_to_ignore = set()

self.__hour_offset = hour_offset

self.__documents = MultipleDocumentDictionary()
self.__save_queue = save_queue

Expand Down Expand Up @@ -332,19 +348,30 @@ def topic(self):
"""Get the name of the Kafka topic monitored by this object."""
return "".join(self.subscription())

def seek_start(
self, topic: str, partition_id: int, offset: int, event_data: dict
) -> None:
"""Attempt to seek into the start document of the current run. May not seek if the current event does not have a sequence number."""
if "seq_num" not in event_data:
self._logger.debug(
"Sequence numbers are not available! o.O\n {}".format(str(event_data))
)
# Hopefully a future event will have it!
return
self.seek(
TopicPartition(topic, partition_id), offset - event_data["seq_num"] - 1
)
def _seek_start_if_needed(self, event) -> SeekStartResult:
should_seek_start = False

try:
if event.value[0] != "start" and len(self.__documents[event.value]) == 0:
# In the middle of a run, try to go back to the beginning
should_seek_start = True
except KeyError:
# In the middle of a run, try to go back to the beginning
should_seek_start = True

seek_start_succeeded = True
if should_seek_start:
try:
seek_start_document(self, event)
except RuntimeError:
self._logger.error("Run '{}': {}", event)
seek_start_succeeded = False

if not should_seek_start:
return SeekStartResult.NO_SEEK
if not seek_start_succeeded:
return SeekStartResult.SEEK_FAILED
return SeekStartResult.SEEK_SUCCEEDED

def _commit_pending_documents(self):
"""Commit pending documents to the save queue, when possible."""
Expand All @@ -359,11 +386,16 @@ def _commit_pending_documents(self):
self.__save_queue.put(doc, block=True, timeout=1.0)
self.__saved_document_uids.add(id)
except Exception as e:
self._logger.error(
"Unhandled exception while trying to save documents. Will try to continue regardless."
)
self._logger.error("Exception if you're into that:")
self._logger.exception(e)
if isinstance(e, QueueFullException):
self._logger.warning(
"Save queue is full. Failed to add run '%s'.", id
)
else:
self._logger.error(
"Unhandled exception while trying to save documents. Will try to continue regardless."
)
self._logger.error("Exception if you're into that:")
self._logger.exception(e)

self.__to_save_documents_save_attempts[id] += 1

Expand All @@ -386,61 +418,65 @@ def _commit_pending_documents(self):
if id in self.__to_save_documents_save_attempts:
del self.__to_save_documents_save_attempts[id]

def handle_event(self, event):
self._logger.debug("Event received.")

seek_start = False
def _handle_kafka_event(self, event):
data = event.value
if len(data) != 2:
self._logger.warning(
"Event data does not have two elements.\n {}".format(str(data))
)
return

try:
data = event.value

if len(data) != 2:
self._logger.warning(
"Event data does not have two elements.\n {}".format(str(data))
)
match self._seek_start_if_needed(event):
case SeekStartResult.NO_SEEK:
pass
case SeekStartResult.SEEK_SUCCEEDED:
return
case SeekStartResult.SEEK_FAILED:
start_uid = _get_start_uid_from_event_data(data[1])
if start_uid is not None:
self._runs_to_ignore.add(start_uid)
return

start_uid = _get_start_uid_from_event_data(data[1])
if start_uid in self._runs_to_ignore:
return

if data[0] == "start":
self._logger.info("Received a 'start' document.")
match data:
case ("start", _):
self.__documents.append(data)

self.__documents.append(data)
new_run_uid = self.__documents[data].identifier
self.__incomplete_documents.append(new_run_uid)

self.__documents[data].subscribe(
partial(self._run_subscriptions, new_run_uid)
)
new_run_uid = self.__documents[data].identifier
self._logger.info(
"Run '{}': Received a 'start' document.".format(new_run_uid)
)

return
self.__incomplete_documents.append(new_run_uid)

try:
if len(self.__documents[data]) == 0:
# In the middle of a run, try to go back to the beginning
seek_start = True
except KeyError:
# In the middle of a run, try to go back to the beginning
seek_start = True

if seek_start:
self.seek_start(event.topic, event.partition, event.offset, data[1])
return
self.__documents[data].subscribe(
partial(self._run_subscriptions, new_run_uid)
)
case ("stop", _):
self._logger.info(
"Run '{}': Received a 'stop' document.".format(
self.__documents[data].identifier
)
)

self.__documents[data].append(*data)
self.__documents[data].append(*data)

if data[0] == "stop":
self._logger.info(
"Run '{}': Received a 'stop' document.".format(
self.__documents[data].identifier
)
)
self.__documents[data].clear_subscriptions()
self.__to_save_documents.append(self.__documents[data].identifier)

self.__documents[data].clear_subscriptions()
self.__to_save_documents.append(self.__documents[data].identifier)
start_uid = self.__documents[data].identifier
self._runs_to_ignore.discard(start_uid)

# TODO: Validate number of saved entries via the stop document's num_events
# TODO: Validate successful run via the stop document's exit_status
# TODO: Validate number of saved entries via the stop document's num_events
# TODO: Validate successful run via the stop document's exit_status

self._commit_pending_documents()
self._commit_pending_documents()
case (_, _):
self.__documents[data].append(*data)

except Exception as e:
self._logger.error("Unhandled exception. Will try to continue regardless.")
Expand All @@ -452,15 +488,20 @@ def handle_event(self, event):

def run(self):
"""Start monitoring the Kafka topic."""
partition_number = list(self.partitions_for_topic(self.topic()))[0]
self._update_fetch_positions([TopicPartition(self.topic(), partition_number)])
# NOTE: Configure the current offset before setting the 'running' flag.
# The timeout time here doesn't matter too much, as we don't care
# whether we received new data or not.
self.poll(timeout_ms=100, max_records=1, update_offsets=False)

if self.__hour_offset is not None:
seek_back_in_time(self, timedelta(hours=self.__hour_offset))

self.running.set()
while not self._closed:
try:
for event in self:
self.handle_event(event)
except StopIteration:
self._handle_kafka_event(event)
except (StopIteration, AssertionError):
pass

self.running.clear()
Expand Down
47 changes: 47 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import pytest

import msgpack_numpy as msgpack

from kafka import KafkaConsumer, KafkaProducer

from bluesky import RunEngine

from .soft_ioc import start_soft_ioc


Expand All @@ -8,3 +14,44 @@ def soft_ioc():
soft_ioc_prefix, stop_soft_ioc = start_soft_ioc()
yield soft_ioc_prefix
stop_soft_ioc()


@pytest.fixture(scope="session")
def kafka_bootstrap_ip():
return "localhost:9092"


@pytest.fixture(scope="session")
def kafka_topic():
return "test_bluesky_raw_docs"


@pytest.fixture(scope="function")
def kafka_producer(kafka_bootstrap_ip):
producer = KafkaProducer(
bootstrap_servers=[kafka_bootstrap_ip], value_serializer=msgpack.dumps
)
yield producer
producer.flush()
producer.close()


@pytest.fixture(scope="function")
def kafka_consumer(kafka_bootstrap_ip, kafka_topic):
consumer = KafkaConsumer(
kafka_topic,
bootstrap_servers=[kafka_bootstrap_ip],
value_deserializer=msgpack.unpackb,
)

# Connect the consumer properly to the topic
consumer.poll(timeout_ms=100, max_records=1, update_offsets=False)

return consumer


@pytest.fixture(scope="function")
def run_engine_without_md(kafka_producer, kafka_topic):
RE = RunEngine()
RE.subscribe(lambda name, doc: kafka_producer.send(kafka_topic, (name, doc)))
return RE
Loading
Loading