Source code for firepit.sqlitestorage

import ipaddress
import logging
import os
import re
import sqlite3

from base64 import b64decode

from firepit.exceptions import DuplicateTable
from firepit.exceptions import InvalidAttr
from firepit.exceptions import UnexpectedError
from firepit.exceptions import UnknownViewname
from firepit.splitter import SqlWriter
from firepit.sqlstorage import DB_VERSION
from firepit.sqlstorage import SqlStorage
from firepit.sqlstorage import infer_type
from firepit.sqlstorage import validate_name

logger = logging.getLogger(__name__)


CONTAINS_TABLE = ('CREATE TABLE IF NOT EXISTS "__contains" '
                  '(source_ref TEXT, target_ref TEXT, x_firepit_rank INTEGER);')
#TODO:' UNIQUE(source_ref, target_ref)

COLUMNS_TABLE = ('CREATE TABLE IF NOT EXISTS "__columns" '
                 '(otype TEXT, path TEXT, shortname TEXT, dtype TEXT,'
                 ' UNIQUE(otype, path));')

# Bootstrap some common SDO tables
ID_TABLE = ('CREATE TABLE "identity" ('
            ' "id" TEXT UNIQUE,'
            ' "identity_class" TEXT,'
            ' "name" TEXT,'
            ' "created" TEXT,'
            ' "modified" TEXT'
            ')')

OD_TABLE = ('CREATE TABLE "observed-data" ('
            ' "id" TEXT UNIQUE,'
            ' "created_by_ref" TEXT,'
            ' "created" TEXT,'
            ' "modified" TEXT,'
            ' "first_observed" TEXT,'
            ' "last_observed" TEXT,'
            ' "number_observed" BIGINT'
            ')')


