import sys
import subprocess
import importlib
import time
from pathlib import Path
import threading
from tqdm import tqdm

print("""
===============================================================
   -------Hackers Nightlight Python flash utility-------
===============================================================
""")

print(r"""
              @@@@         @@              
          @@@@@@@          @@@@@@          
        @@@@   @@@@      @@@   @@@@@       
      @@@@  @@@@@@        @@@@@@  @@@      
    @@@@ @@@@@  @@@@    @@@  @@@@@ @@@@    
   @@@  @@@  @@@@@       @@@@@  @@@  @@@   
  @@@ @@@@ @@@@ @@@     @@@@@@@@ @@@@ @@@  
 @@@ @@@  @@@@@@    @@@    @@@@@@  @@@ @@@ 
@@@ @@@@ @@@@@   @@@ @    @@@@@@@@ @@@  @@@
@@@ @@@ @@@@@    @@@ @    @@@ @@@@@ @@@ @@@
@@  @@  @@@@      @@ @     @@  @@@@  @@  @@
@@  @@  @@@@  @@@@@@ @ @@@ @@  @@@@  @@  @@
@@  @@    @@   @@ @@ @ @@  @@  @@@@  @@  @@
           @   @@ @@ @ @@@@@   @         @@
            @@ @@@ @@@ @@@   @@            
             @@@ @@@ @ @@  @@@             
               @@ @@ @ @@ @@               
                @ @@ @ @@ @                
                @@@@@@@@@@@                
                @@@@@@@@@@@                
                @@@@@@@@@@@                
                @@@@@@@@@@@                
                   @@@@@                   
                                                                         
""")
print("""
===============================================================
""")

REQUIRED_PACKAGES = ["pyserial", "esptool", "requests", "tqdm"]
LOG_FILE = "logs.txt"
CP210X_VID = 0x10C4

def ensure_package(pkg_name, import_name=None):
    if import_name is None:
        import_name = pkg_name
    try:
        importlib.import_module(import_name)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--user", pkg_name])
        importlib.invalidate_caches()
        importlib.import_module(import_name)

ensure_package("pyserial", "serial")
ensure_package("esptool")
ensure_package("requests")
ensure_package("tqdm")

import serial
import serial.tools.list_ports
import requests
with open(LOG_FILE, "w") as f:
    f.write("")

def log(msg):
    with open(LOG_FILE, "a") as f:
        f.write(msg + "\n")

BIN_URLS = {
    "ATK-C6": {
        "bootloader.bin": "https://flash.hackersnightlight.com/firmwares/ATK/v0.7.0/bootloader.bin",
        "partition-table.bin": "https://flash.hackersnightlight.com/firmwares/ATK/v0.7.0/partition-table.bin",
        "ATK-C6.bin": "https://flash.hackersnightlight.com/firmwares/ATK/v0.7.0/ATK-C6.bin",
    },
    "ESP-RGB": {
        "bootloader.bin": "https://flash.hackersnightlight.com/firmwares/RGB/v0.7.0/bootloader.bin",
        "partition-table.bin": "https://flash.hackersnightlight.com/firmwares/RGB/v0.7.0/partition-table.bin",
        "ESP-RGB.bin": "https://flash.hackersnightlight.com/firmwares/RGB/v0.7.0/ESP-RGB.bin",
    }
}

FLASH_FILES_COMMON = [
    ("bootloader.bin", 0x0),
    ("partition-table.bin", 0x8000),
]
DEVICE_APP_BIN = {
    "ATK-C6": ("ATK-C6.bin", 0x10000),
    "ESP-RGB": ("ESP-RGB.bin", 0x10000),
}

def find_cp210x_devices():
    return [port for port in serial.tools.list_ports.comports() if port.vid == CP210X_VID]

def wait_for_port_ready(port_name, timeout=5):
    start = time.time()
    while time.time() - start < timeout:
        try:
            with serial.Serial(port_name, baudrate=115200, timeout=0.1):
                return True
        except serial.SerialException:
            time.sleep(0.5)
    return False

def verify_serial(port, device_type, timeout=20):
    try:
        with serial.Serial(port, baudrate=115200, timeout=0.1) as ser:
            buffer = ""
            start = time.time()
            while time.time() - start < timeout:
                if ser.in_waiting:
                    data = ser.read(ser.in_waiting)
                    buffer += data.decode(errors='ignore')
                    if device_type == "ATK-C6" and "ATK-C6" in buffer:
                        print(f"{device_type} detected on {port} via serial.")
                        return True
                    elif device_type == "ESP-RGB" and ("ESP-RGB" in buffer or "RGB" in buffer):
                        print(f"{device_type} detected on {port} via serial.")
                        return True
            print(f"No serial output detected on {port}. Firmware may still be running.")
            return False
    except serial.SerialException as e:
        print(f"[!] Serial error on {port}: {e}")
        return False

def reset_device(port):
    try:
        with serial.Serial(port, baudrate=115200) as ser:
            ser.setDTR(False)
            ser.setRTS(True)
            time.sleep(0.1)
            ser.setDTR(True)
            ser.setRTS(False)
            log(f"Manual reset sent to {port}")
            print(f"Manual reset sent to {port}")
    except serial.SerialException as e:
        log(f"[!] Failed to reset {port}: {e}")

