from __future__ import annotations

import re
import string
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, cast

import numpy as np
import pandas as pd

from dask.dataframe._compat import PANDAS_GE_220, PANDAS_GE_300
from dask.dataframe._pyarrow import is_object_string_dtype
from dask.dataframe.core import tokenize
from dask.dataframe.io.utils import DataFrameIOFunction
from dask.utils import random_state_data

__all__ = [
    "make_timeseries",
    "with_spec",
    "ColumnSpec",
    "RangeIndexSpec",
    "DatetimeIndexSpec",
    "DatasetSpec",
]

default_int_args: dict[str, tuple[tuple[Any, ...], dict[str, Any]]] = {
    "poisson": ((), {"lam": 1000}),
    "normal": ((), {"scale": 1000}),
    "uniform": ((), {"high": 1000}),
    "binomial": ((1000, 0.5), {}),
    "random": ((0,), {"high": 1000}),
}


@dataclass
class ColumnSpec:
    """Encapsulates properties of a family of columns with the same dtype.
    Different method can be specified for integer dtype ("poisson", "uniform",
    "binomial", etc.)

    Notes
    -----
    This API is still experimental, and will likely change in the future"""

    prefix: str | None = None
    """Column prefix. If not specified, will default to str(dtype)"""

    dtype: str | type | None = None
    """Column data type. Only supports numpy dtypes"""

    number: int = 1
    """How many columns to create with these properties. Default 1.
    If more than one columns are specified, they will be numbered: "int1", "int2", etc."""

    nunique: int | None = None  # number of unique categories
    """For a "category" column, how many unique categories to generate"""

    choices: list = field(default_factory=list)
    """For a "category" or str column, list of possible values"""

    low: int | None = None
    """Start value for an int column. Optional if random=True, since ``randint`` doesn't accept
    high and low."""

    high: int | None = None
    """For an int column, high end of range"""

    length: int | None = None
    """For a str or "category" column with random=True, how large a string to generate"""

    random: bool = False
    """For an int column, whether to use ``randint``. For a string column produces a random string
    of specified ``length``"""

    method: str | None = None
    """For an int column, method to use when generating the value, such as "poisson", "uniform", "binomial".
    Default "poisson". Delegates to the same method of ``RandomState``"""

    args: tuple[Any, ...] = field(default_factory=tuple)
    """Args to pass into the method"""

    kwargs: dict[str, Any] = field(default_factory=dict)
    """Any other kwargs to pass into the method"""


@dataclass
class RangeIndexSpec:
    """Properties of the dataframe RangeIndex

    Notes
    -----
    This API is still experimental, and will likely change in the future"""

    dtype: str | type = int
    """Index dtype"""

    step: int = 1
    """Step for a RangeIndex"""


@dataclass
class DatetimeIndexSpec:
    """Properties of the dataframe DatetimeIndex

    Notes
    -----
    This API is still experimental, and will likely change in the future"""

    dtype: str | type = int
    """Index dtype"""

    start: str | None = None
    """First value of the index"""

    freq: str = "1H"
    """Frequency for the index ("1H", "1D", etc.)"""

    partition_freq: str | None = None
    """Partition frequency ("1D", "1M", etc.)"""


@dataclass
class DatasetSpec:
    """Defines a dataset with random data, such as which columns and data types to generate

    Notes
    -----
    This API is still experimental, and will likely change in the future"""

    npartitions: int = 1
    """How many partitions generate in the dataframe. If the dataframe has a DatetimeIndex, specify
    its ``partition_freq`` instead"""

    nrecords: int = 1000
    """Total number of records to generate"""

    index_spec: RangeIndexSpec | DatetimeIndexSpec = field(
        default_factory=RangeIndexSpec
    )
    """Properties of the index"""

    column_specs: list[ColumnSpec] = field(default_factory=list)
    """List of column definitions"""


def make_float(n, rstate, random=False, **kwargs):
    kwargs.pop("dtype", None)
    kwargs.pop("args", None)
    if random:
        return rstate.random(size=n, **kwargs)
    return rstate.rand(n) * 2 - 1


def make_int(
    n: int,
    rstate: Any,
    random: bool = False,
    dtype: str | type = int,
    method: str | Callable = "poisson",
    args: tuple[Any, ...] = (),
    **kwargs,
):
    def _with_defaults(_method):
        handler_args, handler_kwargs = default_int_args.get(_method, ((), {}))
        handler_kwargs = handler_kwargs.copy()
        handler_kwargs.update(**kwargs)
        handler_args = args if args else handler_args
        return handler_args, handler_kwargs

    if random:
        handler_args, handler_kwargs = _with_defaults("random")
        if "low" in handler_kwargs:
            handler_args = ()
        data = rstate.randint(*handler_args, size=n, **handler_kwargs)
    else:
        if isinstance(method, str):
            # "poisson", "binomial", etc.
            handler_args, handler_kwargs = _with_defaults(method)
            handler = getattr(rstate, method)
            data = handler(*handler_args, size=n, **handler_kwargs)
        else:
            # method is a Callable
            data = method(*args, state=rstate, size=n, **kwargs)
    return data


