diff --git a/config/models.py b/config/models.py index f5edefe..efa7730 100644 --- a/config/models.py +++ b/config/models.py @@ -1,8 +1,17 @@ from datetime import datetime -from sqlalchemy import (TIMESTAMP, BigInteger, Boolean, Column, Float, - ForeignKey, Integer, SmallInteger, String, - UniqueConstraint) +from sqlalchemy import ( + TIMESTAMP, + BigInteger, + Boolean, + Column, + Float, + ForeignKey, + Integer, + SmallInteger, + String, + UniqueConstraint, +) from sqlalchemy.dialects.postgresql import ARRAY as PSQL_ARRAY from sqlalchemy.dialects.postgresql import INTEGER as PSQL_INTEGER from sqlalchemy.dialects.postgresql import JSONB as PSQL_JSONB diff --git a/features/communication.py b/features/communication.py index 69eb319..cee9b9e 100644 --- a/features/communication.py +++ b/features/communication.py @@ -2,7 +2,7 @@ from typing import List import pandas as pd -from config.models import Call, Participant, SMS +from config.models import SMS, Call, Participant from setup import db_engine, session call_types = {1: "incoming", 2: "outgoing", 3: "missed"} @@ -117,9 +117,11 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame: .add_prefix("duration_") ) comm_features = comm_counts.join(comm_duration) - try: comm_features.drop(columns="duration_" + call_types[3], inplace=True) + try: + comm_features.drop(columns="duration_" + call_types[3], inplace=True) # The missed calls are always of 0 duration. - except KeyError: pass + except KeyError: + pass # If there were no missed calls, this exception is raised. # But we are dropping the column anyway, so no need to deal with the exception. elif "message_type" in comm_df: diff --git a/test/test_communication.py b/test/test_communication.py index 4b2eda1..4d33567 100644 --- a/test/test_communication.py +++ b/test/test_communication.py @@ -5,7 +5,7 @@ import pandas as pd from numpy.random import default_rng from pandas.testing import assert_series_equal -from features.communication import enumerate_contacts, get_call_data +from features.communication import count_comms, enumerate_contacts, get_call_data rng = default_rng() @@ -32,7 +32,7 @@ class CallsFeatures(unittest.TestCase): rng.integers(1612169903000, 1614556703000, size=call_rows) ), "device_id": "device1", - "call_type": rng.integers(1, 3, size=call_rows), + "call_type": rng.integers(1, 3, size=call_rows, endpoint=True), "call_duration": rng.integers(0, 600, size=call_rows), "trace": callers, "participant_id": 29, @@ -65,3 +65,7 @@ class CallsFeatures(unittest.TestCase): self.assertSeriesEqual( self.calls["contact_id_manual"], self.calls["contact_id"], check_names=False ) + + def test_count_comms(self): + self.features = count_comms(self.calls) + self.assertIsInstance(self.features, pd.DataFrame)