#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Copyright(C) 2007 INL
Written by Romain Bignon <romain AT inl.fr>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

$Id: inl.py 12147 2008-01-11 17:05:31Z romain $
"""

# NEEDS python-mysqldb
from twisted.enterprise import adbapi # SQL
from twisted.internet import defer
from IPy import IP
import sys, traceback
from twisted.enterprise.util import safe
import struct
import socket
import re
import datetime
from cStringIO import StringIO
import nulog

#############################################
#                                           #
#            Database Object                #
#                                           #
#############################################
class DataBase:

    def __init__(self, _type, host, user, passwd, db, ip_type, ulog):
        """ Create a connection to database.
            @param _type [string] type of database ("ulog", "triggers", etc, see createRequest function)
            @param host [string] hostname of MySQL database
            @param user [string] username
            @param passwd [string] password
            @param db [string] Database name
            @param ip_type [integer] 4 or 6.
            @param ulog [string] ulog table name.
        """

        self.dbpool = adbapi.ConnectionPool("MySQLdb", host, user, passwd, db)
        self._type = _type
        self.ulog = ulog
        self.ip_type = int(ip_type)
        self.conntrack_ulog = None

    def setConntrack(self, conntrack_ulog):
        self.conntrack_ulog = conntrack_ulog

    def createRequest(self, table=None):

        types = {'ulog':     Request,
                 'triggers': TriggerRequest
                 }

        if types.has_key(self._type):
            obj = types[self._type]
        else:
            obj = types.values()[0]

        if not table:
            table = self.ulog

        return obj(self, table, self.ip_type)

    def runQuery(self, query):
        """ Send a query to SGDB and receive answer in _print_result() method.
            @param query [string] SQL formated query

            @return [Deffered] This is a deffered object. If this method is called by a
                               SOAP client, return this Deffered objet to tell SOAP to
                               wait callback result to send it to client.
        """

        print 'DB: %s;' % query
        return self.dbpool.runQuery(query)

#############################################
#                                           #
#             Request Object                #
#                                           #
#############################################
class Request:
    """ This class is used to abstract SQL requests.
        Use the same interface on a child class to make other requests..

        nulog.py decides what class is created.
    """

    def __init__(self, database, tablename, ip_type):
        """ Create a database request
            @param database [DataBase] database object
            @param tablename [string] ulog table name (really usefull?)
            @param ip_type [integer] 4 or 6.

            TODO: for now it isn't really useful to give the "tablename"
                  argument because it is database.ulog, but... it can
                  become useful soon.
        """

        self.database = database
        self.ulog = safe(tablename)

        # This is usefull to know if we work with an archive table, and
        # which prefix use, for example for usersstats table.
        m = re.match("%s(.*)" % self.database.ulog, self.ulog)
        self.sufix = m and m.groups()[0] or ''

        self.ip_type = int(ip_type)

    def select_packets(self, where, conntrack=False):
        """ Method used to select all packets which matches "where" clause. """

        if conntrack:
            if not self.database.conntrack_ulog:
                return ""
            table = self.database.conntrack_ulog
        else:
            table = self.ulog

        return """SELECT id, username, user_id, ip_saddr, ip_daddr,
                    (IF(tcp_sport IS NOT NULL, "tcp", IF(udp_sport IS NOT NULL, "udp", "icmp"))) as proto,
                    (IF(tcp_sport IS NOT NULL, tcp_sport, IF(udp_sport IS NOT NULL, udp_sport, ""))) as SPort,
                    (IF(tcp_dport IS NOT NULL, tcp_dport, IF(udp_dport IS NOT NULL, udp_dport, ""))) as DPort,
                    timestamp, oob_prefix, state
                    FROM %s
                    %s""" % (table, where)

    def get_packet(self, _id):
        """ Get a specific packet with his id """

        return """SELECT id, username, user_id, timestamp, oob_time_usec, oob_in, oob_out,
                        oob_prefix, oob_mark, ip_saddr, ip_daddr, ip_tos, ip_ttl, ip_totlen,
                        ip_ihl, ip_csum, ip_id,
                        (IF(tcp_sport IS NOT NULL, "tcp", IF(udp_sport IS NOT NULL, "udp", "icmp"))) as protocol,
                        (IF(tcp_sport IS NOT NULL, tcp_sport, IF(udp_sport IS NOT NULL, udp_sport, ""))) as sport,
                        (IF(tcp_dport IS NOT NULL, tcp_dport, IF(udp_dport IS NOT NULL, udp_dport, ""))) as dport,
                        tcp_seq, tcp_ackseq, tcp_window, tcp_urg, tcp_urgp, tcp_ack, tcp_psh, tcp_rst, tcp_syn, tcp_fin,
                        udp_len, icmp_type, icmp_code, client_os, client_app, bytes_in, bytes_out,
                        packets_in, packets_out, raw_mac
                    FROM %s
                    WHERE id = %s
                """ % (self.ulog, _id)

    def select_conusers(self, where):
        """ Get all connected users """

        return """SELECT username, user_id, ip_saddr, os_sysname, start_time, end_time
                    FROM users%s
                    %s""" % (self.sufix, where)

    def select_ports(self, proto, where):
        """ List all ports """

        if where:
            where += ' AND state = 0 AND %s_dport IS NOT NULL' % proto
        else:
            where = 'WHERE state = 0 AND %s_dport IS NOT NULL' % proto

        return """SELECT %s_dport, COUNT(*) AS packets, MIN(timestamp) AS begin, MAX(timestamp) AS end
                    FROM %s
                    %s
                    GROUP BY %s_dport
                """ % (proto, self.ulog, where, proto)

    def count_drop_port(self, proto, where):
        """ Count number of dropped ports on this protocol
            @param proto [string] 'tcp' or 'udp'
        """

        if where:
            where += ' AND state = 0 AND %s_dport IS NOT NULL' % proto
        else:
            where = 'WHERE state = 0 AND %s_dport IS NOT NULL' % proto

        return """SELECT COUNT(*) FROM %s %s""" % (self.ulog, where)

    def select_ip(self, direction, where):
        """ List all ips which matches filters """

        if where:
            where += ' AND state = 0 AND ip_%saddr IS NOT NULL' % direction
        else:
            where = 'WHERE state = 0 AND ip_%saddr IS NOT NULL' % direction

        return """SELECT ip_%saddr, COUNT(*) AS packets, MIN(timestamp) AS begin, MAX(timestamp) AS end
                    FROM %s
                    %s
                    GROUP BY ip_%saddr
                """ % (direction, self.ulog, where, direction)

    def count_drop_ip(self, direction, where):
        """ Count all dropped packets """

        if where:
            where += ' AND state = 0 AND ip_%saddr IS NOT NULL' % direction
        else:
            where = 'WHERE state = 0 AND ip_%saddr IS NOT NULL' % direction

        return """SELECT COUNT(*) FROM %s %s""" % (self.ulog, where)

    def select_apps(self, where):

        if where:
            where += ' AND state = 0 AND client_app IS NOT NULL'
        else:
            where = 'WHERE state = 0 AND client_app IS NOT NULL'

        return """SELECT client_app, COUNT(*) AS packets, MIN(timestamp) AS begin, MAX(timestamp) AS end
                    FROM %s
                    %s
                    GROUP BY client_app
                """ % (self.ulog, where)

    def count_drop_apps(self, where):
        """ Count all dropped packets """

        if where:
            where += ' AND state = 0 AND client_app IS NOT NULL'
        else:
            where = 'WHERE state = 0 AND client_app IS NOT NULL'

        return """SELECT COUNT(*) FROM %s %s""" % (self.ulog, where)

    def select_user(self, where):
        """ List all users who have dropped packets """

        if where:
            where += ' AND state = 0 AND user_id IS NOT NULL'
        else:
            where = 'WHERE state = 0 AND user_id IS NOT NULL'

        return """SELECT username, user_id, COUNT(*) AS packets, MIN(timestamp) AS begin, MAX(timestamp) AS end
                    FROM %s
                    %s
                    GROUP BY user_id
                """ % (self.ulog, where)

    def count_drop_user(self, where):
        """ Count all users packets which are dropped """

        if where:
            where += ' AND state = 0 AND user_id IS NOT NULL'
        else:
            where = 'WHERE state = 0 AND user_id IS NOT NULL'

        return """SELECT COUNT(*) FROM %s %s""" % (self.ulog, where)

    def select_badhosts(self):

        return """SELECT ip_saddr, count(*)/300 as RATE
                  FROM %s
                  WHERE (timestamp > NOW() - INTERVAL 5 MINUTE) AND state = 0
                  GROUP BY ip_saddr""" % self.ulog

    def select_badusers(self):

        return """SELECT user_id, username, count(*)/300 as RATE
                  FROM %s
                  WHERE timestamp > NOW()- INTERVAL 5 MINUTE AND username NOT LIKE \"\" AND state=0
                  GROUP BY user_id""" % self.ulog



#############################################
#                                           #
#          TriggerRequest Object            #
#                                           #
#############################################

class TriggerRequest(Request):
    """ This class, child of DataBase, is used to make SQL requests on help tables 'offenders',
        'tcp_ports, 'udp_ports' and 'usersstats'.
    """

    def select_ports(self, proto, where):

        if where or self.sufix:
            # We can only show all dropped packets, without any filter.
            return Request.select_ports(self, proto, where)

        return """SELECT %s_dport, count AS packets, first_time AS begin, last_time AS end
                    FROM %s_ports """ % (proto, proto)

    def count_drop_port(self, proto, where):

        if where or self.sufix:
            return Request.count_drop_port(self, proto, where)

        return """SELECT SUM(count) FROM %s_ports""" % proto

    def select_ip(self, direction, where):

        if where or direction != 's' or self.sufix:
            return Request.select_ip(self, direction, where)

        return """SELECT ip_addr AS ip_saddr, count AS packets, first_time AS begin, last_time AS end
                    FROM offenders
               """

    def count_drop_ip(self, direction, where):

        if where or direction != 's' or self.sufix:
            return Request.count_drop_ip(self, direction, where)

        return """SELECT SUM(count) FROM offenders"""

    def select_user(self, where):

        if where or self.sufix:
            return Request.select_user(self, where)

        return """SELECT username, user_id, bad_conns AS packets, first_time AS begin, last_time AS end
                    FROM usersstats
                    WHERE bad_conns > 0
                """

    def count_drop_user(self, where):

        if where or self.sufix:
            return Request.count_drop_user(self, where)

        return """SELECT SUM(bad_conns) FROM usersstats"""

def getBacktrace(empty="Empty backtrace."):
    """
    Try to get backtrace as string.
    Returns "Error while trying to get backtrace" on failure.
    """
    try:
        info = sys.exc_info()
        trace = traceback.format_exception(*info)
        sys.exc_clear()
        if trace[0] != "None\n":
            return "".join(trace)
    except:
        # No i18n here (imagine if i18n function calls error...)
        return "Error while trying to get backtrace"
    return empty

#############################################
#                                           #
#            TableBase Object               #
#                                           #
#############################################

class TableBase:

    """ This is a table base object.
    You can call a SGDB and receive his answer in a formated table.
    Note that it is designed for SOAP calls.

    Example:

        class MyTable(TableBase):

            def __init__(self, database):
                TableBase.__init__(database, ['hello', 'world'])

            def __call__(self, **args):

                sort = args['sort']

                result = self._sql_query('SELECT hello, world from helloworld ORDER BY %s' % sort)

                result.addCallback(self._print_result)
                return result

            def entry_form(self, entry):

                result = (entry[0] + '_',)
                result += entry[1:]
                return result

    Usage:

        >>> table = MyTable()     # Create MyTable instance
        >>> deffered = table(**{'sort': 'hello'})    # <=> table.__call__()

        If this is a SOAP function, return deffered object, and SOAP will wait
        callback function to return it value.
        _print_result() return an array of tuple [()].

    """

    def __init__(self, database, columns):
        """
            @param database [DataBase] database object
            @param columns [list] list of columns id
        """

        self.columns = columns
        self.table = []
        self.args = {}
        self.filters = {}
        self.count = 0
        self.database = database

        assert isinstance(self.columns, list)

    def __iter__(self):

        return self.table.__iter__()

    def entry_form(self, entry):
        """ This is a callback for each entry received by _sql_query()
        Overload this method to process a entry.

            @param entry [tuple] This is a line.

            @return [tuple]

        WARNING: tuple is an imutable object, so you HAVE to recreate another tuple
        to modify it !
        """

        return entry

    def _save_count(self, result):
        """ Callback used by a _sql_query() when I want to receive a result of a
            COUNT(*) query.
        """

        self.count = result[0][0]

        return self

    def _print_result(self, result):
        """ Callback of _sql_query() when I receive a result from SGDB
            @param result [array of tuple]

            @return [TableBase] If called by SOAP, client will receive result of this function
        """

        for entry in result:
            self.table += [self.entry_form(entry),]

        return self

    def ip2str(self, ip):
        """ Get an IP integer and return a formated IP string
            @param ip [integer]
            @return [string]
        """

        if self.database.ip_type == 6:
            i = 0
            for j in xrange(len(ip)):
                i |= ord(ip[j]) << (8 * (len(ip) - j - 1))

            return IP(i, 6).strCompressed()
        else:
            return IP(int(ip), 4).strCompressed()

    def str2ip(self, string):
        """ Get a string to return an IP integer
            @param string [string] IP string
            @return [string] hexadecimal IP string
        """
        if self.database.ip_type == 6:
            return '0x%X' % IP(string).int()
        else:
            return '%d' % IP(string).int()

    def _sql_query(self, functioname, *args, **kwargs):
        """ Send a query to SGDB and receive answer in _print_result() method.
            @param query [string] SQL formated query

            @return [Deffered] This is a deffered object. If this method is called by a
                               SOAP client, return this Deffered objet to tell SOAP to
                               wait callback result to send it to client.
        """

        begin = None
        end = None
        query = StringIO()

        # Actual 'ulog' table is from the first packet (0) and now.
        nulog.core.archives[self.database.ulog] = (nulog.core.archives[self.database.ulog][0], datetime.datetime.today())

        # if there is 'conntrack_ulog' table, begin and end time is same than 'ulog' table.
        if nulog.core.archives.has_key(self.database.conntrack_ulog):
            nulog.core.archives[self.database.conntrack_ulog] = nulog.core.archives[self.database.ulog]

        if self.args.has_key('begin'):
            begin = datetime.datetime.fromtimestamp(int(self.args['begin']))
        else:
            begin = nulog.core.archives[self.database.ulog][0]

        if self.args.has_key('end'):
            end = datetime.datetime.fromtimestamp(int(self.args['end']))
        else:
            end = nulog.core.archives[self.database.ulog][1]

        assert begin
        assert end

        display = True
        if kwargs.has_key('display') and not kwargs['display']:
            display = False
            kwargs.pop('display')

        for table, date in nulog.core.archives.items():

            if date[1] and date[0] and begin < date[1] and end > date[0]:

                request = self.database.createRequest(table)
                if query.getvalue():
                    query.write(' UNION ')

                query.write(getattr(request, functioname)(*args, **kwargs))

        if not query.getvalue():
            # There isn't any query, so we exit
            return defer.succeed(self)

        if display:
            if self.args.has_key('sortby') and self.args.has_key('sort'):
                query.write(' ORDER BY %s %s' % (self.args['sortby'], self.args['sort']))
            if self.args.has_key('start') and self.args['start'] >= 0 and self.args.has_key('limit') and self.args['limit'] >= 0:
                query.write(' LIMIT %d,%d' % (self.args['start'], self.args['limit']))

        return self.database.runQuery(query.getvalue())

    def _remove_column(self, name):

        try:
            self.columns.remove(name)
        except:
            pass

    def _arg_int(self, args, argname):
        """ Check if argname is in args, and if it is an integer
            @param args [dict] Args where we look for argname
            @param argname [string] Argument we want to check
            @return NOTHING. Value is changed in self.args
        """

        if not args.has_key(argname):
            if not self.args.has_key(argname):
                raise Exception('You must tell a specific %s' % argname)
            else:
                return

        value = args[argname]
        try:
            self.args[argname] = int(value)
        except:
            raise Exception('%s must be an integer' % argname)

    def _arg_bool(self, args, argname):

        if not args.has_key(argname):
            if not self.args.has_key(argname):
                raise Exception('You must tell a specific %s' % argname)
            else:
                return

        value = args[argname]
        if value == 1 or value == True or value == '1' or value == 'true' or value == 'True' or value == 'yes' or value == 'y':
            self.args[argname] = True
        else:
            self.args[argname] = False

    def _arg_in(self, args, argname, lst):
        """ Look for argname in args, and check if it is in lst.
            @param args [dict] Args
            @param argname [string] Argument name
            @param lst [list] Value may be in this list.
            @return NOTHING. Value is changed in self.args
        """

        if not args.has_key(argname):
            return

        if args[argname] in lst:
            self.args[argname] = args[argname]

    def _arg_where(self, args, where, lst):
        """ Make a SQL WHERE string with args and a definition of args (in lst).
            @param args [dict] Args
            @param where [StringIO (!)] We will put here WHERE string. Note that if where has already
                                        a value, we will add our string in it (with " AND ..")
            @param lst [dict] This dict contains an association between arg and function to call
            @return NOTHING.

            Note: help function must have a prototype like :
                     def help_function(args[DICT], key[STRING], value[STRING])
                  You can use one of helper function below.
        """

        for w in lst:
            if args.has_key(w):
                if args[w] is None or args[w] == '':
                    continue # We don't take empty strings

                if not where.getvalue():
                    entry = "WHERE "
                else:
                    entry = " AND "

                value = args[w]

                # There isn't any function to call
                if lst[w] is None:
                    entry += "%s = " % w

                    if isinstance(value, int) or isinstance(value, long) or isinstance(value, str) and value.isdigit():
                        entry += str(value)
                    else:
                        entry += "'%s'" % safe(value)
                else:
                    try: # We try to call function, and if there is an exception raised, we show backtrace and raise it.
                        ret = lst[w](args, w, safe(str(value)))
                        if not ret:
                            continue

                        entry += ret
                    except Exception, e:
                        raise e

                self.args[w] = args[w]
                self.filters[w] = args[w]

                where.write(entry)

    def __arg_where_priv_ip(self, args, key, value, function):
        """ Never call this function outside of _arg_where_ip* functions. """

        try:
            ips = socket.gethostbyname_ex(value)[2]
        except socket.gaierror, e:
            try:
                IP(value)
            except:
                raise Exception('Please give a correct hostname or IP')
            ips = [value]

        s = ''
        for ip in ips:
            if s:
                s += ' OR '
            if self.database.ip_type == 6:
                # gethostbyname_ex ALWAYS return an ipv4.
                # We cast it to an ipv6 if necessary.
                ip = IP(ip, 6).strCompressed()

            s += '(%s)' % function(args, key, ip)

        return '(%s)' % s

    def __arg_where_priv_ip_generic(self, args, key, value):
        if self.database.ip_type == 6:
            return '%s = LPAD(%s, 16, 0x00)' % (key, self.str2ip(value))
        else:
            return '%s = %s' % (key, self.str2ip(value))

    def _arg_where_ip(self, args, key, value):
        return self.__arg_where_priv_ip(args, key, value, self.__arg_where_priv_ip_generic)

    def __arg_where_priv_ip_REVERSE(self, args, key, value):
        if self.database.ip_type == 6:
            return self._arg_where_ip(args, key, value)
        else:
            ip = self.str2ip(value)
            value = struct.unpack("<I", struct.pack(">I", long(ip)))[0]
            return '%s = %s' % (key, value)

    def _arg_where_REVERSEip(self, args, key, value):
        return self.__arg_where_priv_ip(args, key, value, self.__arg_where_priv_ip_REVERSE)

    def __arg_where_priv_ip_both(self, args, key, value):

        if not args.has_key('ip_from') or not args['ip_from'] in ('s', 'd'):
            return '(%s OR %s)' % (self._arg_where_ip(args, 'ip_saddr', value),
                                   self._arg_where_ip(args, 'ip_daddr', value))
        else:
            return self._arg_where_ip(args, ('ip_%saddr' % args['ip_from']), value)

    def _arg_where_ip_both(self, args, key, value):
        return self.__arg_where_priv_ip(args, key, value, self.__arg_where_priv_ip_both)

    def _arg_where_port(self, args, key, value):

        if not isinstance(value, int) and not isinstance(value, long) and (not isinstance(value, str) or not value.isdigit()):
            raise Exception('Please specify an integer value')

        protos = ('tcp', 'udp')
        if not args.has_key('proto') or not args['proto'] in protos:
            return '(tcp_%s = %s OR udp_%s = %s)' % (key, value, key, value)
        else:
            return '%s_%s = %s' % (args['proto'], key, value)

    def _arg_where_int(self, args, key, value):

        if not isinstance(value, int) and not isinstance(value, long) and (not isinstance(value, str) or not value.isdigit()):
            raise Exception('Please specify an integer value')

        return '%s = %s' % (key, value)

    def _arg_where_proto(self, args, key, value):

        if value in ('tcp', 'udp'):
            return '%s_sport IS NOT NULL' % value
        elif value == 'icmp':
            return 'tcp_sport IS NULL AND udp_sport IS NULL'
        else:
            raise Exception('Protocol must be tcp or udp (is %s)' % value)

    def _arg_where_state(self, args, key, value):

        try:
            value = int(value)

            if value == -1:
                # -1 = ALL
                return ''

            if value < 0 or value > 4:
                raise Exception('State must be an integer between 0 and 4.')

            if value == 4:
                return '(%s = 1 OR %s = 2)' % (key, key)
            else:
                return '%s = %s' % (key, value)
        except:
            raise Exception('Please specify an integer value')

    def _arg_where_begin_time(self, args, key, value):

        if not isinstance(value, int) and not isinstance(value, long) and (not isinstance(value, str) or not value.isdigit()):
            raise Exception("Please specify an integer value")

        return 'timestamp >= FROM_UNIXTIME(%s)' % value

    def _arg_where_end_time(self, args, key, value):

        if not isinstance(value, int) and not isinstance(value, long) and (not isinstance(value, str) or not value.isdigit()):
            raise Exception("Please specify an integer value")

        return 'timestamp <= FROM_UNIXTIME(%s)' % value

    def _arg_where_like(self, args, key, value):

        return '%s LIKE \'%%%s%%\'' % (key, value)

class InfoBase(TableBase):
    """ This class is herited from TableBase and has just only one specific overloaded function """


    def _print_result(self, result):
        """ Callback of _sql_query() when I receive a result from SGDB
            In this function, we will create a dict to associate key and value.
            @param result [array of tuple]

            @return [TableBase] If called by SOAP, client will receive result of this function
        """

        for entry in result:
            self.table += [self.entry_form(entry),]

        if len(self.table) > 0:
            self.info = dict()
            for key, value in zip(self.columns, self.table[0]):
                self.info[key] = value
        else:
            return None

        return self

def start_dbpool(host, db, user, passwd, _type, ip, ulog):
    """ This function create a connection to the database and put
        the ConnectionPool object in a global variable.
    """
    return DataBase(_type, host, user, passwd, db, ip, ulog)