names = [
    "Alice",
    "Bob",
    "Charlie",
    "Dan",
    "Edith",
    "Frank",
    "George",
    "Hannah",
    "Ingrid",
    "Jerry",
    "Kevin",
    "Laura",
    "Michael",
    "Norbert",
    "Oliver",
    "Patricia",
    "Quinn",
    "Ray",
    "Sarah",
    "Tim",
    "Ursula",
    "Victor",
    "Wendy",
    "Xavier",
    "Yvonne",
    "Zelda",
]


def make_random_string(n, rstate, length: int = 25) -> list[str]:
    choices = list(string.ascii_letters + string.digits + string.punctuation + " ")
    return ["".join(rstate.choice(choices, size=length)) for _ in range(n)]


def make_string(n, rstate, choices=None, random=False, length=None, **kwargs):
    kwargs.pop("args", None)
    if random:
        return make_random_string(n, rstate, length=length)
    choices = choices or names
    return rstate.choice(choices, size=n)


def make_categorical(n, rstate, choices=None, nunique=None, **kwargs):
    kwargs.pop("args", None)
    if nunique is not None:
        cat_len = len(str(nunique))
        choices = [str(x + 1).zfill(cat_len) for x in range(nunique)]
    else:
        choices = choices or names
    return pd.Categorical.from_codes(rstate.randint(0, len(choices), size=n), choices)


make: dict[type | str, Callable] = {
    float: make_float,
    int: make_int,
    str: make_string,
    object: make_string,
    "string[python]": make_string,
    "string[pyarrow]": make_string,
    "category": make_categorical,
    "int8": make_int,
    "int16": make_int,
    "int32": make_int,
    "int64": make_int,
    "float8": make_float,
    "float16": make_float,
    "float32": make_float,
    "float64": make_float,
}


class MakeDataframePart(DataFrameIOFunction):
    """
    Wrapper Class for ``make_dataframe_part``
    Makes a timeseries partition.
    """

    def __init__(self, index_dtype, dtypes, kwargs, columns=None):
        self.index_dtype = index_dtype
        self._columns = columns or list(dtypes.keys())
        self.dtypes = dtypes
        self.kwargs = kwargs

    @property
    def columns(self):
        return self._columns

    def project_columns(self, columns):
        """Return a new MakeTimeseriesPart object with
        a sub-column projection.
        """
        if columns == self.columns:
            return self
        return MakeDataframePart(
            self.index_dtype,
            self.dtypes,
            self.kwargs,
            columns=columns,
        )

    def __call__(self, part):
        divisions, state_data = part
        return make_dataframe_part(
            self.index_dtype,
            divisions[0],
            divisions[1],
            self.dtypes,
            self.columns,
            state_data,
            self.kwargs,
        )


def make_dataframe_part(index_dtype, start, end, dtypes, columns, state_data, kwargs):
    state = np.random.RandomState(state_data)
    if pd.api.types.is_datetime64_any_dtype(index_dtype):
        # FIXME: tzinfo would be lost in pd.date_range
        index = pd.date_range(
            start=start, end=end, freq=kwargs.get("freq"), name="timestamp"
        )
    elif pd.api.types.is_integer_dtype(index_dtype):
        step = kwargs.get("freq")
        index = pd.RangeIndex(start=start, stop=end + step, step=step).astype(
            index_dtype
        )
    else:
        raise TypeError(f"Unhandled index dtype: {index_dtype}")
    df = make_partition(columns, dtypes, index, kwargs, state)
    while df.index[-1] >= end:
        df = df.iloc[:-1]
    return df


def same_astype(a: str | type, b: str | type):
    """Same as pandas.api.types.is_dtype_equal, but also returns True for str / object"""
    return pd.api.types.is_dtype_equal(a, b) or (
        is_object_string_dtype(a) and is_object_string_dtype(b)
    )


def make_partition(columns: list, dtypes: dict[str, type | str], index, kwargs, state):
    data = {}
    for k, dt in dtypes.items():
        kws = {
            kk.rsplit("_", 1)[1]: v
            for kk, v in kwargs.items()
            if kk.rsplit("_", 1)[0] == k
        }
        # Note: we compute data for all dtypes in order, not just those in the output
        # columns. This ensures the same output given the same state_data, regardless
        # of whether there is any column projection.
        # cf. https://github.com/dask/dask/pull/9538#issuecomment-1267461887
        result = make[dt](len(index), state, **kws)
        if k in columns:
            data[k] = result
    df = pd.DataFrame(data, index=index, columns=columns)
    update_dtypes = {
        k: v
        for k, v in dtypes.items()
        if k in columns and not same_astype(v, df[k].dtype)
    }
    if update_dtypes:
        kwargs = {} if PANDAS_GE_300 else {"copy": False}
        df = df.astype(update_dtypes, **kwargs)
    return df


