from PyQt5.QtWebEngineWidgets import QWebEngineView
from PyQt5.QtWebChannel import QWebChannel, QWebChannelAbstractTransport, QWebChannel
from PyQt5.QtWidgets import QLabel, QApplication, QMainWindow, QFileDialog, QVBoxLayout, QWidget, QAction, QComboBox, \
    QHBoxLayout, QWidget, QSizePolicy
from PyQt5.QtCore import QObject, pyqtSlot, QUrl
import json
import geopandas as gpd

from src.compute.aggregator import Aggregator

# ---------- Bridge for communication ----------
class MapBridge(QObject):
    def __init__(self, parent):
        super().__init__()
        self.parent = parent

    @pyqtSlot(float)
    def zoomChanged(self, zoom_level):
        self.parent.update_display(zoom_level)

# ---------- Main Window ----------
class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Building Aggregation Demo (Leaflet + Dynamic Zoom)")
        self.resize(900, 700)

        self.aggregator = Aggregator()
        self.current_gdf = None

        # Central widget with WebEngine
        central_widget = QWidget()
        layout = QVBoxLayout(central_widget)
        self.web_view = QWebEngineView()
        layout.addWidget(self.web_view)
        self.setCentralWidget(central_widget)

        # Menu
        menubar = self.menuBar()
        file_menu = menubar.addMenu("File")
        open_action = QAction("Open", self)
        open_action.triggered.connect(self.open_file_dialog)
        file_menu.addAction(open_action)

        #Dropdown for Mode Selection
        self.tri_mode_selector = QComboBox()
        self.tri_mode_selector.addItems(["naive", "fast"])
        self.tri_mode_selector.setCurrentText("naive") 
        self.current_tri_mode = "naive"
        self.tri_mode_selector.currentTextChanged.connect(self.on_tri_mode_changed)

        tri_layout = QVBoxLayout()
        tri_layout.setContentsMargins(0, 0, 0, 0)  
        tri_layout.setSpacing(2) 
        tri_label = QLabel("Triangles")
        tri_layout.addWidget(tri_label)
        tri_layout.addWidget(self.tri_mode_selector)

        self.graph_mode_selector = QComboBox()
        self.graph_mode_selector.addItems(["naive", "fast"])
        self.graph_mode_selector.setCurrentText("naive") 
        self.current_graph_mode = "naive"
        self.graph_mode_selector.currentTextChanged.connect(self.on_graph_mode_changed)

        graph_layout = QVBoxLayout()
        graph_layout.setContentsMargins(0, 0, 0, 0)
        graph_layout.setSpacing(2)
        graph_label = QLabel("Graph")
        graph_layout.addWidget(graph_label)
        graph_layout.addWidget(self.graph_mode_selector)

        for label in [tri_label, graph_label]:
            label.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)

        for combo in [self.tri_mode_selector, self.graph_mode_selector]:
            combo.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)

        # Top bar layout
        top_bar = QHBoxLayout()
        top_bar.setContentsMargins(0, 0, 0, 0)
        top_bar.setSpacing(0)  
        top_bar.addStretch()
        top_bar.addLayout(tri_layout)
        top_bar.addLayout(graph_layout)

        container = QWidget()
        main_layout = QVBoxLayout(container)
        main_layout.addLayout(top_bar)
        main_layout.addWidget(self.web_view)

        self.setCentralWidget(container)

        # Set up WebChannel
        self.bridge = MapBridge(self)
        self.channel = QWebChannel()
        self.channel.registerObject("bridge", self.bridge)
        self.web_view.page().setWebChannel(self.channel)

        # Load empty map first
        self.init_map([])

    def open_file_dialog(self):
        filepath, _ = QFileDialog.getOpenFileName(
            self, "Open footprint file", "", "Vector files (*.shp *.gpkg)"
        )
        if filepath:
            self.load_and_display(filepath)

    def load_and_display(self, filepath):
        gdf = gpd.read_file(filepath)

        # Reproject to EPSG:4326 (lat/lon) for Leaflet
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(epsg=4326)

        self.current_gdf = gdf

        self.init_display()

    def init_display(self):
        if self.current_gdf is None:
            return
        aggregated = self.aggregator.aggregate(self.current_gdf, zoom_level=20, tri_mode=self.current_tri_mode, graph_mode=self.current_graph_mode) 
        geojson = aggregated.to_json()
        self.init_map(json.loads(geojson))

    def init_map(self, geojson_data):
        """Load map HTML and add initial GeoJSON layer."""
        geojson_obj = geojson_data or {"type": "FeatureCollection", "features": []}

        html = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <meta charset="utf-8" />
            <title>Map</title>
            <link rel="stylesheet" href="https://unpkg.com/leaflet/dist/leaflet.css"/>
            <script src="https://unpkg.com/leaflet/dist/leaflet.js"></script>
            <script src="qrc:///qtwebchannel/qwebchannel.js"></script>
            <style>html, body, #map {{height:100%; margin:0;}}</style>
        </head>
        <body>
            <div id="map"></div>
            <script>
                new QWebChannel(qt.webChannelTransport, function(channel) {{
                    window.bridge = channel.objects.bridge;

                    // Initialize map
                    window.map = L.map('map');

                    // Add OpenStreetMap tiles
                    //L.tileLayer('https://{{s}}.tile.openstreetmap.org/{{z}}/{{x}}/{{y}}.png', {{
                    L.tileLayer('https://{{s}}.basemaps.cartocdn.com/light_all/{{z}}/{{x}}/{{y}}.png', {{
                        maxZoom: 19
                    }}).addTo(map);

                    // Add initial GeoJSON layer
                    window.geojsonLayer = L.geoJSON({json.dumps(geojson_obj)}).addTo(map);

                    // Fit map to data bounds
                    var bounds = geojsonLayer.getBounds();
                    if (bounds.isValid()) {{
                        map.fitBounds(bounds);
                    }} else {{
                        if (geojsonLayer.getLayers().length > 0) {{
                            map.setView(geojsonLayer.getLayers()[0].getBounds().getCenter(), 15);
                        }} else {{
                            map.setView([0,0], 2);
                        }}
                    }}

                    // Listen to zoom events
                    map.on('zoomend', function() {{
                        var zoom = map.getZoom();
                        bridge.zoomChanged(zoom);
                    }});

                    // Function to update GeoJSON without resetting zoom/center
                    window.updateGeoJSON = function(newData) {{
                        map.removeLayer(geojsonLayer);
                        geojsonLayer = L.geoJSON(newData).addTo(map);
                    }};
                }});
            </script>
        </body>
        </html>
        """
        self.web_view.setHtml(html, QUrl("qrc:///"))

    def update_display(self, zoom_level: float):
        if self.current_gdf is None:
            return
        aggregated = self.aggregator.aggregate(self.current_gdf, zoom_level=zoom_level, tri_mode=self.current_tri_mode, graph_mode=self.current_graph_mode)

        geojson = aggregated.to_json()
        self.update_map(json.loads(geojson))

    def update_map(self, geojson_data):
        self.web_view.page().runJavaScript(f"updateGeoJSON({json.dumps(geojson_data)});")

    def on_tri_mode_changed(self, text):
        self.current_tri_mode = text
        print(f"Aggregation mode switched to: {self.current_tri_mode}")

    def on_graph_mode_changed(self, text):
        self.current_graph_mode = text
        print(f"Aggregation mode switched to: {self.current_graph_mode}")

