diff --git a/src/features/phone_esm/straw/process_user_event_related_segments.py b/src/features/phone_esm/straw/process_user_event_related_segments.py index 37a48897..788d99ec 100644 --- a/src/features/phone_esm/straw/process_user_event_related_segments.py +++ b/src/features/phone_esm/straw/process_user_event_related_segments.py @@ -72,6 +72,7 @@ def extract_ers(esm_df, device_id): elif targets_method == "stress_event": # Get and join required data extracted_ers = esm_df.groupby(["device_id", "esm_session"])['timestamp'].apply(lambda x: math.ceil((x.max() - x.min()) / 1000)).reset_index().rename(columns={'timestamp': 'session_length'}) # questionnaire end timestamp + extracted_ers = extracted_ers[extracted_ers["session_length"] <= 15 * 60].reset_index(drop=True) # ensure that the longest duration of the questionnaire anwsering is 15 min session_end_timestamp = esm_df.groupby(['device_id', 'esm_session'])['timestamp'].max().to_frame().rename(columns={'timestamp': 'session_end_timestamp'}) # questionnaire end timestamp se_time = esm_df[esm_df.questionnaire_id == 90.].set_index(['device_id', 'esm_session'])['esm_user_answer'].to_frame().rename(columns={'esm_user_answer': 'se_time'}) se_duration = esm_df[esm_df.questionnaire_id == 91.].set_index(['device_id', 'esm_session'])['esm_user_answer'].to_frame().rename(columns={'esm_user_answer': 'se_duration'}) @@ -87,9 +88,8 @@ def extract_ers(esm_df, device_id): # Transform data into its final form, ready for the extraction extracted_ers.reset_index(inplace=True) - extracted_ers["label"] = f"straw_event_{targets_method}_" + snakemake.params["pid"] + "_" + extracted_ers.index.astype(str).str.zfill(3) - time_before_event = 10 * 60 # in seconds (10 minutes) + time_before_event = 5 * 60 # in seconds (5 minutes) extracted_ers['event_timestamp'] = pd.to_datetime(extracted_ers['se_time']).apply(lambda x: x.timestamp() * 1000).astype('int64') extracted_ers['shift_direction'] = -1 @@ -103,6 +103,9 @@ def extract_ers(esm_df, device_id): extracted_ers['se_duration'] = \ extracted_ers['se_duration'].apply(lambda x: math.ceil(x / 1000) if isinstance(x, int) else (pd.to_datetime(x).hour * 60 + pd.to_datetime(x).minute) * 60) + time_before_event + extracted_ers = extracted_ers[extracted_ers["se_duration"] <= 2.5 * 60 * 60].reset_index(drop=True) # Exclude events that are longer than 2.5 hours + + extracted_ers["label"] = f"straw_event_{targets_method}_" + snakemake.params["pid"] + "_" + extracted_ers.index.astype(str).str.zfill(3) extracted_ers['shift'] = format_timestamp(time_before_event) extracted_ers['length'] = extracted_ers['se_duration'].apply(lambda x: format_timestamp(x)) @@ -127,7 +130,7 @@ if snakemake.params["stage"] == "extract": extracted_ers.to_csv(snakemake.output[0], index=False) elif snakemake.params["stage"] == "merge": - + input_data_files = dict(snakemake.input) straw_events = pd.DataFrame(columns=["label", "event_timestamp", "length", "shift", "shift_direction", "device_id"]) stress_events_targets = pd.DataFrame(columns=["label", "intensity"])