from abc import abstractmethod
import pickle
from cheese.data import BatchElement
from cheese.tasks import Task
from b_rabbit import BRabbit
from cheese.utils.rabbit_utils import rabbitmq_callback
[docs]class BaseModel:
def __init__(self):
self.publisher = None
self.subscriber = None
self.task_queue = []
self.working = False # Is task loop running?
[docs] def init_connection(self, connection : BRabbit):
"""
Initialize RabbitMQ connection
"""
self.publisher = connection.EventPublisher(
b_rabbit = connection,
publisher_name = 'model'
)
self.subscriber = connection.EventSubscriber(
b_rabbit = connection,
routing_key = 'model',
publisher_name = 'client',
event_listener = self.dequeue_task
)
self.subscriber.subscribe_on_thread()
[docs] @abstractmethod
def process(self, data : BatchElement) -> BatchElement:
"""
Process BatchElement with model. Assume the inputs to the model are in the BatchElement,
then use them to create some outputs. The outputs should be added to the BatchElement before it is returned.
"""
pass
[docs] def handle_queued_tasks(self):
"""
Handle every task in queue. New tasks can be added to queue if needed. Should not be called again if still running.
"""
if self.working:
raise Exception("Error: Tried to handle model queue twice after already calling method once.")
self.working = True
while self.task_queue:
task = self.task_queue.pop(0)
task.data = self.process(task.data)
self.queue_task(task)
self.working = False
[docs] def queue_task(self, task : Task):
"""
Creates a task and queue to client.
:param task: The task to queue
:type task: Task
"""
if task.data.trip == task.data.trip_max:
route = 'pipeline'
else:
route = 'active'
tasks = pickle.dumps(task)
self.publisher.publish(
routing_key = route,
payload = tasks
)
@rabbitmq_callback
def dequeue_task(self, tasks : str):
"""Check inbound queue for completed task."""
task = pickle.loads(tasks)
task.data.trip += 1
self.task_queue.append(task)
if not self.working:
self.handle_queued_tasks()