Source code for cheese.api

from re import I
from typing import ClassVar, Iterable, Tuple, Dict, Any

from cheese.client import ClientManager
from cheese.client.gradio_client import GradioClientManager
from cheese.pipeline import Pipeline
from cheese.models import BaseModel

import cheese.utils.msg_constants as msg_constants
from cheese.utils.rabbit_utils import rabbitmq_callback

from b_rabbit import BRabbit

# Master object for CHEESE
[docs]class CHEESE: """ Main object to use for running tasks with CHEESE :param pipeline_cls: Class for pipeline :type pipeline_cls: Callable[, Pipeline] :param client_cls: Class for client :type client_cls: Callable[,GradioFront] if gradio :param model_cls: Class for model :type model_cls: Callable[,BaseModel] :param pipeline_kwargs: Additional keyword arguments to pass to pipeline constructor :type pipeline_kwargs: Dict[str, Any] :param gradio: Whether to use gradio or custom frontend :type gradio: bool """ def __init__( self, pipeline_cls, client_cls = None, model_cls = None, pipeline_kwargs : Dict[str, Any] = {}, model_kwargs : Dict[str, Any] = {}, gradio : bool = True ): self.gradio = gradio # Initialize rabbit MQ server self.connection = BRabbit(host='localhost', port=5672) # Channel for client to notify of task completion self.subscriber = self.connection.EventSubscriber( b_rabbit = self.connection, routing_key = 'main', publisher_name = 'client', event_listener = self.client_ping ) self.subscriber.subscribe_on_thread() # API components initialized self.pipeline : Pipeline = pipeline_cls(**pipeline_kwargs) self.model : BaseModel = model_cls(**model_kwargs) if model_cls is not None else None self.client_cls = client_cls if gradio: self.client_manager = GradioClientManager() else: self.client_manager = ClientManager() self.pipeline.init_connection(self.connection) self.client_manager.init_connection(self.connection) if self.model is not None: self.model.init_connection(self.connection) self.clients = 0 self.busy_clients = 0 self.finished = False # For when pipeline is exhausted
[docs] def launch(self) -> str: """ Launch the frontend and return URL for users to access it. """ return self.client_manager.init_front(self.client_cls)
@rabbitmq_callback def client_ping(self, msg): """ Method for ClientManager to ping the API when it needs more tasks or has taken a task """ msg = msg.decode('utf-8') if msg == msg_constants.SENT: # Client sent task to pipeline, needs a new one self.busy_clients -= 1 self.draw() elif msg == msg_constants.RECEIVED: self.busy_clients += 1 else: raise Exception("Error: Client pinged master with unknown message")
[docs] def create_client(self, id : int) -> Tuple[int, int]: """ Create a client instance with given id. :param id: A unique identifying number for the client. :type id: int :return: Username and password user can use to log in to CHEESE """ id, pwd = self.client_manager.add_client(id) self.clients += 1 self.draw() # pre-emptively draw a task for the client to pick up return id, pwd
[docs] def draw(self): """ Draws a sample from data pipeline and creates a task to send to clients. Does nothing if no free clients. """ if self.busy_clients >= self.clients: return exhausted = not self.pipeline.queue_task() if exhausted and self.pipeline.exhausted(): # finished, so we can stop self.finished = True