def download_bins(dest_folder="firmware"):
    Path(dest_folder).mkdir(parents=True, exist_ok=True)
    for dev_type, bins in BIN_URLS.items():
        for name, url in bins.items():
            dest = Path(dest_folder) / f"{dev_type}_{name}"
            log(f"Downloading {name} for {dev_type}...")
            r = requests.get(url, stream=True)
            r.raise_for_status()
            with open(dest, "wb") as f:
                for chunk in r.iter_content(4096):
                    f.write(chunk)
    log("All binaries downloaded.")

def flash_device_with_progress(port, device_type, bin_folder="firmware"):
    files_to_flash = FLASH_FILES_COMMON + [DEVICE_APP_BIN[device_type]]
    files_to_flash_full = [(str(Path(bin_folder) / f"{device_type}_{f[0]}"), f[1]) for f in files_to_flash]

    cmd = [
        sys.executable, "-m", "esptool",
        "--chip", "esp32c6",
        "--port", port,
        "--baud", "115200",
        "--before", "default-reset",
        "--after", "hard-reset",
        "write_flash",
        "--flash-mode", "dio",
        "--flash-freq", "40m",
        "--flash-size", "16MB"
    ]
    for path, offset in files_to_flash_full:
        cmd += [hex(offset), path]

    print(f"Flashing {device_type} on {port}...")
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    pbar = tqdm(total=100, desc=f"Flashing {device_type} ({port})", unit="%")

    for line in process.stdout:
        if "Writing at" in line or "Erasing" in line:
            pbar.update(1)
            if pbar.n > 95:
                pbar.n = 95
        else:
            log(line.strip())  # log everything else to file, keep console clean

    process.wait()
    pbar.n = 100
    pbar.refresh()
    pbar.close()
    log(f"Flashing complete on {port} ({device_type}).")
    print(f"Flashing complete on {port} ({device_type}).")
    reset_device(port)
    verify_serial(port, device_type)

def listen_to_port(port_name, results):
    try:
        with serial.Serial(port_name, baudrate=115200, timeout=0.1) as ser:
            buffer = ""
            while True:
                if port_name in results:
                    return
                if ser.in_waiting:
                    data = ser.read(ser.in_waiting)
                    buffer += data.decode(errors='ignore')
                    if "ATK-C6" in buffer:
                        log(f"ATK-C6 detected on {port_name}")
                        results[port_name] = "ATK-C6"
                        return
                    elif "ESP-RGB" in buffer or "RGB" in buffer:
                        log(f"ESP-RGB detected on {port_name}")
                        results[port_name] = "ESP-RGB"
                        return
    except serial.SerialException as e:
        log(f"[!] {port_name} serial error: {e}")

if __name__ == "__main__":
    print("Detecting CP210x devices for flash...")
    devices = find_cp210x_devices()
    if not devices:
        print("No CP210x devices found.")
        sys.exit(1)

    results = {}
    threads = []

    for dev in devices:
        port = dev.device
        if wait_for_port_ready(port):
            t = threading.Thread(target=listen_to_port, args=(port, results))
            t.start()
            threads.append(t)

    for t in threads:
        t.join()

    if not results:
        print("No ESP devices detected.")
        sys.exit(0)

    for port, dev_type in results.items():
        print(f"Found Device: {dev_type} on {port}")
        log(f"Detected {dev_type} on {port}")

    download_bins()

    if len(results) == 1:
        port, dev_type = list(results.items())[0]
        print(f"\nOnly one device detected: {dev_type} on {port}")
        choice = input(f"Flash this device? [y/n]: ").strip().lower()
        if choice == "y":
            flash_device_with_progress(port, dev_type)
            print("\nDevice flashed")

            choice2 = input("\nDo you want to flash the other device? [y/n]: ").strip().lower()
            if choice2 == "y":
                opposite_type = "ESP-RGB" if dev_type == "ATK-C6" else "ATK-C6"
                print(f"\nPlease plug in the {opposite_type} device now...")

                new_port = None
                print("Waiting for new CP210x device...")
                while new_port is None:
                    current_ports = {p.device for p in find_cp210x_devices()}
                    old_ports = set(results.keys())
                    new_ports = current_ports - old_ports
                    if new_ports:
                        new_port = new_ports.pop()
                        print(f"New device detected on {new_port}")
                        if wait_for_port_ready(new_port):
                            temp_results = {}
                            listen_to_port(new_port, temp_results)
                            if temp_results.get(new_port) == opposite_type:
                                print(f"Correct device type detected: {opposite_type}")
                                flash_device_with_progress(new_port, opposite_type)
                                print("\nSecond device flashed")
                            else:
                                print(f"[!] Detected device is not {opposite_type}. Aborting.")
                    time.sleep(1)
        else:
            print("\nFlashing aborted by user.")
    else:
        choice = input("\nDo you want to flash the detected devices? [y/n]: ").strip().lower()
        if choice == "y":
            for port, dev_type in results.items():
                flash_device_with_progress(port, dev_type)

    print("\nAll flashing operations complete.")