Source code for cheese.api

from typing import ClassVar, Iterable, Tuple, Dict, Any, Callable

from cheese.client import ClientManager, ClientStatistics
from cheese.client.gradio_client import GradioClientManager
from cheese.pipeline import Pipeline
from cheese.models import BaseModel

import cheese.utils.msg_constants as msg_constants
from cheese.utils.rabbit_utils import rabbitmq_callback

import pickle
from b_rabbit import BRabbit
from tqdm import tqdm
import time

# Master object for CHEESE
[docs]class CHEESEAPI: """ API to access CHEESE master object. Assumes :param host: Host for rabbitmq server. Normally just locahost if you are running locally :type host: str :param port: Port to run rabbitmq server on :type port: int :param timeout: Timeout for waiting for main server to respond :type timeout: float :param debug: Print debug messages for rabbitmq :type debug: bool """ def __init__(self, host : str = 'localhost', port : int = 5672, timeout : float = 10, debug : bool = False): self.timeout = timeout # Initialize rabbit MQ server self.connection = BRabbit(host=host, port=port) self.debug = debug # Channel to get results back from main server self.subscriber = self.connection.EventSubscriber( b_rabbit = self.connection, routing_key = 'api', publisher_name = 'main', event_listener = self.main_listener ) self.subscriber.subscribe_on_thread() # Channel to send commands to main server self.publisher = self.connection.EventPublisher( b_rabbit = self.connection, publisher_name = 'api' ) # Any received values from main will be placed here self.buffer : Any = None # Check if main server is running self.connected_to_main : bool = False self.publisher.publish('main', pickle.dumps(msg_constants.READY)) self.connected_to_main = True if not self.await_result(): raise Exception("Main server not running") @rabbitmq_callback def main_listener(self, msg : str): """ Callback for main server. Receives messages from main server and places them in buffer. """ if self.debug: print(f"Received message from main server: {pickle.loads(msg)}") if not self.connected_to_main: print("Warning: RabbitMQ queue non-empty at startup. Consider restarting RabbitMQ server if unexpected errors arise.") return msg = pickle.loads(msg) self.buffer = msg
[docs] def await_result(self, time_step : float = 0.5): """ Assuming buffer is none """ total_time = 0 while self.buffer is None: time.sleep(time_step) total_time += time_step if total_time > self.timeout: print("Warning: Timeout exceeded awaiting API result.") return None res = self.buffer self.buffer = None return res
[docs] def launch(self) -> str: """ Launch the frontend and return URL for users to access it. """ self.publisher.publish('main', pickle.dumps(msg_constants.LAUNCH)) return self.await_result()
[docs] def create_client(self, id : int) -> Tuple[int, int]: """ Create a client instance with given id. :param id: A unique identifying number for the client. :type id: int :return: Username and password user can use to log in to CHEESE """ msg = f"{msg_constants.ADD}|{id}" self.publisher.publish('main', pickle.dumps(msg)) return self.await_result()
[docs] def remove_client(self, id : int): """ Remove client with given id. :param id: A unique identifying number for the client. :type id: int """ msg = f"{msg_constants.REMOVE}|{id}" self.publisher.publish('main', pickle.dumps(msg)) return self.await_result()
[docs] def get_stats(self) -> Dict: """ Get various statistics in the form of a dictionary. :return: Dictionary containing following statistics: - num_clients: Number of clients connected to CHEESE - num_busy_clients: Number of clients currently working on a task - num_tasks: Number of tasks completed overall - client_stats: Dictionary of client statistics - model_stats: Dictionary of model statistics - pipeline_stats: Dictionary of pipeline statistics """ self.publisher.publish('main', pickle.dumps(msg_constants.STATS)) return self.await_result()
[docs] def draw(self): """ Draws a sample from data pipeline and creates a task to send to clients. Does nothing if no free clients. This check if overriden if draw_always is set to True. """ self.publisher.publish('main', pickle.dumps(msg_constants.DRAW))
[docs] def progress_bar(self, max_tasks : int, access_stat : Callable, call_every : Callable = None, check_every : float = 1.0): """ This function shows a progress bar via tqdm some given stat. Blocks execution. Not recommended for interactive use. :param max_tasks: The maximum number of tasks to show progress to before returning :type max_tasks: int :param access_stat: Some callable that returns a stat we want to see progress for (i.e. as an integer). :type access_stat: Callable[, int] :param call_every: Some callable that we want to call every time stat is updated. :type call_every: Callable[, None] :param check_every: How often to check for updates to the stat in seconds. :type check_every: float """ for i in tqdm(range(max_tasks)): current_stat = access_stat() while True: if call_every: call_every() if current_stat != access_stat(): break time.sleep(check_every)