#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Copyright(C) 2007 INL
Written by Romain Bignon <romain AT inl.fr>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

$Id: nulog.py 12194 2008-01-14 17:01:20Z romain $
"""

import table
import inl
from inl import getBacktrace
from nucentral.core import cron

from ConfigParser import SafeConfigParser
import info
import os
import SOAPpy
from datetime import datetime

from twisted.internet.defer import gatherResults

etcpath = '/etc/nulog/'
mypath = os.path.dirname(__file__) + os.sep
core_conf = "core.conf"
core_defconf = "default.core.conf"
update_archives_period = 86400 # 60*60*24 = 1j
update_conntrack_period= 60
update_conntrack_newtimeout = 120

class NulogCore:

    def __init__(self, parser, dbpool):

        self.conf = parser
        self.dbpool = dbpool
        self.ulog = parser.get("DB", "table")
        self.maxrotate = parser.get("DB", "maxrotate")
        self.archives = {}

        self.check_conntrack(parser.get("DB", "conntrack"))
        cron.scheduleRepeat(update_archives_period, self.update_archives)

    def check_conntrack(self, conntrack_ulog):
        """ Check if the conntrack_ulog table exists. """

        def eb(result):
            pass

        def cb(result):
            self.dbpool.setConntrack(conntrack_ulog)
            self.archives[conntrack_ulog] = (datetime.min, datetime.max)
            cron.scheduleRepeat(update_conntrack_period, self.update_conntrack)

        return self.dbpool.runQuery("""SELECT COUNT(*) FROM %s""" % conntrack_ulog).addCallback(cb).addErrback(eb)

    def update_conntrack(self):

        def cb(result):

            try:
                max_id = int(result[0][0])
            except:
                return

            insert = self.dbpool.runQuery("INSERT INTO %s SELECT * FROM %s WHERE id <= %d AND (state = 0 OR state = 3)"
                                          % (self.ulog, self.dbpool.conntrack_ulog, max_id))
            delete = self.dbpool.runQuery("DELETE FROM %s WHERE id <= %d AND (state = 0 OR state = 3)"
                                          % (self.dbpool.conntrack_ulog, max_id))
            update = self.dbpool.runQuery("UPDATE %s SET state=3 WHERE state=1 AND timestamp<DATE_ADD(NOW(),INTERVAL %d SECOND)"
                                          % (self.dbpool.conntrack_ulog, update_conntrack_newtimeout))

            return gatherResults([insert, delete, update])

        def eb(result):
            pass

        return self.dbpool.runQuery("SELECT MAX(id) FROM %s WHERE state = 0 OR state = 3"
                                    % self.dbpool.conntrack_ulog).addErrback(eb).addCallback(cb)

    def update_archives(self, result=None, nb_table=0):
        """
            This function is used to update cache of archives list.
            We store the first and last packet timestamp, to know
            in which time interval this archive table is used.

            Call it without any param to launch the process, and it
            will call itself as callback to iterate all archive table.

            It begins with ulog_1 and increment to the latest existant table.
        """

        def eb(result):
            # We doesn't care about an error..
            pass

        if result:
            # If result != None it is because this is a callback,
            # so we can store it.

            try:
                assert result[0][0]
                assert result[0][1]

                self.archives['%s_%d' % (self.ulog, nb_table)] = result[0]
                if self.archives[self.ulog][0] == datetime.min:
                    self.archives[self.ulog] = (result[0][1], datetime.today())
            except:
                pass

        else:
            self.archives[self.ulog] = (datetime.min, datetime.today())

        table = '%s_%d' % (self.ulog, nb_table+1)

        return self.dbpool.runQuery("""SELECT MIN(timestamp), MAX(timestamp) FROM %s"""
                                    % table).addCallback(self.update_archives, nb_table+1).addErrback(eb)

    def count_entries(self, what):
        """ Use a keyword to tell what request you want to do on database.
            Goal of this function is only to count entries with filters
        """

        counts = {'conusers': 'SELECT COUNT(*) FROM users WHERE end_time IS NULL',
                  'packets':  'SELECT COUNT(*) FROM %s' % self.ulog,
                  'average1': """SELECT COUNT(*)/60 FROM %s
                                 WHERE timestamp > NOW()- INTERVAL 1 MINUTE AND state = 0 """ % self.ulog,
                  'average5': """SELECT COUNT(*)/300 FROM %s
                                 WHERE timestamp > NOW()- INTERVAL 5 MINUTE AND state = 0""" % self.ulog,
                  'average15':"""SELECT COUNT(*)/900 FROM %s
                                 WHERE timestamp > NOW()- INTERVAL 15 MINUTE AND state = 0""" % self.ulog,
                 }

        if counts.has_key(what):
            return self.dbpool.runQuery(counts[what])
        else:
            raise Exception("Unable to find this value: %s" % what)

    def table(self, name, args):
        """ This is the 'table' distant access function.
            @param name [string] function name;
            @param args [dictionnary]
        """

        tables = {
                    'TCPTable':       table.TCPTable,
                    'UDPTable':       table.UDPTable,
                    'IPsrcTable':     table.IPsrcTable,
                    'IPdstTable':     table.IPdstTable,
                    'UserTable':      table.UserTable,
                    'PacketTable':    table.PacketTable,
                    'ConUserTable':   table.ConUserTable,
                    'AppTable':       table.AppTable,
                    'PacketInfo':     info.PacketInfo,
                    'ConnTrackTable': table.ConnTrackTable,
                    'BadHosts':       table.BadHosts,
                    'BadUsers':       table.BadUsers,
                }

        p = tables.get(name) # We get class type

        if p is None:
            Exception('Table %s not found' % name)

        try:
            p = p(self.dbpool) # We instance class

            if isinstance(args, SOAPpy.Types.structType):
                args = args._asdict()
            return p(**args) # <=> p.__call__(..)

        except Exception, e:
            print getBacktrace()
            raise e

core = None

def getComponentName():
    """
        Function called by NuCentral (or wrapper) to get my name
        I also create database connection, because I know this function
        is called at program launch.
    """

    parser = SafeConfigParser()

    parser.read([etcpath + core_defconf, mypath + core_defconf, etcpath + core_conf, mypath + core_conf])
    section = "DB"
    dbpool = inl.start_dbpool(parser.get(section, "host"),
                              parser.get(section, "db"),
                              parser.get(section, "user"),
                              parser.get(section, "password"),
                              parser.get(section, "type"),
                              parser.get(section, "ip"),
                              parser.get(section, "table"))

    global core
    core = NulogCore(parser, dbpool)

    return 'nulog-core'

def getServiceList():
    """ Called by NuCentral, I return my services """

    services = dict()
    services['table'] = core.table
    services['count'] = core.count_entries
    return services
