#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Script that filters a bag file using a sequence of message filters."""

from __future__ import absolute_import, division, print_function

import argparse

import os
import sys

import yaml
from argparse import ArgumentParser
from glob import glob
from shutil import copyfile

try:
    from StringIO import StringIO
except ImportError:
    from io import BytesIO as StringIO

import rosbag
from cras import pretty_file_size
from cras_bag_tools.bag_filter import filter_bag
from cras_bag_tools.message_filter import FilterChain, MessageFilter, get_filters
from cras_bag_tools.tqdm_bag import TqdmBag


def out_path(path, fmt):
    """Resolve output path template.

    :param str path: Path of the current bag file.
    :param str fmt: The output name template.
    :return: The output path.
    :rtype: str
    """
    dirname, basename = os.path.split(path)
    name, ext = os.path.splitext(basename)
    out = fmt.format(dirname=dirname, name=name, ext=ext)
    try:
        os.makedirs(os.path.dirname(out))
    except OSError:
        pass
    return out


def copy_params_if_any(bag_path, out_bag_path):
    """If the bagfile has a sidecar parameters file named $BAG.params, copy this file to the out path.

    :param str bag_path: Source bag path.
    :param out_bag_path: Output bag path.
    """
    try:
        copyfile(bag_path + '.params', out_bag_path + '.params')
        print('Params: %s copied to %s.' % (bag_path + '.params', out_bag_path + '.params'))
    except (OSError, IOError) as ex:
        print('Params: %s not found.' % (bag_path + '.params'))


def filter_bags(bags, out_format, compression, copy_params, filter):
    """Filter all given bags using the given filter.

    :param list bags: The bags to filter.
    :param str out_format: Output path template.
    :param str compression: Output bag compression. One of 'rosbag.Compression.*' constants.
    :param bool copy_params: If True, copy parameters file along with the bag file if it exists.
    :param MessageFilter filter: The filter to apply.
    """
    i = 0
    for bag_path in bags:
        i += 1
        out_bag_path = out_path(bag_path, out_format)
        print()
        print("[{}/{}] Bag {}".format(i, len(bags), bag_path))
        print('Source:      %s' % (os.path.abspath(bag_path),))
        if not os.path.exists(bag_path):
            print('Source bag does not exist', file=sys.stderr)
            continue
        print('- Size: %s' % (pretty_file_size(os.path.getsize(bag_path)),))
        print('Destination: %s' % (os.path.abspath(out_bag_path),))
        print('- Compression: %s' % (compression,))
        if copy_params:
            copy_params_if_any(bag_path, out_bag_path)
        print()
        with TqdmBag(bag_path, skip_index=True) as bag, rosbag.Bag(out_bag_path, 'w', compression=compression) as out:
            bag.read_index()
            filter_bag(bag, out, filter)


def main():
    parser = ArgumentParser()
    parser.add_argument('bags', nargs='*', help="The bag files to filter.")
    parser.add_argument('-c', '--config', nargs='+', help="YAML configs of filters")
    parser.add_argument('-o', '--out-format', default=argparse.SUPPRESS,
                        help='Template for naming the output bag. Defaults to "{name}.proc{ext}"')
    parser.add_argument('--lz4', dest='compression', action='store_const', const=rosbag.Compression.LZ4,
                        help="Compress the bag via LZ4")
    parser.add_argument('--bz2', dest='compression', action='store_const', const=rosbag.Compression.BZ2,
                        help="Compress the bag via BZ2 (space-efficient, but very slow)")
    parser.add_argument('-f', '--filters', nargs='+', help=argparse.SUPPRESS)
    parser.add_argument('--no-copy-params', dest='copy_params', action='store_false',
                        help="If set, no .params file will be copied.")
    parser.add_argument("--list-yaml-keys", dest="list_yaml_keys", action="store_true",
                        help="Print a list of all available YAML top-level keys provided by filters.")
    parser.add_argument("--list-filters", dest="list_filters", action="store_true",
                        help="Print a list of all available filters.")

    loaded_filters = get_filters()
    unique_filters = set(loaded_filters.values())
    for f in unique_filters:
        if hasattr(f, 'add_cli_args'):
            getattr(f, 'add_cli_args')(parser)

    default_yaml_keys = ['bags', 'out_format', 'compression', 'filters']

    if "--list-filters" in sys.argv:
        for f in unique_filters:
            print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
            if f.__init__.__doc__ is not None:
                print(f.__init__.__doc__)
        sys.exit(0)

    if "--list-yaml-keys" in sys.argv:
        print("Global:")
        for arg in default_yaml_keys:
            print("  {}".format(arg))
        for f in unique_filters:
            if hasattr(f, 'yaml_config_args'):
                args = getattr(f, 'yaml_config_args')()
                if len(args) > 0:
                    print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
                    for arg in args:
                        print("  {}".format(arg))
        sys.exit(0)

    args = parser.parse_args()

    print()
    print('Command-line arguments:')
    for k, v in sorted(vars(args).items(), key=lambda kv: kv[0]):
        if v is not None:
            print('%s: %s' % (k, v))

    if args.config is None:
        args.config = []
    args.config = [list(glob(config)) for config in args.config]
    args.config = sum(args.config, [])

    print()
    print('YAML arguments:')
    yaml_keys = list(default_yaml_keys)
    for f in unique_filters:
        if hasattr(f, 'yaml_config_args'):
            yaml_keys.extend(getattr(f, 'yaml_config_args')())
    for config in args.config:
        with open(config, 'r') as f:
            cfg = yaml.safe_load(f)
            for key in yaml_keys:
                if key not in cfg:
                    continue
                if not hasattr(args, key) or getattr(args, key) is None or \
                        (isinstance(getattr(args, key), list) and len(getattr(args, key)) == 0):
                    setattr(args, key, cfg[key])
                    if key != "filters":
                        print("{}: {}".format(key, cfg[key]))

    # Process command-line args
    filters = []
    for f in unique_filters:
        if hasattr(f, 'process_cli_args'):
            getattr(f, 'process_cli_args')(filters, args)

    filter = FilterChain(filters) + MessageFilter.from_config(args.filters)
    if len(filter.filters) == 0:
        print("No filters defined, exiting.", file=sys.stderr)
        sys.exit(1)

    print()
    print('Filters:')
    print('\n'.join(str(f) for f in filter.filters))

    if args.compression is None:
        args.compression = rosbag.Compression.NONE
    if "out_format" not in args or args.out_format is None:
        args.out_format = "{name}.proc{ext}"

    filter_bags(args.bags, args.out_format, args.compression, args.copy_params, filter)


if __name__ == '__main__':
    main()
