#!/usr/bin/env python3
import json
import argparse
import importlib
from typing import List, Dict, Union, Tuple, Optional, Any

import rclpy
import setproctitle
from rclpy.executors import MultiThreadedExecutor
from rclpy.utilities import try_shutdown

from agents import config as all_configs
from agents import components as all_components
from agents import clients
from agents.clients.model_base import ModelClient
from agents.ros import (
    Topic,
    FixedInput,
    MapLayer,
    Route,
    QoSConfig,
    SupportedType,
    Event,
)

# Python3.8 compatible NoneType
NoneType = type(None)

def _parse_args() -> Tuple[argparse.Namespace, List[str]]:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Component Executable Config")
    parser.add_argument(
        "--config_type", type=str, help="Component configuration class name"
    )
    parser.add_argument("--component_type", type=str, help="Component class name")
    parser.add_argument(
        "--node_name",
        type=str,
        help="Component ROS2 node name",
    )
    parser.add_argument("--config", type=str, help="Component configuration object")
    parser.add_argument(
        "--inputs",
        type=str,
        help="Component input topics",
    )
    parser.add_argument(
        "--outputs",
        type=str,
        help="Component output topics",
    )
    parser.add_argument(
        "--routes",
        type=str,
        help="Semantic router routes",
    )
    parser.add_argument(
        "--layers",
        type=str,
        help="Map Encoding layers",
    )
    parser.add_argument(
        "--trigger",
        type=str,
        help="Component trigger",
    )
    parser.add_argument(
        "--model_client",
        type=str,
        help="Model Client",
    )
    parser.add_argument(
        "--db_client",
        type=str,
        help="DB Client",
    )
    parser.add_argument(
        "--additional_model_clients",
        type=str,
        help="Additional model clients",
    )
    parser.add_argument(
        "--config_file", type=str, help="Path to configuration YAML file"
    )
    parser.add_argument(
        "--events", type=str, help="Events to be monitored by the component"
    )
    parser.add_argument(
        "--actions", type=str, help="Actions associated with the component Events"
    )
    parser.add_argument(
        "--fallbacks", type=str, help="Fallbacks to be executed on component Failure"
    )
    parser.add_argument(
        "--external_processors",
        type=str,
        help="External processors associated with the component input and output topics",
    )
    parser.add_argument(
        "--additional_types",
        type=str,
        help="Additional type modules from derived packages",
    )

    return parser.parse_known_args()


def _parse_component_config(
    args: argparse.Namespace,
) -> all_configs.BaseComponentConfig:
    """Parse the component config object

    :param args: Command line arguments
    :type args: argparse.Namespace

    :return: Component config object
    :rtype: object
    """
    config_type = args.config_type or None
    if not config_type:
        raise ValueError("config_type must be provided")

    # Get config type and update from json arg
    config_class = getattr(all_configs, config_type)
    if not config_class:
        raise TypeError(
            f"Unknown config_type '{config_type}'. Known types are {all_configs.__all__}"
        )

    config = config_class(**json.loads(args.config))

    return config


def _parse_additional_types(value: str):
    """Get additional types"""
    serialized_types = json.loads(value)
    _additional_types = []
    for s_t in serialized_types:
        module_name, _, class_name = s_t.rpartition(".")
        if not module_name:
            continue
        module = importlib.import_module(module_name)
        new_type = getattr(module, class_name)
        if issubclass(new_type, SupportedType):
            _additional_types.append(new_type)
    return _additional_types


def _parse_trigger(
    trigger_str: str,
) -> Union[Topic, List[Topic], float, Event, NoneType]:
    """Parse component trigger json string

    :param trigger_str: Trigger JSON string
    :type trigger_str: str

    :return: Trigger topics or float
    :rtype: Topic | List[Topic] | float
    """
    # TODO: Handle additional types here

    # Deserialize main dict or float value
    trigger_deserialized = json.loads(trigger_str)
    if isinstance(trigger_deserialized, Dict):
        # Deserialize internal trigger content
        if trigger_deserialized["trigger_type"] == "List":
            # List always contains topics
            return [
                Topic(**json.loads(t))
                for t in json.loads(trigger_deserialized["trigger"])
            ]
        elif trigger_deserialized["trigger_type"] == "Topic":
            return Topic(**json.loads(trigger_deserialized["trigger"]))
        elif trigger_deserialized["trigger_type"] == "Event":
            return Event.from_json(trigger_deserialized["trigger"])
    else:
        # return float or None
        return trigger_deserialized


def _deserialize_topics(
    serialized_topics: str, additional_types: Optional[List] = None
) -> List[Dict]:
    list_of_str = json.loads(serialized_topics)
    topic_dicts = []
    for t in list_of_str:
        topic_dict = json.loads(t)
        topic_dict["qos_profile"] = QoSConfig(**topic_dict.get("qos_profile", {}))
        topic_dict["additional_types"] = (
            additional_types if additional_types else []
        )  # Add any additional types
        topic_dicts.append(topic_dict)
    return topic_dicts


def _load_primary_clients(
    args: argparse.Namespace,
) -> Tuple[Optional[Any], Optional[Any]]:
    """Instantiates Model and DB clients."""
    model_client = None
    db_client = None

    if args.model_client:
        mc_json = json.loads(args.model_client)
        model_client = getattr(clients, mc_json["client_type"])(**mc_json)

    if args.db_client:
        dbc_json = json.loads(args.db_client)
        db_client = getattr(clients, dbc_json["client_type"])(**dbc_json)

    return model_client, db_client


