From b2d93e06868580a32545d58189736ea379923ce3 Mon Sep 17 00:00:00 2001 From: junos Date: Tue, 5 Jan 2021 17:00:45 +0100 Subject: [PATCH] Add a method to get Calls data. Add a test for this. --- config/environment.yml | 1 + config/models.py | 15 +++------------ features/__init__.py | 0 features/communication.py | 17 +++++++++++++++++ setup.py | 16 ++++++++++++---- test/test_database.py | 11 +++++++---- 6 files changed, 40 insertions(+), 20 deletions(-) create mode 100644 features/__init__.py create mode 100644 features/communication.py diff --git a/config/environment.yml b/config/environment.yml index 55ffcc8..7f1d829 100644 --- a/config/environment.yml +++ b/config/environment.yml @@ -9,6 +9,7 @@ dependencies: - flake8 - jupyterlab - mypy + - pandas - psycopg2 - python-dotenv - sqlalchemy \ No newline at end of file diff --git a/config/models.py b/config/models.py index b217326..1d98ad7 100644 --- a/config/models.py +++ b/config/models.py @@ -1,17 +1,8 @@ from datetime import datetime -from sqlalchemy import ( - TIMESTAMP, - BigInteger, - Boolean, - Column, - Float, - ForeignKey, - Integer, - SmallInteger, - String, - UniqueConstraint, -) +from sqlalchemy import (TIMESTAMP, BigInteger, Boolean, Column, Float, + ForeignKey, Integer, SmallInteger, String, + UniqueConstraint) from sqlalchemy.dialects.postgresql import ARRAY as PSQL_ARRAY from sqlalchemy.dialects.postgresql import INTEGER as PSQL_INTEGER from sqlalchemy.dialects.postgresql import JSONB as PSQL_JSONB diff --git a/features/__init__.py b/features/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/features/communication.py b/features/communication.py new file mode 100644 index 0000000..a915ecd --- /dev/null +++ b/features/communication.py @@ -0,0 +1,17 @@ +from typing import List + +import pandas as pd + +from config.models import Call, Participant +from setup import db_engine, session + + +def get_call_data(usernames: List) -> pd.DataFrame: + query_calls = ( + session.query(Call, Participant.username) + .filter(Participant.id == Call.participant_id) + .filter(Participant.username.in_(usernames)) + ) + with db_engine.connect() as connection: + df_calls = pd.read_sql(query_calls.statement, connection) + return df_calls diff --git a/setup.py b/setup.py index b939deb..68c824c 100644 --- a/setup.py +++ b/setup.py @@ -1,21 +1,29 @@ import os + import sqlalchemy.engine.url +from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from dotenv import load_dotenv load_dotenv() +testing: bool = False + db_password = os.getenv("DB_PASSWORD") db_uri = sqlalchemy.engine.url.URL( - drivername='postgresql+psycopg2', + drivername="postgresql+psycopg2", username="staw_db", password=db_password, host="212.235.208.113", port=5432, - database="staw" + database="staw", ) -db_engine = create_engine('sqlite:///:memory:', echo=True) +if testing: + db_engine = create_engine("sqlite:///:memory:", echo=True) +else: + db_engine = create_engine(db_uri) + Session = sessionmaker(bind=db_engine) +session = Session() diff --git a/test/test_database.py b/test/test_database.py index f313aac..6abd8ce 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -4,6 +4,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from config.models import LightSensor, Participant +from features.communication import get_call_data from setup import db_uri @@ -23,11 +24,13 @@ class DatabaseConnection(unittest.TestCase): connection.close() def test_get_participant(self): - self.participant_0 = self.session.query(Participant).first() - self.assertIsNotNone(self.participant_0) - print(self.participant_0) + participant_0 = self.session.query(Participant).first() + self.assertIsNotNone(participant_0) def test_get_light_data(self): light_0 = self.session.query(Participant).join(LightSensor).first() self.assertIsNotNone(light_0) - print(light_0) + + def test_get_calls_data(self): + calls = get_call_data(["nokia_0000003"]) + self.assertIsNotNone(calls)