68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import unittest
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from numpy.random import default_rng
|
|
from pandas.testing import assert_series_equal
|
|
|
|
from features.communication import enumerate_contacts, get_call_data
|
|
|
|
rng = default_rng()
|
|
|
|
|
|
class CallsFeatures(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
call_rows = 10
|
|
callers = np.concatenate(
|
|
(
|
|
np.repeat("caller1", 2),
|
|
np.repeat("caller2", 3),
|
|
np.repeat("caller3", 4),
|
|
np.repeat("caller4", 1),
|
|
),
|
|
axis=None,
|
|
)
|
|
rng.shuffle(callers)
|
|
cls.calls = pd.DataFrame(
|
|
{
|
|
"id": np.linspace(0, call_rows - 1, num=call_rows, dtype="u4") + 100,
|
|
"_id": np.linspace(0, call_rows - 1, num=call_rows, dtype="u4"),
|
|
"timestamp": np.sort(
|
|
rng.integers(1612169903000, 1614556703000, size=call_rows)
|
|
),
|
|
"device_id": "device1",
|
|
"call_type": rng.integers(1, 3, size=call_rows),
|
|
"call_duration": rng.integers(0, 600, size=call_rows),
|
|
"trace": callers,
|
|
"participant_id": 29,
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def assertSeriesEqual(cls, a, b, msg=None, **optional):
|
|
try:
|
|
assert_series_equal(a, b, **optional)
|
|
except AssertionError as e:
|
|
raise cls.failureException(msg) from e
|
|
|
|
def setUp(self):
|
|
self.addTypeEqualityFunc(pd.DataFrame, self.assertSeriesEqual)
|
|
|
|
def test_get_calls_data(self):
|
|
calls_from_db = get_call_data(["nokia_0000003"])
|
|
self.assertIsNotNone(calls_from_db)
|
|
|
|
def test_enumeration(self):
|
|
self.calls["contact_id_manual"] = self.calls["trace"].astype("category")
|
|
self.calls["contact_id_manual"] = self.calls[
|
|
"contact_id_manual"
|
|
].cat.rename_categories(
|
|
{"caller1": 2, "caller2": 1, "caller3": 0, "caller4": 3}
|
|
)
|
|
# Enumerate callers manually by their frequency as set in setUpClass.
|
|
self.calls = enumerate_contacts(self.calls)
|
|
self.assertSeriesEqual(
|
|
self.calls["contact_id_manual"], self.calls["contact_id"], check_names=False
|
|
)
|