Source code for lories.connectors.csv

# -*- coding: utf-8 -*-
"""
lories.connectors.csv
~~~~~~~~~~~~~~~~~~~~~


"""

from __future__ import annotations

import os
from typing import Mapping, Optional, Tuple

import pandas as pd
from lories.connectors import ConnectionError, Database, register_connector_type
from lories.core.configs import ConfigurationError
from lories.io import csv
from lories.typing import Configurations, Resources, Timestamp
from lories.util import ceil_date, floor_date, parse_freq


# noinspection PyShadowingBuiltins
[docs] @register_connector_type("csv") class CsvDatabase(Database): _data: Optional[pd.DataFrame] = None _data_path: Optional[str] = None _data_dir: str index_column: str = "timestamp" index_type: str = "timestamp" override: bool = False slice: bool = False freq: str = "D" format: str = None suffix: Optional[str] = None decimal: str = "." separator: str = "," columns: Mapping[str, str] = {} pretty: bool = False # noinspection PyTypeChecker def configure(self, configs: Configurations) -> None: super().configure(configs) data_dir = configs.get("dir", default=None) if data_dir is not None: if "~" in data_dir: data_dir = os.path.expanduser(data_dir) if not os.path.isabs(data_dir): data_dir = os.path.join(configs.dirs.data, data_dir) data_path = configs.get("file", default=None) if data_path is not None: if not os.path.isabs(data_path): if data_dir is None: data_dir = configs.dirs.data data_path = os.path.join(data_dir, data_path) elif data_dir is None: data_dir = os.path.dirname(data_path) if data_dir is None: data_dir = configs.dirs.data self._data_dir = data_dir self._data_path = data_path self.index_column = configs.get("index_column", default=CsvDatabase.index_column) self.index_type = configs.get("index_type", default=CsvDatabase.index_type).lower() if self.index_type not in ["timestamp", "unix", "none", None]: raise ConfigurationError(f"Unknown index type: {self.index_type}") self.override = configs.get_bool("override", default=CsvDatabase.override) self.slice = configs.get_bool("slice", default=CsvDatabase.slice) self.freq = parse_freq(configs.get("freq", default=CsvDatabase.freq)) format = configs.get("format", default=CsvDatabase.format) if format is not None: self.format = format elif self.freq == "Y": self.format = "%Y" elif self.freq == "M": self.format = "%Y-%m" elif any([self.freq.endswith("D")]): self.format = "%Y%m%d" elif any([self.freq.endswith(s) for s in ["h", "min", "s"]]): self.format = "%Y%m%d_%H%M%S" else: raise ConfigurationError(f"Invalid frequency: {self.freq}") self.suffix = configs.get("suffix", default=CsvDatabase.suffix) if self.suffix is not None: self.format += f"_{self.suffix}" self.decimal = configs.get("decimal", CsvDatabase.decimal) self.separator = configs.get("separator", CsvDatabase.separator) self.pretty = configs.get_bool("pretty", default=False) self.columns = configs.get("columns", default=CsvDatabase.columns) def _build_columns(self, resources: Optional[Resources] = None) -> Mapping[str, str]: columns = {r.id: self.columns[r.key] for r in resources if r.key in self.columns} if self.pretty: columns.update({r.id: r.full_name(unit=True) for r in self.resources if "name" in r}) if resources is not None: columns.update({r.id: r.full_name(unit=True) for r in resources if "name" in r}) else: columns.update({r.id: r.get("column", default=r.key) for r in self.resources}) if resources is not None: columns.update({r.id: r.get("column", default=r.key) for r in resources}) return columns def connect(self, resources: Resources) -> None: if not os.path.isdir(self._data_dir): os.makedirs(self._data_dir, exist_ok=True) try: if self._data_path is not None: self._data = csv.read_file( self._data_path, index_column=self.index_column, index_type=self.index_type, timezone=self.timezone, separator=self.separator, decimal=self.decimal, rename=self._build_columns(resources), ) except IOError as e: raise ConnectionError(self, str(e)) def disconnect(self) -> None: self._data = None def is_connected(self) -> bool: return True def read( self, resources: Resources, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, ) -> pd.DataFrame: def _infer_dates(s=start, e=end) -> Tuple[pd.Timestamp, pd.Timestamp]: if all(pd.isna(d) for d in [s, e]): n = pd.Timestamp.now(tz=self.timezone) s = floor_date(n, timezone=self.timezone, freq=self.freq) e = ceil_date(n, timezone=self.timezone, freq=self.freq) return s, e try: if self._data is not None: data = self._data else: data = csv.read_files( self._data_dir, self.freq, self.format, *_infer_dates(), index_column=self.index_column, index_type=self.index_type, timezone=self.timezone, separator=self.separator, decimal=self.decimal, ) if self.index_type in ["timestamp", "unix"] and all(pd.isna(d) for d in [start, end]): now = pd.Timestamp.now(tz=self.timezone) index = data.index.tz_convert(self.timezone).get_indexer([now], method="nearest") data = data.iloc[[index[-1]], :] columns = self._build_columns(resources) results = [] for resource in resources: resource_column = columns[resource.id] if resource_column not in data.columns: results.append(pd.Series(name=resource.id)) continue resource_data = data.loc[:, resource_column].copy() resource_data.name = resource.id results.append(resource_data) return pd.concat(results, axis="columns") except IOError as e: raise ConnectionError(self, str(e)) # noinspection PyTypeChecker def read_first(self, resources: Resources) -> Optional[pd.DataFrame]: try: if self._data is not None: data = self._data else: files = csv.get_files( self._data_dir, self.freq, self.format, timezone=self.timezone, ) if len(files) == 0: return None data = csv.read_file( files[0], index_column=self.index_column, index_type=self.index_type, timezone=self.timezone, separator=self.separator, decimal=self.decimal, rename=self._build_columns(resources), ) if data is None or data.empty: return None columns = self._build_columns(resources) results = [] for resource in resources: resource_column = columns[resource.id] if resource_column not in data.columns: results.append(pd.Series(name=resource.id)) continue resource_data = data.loc[:, resource_column].copy() resource_data.name = resource.id results.append(resource_data) return pd.concat(results, axis="columns").head(1) except IOError as e: raise ConnectionError(self, str(e)) # noinspection PyTypeChecker def read_last(self, resources: Resources) -> Optional[pd.DataFrame]: try: if self._data is not None: data = self._data else: files = csv.get_files( self._data_dir, self.freq, self.format, timezone=self.timezone, ) if len(files) == 0: return None data = csv.read_file( files[-1], index_column=self.index_column, index_type=self.index_type, timezone=self.timezone, separator=self.separator, decimal=self.decimal, rename=self._build_columns(resources), ) if data is None or data.empty: return None columns = self._build_columns(resources) results = [] for resource in resources: resource_column = columns[resource.id] if resource_column not in data.columns: results.append(pd.Series(name=resource.id)) continue resource_data = data.loc[:, resource_column].copy() resource_data.name = resource.id results.append(resource_data) return pd.concat(results, axis="columns").tail(1) except IOError as e: raise ConnectionError(self, str(e)) def write(self, data: pd.DataFrame) -> None: columns = self._build_columns(self.resources) kwargs = { "timezone": self.timezone, "separator": self.separator, "decimal": self.decimal, "override": self.override, "rename": columns, } if self.pretty: data.index.name = self.index_column.title() else: data.index.name = self.index_column if self.slice: csv.write_files(data, self._data_dir, self.freq, self.format, **kwargs) else: csv_file = os.path.join(self._data_dir, data.index[0].strftime(self.format) + ".csv") csv.write_file(data, csv_file, **kwargs)