# Copyright (C) 2021- Swedish Meteorological and Hydrological Institute (SMHI)
#
# This file is part of baltrad-exchange.
#
# baltrad-exchange is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# baltrad-exchange is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with baltrad-exchange. If not, see <http://www.gnu.org/licenses/>.
###############################################################################
## The SQL functionality that is used by the server. As default it uses a sqlite database
## for miscellaneous operations and also made available to plugins
## @file
## @author Anders Henja, SMHI
## @date 2021-08-18
from __future__ import absolute_import
import contextlib
import datetime
import logging
from sqlalchemy import asc,desc,func,text
from sqlalchemy import engine, event, exc as sqlexc, sql
from sqlalchemy.orm import mapper, sessionmaker
from sqlalchemy.types import (
Integer,
BigInteger,
Float,
Text,
DateTime,
TIMESTAMP
)
from sqlalchemy import (
Column,
ForeignKey,
MetaData,
PrimaryKeyConstraint,
UniqueConstraint,
Table,
)
from bexchange.db import util as dbutil
logger = logging.getLogger("bexchange.db.sqldatabase")
dbmeta = MetaData()
##
# Used to ensure that the
#The created table is in format | spid | origin | source | counter | update |
#
db_statistics = Table("exchange_statistics", dbmeta,
Column("spid", Text, nullable=False),
Column("origin", Text, nullable=False),
Column("source", Text, nullable=True),
Column("counter", BigInteger, nullable=False),
Column("updated_at", DateTime, nullable=False),
PrimaryKeyConstraint("spid","origin","source")
)
db_statentry = Table("exchange_statentry", dbmeta,
Column('id', Integer, primary_key=True),
Column("spid", Text, nullable=False),
Column("origin", Text, nullable=False),
Column("source", Text, nullable=True),
Column("hashid", Text, nullable=True),
Column("entrytime", TIMESTAMP, nullable=False),
Column("optime", Integer, nullable=False),
Column("delay", Integer, nullable=False),
Column("optime_info", Text, nullable=True),
Column("datetime", DateTime, nullable=True),
Column("object_type", Text, nullable=True),
Column("elevation_angle", Float, nullable=True),
UniqueConstraint("spid", "origin", "source", "entrytime")
)
[docs]
class statistics(object):
def __init__(self, spid, origin, source, counter, updated_at):
""" Keeps the totals for a specified spid + origin + source
:param spid: The statistics plugin id
:param origin: Origin for this stat
:param source: Source
:param counter: Counter
:param updated_at: When entry last was updated
"""
self.spid = spid
self.origin = origin
self.source = source
self.counter = counter
self.updated_at = updated_at
[docs]
def json_repr(self):
return {
"spid":self.spid,
"origin":self.origin,
"source":self.source,
"counter":self.counter,
"updated_at":self.updated_at.isoformat()
}
[docs]
class statentry(object):
def __init__(self, spid, origin, source, hashid, entrytime, optime=0, optime_info=None, delay=0, ndatetime=None, object_type=None, elevation_angle=None):
""" Represents one increment entry. Used for creating averages and such information
:param spid: The statistics plugin id
:param origin: Origin for this stat
:param source: Source
:param hashid: The hash id
:param entrytime: When this entry was created
:param optime: Operation time entry in ms
:param optime_info: Used to identify what was timed
:param delay: Difference between arrival of file and nominal datetime
"""
self.spid = spid
self.origin = origin
self.source = source
self.hashid = hashid
self.entrytime = entrytime
self.optime = optime
self.optime_info = optime_info
self.delay = delay
self.datetime = ndatetime
self.object_type = object_type
self.elevation_angle = elevation_angle
self.attributes = {}
[docs]
def json_repr(self):
result = {
"spid":self.spid,
"origin":self.origin,
"source":self.source,
"hashid":self.hashid,
"entrytime":self.entrytime.isoformat(),
"optime":self.optime,
"optime_info":self.optime_info,
"delay":self.delay,
"datetime":self.datetime.isoformat(),
"object_type":self.object_type,
"elevation_angle":self.elevation_angle
}
if "attributes" in self.__dict__:
for a in self.attributes:
result[a] = self.attributes[a]
return result
[docs]
def add_attribute(self, name, value):
if not "attributes" in self.__dict__:
self.attributes = {}
self.attributes[name] = value
mapper(statistics, db_statistics)
mapper(statentry, db_statentry)
logger = logging.getLogger("bexchange.db.sqldatabase")
[docs]
def force_sqlite_foreign_keys(dbapi_con, con_record):
try:
import sqlite3
except ImportError:
# built without sqlite support
return
if isinstance(dbapi_con, sqlite3.Connection):
dbapi_con.execute("pragma foreign_keys=ON")
[docs]
class SqlAlchemyDatabase(object):
def __init__(self, uri="sqlite:///tmp/baltrad-exchange.db", poolsize=10):
"""Constructor
:param uri: The uri pointing to the database.
:param poolsize: How many database connections we should use
"""
self._engine = dbutil.create_engine_from_url(uri, poolsize)
if self._engine.driver == "pysqlite":
event.listen(self._engine, "connect", force_sqlite_foreign_keys)
self.init_tables()
dbmeta.bind = self._engine
self.Session = sessionmaker(bind=self._engine)
@property
def driver(self):
"""database driver name
"""
return self._engine.driver
[docs]
def init_tables(self):
dbmeta.create_all(self._engine)
logger.info("Initialized alchemy database")
[docs]
def get_connection(self):
"""get a context managed connection to the database
"""
return contextlib.closing(self._engine.connect())
[docs]
def get_session(self):
session = self.Session()
return contextlib.closing(session)
[docs]
def get_statistics_entry(self, spid, origin, source):
with self.get_session() as s:
q = s.query(statistics).filter(statistics.spid == spid).filter(statistics.origin == origin).filter(statistics.source == source)
return q.one_or_none()
[docs]
def list_statistic_ids(self):
with self.get_session() as s:
entries = s.query(statistics.spid).distinct(statistics.spid).all()
result = [e[0] for e in entries]
return result
[docs]
def list_statentry_ids(self):
with self.get_session() as s:
entries = s.query(statentry.spid).distinct(statentry.spid).all()
result = [e[0] for e in entries]
return result
[docs]
def find_statistics(self, spid, origins, sources):
with self.get_session() as s:
q = s.query(statistics).filter(statistics.spid == spid)
if origins and len(origins) > 0:
q = q.filter(statistics.origin.in_(origins))
if sources and len(sources) > 0:
q = q.filter(statistics.source.in_(sources))
return q.all()
[docs]
def find_statentries(self, spid, origins, sources, hashid=None, filters=None, object_type=None):
with self.get_session() as s:
q = s.query(statentry).filter(statentry.spid == spid)
if origins and len(origins) > 0:
q = q.filter(statentry.origin.in_(origins))
if sources and len(sources) > 0:
q = q.filter(statentry.source.in_(sources))
if hashid:
q = q.filter(statentry.hashid == hashid)
if filters:
for dtfilter in filters:
if dtfilter[0] == "datetime":
if dtfilter[1] == ">":
q = q.filter(statentry.datetime > dtfilter[2])
elif dtfilter[1] == ">=":
q = q.filter(statentry.datetime >= dtfilter[2])
elif dtfilter[1] == "=":
q = q.filter(statentry.datetime == dtfilter[2])
elif dtfilter[1] == "<=":
q = q.filter(statentry.datetime <= dtfilter[2])
elif dtfilter[1] == "<":
q = q.filter(statentry.datetime < dtfilter[2])
elif dtfilter[0] == "entrytime":
if dtfilter[1] == ">":
q = q.filter(statentry.entrytime > dtfilter[2])
elif dtfilter[1] == ">=":
q = q.filter(statentry.entrytime >= dtfilter[2])
elif dtfilter[1] == "=":
q = q.filter(statentry.entrytime == dtfilter[2])
elif dtfilter[1] == "<=":
q = q.filter(statentry.entrytime <= dtfilter[2])
elif dtfilter[1] == "<":
q = q.filter(statentry.entrytime < dtfilter[2])
elif dtfilter[0] == "optime":
if dtfilter[1] == ">":
q = q.filter(statentry.optime > dtfilter[2])
elif dtfilter[1] == ">=":
q = q.filter(statentry.optime >= dtfilter[2])
elif dtfilter[1] == "=":
q = q.filter(statentry.optime == dtfilter[2])
elif dtfilter[1] == "<=":
q = q.filter(statentry.optime <= dtfilter[2])
elif dtfilter[1] == "<":
q = q.filter(statentry.optime < dtfilter[2])
elif dtfilter[0] == "delay":
if dtfilter[1] == ">":
q = q.filter(statentry.delay > dtfilter[2])
elif dtfilter[1] == ">=":
q = q.filter(statentry.delay >= dtfilter[2])
elif dtfilter[1] == "=":
q = q.filter(statentry.delay == dtfilter[2])
elif dtfilter[1] == "<=":
q = q.filter(statentry.delay <= dtfilter[2])
elif dtfilter[1] == "<":
q = q.filter(statentry.delay < dtfilter[2])
if object_type:
q = q.filter(statentry.object_type == object_type)
q = q.order_by(asc(statentry.origin)) \
.order_by(asc(statentry.source)) \
.order_by(asc(statentry.entrytime))
return q.all()
[docs]
def get_average_statentries(self, spid, origins, sources, hashid=None):
with self.get_session() as s:
q = s.query(statentry, func.avg(statentry.optime)).filter(statentry.spid == spid)
if origins and len(origins) > 0:
q = q.filter(statentry.origin.in_(origins))
if sources and len(sources) > 0:
q = q.filter(statentry.source.in_(sources))
q = q.group_by(statentry.spid, statentry.origin, statentry.source)
q = q.order_by(asc(statentry.spid)) \
.order_by(asc(statentry.origin)) \
.order_by(asc(statentry.source)) \
.order_by(asc(statentry.entrytime))
qresult = q.all()
result = []
for e in qresult:
e[0].add_attribute("average", e[1])
result.append(e[0])
return result
[docs]
def cleanup_statentries(self, maxagedt):
logger.info("Cleanup of statentries older than %s"%maxagedt.strftime("%Y-%m-%d %H:%M"))
q = db_statentry.delete().where(db_statentry.c.entrytime < maxagedt.strftime("%Y-%m-%d %H:%M"))
logger.debug("Query: %s"%q)
self._engine.execute(q)
[docs]
def increment_statistics(self, spid, origin, source):
with self.get_session() as session:
stats = session.query(statistics).filter(statistics.spid == spid).filter(statistics.origin == origin).filter(statistics.source == source).first()
if not stats:
stats = statistics(spid, origin, source, 0, datetime.datetime.now())
stats.counter += 1
stats.updated_at = datetime.datetime.now()
try:
session.merge(stats)
session.commit()
except:
session.rollback()
raise
finally:
session.close()
[docs]
def add(self, obj):
session = self.Session()
xlist = obj
if not isinstance(obj, list):
xlist = [obj]
try:
for x in xlist:
nx = session.add(x)
session.commit()
except:
session.rollback()
raise
finally:
session.close()
session = None
[docs]
def update(self, obj):
session = self.Session()
xlist = obj
if not isinstance(obj, list):
xlist = [obj]
try:
for x in xlist:
nx = session.merge(x)
session.add(nx)
session.commit()
except:
session.rollback()
raise
finally:
session.close()
session = None