# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import importlib.util

import pytest
from transformers import AutoTokenizer
import nemo_automodel.components.datasets.llm.squad as SQUAD
from nemo_automodel.components.datasets.llm.formatting_utils import _add_pad_token


# Explicit set of tokenizer names present in this repo
TOKENIZER_NAMES = [
    "gpt-oss-20b",
    "llama_3.2_1b",
    "qwen3_30b_a3b_instruct_2507",
]


def _specific_tokenizer_dirs() -> list[Path]:
    tokenizers_dir = Path("/home/TestData/akoumparouli/tokenizers/")
    dirs: list[Path] = []
    for name in TOKENIZER_NAMES:
        p = tokenizers_dir / name
        dirs.append(p)
    return dirs


@pytest.mark.parametrize("tokenizer_dir", _specific_tokenizer_dirs(), ids=lambda p: p.name)
@pytest.mark.parametrize("seq_length", [None, 256])
def test_formatting_prompts_func_returns_valid_shapes_and_masks(tokenizer_dir: Path, seq_length):
    os.environ["TRANSFORMERS_OFFLINE"] = "1"
    os.environ["HF_HUB_OFFLINE"] = "1"
    tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir))
    if getattr(tokenizer, "pad_token", None) is None:
        tokenizer.pad_token = tokenizer.eos_token
    if getattr(tokenizer, "pad_token_id", None) is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    example = {
        "question": "What is the capital of France?",
        "context": "France is a country in Europe. The capital of France is Paris.",
        "answers": {"text": ["Paris"]},
    }

    # Only run if this tokenizer defines a chat template; otherwise skip
    chat_template = getattr(tokenizer, "chat_template", None)
    if chat_template:
        pytest.skip(f"Tokenizer {tokenizer_dir.name} has chat-template")

    eos_token_id = getattr(tokenizer, "eos_token_id", 0)
    pad_token_id = _add_pad_token(tokenizer) or eos_token_id
    result = SQUAD._formatting_prompts_func(example, tokenizer, eos_token_id, pad_token_id, seq_length=seq_length)

    # Basic structure
    assert isinstance(result, dict)
    assert set(["input_ids", "labels"]).issubset(result.keys())
    assert len(result["input_ids"]) == len(result["labels"]) > 0

    # If seq_length is provided, lengths must match it
    if isinstance(seq_length, int):
        assert len(result["input_ids"]) == seq_length
        assert len(result["labels"]) == seq_length
        # padding after answer should be masked
        assert result['labels'][-1] == -100

    if getattr(tokenizer, "eos_token_id", None) is not None:
        # EOS should be in labels but not in input_ids
        assert result['labels'].count(tokenizer.eos_token_id) <= 1
        if pad_token_id != tokenizer.eos_token_id:
            assert tokenizer.eos_token_id != result['input_ids'][-1]
        else:
            i = len(result['input_ids']) - 1
            while i >= 0 and result['input_ids'][i] == pad_token_id:
                i -= 1
            assert i >= 0
            assert result['input_ids'][i] != tokenizer.eos_token_id

        assert tokenizer.eos_token_id in result['labels']
        if isinstance(seq_length, int):
            assert tokenizer.eos_token_id != result['labels'][-1]

    for i, label in enumerate(result['attention_mask']):
        if label == 1:
            assert result['input_ids'][i] != tokenizer.eos_token_id

    # context tokens should be masked
    assert result['labels'][0] == -100

    # There should be some supervised tokens (labels != -100)
    assert any(label != -100 for label in result["labels"])  # answer tokens should be supervised

    # If there is a trailing pad token, its label should be masked
    pad_id = getattr(tokenizer, "pad_token_id", None)
    if pad_id is None and "___PAD_TOKEN_IDS___" in result:
        pad_id = result["___PAD_TOKEN_IDS___"].get("input_ids")
    if pad_id is not None and result["input_ids"]:
        if result["input_ids"][-1] == pad_id:
            assert result["labels"][-1] == -100

    os.environ.pop("TRANSFORMERS_OFFLINE", None)
    os.environ.pop("HF_HUB_OFFLINE", None)


@pytest.mark.parametrize("tokenizer_dir", _specific_tokenizer_dirs(), ids=lambda p: p.name)
@pytest.mark.parametrize("seq_length", [None, 64])
def test_formatting_with_chat_template_when_available(tokenizer_dir: Path, seq_length):
    os.environ["TRANSFORMERS_OFFLINE"] = "1"
    os.environ["HF_HUB_OFFLINE"] = "1"
    tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir))

    # Only run if this tokenizer defines a chat template; otherwise skip
    chat_template = getattr(tokenizer, "chat_template", None)
    if not chat_template:
        pytest.skip(f"No chat_template for tokenizer: {tokenizer_dir.name}")

    example = {
        "question": "What is the capital of France?",
        "context": "France is a country in Europe. The capital of France is Paris.",
        "answers": {"text": ["Paris"]},
    }

    eos_token_id = getattr(tokenizer, "eos_token_id", 0)
    pad_token_id = _add_pad_token(tokenizer) or eos_token_id
    # We don't assume a specific start-of-turn token string here, since it varies by model family
    result = SQUAD._formatting_prompts_func_with_chat_template(
        example,
        tokenizer,
        eos_token_id,
        pad_token_id,
        seq_length=seq_length,
        start_of_turn_token=None,
    )

    assert isinstance(result, dict)
    assert set(["input_ids", "labels"]).issubset(result.keys())
    assert len(result["input_ids"]) == len(result["labels"]) > 0

    if isinstance(seq_length, int):
        assert len(result["input_ids"]) == seq_length
        assert len(result["labels"]) == seq_length
        # padding after answer should be masked
        assert result['labels'][-1] == -100

    if getattr(tokenizer, "eos_token_id", None) is not None:
        # EOS should be in labels but not in input_ids
        print(result['labels'])
        assert result['labels'].count(tokenizer.eos_token_id) <= 2, result['labels'].count(tokenizer.eos_token_id)
        eos_token_id_pos = result['labels'].index(tokenizer.eos_token_id)
        assert tokenizer.eos_token_id != result['input_ids'][eos_token_id_pos]

    # There should be some supervised tokens
    assert any(label != -100 for label in result["labels"])
    os.environ.pop("TRANSFORMERS_OFFLINE", None)
    os.environ.pop("HF_HUB_OFFLINE", None)