_ME = "ME" if PANDAS_GE_220 else "M"


def make_timeseries(
    start="2000-01-01",
    end="2000-12-31",
    dtypes=None,
    freq="10s",
    partition_freq=f"1{_ME}",
    seed=None,
    **kwargs,
):
    """Create timeseries dataframe with random data

    Parameters
    ----------
    start: datetime (or datetime-like string)
        Start of time series
    end: datetime (or datetime-like string)
        End of time series
    dtypes: dict (optional)
        Mapping of column names to types.
        Valid types include {float, int, str, 'category'}
    freq: string
        String like '2s' or '1H' or '12W' for the time series frequency
    partition_freq: string
        String like '1M' or '2Y' to divide the dataframe into partitions
    seed: int (optional)
        Randomstate seed
    kwargs:
        Keywords to pass down to individual column creation functions.
        Keywords should be prefixed by the column name and then an underscore.

    Examples
    --------
    >>> import dask.dataframe as dd
    >>> df = dd.demo.make_timeseries('2000', '2010',
    ...                              {'value': float, 'name': str, 'id': int},
    ...                              freq='2h', partition_freq='1D', seed=1)
    >>> df.head()  # doctest: +SKIP
                           id      name     value
    2000-01-01 00:00:00   969     Jerry -0.309014
    2000-01-01 02:00:00  1010       Ray -0.760675
    2000-01-01 04:00:00  1016  Patricia -0.063261
    2000-01-01 06:00:00   960   Charlie  0.788245
    2000-01-01 08:00:00  1031     Kevin  0.466002
    """
    if dtypes is None:
        dtypes = {"name": str, "id": int, "x": float, "y": float}

    divisions = list(pd.date_range(start=start, end=end, freq=partition_freq))
    npartitions = len(divisions) - 1
    if seed is None:
        # Get random integer seed for each partition. We can
        # call `random_state_data` in `MakeDataframePart`
        state_data = np.random.randint(2e9, size=npartitions)
    else:
        state_data = random_state_data(npartitions, seed)

    # Build parts
    parts = []
    for i in range(len(divisions) - 1):
        parts.append((divisions[i : i + 2], state_data[i]))

    kwargs["freq"] = freq
    index_dtype = "datetime64[ns]"
    meta_start, meta_end = list(pd.date_range(start="2000", freq=freq, periods=2))

    from dask.dataframe import _dask_expr_enabled

    if _dask_expr_enabled():
        from dask_expr import from_map

        k = {}
    else:
        from dask.dataframe.io.io import from_map

        k = {"token": tokenize(start, end, dtypes, freq, partition_freq, state_data)}

    # Construct the output collection with from_map
    return from_map(
        MakeDataframePart(index_dtype, dtypes, kwargs),
        parts,
        meta=make_dataframe_part(
            index_dtype,
            meta_start,
            meta_end,
            dtypes,
            list(dtypes.keys()),
            state_data[0],
            kwargs,
        ),
        divisions=divisions,
        label="make-timeseries",
        enforce_metadata=False,
        **k,
    )


