#!/usr/bin/env python3

# All imports first
import asyncio
import logging
import os
import sys
import textwrap
import time
import uuid
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import asynccontextmanager
from functools import cache
from pathlib import Path

import aiohttp
import openai
import qdrant_client
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastembed import SparseTextEmbedding, TextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
from pydantic import BaseModel, Field
from pymilvus import AnnSearchRequest, AsyncMilvusClient, RRFRanker  # type: ignore

warnings.filterwarnings("ignore", category=UserWarning)

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
RANK_MODEL = os.getenv("RANK_MODEL", "Xenova/ms-marco-MiniLM-L-6-v2")
COLLECTION_NAME = "rag"
# Needed for mac to not give errors
os.environ["TOKENIZERS_PARALLELISM"] = "true"


# OpenAI API Compatible Data Models
class ChatMessage(BaseModel):
    role: str = Field(description="The role of the message author")
    content: str = Field(description="The contents of the message")


class ChatCompletionRequest(BaseModel):
    messages: list[ChatMessage] = Field(description="A list of messages comprising the conversation")
    model: str = Field("", description="ID of the model to use")
    temperature: float | None = Field(1.0, ge=0, le=2, description="Sampling temperature")
    top_p: float | None = Field(1.0, ge=0, le=1, description="Nucleus sampling")
    n: int | None = Field(1, ge=1, le=128, description="Number of completions to generate")
    stream: bool | None = Field(False, description="Whether to stream back partial progress")
    stop: str | list[str] | None = Field(None, description="Sequences where the API will stop generating")
    max_completion_tokens: int | None = Field(None, ge=1, description="Maximum number of tokens to generate")
    presence_penalty: float | None = Field(0, ge=-2, le=2, description="Presence penalty")
    frequency_penalty: float | None = Field(0, ge=-2, le=2, description="Frequency penalty")
    logit_bias: dict[str, float] | None = Field(None, description="Modify likelihood of specified tokens")
    user: str | None = Field(None, description="A unique identifier representing your end-user")


class Usage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int


class Choice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: str | None = None


class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: list[Choice]
    usage: Usage


# Streaming models
class Delta(BaseModel):
    role: str | None = None
    content: str | None = None
    reasoning_content: str | None = None


class StreamChoice(BaseModel):
    index: int
    delta: Delta
    finish_reason: str | None = None


class ChatCompletionStreamResponse(BaseModel):
    id: str
    object: str = "chat.completion.chunk"
    created: int
    model: str
    choices: list[StreamChoice]


def eprint(e, exit_code):
    print("Error:", str(e).strip("'\""), file=sys.stderr)
    sys.exit(exit_code)


# Helper Classes and Functions


async def request(host: str, port: int, path: str, timeout: int = 10):
    async with aiohttp.ClientSession(
        f"http://{host}:{port}",
        timeout=aiohttp.ClientTimeout(total=timeout),
    ) as session:
        async with session.get(path) as resp:
            data = await resp.json()
            return resp, data


async def wait_for_llama_server(host: str, port: int, total_timeout: int = 120, print_interval: int = 5):
    end_time = time.monotonic() + total_timeout
    last_print = 0
    request_timeout = 1  # Timeout per individual HTTP request

    while time.monotonic() < end_time:
        try:
            resp, data = await request(host, port, "/health", timeout=request_timeout)
            if resp.status == 200:
                return True

            now = time.monotonic()
            if now - last_print >= print_interval:
                print(f"Server at {host}:{port} is running but not ready (status={resp.status}), retrying...")
                last_print = now

        except Exception as e:
            now = time.monotonic()
            if now - last_print >= print_interval:
                print(f"Error connecting to {host}:{port}: {e}, retrying...")
                last_print = now

        await asyncio.sleep(1)

    raise TimeoutError(f"LLaMA server at {host}:{port} did not become ready after {total_timeout} seconds.")