[docs]def get_storage(path): return SQLiteStorage(path)
def _in_subnet(value, net): """User-defined function to help implement STIX ISSUBSET""" if '/' in value: value = ipaddress.IPv4Network(value).network_address else: value = ipaddress.IPv4Address(value) net = ipaddress.IPv4Network(net) return value in net def _match(pattern, value): """User-defined function to implement SQL MATCH/STIX MATCHES""" return (value is not None and bool(re.search(pattern, value, re.DOTALL))) def _match_bin(pattern, value): """User-defined function to implement SQL MATCH/STIX MATCHES for binaries""" if value is not None: val = b64decode(value).decode("utf-8") return bool(re.search(pattern, val, re.DOTALL)) return False def _like_bin(pattern, value): """User-defined function to implement SQL/STIX LIKE for binaries""" if value is not None: try: exp = re.escape(pattern).replace('%', '.*').replace('_', '.') val = b64decode(value).decode("utf-8") return bool(re.search(exp, val, re.DOTALL)) except Exception as e: logger.error('%s', e, exc_info=e) return False
[docs]class SQLiteStorage(SqlStorage): def __init__(self, dbname): super().__init__() self.dialect = 'sqlite3' self.placeholder = '?' self.dbname = dbname self.connection = sqlite3.connect(dbname) self.connection.row_factory = row_factory logger.debug("Connection to SQLite DB %s successful", dbname) # Create functions for IP address subnet membership checks self.connection.create_function('in_subnet', 2, _in_subnet) # Create function for SQL MATCH self.connection.create_function("match", 2, _match) self.connection.create_function("match_bin", 2, _match_bin) self.connection.create_function("like_bin", 2, _like_bin) cursor = self.connection.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='__queries'") rows = cursor.fetchall() if len(rows) == 1: # Attaching to existing DB self._checkdb() else: # Do DB initization cursor.execute('BEGIN;') cursor.execute(CONTAINS_TABLE) cursor.execute(COLUMNS_TABLE) cursor.execute(ID_TABLE) cursor.execute(OD_TABLE) self._initdb(cursor) cursor.close() def _migrate(self, version, cursor): if version == '2': self._execute(COLUMNS_TABLE, cursor) version = '2.1' if version == '2.1': # Add unique contraint to __symtable # First de-dup the table data = self.get_view_data() views = {} for row in data: views[row['name']] = row cursor = self._execute('BEGIN;') self._execute('DROP TABLE __symtable', cursor) stmt = ('CREATE TABLE IF NOT EXISTS "__symtable" ' '(name TEXT, type TEXT, appdata TEXT,' ' UNIQUE(name));') self._execute(stmt, cursor) for view in views.values(): stmt = (f'INSERT INTO "__symtable" (name, type, appdata)' f' VALUES ({self.placeholder}, {self.placeholder}, {self.placeholder})') cursor.execute(stmt, (view['name'], view['type'], view['appdata'])) self.connection.commit() cursor.close() version = '2.2' return version == DB_VERSION def _get_writer(self, **kwargs): """Get a DB inserter object""" filedir = os.path.dirname(self.dbname) return SqlWriter(filedir, self, infer_type=infer_type) def _do_execute(self, query, values=None, cursor=None): if not cursor: cursor = self.connection.cursor() try: logger.debug('Executing query: %s', query) if not values: cursor.execute(query) else: cursor.execute(query, values) except sqlite3.OperationalError as e: logger.error('%s: %s', query, e) #, exc_info=e) if e.args[0].startswith("no such column"): m = e.args[0].replace("no such column", "invalid attribute") raise InvalidAttr(m) from e elif e.args[0].startswith("no such table: main."): # Just means no match - return empty cursor? cursor = self.connection.cursor() elif e.args[0].startswith("no such table: "): raise UnknownViewname(e.args[0]) from e elif e.args[0].endswith("syntax error"): # We see this on SQL injection attempts raise UnexpectedError(e.args[0]) from e elif e.args[0].endswith("table") and e.args[0].endswith(" already exists"): tablename = e.args[0].split('"')[1] raise DuplicateTable(tablename) from e else: raise e # See if caller wants special behavior return cursor def _execute(self, statement, cursor=None): return self._do_execute(statement, cursor=cursor) def _query(self, query, values=None, cursor=None): cursor = self._do_execute(query, values=values, cursor=cursor) self.connection.commit() return cursor def _create_view(self, viewname, select, sco_type, deps=None, cursor=None): """Overrides parent""" validate_name(viewname) if not cursor: cursor = self._execute('BEGIN;') is_new = True if not deps: deps = [] elif viewname in deps: is_new = False # Get the query that makes up the current view slct = self._get_view_def(viewname) if self._is_sql_view(viewname, cursor): self._execute(f'DROP VIEW IF EXISTS "{viewname}"', cursor) else: self._execute(f'ALTER TABLE "{viewname}" RENAME TO "_{viewname}"', cursor) slct = slct.replace(viewname, f'_{viewname}') # Swap out the viewname for its definition select = re.sub(f'FROM "{viewname}"', f'FROM ({slct}) AS tmp', select, count=1) select = re.sub(f'"{viewname}"', 'tmp', select) if self._is_sql_view(viewname, cursor): is_new = False self._execute(f'DROP VIEW IF EXISTS "{viewname}"', cursor) self._execute(f'CREATE VIEW "{viewname}" AS {select}', cursor) if is_new: self._new_name(cursor, viewname, sco_type) return cursor def _create_table(self, tablename, columns): stmt = f'CREATE TABLE "{tablename}" (' stmt += ','.join([f'"{colname}" {coltype}' for colname, coltype in columns.items()]) stmt += ');' logger.debug('_create_table: "%s"', stmt) try: cursor = self._execute(stmt) except sqlite3.OperationalError as e: self.connection.rollback() logger.debug('_create_table: %s', e) #, exc_info=e) if e.args[0].startswith(f'table "{tablename}" already exists'): raise DuplicateTable(tablename) from e self._create_index(tablename, cursor) self.connection.commit() cursor.close() def _add_column(self, tablename, prop_name, prop_type): stmt = f'ALTER TABLE "{tablename}" ADD COLUMN "{prop_name}" {prop_type};' logger.debug('new_property: "%s"', stmt) try: cursor = self._execute(stmt) self.connection.commit() cursor.close() except sqlite3.OperationalError as e: self.connection.rollback() logger.debug('%s', e) #, exc_info=e) if e.args[0].startswith('duplicate column name: '): pass else: raise Exception('Internal error: ' + e.args[0]) from e def _get_view_def(self, viewname): view = self._query(("SELECT sql from sqlite_master" " WHERE type='view' and name=?"), values=(viewname,)).fetchone() if view: slct = view['sql'] return slct.replace(f'CREATE VIEW "{viewname}" AS ', '') # Must be a table return f'SELECT * FROM "{viewname}"' def _is_sql_view(self, name, cursor=None): view = self._query(("SELECT sql from sqlite_master" " WHERE type='view' and name=?"), values=(name,)).fetchone() return view is not None
[docs] def tables(self): cursor = self.connection.execute( "SELECT name FROM sqlite_master WHERE type='table';") rows = cursor.fetchall() return [i['name'] for i in rows if not i['name'].startswith('__') and not i['name'].startswith('sqlite')]
[docs] def types(self, private=False): stmt = ("SELECT name FROM sqlite_master WHERE type='table'" " EXCEPT SELECT name FROM __symtable") cursor = self.connection.execute(stmt) rows = cursor.fetchall() if private: return [i['name'] for i in rows] return [i['name'] for i in rows if not i['name'].startswith('__') and not i['name'].startswith('sqlite')]
[docs] def columns(self, viewname): validate_name(viewname) stmt = f'PRAGMA table_info("{viewname}")' cursor = self._execute(stmt) try: mappings = cursor.fetchall() if mappings: result = [e["name"] for e in mappings] else: result = [] logger.debug('%s columns = %s', viewname, result) except sqlite3.OperationalError as e: logger.error('%s', e) result = [] return result
[docs] def schema(self, viewname=None): if viewname: validate_name(viewname) stmt = f'PRAGMA table_info("{viewname}")' cursor = self._execute(stmt) return [{k: v for k, v in row.items() if k in ['name', 'type']} for row in cursor.fetchall()] else: result = [] for obj_type in self.types(True): stmt = f'PRAGMA table_info("{obj_type}")' cursor = self._execute(stmt) for row in cursor.fetchall(): result.append({ 'table': obj_type, 'name': row['name'], 'type': row['type'] }) return result
[docs] def delete(self): """Delete ALL data in this store""" self.connection.close() try: os.remove(self.dbname) except FileNotFoundError: pass
[docs]def row_factory(cursor, row): return {col[0]: row[idx] for idx, col in enumerate(cursor.description)}