Enumerate contacts by grouping by participant ID first.

Previously, enumeration worked for only one participant.
communication
junos 2021-04-09 16:01:53 +02:00
parent 5ffd85e05b
commit 9301d9ec7f
2 changed files with 25 additions and 12 deletions

View File

@ -72,18 +72,28 @@ def enumerate_contacts(comm_df: pd.DataFrame) -> pd.DataFrame:
The altered dataframe with the column contact_id, arranged by frequency.
"""
contact_counts = (
comm_df["trace"]
.value_counts(sort=True, ascending=False)
.to_frame(name="frequency")
comm_df.groupby(
["participant_id", "trace"]
) # We want to count rows by participant_id and trace
.size() # Count rows
.reset_index() # Make participant_id a regular column.
.rename(columns={0: "freq"})
.sort_values(["participant_id", "freq"], ascending=False)
# First sort by participant_id and then by call frequency.
)
# A frequency table of different traces (contacts).
contact_counts["contact_id"] = list(range(len(contact_counts.index)))
contact_code = contact_counts["contact_id"].to_dict()
# Create a dictionary translating traces into integers, enumerated by their frequency.
comm_df["contact_id"] = comm_df["trace"].astype("category")
# Transform to categorical data instead of a simple character column.
comm_df["contact_id"] = comm_df["contact_id"].cat.rename_categories(contact_code)
# Recode the contacts into integers from 0 to n_contacts, so that the first one is contacted the most often.
# We now have a frequency table of different traces (contacts) *within* each participant_id.
# Next, enumerate these contacts.
# In other words, recode the contacts into integers from 0 to n_contacts,
# so that the first one is contacted the most often.
contact_ids = (
contact_counts.groupby("participant_id") # Group again for enummeration.
.cumcount() # Enummerate (count) rows *within* participants.
.to_frame("contact_id")
)
contact_counts = contact_counts.join(contact_ids)
# Add these contact_ids to the temporary (grouped) data frame.
comm_df = comm_df.merge(contact_counts, on=["participant_id", "trace"])
# Add these contact_ids to the original data frame.
return comm_df

View File

@ -63,7 +63,10 @@ class CallsFeatures(unittest.TestCase):
# Enumerate callers manually by their frequency as set in setUpClass.
self.calls = enumerate_contacts(self.calls)
self.assertSeriesEqual(
self.calls["contact_id_manual"], self.calls["contact_id"], check_names=False
self.calls["contact_id_manual"],
self.calls["contact_id"].astype("category"),
check_names=False,
check_category_order=False,
)
def test_count_comms(self):