class qdrant:
    def __init__(self, vector_path):
        self.client = qdrant_client.AsyncQdrantClient(path=vector_path)
        self.client.set_model(EMBED_MODEL, local_files_only=True)
        self.client.set_sparse_model(SPARSE_MODEL, local_files_only=True)

    async def query(self, prompt):
        results = await self.client.query(
            collection_name="rag",
            query_text=prompt,
            limit=20,
        )
        return [r.document for r in results]


class milvus:
    def __init__(self, vector_path):
        self.milvus_client = AsyncMilvusClient(uri=os.path.join(vector_path, "milvus.db"))
        self.dmodel = TextEmbedding(model_name=EMBED_MODEL, local_files_only=True)
        self.smodel = SparseTextEmbedding(model_name=SPARSE_MODEL, local_files_only=True)

    async def query(self, prompt):
        dense_embedding = next(self.dmodel.embed([prompt]))
        sparse_embedding = next(self.smodel.embed([prompt])).as_dict()

        search_param_dense = {
            "data": [dense_embedding],
            "anns_field": "dense",
            "param": {"metric_type": "IP", "params": {"nprobe": 10}},
            "limit": 10,
        }

        request_dense = AnnSearchRequest(**search_param_dense)

        search_param_sparse = {
            "data": [sparse_embedding],
            "anns_field": "sparse",
            "param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
            "limit": 10,
        }

        request_sparse = AnnSearchRequest(**search_param_sparse)

        reqs = [request_dense, request_sparse]

        ranker = RRFRanker(100)

        results = await self.milvus_client.hybrid_search(
            collection_name=COLLECTION_NAME,
            reqs=reqs,
            ranker=ranker,
            limit=20,
            output_fields=["text"],
        )
        return [hit["entity"]["text"] for hit in results[0]]


