Source code for noob.runner.zmq

"""
- Central command pub/sub
- each sub-runner has its own set of sockets for publishing and consuming events
- use the `node_id.signal` etc. as basically a feed address

.. todo::

    Currently only IPC is supported, and thus the zmq runner can't run across machines.
    Supporting TCP is WIP, it will require some degree of authentication
    among nodes to prevent arbitrary code execution,
    since we shouldn't count on users to properly firewall their runners.


.. todo::

    The socket spawning and event handling is awfully manual here.
    Leaving it as is because it's somewhat unlikely we'll need to generalize it,
    but otherwise it would be great to standardize socket names and have
    event handler decorators like:

        @on_router(MessageType.sometype)

"""

import multiprocessing as mp
import os
import signal
import threading
import traceback
from collections import defaultdict
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from itertools import count
from multiprocessing.synchronize import Event as EventType
from time import time
from types import FrameType
from typing import TYPE_CHECKING, Any, Literal, cast

from noob.network.loop import EventloopMixin

try:
    import zmq
except ImportError as e:
    raise ImportError(
        "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`",
    ) from e

from zmq.eventloop.zmqstream import ZMQStream

from noob.config import config
from noob.event import Event
from noob.input import InputCollection
from noob.logging import init_logger
from noob.network.message import (
    AnnounceMsg,
    AnnounceValue,
    ErrorMsg,
    ErrorValue,
    EventMsg,
    IdentifyMsg,
    IdentifyValue,
    Message,
    MessageType,
    NodeStatus,
    ProcessMsg,
    StatusMsg,
    StopMsg,
)
from noob.node import Node, NodeSpecification, Return
from noob.runner.base import TubeRunner
from noob.scheduler import Scheduler
from noob.store import EventStore
from noob.types import NodeID, ReturnNodeType

if TYPE_CHECKING:
    pass


