# cSpell:words polylines beziers pydai


import asyncio
import logging
from typing import Any, Optional, Sequence
from xml.etree import ElementTree as ET

from playwright.async_api import async_playwright
from pydantic_ai import Agent
from pydantic_ai import exceptions as pydai_exceptions
from pydantic_ai.agent import AgentRunResult
from pydantic_ai.messages import UserContent


def filter_component_details(json: dict) -> dict:
    """
    To reduce the number of tokens a single component lookup uses, this function
    filters out graphical details from the JSON response, as the agent already
    has access to an image of the schematic.
    """

    if not json:
        return {}

    filtered = json.copy()

    for key in [
        "polylines",
        "polygons",
        "rectangles",
        "arcs",
        "ellipses",
        "beziers",
        "bitmaps",
        "texts",
        "position",
    ]:
        if key in filtered:
            del filtered[key]

    if "attributes" in json:
        filtered["attributes"] = {}
        for key, attr in json["attributes"].items():
            filtered["attributes"][attr.get("name")] = attr.get("value")

    if "pins" in json:
        filtered["pins"] = {}
        for pin_id, pin in json["pins"].items():
            filtered["pins"][pin_id] = {
                "designator": pin.get("designator"),
                "electrical_type": pin.get("electrical_type"),
            }

    return filtered


def filter_schematic_page(page_json: dict[str, Any]) -> dict[str, Any]:
    page = page_json.copy()

    for key in [
        "title_blocks",
        "graphics",
        "fonts",
        "border",
        "ui_group_id",
        "ui_group_name",
        "wires",
        "no_connects",
        "junctions",
    ]:
        if key in page:
            del page[key]

    if "components" in page:
        filtered_components = {}
        for comp_id, component in page["components"].items():
            filtered_components[comp_id] = filter_component_details(component)
        page["components"] = filtered_components

    if "ports" in page:
        filtered_components = {}
        for comp_id, component in page["ports"].items():
            filtered_components[comp_id] = filter_component_details(component)

        page["ports"] = filtered_components

    return page


def filter_schematic_json(schematic_json: dict[str, Any]) -> dict[str, Any]:
    filtered_json = schematic_json.copy()

    new_pages = []

    for page in filtered_json.get("pages", []):
        filtered_page = filter_schematic_page(page)

        new_pages.append(filtered_page)

    filtered_json["pages"] = new_pages

    return filtered_json


def split_multipage_svg(svg_text: str) -> list[str]:
    """
    Split a multi-page SVG into individual SVG files, one for each page.
    Uses ElementTree for proper XML parsing.

    Args:
        svg_text (str): The content of the multi-page SVG file.
    Returns:
        list: List of the SVG contents for each page.
    """

    ET.register_namespace("", "http://www.w3.org/2000/svg")
    parser = ET.XMLParser(encoding="utf-8")

    root = ET.fromstring(svg_text, parser=parser)

    children = list(root)

    # Each pair of <style> and <g> is one page.
    page_pairs = []

    for i in range(len(children) - 1):
        current = children[i]
        next_elem = children[i + 1]

        if current.tag.endswith("}style") and next_elem.tag.endswith("}g"):
            page_pairs.append((current, next_elem))

    output_files = []

    for i, (style_elem, g_elem) in enumerate(page_pairs):
        new_svg = ET.Element("svg")

        for attr, value in root.attrib.items():
            new_svg.set(attr, value)

        original_id = root.get("id", "")
        new_svg.set("id", f"{original_id}" if original_id else f"page-{i + 1}")

        width = g_elem.get("data-width")
        height = g_elem.get("data-height")
        view_box: str = g_elem.get("data-view-box")

        if width:
            new_svg.set("width", width)
        if height:
            new_svg.set("height", height)

        new_svg.set("viewBox", view_box)

        new_svg.append(style_elem)
        del g_elem.attrib["transform"]
        new_svg.append(g_elem)

        svg_str = ET.tostring(new_svg, encoding="unicode")

        output_files.append(svg_str)

    return output_files


async def render_svg(svg_path: str, output_path: str) -> None:
    """
    Render an SVG file to a PNG image using Playwright.
    """

    async with async_playwright() as p:
        browser = await p.firefox.launch()
        context = await browser.new_context(viewport={"width": 1920, "height": 1080})
        page = await context.new_page()
        await page.goto(f"file://{svg_path}")
        await page.screenshot(path=output_path, full_page=False)
        await context.close()
        await browser.close()


async def run_agent_with_retries[DT, RT](
    agent: Agent[DT, RT],
    inputs: Sequence[UserContent],
    deps: DT,
    max_attempts: int,
    retry_delay: int,
    logger: logging.Logger,
) -> Optional[AgentRunResult[RT]]:
    attempts = 0
    result = None

    while attempts < max_attempts:
        try:
            result = await agent.run(inputs, deps=deps)
            break
        except pydai_exceptions.ModelHTTPError as e:
            if e.status_code in [429, 529]:
                attempts += 1
                cause = "Rate limited" if e.status_code == 429 else "Overloaded"
                logger.warning(
                    f"{cause}; sleeping {retry_delay}s before "
                    f"retrying attempt {attempts}"
                )
                logger.debug(f"Error: {e}")
                await asyncio.sleep(retry_delay)
                continue
            else:
                raise e

    return result