Functional effects of mutations across replicates

This notebook aggregates all the global epistasis fits for individual replicates of the effects of mutations on the functional effects of mutations on viral entry. It analyzes both the latent and observed phenotypes from the global epistasis models.

First, import Python modules:

[1]:
import os

import altair as alt

import dms_variants.utils

import numpy

import pandas as pd

import polyclonal
import polyclonal.alphabets
import polyclonal.plot

import yaml
[2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

Get configuration information:

[3]:
# If you are running notebook interactively rather than in pipeline that handles
# working directories, you may have to first `os.chdir` to appropriate directory.

with open("config.yaml") as f:
    config = yaml.safe_load(f)

Read the sequential-to-reference site numbering map:

[4]:
sitenumbering_map = pd.read_csv(config["site_numbering_map"])

Read the mutation effects

The functional selections data frame:

[5]:
func_selections = pd.read_csv(config["functional_selections"])

The mutation effects:

[6]:
phenotypes = ["observed", "latent"]

muteffects = pd.concat(
    [
        pd.read_csv(
            os.path.join(
                config["globalepistasis_dir"],
                f"{selection_name}_muteffects_{phenotype}.csv",
            )
        ).assign(
            selection_name=selection_name,
            phenotype=phenotype,
            times_seen=lambda x: x["times_seen"].astype("Int64"),
            mutation=lambda x: x["wildtype"]
            + x["sequential_site"].astype(str)
            + x["mutant"],
        )
        for selection_name in func_selections["selection_name"]
        for phenotype in phenotypes
    ],
    ignore_index=True,
).merge(
    func_selections,
    on="selection_name",
    how="left",
    validate="many_to_one",
)

assert len(muteffects) == len(muteffects.drop_duplicates())
assert muteffects.drop(columns="times_seen").notnull().all().all()

Correlations among mutation effects

Correlations among replicates:

[7]:
corrs = (
    dms_variants.utils.tidy_to_corr(
        df=muteffects,
        sample_col="selection_name",
        label_col="mutation",
        value_col="effect",
        group_cols="phenotype",
    )
    .assign(r2=lambda x: x["correlation"] ** 2)
    .drop(columns="correlation")
)

for phenotype, phenotype_corr in corrs.groupby("phenotype"):
    corr_chart = (
        alt.Chart(phenotype_corr)
        .encode(
            alt.X("selection_name_1", title=None),
            alt.Y("selection_name_2", title=None),
            color=alt.Color("r2", scale=alt.Scale(zero=True)),
            tooltip=[
                alt.Tooltip(c, format=".3g") if c == "r2" else c
                for c in ["phenotype", "selection_name_1", "selection_name_2", "r2"]
            ],
        )
        .mark_rect(stroke="black")
        .properties(width=alt.Step(15), height=alt.Step(15), title=phenotype)
        .configure_axis(labelLimit=500)
    )

    display(corr_chart)
/fh/fast/bloom_j/computational_notebooks/ceradfor/2023/HIV_Envelope_BF520_DMS_CD4bs_sera/.snakemake/conda/056c8fe5fcb5561ce39a818628b75df6_/lib/python3.11/site-packages/dms_variants/utils.py:360: FutureWarning: The default value of numeric_only in DataFrameGroupBy.corr is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.
  corr = df.corr(method=method).dropna(how="all", axis="index").reset_index()

Compute average mutation effects

Compute averages for each library individually and across all replicates of both libraries. Note that the cross-library averages are not weighted equally by library, but are rather weighted by the number of total replicates for each library:

[8]:
muteffects_avg_method = config["muteffects_avg_method"]
print(f"Defining the average as the {muteffects_avg_method} across replicates")
assert muteffects_avg_method in {"median", "mean"}

n_selections = muteffects["selection_name"].nunique()
assert n_selections == len(func_selections)

groupcols = ["sequential_site", "reference_site", "wildtype", "mutant", "phenotype"]
muteffects_avg = (
    muteffects.groupby(groupcols, as_index=False)
    .aggregate(
        effect=pd.NamedAgg("effect", muteffects_avg_method),
        effect_std=pd.NamedAgg("effect", "std"),
        times_seen=pd.NamedAgg("times_seen", lambda s: s.sum() / n_selections),
        n_libraries=pd.NamedAgg("library", "nunique"),
    )
    .assign(
        times_seen=lambda x: x["times_seen"].where(x["wildtype"] != x["mutant"], pd.NA)
    )
    # add per-library effects
    .merge(
        muteffects.groupby(["library", *groupcols], as_index=False)
        .aggregate(
            effect=pd.NamedAgg("effect", muteffects_avg_method),
        )
        .assign(library=lambda x: x["library"].astype(str) + " effect")
        .pivot_table(index=groupcols, columns="library", values="effect"),
        on=groupcols,
        validate="one_to_one",
        how="left",
    )
)
Defining the average as the mean across replicates

Write average mutation effects to CSVs:

[9]:
for phenotype, df in muteffects_avg.groupby("phenotype"):
    outfile = config[f"muteffects_{phenotype}"]
    os.makedirs(os.path.dirname(outfile), exist_ok=True)
    print(f"Writing {phenotype}-phenotype mutation effects to {outfile}")
    df.to_csv(outfile, index=False, float_format="%.4f")
Writing latent-phenotype mutation effects to results/muteffects_functional/muteffects_latent.csv
Writing observed-phenotype mutation effects to results/muteffects_functional/muteffects_observed.csv

Plot average mutational effects

These are interactive plots. The times_seen is averaged across all replicates, and you can select how many libraries must have data for the mutation. The tooltips show library-specific values as well. Plot using the reference site numbering:

[10]:
plot_kwargs = config["muteffects_plot_kwargs"]

df_to_plot = muteffects_avg.rename(columns={"reference_site": "site"})

if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1

if "n_libraries" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["n_libraries"] = 1

if "region" in sitenumbering_map.columns:
    df_to_plot = df_to_plot.merge(
        sitenumbering_map.rename(columns={"reference_site": "site"})[["site", "region"]]
    )
    plot_kwargs["site_zoom_bar_color_col"] = "region"

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []

plot_kwargs["addtl_tooltip_stats"].append("effect_std")

if any(df_to_plot["site"] != df_to_plot["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

libraries = sorted(muteffects["library"].unique())
for lib in libraries:
    plot_kwargs["addtl_tooltip_stats"].append(f"{lib} effect")

for phenotype, df in df_to_plot.groupby("phenotype"):
    print(f"\n{phenotype} phenotype\n")

    plot_kwargs["plot_title"] = f"functional effects ({phenotype} phenotype)"

    chart = polyclonal.plot.lineplot_and_heatmap(
        data_df=df,
        stat_col="effect",
        category_col="phenotype",
        alphabet=polyclonal.alphabets.biochem_order_aas(
            polyclonal.AAS_WITHSTOP_WITHGAP
        ),
        sites=sitenumbering_map.sort_values("sequential_site")[
            "reference_site"
        ].tolist(),
        **plot_kwargs,
    )

    heatmapfile = (
        os.path.splitext(config[f"muteffects_{phenotype}"])[0]
        + "_heatmap_unformatted.html"
    )
    print(f"Saving to {heatmapfile}")
    chart.save(heatmapfile)

    display(chart)

latent phenotype

Saving to results/muteffects_functional/muteffects_latent_heatmap_unformatted.html

observed phenotype

Saving to results/muteffects_functional/muteffects_observed_heatmap_unformatted.html

Plot distributions of mutation effects

Make plots showing the distribution of mutation effects. We group amino-acid and deletion mutations as missense. You can mouse over points and use the slider to adjust the times seen. We also draw a box showing the median and first and third quartiles:

[11]:
def assign_mut_type(row):
    wt = row["wildtype"]
    m = row["mutant"]
    if wt == m:
        return "synonymous"
    elif m == "*":
        return "stop codon"
    else:
        return "missense"


jitter_sd = 0.12  # how much to jitter points
bar_extent = 2.5 * jitter_sd  # bars extend this much in each direction
random_seed = 1  # random number seed

dist_df = (
    muteffects_avg.query("wildtype != mutant")  # do not plot synonymous
    .query("wildtype != '*'")  # do not plot mutation at stop codon sites
    .assign(
        mutation=lambda x: (
            x["wildtype"] + x["reference_site"].astype(str) + x["mutant"]
        ),
        mut_type=lambda x: x.apply(assign_mut_type, axis=1),
    )
)

mut_types = sorted(dist_df["mut_type"].unique())

for phenotype, df in dist_df.groupby("phenotype"):
    print(f"\nChart for {phenotype=}")

    numpy.random.seed(random_seed)
    df = (
        df[["mutation", *[c for c in df.columns if c != "mutation"]]]
        .assign(
            x=lambda x: (
                x["mut_type"].map(lambda m: mut_types.index(m))
                + numpy.random.normal(0, jitter_sd, len(x)).clip(
                    min=-bar_extent,
                    max=bar_extent,
                )
            ),
            x_start=lambda x: x["mut_type"].map(
                lambda m: mut_types.index(m) - 1.3 * bar_extent
            ),
            x_end=lambda x: x["mut_type"].map(
                lambda m: mut_types.index(m) + 1.3 * bar_extent
            ),
        )
        .drop(
            columns=[
                "reference_site",
                "wildtype",
                "mutant",
                "phenotype",
                "sequential_site",
            ]
        )
    )

    # convert library-specific measurements to str or null displays as 0 in tooltip
    for col in df.columns:
        if col.endswith(" effect"):
            df[col] = df[col].map(lambda x: f"{x:.3g}")

    # build labelExpr as here: https://github.com/vega/vega-lite/issues/7045
    labelExpr = []
    for i, mut_type in enumerate(mut_types):
        if i == len(mut_types) - 1:
            labelExpr.append(f"'{mut_type}'")
        else:
            labelExpr.append(f"datum.label == {i} ? '{mut_type}'")
    labelExpr = " : ".join(labelExpr)

    if (
        "slider_binding_range_kwargs" in plot_kwargs
        and "times_seen" in plot_kwargs["slider_binding_range_kwargs"]
    ):
        binding_range_kwargs = plot_kwargs["slider_binding_range_kwargs"]["times_seen"]
    else:
        binding_range_kwargs = {"min": 1, "max": df["times_seen"].max(), "step": 1}

    times_seen_slider = alt.selection_point(
        fields=["cutoff"],
        value=[{"cutoff": plot_kwargs["addtl_slider_stats"]["times_seen"]}],
        bind=alt.binding_range(name="minimum times seen", **binding_range_kwargs),
    )

    chart_base = (
        alt.Chart(df)
        .transform_filter(alt.datum["times_seen"] >= times_seen_slider["cutoff"])
        .transform_joinaggregate(
            effect_median="median(effect)",
            effect_q1="q1(effect)",
            effect_q3="q3(effect)",
            groupby=["mut_type"],
        )
    )

    chart_points = chart_base.encode(
        x=alt.X(
            "x",
            title="mutation type",
            scale=alt.Scale(domain=[-0.5, len(mut_types) - 0.5], nice=False),
            axis=alt.Axis(values=list(range(len(mut_types))), labelExpr=labelExpr),
        ),
        y=alt.Y("effect", title=f"functional effect ({phenotype} phenotype)"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if df[c].dtype == float else c
            for c in df.columns.tolist()
            if not c.startswith("x")
        ],
    ).mark_circle(opacity=0.15, color="black", size=15)

    chart_median = chart_base.encode(
        x=alt.X("x_start"),
        x2=alt.X2("x_end"),
        y=alt.Y("effect_median:Q"),
    ).mark_rule(color="red", strokeWidth=2)

    chart_box = chart_base.encode(
        x=alt.X("x_start"),
        x2=alt.X2("x_end"),
        y=alt.Y("effect_q1:Q"),
        y2=alt.Y2("effect_q3:Q"),
    ).mark_bar(color="red", filled=False)

    chart = (
        (chart_points + chart_median + chart_box)
        .add_selection(times_seen_slider)
        .properties(height=200, width=90 * len(mut_types))
        .configure_axis(grid=False)
    )

    display(chart)

Chart for phenotype='latent'
/fh/fast/bloom_j/computational_notebooks/ceradfor/2023/HIV_Envelope_BF520_DMS_CD4bs_sera/.snakemake/conda/056c8fe5fcb5561ce39a818628b75df6_/lib/python3.11/site-packages/altair/utils/deprecation.py:65: AltairDeprecationWarning: 'add_selection' is deprecated. Use 'add_params' instead.
  warnings.warn(message, AltairDeprecationWarning)

Chart for phenotype='observed'
/fh/fast/bloom_j/computational_notebooks/ceradfor/2023/HIV_Envelope_BF520_DMS_CD4bs_sera/.snakemake/conda/056c8fe5fcb5561ce39a818628b75df6_/lib/python3.11/site-packages/altair/utils/deprecation.py:65: AltairDeprecationWarning: 'add_selection' is deprecated. Use 'add_params' instead.
  warnings.warn(message, AltairDeprecationWarning)
[ ]: