Compare commits
25 Commits
74b4f9ddbe
...
e2e268148d
Author | SHA1 | Date |
---|---|---|
junos | e2e268148d | |
junos | 00015a3b8d | |
junos | 9a319ac6e5 | |
junos | 777e6f0a58 | |
junos | 2d78aacd18 | |
junos | c88336481e | |
junos | 1bc996413e | |
junos | a2a44c202a | |
junos | 4740e94d37 | |
junos | b1ad8d1309 | |
junos | bb75abcb9b | |
junos | e7fe4e8398 | |
Junos Lukan | cf28aa547a | |
junos | d6f36ec8f8 | |
junos | b06ec6e1ae | |
junos | 622477f19f | |
junos | 577a874288 | |
junos | c8bb481508 | |
junos | 98f1df81c6 | |
junos | ad85f79bc5 | |
junos | 070cfdba80 | |
junos | c6d0e4391e | |
junos | af65d0864f | |
junos | a2180aee54 | |
junos | a06ad0800f |
|
@ -16,6 +16,7 @@ dependencies:
|
|||
- python-dotenv
|
||||
- pytz
|
||||
- seaborn
|
||||
- scikit-learn
|
||||
- sqlalchemy
|
||||
- statsmodels
|
||||
- tabulate
|
|
@ -0,0 +1,63 @@
|
|||
id,timestamp,device_id,_id,double_proximity,accuracy,label,dateTime
|
||||
39017,1565802024310,f67354f7-d675-4b76-80c8-123cc4744a5b,2962,0,3,,2019-08-14T17:00:24Z
|
||||
39018,1565802051075,f67354f7-d675-4b76-80c8-123cc4744a5b,2963,0,3,,2019-08-14T17:00:51Z
|
||||
39019,1565802051354,f67354f7-d675-4b76-80c8-123cc4744a5b,2964,8,3,,2019-08-14T17:00:51Z
|
||||
39089,1565010418305,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,51,5,3,,2019-08-05T13:06:58Z
|
||||
39090,1565010772188,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,52,5,3,,2019-08-05T13:12:52Z
|
||||
39091,1565012334450,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,53,5,3,,2019-08-05T13:38:54Z
|
||||
39092,1565013000660,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,54,5,3,,2019-08-05T13:50:00Z
|
||||
39093,1565022742894,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,55,0,3,,2019-08-05T16:32:22Z
|
||||
39094,1565089295906,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,56,5,3,,2019-08-06T11:01:35Z
|
||||
39095,1565096030817,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,57,0,3,,2019-08-06T12:53:50Z
|
||||
39096,1565096367694,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,58,5,3,,2019-08-06T12:59:27Z
|
||||
39097,1565096408570,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,59,5,3,,2019-08-06T13:00:08Z
|
||||
39098,1565116821528,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,60,5,3,,2019-08-06T18:40:21Z
|
||||
39099,1565131345333,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,61,0,3,,2019-08-06T22:42:25Z
|
||||
39100,1565131375072,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,62,5,3,,2019-08-06T22:42:55Z
|
||||
39101,1565131386353,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,63,0,3,,2019-08-06T22:43:06Z
|
||||
39102,1565131389213,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,64,5,3,,2019-08-06T22:43:09Z
|
||||
39103,1565131448891,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,65,0,3,,2019-08-06T22:44:08Z
|
||||
39104,1565131454131,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,66,5,3,,2019-08-06T22:44:14Z
|
||||
39105,1565176143083,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,67,0,3,,2019-08-07T11:09:03Z
|
||||
39106,1565179569310,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,68,5,3,,2019-08-07T12:06:09Z
|
||||
39107,1565180699173,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,69,5,3,,2019-08-07T12:24:59Z
|
||||
39108,1565182538578,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,70,5,3,,2019-08-07T12:55:38Z
|
||||
39109,1565192592776,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,71,0,3,,2019-08-07T15:43:12Z
|
||||
39110,1565216023797,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,72,5,3,,2019-08-07T22:13:43Z
|
||||
39111,1565248358647,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,73,0,3,,2019-08-08T07:12:38Z
|
||||
39112,1565275859157,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,74,5,3,,2019-08-08T14:50:59Z
|
||||
39113,1565304201431,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,75,0,3,,2019-08-08T22:43:21Z
|
||||
39114,1565304229591,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,76,5,3,,2019-08-08T22:43:49Z
|
||||
39115,1565304262050,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,77,0,3,,2019-08-08T22:44:22Z
|
||||
39116,1565613142970,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,78,5,3,,2019-08-12T12:32:22Z
|
||||
39117,1565618266531,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,79,5,3,,2019-08-12T13:57:46Z
|
||||
39118,1565618410488,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,80,5,3,,2019-08-12T14:00:10Z
|
||||
39119,1565618704942,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,81,5,3,,2019-08-12T14:05:04Z
|
||||
39120,1565619005315,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,82,5,3,,2019-08-12T14:10:05Z
|
||||
39121,1565619405904,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,83,5,3,,2019-08-12T14:16:45Z
|
||||
39122,1565619678037,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,84,5,3,,2019-08-12T14:21:18Z
|
||||
39123,1565621206713,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,85,5,3,,2019-08-12T14:46:46Z
|
||||
39124,1565626622125,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,86,5,3,,2019-08-12T16:17:02Z
|
||||
39125,1565684876738,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,87,5,3,,2019-08-13T08:27:56Z
|
||||
39126,1565684956618,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,88,5,3,,2019-08-13T08:29:16Z
|
||||
39127,1565684965647,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,89,5,3,,2019-08-13T08:29:25Z
|
||||
39128,1565685092246,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,90,5,3,,2019-08-13T08:31:32Z
|
||||
39129,1565685136337,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,91,5,3,,2019-08-13T08:32:16Z
|
||||
39130,1565685147453,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,92,5,3,,2019-08-13T08:32:27Z
|
||||
39131,1565685212523,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,93,5,3,,2019-08-13T08:33:32Z
|
||||
39132,1565703397796,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,94,0,3,,2019-08-13T13:36:37Z
|
||||
39133,1565776203019,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,95,5,3,,2019-08-14T09:50:03Z
|
||||
39134,1565776434168,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,96,5,3,,2019-08-14T09:53:54Z
|
||||
39135,1565776435231,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,97,0,3,,2019-08-14T09:53:55Z
|
||||
39136,1565776443368,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,98,5,3,,2019-08-14T09:54:03Z
|
||||
39137,1565779277109,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,99,0,3,,2019-08-14T10:41:17Z
|
||||
39138,1565780016327,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,100,5,3,,2019-08-14T10:53:36Z
|
||||
39139,1565780027437,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,101,5,3,,2019-08-14T10:53:47Z
|
||||
39140,1565783470934,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,102,5,3,,2019-08-14T11:51:10Z
|
||||
39141,1565783801540,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,103,0,3,,2019-08-14T11:56:41Z
|
||||
39142,1565783802120,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,104,5,3,,2019-08-14T11:56:42Z
|
||||
39143,1565783861495,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,105,5,3,,2019-08-14T11:57:41Z
|
||||
39144,1565785318762,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,106,0,3,,2019-08-14T12:21:58Z
|
||||
39145,1565785319346,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,107,5,3,,2019-08-14T12:21:59Z
|
||||
39146,1565960121019,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,108,5,3,,2019-08-16T12:55:21Z
|
||||
39147,1565960226792,fdb06d4a-ee6e-4336-9a96-fc8d2715f243,109,5,3,,2019-08-16T12:57:06Z
|
|
|
@ -0,0 +1,175 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.11.4
|
||||
# kernelspec:
|
||||
# display_name: straw2analysis
|
||||
# language: python
|
||||
# name: straw2analysis
|
||||
# ---
|
||||
|
||||
# %%
|
||||
# %matplotlib inline
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
import seaborn as sns
|
||||
from sklearn import linear_model
|
||||
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||
|
||||
nb_dir = os.path.split(os.getcwd())[0]
|
||||
if nb_dir not in sys.path:
|
||||
sys.path.append(nb_dir)
|
||||
|
||||
# %%
|
||||
import participants.query_db
|
||||
from features import esm, helper, proximity
|
||||
|
||||
# %% [markdown]
|
||||
# # 1. Get the relevant data
|
||||
|
||||
# %%
|
||||
participants_inactive_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
# Consider only two participants to simplify.
|
||||
ptcp_2 = participants_inactive_usernames[0:2]
|
||||
|
||||
# %% [markdown]
|
||||
# ## 1.1 Labels
|
||||
|
||||
# %%
|
||||
df_esm = esm.get_esm_data(ptcp_2)
|
||||
df_esm_preprocessed = esm.preprocess_esm(df_esm)
|
||||
|
||||
# %%
|
||||
df_esm_PANAS = df_esm_preprocessed[
|
||||
(df_esm_preprocessed["questionnaire_id"] == 8)
|
||||
| (df_esm_preprocessed["questionnaire_id"] == 9)
|
||||
]
|
||||
df_esm_PANAS_clean = esm.clean_up_esm(df_esm_PANAS)
|
||||
|
||||
# %% [markdown]
|
||||
# ## 1.2 Sensor data
|
||||
|
||||
# %%
|
||||
df_proximity = proximity.get_proximity_data(ptcp_2)
|
||||
df_proximity = helper.get_date_from_timestamp(df_proximity)
|
||||
df_proximity = proximity.recode_proximity(df_proximity)
|
||||
|
||||
# %% [markdown]
|
||||
# ## 1.3 Standardization/personalization
|
||||
|
||||
# %% [markdown]
|
||||
# # 2. Grouping/segmentation
|
||||
|
||||
# %%
|
||||
df_esm_PANAS_daily_means = (
|
||||
df_esm_PANAS_clean.groupby(["participant_id", "date_lj", "questionnaire_id"])
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
|
||||
# %%
|
||||
df_esm_PANAS_daily_means = (
|
||||
df_esm_PANAS_daily_means.pivot(
|
||||
index=["participant_id", "date_lj"],
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns={8.0: "PA", 9.0: "NA"})
|
||||
.set_index(["participant_id", "date_lj"])
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
df_proximity_daily_counts = proximity.count_proximity(
|
||||
df_proximity, ["participant_id", "date_lj"]
|
||||
)
|
||||
|
||||
# %%
|
||||
df_proximity_daily_counts
|
||||
|
||||
# %% [markdown]
|
||||
# # 3. Join features (and export to csv?)
|
||||
|
||||
# %%
|
||||
df_full_data_daily_means = df_esm_PANAS_daily_means.join(
|
||||
df_proximity_daily_counts
|
||||
).reset_index()
|
||||
|
||||
# %% [markdown]
|
||||
# # 4. Machine learning model and parameters
|
||||
|
||||
# %%
|
||||
lin_reg_proximity = linear_model.LinearRegression()
|
||||
|
||||
# %% [markdown]
|
||||
# ## 4.1 Validation method
|
||||
|
||||
# %%
|
||||
logo = LeaveOneGroupOut()
|
||||
logo.get_n_splits(
|
||||
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||
df_full_data_daily_means["PA"],
|
||||
groups=df_full_data_daily_means["participant_id"],
|
||||
)
|
||||
|
||||
# %% [markdown]
|
||||
# ## 4.2 Fit results (export?)
|
||||
|
||||
# %%
|
||||
cross_val_score(
|
||||
lin_reg_proximity,
|
||||
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||
df_full_data_daily_means["PA"],
|
||||
groups=df_full_data_daily_means["participant_id"],
|
||||
cv=logo,
|
||||
n_jobs=-1,
|
||||
scoring="r2",
|
||||
)
|
||||
|
||||
# %%
|
||||
lin_reg_proximity.fit(
|
||||
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||
df_full_data_daily_means["PA"],
|
||||
)
|
||||
|
||||
# %%
|
||||
lin_reg_proximity.score(
|
||||
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||
df_full_data_daily_means["PA"],
|
||||
)
|
||||
|
||||
# %% [markdown]
|
||||
# # Merging these into a pipeline
|
||||
|
||||
# %%
|
||||
from machine_learning import pipeline
|
||||
|
||||
# %%
|
||||
ml_pipeline = pipeline.MachineLearningPipeline(
|
||||
labels_questionnaire="PANAS", data_types="proximity"
|
||||
)
|
||||
|
||||
# %%
|
||||
ml_pipeline.get_labels()
|
||||
|
||||
# %% tags=[]
|
||||
ml_pipeline.get_sensor_data()
|
||||
|
||||
# %%
|
||||
ml_pipeline.aggregate_daily()
|
||||
|
||||
# %%
|
||||
ml_pipeline.df_full_data_daily_means
|
||||
|
||||
# %%
|
|
@ -0,0 +1,76 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.11.4
|
||||
# kernelspec:
|
||||
# display_name: straw2analysis
|
||||
# language: python
|
||||
# name: straw2analysis
|
||||
# ---
|
||||
|
||||
# %%
|
||||
# %matplotlib inline
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
nb_dir = os.path.split(os.getcwd())[0]
|
||||
if nb_dir not in sys.path:
|
||||
sys.path.append(nb_dir)
|
||||
|
||||
# %%
|
||||
from config.models import AppCategories, Participant
|
||||
from setup import db_engine, session
|
||||
|
||||
# %%
|
||||
query_app_categories = session.query(AppCategories)
|
||||
with db_engine.connect() as connection:
|
||||
df_app_categories = pd.read_sql(query_app_categories.statement, connection)
|
||||
|
||||
# %%
|
||||
df_app_categories.head()
|
||||
|
||||
# %%
|
||||
df_app_categories["play_store_genre"].value_counts()
|
||||
|
||||
# %%
|
||||
df_category_not_found = df_app_categories[
|
||||
df_app_categories["play_store_genre"] == "not_found"
|
||||
]
|
||||
|
||||
# %%
|
||||
df_category_not_found["play_store_response"].value_counts()
|
||||
|
||||
# %%
|
||||
df_category_not_found["package_name"].value_counts()
|
||||
|
||||
# %%
|
||||
manufacturers = [
|
||||
"samsung",
|
||||
"oneplus",
|
||||
"huawei",
|
||||
"xiaomi",
|
||||
"lge",
|
||||
"motorola",
|
||||
"miui",
|
||||
"lenovo",
|
||||
"oppo",
|
||||
"mediatek",
|
||||
]
|
||||
custom_rom = ["coloros", "lineageos", "myos", "cyanogenmod", "foundation.e"]
|
||||
other = ["android", "wssyncmldm"]
|
||||
rows_os_manufacturer = df_category_not_found["package_name"].str.contains(
|
||||
"|".join(manufacturers + custom_rom + other), case=False
|
||||
)
|
||||
|
||||
# %%
|
||||
with pd.option_context("display.max_rows", None, "display.max_columns", None):
|
||||
display(df_category_not_found.loc[~rows_os_manufacturer])
|
|
@ -13,14 +13,15 @@
|
|||
# name: straw2analysis
|
||||
# ---
|
||||
|
||||
# %%
|
||||
import importlib
|
||||
|
||||
# %%
|
||||
# %matplotlib inline
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# %%
|
||||
import seaborn as sns
|
||||
|
||||
nb_dir = os.path.split(os.getcwd())[0]
|
||||
|
@ -28,21 +29,29 @@ if nb_dir not in sys.path:
|
|||
sys.path.append(nb_dir)
|
||||
|
||||
# %%
|
||||
from features.communication import *
|
||||
from features import communication, helper
|
||||
|
||||
# %%
|
||||
importlib.reload(communication)
|
||||
|
||||
# %% [markdown]
|
||||
# # Example of communication data and feature calculation
|
||||
|
||||
# %%
|
||||
df_calls = get_call_data(["nokia_0000003"])
|
||||
df_calls = communication.get_call_data(["nokia_0000003"])
|
||||
print(df_calls)
|
||||
|
||||
# %%
|
||||
count_comms(df_calls)
|
||||
df_calls = helper.get_date_from_timestamp(df_calls)
|
||||
communication.count_comms(df_calls, ["date_lj"])
|
||||
|
||||
# %%
|
||||
df_sms = get_sms_data(["nokia_0000003"])
|
||||
count_comms(df_sms)
|
||||
df_sms = communication.get_sms_data(["nokia_0000003"])
|
||||
df_sms = helper.get_date_from_timestamp(df_sms)
|
||||
communication.count_comms(df_sms, ["date_lj"])
|
||||
|
||||
# %%
|
||||
communication.calls_sms_features(df_calls, df_sms, ["date_lj"])
|
||||
|
||||
# %% [markdown]
|
||||
# # Call data
|
||||
|
|
|
@ -8,6 +8,43 @@ from setup import db_engine, session
|
|||
call_types = {1: "incoming", 2: "outgoing", 3: "missed"}
|
||||
sms_types = {1: "received", 2: "sent"}
|
||||
|
||||
FEATURES_CALLS = (
|
||||
["no_calls_all"]
|
||||
+ ["no_" + call_type for call_type in call_types.values()]
|
||||
+ ["duration_total_" + call_types.get(1), "duration_total_" + call_types.get(2)]
|
||||
+ ["duration_max_" + call_types.get(1), "duration_max_" + call_types.get(2)]
|
||||
+ ["no_" + call_types.get(1) + "_ratio", "no_" + call_types.get(2) + "_ratio"]
|
||||
+ ["no_contacts_calls"]
|
||||
)
|
||||
|
||||
# FEATURES_CALLS =
|
||||
# ["no_calls_all",
|
||||
# "no_incoming", "no_outgoing", "no_missed",
|
||||
# "duration_total_incoming", "duration_total_outgoing",
|
||||
# "duration_max_incoming", "duration_max_outgoing",
|
||||
# "no_incoming_ratio", "no_outgoing_ratio",
|
||||
# "no_contacts"]
|
||||
|
||||
FEATURES_SMS = (
|
||||
["no_sms_all"]
|
||||
+ ["no_" + sms_type for sms_type in sms_types.values()]
|
||||
+ ["no_" + sms_types.get(1) + "_ratio", "no_" + sms_types.get(2) + "_ratio"]
|
||||
+ ["no_contacts_sms"]
|
||||
)
|
||||
# FEATURES_SMS =
|
||||
# ["no_sms_all",
|
||||
# "no_received", "no_sent",
|
||||
# "no_received_ratio", "no_sent_ratio",
|
||||
# "no_contacts"]
|
||||
|
||||
FEATURES_CONTACT = [
|
||||
"proportion_calls_all",
|
||||
"proportion_calls_incoming",
|
||||
"proportion_calls_outgoing",
|
||||
"proportion_calls_contacts",
|
||||
"proportion_calls_missed_sms_received",
|
||||
]
|
||||
|
||||
|
||||
def get_call_data(usernames: Collection) -> pd.DataFrame:
|
||||
"""
|
||||
|
@ -98,7 +135,7 @@ def enumerate_contacts(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
return comm_df
|
||||
|
||||
|
||||
def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||
def count_comms(comm_df: pd.DataFrame, group_by=None) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate frequencies (and duration) of messages (or calls), grouped by their types.
|
||||
|
||||
|
@ -106,6 +143,9 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
----------
|
||||
comm_df: pd.DataFrame
|
||||
A dataframe of calls or SMSes.
|
||||
group_by: list
|
||||
A list of strings, specifying by which parameters to group.
|
||||
By default, the features are calculated per participant, but could be "date_lj" etc.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -118,51 +158,42 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
* the number of messages by type (received, sent), and
|
||||
* the number of communication contacts by type.
|
||||
"""
|
||||
if group_by is None:
|
||||
group_by = []
|
||||
if "call_type" in comm_df:
|
||||
data_type = "calls"
|
||||
comm_counts = (
|
||||
comm_df.value_counts(subset=["participant_id", "call_type"])
|
||||
.unstack()
|
||||
comm_df.value_counts(subset=group_by + ["participant_id", "call_type"])
|
||||
.unstack(level="call_type", fill_value=0)
|
||||
.rename(columns=call_types)
|
||||
.add_prefix("no_")
|
||||
)
|
||||
# Count calls by type.
|
||||
comm_counts["no_all"] = comm_counts.sum(axis=1)
|
||||
comm_counts["no_calls_all"] = comm_counts.sum(axis=1)
|
||||
# Add a total count of calls.
|
||||
comm_counts = comm_counts.assign(
|
||||
no_incoming_ratio=lambda x: x.no_incoming / x.no_all,
|
||||
no_outgoing_ratio=lambda x: x.no_outgoing / x.no_all,
|
||||
no_incoming_ratio=lambda x: x.no_incoming / x.no_calls_all,
|
||||
no_outgoing_ratio=lambda x: x.no_outgoing / x.no_calls_all,
|
||||
)
|
||||
# Ratio of incoming and outgoing calls to all calls.
|
||||
comm_duration_total = (
|
||||
comm_df.groupby(["participant_id", "call_type"])
|
||||
comm_df.groupby(group_by + ["participant_id", "call_type"])
|
||||
.sum()["call_duration"]
|
||||
.unstack()
|
||||
.unstack(level="call_type", fill_value=0)
|
||||
.rename(columns=call_types)
|
||||
.add_prefix("duration_total_")
|
||||
)
|
||||
# Total call duration by type.
|
||||
comm_duration_max = (
|
||||
comm_df.groupby(["participant_id", "call_type"])
|
||||
comm_df.groupby(group_by + ["participant_id", "call_type"])
|
||||
.max()["call_duration"]
|
||||
.unstack()
|
||||
.unstack(level="call_type", fill_value=0)
|
||||
.rename(columns=call_types)
|
||||
.add_prefix("duration_max_")
|
||||
)
|
||||
# Max call duration by type
|
||||
comm_contacts_counts = (
|
||||
enumerate_contacts(comm_df)
|
||||
.groupby(["participant_id"])
|
||||
.nunique()["contact_id"]
|
||||
.reset_index()
|
||||
.rename(columns={"contact_id": "no_contacts"})
|
||||
)
|
||||
# Number of communication contacts
|
||||
comm_features = comm_counts.join(comm_duration_total)
|
||||
comm_features = comm_features.join(comm_duration_max)
|
||||
comm_features = comm_features.merge(
|
||||
comm_contacts_counts,
|
||||
on="participant_id"
|
||||
).set_index("participant_id")
|
||||
try:
|
||||
comm_features.drop(columns="duration_total_" + call_types[3], inplace=True)
|
||||
comm_features.drop(columns="duration_max_" + call_types[3], inplace=True)
|
||||
|
@ -172,33 +203,30 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
# If there were no missed calls, this exception is raised.
|
||||
# But we are dropping the column anyway, so no need to deal with the exception.
|
||||
elif "message_type" in comm_df:
|
||||
data_type = "sms"
|
||||
comm_counts = (
|
||||
comm_df.value_counts(subset=["participant_id", "message_type"])
|
||||
.unstack()
|
||||
comm_df.value_counts(subset=group_by + ["participant_id", "message_type"])
|
||||
.unstack(level="message_type", fill_value=0)
|
||||
.rename(columns=sms_types)
|
||||
.add_prefix("no_")
|
||||
)
|
||||
comm_counts["no_all"] = comm_counts.sum(axis=1)
|
||||
comm_counts["no_sms_all"] = comm_counts.sum(axis=1)
|
||||
# Add a total count of messages.
|
||||
comm_features = comm_counts.assign(
|
||||
no_received_ratio=lambda x: x.no_received / x.no_all,
|
||||
no_sent_ratio=lambda x: x.no_sent / x.no_all,
|
||||
no_received_ratio=lambda x: x.no_received / x.no_sms_all,
|
||||
no_sent_ratio=lambda x: x.no_sent / x.no_sms_all,
|
||||
)
|
||||
# Ratio of incoming and outgoing messages to all messages.
|
||||
comm_contacts_counts = (
|
||||
enumerate_contacts(comm_df)
|
||||
.groupby(["participant_id"])
|
||||
.nunique()["contact_id"]
|
||||
.reset_index()
|
||||
.rename(columns={"contact_id": "no_contacts"})
|
||||
)
|
||||
# Number of communication contacts
|
||||
comm_features = comm_features.merge(
|
||||
comm_contacts_counts,
|
||||
on="participant_id"
|
||||
).set_index("participant_id")
|
||||
else:
|
||||
raise KeyError("The dataframe contains neither call_type or message_type")
|
||||
comm_contacts_counts = (
|
||||
enumerate_contacts(comm_df)
|
||||
.groupby(group_by + ["participant_id"])
|
||||
.nunique()["contact_id"]
|
||||
.rename("no_contacts_" + data_type)
|
||||
)
|
||||
# Number of communication contacts
|
||||
comm_features = comm_features.join(comm_contacts_counts)
|
||||
return comm_features
|
||||
|
||||
|
||||
|
@ -211,7 +239,7 @@ def contact_features(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
df_enumerated: pd.DataFrame
|
||||
comm_df: pd.DataFrame
|
||||
A dataframe of calls or SMSes.
|
||||
|
||||
Returns
|
||||
|
@ -221,33 +249,33 @@ def contact_features(comm_df: pd.DataFrame) -> pd.DataFrame:
|
|||
"""
|
||||
df_enumerated = enumerate_contacts(comm_df)
|
||||
contacts_count = (
|
||||
df_enumerated
|
||||
.groupby(["participant_id","contact_id"])
|
||||
.size()
|
||||
.reset_index()
|
||||
df_enumerated.groupby(["participant_id", "contact_id"]).size().reset_index()
|
||||
)
|
||||
# Check whether df contains calls or SMS data since some
|
||||
# features we want to calculate are type-specyfic
|
||||
# features we want to calculate are type-specific
|
||||
if "call_duration" in df_enumerated:
|
||||
# Add a column with the total duration of calls between two people
|
||||
duration_count = (
|
||||
df_enumerated
|
||||
.groupby(["participant_id", "contact_id"])
|
||||
df_enumerated.groupby(["participant_id", "contact_id"])
|
||||
# For each participant and for each caller, sum durations of their calls
|
||||
["call_duration"]
|
||||
.sum()
|
||||
.reset_index() # Make index (which is actually the participant id) a normal column
|
||||
.rename(columns={"call_duration": "total_call_duration"})
|
||||
)
|
||||
contacts_count = contacts_count.merge(duration_count, on=["participant_id", "contact_id"])
|
||||
contacts_count.rename(columns={0:"no_calls"}, inplace=True)
|
||||
contacts_count = contacts_count.merge(
|
||||
duration_count, on=["participant_id", "contact_id"]
|
||||
)
|
||||
contacts_count.rename(columns={0: "no_calls"}, inplace=True)
|
||||
else:
|
||||
contacts_count.rename(columns={0:"no_sms"}, inplace=True)
|
||||
contacts_count.rename(columns={0: "no_sms"}, inplace=True)
|
||||
# TODO:Determine work vs non-work contacts by work hours heuristics
|
||||
return contacts_count
|
||||
|
||||
|
||||
def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataFrame:
|
||||
def calls_sms_features(
|
||||
df_calls: pd.DataFrame, df_sms: pd.DataFrame, group_by=None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculates additional features relating calls and sms data.
|
||||
|
||||
|
@ -257,13 +285,16 @@ def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataF
|
|||
A dataframe of calls (return of get_call_data).
|
||||
df_sms: pd.DataFrame
|
||||
A dataframe of SMSes (return of get_sms_data).
|
||||
group_by: list
|
||||
A list of strings, specifying by which parameters to group.
|
||||
By default, the features are calculated per participant, but could be "date_lj" etc.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df_calls_sms: pd.DataFrame
|
||||
The list of features relating calls and sms data for every participant.
|
||||
These are:
|
||||
* proportion_calls:
|
||||
* proportion_calls_all:
|
||||
proportion of calls in total number of communications
|
||||
* proportion_calls_incoming:
|
||||
proportion of incoming calls in total number of incoming/received communications
|
||||
|
@ -274,16 +305,22 @@ def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataF
|
|||
* proportion_calls_contacts:
|
||||
proportion of calls contacts in total number of communication contacts
|
||||
"""
|
||||
count_calls = count_comms(df_calls)
|
||||
count_sms = count_comms(df_sms)
|
||||
if group_by is None:
|
||||
group_by = []
|
||||
count_calls = count_comms(df_calls, group_by)
|
||||
count_sms = count_comms(df_sms, group_by)
|
||||
count_joined = (
|
||||
count_calls.merge(
|
||||
count_sms, on="participant_id", suffixes=("_calls", "_sms")
|
||||
) # Merge calls and sms features
|
||||
.reset_index() # Make participant_id a regular column
|
||||
count_sms,
|
||||
how="outer",
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
validate="one_to_one",
|
||||
)
|
||||
.fillna(0, downcast="infer")
|
||||
.assign(
|
||||
proportion_calls=(
|
||||
lambda x: x.no_all_calls / (x.no_all_calls + x.no_all_sms)
|
||||
proportion_calls_all=(
|
||||
lambda x: x.no_calls_all / (x.no_calls_all + x.no_sms_all)
|
||||
),
|
||||
proportion_calls_incoming=(
|
||||
lambda x: x.no_incoming / (x.no_incoming + x.no_received)
|
||||
|
@ -295,18 +332,11 @@ def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataF
|
|||
lambda x: x.no_outgoing / (x.no_outgoing + x.no_sent)
|
||||
),
|
||||
proportion_calls_contacts=(
|
||||
lambda x: x.no_contacts_calls / (x.no_contacts_calls + x.no_contacts_sms)
|
||||
lambda x: x.no_contacts_calls
|
||||
/ (x.no_contacts_calls + x.no_contacts_sms)
|
||||
)
|
||||
# Calculate new features and create additional columns
|
||||
)[
|
||||
[
|
||||
"participant_id",
|
||||
"proportion_calls",
|
||||
"proportion_calls_incoming",
|
||||
"proportion_calls_outgoing",
|
||||
"proportion_calls_contacts",
|
||||
"proportion_calls_missed_sms_received",
|
||||
]
|
||||
] # Filter out only the relevant features
|
||||
)
|
||||
.fillna(0.5, downcast="infer")
|
||||
)
|
||||
return count_joined
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
import datetime
|
||||
from collections.abc import Collection
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pytz import timezone
|
||||
|
||||
from config.models import ESM, Participant
|
||||
from features import helper
|
||||
from setup import db_engine, session
|
||||
|
||||
TZ_LJ = timezone("Europe/Ljubljana")
|
||||
ESM_STATUS_ANSWERED = 2
|
||||
|
||||
GROUP_SESSIONS_BY = ["participant_id", "device_id", "esm_session"]
|
||||
|
@ -67,14 +65,8 @@ def preprocess_esm(df_esm: pd.DataFrame) -> pd.DataFrame:
|
|||
df_esm_preprocessed: pd.DataFrame
|
||||
A dataframe with added columns: datetime in Ljubljana timezone and all fields from ESM_JSON column.
|
||||
"""
|
||||
df_esm["datetime_lj"] = df_esm["double_esm_user_answer_timestamp"].apply(
|
||||
lambda x: datetime.datetime.fromtimestamp(x / 1000.0, tz=TZ_LJ)
|
||||
)
|
||||
df_esm = df_esm.assign(
|
||||
date_lj=lambda x: (x.datetime_lj - datetime.timedelta(hours=4)).dt.date
|
||||
)
|
||||
# Since daytime EMAs could *theoretically* last beyond midnight, but never after 4 AM,
|
||||
# the datetime is first translated to 4 h earlier.
|
||||
df_esm = helper.get_date_from_timestamp(df_esm)
|
||||
|
||||
df_esm_json = pd.json_normalize(df_esm["esm_json"]).drop(
|
||||
columns=["esm_trigger"]
|
||||
) # The esm_trigger column is already present in the main df.
|
||||
|
@ -256,9 +248,9 @@ def clean_up_esm(df_esm_preprocessed: pd.DataFrame) -> pd.DataFrame:
|
|||
ESM.ESM_TYPE.get("scale"),
|
||||
ESM.ESM_TYPE.get("number"),
|
||||
]
|
||||
df_esm_clean[df_esm_clean["esm_type"].isin(esm_type_numeric)] = df_esm_clean[
|
||||
df_esm_clean.loc[
|
||||
df_esm_clean["esm_type"].isin(esm_type_numeric)
|
||||
].assign(
|
||||
] = df_esm_clean.loc[df_esm_clean["esm_type"].isin(esm_type_numeric)].assign(
|
||||
esm_user_answer_numeric=lambda x: x.esm_user_answer.str.slice(stop=1).astype(
|
||||
int
|
||||
)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import datetime
|
||||
|
||||
import pandas as pd
|
||||
from pytz import timezone
|
||||
|
||||
TZ_LJ = timezone("Europe/Ljubljana")
|
||||
COLUMN_TIMESTAMP = "timestamp"
|
||||
COLUMN_TIMESTAMP_ESM = "double_esm_user_answer_timestamp"
|
||||
|
||||
|
||||
def get_date_from_timestamp(df_aware) -> pd.DataFrame:
|
||||
"""
|
||||
Transform a UNIX timestamp into a datetime (with Ljubljana timezone).
|
||||
Additionally, extract only the date part, where anything until 4 AM is considered the same day.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_aware: pd.DataFrame
|
||||
Any AWARE-type data as defined in models.py.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df_aware: pd.DataFrame
|
||||
The same dataframe with datetime_lj and date_lj columns added.
|
||||
|
||||
"""
|
||||
if COLUMN_TIMESTAMP_ESM in df_aware:
|
||||
column_timestamp = COLUMN_TIMESTAMP_ESM
|
||||
else:
|
||||
column_timestamp = COLUMN_TIMESTAMP
|
||||
|
||||
df_aware["datetime_lj"] = df_aware[column_timestamp].apply(
|
||||
lambda x: datetime.datetime.fromtimestamp(x / 1000.0, tz=TZ_LJ)
|
||||
)
|
||||
df_aware = df_aware.assign(
|
||||
date_lj=lambda x: (x.datetime_lj - datetime.timedelta(hours=4)).dt.date
|
||||
)
|
||||
# Since daytime EMAs could *theoretically* last beyond midnight, but never after 4 AM,
|
||||
# the datetime is first translated to 4 h earlier.
|
||||
|
||||
return df_aware
|
|
@ -5,6 +5,8 @@ import pandas as pd
|
|||
from config.models import Participant, Proximity
|
||||
from setup import db_engine, session
|
||||
|
||||
FEATURES_PROXIMITY = ["freq_prox_near", "prop_prox_near"]
|
||||
|
||||
|
||||
def get_proximity_data(usernames: Collection) -> pd.DataFrame:
|
||||
"""
|
||||
|
@ -28,3 +30,65 @@ def get_proximity_data(usernames: Collection) -> pd.DataFrame:
|
|||
with db_engine.connect() as connection:
|
||||
df_proximity = pd.read_sql(query_proximity.statement, connection)
|
||||
return df_proximity
|
||||
|
||||
|
||||
def recode_proximity(df_proximity: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
This function recodes proximity from a double to a boolean value.
|
||||
Different proximity sensors report different values,
|
||||
but in our data only several distinct values have ever been found.
|
||||
These are therefore converted into "near" and "far" binary values.
|
||||
See expl_proximity.ipynb for additional info.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_proximity: pd.DataFrame
|
||||
A dataframe of proximity data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df_proximity: pd.DataFrame
|
||||
The same dataframe with an additional column bool_prox_near,
|
||||
indicating whether "near" proximity was reported.
|
||||
False values correspond to "far" reported by this sensor.
|
||||
|
||||
"""
|
||||
df_proximity = df_proximity.assign(bool_prox_near=lambda x: x.double_proximity == 0)
|
||||
return df_proximity
|
||||
|
||||
|
||||
def count_proximity(
|
||||
df_proximity: pd.DataFrame, group_by: Collection = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
The function counts how many times a "near" value occurs in proximity
|
||||
and calculates the proportion of this counts to all proximity values (i.e. relative count).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_proximity: pd.DataFrame
|
||||
A dataframe of proximity data.
|
||||
group_by: Collection
|
||||
A list of strings, specifying by which parameters to group.
|
||||
By default, the features are calculated per participant, but could be "date_lj" etc.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df_proximity_features: pd.DataFrame
|
||||
A dataframe with the count of "near" proximity values and their relative count.
|
||||
"""
|
||||
if group_by is None:
|
||||
group_by = ["participant_id"]
|
||||
if "bool_prox_near" not in df_proximity:
|
||||
df_proximity = recode_proximity(df_proximity)
|
||||
df_proximity["bool_prox_far"] = ~df_proximity["bool_prox_near"]
|
||||
df_proximity_features = df_proximity.groupby(group_by).sum()[
|
||||
["bool_prox_near", "bool_prox_far"]
|
||||
]
|
||||
df_proximity_features = df_proximity_features.assign(
|
||||
prop_prox_near=lambda x: x.bool_prox_near / (x.bool_prox_near + x.bool_prox_far)
|
||||
)
|
||||
df_proximity_features = df_proximity_features.rename(
|
||||
columns={"bool_prox_near": "freq_prox_near"}
|
||||
).drop(columns="bool_prox_far", inplace=False)
|
||||
return df_proximity_features
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
QUESTIONNAIRE_IDS = {"PANAS": {"PA": 8.0, "NA": 9.0}}
|
||||
|
||||
QUESTIONNAIRE_IDS_RENAME = {}
|
||||
|
||||
for questionnaire in QUESTIONNAIRE_IDS.items():
|
||||
for k, v in questionnaire[1].items():
|
||||
QUESTIONNAIRE_IDS_RENAME[v] = k
|
|
@ -0,0 +1,125 @@
|
|||
import datetime
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
import participants.query_db
|
||||
from features import esm, helper, proximity
|
||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||
|
||||
|
||||
class MachineLearningPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
labels_questionnaire,
|
||||
labels_scale,
|
||||
data_types,
|
||||
participants_usernames=None,
|
||||
feature_names=None,
|
||||
grouping_variable=None,
|
||||
):
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
self.participants_usernames = participants_usernames
|
||||
self.labels_questionnaire = labels_questionnaire
|
||||
self.data_types = data_types
|
||||
|
||||
if feature_names is None:
|
||||
self.feature_names = []
|
||||
self.df_features = pd.DataFrame()
|
||||
self.labels_scale = labels_scale
|
||||
self.df_labels = pd.DataFrame()
|
||||
self.grouping_variable = grouping_variable
|
||||
self.df_groups = pd.DataFrame()
|
||||
|
||||
self.model = None
|
||||
self.validation_method = None
|
||||
|
||||
self.df_esm = pd.DataFrame()
|
||||
self.df_esm_preprocessed = pd.DataFrame()
|
||||
self.df_esm_interest = pd.DataFrame()
|
||||
self.df_esm_clean = pd.DataFrame()
|
||||
|
||||
self.df_proximity = pd.DataFrame()
|
||||
|
||||
self.df_full_data_daily_means = pd.DataFrame()
|
||||
self.df_esm_daily_means = pd.DataFrame()
|
||||
self.df_proximity_daily_counts = pd.DataFrame()
|
||||
|
||||
def get_labels(self):
|
||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||
if self.labels_questionnaire == "PANAS":
|
||||
self.df_esm_interest = self.df_esm_preprocessed[
|
||||
(
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
||||
)
|
||||
| (
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
||||
)
|
||||
]
|
||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||
|
||||
def get_sensor_data(self):
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity = proximity.get_proximity_data(
|
||||
self.participants_usernames
|
||||
)
|
||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||
|
||||
def aggregate_daily(self):
|
||||
self.df_esm_daily_means = (
|
||||
self.df_esm_clean.groupby(["participant_id", "date_lj", "questionnaire_id"])
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
self.df_esm_daily_means = (
|
||||
self.df_esm_daily_means.pivot(
|
||||
index=["participant_id", "date_lj"],
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id", "date_lj"])
|
||||
)
|
||||
self.df_full_data_daily_means = self.df_esm_daily_means.copy()
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity_daily_counts = proximity.count_proximity(
|
||||
self.df_proximity, ["participant_id", "date_lj"]
|
||||
)
|
||||
self.df_full_data_daily_means = self.df_full_data_daily_means.join(
|
||||
self.df_proximity_daily_counts
|
||||
)
|
||||
|
||||
def assign_columns(self):
|
||||
self.df_features = self.df_full_data_daily_means[self.feature_names]
|
||||
self.df_labels = self.df_full_data_daily_means[self.labels_scale]
|
||||
if self.grouping_variable:
|
||||
self.df_groups = self.df_full_data_daily_means[self.grouping_variable]
|
||||
else:
|
||||
self.df_groups = None
|
||||
|
||||
def validate_model(self):
|
||||
if self.model is None:
|
||||
raise AttributeError(
|
||||
"Please, specify a machine learning model first, by setting the .model attribute."
|
||||
)
|
||||
if self.validation_method is None:
|
||||
raise AttributeError(
|
||||
"Please, specify a cross validation method first, by setting the .validation_method attribute."
|
||||
)
|
||||
cross_val_score(
|
||||
estimator=self.model,
|
||||
X=self.df_features,
|
||||
y=self.df_labels,
|
||||
groups=self.df_groups,
|
||||
cv=self.validation_method,
|
||||
n_jobs=-1,
|
||||
)
|
|
@ -5,7 +5,7 @@ import pandas as pd
|
|||
from numpy.random import default_rng
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from features.communication import count_comms, enumerate_contacts, get_call_data
|
||||
from features.communication import *
|
||||
|
||||
rng = default_rng()
|
||||
|
||||
|
@ -76,10 +76,18 @@ class CallsFeatures(unittest.TestCase):
|
|||
|
||||
def test_count_comms_calls(self):
|
||||
self.features = count_comms(self.calls)
|
||||
print(self.features)
|
||||
self.assertIsInstance(self.features, pd.DataFrame)
|
||||
self.assertCountEqual(self.features.columns.to_list(), FEATURES_CALLS)
|
||||
|
||||
def test_count_comms_sms(self):
|
||||
self.features = count_comms(self.sms)
|
||||
print(self.features)
|
||||
self.assertIsInstance(self.features, pd.DataFrame)
|
||||
self.assertCountEqual(self.features.columns.to_list(), FEATURES_SMS)
|
||||
|
||||
def test_calls_sms_features(self):
|
||||
self.features_call_sms = calls_sms_features(self.calls, self.sms)
|
||||
self.assertIsInstance(self.features_call_sms, pd.DataFrame)
|
||||
self.assertCountEqual(
|
||||
self.features_call_sms.columns.to_list(),
|
||||
FEATURES_CALLS + FEATURES_SMS + FEATURES_CONTACT,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import unittest
|
||||
|
||||
from features.proximity import *
|
||||
|
||||
|
||||
class ProximityFeatures(unittest.TestCase):
|
||||
df_proximity = pd.DataFrame()
|
||||
df_proximity_recoded = pd.DataFrame()
|
||||
df_proximity_features = pd.DataFrame()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.df_proximity = pd.read_csv("../data/example_proximity.csv")
|
||||
cls.df_proximity["participant_id"] = 99
|
||||
|
||||
def test_recode_proximity(self):
|
||||
self.df_proximity_recoded = recode_proximity(self.df_proximity)
|
||||
self.assertIn("bool_prox_near", self.df_proximity_recoded)
|
||||
# Is the recoded column present?
|
||||
self.assertIn(True, self.df_proximity_recoded.bool_prox_near)
|
||||
# Are there "near" values in the data?
|
||||
self.assertIn(False, self.df_proximity_recoded.bool_prox_near)
|
||||
# Are there "far" values in the data?
|
||||
|
||||
def test_count_proximity(self):
|
||||
self.df_proximity_recoded = recode_proximity(self.df_proximity)
|
||||
self.df_proximity_features = count_proximity(self.df_proximity_recoded)
|
||||
print(self.df_proximity_features.columns)
|
||||
self.assertCountEqual(
|
||||
self.df_proximity_features.columns.to_list(), FEATURES_PROXIMITY
|
||||
)
|
Loading…
Reference in New Issue