From 9301d9ec7ffab39fe44b856fb90973dcc2b0aa09 Mon Sep 17 00:00:00 2001 From: junos Date: Fri, 9 Apr 2021 16:01:53 +0200 Subject: [PATCH] Enumerate contacts by grouping by participant ID first. Previously, enumeration worked for only one participant. --- features/communication.py | 32 +++++++++++++++++++++----------- test/test_communication.py | 5 ++++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/features/communication.py b/features/communication.py index 8b9b5d1..13792d5 100644 --- a/features/communication.py +++ b/features/communication.py @@ -72,18 +72,28 @@ def enumerate_contacts(comm_df: pd.DataFrame) -> pd.DataFrame: The altered dataframe with the column contact_id, arranged by frequency. """ contact_counts = ( - comm_df["trace"] - .value_counts(sort=True, ascending=False) - .to_frame(name="frequency") + comm_df.groupby( + ["participant_id", "trace"] + ) # We want to count rows by participant_id and trace + .size() # Count rows + .reset_index() # Make participant_id a regular column. + .rename(columns={0: "freq"}) + .sort_values(["participant_id", "freq"], ascending=False) + # First sort by participant_id and then by call frequency. ) - # A frequency table of different traces (contacts). - contact_counts["contact_id"] = list(range(len(contact_counts.index))) - contact_code = contact_counts["contact_id"].to_dict() - # Create a dictionary translating traces into integers, enumerated by their frequency. - comm_df["contact_id"] = comm_df["trace"].astype("category") - # Transform to categorical data instead of a simple character column. - comm_df["contact_id"] = comm_df["contact_id"].cat.rename_categories(contact_code) - # Recode the contacts into integers from 0 to n_contacts, so that the first one is contacted the most often. + # We now have a frequency table of different traces (contacts) *within* each participant_id. + # Next, enumerate these contacts. + # In other words, recode the contacts into integers from 0 to n_contacts, + # so that the first one is contacted the most often. + contact_ids = ( + contact_counts.groupby("participant_id") # Group again for enummeration. + .cumcount() # Enummerate (count) rows *within* participants. + .to_frame("contact_id") + ) + contact_counts = contact_counts.join(contact_ids) + # Add these contact_ids to the temporary (grouped) data frame. + comm_df = comm_df.merge(contact_counts, on=["participant_id", "trace"]) + # Add these contact_ids to the original data frame. return comm_df diff --git a/test/test_communication.py b/test/test_communication.py index 4d33567..c973f3b 100644 --- a/test/test_communication.py +++ b/test/test_communication.py @@ -63,7 +63,10 @@ class CallsFeatures(unittest.TestCase): # Enumerate callers manually by their frequency as set in setUpClass. self.calls = enumerate_contacts(self.calls) self.assertSeriesEqual( - self.calls["contact_id_manual"], self.calls["contact_id"], check_names=False + self.calls["contact_id_manual"], + self.calls["contact_id"].astype("category"), + check_names=False, + check_category_order=False, ) def test_count_comms(self):