import argparse
import functools
import json
import os
import sys
from typing import List

import torch
import torch_npu

# Make local example package importable
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
MSMODELSLIM_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..", ".."))
if MSMODELSLIM_ROOT not in sys.path:
    sys.path.insert(0, MSMODELSLIM_ROOT)

from example.common.security.path import get_valid_read_path, get_write_directory
from example.common.security.type import check_number
from example.common.utils import SafeGenerator, cmd_bool
from example.common.rot_utils.rot_qwen import rot_model
from example.common.copy_config_files import copy_config_files, modify_config_json
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim.utils.logging import set_logger_level
from msmodelslim import logger


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to float model")
    parser.add_argument("--save_path", type=str, required=True, help="Path to save quantized model")

    # Force explicit paths to avoid the /common relative-path bug
    parser.add_argument("--anti_dataset", type=str, required=True, help="JSON file for anti-outlier prompts")
    parser.add_argument("--calib_dataset", type=str, required=True, help="JSON file for calibration prompts")

    parser.add_argument("--layer_count", type=int, default=0, help="0 means all layers")
    parser.add_argument("--batch_size", type=int, default=1, help="Use small batch first on MoE")
    parser.add_argument("--mindie_format", action="store_true", help="Enable only mindie config save")
    parser.add_argument("--trust_remote_code", type=cmd_bool, default=False)
    parser.add_argument("--rot", action="store_true", help="Apply ROT before quantization")

    # safer knobs
    parser.add_argument("--skip_anti_outlier", action="store_true", help="Skip anti-outlier stage")
    parser.add_argument("--device_id", type=int, default=0, help="Current NPU id for quant ops")
    return parser.parse_args()


def custom_hook(model_config):
    model_config["quantize"] = "w8a8_dynamic"


def load_json_list(json_path: str) -> List[str]:
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError(f"{json_path} must contain a JSON list.")
    if not data:
        raise ValueError(f"{json_path} is empty.")
    if not all(isinstance(x, str) for x in data):
        raise ValueError(f"{json_path} must be a JSON list of strings.")
    return data


def get_calib_dataset_batch(tokenizer, texts: List[str], batch_size: int, device: str = "npu"):
    batches = []
    chunks = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
    for chunk in chunks:
        inputs = tokenizer(chunk, return_tensors="pt", padding=True).to(device)
        batches.append(
            [value.to(device) for _, value in inputs.data.items() if isinstance(value, torch.Tensor)]
        )
    return batches


def maybe_log_devices():
    logger.info(f"ASCEND_RT_VISIBLE_DEVICES={os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '')}")
    logger.info(f"PYTORCH_NPU_ALLOC_CONF={os.environ.get('PYTORCH_NPU_ALLOC_CONF', '')}")


def main():
    args = parse_args()
    set_logger_level("info")
    maybe_log_devices()

    # Important for Ascend quantization flow on NPU
    torch.npu.set_compile_mode(jit_compile=False)

    # Patched: do not use get_valid_read_path for model_path
    model_path = args.model_path
    if not os.path.isdir(model_path):
        raise ValueError(f"model_path must be an existing directory: {model_path}")

    anti_dataset_path = get_valid_read_path(args.anti_dataset, "json", is_dir=False)
    calib_dataset_path = get_valid_read_path(args.calib_dataset, "json", is_dir=False)
    save_path = get_write_directory(args.save_path, write_mode=0o750)

    check_number(args.batch_size, int, 1, 16, "batch_size")

    anti_prompt = load_json_list(anti_dataset_path)
    calib_prompt = load_json_list(calib_dataset_path)

    safe_generator = SafeGenerator()

    config = safe_generator.get_config_from_pretrained(
        model_path=model_path,
        trust_remote_code=args.trust_remote_code
    )

    num_layers = config.num_hidden_layers
    if args.layer_count < 0 or args.layer_count > num_layers:
        raise ValueError(
            f"Invalid layer_count={args.layer_count}. Must be between 0 and {num_layers}."
        )

    if args.layer_count != 0:
        config.num_hidden_layers = args.layer_count

    # Avoid unnecessary KV cache memory during quantization
    config.use_cache = False

    tokenizer = safe_generator.get_tokenizer_from_pretrained(
        model_path=model_path,
        config=config,
        trust_remote_code=args.trust_remote_code,
        use_fast=True,
        add_eos_token=True
    )

    # Keep embeddings/lm_head on NPU, heavy decoder layers on CPU to reduce NPU pressure
    model = safe_generator.get_model_from_pretrained(
        model_path=model_path,
        config=config,
        trust_remote_code=args.trust_remote_code,
        device_map={
            "model.embed_tokens": args.device_id,
            "model.layers": "cpu",
            "model.norm": "cpu",
            "lm_head": args.device_id,
        },
        torch_dtype="auto",
        attn_implementation="eager",
    )

    anti_dataset = get_calib_dataset_batch(tokenizer, anti_prompt, args.batch_size, model.device)
    dataset_calib = get_calib_dataset_batch(tokenizer, calib_prompt, args.batch_size, model.device)

    with torch.no_grad():
        if args.layer_count > 0:
            test_prompt = "what is deep learning?"
            test_input = tokenizer(test_prompt, return_tensors="pt").to(model.device)
            ori_out = model(**test_input)

        if args.rot:
            rot_model(model)

        if args.layer_count > 0:
            rot_out = model(**test_input)
            loss = torch.nn.MSELoss()
            logger.info(f"ROT MSE = {loss(ori_out[0], rot_out[0])}")

    if not args.skip_anti_outlier:
        with torch.no_grad():
            anti_config = AntiOutlierConfig(
                w_bit=8,
                a_bit=8,
                anti_method="m4",
                dev_type="npu",
                dev_id=args.device_id,
            )
            anti_outlier = AntiOutlier(model, calib_data=anti_dataset, cfg=anti_config)
            anti_outlier.process()

    disable_names = []
    for idx in range(config.num_hidden_layers):
        disable_names.append(f"model.layers.{idx}.mlp.gate")

    quant_config = QuantConfig(
        a_bit=8,
        w_bit=8,
        disable_names=disable_names,
        dev_type="npu",
        dev_id=args.device_id,
        act_method=1,
        pr=1.0,
        w_sym=True,
        mm_tensor=False,
    )

    calibrator = Calibrator(
        model,
        quant_config,
        calib_data=dataset_calib,
        disable_level="L0",
        mix_cfg={
            "*.mlp.*": "w8a8_dynamic",
            "*": "w8a8",
        }
    )
    calibrator.run()

    if args.mindie_format:
        quant_model_description_json_name = "quant_model_description_w8a8_dynamic.json"
        save_type = "safe_tensor"
        safetensors_name = "quant_model_weight_w8a8_dynamic.safetensors"
    else:
        quant_model_description_json_name = "quant_model_description.json"
        save_type = "ascendV1"
        safetensors_name = "quant_model_weight_w8a8_dynamic.safetensors"

    calibrator.save(
        save_path,
        json_name=quant_model_description_json_name,
        safetensors_name=safetensors_name,
        save_type=[save_type],
        part_file_size=4,
    )

    custom_hooks = {
        "config.json": functools.partial(modify_config_json, custom_hook=custom_hook)
    }
    copy_config_files(
        input_path=model_path,
        output_path=save_path,
        quant_config=quant_config,
        mindie_format=args.mindie_format,
        custom_hooks=custom_hooks,
    )


if __name__ == "__main__":
    # torch_npu lazy init can fail after big model load; initialize before main
    torch_npu.npu.init()
    main()