import geopandas as gpd
from geopandas import GeoDataFrame
from shapely.geometry import Polygon, MultiPolygon, LineString
from shapely.ops import unary_union
import numpy as np
from shapely.geometry import Polygon, Point
from shapely.strtree import STRtree
import networkx as nx
import time
from collections import defaultdict
import triangle

#translate zoom level to aggregation:
ZOOM_LEVEL_TO_ALPHA = {
    14.0 : 500.0,
    15.0 : 250.0,
    16.0 : 50.0,
    17.0 : 25.0,
    18.0 : 10.0
}
DEFAULT_ALPHA_LOW_ZOOM = 2000.0
DEFAULT_ALPHA_HIGH_ZOOM = 0.0

LOCAL_CRS = 25832

def extract_vertices(gdf : GeoDataFrame) -> np.ndarray:
    points = []
    for geom in gdf.geometry:
        if isinstance(geom, Polygon):
            points.extend(list(geom.exterior.coords))
        elif isinstance(geom, MultiPolygon):
            for poly in geom.geoms:
                points.extend(list(poly.exterior.coords))
    # Remove duplicates
    unique_points = list(set(points))
    return np.array(unique_points)

def extract_vertices_and_segments(gdf : GeoDataFrame) -> tuple[np.ndarray, np.ndarray]:
    vertices = []
    vertex_map = {}  # maps coord -> index
    segments = []

    for poly in gdf.geometry:
        if poly.is_empty:
            continue
        if poly.geom_type == "Polygon":
            rings = [poly.exterior] + list(poly.interiors)
        elif poly.geom_type == "MultiPolygon":
            rings = []
            for p in poly.geoms:
                rings.append(p.exterior)
                rings.extend(p.interiors)
        else:
            continue

        for ring in rings:
            coords = list(ring.coords)
            for i in range(len(coords) - 1):  # skip duplicate last point
                c1, c2 = coords[i], coords[i + 1]

                for c in (c1, c2):
                    if c not in vertex_map:
                        vertex_map[c] = len(vertices)
                        vertices.append(c)

                i1, i2 = vertex_map[c1], vertex_map[c2]
                segments.append((i1, i2))

    return np.array(vertices), np.array(segments)

def get_bbox_vertices_and_segments(gdf : GeoDataFrame, num_vertices : int) -> tuple[np.ndarray, np.ndarray]:
    # Compute polygon bounds
    minx, miny, maxx, maxy = gdf.total_bounds
    padding = 0.1 * max(maxx - minx, maxy - miny)  # add some margin

    # Bounding box vertices
    bbox_vertices = np.array([
        [minx - padding, miny - padding],
        [maxx + padding, miny - padding],
        [maxx + padding, maxy + padding],
        [minx - padding, maxy + padding]
    ])

    # Bounding box edges (as indices into all_vertices)
    bbox_segments = np.array([
        [num_vertices + 0, num_vertices + 1],
        [num_vertices + 1, num_vertices + 2],
        [num_vertices + 2, num_vertices + 3],
        [num_vertices + 3, num_vertices + 0]
    ])

    return bbox_vertices, bbox_segments

def flag_triangles_naive(triangles : list[Polygon], polygons : list[Polygon]) -> list[bool]:
    flagged = []
    for tri_poly in triangles:
        if any(tri_poly.intersection(poly).area > 1e-4 for poly in polygons):
            flagged.append(True)
        else:
            flagged.append(False)
    return flagged

def flag_triangles_fast(triangles : list[Polygon], polygons : list[Polygon]) -> list[bool]:
    tree = STRtree(polygons)
    flagged = []
    for tri_poly in triangles:
        # get only candidate polygons whose bounding boxes intersect the triangle
        candidates = tree.query(tri_poly)
        # check actual intersection
        if any(tri_poly.intersection(polygons[p]).area > 1e-4 for p in candidates):
            flagged.append(True)
        else:
            flagged.append(False)
    return flagged

