Add a method to enumerate contacts and appropriate tests.
parent
b2d93e0686
commit
dfb4236769
|
@ -15,3 +15,8 @@ def get_call_data(usernames: List) -> pd.DataFrame:
|
||||||
with db_engine.connect() as connection:
|
with db_engine.connect() as connection:
|
||||||
df_calls = pd.read_sql(query_calls.statement, connection)
|
df_calls = pd.read_sql(query_calls.statement, connection)
|
||||||
return df_calls
|
return df_calls
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_contacts(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
# Calculate frequencies and return in descending order.
|
||||||
|
return comm_df
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from numpy.random import default_rng
|
||||||
|
|
||||||
|
from features.communication import enumerate_contacts, get_call_data
|
||||||
|
|
||||||
|
rng = default_rng()
|
||||||
|
|
||||||
|
|
||||||
|
class CallsFeatures(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
call_rows = 10
|
||||||
|
callers = np.concatenate((
|
||||||
|
np.repeat("caller1", 4),
|
||||||
|
np.repeat("caller2", 3),
|
||||||
|
np.repeat("caller3", 2),
|
||||||
|
np.repeat("caller4", 1)), axis=None)
|
||||||
|
rng.shuffle(callers)
|
||||||
|
self.calls = pd.DataFrame({
|
||||||
|
"id": np.linspace(0, call_rows - 1, num=call_rows, dtype="u4") + 100,
|
||||||
|
"_id": np.linspace(0, call_rows - 1, num=call_rows, dtype="u4"),
|
||||||
|
"timestamp": np.sort(rng.integers(1612169903000, 1614556703000, size=call_rows)),
|
||||||
|
"device_id": "device1",
|
||||||
|
"call_type": rng.integers(1, 3, size=call_rows),
|
||||||
|
"call_duration": rng.integers(0, 600, size=call_rows),
|
||||||
|
"trace": callers,
|
||||||
|
"participant_id": 29
|
||||||
|
})
|
||||||
|
print(self.calls)
|
||||||
|
|
||||||
|
def test_get_calls_data(self):
|
||||||
|
calls_from_db = get_call_data(["nokia_0000003"])
|
||||||
|
self.assertIsNotNone(calls_from_db)
|
||||||
|
|
||||||
|
def test_enumeration(self):
|
||||||
|
enumerate_contacts(self.calls)
|
||||||
|
#Enumerate manually and compare
|
|
@ -4,7 +4,6 @@ from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from config.models import LightSensor, Participant
|
from config.models import LightSensor, Participant
|
||||||
from features.communication import get_call_data
|
|
||||||
from setup import db_uri
|
from setup import db_uri
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +29,3 @@ class DatabaseConnection(unittest.TestCase):
|
||||||
def test_get_light_data(self):
|
def test_get_light_data(self):
|
||||||
light_0 = self.session.query(Participant).join(LightSensor).first()
|
light_0 = self.session.query(Participant).join(LightSensor).first()
|
||||||
self.assertIsNotNone(light_0)
|
self.assertIsNotNone(light_0)
|
||||||
|
|
||||||
def test_get_calls_data(self):
|
|
||||||
calls = get_call_data(["nokia_0000003"])
|
|
||||||
self.assertIsNotNone(calls)
|
|
||||||
|
|
Loading…
Reference in New Issue