[docs] class CommandNode(EventloopMixin): """ Pub node that controls the state of the other nodes/announces addresses - one PUB socket to distribute commands - one ROUTER socket to receive return messages from runner nodes - one SUB socket to subscribe to all events The wrapping runner should register callbacks with `add_callback` to handle incoming messages. """ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = None): """ Args: runner_id (str): The unique ID for the runner/tube session. All nodes within a runner use this to limit communication within a tube protocol: port: """ super().__init__() self.runner_id = runner_id self.port = port self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") self._outbox: zmq.Socket = None # type: ignore[assignment] self._inbox: ZMQStream = None # type: ignore[assignment] self._router: ZMQStream = None # type: ignore[assignment] self._nodes: dict[str, IdentifyValue] = {} self._ready_condition = threading.Condition() self._callbacks: dict[str, list[Callable[[Message], Any]]] = defaultdict(list) @property def pub_address(self) -> str: """Address the publisher bound to""" if self.protocol == "ipc": path = config.tmp_dir / f"{self.runner_id}/command/outbox" path.parent.mkdir(parents=True, exist_ok=True) return f"{self.protocol}://{str(path)}" else: raise NotImplementedError() @property def router_address(self) -> str: """Address the return router is bound to""" if self.protocol == "ipc": path = config.tmp_dir / f"{self.runner_id}/command/inbox" path.parent.mkdir(parents=True, exist_ok=True) return f"{self.protocol}://{str(path)}" else: raise NotImplementedError()
[docs] def start(self) -> None: self.logger.debug("Starting command runner") self._init_sockets() self.start_loop() self.logger.debug("Command runner started")
[docs] def stop(self) -> None: self.logger.debug("Stopping command runner") msg = StopMsg(node_id="command") self._outbox.send_multipart([b"stop", msg.to_bytes()]) self.stop_loop() self.logger.debug("Command runner stopped")
def _init_sockets(self) -> None: self._outbox = self._init_outbox() self._router = self._init_router() self._inbox = self._init_inbox() def _init_outbox(self) -> zmq.Socket: """Create the main control publisher""" pub = self.context.socket(zmq.PUB) pub.bind(self.pub_address) pub.setsockopt_string(zmq.IDENTITY, "command.outbox") return pub def _init_router(self) -> ZMQStream: """Create the inbox router""" router = self.context.socket(zmq.ROUTER) router.bind(self.router_address) router.setsockopt_string(zmq.IDENTITY, "command.router") router = ZMQStream(router, self.loop) router.on_recv(self.on_router) self.logger.debug("Inbox bound to %s", self.router_address) return router def _init_inbox(self) -> ZMQStream: """Subscriber that receives all events from running nodes""" sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, "command.inbox") sub.setsockopt_string(zmq.SUBSCRIBE, "") sub = ZMQStream(sub, self.loop) sub.on_recv(self.on_inbox) return sub
[docs] def announce(self) -> None: msg = AnnounceMsg( node_id="command", value=AnnounceValue(inbox=self.router_address, nodes=self._nodes) ) self._outbox.send_multipart([b"announce", msg.to_bytes()])
[docs] def process(self) -> None: """Emit a ProcessMsg to process a single round through the graph""" self._outbox.send_multipart([b"process", ProcessMsg(node_id="command").to_bytes()]) self.logger.debug("Sent process message")
[docs] def add_callback(self, type_: Literal["inbox", "router"], cb: Callable[[Message], Any]) -> None: """ Add a callback called for message received - by the inbox: the subscriber that receives all events from node runners - by the router: direct messages sent by node runners to the command node """ self._callbacks[type_].append(cb)
[docs] def await_ready(self, node_ids: list[NodeID]) -> None: """ Wait until all the node_ids have announced themselves """ with self._ready_condition: if set(node_ids) == set(self._nodes): return def _is_ready() -> bool: return set(node_ids) == { node_id for node_id, state in self._nodes.items() if state["status"] == "ready" } self._ready_condition.wait_for(_is_ready)
[docs] def on_router(self, msg: list[bytes]) -> None: try: message = Message.from_bytes(msg) self.logger.debug("Received ROUTER message %s", message) except Exception as e: self.logger.exception("Exception decoding: %s, %s", msg, e) raise e for cb in self._callbacks["router"]: cb(message) if message.type_ == MessageType.identify: message = cast(IdentifyMsg, message) self.on_identify(message) elif message.type_ == MessageType.status: message = cast(StatusMsg, message) self.on_status(message)
[docs] def on_inbox(self, msg: list[bytes]) -> None: message = Message.from_bytes(msg) self.logger.debug("Received INBOX message: %s", message) for cb in self._callbacks["inbox"]: cb(message)
[docs] def on_identify(self, msg: IdentifyMsg) -> None: with self._ready_condition: self._nodes[msg.node_id] = msg.value self._inbox.connect(msg.value["outbox"]) self._ready_condition.notify_all() try: self.announce() self.logger.debug("Announced") except Exception as e: self.logger.exception("Exception announced: %s", e)
[docs] def on_status(self, msg: StatusMsg) -> None: with self._ready_condition: self._nodes[msg.node_id]["status"] = msg.value self._ready_condition.notify_all()
[docs] class NodeRunner(EventloopMixin): """ Runner for a single node - DEALER to communicate with command inbox - PUB (outbox) to publish events - SUB (inbox) to subscribe to events from other nodes. """ def __init__( self, spec: NodeSpecification, runner_id: str, command_outbox: str, command_router: str, input_collection: InputCollection, protocol: str = "ipc", ): super().__init__() self.spec = spec self.runner_id = runner_id self.input_collection = input_collection self.command_outbox = command_outbox self.command_router = command_router self.protocol = protocol self.store = EventStore() self.scheduler: Scheduler = None # type: ignore[assignment] self.logger = init_logger(f"runner.node.{runner_id}.{self.spec.id}") self._dealer: ZMQStream = None # type: ignore[assignment] self._outbox: zmq.Socket = None # type: ignore[assignment] self._inbox: ZMQStream = None # type: ignore[assignment] self._node: Node | None = None self._depends: tuple[tuple[str, str], ...] | None = None self._nodes: dict[str, IdentifyValue] = {} self._counter = count() self._process_quitting = mp.Event() self._freerun = mp.Event() self._process_one = mp.Event() self._status: NodeStatus = NodeStatus.stopped self._status_lock = mp.RLock() self._to_process = 0 @property def outbox_address(self) -> str: if self.protocol == "ipc": path = config.tmp_dir / f"{self.runner_id}/nodes/{self.spec.id}/outbox" path.parent.mkdir(parents=True, exist_ok=True) return f"{self.protocol}://{str(path)}" else: raise NotImplementedError() @property def depends(self) -> tuple[tuple[str, str], ...] | None: """(node, signal) tuples of the wrapped node's dependencies""" if self._node is None: return None elif self._depends is None: self._depends = tuple( (edge.source_node, edge.source_signal) for edge in self._node.edges ) return self._depends @property def status(self) -> NodeStatus: with self._status_lock: return self._status @status.setter def status(self, status: NodeStatus) -> None: with self._status_lock: self._status = status
[docs] @classmethod def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: """ Target for multiprocessing.run, init the class and start it! """ runner = NodeRunner(spec=spec, **kwargs) try: def _handler(sig: int, frame: FrameType | None = None) -> None: raise KeyboardInterrupt() signal.signal(signal.SIGTERM, _handler) runner.init() runner._node = cast(Node, runner._node) runner._process_quitting.clear() runner._freerun.clear() runner._process_one.clear() for args, kwargs, epoch in runner.await_inputs(): runner.logger.debug( "Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch ) value = runner._node.process(*args, **kwargs) events = runner.store.add_value(runner._node.signals, value, runner._node.id, epoch) runner.scheduler.add_epoch() # node runners should not report epoch endings events = [e for e in events if e["node_id"] != "meta"] if events: runner.update_graph(events) runner.publish_events(events) except KeyboardInterrupt: runner.logger.debug("Got keyboard interrupt, quitting") except Exception as e: runner.error(e) finally: runner.deinit()
[docs] def await_inputs(self) -> Generator[tuple[list[Any], dict[str, Any], int]]: self._node = cast(Node, self._node) while not self._process_quitting.is_set(): # if we are not freerunning, keep track of how many times we are supposed to run, # and run until we aren't supposed to anymore! if not self._freerun.is_set(): if self._to_process == 0: self._process_one.wait() self._to_process -= 1 if self._to_process == 0: self._process_one.clear() ready = self.scheduler.await_node(self.spec.id) edges = self._node.edges inputs = self.store.collect(edges, ready["epoch"]) if inputs is None: inputs = {} args, kwargs = self.store.split_args_kwargs(inputs) yield args, kwargs, ready["epoch"]
[docs] def update_graph(self, events: list[Event]) -> None: self.scheduler.update(events)
[docs] def publish_events(self, events: list[Event]) -> None: msg = EventMsg(node_id=self.spec.id, value=events) self._outbox.send_multipart([b"event", msg.to_bytes()])
[docs] def init(self) -> None: self.logger.debug("Initializing") self.init_node() self.start_sockets() self.status = NodeStatus.waiting if self.depends else NodeStatus.ready self.identify() self.logger.debug("Initialization finished")
[docs] def deinit(self) -> None: self.logger.debug("Deinitializing") if self._node is not None: self._node.deinit() self.update_status(NodeStatus.closed) self.stop_loop() self.logger.debug("Deinitialization finished")
[docs] def identify(self) -> None: """ Send the command node an announce to say we're alive """ if self._node is None: raise RuntimeError( "Node was not initialized by the time we tried to " "identify ourselves to the command node." ) with self._status_lock: ann = IdentifyMsg( node_id=self.spec.id, value=IdentifyValue( node_id=self.spec.id, status=self.status, outbox=self.outbox_address, signals=[s.name for s in self._node.signals] if self._node.signals else None, slots=( [slot_name for slot_name in self._node.slots] if self._node.slots else None ), ), ) self._dealer.send_multipart([ann.to_bytes()]) self.logger.debug("Sent identification message: %s", ann)
[docs] def update_status(self, status: NodeStatus) -> None: """Update our internal status and announce it to the command node""" with self._status_lock: self.status = status msg = StatusMsg(node_id=self.spec.id, value=status) self._dealer.send_multipart([msg.to_bytes()])
[docs] def start_sockets(self) -> None: self._init_sockets() self.start_loop()
[docs] def init_node(self) -> None: self._node = Node.from_specification(self.spec, self.input_collection) self._node.init() self.scheduler = Scheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) self.scheduler.add_epoch()
def _init_sockets(self) -> None: self._dealer = self._init_dealer() self._outbox = self._init_outbox() self._inbox = self._init_inbox() def _init_dealer(self) -> ZMQStream: dealer = self.context.socket(zmq.DEALER) dealer.setsockopt_string(zmq.IDENTITY, self.spec.id) dealer.connect(self.command_router) dealer = ZMQStream(dealer, self.loop) dealer.on_recv(self.on_dealer) self.logger.debug("Connected to command node at %s", self.command_router) return dealer def _init_outbox(self) -> zmq.Socket: pub = self.context.socket(zmq.PUB) pub.setsockopt_string(zmq.IDENTITY, self.spec.id) if self.protocol == "ipc": pub.bind(self.outbox_address) else: raise NotImplementedError() # something like: # port = pub.bind_to_random_port(self.protocol) return pub def _init_inbox(self) -> ZMQStream: """ Init the subscriber, but don't attempt to subscribe to anything but the command yet! we do that when we get node Announces """ sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, self.spec.id) sub.setsockopt_string(zmq.SUBSCRIBE, "") sub.connect(self.command_outbox) sub = ZMQStream(sub, self.loop) sub.on_recv(self.on_inbox) self.logger.debug("Subscribed to command outbox %s", self.command_outbox) return sub
[docs] def on_dealer(self, msg: list[bytes]) -> None: self.logger.debug("DEALER received %s", msg)
[docs] def on_inbox(self, msg: list[bytes]) -> None: try: message = Message.from_bytes(msg) self.logger.debug("INBOX received %s", msg) except Exception as e: self.logger.exception("Error decoding message %s %s", msg, e) return # FIXME: all this switching sux, # just have a decorator to register a handler for a given message type if message.type_ == MessageType.announce: message = cast(AnnounceMsg, message) self.on_announce(message) elif message.type_ == MessageType.event: message = cast(EventMsg, message) self.on_event(message) elif message.type_ == MessageType.process: message = cast(ProcessMsg, message) self.on_process(message) elif message.type_ == MessageType.stop: message = cast(StopMsg, message) self.on_stop(message) else: # log but don't throw - other nodes shouldn't be able to crash us self.logger.error(f"{message.type_} not implemented!") self.logger.debug("%s", message)
[docs] def on_announce(self, msg: AnnounceMsg) -> None: """ Store map, connect to the nodes we depend on """ self._node = cast(Node, self._node) with self._status_lock: depended_nodes = {edge.source_node for edge in self._node.edges} for node_id in msg.value["nodes"]: if node_id in depended_nodes and node_id not in self._nodes: # TODO: a way to check if we're already connected, without storing it locally? outbox = msg.value["nodes"][node_id]["outbox"] self._inbox.connect(outbox) self.logger.debug("Subscribed to %s at %s", node_id, outbox) self._nodes = msg.value["nodes"] if set(self._nodes) >= depended_nodes and self.status == NodeStatus.waiting: self.update_status(NodeStatus.ready)
[docs] def on_event(self, msg: EventMsg) -> None: events = msg.value if not self.depends: self.logger.debug("No dependencies, not storing events") return to_add = [e for e in events if (e["node_id"], e["signal"]) in self.depends] for event in to_add: self.store.add(event) self.scheduler.update(events)
[docs] def on_process(self, msg: ProcessMsg) -> None: """ Process a single graph iteration """ # TODO: handle inputs self._to_process += 1 self._process_one.set()
[docs] def on_stop(self, msg: StopMsg) -> None: """Stop processing!""" pid = mp.current_process().pid if pid is None: return self.logger.debug("Emitting sigterm to self %s", msg) os.kill(pid, signal.SIGTERM)
[docs] def error(self, err: Exception) -> None: """ Capture the error and traceback context from an exception using :class:`traceback.TracebackException` and send to command node to re-raise """ tbexception = "\n".join(traceback.format_tb(err.__traceback__)) self.logger.debug("Throwing error in main runner: %s", tbexception) msg = ErrorMsg( node_id=self.spec.id, value=ErrorValue( err_type=type(err), err_args=err.args, traceback=tbexception, ), ) self._dealer.send_multipart([msg.to_bytes()])
[docs] @dataclass class ZMQRunner(TubeRunner): """ A concurrent runner that uses zmq to broker events between nodes running in separate processes """ node_procs: dict[NodeID, mp.Process] = field(default_factory=dict) command: CommandNode | None = None quit_timeout: float = 10 """time in seconds to wait after calling deinit to wait before killing runner processes""" store: EventStore = field(default_factory=EventStore) _running: EventType = field(default_factory=mp.Event) _return_node: Return | None = None _init_lock: threading.Lock = field(default_factory=threading.Lock) _to_throw: ErrorValue | None = None @property def running(self) -> bool: with self._init_lock: return self._running.is_set()
[docs] def init(self) -> None: if self.running: return with self._init_lock: self._logger.debug("Initializing ZMQ runner") self.command = CommandNode(runner_id=self.runner_id) self.command.add_callback("inbox", self.on_event) self.command.add_callback("router", self.on_router) self.command.start() self._logger.debug("Command node initialized") for node_id, node in self.tube.nodes.items(): if isinstance(node, Return): self._return_node = node continue self.node_procs[node_id] = mp.Process( target=NodeRunner.run, args=(node.spec,), kwargs={ "runner_id": self.runner_id, "command_outbox": self.command.pub_address, "command_router": self.command.router_address, "input_collection": self.tube.input_collection, }, name=".".join([self.runner_id, node_id]), daemon=True, ) self.node_procs[node_id].start() self._logger.debug("Started node processes, awaiting ready") self.command.await_ready( [k for k, v in self.tube.nodes.items() if not isinstance(v, Return)] ) self._logger.debug("Nodes ready") self._running.set()
[docs] def deinit(self) -> None: if not self.running: return with self._init_lock: self.command = cast(CommandNode, self.command) self.command.stop() # wait for nodes to finish, if they don't finish in the timeout, kill them started_waiting = time() waiting_on = set(self.node_procs.values()) while time() < started_waiting + self.quit_timeout and len(waiting_on) > 0: _waiting = waiting_on.copy() for proc in _waiting: if not proc.is_alive(): waiting_on.remove(proc) for proc in waiting_on: self._logger.info( f"NodeRunner {proc.name} was still alive after timeout expired, killing it" ) proc.kill() proc.close() self._running.clear()
[docs] def process(self, **kwargs: Any) -> ReturnNodeType: if not self.running: self._logger.info("Runner called process without calling `init`, initializing now.") self.init() self._current_epoch = self.tube.scheduler.add_epoch() self.command = cast(CommandNode, self.command) self.command.process() self._logger.debug("awaiting epoch %s", self._current_epoch) self.tube.scheduler.await_epoch(self._current_epoch) if self._to_throw: self._throw_error() self._logger.debug("collecting return") return self.collect_return(self._current_epoch)
[docs] def on_event(self, msg: Message) -> None: self._logger.debug("EVENT received: %s", msg) if msg.type_ != MessageType.event: self._logger.debug(f"Ignoring message type {msg.type_}") return msg = cast(EventMsg, msg) for event in msg.value: self.store.add(event) self.tube.scheduler.update(msg.value) if self._return_node is not None: # mark the return node done if we've received the expected events for an epoch # do it here since we don't really run the return node like a real node # to avoid an unnecessary pickling/unpickling across the network epochs = set(e["epoch"] for e in msg.value) for epoch in epochs: if self.tube.scheduler.node_is_ready(self._return_node.id, epoch): self._logger.debug("Marking return node ready in epoch %s", epoch) self.tube.scheduler.done(epoch, self._return_node.id)
[docs] def on_router(self, msg: Message) -> None: if isinstance(msg, ErrorMsg): self._handle_error(msg)
[docs] def collect_return(self, epoch: int | None = None) -> Any: if epoch is None: raise ValueError("Must specify epoch in concurrent runners") if self._return_node is None: return None else: events = self.store.collect(self._return_node.edges, epoch) if events is None: return None args, kwargs = self.store.split_args_kwargs(events) self._return_node.process(*args, **kwargs) return self._return_node.get(keep=False)
def _handle_error(self, msg: ErrorMsg) -> None: """Cancel current epoch, stash error for process method to throw""" self._logger.error("Received error from node: %s", msg) self._to_throw = msg.value self.tube.scheduler.end_epoch(self._current_epoch) def _throw_error(self) -> None: errval = self._to_throw if errval is None: return # clear instance object and store locally, we aren't locked here. self._to_throw = None self._logger.debug( "Deinitializing before throwing error", ) self.deinit() # add the traceback as a note, # sort of the best we can do without using tblib err = errval["err_type"](*errval["err_args"]) tb_message = "\nError re-raised from node runner process\n\n" tb_message += "Original traceback:\n" tb_message += "-" * 20 + "\n" tb_message += errval["traceback"] err.add_note(tb_message) raise err
[docs] def enable_node(self, node_id: str) -> None: raise NotImplementedError()
[docs] def disable_node(self, node_id: str) -> None: raise NotImplementedError()