def build_triangle_graph_naive(triangles : list[Polygon], flagged_triangles : list[bool], alpha : float) -> tuple[nx.DiGraph, str, str]:
    G = nx.DiGraph()
    source = "source"
    sink = "sink"

    # Add nodes for triangles
    for i, tri in enumerate(triangles):
        G.add_node(i, polygon=tri)

    # Add edges between neighboring triangles
    for i, tri1 in enumerate(triangles):
        for j, tri2 in enumerate(triangles):
            if i >= j:
                continue
            inter = tri1.intersection(tri2)
            if inter.is_empty:
                continue
            # weight = length of common boundary
            weight = alpha * inter.length
            G.add_edge(i, j, capacity=weight)
            G.add_edge(j, i, capacity=weight)

    # Add source/sink edges
    for i, tri in enumerate(triangles):
        if flagged_triangles[i]:
            # flagged triangles are mandatory and hence get an infinite connection to source
            G.add_edge(source, i, capacity=float("inf"))
            G.add_edge(i, sink, capacity=0)
        else:
            outer_boundary = tri.exterior.length - sum(
                tri.intersection(t).length for j, t in enumerate(triangles) if i != j
            )
            area = tri.area
            G.add_edge(source, i, capacity=0)
            # add edge with objective value to sink
            G.add_edge(i, sink, capacity=area + alpha * outer_boundary)

    return G, source, sink

def build_triangle_graph_fast(triangles : list[Polygon], flagged_triangles : list[bool], alpha : float) -> tuple[nx.DiGraph, str, str]:
    G = nx.DiGraph()
    source = "source"
    sink = "sink"

    # Add nodes for triangles
    for i, tri in enumerate(triangles):
        G.add_node(i, polygon=tri)

    # --- Step 1: Build edge-to-triangle map ---
    edge_to_triangles = defaultdict(list)
    for i, tri in enumerate(triangles):
        coords = list(tri.exterior.coords)[:-1]  # drop closing coord
        for j in range(3):
            v1 = coords[j]
            v2 = coords[(j + 1) % 3]
            edge = tuple(sorted((v1, v2)))  # order-invariant
            edge_to_triangles[edge].append(i)

    # --- Step 2: Add neighbor edges between triangles ---
    for edge, tris in edge_to_triangles.items():
        if len(tris) == 2:  # internal edge
            t1, t2 = tris
            v1, v2 = edge
            edge_length = LineString([v1, v2]).length  # just distance between v1 and v2
            weight = alpha * edge_length
            G.add_edge(t1, t2, capacity=weight)
            G.add_edge(t2, t1, capacity=weight)

    # --- Step 3: Add source/sink edges ---
    for i, tri in enumerate(triangles):
        if flagged_triangles[i]:
            # flagged triangles → infinite connection to source
            G.add_edge(source, i, capacity=float("inf"))
            G.add_edge(i, sink, capacity=0)
        else:
            # compute outer boundary: edges belonging only to one triangle
            coords = list(tri.exterior.coords)[:-1]
            outer_boundary = 0.0
            for j in range(3):
                v1 = coords[j]
                v2 = coords[(j + 1) % 3]
                edge = tuple(sorted((v1, v2)))
                if len(edge_to_triangles[edge]) == 1:  # outer edge
                    outer_boundary += LineString([v1, v2]).length

            area = tri.area
            G.add_edge(source, i, capacity=0)
            G.add_edge(i, sink, capacity=area + alpha * outer_boundary)

    return G, source, sink

def retrieve_geometric_solution_from_cut(source_component : list[int], triangles : list[Polygon]) -> GeoDataFrame:
    # Filter out the special source/sink nodes
    source_triangles = [triangles[i] for i in source_component if isinstance(i, int)]

    # Dissolve triangles into a GeoDataFrame
    merged = unary_union(source_triangles)
    if isinstance(merged, Polygon):
        gdf = gpd.GeoDataFrame(geometry=[merged])
    elif isinstance(merged, MultiPolygon):
        gdf = gpd.GeoDataFrame(geometry=list(merged.geoms))
    else:
        gdf = gpd.GeoDataFrame(geometry=[])
    return gdf