def _load_additional_model_clients(
    additional_clients_json: str,
) -> Dict[str, ModelClient]:
    """Initialize additional model clients"""
    _additional_clients = json.loads(additional_clients_json)
    for k, v in _additional_clients.items():
        _additional_clients[k] = getattr(clients, v["client_type"])(**v)
    return _additional_clients


def _parse_ros_args(args_names: List[str]) -> List[str]:
    """Parse ROS arguments from command line arguments

    :param args_names: List of all parsed arguments
    :type args_names: list[str]

    :return: List ROS parsed arguments
    :rtype: list[str]
    """
    # Look for --ros-args in ros_args
    ros_args_start = None
    if "--ros-args" in args_names:
        ros_args_start = args_names.index("--ros-args")

    if ros_args_start is not None:
        ros_specific_args = args_names[ros_args_start:]
    else:
        ros_specific_args = []
    return ros_specific_args


def _setup_component_post_init(component: Any, args: argparse.Namespace) -> None:
    """Perform post-initialization setup on the component instance.

    :param component: The instantiated component object
    :param args: Parsed command line arguments
    """
    # Init the node with rclpy
    component.rclpy_init_node()

    # Set events/actions
    if args.events and args.actions:
        component._events_json = args.events
        component._actions_json = args.actions

    # Set fallbacks
    if fallbacks_json := args.fallbacks:
        component._fallbacks_json = fallbacks_json

    # Set external processors
    if args.external_processors:
        component._external_processors_json = args.external_processors

    # Set additional model clients if any
    component.additional_model_clients = (
        _load_additional_model_clients(args.additional_model_clients)
        if args.additional_model_clients
        else None
    )


def main():
    """Executable main function to run a component as a ROS2 node in a new process.
    Used to start a node using Sugarcoat Launcher. Extends functionality from ROS Sugar

    :param list_of_components: List of all known Component classes in the package
    :type list_of_components: List[Type]
    :param list_of_configs: List of all known ComponentConfig classes in the package
    :type list_of_configs: List[Type]
    :raises ValueError: If component or component config are unknown classes
    :raises ValueError: If component cannot be started with provided arguments
    """
    args, args_names = _parse_args()

    # Initialize rclpy with the ros-specific arguments
    rclpy.init(args=_parse_ros_args(args_names))

    component_type = args.component_type or None

    if not component_type:
        raise ValueError("Cannot launch without providing a component_type")

    comp_class = getattr(all_components, component_type)

    if not comp_class:
        raise ValueError(
            f"Cannot launch unknown component type '{component_type}'. Known types are: '{all_components.__all__}'"
        )

    # Get name
    component_name = args.node_name or None

    if not component_name:
        raise ValueError("Cannot launch component without specifying a name")

    # SET PROCESS NAME
    setproctitle.setproctitle(component_name)

    config = _parse_component_config(args)

    # Get Yaml config file if provided
    config_file = args.config_file or None

    additional_types = (
        _parse_additional_types(args.additional_types)
        if args.additional_types
        else None
    )
    # Get inputs/outputs/layers/routes
    inputs = (
        [
            FixedInput(**i) if i.get("fixed") else Topic(**i)
            for i in _deserialize_topics(args.inputs, additional_types)
        ]
        if args.inputs
        else None
    )
    outputs = (
        [Topic(**o) for o in _deserialize_topics(args.outputs, additional_types)]
        if args.outputs
        else None
    )

    # TODO: Handle additional types and qos in deserialization
    layers = (
        [MapLayer(**json.loads(i)) for i in json.loads(args.layers)]
        if args.layers
        else None
    )

    # TODO: Handle additional types and qos in deserialization
    routes = (
        [Route(**json.loads(r)) for r in json.loads(args.routes)]
        if args.routes
        else None
    )

    # Get triggers
    trigger = _parse_trigger(args.trigger)

    # Initialize clients
    model_client, db_client = _load_primary_clients(args)

    # Init the component
    # Semantic Router Component
    if component_type == all_components.SemanticRouter.__name__:
        component = comp_class(
            inputs=inputs,
            routes=routes,
            db_client=db_client,
            model_client=model_client,
            config=config,
            component_name=component_name,
            config_file=config_file,
        )  # we dont pass default route here as its already part of the config
    # and will be set from there
    # Map Encoding Component
    elif component_type == all_components.MapEncoding.__name__:
        db_client_json = json.loads(args.db_client)
        db_client = getattr(clients, db_client_json["client_type"])(**db_client_json)
        if not db_client:
            raise RuntimeError("The map encoding component expects a vectorDB client")
        component = comp_class(
            layers=layers,
            position=config._position,
            map_topic=config._map_topic,
            db_client=db_client,
            config=config,
            trigger=trigger,
            component_name=component_name,
            config_file=config_file,
        )

    # All other components
    else:
        component = comp_class(
            inputs=inputs,
            outputs=outputs,
            model_client=model_client,
            db_client=db_client,
            trigger=trigger,
            config=config,
            component_name=component_name,
            config_file=config_file,
        )

    # Run Post-Init Setup
    _setup_component_post_init(component, args)

    executor = MultiThreadedExecutor()

    executor.add_node(component)

    try:
        executor.spin()

    except KeyboardInterrupt:
        pass

    finally:
        executor.remove_node(component)
        try_shutdown()


if __name__ == "__main__":
    main()