def with_spec(spec: DatasetSpec, seed: int | None = None):
    """Generate a random dataset according to provided spec

    Parameters
    ----------
    spec : DatasetSpec
        Specify all the parameters of the dataset
    seed: int (optional)
        Randomstate seed

    Notes
    -----
    This API is still experimental, and will likely change in the future

    Examples
    --------
    >>> from dask.dataframe.io.demo import ColumnSpec, DatasetSpec, with_spec
    >>> ddf = with_spec(
    ...     DatasetSpec(
    ...         npartitions=10,
    ...         nrecords=10_000,
    ...         column_specs=[
    ...             ColumnSpec(dtype=int, number=2, prefix="p"),
    ...             ColumnSpec(dtype=int, number=2, prefix="n", method="normal"),
    ...             ColumnSpec(dtype=float, number=2, prefix="f"),
    ...             ColumnSpec(dtype=str, prefix="s", number=2, random=True, length=10),
    ...             ColumnSpec(dtype="category", prefix="c", choices=["Y", "N"]),
    ...         ],
    ...     ), seed=42)
    >>> ddf.head(10)  # doctest: +SKIP
         p1    p2    n1    n2        f1        f2          s1          s2 c1
    0  1002   972  -811    20  0.640846 -0.176875  L#h98#}J`?  _8C607/:6e  N
    1   985   982 -1663  -777  0.790257  0.792796  u:XI3,omoZ  w~@ /d)'-@  N
    2   947   970   799  -269  0.740869 -0.118413  O$dnwCuq\\  !WtSe+(;#9  Y
    3  1003   983  1133   521 -0.987459  0.278154  j+Qr_2{XG&  &XV7cy$y1T  Y
    4  1017  1049   826     5 -0.875667 -0.744359  \4bJ3E-{:o  {+jC).?vK+  Y
    5   984  1017  -492  -399  0.748181  0.293761  ~zUNHNgD"!  yuEkXeVot|  Y
    6   992  1027  -856    67 -0.125132 -0.234529  j.7z;o]Gc9  g|Fi5*}Y92  Y
    7  1011   974   762 -1223  0.471696  0.937935  yT?j~N/-u]  JhEB[W-}^$  N
    8   984   974   856    74  0.109963  0.367864  _j"&@ i&;/  OYXQ)w{hoH  N
    9  1030  1001  -792  -262  0.435587 -0.647970  Pmrwl{{|.K  3UTqM$86Sg  N
    """
    if len(spec.column_specs) == 0:
        spec.column_specs = [
            ColumnSpec(prefix="i", dtype="int64", low=0, high=1_000_000, random=True),
            ColumnSpec(prefix="f", dtype=float, random=True),
            ColumnSpec(prefix="c", dtype="category", choices=["a", "b", "c", "d"]),
            ColumnSpec(prefix="s", dtype=str),
        ]

    columns = []
    dtypes = {}
    partition_freq: str | int | None
    step: str | int
    if isinstance(spec.index_spec, DatetimeIndexSpec):
        start = pd.Timestamp(spec.index_spec.start)
        step = spec.index_spec.freq
        partition_freq = spec.index_spec.partition_freq
        end = pd.Timestamp(spec.index_spec.start) + spec.nrecords * pd.Timedelta(step)
        divisions = list(pd.date_range(start=start, end=end, freq=partition_freq))
        if divisions[-1] < end:
            divisions.append(end)
        meta_start, meta_end = start, start + pd.Timedelta(step)
    elif isinstance(spec.index_spec, RangeIndexSpec):
        step = spec.index_spec.step
        partition_freq = spec.nrecords * step // spec.npartitions
        end = spec.nrecords * step - 1
        divisions = list(pd.RangeIndex(0, stop=end, step=partition_freq))
        if divisions[-1] < (end + 1):
            divisions.append(end + 1)
        meta_start, meta_end = 0, step
    else:
        raise ValueError(f"Unhandled index: {spec.index_spec}")

    kwargs: dict[str, Any] = {"freq": step}
    for col in spec.column_specs:
        if col.prefix:
            prefix = col.prefix
        elif isinstance(col.dtype, str):
            prefix = re.sub(r"[^a-zA-Z0-9]", "_", f"{col.dtype}").rstrip("_")
        elif hasattr(col.dtype, "name"):
            prefix = col.dtype.name  # type: ignore[union-attr]
        else:
            prefix = col.dtype.__name__  # type: ignore[union-attr]
        for i in range(col.number):
            col_n = i + 1
            while (col_name := f"{prefix}{col_n}") in dtypes:
                col_n = col_n + 1
            columns.append(col_name)
            dtypes[col_name] = col.dtype
            kwargs.update(
                {
                    f"{col_name}_{k}": v
                    for k, v in asdict(col).items()
                    if k not in {"prefix", "number", "kwargs"} and v not in (None, [])
                }
            )
            # set untyped kwargs, if any
            for kw_name, kw_val in col.kwargs.items():
                kwargs[f"{col_name}_{kw_name}"] = kw_val

    npartitions = len(divisions) - 1
    if seed is None:
        state_data = cast(list[Any], np.random.randint(int(2e9), size=npartitions))
    else:
        state_data = random_state_data(npartitions, seed)

    parts = [(divisions[i : i + 2], state_data[i]) for i in range(npartitions)]

    from dask.dataframe import _dask_expr_enabled

    if _dask_expr_enabled():
        from dask_expr import from_map

        k = {}
    else:
        from dask.dataframe.io.io import from_map

        k = {
            "token": tokenize(
                0, spec.nrecords, dtypes, step, partition_freq, state_data
            )
        }

    return from_map(
        MakeDataframePart(spec.index_spec.dtype, dtypes, kwargs, columns=columns),
        parts,
        meta=make_dataframe_part(
            spec.index_spec.dtype,
            meta_start,
            meta_end,
            dtypes,
            columns,
            state_data[0],
            kwargs,
        ),
        divisions=divisions,
        label="make-random",
        enforce_metadata=False,
        **k,
    )
