from abc import abstractmethod
from typing import Iterable, Any
import pickle
from cheese.data import BatchElement
from cheese.tasks import Task
from b_rabbit import BRabbit
from cheese.utils.rabbit_utils import rabbitmq_callback
[docs]class BaseModel:
"""
BaseModel object handles anything that may require a model for processing data. It can also be used more generally
just to handle data processing separately from the pipeline and client.
:param batch_size: The maximum number of elements to process at once. If there are not this many elements available,
the model will simply process everything that is in the task queue.
:type batch_size: int
"""
def __init__(self, batch_size : int = 1):
self.publisher = None
self.subscriber = None
self.task_queue = []
self.working = False # Is task loop running?
self.batch_size = batch_size
[docs] def get_stats(self) -> dict:
"""
Get statistics about the model.
"""
return {"num_tasks": len(self.task_queue)}
[docs] def init_connection(self, connection : BRabbit):
"""
Initialize RabbitMQ connection
"""
self.publisher = connection.EventPublisher(
b_rabbit = connection,
publisher_name = 'model'
)
self.subscriber_client = connection.EventSubscriber(
b_rabbit = connection,
routing_key = 'model',
publisher_name = 'client',
event_listener = self.dequeue_task
)
self.subscriber_pipeline = connection.EventSubscriber(
b_rabbit = connection,
routing_key = 'model',
publisher_name = 'pipeline',
event_listener = self.dequeue_task
)
self.subscriber_client.subscribe_on_thread()
self.subscriber_pipeline.subscribe_on_thread()
[docs] @abstractmethod
def process(self, data : Iterable[BatchElement]) -> Iterable[BatchElement]:
"""
Process BatchElement with model. Assume the inputs to the model are in the BatchElement,
then use them to create some outputs. The outputs should be added to the BatchElement before it is returned.
:param data: The data to process. Can be an iterable of BatchElements, or a single one, depending on use-case.
"""
pass
[docs] def handle_queued_tasks(self):
"""
Handle every task in queue. New tasks can be added to queue if needed. Should not be called again if still running.
"""
if self.working:
raise Exception("Error: Tried to handle model queue twice after already calling method once.")
self.working = True
while self.task_queue:
tasks = self.task_queue[:self.batch_size]
self.task_queue = self.task_queue[self.batch_size:]
data_list = self.process([task.data for task in tasks])
for i, data in enumerate(data_list):
tasks[i].data = data
while tasks:
self.queue_task(tasks.pop(0))
self.working = False
[docs] def queue_task(self, task : Task):
"""
Creates a task and queue to client.
:param task: The task to queue
:type task: Task
"""
if task.data.trip == task.data.trip_max:
route = 'pipeline'
else:
route = 'active'
tasks = pickle.dumps(task)
self.publisher.publish(
routing_key = route,
payload = tasks
)
@rabbitmq_callback
def dequeue_task(self, tasks : str):
"""Check inbound queue for completed task."""
task = pickle.loads(tasks)
task.data.trip += 1
self.task_queue.append(task)
if not self.working:
self.handle_queued_tasks()