class OpenAICompatibleRagService:
    def __init__(self, vector_path, model_host, model_port):
        self.reranker = TextCrossEncoder(model_name=RANK_MODEL, local_files_only=True)

        if self.is_milvus(vector_path):
            # setup milvus
            self.vectordb = milvus(vector_path)
        else:
            # setup qdrant
            self.vectordb = qdrant(vector_path)

        self.llm = openai.AsyncOpenAI(
            api_key="your-api-key",
            base_url=f"http://{model_host}:{model_port}",
            http_client=openai.DefaultAioHttpClient(),
        )

    def is_milvus(self, vector_path):
        return any(f.suffix == ".db" and f.is_file() for f in Path(vector_path).iterdir())

    def _extract_conversation_context(self, messages: list[ChatMessage]) -> str:
        """Extract conversation context from message history"""
        # Use last few messages as conversation context
        context_messages = messages[-6:]
        return "\n".join([f"{msg.role}: {msg.content}" for msg in context_messages])

    async def _perform_rag_lookup(self, query: str) -> str:
        """Perform RAG vector search and reranking"""
        # Vector search
        results = await self.vectordb.query(query)

        # Rerank and select top 5
        reranked_context = " ".join(
            str(results[i])
            for i, _ in sorted(enumerate(self.reranker.rerank(query, results)), key=lambda x: x[1], reverse=True)[:5]
        )
        return reranked_context.strip()

    async def _build_rag_enhanced_messages(self, messages: list[ChatMessage]) -> list[dict]:
        """Build messages with RAG context for the LLM"""
        if not messages:
            return []

        # Get the latest user message
        latest_message = messages[-1]
        if latest_message.role != "user":
            return [{"role": msg.role, "content": msg.content} for msg in messages]

        # Perform RAG lookup
        rag_context = await self._perform_rag_lookup(latest_message.content)
        conversation_history = self._extract_conversation_context(messages[:-1])

        # Enhanced system prompt with RAG context
        system_prompt = (
            textwrap.dedent(
                """
            You are an expert software architect assistant.
            Use the provided context and conversation history to answer questions accurately and concisely.
            If the answer is not in the context, respond with "I don't know" - do not fabricate details.

            ### Conversation History:
            {0}

            ### Retrieved Context:
            {1}

            ### Current Question:
            {2}

            Provide a helpful and accurate response based on the context above.
            """
            )
            .strip()
            .format(conversation_history, rag_context, latest_message.content)
        )

        return [{"role": "user", "content": system_prompt}]

    async def create_chat_completion(
        self, request: ChatCompletionRequest
    ) -> ChatCompletionResponse | StreamingResponse:
        """OpenAI-compatible chat completion endpoint"""
        completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
        created = int(time.time())

        # Build RAG-enhanced messages
        enhanced_messages = await self._build_rag_enhanced_messages(request.messages)

        if request.stream:
            return await self._stream_completion(completion_id, created, request, enhanced_messages)
        else:
            return await self._complete_chat(completion_id, created, request, enhanced_messages)

    async def _complete_chat(
        self, completion_id: str, created: int, request: ChatCompletionRequest, enhanced_messages: list[dict]
    ) -> ChatCompletionResponse:
        """Non-streaming completion"""
        response = await self.llm.chat.completions.create(
            model=request.model,
            messages=enhanced_messages,
            temperature=request.temperature,
            max_completion_tokens=request.max_completion_tokens,
            stream=False,
        )

        # Convert to OpenAI format
        return ChatCompletionResponse(
            id=completion_id,
            created=created,
            model=request.model,
            choices=[
                Choice(
                    index=0,
                    message=ChatMessage(role="assistant", content=response.choices[0].message.content),
                    finish_reason=response.choices[0].finish_reason,
                )
            ],
            usage=Usage(
                prompt_tokens=response.usage.prompt_tokens,
                completion_tokens=response.usage.completion_tokens,
                total_tokens=response.usage.total_tokens,
            ),
        )

    async def _stream_completion(
        self, completion_id: str, created: int, request: ChatCompletionRequest, enhanced_messages: list[dict]
    ) -> StreamingResponse:
        """OpenAI-compatible streaming response generator"""

        async def generate_stream():
            # First chunk with role
            first_chunk = ChatCompletionStreamResponse(
                id=completion_id,
                created=created,
                model=request.model,
                choices=[StreamChoice(index=0, delta=Delta(role="assistant"), finish_reason=None)],
            )
            yield f"data: {first_chunk.model_dump_json()}\n\n"

            # Stream content from LLM
            response = await self.llm.chat.completions.create(
                model=request.model,
                messages=enhanced_messages,
                temperature=request.temperature,
                max_completion_tokens=request.max_completion_tokens,
                stream=True,
            )

            async for chunk in response:
                if chunk.choices and (delta := chunk.choices[0].delta):
                    # Use getattr to safely access optional attributes
                    content = getattr(delta, 'content', None)
                    reasoning_content = getattr(delta, 'reasoning_content', None)

                    # Only send chunk if there's actual content or reasoning
                    if content is not None or reasoning_content is not None:
                        stream_chunk = ChatCompletionStreamResponse(
                            id=completion_id,
                            created=created,
                            model=request.model,
                            choices=[
                                StreamChoice(
                                    index=0,
                                    delta=Delta(content=content, reasoning_content=reasoning_content),
                                    finish_reason=None,
                                )
                            ],
                        )
                        yield f"data: {stream_chunk.model_dump_json()}\n\n"

                # Handle finish reason
                if chunk.choices and chunk.choices[0].finish_reason:
                    final_chunk = ChatCompletionStreamResponse(
                        id=completion_id,
                        created=created,
                        model=request.model,
                        choices=[StreamChoice(index=0, delta=Delta(), finish_reason=chunk.choices[0].finish_reason)],
                    )
                    yield f"data: {final_chunk.model_dump_json()}\n\n"

            # End of stream marker
            yield "data: [DONE]\n\n"

        return StreamingResponse(
            generate_stream(),
            media_type="text/plain",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
            },
        )


@cache
def initialize_rag_service() -> OpenAICompatibleRagService:
    args = get_args()
    return OpenAICompatibleRagService(
        args.vector_path,
        args.model_host,
        args.model_port,
    )


