import os
import re
import sys

def get_exe_directory():
    """获取 EXE 所在目录（兼容打包后）"""
    if getattr(sys, 'frozen', False):
        return os.path.dirname(sys.executable)
    else:
        return os.path.dirname(os.path.abspath(__file__))

def search_uvproj_files(root_path):
    uvproj_list = []
    for dirpath, _, filenames in os.walk(root_path):
        for file in filenames:
            if file.lower().endswith(".uvproj"):
                full_path = os.path.join(dirpath, file)
                uvproj_list.append(full_path)
    return uvproj_list

def extract_listing_path(uvproj_path):
    try:
        import xml.etree.ElementTree as ET
        tree = ET.parse(uvproj_path)
        root = tree.getroot()
        listing_node = root.find(".//ListingPath")
        if listing_node is not None and listing_node.text:
            uvproj_dir = os.path.dirname(uvproj_path)
            abs_path = os.path.abspath(os.path.join(uvproj_dir, listing_node.text))
            return abs_path
        return None
    except Exception:
        return None

def search_i_files(listing_path):
    i_files = []
    if not os.path.isdir(listing_path):
        return i_files
    for dirpath, _, filenames in os.walk(listing_path):
        for file in filenames:
            if file.lower().endswith(".i"):
                full_path = os.path.join(dirpath, file)
                i_files.append(full_path)
    return i_files

def extract_file_paths_from_i(i_file_path, project_root):
    file_paths = set()
    pattern = re.compile(r'#line\s+\d+\s+"([^"]+)"')
    try:
        with open(i_file_path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                match = pattern.search(line)
                if match:
                    file_path = match.group(1).strip()
                    abs_path = os.path.abspath(os.path.join(project_root, file_path))
                    if abs_path.endswith(('.h', '.c')):
                        file_paths.add(abs_path)
    except Exception:
        pass
    return file_paths

def generate_source_insight_filelist(all_file_paths, output_path):
    with open(output_path, "w", encoding="utf-8") as f:
        for path in sorted(all_file_paths):
            f.write(path + "\n")

def main():
    print("=" * 60)
    print("       Keil 工程文件列表生成器 v1.0  By 专用工具")
    print("=" * 60)

    script_dir = get_exe_directory()
    print(f"当前目录：{script_dir}")

    uvproj_files = search_uvproj_files(script_dir)
    if len(uvproj_files) == 0:
        print("\n❌ 错误：未找到任何 .uvproj 工程文件！")
        input("\n按回车键退出...")
        return
    if len(uvproj_files) > 1:
        print("\n" + "="*70)
        print("❌ 错误：检测到多个 Keil 工程文件！")
        print(f"共找到 {len(uvproj_files)} 个：")
        for f in uvproj_files:
            print(f" → {os.path.basename(f)}")
        print("\n请将本程序放到【单个 Keil 工程目录】下运行！")
        print("="*70)
        input("\n按回车键退出...")
        return

    print(f"\n✅ 找到工程文件：{os.path.basename(uvproj_files[0])}")
    listing_path = extract_listing_path(uvproj_files[0])
    if not listing_path:
        print("❌ 错误：无法提取 ListingPath！")
        input("\n按回车键退出...")
        return

    print(f"✅ 编译目录：{listing_path}")
    all_i_files = search_i_files(listing_path)
    if not all_i_files:
        print("❌ 错误：未找到任何 .i 预处理文件！")
        input("\n按回车键退出...")
        return

    print(f"✅ 找到 {len(all_i_files)} 个 .i 文件，开始解析...")

    all_found_paths = set()
    for i_file in all_i_files:
        paths = extract_file_paths_from_i(i_file, script_dir)
        all_found_paths.update(paths)

    output_file = os.path.join(script_dir, "SourceInsight_FileList.txt")
    generate_source_insight_filelist(list(all_found_paths), output_file)

    print("\n" + "="*50)
    print("🎉 处理完成！")
    print(f"📄 提取文件总数：{len(all_found_paths)} 个")
    print(f"📁 输出文件：{os.path.basename(output_file)}")
    print("="*50)
    input("\n按回车键退出...")

if __name__ == "__main__":
    main()