# ---------- Aggregator ----------
class Aggregator:
    def aggregate(self, gdf_input : GeoDataFrame, zoom_level: float, tri_mode : str="naive", graph_mode : str="naive") -> GeoDataFrame:
        if(tri_mode not in ["naive", "fast"]):
            raise ValueError(f"Invalid tri_mode: {tri_mode}")
        if(graph_mode not in ["naive", "fast"]):
            raise ValueError(f"Invalid graph_mode: {graph_mode}")

        # project for geometric accuracy
        gdf = gdf_input.copy()
        if gdf.crs is not None and gdf.crs.to_epsg() != LOCAL_CRS:
            input_crs = gdf.crs.to_epsg()
            gdf = gdf.to_crs(epsg=LOCAL_CRS)

        start = time.perf_counter()

        alpha = self.get_alpha_from_zoom_level(zoom_level)
        print(f"[{time.perf_counter() - start:.3f}s] set alpha to: {alpha}")

        t0 = time.perf_counter()
        vertices, segments = extract_vertices_and_segments(gdf)
        bbox_vertices, bbox_segments = get_bbox_vertices_and_segments(gdf, len(vertices))
        vertices = np.vstack((vertices, bbox_vertices))
        segments = np.vstack((segments, bbox_segments))
        print(f"[{time.perf_counter() - t0:.3f}s] extracted {len(vertices)} vertices.")

        t0 = time.perf_counter()
        A = dict(vertices=vertices, segments=segments)
        tri = triangle.triangulate(A, 'p')  # 'p' = enforce segments (CDT)
        print(f"[{time.perf_counter() - t0:.3f}s] triangulated.")

        t0 = time.perf_counter()
        final_vertices = tri["vertices"]  # ndarray of shape (n_points, 2)
        triangles = [Polygon(final_vertices[simplex]) for simplex in tri["triangles"]]
        print(f"[{time.perf_counter() - t0:.3f}s] created {len(triangles)} polygons.")

        t0 = time.perf_counter()
        if tri_mode == "naive":
            flags = flag_triangles_naive(triangles, gdf.geometry)
        elif tri_mode == "fast":
            flags = flag_triangles_fast(triangles, gdf.geometry)
        print(f"[{time.perf_counter() - t0:.3f}s] flagged {sum(flags)} triangles.")

        t0 = time.perf_counter()
        if graph_mode == "naive":
            G, source, sink = build_triangle_graph_naive(triangles, flags, alpha)
        elif graph_mode == "fast":
            G, source, sink = build_triangle_graph_fast(triangles, flags, alpha)
        print(f"[{time.perf_counter() - t0:.3f}s] built graph with {len(G.nodes)} nodes and {len(G.edges)} edges.")

        t0 = time.perf_counter()
        cut_value, (reachable, non_reachable) = nx.minimum_cut(G, source, sink)
        print(f"[{time.perf_counter() - t0:.3f}s] cut graph (cut value={cut_value:.2f}).")

        t0 = time.perf_counter()
        result_gdf = retrieve_geometric_solution_from_cut(reachable, triangles)
        result_gdf.set_crs(epsg=LOCAL_CRS, inplace=True)
        print(f"[{time.perf_counter() - t0:.3f}s] retrieved geometric solution → {len(result_gdf)} polygons.")

        total_time = time.perf_counter() - start
        print(f"Finished entire pipeline in {total_time:.3f} seconds.")

        # reset input crs
        if result_gdf.crs is not None and result_gdf.crs.to_epsg() != input_crs:
            result_gdf = result_gdf.to_crs(epsg=input_crs)

        return result_gdf

    def get_alpha_from_zoom_level(self,zoom_level: float) -> float:
        if zoom_level < 14:
            alpha = DEFAULT_ALPHA_LOW_ZOOM
        elif zoom_level > 18:
            alpha = DEFAULT_ALPHA_HIGH_ZOOM
        else:
            alpha = ZOOM_LEVEL_TO_ALPHA.get(zoom_level, 100.0)
        return alpha