"""Peewee migrations -- 147_MAIN.py.

Some examples (model - class or model name)::

    > Model = migrator.orm['table_name']            # Return model in current state by name
    > Model = migrator.ModelClass                   # Return model in current state by name

    > migrator.sql(sql)                             # Run custom SQL
    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
    > migrator.remove_model(model, cascade=True)    # Remove a model
    > migrator.add_fields(model, **fields)          # Add fields to a model
    > migrator.change_fields(model, **fields)       # Change fields
    > migrator.remove_fields(model, *field_names, cascade=True)
    > migrator.rename_field(model, old_field_name, new_field_name)
    > migrator.rename_table(model, new_table_name)
    > migrator.add_index(model, *col_names, unique=False)
    > migrator.add_not_null(model, *field_names)
    > migrator.add_default(model, field_name, default)
    > migrator.add_constraint(model, name, sql)
    > migrator.drop_index(model, *col_names)
    > migrator.drop_not_null(model, *field_names)
    > migrator.drop_constraints(model, *constraints)

"""

from contextlib import suppress

import peewee as pw
from peewee_migrate import Migrator


with suppress(ImportError):
    import playhouse.postgres_ext as pw_pext


def get_zones(migrator: Migrator, distance: float | None) -> list:
    TransportZone = migrator.orm['transportzone']

    if distance is not None:
        zones_of_distance = (
            TransportZone.select()
            .where(TransportZone.distance_km_max > distance)
            .where(TransportZone.distance_km_min < distance)
            .where(TransportZone.distance_min_inclusive == False)  # noqa: E712
            .where(TransportZone.distance_max_inclusive == False)  # noqa: E712
            + TransportZone.select()
            .where(TransportZone.distance_km_max >= distance)
            .where(TransportZone.distance_km_min < distance)
            .where(TransportZone.distance_min_inclusive == False)  # noqa: E712
            .where(TransportZone.distance_max_inclusive == True)  # noqa: E712
            + TransportZone.select()
            .where(TransportZone.distance_km_max > distance)
            .where(TransportZone.distance_km_min <= distance)
            .where(TransportZone.distance_min_inclusive == True)  # noqa: E712
            .where(TransportZone.distance_max_inclusive == False)  # noqa: E712
            + TransportZone.select()
            .where(TransportZone.distance_km_max >= distance)
            .where(TransportZone.distance_km_min <= distance)
            .where(TransportZone.distance_min_inclusive == True)  # noqa: E712
            .where(TransportZone.distance_max_inclusive == True)  # noqa: E712
        )
    else:
        zones_of_distance = []

    ret = []

    for x in zones_of_distance:
        if x not in ret:
            ret.append(x)

    for x in TransportZone.select():
        if x not in ret:
            ret.append(x)

    return ret


def calc_price_transport(migrator: Migrator, delivery, volume: float | None, kms: float) -> float | None:
    from atxdispatch import glo

    if glo.setup.get("transport_zones"):
        try:
            zone = delivery.transport_zone_record or get_zones(migrator, kms)[0]
        except LookupError:
            return None
        try:
            if zone.price_is_per_m3:
                return zone.price_per_m3 * max(volume or 0, zone.minimal_volume or 0)
            else:
                return zone.price_per_m3
        except TypeError:
            return None
    else:
        if (price := delivery.price_per_km_modified or delivery.car_price_per_km) is not None:
            return price * kms
        else:
            return None


def recalc_order_prices(migrator: Migrator) -> None:
    from atxdispatch import glo

    if not glo.setup.get("transport_zones"):
        return

    for order in migrator.orm["order"].select():
        if order.without_transport:
            continue

        distance_driven = order.distance_driven_modified
        if distance_driven is not None:
            distance_driven = order.distance_driven or 0

        for delivery in order.deliveries:
            delivery.price_transport = calc_price_transport(migrator, delivery, order.volume, distance_driven)
            delivery.save()


def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
    """Write your migrations here."""

    migrator.rename_field('delivery', 'transport_zone_modified', 'transport_zone_record')
    
    migrator.add_fields(
        'delivery',

        price_transport=pw.DoubleField(null=True))

    migrator.run(recalc_order_prices, migrator)


def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
    """Write your rollback migrations here."""

    migrator.rename_field('delivery', 'transport_zone_record', 'transport_zone_modified')

    migrator.remove_fields('delivery', 'price_transport')
