Source code for cheese.models

from abc import abstractmethod

import pickle

from cheese.data import BatchElement
from cheese.tasks import Task

from b_rabbit import BRabbit
from cheese.utils.rabbit_utils import rabbitmq_callback

[docs]class BaseModel: def __init__(self): self.publisher = None self.subscriber = None self.task_queue = [] self.working = False # Is task loop running?
[docs] def init_connection(self, connection : BRabbit): """ Initialize RabbitMQ connection """ self.publisher = connection.EventPublisher( b_rabbit = connection, publisher_name = 'model' ) self.subscriber = connection.EventSubscriber( b_rabbit = connection, routing_key = 'model', publisher_name = 'client', event_listener = self.dequeue_task ) self.subscriber.subscribe_on_thread()
[docs] @abstractmethod def process(self, data : BatchElement) -> BatchElement: """ Process BatchElement with model. Assume the inputs to the model are in the BatchElement, then use them to create some outputs. The outputs should be added to the BatchElement before it is returned. """ pass
[docs] def handle_queued_tasks(self): """ Handle every task in queue. New tasks can be added to queue if needed. Should not be called again if still running. """ if self.working: raise Exception("Error: Tried to handle model queue twice after already calling method once.") self.working = True while self.task_queue: task = self.task_queue.pop(0) task.data = self.process(task.data) self.queue_task(task) self.working = False
[docs] def queue_task(self, task : Task): """ Creates a task and queue to client. :param task: The task to queue :type task: Task """ if task.data.trip == task.data.trip_max: route = 'pipeline' else: route = 'active' tasks = pickle.dumps(task) self.publisher.publish( routing_key = route, payload = tasks )
@rabbitmq_callback def dequeue_task(self, tasks : str): """Check inbound queue for completed task.""" task = pickle.loads(tasks) task.data.trip += 1 self.task_queue.append(task) if not self.working: self.handle_queued_tasks()