@asynccontextmanager
async def lifespan(app: FastAPI):
    args = get_args()
    initialize_rag_service()
    await wait_for_llama_server(args.model_host, args.model_port, total_timeout=120)
    yield


# FastAPI Application
app = FastAPI(
    title="RAG-Enhanced OpenAI Compatible API",
    lifespan=lifespan,
)


@app.post("/v1/chat/completions", response_model=None)
async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse | StreamingResponse:
    """OpenAI-compatible chat completions endpoint"""
    try:
        service = initialize_rag_service()
        return await service.create_chat_completion(request)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/models")
async def llama_models() -> dict:
    """Return the response from the llama.cpp /models endpoint."""
    args = get_args()
    try:
        resp, data = await request(args.model_host, args.model_port, "/models")
        if resp.status == 200:
            for model in data.get("models", []):
                # Append +rag to the model name so it matches what the client expects
                model["name"] += "+rag"
            return data
        raise HTTPException(status_code=resp.status, detail=resp.reason)
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"LLM service unavailable: {str(e)}")


@app.get("/v1/models")
async def list_models() -> dict:
    """OpenAI-compatible models endpoint"""
    try:
        service = initialize_rag_service()
        # Get available models from the LLM service
        llm_models = await service.llm.models.list()
        model_data = []

        # Add each model from the LLM service with +rag suffix
        for model in llm_models.data:
            model_data.append(
                {
                    "id": f"{model.id}+rag",
                    "object": "model",
                    "created": getattr(model, "created", int(time.time())),
                    "owned_by": getattr(model, "owned_by", "ramalama"),
                }
            )
        return {"object": "list", "data": model_data}
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"LLM service unavailable: {str(e)}")


@app.get("/health")
async def health_check() -> dict[str, str]:
    try:
        args = get_args()
        resp, data = await request(args.model_host, args.model_port, "/health")
        if resp.status != 200:
            raise HTTPException(status_code=resp.status, detail=resp.reason)
        service = initialize_rag_service()
        # Verify core components are initialized
        if service.vectordb is None or service.llm is None or service.reranker is None:
            raise HTTPException(status_code=503, detail="RAG has not been initialized")
        return {"status": "ok", "rag": "ok"}
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"LLM service unavailable: {str(e)}")


def serve_rag_api(args):
    """Start the OpenAI-compatible RAG API server"""
    uvicorn.run(app, host=args.host, port=args.port)


def load():
    client = qdrant_client.QdrantClient(":memory:")
    client.set_model(EMBED_MODEL)
    client.set_sparse_model(SPARSE_MODEL)
    TextCrossEncoder(model_name=RANK_MODEL)


parser = ArgumentParser(description="OpenAI-compatible RAG API server")
parser.add_argument("--debug", action="store_true", help="Output debug information")
subparsers = parser.add_subparsers(dest='command')

serve_parser = subparsers.add_parser('serve', help='Run RAG as OpenAI-compatible HTTP API server')
serve_parser.add_argument("vector_path", type=str, help="Path to the vector database")
serve_parser.add_argument("--model-host", default="localhost", help="The hostname where the model is being served")
serve_parser.add_argument("--model-port", default=8080, type=int, help="The port where the model is being served")
serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind server")
serve_parser.add_argument("--port", default=8081, type=int, help="Port for RAG API")
serve_parser.set_defaults(func=serve_rag_api)

load_parser = subparsers.add_parser('load', help='Preload RAG Embedding Models')
load_parser.set_defaults(func=load)


@cache
def get_args() -> Namespace:
    return parser.parse_args()


def setup_logging(args: Namespace):
    if args.debug:
        level = logging.DEBUG
    else:
        level = logging.WARNING
    logging.basicConfig(
        level=level,
        style="{",
        format="{asctime} {name} {levelname}: {message}",
    )


if __name__ == "__main__":
    try:
        args = get_args()
        setup_logging(args)

        if args.command:
            if args.command in ['serve']:
                args.func(args)
            else:
                # no argument for 'load'
                args.func()
        else:
            parser.print_help()
    except ValueError as e:
        eprint(e, 1)
