# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-based reader
==================
Base class for readers that use pre-trained models for
generating diagrams.
"""
from __future__ import annotations
__all__ = ['ModelBasedReader']
from abc import abstractmethod
from pathlib import Path
from typing import Any
import torch
from lambeq.core.globals import VerbosityLevel
from lambeq.core.utils import (SentenceBatchType,
                               tokenised_batch_type_check,
                               TokenisedSentenceBatchType,
                               untokenised_batch_type_check)
from lambeq.text2diagram.base import Reader
from lambeq.text2diagram.model_based_reader.model_downloader import (
    ModelDownloader,
    ModelDownloaderError,
    MODELS
)
from lambeq.typing import StrPathT
[docs]
class ModelBasedReader(Reader):
    """Base class for readers that use pre-trained models.
    This is an abstract base class that provides common functionality for
    model-based readers. Subclasses must implement the specific model
    initialization and inference logic.
    """
[docs]
    def __init__(
        self,
        model_name_or_path: str | None = None,
        device: int | str | torch.device = 'cpu',
        cache_dir: StrPathT | None = None,
        force_download: bool = False,
        verbose: str = VerbosityLevel.PROGRESS.value,
    ) -> None:
        """Initialise the model-based reader.
        Parameters
        ----------
        model_name_or_path : str, default: 'bert'
            Can be either:
                - The path to a directory containing a model.
                - The name of a pre-trained model.
        device : int, str, or torch.device, default: 'cpu'
            Specifies the device on which to run the tagger model.
            - For CPU, use `'cpu'`.
            - For CUDA devices, use `'cuda:<device_id>'` or `<device_id>`.
            - For Apple Silicon (MPS), use `'mps'`.
            - You may also pass a :py:class:`torch.device` object.
            - For other devices, refer to the PyTorch documentation.
        cache_dir : str or os.PathLike, optional
            The directory to which a downloaded pre-trained model should
            be cached instead of the standard cache.
        force_download : bool, default: False
            Force the model to be downloaded, even if it is already
            available locally.
        verbose : str, default: 'progress'
            See :py:class:`VerbosityLevel` for options.
        """
        super().__init__(verbose=verbose)
        if model_name_or_path is None:
            raise ValueError(f'Invalid value `{model_name_or_path}`'
                             ' for argument `model_name_or_path`.')
        self.model_name_or_path = model_name_or_path
        self.device = device
        self.cache_dir = cache_dir
        self.force_download = force_download
        self.model_dir: Path | None = None
        # Prepare model artifacts
        self._prepare_model_artifacts() 
    def _prepare_model_artifacts(self) -> None:
        """Download model artifacts to disk."""
        self.model_dir = Path(self.model_name_or_path)
        if not self.model_dir.is_dir():
            # Check for updates only if a local model path is not
            # specified in `self.model_name_or_path`
            downloader = ModelDownloader(self.model_name_or_path,
                                         self.cache_dir)
            self.model_dir = downloader.model_dir
            if (self.force_download
                    or not self.model_dir.is_dir()
                    or downloader.model_is_stale()):
                try:
                    downloader.download_model(self.verbose)
                except ModelDownloaderError as e:
                    local_model_version = downloader.get_local_model_version()
                    if (self.model_dir.is_dir()
                            and local_model_version is not None):
                        print('Failed to update model with '
                              f'exception: {e}')
                        print('Attempting to continue with version '
                              f'{local_model_version}')
                    else:
                        # No local version to fall back to
                        raise e
    @abstractmethod
    def _initialise_model(self, **kwargs: Any) -> None:
        """Initialise the model and put it into the appropriate device.
        Also handle required miscellaneous initialisation steps here."""
[docs]
    def validate_sentence_batch(
        self,
        sentences: SentenceBatchType,
        tokenised: bool = False,
        suppress_exceptions: bool = False,
    ) -> tuple[TokenisedSentenceBatchType, list[int]]:
        """Prepare input sentences for parsing.
        Parameters
        ----------
        sentences : list of str, or list of list of str
            The sentences to be parsed, passed either as strings or as
            lists of tokens.
        suppress_exceptions : bool, default: False
            Whether to suppress exceptions. If :py:obj:`True`, then if a
            sentence fails to parse, instead of raising an exception,
            its return entry is :py:obj:`None`.
        tokenised : bool, default: False
            Whether each sentence has been passed as a list of tokens.
        verbose : str, optional
            See :py:class:`VerbosityLevel` for options. If set, takes
            priority over the :py:attr:`verbose` attribute of the
            parser.
        Returns
        -------
        SentenceBatchType
            List of (tokenised or untokenised) sentences
        """
        tokenised_sentences: TokenisedSentenceBatchType
        if tokenised:
            if not tokenised_batch_type_check(sentences):
                raise ValueError('`tokenised` set to `True`, but variable '
                                 '`sentences` does not have type '
                                 '`List[List[str]]`.')
            tokenised_sentences = list(sentences)   # type: ignore[arg-type]
        else:
            if not untokenised_batch_type_check(sentences):
                raise ValueError('`tokenised` set to `False`, but variable '
                                 '`sentences` does not have type '
                                 '`List[str]`.')
            sent_list: list[str] = [str(s) for s in sentences]
            tokenised_sentences = [sentence.split() for sentence in sent_list]
        # Remove empty sentences
        empty_indices: list[int] = []
        for i, sentence in enumerate(tokenised_sentences):
            if not sentence:
                if suppress_exceptions:
                    empty_indices.append(i)
                else:
                    raise ValueError('sentence is empty.')
        for i in reversed(empty_indices):
            del tokenised_sentences[i]
        return tokenised_sentences, empty_indices 
[docs]
    @staticmethod
    def available_models() -> list[str]:
        """List the available models."""
        return [*MODELS]