Source code for cheese.pipeline.datasets

from abc import abstractmethod
from typing import Iterable, Dict, Any

from datasets import Dataset

from cheese.pipeline import Pipeline
from cheese.utils import safe_mkdir

import pandas as pd

[docs]class DatasetPipeline(Pipeline): """ Base class for any pipeline thats data destination is a datasets.Dataset object :param format: Format to save result dataset to. Defaults to arrow. Can be arrow or csv. :type format: str :param save_every: Save dataset whenever this number of rows is added. :type save_every: int """ def __init__(self, format : str = "csv", save_every : int = 1): super().__init__() self.write_path : str = None self.res_dataset : Dataset = None self.format = format self.save_every = save_every self.save_accum = 0
[docs] def load_dataset(self) -> bool: """ Loads the results dataset from a given path. Returns false if load fails. Assumes write_path has been set already. :return: Whether load was successful :rtype: bool """ if self.write_path is None: raise Exception("Error: Attempted to load results dataset without ever specifiying a path to write it to") try: if self.format == "arrow": self.res_dataset = Dataset.load_from_disk(self.write_path) elif self.format == "csv": self.res_dataset = pd.read_csv(self.write_path) return True except: return False
[docs] def save_dataset(self): """ Saves the result dataset to the write path (assuming it has been specified by subclass). Does nothing if there is no data to save yet. """ if self.res_dataset is None: return if self.write_path is None: raise Exception("Error: Attempted to save result dataset without ever specifiying a path to write to") if self.format == "arrow": self.res_dataset.save_to_disk(self.write_path) elif self.format == "csv": self.res_dataset.to_csv(self.write_path, index = False)
[docs] def add_row_to_dataset(self, row : Dict[str, Any]): """ Add single row to result dataset and then saves. :param row: The row, as a dictionary, to add to the result dataset :type row: Dict[str, Any] """ row = {key : [row[key]] for key in row} if self.res_dataset is None: self.res_dataset = Dataset.from_dict(row) if self.format == "arrow" else pd.DataFrame(row) else: if self.format == "arrow": self.res_dataset = self.res_dataset.append(row) else: new_df = pd.DataFrame(row) self.res_dataset = pd.concat([self.res_dataset, new_df], ignore_index = True) self.save_accum += 1 if self.save_accum >= self.save_every: self.save_dataset() self.save_accum = 0