"""Rescheduling of failed cleanup tasks."""

import asyncio
import logging
import random
import sqlite3
import time
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Tuple
from arcnagios.nagutils import NagiosReport, WARNING
from arcnagios.utils import nth

DB_BUSY_TIMEOUT = 10

_default_log = logging.getLogger(__name__)

_CREATE_SQL = """\
CREATE TABLE %s (
    n_attempts integer NOT NULL,
    t_sched integer NOT NULL,
    task_type varchar(16) NOT NULL,
    arg text NOT NULL
)"""

def _format_time(t: float):
    return time.strftime('%Y-%m-%d %H:%M', time.localtime(t))

Handler = Callable[[str, int], Awaitable[bool]]

class TaskType:
    def __init__(self, handler: Handler,
                 min_delay: int = 3600, max_attempts: int = 12,
                 delay_dev: float = 0.1):
        self.handler = handler
        self.min_delay = min_delay
        self.max_attempts = max_attempts
        self.delay_dev = delay_dev

    def next_delay(self, n_attempts: int) -> float:
        return (self.min_delay << n_attempts) * random.gauss(1.0, self.delay_dev)

class Rescheduler:
    def __init__(self, db_path: str, table_name: str,
                 nagios_report: Optional[NagiosReport] = None,
                 log: Optional[logging.Logger] = None):
        self._db = sqlite3.connect(db_path, DB_BUSY_TIMEOUT)
        self._table = table_name
        self._task_types: Dict[str, TaskType] = {}

        self._nagios_report = nagios_report
        if log is None:
            if nagios_report is None:
                log = _default_log
            else:
                log = nagios_report.log
        self._log = log

        try:
            self._db.execute(_CREATE_SQL % self._table)
        except sqlite3.OperationalError:
            pass

    def close(self) -> None:
        self._db.close()

    def _report_warning(self, msg: str = 'Check rescheduler errors.') -> None:
        if self._nagios_report:
            self._nagios_report.update_status(WARNING, msg)

    def _update(self, stmt: str, *args) -> None:
        try:
            self._db.execute(stmt, args)
            self._db.commit()
        except sqlite3.OperationalError as exn:
            self._log.error('Failed to update rescheduled work: %s', exn)
            self._report_warning()

    def _query(self, stmt: str, *args) -> Sequence[Any]:
        try:
            return self._db.execute(stmt, args).fetchall()
        except sqlite3.OperationalError as exn:
            self._log.error('Failed to fetch rescheduled work: %s', exn)
            self._report_warning()
            return []

    def register(self, task_type_name: str, h: Handler,
                 min_delay: int = 3600, max_attempts: int = 12,
                 delay_dev: float = 0.1) -> None:
        self._task_types[task_type_name] \
            = TaskType(h, min_delay, max_attempts, delay_dev)

    def schedule(
            self, task_type_name: str, arg: str,
            n_attempts: int = 0, t_sched: Optional[float] = None
        ) -> None:
        handler = self._task_types[task_type_name]
        if t_sched is None:
            t_sched = time.time() + handler.next_delay(n_attempts)
        self._update('INSERT INTO %s (n_attempts, t_sched, task_type, arg) '
                     'VALUES (?, ?, ?, ?)' % self._table,
                     n_attempts, t_sched, task_type_name, arg)

    def _unschedule_rowid(self, rowid: int) -> None:
        self._update('DELETE FROM %s WHERE ROWID = ?' % self._table, rowid)

    def _reschedule_rowid(self, rowid: int, n_attempts: int, t_sched: float) \
            -> None:
        self._update('UPDATE %s SET n_attempts = ?, t_sched = ? '
                     'WHERE ROWID = ?' % self._table,
                     n_attempts, t_sched, rowid)

    async def call(self, task_type_name: str, arg: str) -> bool:
        if await self._task_types[task_type_name].handler(arg, 0):
            return True
        self.schedule(task_type_name, arg, n_attempts = 1)
        return False

    async def run(
            self, timeout: float, semaphore: asyncio.Semaphore
        ) -> Tuple[int, int, int, int]:
        """Run pending jobs. Currently timeout is the deadline for starting
           jobs, so the maximum full running time will be timeout plus the
           maximum time of an individual job."""

        t_now = time.time()
        t_deadline = t_now + timeout
        success_count = 0
        failed_count = 0
        resched_count = 0
        postponed_count = 0

        async def process(
                rowid: int, n_attempts: int, t_sched: float,
                task_type_name: str, arg: str):
            nonlocal success_count, failed_count, resched_count, postponed_count

            if not task_type_name in self._task_types:
                self._log.warning('No task type %s.', task_type_name)
                return
            if time.time() >= t_deadline:
                postponed_count += 1
                return

            task_type = self._task_types[task_type_name]
            try:
                is_ok = await task_type.handler(arg, n_attempts)
            except Exception as exn: # pylint: disable=broad-except
                self._log.error('Task %s(%r) raised exception: %s',
                                task_type_name, arg, exn)
                is_ok = False

            if is_ok:
                self._log.info('Finished %s(%r)', task_type_name, arg)
                self._unschedule_rowid(rowid)
                success_count += 1
            elif n_attempts >= task_type.max_attempts:
                self._log.info('Giving up on %s(%r)', task_type_name, arg)
                self._unschedule_rowid(rowid)
                failed_count += 1
            else:
                t_sched = t_now + task_type.next_delay(n_attempts)
                n_attempts += 1
                self._log.info('Scheduling %s attempt at %s to %s(%r)',
                               nth(n_attempts), _format_time(t_sched),
                               task_type_name, arg)
                self._reschedule_rowid(rowid, n_attempts, t_sched)
                resched_count += 1

        async def process_with_semaphore(row):
            async with semaphore:
                await process(*row)

        rows = self._query(
                'SELECT ROWID, n_attempts, t_sched, task_type, arg '
                'FROM %s WHERE t_sched <= ?' % self._table,
                t_now)
        if rows:
            tasks = [
                asyncio.create_task(process_with_semaphore(row)) for row in rows
            ]
            await asyncio.gather(*tasks)
        return (success_count, resched_count, failed_count, postponed_count)
