Source code for firepit.query

"""Utilities for generating SQL while avoiding SQL injection vulns"""
import re

from firepit.validate import validate_name
from firepit.validate import validate_path

COMP_OPS = ['=', '<>', '!=', '<', '>', '<=', '>=',
            'LIKE', 'IN', 'IS', 'NOT LIKE', 'NOT IN', 'IS NOT', 'MATCHES']
PRED_OPS = ['AND', 'OR']
JOIN_TYPES = ['INNER', 'OUTER', 'LEFT OUTER', 'CROSS']
AGG_FUNCS = ['COUNT', 'SUM', 'MIN', 'MAX', 'AVG', 'NUNIQUE']
COL_PATTERN = r"^(\*|[A-Za-z_]+)$"


def _validate_column_name(name):
    if not bool(re.match(COL_PATTERN, name)):
        validate_path(name)  # This is for STIX object paths, not column names...


def _validate_column(col):
    if isinstance(col, str):
        _validate_column_name(col)
    elif isinstance(col, Column):
        _validate_column_name(col.name)
        if col.table:
            validate_name(col.table)
        if col.alias:
            validate_path(col.alias)


[docs]class InvalidComparisonOperator(Exception): pass
[docs]class InvalidPredicateOperator(Exception): pass
[docs]class InvalidPredicateOperand(Exception): pass
[docs]class InvalidJoinOperator(Exception): pass
[docs]class InvalidAggregateFunction(Exception): pass
[docs]class InvalidQuery(Exception): pass
def _quote(obj): """Double-quote an SQL identifier if necessary""" if isinstance(obj, str): if obj == '*': return obj return f'"{obj}"' return str(obj) def _alias(obj): if hasattr(obj, 'alias') and obj.alias: return _quote(obj.alias) return _quote(obj)
[docs]class Column: """SQL Column name""" def __init__(self, name, table=None, alias=None): _validate_column_name(name) if table: validate_name(table) if alias: validate_path(alias) self.name = name self.table = table self.alias = alias def __str__(self): if self.table: result = f'"{self.table}".{_quote(self.name)}' else: result = f'{_quote(self.name)}' if self.alias: result = f'{result} AS "{self.alias}"' return result
[docs] def endswith(self, s): return str(self).endswith(s)
[docs]class CoalescedColumn: """First non-null column from a list - used after a JOIN""" def __init__(self, names, alias): for name in names: _validate_column_name(name) validate_path(alias) self.names = names self.alias = alias def __str__(self): result = ', '.join(self.names) result = f'COALESCE({result}) AS "{self.alias}"' return result
[docs]class BinnedColumn(Column): """Bin (or "bucket") column values, persumably for easier grouping""" def __init__(self, prop: str, n: int, unit: str = None, table: str = None, alias: str = None ): super().__init__(prop, table, alias) self.n = n self.unit = unit.lower() if unit else ''
[docs] def render(self, _placeholder, dialect=None): if self.table: col = f'"{self.table}".{_quote(self.name)}' else: col = _quote(self.name) if self.alias: alias = self.alias else: alias = self.name if self.unit in ("days", "d"): secs = 86400 # seconds per day elif self.unit in ("hours", "h"): secs = 3600 # seconds per hour elif self.unit in ("minutes", "m"): secs = 60 # seconds per minute elif self.unit in ("seconds", "s"): secs = 1 # seconds per second else: secs = None # i.e. not a timestamp if secs: # Must be a timestamp bin_size = self.n * secs if dialect == 'postgresql': # PostgreSQL dt = f'to_timestamp((FLOOR(EXTRACT(epoch from {col}::timestamp)/{bin_size})*{bin_size}))' return f'to_char({dt}, \'yyyy-MM-dd"T"HH24:MI:SS"Z"\') AS "{alias}"' # sqlite3 dt = f"datetime(strftime('%s', {col})/{bin_size}*{bin_size}, 'unixepoch')" return f'strftime(\'%Y-%m-%dT%H:%M:%SZ\', {dt}) AS "{alias}"' # else we assume it's some numeric column bin_size = self.n return f'{col}/{bin_size}*{bin_size} AS "{alias}"'
[docs]class Predicate: """Row value predicate""" def __init__(self, lhs, op, rhs): if isinstance(lhs, Predicate): if op not in PRED_OPS: raise InvalidPredicateOperator(op) if not isinstance(rhs, Predicate): raise InvalidPredicateOperand(str(rhs)) self.values = lhs.values + rhs.values else: table = alias = None if op not in COMP_OPS: raise InvalidComparisonOperator(op) if rhs is None: rhs = 'NULL' if isinstance(lhs, Column): table = lhs.table alias = lhs.alias lhs = lhs.name if '[*]' in lhs: # STIX list property lhs, _, _ = lhs.partition('[*]') # Need to remove this if rhs not in ['null', 'NULL']: rhs = f"%{rhs}%" # wrap with SQL wildcards since list is encoded as string if op == '=': op = 'LIKE' elif op == '!=': op = 'NOT LIKE' if isinstance(lhs, str): lhs = Column(lhs, table, alias) if rhs in ['null', 'NULL']: self.values = () if op not in ['=', '!=', '<>', 'IS', 'IS NOT']: raise InvalidComparisonOperator(op) # Maybe need different exception here? elif isinstance(rhs, (list, tuple)): self.values = tuple(rhs) elif isinstance(rhs, Column): self.values = tuple() elif isinstance(rhs, Query): _, self.values = rhs.render('IGNORED') else: self.values = (rhs, ) self.lhs = lhs self.op = op self.rhs = rhs
[docs] def render(self, placeholder, _dialect=None): if isinstance(self.lhs, Predicate): text = self.lhs.render(placeholder) text += f' {self.op} ' text += self.rhs.render(placeholder) return f'({text})' # Do we really need parens? neg, _, op = self.op.rpartition(' ') # Special case for base64-encoded artifacts if self.lhs.name == 'payload_bin' and op in ('LIKE', 'MATCHES'): if op == 'MATCHES': text = f'{neg} match_bin(CAST({placeholder} AS TEXT), {_quote(self.lhs)})' elif op == 'LIKE': text = f'{neg} like_bin(CAST({placeholder} AS TEXT), {_quote(self.lhs)})' elif self.rhs in ['null', 'NULL']: if self.op in ['!=', '<>']: text = f'({_quote(self.lhs)} IS NOT NULL)' elif self.op == '=': text = f'({_quote(self.lhs)} IS NULL)' else: raise InvalidComparisonOperator(self.op) elif isinstance(self.rhs, Column): text = f'({_quote(self.lhs)} {self.op} {_quote(self.rhs)})' elif op == 'IN': if isinstance(self.rhs, Query): # there's probably a better way to detect this rhs, _ = self.rhs.render(placeholder) else: rhs = ', '.join([placeholder] * len(self.rhs)) text = f'({_quote(self.lhs)} {self.op} ({rhs}))' else: text = f'({_quote(self.lhs)} {self.op} {placeholder})' return text
[docs] def set_table(self, table): """Specify table for ALL columns in Predicate""" if isinstance(self.lhs, Predicate): self.lhs.set_table(table) elif isinstance(self.lhs, Column): self.lhs = Column(self.lhs.name, table) else: self.lhs = Column(self.lhs, table) if isinstance(self.rhs, Predicate): self.rhs.set_table(table)
[docs]class Filter: """Alternative SQL WHERE clause""" OR = ' OR ' AND = ' AND ' def __init__(self, preds, op=AND): self.preds = preds self.op = op self.values = () for pred in self.preds: self.values += pred.values
[docs] def render(self, placeholder, _dialect=None): pred_list = [] for pred in self.preds: pred_list.append(pred.render(placeholder)) result = self.op.join(pred_list) if self.op == Filter.OR: return f'({result})' return result
[docs] def set_table(self, table): """Specify table for ALL Predicates in Filter""" for pred in self.preds: pred.set_table(table)
[docs]class Order: """SQL ORDER BY clause""" ASC = 'ASC' DESC = 'DESC' def __init__(self, cols): self.cols = [] for col in cols: if not isinstance(col, tuple): col = (col, Order.ASC) if isinstance(col[0], str): validate_path(col[0]) self.cols.append(col)
[docs] def render(self, _placeholder, _dialect=None): col_list = [] for col in self.cols: col_list.append(f'{_alias(col[0])} {col[1]}') return ', '.join(col_list)
[docs]class Projection: """SQL SELECT (really projection - pick column subset) clause""" def __init__(self, cols): for col in cols: _validate_column(col) self.cols = cols
[docs] def render(self, placeholder, dialect=None): cols = [col.render(placeholder, dialect) if hasattr(col, 'render') else _quote(col) for col in self.cols] # Dumb hack to get db-specific fetures return ', '.join(cols)
[docs]class Table: """SQL Table selection""" def __init__(self, name): validate_name(name) self.name = name
[docs] def render(self, _placeholder, _dialect=None): return self.name
[docs]class Group: """SQL GROUP clause""" def __init__(self, cols): for col in cols: _validate_column(col) self.cols = cols
[docs] def render(self, _placeholder, _dialect=None): cols = [] for col in self.cols: if hasattr(col, 'alias') and col.alias: # Ugly, ugly hack # Can only define a new column in Projection # So here we just use the alias cols.append(col.alias) elif isinstance(col, Column): # Again, nasty hacks if col.alias: cols.append(col.alias) elif col.table: cols.append(f'{col.table}"."{col.name}') else: cols.append(col.name) else: cols.append(col) return ', '.join([_quote(col) for col in cols])
[docs]class Aggregation: """Aggregate rows""" def __init__(self, aggs): self.aggs = [] for agg in aggs: if isinstance(agg, tuple): if len(agg) == 3: func, col, alias = agg elif len(agg) == 2: func, col = agg alias = None if func.upper() not in AGG_FUNCS: raise InvalidAggregateFunction(func) if col is not None and col != '*': _validate_column(col) self.aggs.append((func, col, alias)) else: raise TypeError('expected aggregation tuple but received ' + str(type(agg))) self.group_cols = [] # Filled in by Query
[docs] def render(self, _placeholder, _dialect=None): exprs = [_quote(col) for col in self.group_cols] for agg in self.aggs: mod = '' func, col, alias = agg if func.upper() == 'NUNIQUE': func = 'COUNT' mod = 'DISTINCT ' if not col: col = '*' if col == '*': expr = f'{func}({mod}{col})' # No quotes for * else: expr = f'{func}({mod}{_quote(col)})' if not alias: alias = func.lower() expr += f' AS "{alias}"' exprs.append(expr) return ', '.join(exprs)
[docs]class Offset: """SQL row offset""" def __init__(self, num): self.num = int(num)
[docs] def render(self, _placeholder, _dialect=None): return str(self.num)
[docs]class Limit: """SQL row count""" def __init__(self, num): self.num = int(num)
[docs] def render(self, _placeholder, _dialect=None): return str(self.num)
[docs]class Count: """Count the rows in a result set""" def __init__(self): pass
[docs] def render(self, _placeholder, _dialect=None): return 'COUNT(*) AS "count"'
[docs]class Unique: """Reduce the rows in a result set to unique tuples""" def __init__(self): pass
[docs] def render(self, placeholder, dialect=None): return 'SELECT DISTINCT *'
[docs]class CountUnique: """Unique count of the rows in a result set""" def __init__(self, cols=None): for col in cols or []: _validate_column(col) self.cols = cols
[docs] def render(self, _placeholder, _dialect=None): if self.cols: cols = ', '.join([f'"{col}"' for col in self.cols]) return f'COUNT(DISTINCT {cols}) AS "count"' return 'COUNT(*) AS "count"'
[docs]class Join: """Join 2 tables""" def __init__(self, name, left_col=None, op=None, right_col=None, preds=None, how='INNER', alias=None, lhs=None): """ Use *either* `left_col`, `op`, and `right_col` or `preds` """ validate_name(name) if all((left_col, op, right_col)): _validate_column(left_col) _validate_column(right_col) if alias: validate_name(alias) if lhs: validate_name(name) if how.upper() not in JOIN_TYPES: raise InvalidJoinOperator(how) self.prev_name = lhs # If none, filled in by Query self.name = name self.left_col = left_col self.op = op self.right_col = right_col self.how = how self.alias = alias self.values = tuple() self.preds = preds if preds: for pred in self.preds: self.values += pred.values def __repr__(self): return f'Join({self.name}, {self.left_col}, {self.op}, {self.right_col}, {self.how}, {self.alias}, {self.prev_name})' def __eq__(self, rhs): return ( self.prev_name == rhs.prev_name and self.name == rhs.name and self.left_col == rhs.left_col and self.op == rhs.op and self.right_col == rhs.right_col and self.how == rhs.how and self.alias == rhs.alias)
[docs] def render(self, placeholder, _dialect=None): # Assume there's a FROM before this? target = f'"{self.name}"' table = target if self.alias: target += f' AS "{self.alias}"' table = f'"{self.alias}"' if self.left_col: cond = (f'"{self.prev_name}"."{self.left_col}"' f' {self.op} {table}."{self.right_col}"') else: pred_list = [] for pred in self.preds: tmp = pred.render(placeholder) pred_list.append(tmp) cond = ' AND '.join(pred_list) return f'{self.how.upper()} JOIN {target} ON {cond}'
[docs]class Query: """ SQL Query statement SQL order of evaluations: FROM, including JOINs WHERE GROUP BY HAVING WINDOW functions SELECT (projection) DISTINCT UNION ORDER BY LIMIT and OFFSET """ def __init__(self, arg=None): self.table = None self.joins = [] self.where = [] self.groupby = None self.aggs = None self.having = [] #Not supported: windows self.proj = None # Make a list of Projections? self.distinct = False self.count = False # FIXME: isn't this an aggregation? # TODO: self.union = [] self.order = None self.limit = None self.offset = 0 if isinstance(arg, str): self.table = Table(arg) elif isinstance(arg, Table): self.table = arg elif isinstance(arg, list): self.extend(arg)
[docs] def append(self, stage): if isinstance(stage, Table): self.table = stage elif isinstance(stage, Join): if not self.table: raise InvalidQuery('Join must follow Table or Join') self.joins.append(stage) elif isinstance(stage, Filter): if self.groupby: self.having.append(stage) else: self.where.append(stage) elif isinstance(stage, Group): self.groupby = stage elif isinstance(stage, Aggregation): self.aggs = stage elif isinstance(stage, Projection): self.proj = stage elif isinstance(stage, Count): self.count = stage elif isinstance(stage, Unique): self.distinct = True elif isinstance(stage, CountUnique): self.count = Count() self.distinct = True elif isinstance(stage, Order): self.order = stage elif isinstance(stage, Limit): self.limit = stage elif isinstance(stage, Offset): self.offset = stage elif isinstance(stage, Query): if not self.table: self.table = stage
#TODO: else?
[docs] def extend(self, stages): for stage in stages: self.append(stage)
[docs] def render(self, placeholder, dialect=None): if not self.table: raise InvalidQuery("no table") #TODO: better message result_cols = '' sub_count = 0 # Count of "sub queries" values = () text = self.table.render(placeholder) if isinstance(text, tuple): text, values = text sub_count += 1 query = f'FROM ({text}) AS s{sub_count}' else: query = f'FROM "{text}"' for i, join in enumerate(self.joins): # prev_name stuff is a hack if not join.prev_name: join.prev_name = self.table.name if i == 0 else self.joins[i - 1].name values += join.values text = join.render(placeholder) query = f'{query} {text}' filts = [] for filt in self.where: filts.append(filt.render(placeholder)) values += filt.values if filts: where = ' AND '.join(filts) query = f'{query} WHERE {where}' if self.groupby: text = self.groupby.render(placeholder, dialect) query = f'{query} GROUP BY {text}' # Add group cols to result set automatically if result_cols: result_cols += ', ' result_cols += ', '.join([_quote(col) for col in self.groupby.cols]) filts = [] for filt in self.having: values += filt.values filts.append(filt.render(placeholder)) if filts: where = ' AND '.join(filts) query = f'{query} HAVING {where}' # Projection and Aggregation both add columns to result set if self.proj: if result_cols: result_cols += ', ' result_cols = self.proj.render(placeholder, dialect) if self.aggs: if result_cols: result_cols += ', ' result_cols += self.aggs.render(placeholder) if not result_cols: result_cols = '*' if self.distinct and self.count and result_cols == '*': query = f'COUNT(*) AS "count" FROM (SELECT DISTINCT * {query}) AS tmp' elif self.distinct and self.count: query = f'COUNT(DISTINCT {result_cols}) AS "count" {query}' elif self.distinct: query = f'DISTINCT {result_cols} {query}' elif self.count: query = f'COUNT({result_cols}) AS "count" {query}' else: query = f'{result_cols} {query}' if self.order: text = self.order.render(placeholder) query = f'{query} ORDER BY {text}' if self.limit: text = self.limit.render(placeholder) query = f'{query} LIMIT {text}' if self.offset: text = self.offset.render(placeholder) query = f'{query} OFFSET {text}' query = f'SELECT {query}' return query, values