Package MILWRM

Multiplex Image Labeling With Regional Morphology

Authors: Harsimran Kaur & Cody N. Heiser

Expand source code
# -*- coding: utf-8 -*-
"""
Multiplex Image Labeling With Regional Morphology

Authors: [Harsimran Kaur](https://github.com/KSimi7) & [Cody N. Heiser](https://github.com/codyheiser)
"""
from .MILWRM import (
    mxif_labeler,
    st_labeler,
)
from .MxIF import img
from .ST import (
    blur_features_st,
    map_pixels,
    trim_image,
    assemble_pita,
    show_pita,
)

__all__ = [
    "img",
    "blur_features_st",
    "map_pixels",
    "trim_image",
    "assemble_pita",
    "show_pita",
    "mxif_labeler",
    "st_labeler",
]

from ._version import get_versions

__version__ = get_versions()["version"]
del get_versions

Sub-modules

MILWRM.MILWRM

Classes for assigning tissue domain IDs to multiplex immunofluorescence (MxIF) or 10X Visium spatial transcriptomic (ST) and histological imaging data

MILWRM.MxIF

Functions and classes for analyzing multiplex imaging data

MILWRM.ST

Functions and classes for manipulating 10X Visium spatial transcriptomic (ST) and histological imaging data

Functions

def assemble_pita(adata, features=None, use_rep=None, layer=None, plot_out=True, histo=None, verbose=True, **kwargs)

Cast feature into pixel space to construct gene expression image ("pita")

Parameters

adata : AnnData.anndata
the data
features : list of int or str
Names or indices of features to cast onto spot image. If None, cast all features. If plot_out, first feature in list will be plotted. If not specified and plot_out, first feature (index 0) will be plotted.
use_rep : str
Key from adata.obsm to use for plotting. If None, use adata.X.
layer :str
Key from adata.layers to use for plotting. Ignored if use_rep is not None
plot_out : bool
Show resulting image?
histo : str or None, optional (default=None)
Histology image to show along with pita in gridspec (i.e. "hires", "hires_trim", "lowres"). If None or if plot_out==False, ignore.
verbose : bool, optional (default=True)
Print updates to console
**kwargs
Arguments to pass to show_pita() function

Returns

assembled : np.array
Image of desired expression in pixel space
Expand source code
def assemble_pita(
    adata,
    features=None,
    use_rep=None,
    layer=None,
    plot_out=True,
    histo=None,
    verbose=True,
    **kwargs,
):
    """
    Cast feature into pixel space to construct gene expression image ("pita")

    Parameters
    ----------
    adata : AnnData.anndata
        the data
    features : list of int or str
        Names or indices of features to cast onto spot image. If `None`, cast all
        features. If `plot_out`, first feature in list will be plotted. If not
        specified and `plot_out`, first feature (index 0) will be plotted.
    use_rep : str
        Key from `adata.obsm` to use for plotting. If `None`, use `adata.X`.
    layer :str
        Key from `adata.layers` to use for plotting. Ignored if `use_rep` is not `None`
    plot_out : bool
        Show resulting image?
    histo : str or `None`, optional (default=`None`)
        Histology image to show along with pita in gridspec (i.e. "hires",
        "hires_trim", "lowres"). If `None` or if `plot_out`==`False`, ignore.
    verbose : bool, optional (default=`True`)
        Print updates to console
    **kwargs
        Arguments to pass to `show_pita()` function

    Returns
    -------
    assembled : np.array
        Image of desired expression in pixel space
    """
    assert (
        adata.uns["pixel_map_params"] is not None
    ), "Pixel map not yet created. Run map_pixels() first."

    # coerce features to list if only single string
    if features and not isinstance(features, list):
        features = [features]

    if use_rep is None:
        # use all genes if no gene features specified
        if not features:
            features = adata.var_names  # [adata.var.highly_variable == 1].tolist()
        if layer is None:
            if verbose:
                print(
                    "Assembling pita with {} features from adata.X".format(
                        len(features)
                    )
                )
            mapper = pd.DataFrame(
                adata.X[:, [adata.var_names.get_loc(x) for x in features]],
                index=adata.obs_names,
            )
        else:
            if verbose:
                print(
                    "Assembling pita with {} features from adata.layers['{}']".format(
                        len(features), layer
                    )
                )
            mapper = pd.DataFrame(
                adata.layers[layer][:, [adata.var_names.get_loc(x) for x in features]],
                index=adata.obs_names,
            )
    elif use_rep in [".obs", "obs"]:
        assert features is not None, "Must provide feature(s) from adata.obs"
        if verbose:
            print(
                "Assembling pita with {} features from adata.obs".format(len(features))
            )
        if all(isinstance(x, int) for x in features):
            mapper = adata.obs.iloc[:, features].copy()
        else:
            mapper = adata.obs[features].copy()
    else:
        if not features:
            if verbose:
                print(
                    "Assembling pita with {} features from adata.obsm['{}']".format(
                        adata.obsm[use_rep].shape[1], use_rep
                    )
                )
            mapper = pd.DataFrame(adata.obsm[use_rep], index=adata.obs_names)
        else:
            assert all(
                isinstance(x, int) for x in features
            ), "Features must be integer indices if using rep from adata.obsm"
            if verbose:
                print(
                    "Assembling pita with {} features from adata.obsm['{}']".format(
                        len(features), use_rep
                    )
                )
            mapper = pd.DataFrame(
                adata.obsm[use_rep][:, features], index=adata.obs_names
            )

    # check for categorical columns to force into discrete plots
    discrete_cols = {}
    for col in mapper.columns:
        if pd.api.types.is_categorical_dtype(mapper[col]):
            cat_max = len(mapper[col].cat.categories)
            categories = mapper[col].cat.categories  # save original categories
            mapper[col] = mapper[col].replace(
                {v: k for k, v in dict(enumerate(mapper[col].cat.categories)).items()}
            )
            discrete_cols[mapper.columns.get_loc(col)] = (cat_max, categories)
    # if no categorical columns, pass None to discrete_cols
    if bool(discrete_cols) is False:
        discrete_cols = None

    # cast barcodes into pixel dimensions for reindexing
    if verbose:
        print(
            "Casting barcodes to pixel dimensions and saving to adata.uns['pixel_map']"
        )
    pixel_map = (
        adata.uns["pixel_map_df"].pivot(index="y", columns="x", values="barcode").values
    )

    assembled = np.array(
        [mapper.reindex(index=pixel_map[x], copy=True) for x in range(len(pixel_map))]
    ).squeeze()

    if plot_out:
        # determine where the histo image is in anndata
        if histo is not None:
            assert (
                histo
                in adata.uns["spatial"][list(adata.uns["spatial"].keys())[0]][
                    "images"
                ].keys()
            ), "Must provide one of {} for histo".format(
                adata.uns["spatial"][list(adata.uns["spatial"].keys())[0]][
                    "images"
                ].keys()
            )
            histo = adata.uns["spatial"][list(adata.uns["spatial"].keys())[0]][
                "images"
            ][histo]
        show_pita(
            pita=assembled,
            features=None,
            discrete_features=discrete_cols,
            histo=histo,
            **kwargs,
        )
    if verbose:
        print("Done!")
    return assembled, discrete_cols
def blur_features_st(adata, tmp, spatial_graph_key=None, n_rings=1)

Blur values in an AnnData object using spatial nearest neighbors

Parameters

adata : anndata.AnnData
AnnData object containing Visium data
tmp : pd.DataFrame
containing feature columns from adata.obs that will be blurred
spatial_graph_key : str, optional (default=None)
Key in adata.obsp containing spatial graph connectivities (i.e. "spatial_connectivities"). If None, compute new spatial graph using n_rings in squidpy.
n_rings : int, optional (default=1)
Number of hexagonal rings around each spatial transcriptomics spot to blur features by for capturing regional information. Assumes 10X Genomics Visium platform.

Returns

adata.obs is edited in place with new blurred columns
 
Expand source code
def blur_features_st(adata, tmp, spatial_graph_key=None, n_rings=1):
    """
    Blur values in an `AnnData` object using spatial nearest neighbors

    Parameters
    ----------
    adata : anndata.AnnData
        AnnData object containing Visium data
    tmp : pd.DataFrame
        containing feature columns from adata.obs that will be blurred
    spatial_graph_key : str, optional (default=`None`)
        Key in `adata.obsp` containing spatial graph connectivities (i.e.
        `"spatial_connectivities"`). If `None`, compute new spatial graph using
        `n_rings` in `squidpy`.
    n_rings : int, optional (default=1)
        Number of hexagonal rings around each spatial transcriptomics spot to blur
        features by for capturing regional information. Assumes 10X Genomics Visium
        platform.

    Returns
    -------
    adata.obs is edited in place with new blurred columns
    """
    if spatial_graph_key is not None:
        # use existing spatial graph
        assert (
            spatial_graph_key in adata.obsp.keys()
        ), "Spatial connectivities key '{}' not found.".format(spatial_graph_key)
    else:
        # create spatial graph
        print("Computing spatial graph with {} hexagonal rings".format(n_rings))
        sq.gr.spatial_neighbors(adata, coord_type="grid", n_rings=n_rings)
        spatial_graph_key = "spatial_connectivities"  # set key to expected output
    tmp2 = tmp.copy()  # copy of temporary dataframe for dropping blurred features into
    cols = tmp.columns  # get column names
    # perform blurring by nearest spot neighbors
    for x in range(len(tmp)):
        vals = tmp.iloc[
            list(
                np.argwhere(
                    adata.obsp[spatial_graph_key][
                        x,
                    ]
                )[:, 1]
            )
            + [x],
            :,
        ].mean()
        tmp2.iloc[x, :] = vals.values
    # add blurred features to anndata object
    adata.obs[[x for x in cols]] = tmp.loc[:, cols].values
    adata.obs[["blur_" + x for x in cols]] = tmp2.loc[:, cols].values
    return tmp2.loc[:, cols]
def map_pixels(adata, filter_label='in_tissue', img_key='hires', library_id=None, map_size=None)

Map spot IDs to 'pixel space' by assigning spot ID values to evenly spaced grid

Parameters

adata : AnnData.anndata
The data
filter_label : str or None
adata.obs column key that contains binary labels for filtering barcodes. If None, do not filter.
img_key : str
adata.uns key containing the image to use for mapping
library_id : str, optional (default=None)
Key for finding proper library from adata.uns["spatial"]. By default, find the key from adata.uns["spatial"].keys()
map_size : tuple of int, optional (default=None)
Shape of image to map to. By default, trim to ST coordinates. Can provide shape of whole hires image in adata.uns["spatial"] to yield pitas at full H&E image size.

Returns

adata : AnnData.anndata
with the following attributes: adata.uns["pixel_map_df"] : pd.DataFrame Long-form dataframe of Visium spot barcode IDs, pixel coordinates, and .obs metadata adata.uns["pixel_map"] : np.array Pixel space array of Visium spot barcode IDs
Expand source code
def map_pixels(
    adata,
    filter_label="in_tissue",
    img_key="hires",
    library_id=None,
    map_size=None,
):
    """
    Map spot IDs to 'pixel space' by assigning spot ID values to evenly spaced grid

    Parameters
    ----------
    adata : AnnData.anndata
        The data
    filter_label : str or None
        adata.obs column key that contains binary labels for filtering barcodes. If
        None, do not filter.
    img_key : str
        adata.uns key containing the image to use for mapping
    library_id : str, optional (default=None)
        Key for finding proper library from adata.uns["spatial"]. By default, find
        the key from adata.uns["spatial"].keys()
    map_size : tuple of int, optional (default=None)
        Shape of image to map to. By default, trim to ST coordinates. Can provide
        shape of whole hires image in adata.uns["spatial"] to yield pitas at full
        H&E image size.

    Returns
    -------
    adata : AnnData.anndata
        with the following attributes:
        adata.uns["pixel_map_df"] : pd.DataFrame
            Long-form dataframe of Visium spot barcode IDs, pixel coordinates, and
            .obs metadata
        adata.uns["pixel_map"] : np.array
            Pixel space array of Visium spot barcode IDs
    """
    adata.uns["pixel_map_params"] = {
        "img_key": img_key
    }  # create params dict for future use
    # add library_id key to params
    if library_id is None:
        library_id = adata.uns["pixel_map_params"]["library_id"] = list(
            adata.uns["spatial"].keys()
        )[0]
    else:
        adata.uns["pixel_map_params"]["library_id"] = library_id
    # first get center-to-face pixel distance of hexagonal Visium spots
    dist = euclidean_distances(adata.obsm["spatial"])
    adata.uns["pixel_map_params"]["ctr_to_face"] = (
        np.unique(dist)[np.unique(dist) != 0].min() / 2
    )
    # also save center-to-vertex pixel distance as vadata attribute
    adata.uns["pixel_map_params"]["ctr_to_vert"] = adata.uns["pixel_map_params"][
        "ctr_to_face"
    ] / np.cos(30 * (np.pi / 180))
    # get the spot radius from adata.uns["spatial"] as well
    adata.uns["pixel_map_params"]["radius"] = (
        adata.uns["spatial"][library_id]["scalefactors"]["spot_diameter_fullres"] / 2
    )
    # get scale factor from adata.uns["spatial"]
    adata.uns["pixel_map_params"]["scalef"] = adata.uns["spatial"][library_id][
        "scalefactors"
    ][f"tissue_{img_key}_scalef"]

    if filter_label is not None:
        # create frame of mock pixels to make edges look better
        # x and y deltas for moving rows and columns into a blank frame
        delta_x = (
            adata[adata.obs.array_col == 0, :].obsm["spatial"]
            - adata[adata.obs.array_col == 1, :].obsm["spatial"]
        )
        delta_x = np.mean(list(delta_x[:, 1])) * 2
        delta_y = (
            adata[adata.obs.array_row == 0, :].obsm["spatial"]
            - adata[adata.obs.array_row == 1, :].obsm["spatial"]
        )
        delta_y = np.mean(list(delta_y[:, 1])) * 2
        # left part of frame, translated
        left = adata[
            adata.obs.array_col.isin(
                [adata.obs.array_col.max() - 2, adata.obs.array_col.max() - 3]
            ),
            :,
        ].copy()
        left.obsm["spatial"][..., 0] -= delta_x.astype(int)
        del left.var
        del left.uns
        left.obs[filter_label] = 0
        left.obs_names = ["left" + str(x) for x in range(left.n_obs)]
        # right part of frame, translated
        right = adata[adata.obs.array_col.isin([2, 3]), :].copy()
        right.obsm["spatial"][..., 0] += delta_x.astype(int)
        del right.var
        del right.uns
        right.obs[filter_label] = 0
        right.obs_names = ["right" + str(x) for x in range(right.n_obs)]
        # add sides to orig
        a_sides = adata.concatenate(
            [left, right],
            index_unique=None,
        )
        a_sides.obs.drop(columns="batch", inplace=True)
        # bottom part of frame, translated
        bottom = a_sides[a_sides.obs.array_row == 1, :].copy()
        bottom.obsm["spatial"][..., 1] += delta_y.astype(int)
        bottom.obs_names = ["bottom" + str(x) for x in range(bottom.n_obs)]
        del bottom.var
        del bottom.uns
        bottom.obs[filter_label] = 0
        # top part of frame, translated
        top = a_sides[
            a_sides.obs.array_row == a_sides.obs.array_row.max() - 1, :
        ].copy()
        top.obsm["spatial"][..., 1] -= delta_y.astype(int)
        del top.var
        del top.uns
        top.obs[filter_label] = 0
        top.obs_names = ["top" + str(x) for x in range(top.n_obs)]
        # complete frame
        a_frame = a_sides.concatenate(
            [top, bottom],
            index_unique=None,
        )
        a_frame.uns = adata.uns
        a_frame.var = adata.var
        a_frame.obs.drop(columns="batch", inplace=True)
    else:
        a_frame = adata.copy()

    # determine pixel bounds from spot coords, adding center-to-face distance
    a_frame.uns["pixel_map_params"]["xmin_px"] = int(
        np.floor(
            a_frame.uns["pixel_map_params"]["scalef"]
            * (
                a_frame.obsm["spatial"][:, 0].min()
                - a_frame.uns["pixel_map_params"]["radius"]
            )
        )
    )
    a_frame.uns["pixel_map_params"]["xmax_px"] = int(
        np.ceil(
            a_frame.uns["pixel_map_params"]["scalef"]
            * (
                a_frame.obsm["spatial"][:, 0].max()
                + a_frame.uns["pixel_map_params"]["radius"]
            )
        )
    )
    a_frame.uns["pixel_map_params"]["ymin_px"] = int(
        np.floor(
            a_frame.uns["pixel_map_params"]["scalef"]
            * (
                a_frame.obsm["spatial"][:, 1].min()
                - a_frame.uns["pixel_map_params"]["radius"]
            )
        )
    )
    a_frame.uns["pixel_map_params"]["ymax_px"] = int(
        np.ceil(
            a_frame.uns["pixel_map_params"]["scalef"]
            * (
                a_frame.obsm["spatial"][:, 1].max()
                + a_frame.uns["pixel_map_params"]["radius"]
            )
        )
    )

    print("Creating pixel grid and mapping to nearest barcode coordinates")
    # define grid for pixel space
    if map_size is not None:
        # use provided size
        assert (
            map_size[1]
            >= a_frame.uns["pixel_map_params"]["ymax_px"]
            - a_frame.uns["pixel_map_params"]["ymin_px"]
        ), "Given map_size isn't large enough."
        assert (
            map_size[0]
            >= a_frame.uns["pixel_map_params"]["xmax_px"]
            - a_frame.uns["pixel_map_params"]["xmin_px"]
        ), "Given map_size isn't large enough."
        grid_y, grid_x = np.mgrid[
            0 : map_size[0],
            0 : map_size[1],
        ]
        # set min and max pixels to map_size
        a_frame.uns["pixel_map_params"]["ymin_px"] = 0
        a_frame.uns["pixel_map_params"]["ymax_px"] = map_size[0]
        a_frame.uns["pixel_map_params"]["xmin_px"] = 0
        a_frame.uns["pixel_map_params"]["xmax_px"] = map_size[1]

    else:
        # determine size from a.obsm["spatial"]
        grid_y, grid_x = np.mgrid[
            a_frame.uns["pixel_map_params"]["ymin_px"] : a_frame.uns[
                "pixel_map_params"
            ]["ymax_px"],
            a_frame.uns["pixel_map_params"]["xmin_px"] : a_frame.uns[
                "pixel_map_params"
            ]["xmax_px"],
        ]

    # map barcodes to pixel coordinates
    pixel_coords = np.column_stack((grid_x.ravel(order="C"), grid_y.ravel(order="C")))
    barcode_list = griddata(
        np.multiply(a_frame.obsm["spatial"], a_frame.uns["pixel_map_params"]["scalef"]),
        a_frame.obs_names,
        (pixel_coords[:, 0], pixel_coords[:, 1]),
        method="nearest",
    )
    # save grid_x and grid_y to adata.uns
    a_frame.uns["grid_x"], a_frame.uns["grid_y"] = grid_x, grid_y

    # put results into DataFrame for filtering and reindexing
    print("Saving barcode mapping to adata.uns['pixel_map_df'] and adding metadata")
    a_frame.uns["pixel_map_df"] = pd.DataFrame(pixel_coords, columns=["x", "y"])
    # add barcodes to long-form dataframe
    a_frame.uns["pixel_map_df"]["barcode"] = barcode_list
    # merge master df with self.adata.obs for metadata
    a_frame.uns["pixel_map_df"] = a_frame.uns["pixel_map_df"].merge(
        a_frame.obs, how="outer", left_on="barcode", right_index=True
    )
    # filter using label from adata.obs if desired (i.e. "in_tissue")
    if filter_label is not None:
        print(
            "Filtering barcodes using labels in self.adata.obs['{}']".format(
                filter_label
            )
        )
        # set empty pixels (no Visium spot) to "none"
        a_frame.uns["pixel_map_df"].loc[
            a_frame.uns["pixel_map_df"][filter_label] == 0,
            "barcode",
        ] = "none"
        # subset the entire anndata object using filter_label
        a_frame = a_frame[a_frame.obs[filter_label] == 1, :].copy()
        print("New size: {} spots x {} genes".format(a_frame.n_obs, a_frame.n_vars))

    print("Done!")
    return a_frame
def show_pita(pita, features=None, discrete_features=None, RGB=False, histo=None, label='feature', ncols=4, figsize=(7, 7), cmap='plasma', save_to=None, **kwargs)

Plot assembled pita using plt.imshow()

Parameters

pita : np.array
Image of desired expression in pixel space from .assemble_pita()
features : list of int, optional (default=None)
List of features by index to show in plot. If None, use all features.
discrete_features : dict, optional (default=None)
Dictionary of feature indices (keys) containing discrete (categorical) values (i.e. MILWRM domain). Values are tuple of max_value to pass to plot_single_image_discrete for each discrete feature, and the ordered list of categories for legend plotting. If None, treat all features as continuous.
RGB : bool, optional (default=False)
Treat 3-dimensional array as RGB image
histo : np.array or None, optional (default=None)
Histology image to show along with pita in gridspec. If None, ignore.
label : str, optional (default="feature")
What to title each panel of the gridspec (i.e. "PC" or "usage") or each channel in RGB image. Can also pass list of names e.g. ["NeuN","GFAP", "DAPI"] corresponding to channels.
ncols : int, optional (default=4)
Number of columns for gridspec
figsize : tuple of float, optional (default=(7, 7))
Size in inches of output figure
cmap : str, optional (default="plasma")
Matplotlib colormap to use
save_to : str or None, optional (default=None)
Path to image file to save results. if None, show figure.
**kwargs
Arguments to pass to plt.imshow() function

Returns

Matplotlib object (if plotting one feature or RGB) or gridspec object (for
 

multiple features). Saves plot to file if save_to is not None.

Expand source code
def show_pita(
    pita,
    features=None,
    discrete_features=None,
    RGB=False,
    histo=None,
    label="feature",
    ncols=4,
    figsize=(7, 7),
    cmap="plasma",
    save_to=None,
    **kwargs,
):
    """
    Plot assembled pita using `plt.imshow()`

    Parameters
    ----------
    pita : np.array
        Image of desired expression in pixel space from `.assemble_pita()`
    features : list of int, optional (default=`None`)
        List of features by index to show in plot. If `None`, use all features.
    discrete_features : dict, optional (default=`None`)
        Dictionary of feature indices (keys) containing discrete (categorical) values
        (i.e. MILWRM domain). Values are tuple of `max_value` to pass to
        `plot_single_image_discrete` for each discrete feature, and the ordered list
        of categories for legend plotting. If `None`, treat all features as continuous.
    RGB : bool, optional (default=`False`)
        Treat 3-dimensional array as RGB image
    histo : np.array or `None`, optional (default=`None`)
        Histology image to show along with pita in gridspec. If `None`, ignore.
    label : str, optional (default="feature")
        What to title each panel of the gridspec (i.e. "PC" or "usage") or each
        channel in RGB image. Can also pass list of names e.g. ["NeuN","GFAP",
        "DAPI"] corresponding to channels.
    ncols : int, optional (default=4)
        Number of columns for gridspec
    figsize : tuple of float, optional (default=(7, 7))
        Size in inches of output figure
    cmap : str, optional (default="plasma")
        Matplotlib colormap to use
    save_to : str or None, optional (default=`None`)
        Path to image file to save results. if `None`, show figure.
    **kwargs
        Arguments to pass to `plt.imshow()` function

    Returns
    -------
    Matplotlib object (if plotting one feature or RGB) or gridspec object (for
    multiple features). Saves plot to file if `save_to` is not `None`.
    """
    assert pita.ndim > 1, "Pita does not have enough dimensions: {} given".format(
        pita.ndim
    )
    assert pita.ndim < 4, "Pita has too many dimensions: {} given".format(pita.ndim)
    # if only one feature (2D), plot it quickly
    if (pita.ndim == 2) and histo is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        if discrete_features is not None:
            plot_single_image_discrete(
                image=pita,
                ax=ax,
                # use first value in dict as max
                max_val=list(discrete_features.values())[0][0],
                ticklabels=list(discrete_features.values())[0][1],
                label=label[0] if isinstance(label, list) else label,
                cmap=cmap,
                **kwargs,
            )
        else:
            plot_single_image(
                image=pita,
                ax=ax,
                label=label[0] if isinstance(label, list) else label,
                cmap=cmap,
                **kwargs,
            )
        plt.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig
    if (pita.ndim == 2) and histo is not None:
        n_rows, n_cols = 1, 2  # two images here, histo and RGB
        fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
        # arrange axes as subplots
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        # add plots to axes
        ax = plt.subplot(gs[0])
        plot_single_image_rgb(
            image=histo,
            ax=ax,
            channels=None,
            label="Histology",
            **kwargs,
        )
        ax = plt.subplot(gs[1])
        if discrete_features is not None:
            plot_single_image_discrete(
                image=pita,
                ax=ax,
                # use first value in dict as max
                max_val=list(discrete_features.values())[0][0],
                ticklabels=list(discrete_features.values())[0][1],
                label=label[0] if isinstance(label, list) else label,
                cmap=cmap,
                **kwargs,
            )
        else:
            plot_single_image(
                image=pita,
                ax=ax,
                label=label[0] if isinstance(label, list) else label,
                cmap=cmap,
                **kwargs,
            )
        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig
    if RGB:
        # if third dim has 3 features, treat as RGB and plot it quickly
        assert (pita.ndim == 3) & (
            pita.shape[2] == 3
        ), "Need 3 dimensions and 3 given features for an RGB image; shape = {}; features given = {}".format(
            pita.shape, len(features)
        )
        print("Plotting pita as RGB image")
        if isinstance(label, str):
            # if label is single string, name channels numerically
            channels = ["{}_{}".format(label, x) for x in range(pita.shape[2])]
        else:
            assert (
                len(label) == 3
            ), "Please pass 3 channel names for RGB plot; {} labels given: {}".format(
                len(label), label
            )
            channels = label
        if histo is not None:
            n_rows, n_cols = 1, 2  # two images here, histo and RGB
            fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
            # arrange axes as subplots
            gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
            # add plots to axes
            ax = plt.subplot(gs[0])
            plot_single_image_rgb(
                image=histo,
                ax=ax,
                channels=None,
                label="Histology",
                **kwargs,
            )
            ax = plt.subplot(gs[1])
            plot_single_image_rgb(
                image=pita,
                ax=ax,
                channels=channels,
                label="",
                **kwargs,
            )
            fig.tight_layout()
            if save_to:
                plt.savefig(
                    fname=save_to, transparent=True, bbox_inches="tight", dpi=800
                )
            return fig
        else:
            fig, ax = plt.subplots(1, 1, figsize=figsize)
            plot_single_image_rgb(
                image=pita,
                ax=ax,
                channels=channels,
                label="",
                **kwargs,
            )
            if save_to:
                plt.savefig(
                    fname=save_to, transparent=True, bbox_inches="tight", dpi=300
                )
            return fig
    # if pita has multiple features, plot them in gridspec
    if isinstance(features, int):  # force features into list if single integer
        features = [features]
    # if no features are given, use all of them
    if features is None:
        features = [x for x in range(pita.shape[2])]
    else:
        assert (
            pita.ndim > 2
        ), "Not enough features in pita: shape {}, expecting 3rd dim with length {}".format(
            pita.shape, len(features)
        )
        assert (
            len(features) <= pita.shape[2]
        ), "Too many features given: pita has {}, expected {}".format(
            pita.shape[2], len(features)
        )
    if isinstance(label, str):
        # if label is single string, name channels numerically
        labels = ["{}_{}".format(label, x) for x in features]
    else:
        assert len(label) == len(
            features
        ), "Please provide the same number of labels as features; {} labels given, {} features given.".format(
            len(label), len(features)
        )
        labels = label
    # calculate gridspec dimensions
    if histo is not None:
        labels = ["Histology"] + labels  # append histo to front of labels
        if len(features) + 1 <= ncols:
            n_rows, n_cols = 1, len(features) + 1
        else:
            n_rows, n_cols = ceil((len(features) + 1) / ncols), ncols
    else:
        if len(features) <= ncols:
            n_rows, n_cols = 1, len(features)
        else:
            n_rows, n_cols = ceil(len(features) / ncols), ncols
    fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
    # arrange axes as subplots
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
    # add plots to axes
    i = 0
    if histo is not None:
        # add histology plot to first axes
        ax = plt.subplot(gs[i])
        plot_single_image_rgb(
            image=histo,
            ax=ax,
            channels=None,
            label=labels[i],
            **kwargs,
        )
        i = i + 1
    for feature in features:
        ax = plt.subplot(gs[i])
        if discrete_features is not None:
            if feature in discrete_features:
                plot_single_image_discrete(
                    image=pita[:, :, feature],
                    ax=ax,
                    # use corresponding value in dict as max
                    max_val=discrete_features[feature][0],
                    ticklabels=discrete_features[feature][1],
                    label=labels[i],
                    cmap=cmap,
                    **kwargs,
                )
            else:
                plot_single_image(
                    image=pita[:, :, feature],
                    ax=ax,
                    label=labels[i],
                    cmap=cmap,
                    **kwargs,
                )
        else:
            plot_single_image(
                image=pita[:, :, feature],
                ax=ax,
                label=labels[i],
                cmap=cmap,
                **kwargs,
            )
        i = i + 1
    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig
def trim_image(adata, distance_trim=False, threshold=None, channels=None, plot_out=True, **kwargs)

Trim pixels in image using pixel map output from Visium barcodes

Parameters

adata : AnnData.anndata
The data
distance_trim : bool
Manually trim pixels by distance to nearest Visium spot center
threshold : int or None
Number of pixels from nearest Visium spot center to call barcode ID. Ignored if distance_trim==False.
channels : list of str or None
Names of image channels in axis order. If None, channels are named "ch_0", "ch_1", etc.
plot_out : bool
Plot final trimmed image
**kwargs
Arguments to pass to show_pita() function if plot_out==True

Returns

adata.uns["pixel_map_trim"] : np.array
Contains image with unused pixels set to np.nan
adata.obsm["spatial_trim"] : np.array
Contains spatial coords with adjusted pixel values after image cropping
Expand source code
def trim_image(
    adata, distance_trim=False, threshold=None, channels=None, plot_out=True, **kwargs
):
    """
    Trim pixels in image using pixel map output from Visium barcodes

    Parameters
    ----------
    adata : AnnData.anndata
        The data
    distance_trim : bool
        Manually trim pixels by distance to nearest Visium spot center
    threshold : int or None
        Number of pixels from nearest Visium spot center to call barcode ID. Ignored
        if `distance_trim==False`.
    channels : list of str or None
        Names of image channels in axis order. If None, channels are named "ch_0",
        "ch_1", etc.
    plot_out : bool
        Plot final trimmed image
    **kwargs
        Arguments to pass to `show_pita()` function if `plot_out==True`

    Returns
    -------
    adata.uns["pixel_map_trim"] : np.array
        Contains image with unused pixels set to `np.nan`
    adata.obsm["spatial_trim"] : np.array
        Contains spatial coords with adjusted pixel values after image cropping
    """
    assert (
        adata.uns["pixel_map_params"] is not None
    ), "Pixel map not yet created. Run map_pixels() first."

    print(
        "Cropping image to pixel dimensions and adding values to adata.uns['pixel_map_df']"
    )
    cropped = adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]][
        "images"
    ][adata.uns["pixel_map_params"]["img_key"]].transpose(1, 0, 2)[
        int(adata.uns["pixel_map_params"]["xmin_px"]) : int(
            (adata.uns["pixel_map_params"]["xmax_px"])
        ),
        int(adata.uns["pixel_map_params"]["ymin_px"]) : int(
            (adata.uns["pixel_map_params"]["ymax_px"])
        ),
    ]
    # crop x,y coords and save to .obsm as well
    print("Cropping Visium spot coordinates and saving to adata.obsm['spatial_trim']")
    adata.obsm["spatial_trim"] = adata.obsm["spatial"] - np.repeat(
        [
            [
                adata.uns["pixel_map_params"]["xmin_px"],
                adata.uns["pixel_map_params"]["ymin_px"],
            ]
        ],
        adata.obsm["spatial"].shape[0],
        axis=0,
    )

    # manual trimming of pixels by distance if desired
    if distance_trim:
        print("Calculating pixel distances from spot centers for thresholding")
        tree = cKDTree(adata.obsm["spatial"])
        xi = interpnd._ndim_coords_from_arrays(
            (adata.uns["grid_x"], adata.uns["grid_y"]),
            ndim=adata.obsm["spatial"].shape[1],
        )
        dists, _ = tree.query(xi)

        # determine distance threshold
        if threshold is None:
            threshold = int(adata.uns["pixel_map_params"]["ctr_to_vert"] + 1)
            print(
                "Using distance threshold of {} pixels from adata.uns['pixel_map_params']['ctr_to_vert']".format(
                    threshold
                )
            )

        dist_mask = bin_threshold(dists, threshmax=threshold)
        if plot_out:
            # plot pixel distances from spot centers on image
            show_pita(pita=dists, figsize=(4, 4))
            # plot binary thresholded image
            show_pita(pita=dist_mask, figsize=(4, 4))

        print(
            "Trimming pixels by spot distance and adjusting labels in adata.uns['pixel_map_df']"
        )
        mask_df = pd.DataFrame(dist_mask.T.ravel(order="F"), columns=["manual_trim"])
        adata.uns["pixel_map_df"] = adata.uns["pixel_map_df"].merge(
            mask_df, left_index=True, right_index=True
        )
        adata.uns["pixel_map_df"].loc[
            adata.uns["pixel_map_df"]["manual_trim"] == 1, ["barcode"]
        ] = "none"  # set empty pixels to empty barcode
        adata.uns["pixel_map_df"].drop(
            columns="manual_trim", inplace=True
        )  # remove unneeded label

    if channels is None:
        # if channel names not specified, name them numerically
        channels = ["ch_{}".format(x) for x in range(cropped.shape[2])]
    # cast image intensity values to long-form and add to adata.uns["pixel_map_df"]
    rgb = pd.DataFrame(
        np.column_stack(
            [cropped[:, :, x].ravel(order="F") for x in range(cropped.shape[2])]
        ),
        columns=channels,
    )
    adata.uns["pixel_map_df"] = adata.uns["pixel_map_df"].merge(
        rgb, left_index=True, right_index=True
    )
    adata.uns["pixel_map_df"].loc[
        adata.uns["pixel_map_df"]["barcode"] == "none", channels
    ] = np.nan  # set empty pixels to invalid image intensity value

    # calculate mean image values for each channel and create .obsm key
    adata.obsm["image_means"] = (
        adata.uns["pixel_map_df"]
        .loc[adata.uns["pixel_map_df"]["barcode"] != "none", ["barcode"] + channels]
        .groupby("barcode")
        .mean()
        .values
    )

    print(
        "Saving cropped and trimmed image to adata.uns['spatial']['{}']['images']['{}_trim']".format(
            adata.uns["pixel_map_params"]["library_id"],
            adata.uns["pixel_map_params"]["img_key"],
        )
    )
    adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]]["images"][
        "{}_trim".format(adata.uns["pixel_map_params"]["img_key"])
    ] = np.dstack(
        [
            adata.uns["pixel_map_df"]
            .pivot(index="y", columns="x", values=[channels[x]])
            .values
            for x in range(len(channels))
        ]
    )
    # save scale factor as well
    adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]]["scalefactors"][
        "tissue_{}_trim_scalef".format(adata.uns["pixel_map_params"]["img_key"])
    ] = adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]][
        "scalefactors"
    ][
        "tissue_{}_scalef".format(adata.uns["pixel_map_params"]["img_key"])
    ]
    # plot results if desired
    if plot_out:
        if len(channels) == 3:
            show_pita(
                pita=adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]][
                    "images"
                ]["{}_trim".format(adata.uns["pixel_map_params"]["img_key"])],
                RGB=True,
                label=channels,
                **kwargs,
            )
        else:
            show_pita(
                pita=adata.uns["spatial"][adata.uns["pixel_map_params"]["library_id"]][
                    "images"
                ]["{}_trim".format(adata.uns["pixel_map_params"]["img_key"])],
                RGB=False,
                label=channels,
                **kwargs,
            )
    print("Done!")

Classes

class img (img_arr, channels=None, mask=None)

Initialize img class

Parameters

img_arr : np.ndarray
The image as a numpy array
channels : list of str or None, optional (default=None)
List of channel names corresponding to img.shape[2]. i.e. ("DAPI","GFAP", "NeuH")<code>. If </code>None, channels are named "ch_0", "ch_1", etc.
mask : np.ndarray
Mask defining pixels containing tissue in the image

Returns

img object

Expand source code
class img:
    def __init__(self, img_arr, channels=None, mask=None):
        """
        Initialize img class

        Parameters
        ----------
        img_arr : np.ndarray
            The image as a numpy array
        channels : list of str or None, optional (default=`None`)
            List of channel names corresponding to img.shape[2]. i.e. `("DAPI","GFAP",
            "NeuH")`. If `None`, channels are named "ch_0", "ch_1", etc.
        mask : np.ndarray
            Mask defining pixels containing tissue in the image

        Returns
        -------
        `img` object
        """
        assert (
            img_arr.ndim > 1
        ), "Image does not have enough dimensions: {} given".format(img_arr.ndim)
        self.img = img_arr.astype("float64")  # save image array to .img attribute
        if img_arr.ndim > 2:
            self.n_ch = img_arr.shape[2]  # save number of channels to attribute
        else:
            self.n_ch = 1
        if channels is None:
            # if channel names not specified, name them numerically
            self.ch = ["ch_{}".format(x) for x in range(self.n_ch)]
        else:
            if not isinstance(channels, list):
                raise Exception("Channels must be given in a list")
            assert (
                len(channels) == self.n_ch
            ), "Number of channels must match img_arr.shape[2]"
            self.ch = channels
        if mask is not None:
            # validate that mask matches img_arr
            assert (
                mask.shape == img_arr.shape[:2]
            ), "Shape of mask must match the first two dimensions of img_arr"
        self.mask = mask  # set mask attribute, regardless of value given

    def __repr__(self) -> str:
        """Representation of contents of img object"""
        descr = "img object with {} of {} and shape {}px x {}px\n".format(
            type(self.img),
            self.img.dtype,
            self.img.shape[0],
            self.img.shape[1],
        ) + "{} image channels:\n\t{}".format(self.n_ch, self.ch)
        if self.mask is not None:
            descr += "\n\ntissue mask {} of {} and shape {}px x {}px".format(
                type(self.mask),
                self.mask.dtype,
                self.mask.shape[0],
                self.mask.shape[1],
            )
        return descr

    def copy(self) -> "img":
        """Full copy of img object"""
        new = {}
        for key in ["img", "ch", "mask"]:
            attr = self.__getattribute__(key)
            if attr is not None:
                new[key] = attr.copy()
            else:
                new[key] = None
        return img(img_arr=new["img"], channels=new["ch"], mask=new["mask"])

    def __getitem__(self, channels):
        """Slice img object with channel name(s)"""
        if isinstance(channels, int):  # force channels into list if single integer
            channels = [channels]
        if isinstance(channels, str):  # force channels into int if single string
            channels = [self.ch.index(channels)]
        if checktype(channels):  # force channels into list of int if list of strings
            channels = [self.ch.index(x) for x in channels]

        if channels is None:  # if no channels are given, use all of them
            channels = [x for x in range(self.n_ch)]

        return self.img[:, :, channels]

    @classmethod
    def from_tiffs(cls, tiffdir, channels, common_strings=None, mask=None):
        """
        Initialize img class from `.tif` files

        Parameters
        ----------
        tiffdir : str
            Path to directory containing `.tif` files for a multiplexed image
        channels : list of str
            List of channels present in `.tif` file names (case-sensitive)
            corresponding to img.shape[2] e.g. `("ACTG1","BCATENIN","DAPI",...)`
        common_strings : str, list of str, or `None`, optional (default=None)
            Strings to look for in all `.tif` files in `tiffdir` corresponding to
            `channels` e.g. `("WD86055_", "_region_001.tif")` for files named
            "WD86055_[MARKERNAME]_region_001.tif". If `None`, assume that only 1 image
            for each marker in `channels` is present in `tiffdir`.
        mask : str, optional (default=None)
            Name of mask defining pixels containing tissue in the image, present in
            `.tif` file names (case-sensitive) e.g. "_01_TISSUE_MASK.tif"

        Returns
        -------
        `img` object
        """
        if common_strings is not None:
            # coerce single string to list
            if isinstance(common_strings, str):
                common_strings = [common_strings]
        A = []  # list for dumping numpy arrays
        for channel in channels:
            if common_strings is None:
                # find file matching all common_strings and channel name
                f = [f for f in os.listdir(tiffdir) if channel in f]
            else:
                # find file matching all common_strings and channel name
                f = [
                    f
                    for f in os.listdir(tiffdir)
                    if all(x in f for x in common_strings + [channel])
                ]
            # assertions so we only get one file per channel
            assert len(f) != 0, "No file found with channel {}".format(channel)
            assert (
                len(f) == 1
            ), "More than one match found for file with channel {}".format(channel)
            f = os.path.join(tiffdir, f[0])  # get full path to file for reading
            print("Reading marker {} from {}".format(channel, f))
            tmp = imread(f)  # read in .tif file
            A.append(tmp)  # append numpy array to list
        A_arr = np.dstack(
            A
        )  # stack numpy arrays in new dimension (third dim is channel)
        print("Final image array of shape: {}".format(A_arr.shape))
        # read in tissue mask if available
        if mask is not None:
            f = [f for f in os.listdir(tiffdir) if mask in f]
            # assertions so we only get one mask file
            assert len(f) != 0, "No tissue mask file found"
            assert len(f) == 1, "More than one match found for tissue mask file"
            f = os.path.join(tiffdir, f[0])  # get full path to file for reading
            print("Reading tissue mask from {}".format(f))
            A_mask = imread(f)  # read in .tif file
            assert (
                A_mask.shape == A_arr.shape[:2]
            ), "Mask (shape: {}) is not the same shape as marker images (shape: {})".format(
                A_mask.shape, A_arr.shape[:2]
            )
            print("Final mask array of shape: {}".format(A_mask.shape))
        else:
            A_mask = None
        # generate img object
        return cls(img_arr=A_arr, channels=channels, mask=A_mask)

    @classmethod
    def from_npz(cls, file):
        """
        Initialize img class from `.npz` file

        Parameters
        ----------
        file : str
            Path to `.npz` file containing saved img object and metadata

        Returns
        -------
        `img` object
        """
        print("Loading img object from {}...".format(file))
        tmp = np.load(file)  # load from .npz compressed file
        assert (
            "img" in tmp.files
        ), "Unexpected files in .npz: {}, expected ['img','mask','ch'].".format(
            tmp.files
        )
        A_mask = tmp["mask"] if "mask" in tmp.files else None
        A_ch = list(tmp["ch"]) if "ch" in tmp.files else None
        # generate img object
        return cls(img_arr=tmp["img"], channels=A_ch, mask=A_mask)

    def to_npz(self, file):
        """
        Save img object to compressed `.npz` file

        Parameters
        ----------
        file : str
            Path to `.npz` file in which to save img object and metadata

        Returns
        -------
        Writes object to `file`
        """
        print("Saving img object to {}...".format(file))
        if self.mask is None:
            np.savez_compressed(file, img=self.img, ch=self.ch)
        else:
            np.savez_compressed(file, img=self.img, ch=self.ch, mask=self.mask)

    def clip(self, **kwargs):
        """
        Clips outlier values and rescales intensities

        Parameters
        ----------
        **kwargs
            Keyword args to pass to `clip_values()` function

        Returns
        -------
        Clips outlier values from `self.img`
        """
        self.img = clip_values(self.img, **kwargs)

    def scale(self, **kwargs):
        """
        Scales intensities to [0.0, 1.0]

        Parameters
        ----------
        **kwargs
            Keyword args to pass to `scale_rgb()` function

        Returns
        -------
        Scales intensities of `self.img`
        """
        self.img = scale_rgb(self.img, **kwargs)

    def equalize_hist(self, **kwargs):
        """
        Contrast Limited Adaptive Histogram Equalization (CLAHE)

        Parameters
        ----------
        **kwargs
            Keyword args to pass to `CLAHE()` function

        Returns
        -------
        `self.img` is updated with exposure-adjusted values
        """
        self.img = CLAHE(self.img, **kwargs)

    def blurring(self, filter_name="gaussian", sigma=2, **kwargs):
        """
        Aplying a filter on the images

        Parameters
        ----------
        filter : str
            str to define which type of filter to apply

        sigma : int
            parameter controlling extent of smoothening
        """
        if filter_name == "gaussian":
            print("Applying gaussian filter")
            self.img = filters.gaussian(
                self.img,
                sigma=sigma,
                channel_axis=2,
                **kwargs,
            )
        elif filter_name == "median":
            print("Applying median filter")
            if isinstance(sigma, float):
                sigma = int(sigma)
            for i in range(self.img.shape[2]):
                image_array = self.img[:, :, i]
                image_array_blurred = filters.median(
                    image_array,
                    np.ones(sigma, sigma),
                )
                self.img[:, :, i] = image_array_blurred
        elif filter_name == "bilateral":
            print("Applying biltaral filter")
            self.img = denoise_bilateral(
                self.img, sigma_spatial=sigma, channel_axis=2, **kwargs
            )
        else:
            raise Exception(
                "filter name should be either gaussian, median or bilateral"
            )

    def log_normalize(self, pseudoval=1, mean=None, mask=True):
        """
        Log-normalizes values for each marker with `log10(arr/arr.mean() + pseudoval)`

        Parameters
        ----------
        pseudoval : float
            Value to add to image values prior to log-transforming to avoid issues
            with zeros
        mask : bool, optional (default=True)
            Use tissue mask to determine marker mean factor for normalization. Default
            `True`.

        Returns
        -------
        Log-normalizes values in each channel of `self.img`
        """
        if mean is not None:
            if mask:
                assert self.mask is not None, "No tissue mask available"
                for i in range(self.img.shape[2]):
                    fact = mean[i]
                    self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
            else:
                print("WARNING: Performing normalization without a tissue mask.")
                for i in range(self.img.shape[2]):
                    fact = mean[i]
                    self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
        else:
            print("mean calculated to perform log normalization")
            if mask:
                assert self.mask is not None, "No tissue mask available"
                for i in range(self.img.shape[2]):
                    fact = self.img[:, :, i].mean()
                    self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
            else:
                print("WARNING: Performing normalization without a tissue mask.")
                for i in range(self.img.shape[2]):
                    fact = self.img[:, :, i].mean()
                    self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)

    def subsample_pixels(self, features, fract=0.2, random_state=16):
        """
        Sub-samples fraction of pixels from the image randomly for each channel

        Parameters
        ----------
        features : list of int or str
            Indices or names of MxIF channels to use for tissue labeling
        fract : float, optional (default=0.2)
            Fraction of cluster data from each image to randomly select
            for model building

        Returns
        -------
        tmp : np.array
            Clustering data from `image`

        """
        if isinstance(features, int):  # force features into list if single integer
            features = [features]
        if isinstance(features, str):  # force features into int if single string
            features = [self.ch.index(features)]
        if checktype(features):  # force features into list of int if list of strings
            features = [self.ch.index(x) for x in features]
        if features is None:  # if no features are given, use all of them
            features = [x for x in range(self.n_ch)]
        # subsample data for given image
        np.random.seed(random_state)
        tmp = []
        for i in range(self.img.shape[2]):
            tmp.append(self.img[:, :, i][self.mask != 0])
        tmp = np.column_stack(tmp)
        # select cluster data
        i = np.random.choice(tmp.shape[0], int(tmp.shape[0] * fract))
        tmp = tmp[np.ix_(i, features)]
        return tmp

    def downsample(self, fact, func=np.mean):
        """
        Downsamples image by applying `func` to `fact` pixels in both directions from
        each pixel

        Parameters
        ----------
        fact : int
            Number of pixels in each direction (x & y) to downsample with
        func : function
            Numpy function to apply to squares of size (fact, fact, :) for downsampling
            (e.g. `np.mean`, `np.max`, `np.sum`)

        Returns
        -------
        self.img and self.mask are downsampled accordingly in place
        """
        # downsample mask if mask available
        if self.mask is not None:
            self.mask = block_reduce(
                self.mask, block_size=(fact, fact), func=func, cval=0
            )
        # downsample image
        self.img = block_reduce(self.img, block_size=(fact, fact, 1), func=func, cval=0)

    def calculate_non_zero_mean(self):
        """
        Calculate mean estimator for the given image array avoiding mask pixels or
        pixels with value 0

        Parameters
        ----------

        Returns
        -------
        mean_estimator : list of float
            List of mean estimator (mean*pixel) values for each channel
        pixels : int
            pixel count for the image excluding masked pixels
        """
        image = self.img
        pixels = np.count_nonzero(image != 0)
        mean_estimator = []
        for i in range(image.shape[2]):
            ar = image[:, :, i]
            mean = ar[ar != 0].mean()
            mean_estimator.append(mean * pixels)
        return mean_estimator, pixels

    def create_tissue_mask(self, features=None, fract=0.2):
        """
        Create tissue mask

        Parameters
        ----------
        features : list of int or str
            Indices or names of MxIF channels to use for tissue labeling
        fract : float, optional (default=0.2)
            Fraction of cluster data from each image to randomly select
            for model building

        Returns
        -------
        a numpy array as tissue mask set to self.mask
        """
        # create a copy of the image
        image_cp = self.copy()
        # create a temporary tissue mask that covers no region
        w, h, d = image_cp.img.shape
        image_cp.mask = np.ones((w, h))
        # log normalization on image
        image_cp.log_normalize()
        # apply gaussian filter
        image_cp.img = filters.gaussian(image_cp.img, sigma=2, channel_axis=2)
        # subsample data to build kmeans model
        subsampled_data = image_cp.subsample_pixels(features, fract=fract)
        cluster_data = np.row_stack(subsampled_data)
        # reshape image for prediction
        image_ar_reshape = image_cp.img.reshape((w * h, d))
        # build kmeans model with 2 clusters
        kmeans = KMeans(n_clusters=2, random_state=18).fit(cluster_data)
        labels = kmeans.predict(image_ar_reshape).astype(float)
        tID = labels.reshape((w, h))
        # check if the background is labelled as 0 or 1
        scores = kmeans.cluster_centers_
        mean = scores.mean()
        std = scores.std()
        z_scores = (scores - mean) / std
        if z_scores[0].mean() > 0:
            where_0 = np.where(tID == 0.0)
            tID[where_0] = 0.5
            where_1 = np.where(tID == 1.0)
            tID[where_1] = 0.0
            where_05 = np.where(tID == 0.5)
            tID[where_05] = 1.0
        self.mask = tID

    def show(
        self,
        channels=None,
        RGB=False,
        cbar=False,
        mask_out=True,
        ncols=4,
        figsize=(7, 7),
        save_to=None,
        **kwargs,
    ):
        """
        Plot image

        Parameters
        ----------
        channels : tuple of int or None, optional (default=`None`)
            List of channels by index or name to show
        RGB : bool
            Treat 3- or 4-dimensional array as RGB image. If `False`, plot channels
            individually.
        cbar : bool
            Show colorbar for scale of image intensities if plotting individual
            channels.
        mask_out : bool, optional (default=`True`)
            Mask out non-tissue pixels prior to showing
        ncols : int
            Number of columns for gridspec if plotting individual channels.
        figsize : tuple of float
            Size in inches of output figure.
        save_to : str or None
            Path to image file to save results. If `None`, show figure.
        **kwargs
            Arguments to pass to `plt.imshow()` function.

        Returns
        -------
        Matplotlib object (if plotting one feature or RGB) or gridspec object (for
        multiple features). Saves plot to file if `save_to` is not `None`.
        """
        # if only one feature (2D), plot it quickly
        if self.img.ndim == 2:
            fig = plt.figure(figsize=figsize)
            if self.mask is not None and mask_out:
                im_tmp = self.img.copy()  # make copy for masking
                im_tmp[:, :][self.mask == 0] = np.nan  # area outside mask to NaN
                plt.imshow(im_tmp, **kwargs)
            else:
                plt.imshow(self.img, **kwargs)
            plt.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            if cbar:
                plt.colorbar(shrink=0.8)
            plt.tight_layout()
            if save_to:
                plt.savefig(
                    fname=save_to, transparent=True, bbox_inches="tight", dpi=800
                )
            return fig
        # if image has multiple channels, plot them in gridspec
        if isinstance(channels, int):  # force channels into list if single integer
            channels = [channels]
        if isinstance(channels, str):  # force channels into int if single string
            channels = [self.ch.index(channels)]
        if checktype(channels):  # force channels into list of int if list of strings
            channels = [self.ch.index(x) for x in channels]
        if channels is None:  # if no channels are given, use all of them
            channels = [x for x in range(self.n_ch)]
        assert (
            len(channels) <= self.n_ch
        ), "Too many channels given: image has {}, expected {}".format(
            self.n_ch, len(channels)
        )
        if RGB:
            # if third dim has 3 or 4 features, treat as RGB and plot it quickly
            assert (self.img.ndim == 3) & (
                len(channels) == 3
            ), "Need 3 dimensions and 3 given channels for an RGB image; shape = {}; channels given = {}".format(
                self.img.shape, len(channels)
            )
            fig = plt.figure(figsize=figsize)
            # rearrange channels to specified order
            im_tmp = np.dstack(
                [
                    self.img[:, :, channels[0]],
                    self.img[:, :, channels[1]],
                    self.img[:, :, channels[2]],
                ]
            )
            if self.mask is not None and mask_out:
                for i in [0, 1, 2]:  # for 3-channel image
                    im_tmp[:, :, i][self.mask == 0] = np.nan  # area outside mask NaN
            plt.imshow(im_tmp, **kwargs)
            # add legend for channel IDs
            custom_lines = [
                Line2D([0], [0], color=(1, 0, 0), lw=5),
                Line2D([0], [0], color=(0, 1, 0), lw=5),
                Line2D([0], [0], color=(0, 0, 1), lw=5),
            ]
            plt.legend(custom_lines, [self.ch[x] for x in channels], fontsize="medium")
            plt.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            plt.tight_layout()
            if save_to:
                plt.savefig(
                    fname=save_to, transparent=True, bbox_inches="tight", dpi=300
                )
            return fig
        # calculate gridspec dimensions
        if len(channels) <= ncols:
            n_rows, n_cols = 1, len(channels)
        else:
            n_rows, n_cols = ceil(len(channels) / ncols), ncols
        fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
        # arrange axes as subplots
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        # add plots to axes
        i = 0
        for channel in channels:
            ax = plt.subplot(gs[i])
            if self.mask is not None and mask_out:
                im_tmp = self.img[:, :, channel].copy()  # make copy for masking
                im_tmp[self.mask == 0] = np.nan  # area outside mask NaN
                im = ax.imshow(im_tmp, **kwargs)
            else:
                im = ax.imshow(self.img[:, :, channel], **kwargs)
            ax.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            ax.set_title(
                label=self.ch[channel],
                loc="left",
                fontweight="bold",
                fontsize=16,
            )
            if cbar:
                _ = plt.colorbar(im, shrink=0.8)
            i = i + 1
        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

    def plot_image_histogram(self, channels=None, ncols=4, save_to=None, **kwargs):

        """
        Plot image histogram

        Parameters
        ----------
        channels : tuple of int or None, optional (default=`None`)
            List of channels by index or name to show
        ncols : int
            Number sof columns for gridspec if plotting individual channels.
        save_to : str or None
            Path to image file to save results. If `None`, show figure.
        **kwargs
            Arguments to pass to `plt.imshow()` function.

        Returns
        -------
        Gridspec object (for multiple features). Saves plot to file if `save_to` is
        not `None`.
        """
        # calculate gridspec dimensions
        if len(channels) <= ncols:
            n_rows, n_cols = 1, len(channels)
        else:
            n_rows, n_cols = ceil(len(channels) / ncols), ncols
        fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
        # arrange axes as subplots
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        # add plots to axes
        i = 0
        for channel in channels:
            ax = plt.subplot(gs[i])
            data = self[channel].copy()
            ax.hist(data.ravel(), bins=100, **kwargs)
            ax.set_title(channel, fontweight="bold", fontsize=16)
            i = i + 1
        fig.tight_layout()
        plt.show()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return gs

Static methods

def from_npz(file)

Initialize img class from .npz file

Parameters

file : str
Path to .npz file containing saved img object and metadata

Returns

img object

Expand source code
@classmethod
def from_npz(cls, file):
    """
    Initialize img class from `.npz` file

    Parameters
    ----------
    file : str
        Path to `.npz` file containing saved img object and metadata

    Returns
    -------
    `img` object
    """
    print("Loading img object from {}...".format(file))
    tmp = np.load(file)  # load from .npz compressed file
    assert (
        "img" in tmp.files
    ), "Unexpected files in .npz: {}, expected ['img','mask','ch'].".format(
        tmp.files
    )
    A_mask = tmp["mask"] if "mask" in tmp.files else None
    A_ch = list(tmp["ch"]) if "ch" in tmp.files else None
    # generate img object
    return cls(img_arr=tmp["img"], channels=A_ch, mask=A_mask)
def from_tiffs(tiffdir, channels, common_strings=None, mask=None)

Initialize img class from .tif files

Parameters

tiffdir : str
Path to directory containing .tif files for a multiplexed image
channels : list of str
List of channels present in .tif file names (case-sensitive) corresponding to img.shape[2] e.g. ("ACTG1","BCATENIN","DAPI",...)
common_strings : str, list of str, or None, optional (default=None)
Strings to look for in all .tif files in tiffdir corresponding to channels e.g. ("WD86055_", "_region_001.tif") for files named "WD86055_[MARKERNAME]_region_001.tif". If None, assume that only 1 image for each marker in channels is present in tiffdir.
mask : str, optional (default=None)
Name of mask defining pixels containing tissue in the image, present in .tif file names (case-sensitive) e.g. "_01_TISSUE_MASK.tif"

Returns

img object

Expand source code
@classmethod
def from_tiffs(cls, tiffdir, channels, common_strings=None, mask=None):
    """
    Initialize img class from `.tif` files

    Parameters
    ----------
    tiffdir : str
        Path to directory containing `.tif` files for a multiplexed image
    channels : list of str
        List of channels present in `.tif` file names (case-sensitive)
        corresponding to img.shape[2] e.g. `("ACTG1","BCATENIN","DAPI",...)`
    common_strings : str, list of str, or `None`, optional (default=None)
        Strings to look for in all `.tif` files in `tiffdir` corresponding to
        `channels` e.g. `("WD86055_", "_region_001.tif")` for files named
        "WD86055_[MARKERNAME]_region_001.tif". If `None`, assume that only 1 image
        for each marker in `channels` is present in `tiffdir`.
    mask : str, optional (default=None)
        Name of mask defining pixels containing tissue in the image, present in
        `.tif` file names (case-sensitive) e.g. "_01_TISSUE_MASK.tif"

    Returns
    -------
    `img` object
    """
    if common_strings is not None:
        # coerce single string to list
        if isinstance(common_strings, str):
            common_strings = [common_strings]
    A = []  # list for dumping numpy arrays
    for channel in channels:
        if common_strings is None:
            # find file matching all common_strings and channel name
            f = [f for f in os.listdir(tiffdir) if channel in f]
        else:
            # find file matching all common_strings and channel name
            f = [
                f
                for f in os.listdir(tiffdir)
                if all(x in f for x in common_strings + [channel])
            ]
        # assertions so we only get one file per channel
        assert len(f) != 0, "No file found with channel {}".format(channel)
        assert (
            len(f) == 1
        ), "More than one match found for file with channel {}".format(channel)
        f = os.path.join(tiffdir, f[0])  # get full path to file for reading
        print("Reading marker {} from {}".format(channel, f))
        tmp = imread(f)  # read in .tif file
        A.append(tmp)  # append numpy array to list
    A_arr = np.dstack(
        A
    )  # stack numpy arrays in new dimension (third dim is channel)
    print("Final image array of shape: {}".format(A_arr.shape))
    # read in tissue mask if available
    if mask is not None:
        f = [f for f in os.listdir(tiffdir) if mask in f]
        # assertions so we only get one mask file
        assert len(f) != 0, "No tissue mask file found"
        assert len(f) == 1, "More than one match found for tissue mask file"
        f = os.path.join(tiffdir, f[0])  # get full path to file for reading
        print("Reading tissue mask from {}".format(f))
        A_mask = imread(f)  # read in .tif file
        assert (
            A_mask.shape == A_arr.shape[:2]
        ), "Mask (shape: {}) is not the same shape as marker images (shape: {})".format(
            A_mask.shape, A_arr.shape[:2]
        )
        print("Final mask array of shape: {}".format(A_mask.shape))
    else:
        A_mask = None
    # generate img object
    return cls(img_arr=A_arr, channels=channels, mask=A_mask)

Methods

def blurring(self, filter_name='gaussian', sigma=2, **kwargs)

Aplying a filter on the images

Parameters

filter : str
str to define which type of filter to apply
sigma : int
parameter controlling extent of smoothening
Expand source code
def blurring(self, filter_name="gaussian", sigma=2, **kwargs):
    """
    Aplying a filter on the images

    Parameters
    ----------
    filter : str
        str to define which type of filter to apply

    sigma : int
        parameter controlling extent of smoothening
    """
    if filter_name == "gaussian":
        print("Applying gaussian filter")
        self.img = filters.gaussian(
            self.img,
            sigma=sigma,
            channel_axis=2,
            **kwargs,
        )
    elif filter_name == "median":
        print("Applying median filter")
        if isinstance(sigma, float):
            sigma = int(sigma)
        for i in range(self.img.shape[2]):
            image_array = self.img[:, :, i]
            image_array_blurred = filters.median(
                image_array,
                np.ones(sigma, sigma),
            )
            self.img[:, :, i] = image_array_blurred
    elif filter_name == "bilateral":
        print("Applying biltaral filter")
        self.img = denoise_bilateral(
            self.img, sigma_spatial=sigma, channel_axis=2, **kwargs
        )
    else:
        raise Exception(
            "filter name should be either gaussian, median or bilateral"
        )
def calculate_non_zero_mean(self)

Calculate mean estimator for the given image array avoiding mask pixels or pixels with value 0

Parameters

Returns

mean_estimator : list of float
List of mean estimator (mean*pixel) values for each channel
pixels : int
pixel count for the image excluding masked pixels
Expand source code
def calculate_non_zero_mean(self):
    """
    Calculate mean estimator for the given image array avoiding mask pixels or
    pixels with value 0

    Parameters
    ----------

    Returns
    -------
    mean_estimator : list of float
        List of mean estimator (mean*pixel) values for each channel
    pixels : int
        pixel count for the image excluding masked pixels
    """
    image = self.img
    pixels = np.count_nonzero(image != 0)
    mean_estimator = []
    for i in range(image.shape[2]):
        ar = image[:, :, i]
        mean = ar[ar != 0].mean()
        mean_estimator.append(mean * pixels)
    return mean_estimator, pixels
def clip(self, **kwargs)

Clips outlier values and rescales intensities

Parameters

**kwargs
Keyword args to pass to clip_values() function

Returns

Clips outlier values from self.img

Expand source code
def clip(self, **kwargs):
    """
    Clips outlier values and rescales intensities

    Parameters
    ----------
    **kwargs
        Keyword args to pass to `clip_values()` function

    Returns
    -------
    Clips outlier values from `self.img`
    """
    self.img = clip_values(self.img, **kwargs)
def copy(self) ‑> img

Full copy of img object

Expand source code
def copy(self) -> "img":
    """Full copy of img object"""
    new = {}
    for key in ["img", "ch", "mask"]:
        attr = self.__getattribute__(key)
        if attr is not None:
            new[key] = attr.copy()
        else:
            new[key] = None
    return img(img_arr=new["img"], channels=new["ch"], mask=new["mask"])
def create_tissue_mask(self, features=None, fract=0.2)

Create tissue mask

Parameters

features : list of int or str
Indices or names of MxIF channels to use for tissue labeling
fract : float, optional (default=0.2)
Fraction of cluster data from each image to randomly select for model building

Returns

a numpy array as tissue mask set to self.mask
 
Expand source code
def create_tissue_mask(self, features=None, fract=0.2):
    """
    Create tissue mask

    Parameters
    ----------
    features : list of int or str
        Indices or names of MxIF channels to use for tissue labeling
    fract : float, optional (default=0.2)
        Fraction of cluster data from each image to randomly select
        for model building

    Returns
    -------
    a numpy array as tissue mask set to self.mask
    """
    # create a copy of the image
    image_cp = self.copy()
    # create a temporary tissue mask that covers no region
    w, h, d = image_cp.img.shape
    image_cp.mask = np.ones((w, h))
    # log normalization on image
    image_cp.log_normalize()
    # apply gaussian filter
    image_cp.img = filters.gaussian(image_cp.img, sigma=2, channel_axis=2)
    # subsample data to build kmeans model
    subsampled_data = image_cp.subsample_pixels(features, fract=fract)
    cluster_data = np.row_stack(subsampled_data)
    # reshape image for prediction
    image_ar_reshape = image_cp.img.reshape((w * h, d))
    # build kmeans model with 2 clusters
    kmeans = KMeans(n_clusters=2, random_state=18).fit(cluster_data)
    labels = kmeans.predict(image_ar_reshape).astype(float)
    tID = labels.reshape((w, h))
    # check if the background is labelled as 0 or 1
    scores = kmeans.cluster_centers_
    mean = scores.mean()
    std = scores.std()
    z_scores = (scores - mean) / std
    if z_scores[0].mean() > 0:
        where_0 = np.where(tID == 0.0)
        tID[where_0] = 0.5
        where_1 = np.where(tID == 1.0)
        tID[where_1] = 0.0
        where_05 = np.where(tID == 0.5)
        tID[where_05] = 1.0
    self.mask = tID
def downsample(self, fact, func=<function mean>)

Downsamples image by applying func to fact pixels in both directions from each pixel

Parameters

fact : int
Number of pixels in each direction (x & y) to downsample with
func : function
Numpy function to apply to squares of size (fact, fact, :) for downsampling (e.g. np.mean, np.max, np.sum)

Returns

self.img and self.mask are downsampled accordingly in place
 
Expand source code
def downsample(self, fact, func=np.mean):
    """
    Downsamples image by applying `func` to `fact` pixels in both directions from
    each pixel

    Parameters
    ----------
    fact : int
        Number of pixels in each direction (x & y) to downsample with
    func : function
        Numpy function to apply to squares of size (fact, fact, :) for downsampling
        (e.g. `np.mean`, `np.max`, `np.sum`)

    Returns
    -------
    self.img and self.mask are downsampled accordingly in place
    """
    # downsample mask if mask available
    if self.mask is not None:
        self.mask = block_reduce(
            self.mask, block_size=(fact, fact), func=func, cval=0
        )
    # downsample image
    self.img = block_reduce(self.img, block_size=(fact, fact, 1), func=func, cval=0)
def equalize_hist(self, **kwargs)

Contrast Limited Adaptive Histogram Equalization (CLAHE)

Parameters

**kwargs
Keyword args to pass to CLAHE() function

Returns

self.img is updated with exposure-adjusted values

Expand source code
def equalize_hist(self, **kwargs):
    """
    Contrast Limited Adaptive Histogram Equalization (CLAHE)

    Parameters
    ----------
    **kwargs
        Keyword args to pass to `CLAHE()` function

    Returns
    -------
    `self.img` is updated with exposure-adjusted values
    """
    self.img = CLAHE(self.img, **kwargs)
def log_normalize(self, pseudoval=1, mean=None, mask=True)

Log-normalizes values for each marker with log10(arr/arr.mean() + pseudoval)

Parameters

pseudoval : float
Value to add to image values prior to log-transforming to avoid issues with zeros
mask : bool, optional (default=True)
Use tissue mask to determine marker mean factor for normalization. Default True.

Returns

Log-normalizes values in each channel of self.img

Expand source code
def log_normalize(self, pseudoval=1, mean=None, mask=True):
    """
    Log-normalizes values for each marker with `log10(arr/arr.mean() + pseudoval)`

    Parameters
    ----------
    pseudoval : float
        Value to add to image values prior to log-transforming to avoid issues
        with zeros
    mask : bool, optional (default=True)
        Use tissue mask to determine marker mean factor for normalization. Default
        `True`.

    Returns
    -------
    Log-normalizes values in each channel of `self.img`
    """
    if mean is not None:
        if mask:
            assert self.mask is not None, "No tissue mask available"
            for i in range(self.img.shape[2]):
                fact = mean[i]
                self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
        else:
            print("WARNING: Performing normalization without a tissue mask.")
            for i in range(self.img.shape[2]):
                fact = mean[i]
                self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
    else:
        print("mean calculated to perform log normalization")
        if mask:
            assert self.mask is not None, "No tissue mask available"
            for i in range(self.img.shape[2]):
                fact = self.img[:, :, i].mean()
                self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
        else:
            print("WARNING: Performing normalization without a tissue mask.")
            for i in range(self.img.shape[2]):
                fact = self.img[:, :, i].mean()
                self.img[:, :, i] = np.log10(self.img[:, :, i] / fact + pseudoval)
def plot_image_histogram(self, channels=None, ncols=4, save_to=None, **kwargs)

Plot image histogram

Parameters

channels : tuple of int or None, optional (default=None)
List of channels by index or name to show
ncols : int
Number sof columns for gridspec if plotting individual channels.
save_to : str or None
Path to image file to save results. If None, show figure.
**kwargs
Arguments to pass to plt.imshow() function.

Returns

Gridspec object (for multiple features). Saves plot to file if save_to is not None.

Expand source code
def plot_image_histogram(self, channels=None, ncols=4, save_to=None, **kwargs):

    """
    Plot image histogram

    Parameters
    ----------
    channels : tuple of int or None, optional (default=`None`)
        List of channels by index or name to show
    ncols : int
        Number sof columns for gridspec if plotting individual channels.
    save_to : str or None
        Path to image file to save results. If `None`, show figure.
    **kwargs
        Arguments to pass to `plt.imshow()` function.

    Returns
    -------
    Gridspec object (for multiple features). Saves plot to file if `save_to` is
    not `None`.
    """
    # calculate gridspec dimensions
    if len(channels) <= ncols:
        n_rows, n_cols = 1, len(channels)
    else:
        n_rows, n_cols = ceil(len(channels) / ncols), ncols
    fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
    # arrange axes as subplots
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
    # add plots to axes
    i = 0
    for channel in channels:
        ax = plt.subplot(gs[i])
        data = self[channel].copy()
        ax.hist(data.ravel(), bins=100, **kwargs)
        ax.set_title(channel, fontweight="bold", fontsize=16)
        i = i + 1
    fig.tight_layout()
    plt.show()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return gs
def scale(self, **kwargs)

Scales intensities to [0.0, 1.0]

Parameters

**kwargs
Keyword args to pass to scale_rgb() function

Returns

Scales intensities of self.img

Expand source code
def scale(self, **kwargs):
    """
    Scales intensities to [0.0, 1.0]

    Parameters
    ----------
    **kwargs
        Keyword args to pass to `scale_rgb()` function

    Returns
    -------
    Scales intensities of `self.img`
    """
    self.img = scale_rgb(self.img, **kwargs)
def show(self, channels=None, RGB=False, cbar=False, mask_out=True, ncols=4, figsize=(7, 7), save_to=None, **kwargs)

Plot image

Parameters

channels : tuple of int or None, optional (default=None)
List of channels by index or name to show
RGB : bool
Treat 3- or 4-dimensional array as RGB image. If False, plot channels individually.
cbar : bool
Show colorbar for scale of image intensities if plotting individual channels.
mask_out : bool, optional (default=True)
Mask out non-tissue pixels prior to showing
ncols : int
Number of columns for gridspec if plotting individual channels.
figsize : tuple of float
Size in inches of output figure.
save_to : str or None
Path to image file to save results. If None, show figure.
**kwargs
Arguments to pass to plt.imshow() function.

Returns

Matplotlib object (if plotting one feature or RGB) or gridspec object (for
 

multiple features). Saves plot to file if save_to is not None.

Expand source code
def show(
    self,
    channels=None,
    RGB=False,
    cbar=False,
    mask_out=True,
    ncols=4,
    figsize=(7, 7),
    save_to=None,
    **kwargs,
):
    """
    Plot image

    Parameters
    ----------
    channels : tuple of int or None, optional (default=`None`)
        List of channels by index or name to show
    RGB : bool
        Treat 3- or 4-dimensional array as RGB image. If `False`, plot channels
        individually.
    cbar : bool
        Show colorbar for scale of image intensities if plotting individual
        channels.
    mask_out : bool, optional (default=`True`)
        Mask out non-tissue pixels prior to showing
    ncols : int
        Number of columns for gridspec if plotting individual channels.
    figsize : tuple of float
        Size in inches of output figure.
    save_to : str or None
        Path to image file to save results. If `None`, show figure.
    **kwargs
        Arguments to pass to `plt.imshow()` function.

    Returns
    -------
    Matplotlib object (if plotting one feature or RGB) or gridspec object (for
    multiple features). Saves plot to file if `save_to` is not `None`.
    """
    # if only one feature (2D), plot it quickly
    if self.img.ndim == 2:
        fig = plt.figure(figsize=figsize)
        if self.mask is not None and mask_out:
            im_tmp = self.img.copy()  # make copy for masking
            im_tmp[:, :][self.mask == 0] = np.nan  # area outside mask to NaN
            plt.imshow(im_tmp, **kwargs)
        else:
            plt.imshow(self.img, **kwargs)
        plt.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        if cbar:
            plt.colorbar(shrink=0.8)
        plt.tight_layout()
        if save_to:
            plt.savefig(
                fname=save_to, transparent=True, bbox_inches="tight", dpi=800
            )
        return fig
    # if image has multiple channels, plot them in gridspec
    if isinstance(channels, int):  # force channels into list if single integer
        channels = [channels]
    if isinstance(channels, str):  # force channels into int if single string
        channels = [self.ch.index(channels)]
    if checktype(channels):  # force channels into list of int if list of strings
        channels = [self.ch.index(x) for x in channels]
    if channels is None:  # if no channels are given, use all of them
        channels = [x for x in range(self.n_ch)]
    assert (
        len(channels) <= self.n_ch
    ), "Too many channels given: image has {}, expected {}".format(
        self.n_ch, len(channels)
    )
    if RGB:
        # if third dim has 3 or 4 features, treat as RGB and plot it quickly
        assert (self.img.ndim == 3) & (
            len(channels) == 3
        ), "Need 3 dimensions and 3 given channels for an RGB image; shape = {}; channels given = {}".format(
            self.img.shape, len(channels)
        )
        fig = plt.figure(figsize=figsize)
        # rearrange channels to specified order
        im_tmp = np.dstack(
            [
                self.img[:, :, channels[0]],
                self.img[:, :, channels[1]],
                self.img[:, :, channels[2]],
            ]
        )
        if self.mask is not None and mask_out:
            for i in [0, 1, 2]:  # for 3-channel image
                im_tmp[:, :, i][self.mask == 0] = np.nan  # area outside mask NaN
        plt.imshow(im_tmp, **kwargs)
        # add legend for channel IDs
        custom_lines = [
            Line2D([0], [0], color=(1, 0, 0), lw=5),
            Line2D([0], [0], color=(0, 1, 0), lw=5),
            Line2D([0], [0], color=(0, 0, 1), lw=5),
        ]
        plt.legend(custom_lines, [self.ch[x] for x in channels], fontsize="medium")
        plt.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        plt.tight_layout()
        if save_to:
            plt.savefig(
                fname=save_to, transparent=True, bbox_inches="tight", dpi=300
            )
        return fig
    # calculate gridspec dimensions
    if len(channels) <= ncols:
        n_rows, n_cols = 1, len(channels)
    else:
        n_rows, n_cols = ceil(len(channels) / ncols), ncols
    fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
    # arrange axes as subplots
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
    # add plots to axes
    i = 0
    for channel in channels:
        ax = plt.subplot(gs[i])
        if self.mask is not None and mask_out:
            im_tmp = self.img[:, :, channel].copy()  # make copy for masking
            im_tmp[self.mask == 0] = np.nan  # area outside mask NaN
            im = ax.imshow(im_tmp, **kwargs)
        else:
            im = ax.imshow(self.img[:, :, channel], **kwargs)
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        ax.set_title(
            label=self.ch[channel],
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        if cbar:
            _ = plt.colorbar(im, shrink=0.8)
        i = i + 1
    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig
def subsample_pixels(self, features, fract=0.2, random_state=16)

Sub-samples fraction of pixels from the image randomly for each channel

Parameters

features : list of int or str
Indices or names of MxIF channels to use for tissue labeling
fract : float, optional (default=0.2)
Fraction of cluster data from each image to randomly select for model building

Returns

tmp : np.array
Clustering data from image
Expand source code
def subsample_pixels(self, features, fract=0.2, random_state=16):
    """
    Sub-samples fraction of pixels from the image randomly for each channel

    Parameters
    ----------
    features : list of int or str
        Indices or names of MxIF channels to use for tissue labeling
    fract : float, optional (default=0.2)
        Fraction of cluster data from each image to randomly select
        for model building

    Returns
    -------
    tmp : np.array
        Clustering data from `image`

    """
    if isinstance(features, int):  # force features into list if single integer
        features = [features]
    if isinstance(features, str):  # force features into int if single string
        features = [self.ch.index(features)]
    if checktype(features):  # force features into list of int if list of strings
        features = [self.ch.index(x) for x in features]
    if features is None:  # if no features are given, use all of them
        features = [x for x in range(self.n_ch)]
    # subsample data for given image
    np.random.seed(random_state)
    tmp = []
    for i in range(self.img.shape[2]):
        tmp.append(self.img[:, :, i][self.mask != 0])
    tmp = np.column_stack(tmp)
    # select cluster data
    i = np.random.choice(tmp.shape[0], int(tmp.shape[0] * fract))
    tmp = tmp[np.ix_(i, features)]
    return tmp
def to_npz(self, file)

Save img object to compressed .npz file

Parameters

file : str
Path to .npz file in which to save img object and metadata

Returns

Writes object to file

Expand source code
def to_npz(self, file):
    """
    Save img object to compressed `.npz` file

    Parameters
    ----------
    file : str
        Path to `.npz` file in which to save img object and metadata

    Returns
    -------
    Writes object to `file`
    """
    print("Saving img object to {}...".format(file))
    if self.mask is None:
        np.savez_compressed(file, img=self.img, ch=self.ch)
    else:
        np.savez_compressed(file, img=self.img, ch=self.ch, mask=self.mask)
class mxif_labeler (image_df)

Tissue domain labeling class for multiplex immunofluorescence (MxIF) data

Initialize MxIF tissue labeler class

Parameters

image_df : pd.DataFrame object
Containing MILWRM.MxIF.img objects or str path to compressed npz files, batch names, mean estimator and pixel count for each image in the following column order ['Img', 'batch_names', 'mean estimators', 'pixels']

Returns

Does not return anything. self.images attribute is updated, self.cluster_data attribute is initiated as None.

Expand source code
class mxif_labeler(tissue_labeler):
    """
    Tissue domain labeling class for multiplex immunofluorescence (MxIF) data
    """

    def __init__(self, image_df):
        """
        Initialize MxIF tissue labeler class

        Parameters
        ----------
        image_df : pd.DataFrame object
            Containing MILWRM.MxIF.img objects or str path to compressed npz files,
            batch names, mean estimator and pixel count for each image in the
            following column order ['Img', 'batch_names', 'mean estimators', 'pixels']

        Returns
        -------
        Does not return anything. `self.images` attribute is updated,
        `self.cluster_data` attribute is initiated as `None`.
        """
        tissue_labeler.__init__(self)  # initialize parent class
        # validate the format of the image_df dataframe
        if np.all(
            image_df.columns == ["Img", "batch_names", "mean estimators", "pixels"]
        ):
            self.image_df = image_df
        else:
            raise Exception(
                "Image_df must be given with these columns in this format ['Img', 'batch_names', 'mean estimators', 'pixels']"
            )
        if self.image_df["Img"].apply(isinstance, args=[img]).all():
            self.use_paths = False
        elif self.image_df["Img"].apply(isinstance, args=[str]).all():
            self.use_paths = True
        else:
            raise Exception(
                "Img column in the dataframe should be either str for paths to the files or mxif.img object"
            )

    def prep_cluster_data(
        self, features, filter_name="gaussian", sigma=2, fract=0.2, path_save=None
    ):
        """
        Prepare master array for tissue level clustering

        Parameters
        ----------
        features : list of int or str
            Indices or names of MxIF channels to use for tissue labeling
        filter_name : str
            Name of the filter to use - gaussian, median or bilateral
        sigma : float, optional (default=2)
            Standard deviation of Gaussian kernel for blurring
        fract : float, optional (default=0.2)
            Fraction of cluster data from each image to randomly select for model
            building
        path_save : str (default = None)
            Path to save final preprocessed files, if self.use_path is True
            default path_save will raise Exception

        Returns
        -------
        Does not return anything. `self.images` are normalized, blurred and scaled
        according to user parameters. `self.cluster_data` becomes master `np.array`
        for cluster training. Parameters are also captured as attributes for posterity.

        """
        if self.cluster_data is not None:
            print("WARNING: overwriting existing cluster data")
            self.cluster_data = None
        # save the hyperparams as object attributes
        self.model_features = features
        use_path = self.use_paths
        # calculate the batch wise means
        mean_for_each_batch = {}
        for batch in self.image_df["batch_names"].unique():
            list_mean_estimators = list(
                self.image_df[self.image_df["batch_names"] == batch]["mean estimators"]
            )
            mean_estimator_batch = sum(map(np.array, list_mean_estimators))
            pixels = sum(self.image_df[self.image_df["batch_names"] == batch]["pixels"])
            mean_for_each_batch[batch] = mean_estimator_batch / pixels
        # log_normalize, apply blurring filter, minmax scale each channel and subsample
        subsampled_data = []
        path_to_blurred_npz = []
        for image, batch in zip(self.image_df["Img"], self.image_df["batch_names"]):
            tmp = prep_data_single_sample_mxif(
                image,
                use_path=use_path,
                mean=mean_for_each_batch[batch],
                filter_name=filter_name,
                sigma=sigma,
                features=self.model_features,
                fract=fract,
                path_save=path_save,
            )
            if self.use_paths == True:
                subsampled_data.append(tmp[0])
                path_to_blurred_npz.append(tmp[1])
            else:
                subsampled_data.append(tmp)
        batch_labels = [
            [x] * len(subsampled_data[x]) for x in range(len(subsampled_data))
        ]  # batch labels for umap
        self.merged_batch_labels = list(itertools.chain(*batch_labels))
        if self.use_paths == True:
            self.image_df["Img"] = path_to_blurred_npz
        cluster_data = np.row_stack(subsampled_data)
        # perform z-score normalization on cluster_Data
        scaler = StandardScaler()
        self.scaler = scaler.fit(cluster_data)
        scaled_data = scaler.transform(cluster_data)
        self.cluster_data = scaled_data

    def label_tissue_regions(
        self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1
    ):
        """
        Perform tissue-level clustering and label pixels in the corresponding
        images.

        Parameters
        ----------
        k : int, optional (default=None)
            Number of tissue regions to define
        alpha: float
            Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
        plot_out : boolean, optional (default=True)
            Determines if scaled inertia plot should be output
        random_state : int, optional (default=18)
            Seed for k-means clustering model
        n_jobs : int
            Number of cores to parallelize k-choosing and tissue domain assignment across.
            Default all available cores.

        Returns
        -------
        Does not return anything. `self.tissue_ID` is added, containing image with
        final tissue region IDs. `self.kmeans` contains trained `sklearn` clustering
        model. Parameters are also captured as attributes for posterity.
        """
        # save the hyperparams as object attributes
        use_path = self.use_paths
        # find optimal k with parent class
        if k is None:
            print("Determining optimal cluster number k via scaled inertia")
            self.find_optimal_k(
                alpha=alpha,
                plot_out=plot_out,
                random_state=random_state,
                n_jobs=n_jobs,
            )
        # call k-means model from parent class
        self.find_tissue_regions(k=k, random_state=random_state)
        # loop through image objects and create tissue label images
        print("Creating tissue_ID images for image objects...")
        self.tissue_IDs = Parallel(n_jobs=n_jobs, verbose=10)(
            delayed(add_tissue_ID_single_sample_mxif)(
                image, use_path, self.model_features, self.kmeans, self.scaler
            )
            for image in self.image_df["Img"]
        )

    def plot_percentage_variance_explained(
        self, fig_size=(5, 5), R_square=False, save_to=None
    ):
        """
        plot percentage variance_explained or not explained by clustering

        Parameters
        ----------
        fig_size : Tuple
            size for the bar plot
        R_square : Boolean
            Decides if R_square is plotted or S_square
        save_to : str or None
            Path to image file to save results. If `None`, show figure.

        Returns
        -------
        Matplotlib object
        """
        scaler = self.scaler
        centroids = self.kmeans.cluster_centers_
        features = self.model_features
        use_path = self.use_paths
        S_squre_for_each_image = []
        R_squre_for_each_image = []
        for image, tissue_ID in zip(self.image_df["Img"], self.tissue_IDs):
            S_square = estimate_percentage_variance_mxif(
                image, use_path, scaler, centroids, features, tissue_ID
            )
            S_squre_for_each_image.append(S_square)
            R_squre_for_each_image.append(100 - S_square)

        if R_square == True:
            fig = plt.figure(figsize=fig_size)
            fig = plt.figure(figsize=(5, 5))
            plt.scatter(
                range(len(R_squre_for_each_image)),
                R_squre_for_each_image,
                color="black",
            )
            plt.xlabel("images")
            plt.ylabel("percentage variance explained by Kmeans")
            plt.ylim((0, 100))
            plt.axhline(
                y=np.mean(R_squre_for_each_image),
                linestyle="dashed",
                linewidth=1,
                color="black",
            )

        else:
            fig = plt.figure(figsize=fig_size)
            plt.scatter(
                range(len(S_squre_for_each_image)),
                S_squre_for_each_image,
                color="black",
            )
            plt.xlabel("images")
            plt.ylabel("percentage variance explained by Kmeans")
            plt.ylim((0, 100))
            plt.axhline(
                y=np.mean(S_squre_for_each_image),
                linestyle="dashed",
                linewidth=1,
                color="black",
            )

        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

    def confidence_score_images(self):
        """
        estimate confidence score for each image

        Parameters
        ----------

        Returns
        -------
        self.confidence_IDs and self.confidence_score_df is added containing
        confidence score for each tissue domain assignment and mean confidence score for
        each tissue domain within each image
        """
        scaler = self.scaler
        centroids = self.kmeans.cluster_centers_
        features = self.model_features
        tissue_IDs = self.tissue_IDs
        use_path = self.use_paths
        # confidence score estimation for each image
        confidence_IDs = []
        confidence_score_df = pd.DataFrame()
        for i, image in enumerate(self.image_df["Img"]):
            cID, scores_dict = estimate_confidence_score_mxif(
                image, use_path, scaler, centroids, features, tissue_IDs[i]
            )
            confidence_IDs.append(cID)
            df = pd.DataFrame(scores_dict.values(), columns=[i])
            confidence_score_df = pd.concat(
                [confidence_score_df, df.T], ignore_index=True
            )
        # adding confidence_IDs and confidence_score_df to tissue labeller object
        self.confidence_IDs = confidence_IDs
        self.confidence_score_df = confidence_score_df

    def plot_mse_mxif(
        self,
        figsize=(5, 5),
        ncols=None,
        labels=None,
        legend_cols=2,
        titles=None,
        loc="lower right",
        bbox_coordinates=(0, 0, 1.5, 1.5),
        save_to=None,
    ):
        """
        estimate mean square error within each tissue domain

        Parameters
        ----------
        fig_size : Tuple
            size for the bar plot
        ncols : int, optional (default=`None`)
            Number of columns for gridspec. If `None`, uses number of tissue domains k.
        labels : list of str, optional (default=`None`)
            Labels corresponding to each image in legend. If `None`, numeric index is
            used for each imaage
        legend_cols : int, optional (default = `2`)
            n_cols for legend
        titles : list of str, optional (default=`None`)
            Titles of plots corresponding to each MILWRM domain. If `None`, titles
            will be numbers 0 through k.
        loc : str, optional (default = 'lower right')
            str for legend position
        bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
            coordinates for the legend box
        save_to : str, optional (default=`None`)
            Path to image file to save plot

        Returns
        -------
        Matplotlib object
        """
        assert (
            self.kmeans is not None
        ), "No cluster results found. Run \
        label_tissue_regions() first."
        images = self.image_df["Img"]
        use_path = self.use_paths
        scaler = self.scaler
        centroids = self.kmeans.cluster_centers_
        features = self.model_features
        k = self.k
        features = self.model_features
        tissue_IDs = self.tissue_IDs
        mse_id = estimate_mse_mxif(
            images, use_path, tissue_IDs, scaler, centroids, features, k
        )
        if labels is None:
            labels = range(len(images))
        if titles is None:
            titles = ["tissue_ID " + str(x) for x in range(self.k)]
        n_panels = len(mse_id.keys())
        if ncols is None:
            ncols = len(titles)
        if n_panels <= ncols:
            n_rows, n_cols = 1, n_panels
        else:
            n_rows, n_cols = ceil(n_panels / ncols), ncols
        colors = plt.cm.tab20(np.linspace(0, 1, len(images)))
        fig = plt.figure(figsize=(n_cols * figsize[0], n_rows * figsize[1]))
        left, bottom = 0.1 / n_cols, 0.1 / n_rows
        gs = gridspec.GridSpec(
            nrows=n_rows,
            ncols=n_cols,
            left=left,
            bottom=bottom,
            right=1 - (n_cols - 1) * left - 0.01 / n_cols,
            top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
        )
        for i in mse_id.keys():
            plt.subplot(gs[i])
            df = pd.DataFrame.from_dict(mse_id[i])
            plt.boxplot(df, positions=range(len(features)), showfliers=False)
            plt.xticks(
                ticks=range(len(features)),
                labels=self.model_features,
                rotation=60,
                fontsize=8,
            )
            for col in df:
                for k in range(len(images)):
                    dots = plt.scatter(
                        col,
                        df[col][k],
                        s=k + 1,
                        color=colors[k],
                        label=labels[k] if col == 0 else "",
                    )
                    offsets = dots.get_offsets()
                    jittered_offsets = offsets
                    # only jitter in the x-direction
                    jittered_offsets[:, 0] += np.random.uniform(
                        -0.3, 0.3, offsets.shape[0]
                    )
                    dots.set_offsets(jittered_offsets)
            plt.xlabel("marker")
            plt.ylabel("mean square error")
            plt.title(titles[i])
        plt.legend(loc=loc, bbox_to_anchor=bbox_coordinates, ncol=legend_cols)
        gs.tight_layout(fig)
        if save_to:
            plt.savefig(fname=save_to, transparent=True, dpi=300)
        return fig

    def plot_tissue_ID_proportions_mxif(
        self,
        tID_labels=None,
        slide_labels=None,
        figsize=(5, 5),
        cmap="tab20",
        save_to=None,
    ):
        """
        Plot proportion of each tissue domain within each slide

        Parameters
        ----------
        tID_labels : list of str, optional (default=`None`)
            List of labels corresponding to MILWRM tissue domains for plotting legend
        slide_labels : list of str, optional (default=`None`)
            List of labels for each slide batch for labeling x-axis
        figsize : tuple of float, optional (default=(5,5))
            Size of matplotlib figure
        cmap : str, optional (default = `"tab20"`)
        save_to : str, optional (default=`None`)
            Path to image file to save plot

        Returns
        -------
        `gridspec.GridSpec` if `save_to` is `None`, else saves plot to file
        """
        df_count = pd.DataFrame()
        for i in range(len(self.tissue_IDs)):
            unique, counts = np.unique(self.tissue_IDs[i], return_counts=True)
            dict_ = dict(zip(unique, counts))
            n_counts = []
            for k in range(self.k):
                if k not in dict_.keys():
                    n_counts.append(0)
                else:
                    n_counts.append(dict_[k])
            df = pd.DataFrame(n_counts, columns=[i])
            df_count = pd.concat([df_count, df], axis=1)
        df_count = df_count / df_count.sum()
        if tID_labels:
            assert (
                len(tID_labels) == df_count.shape[1]
            ), "Length of given tissue domain labels does not match number of tissue domains!"
            df_count.columns = tID_labels
        if slide_labels:
            assert (
                len(slide_labels) == df_count.shape[0]
            ), "Length of given slide labels does not match number of slides!"
            df_count.index = slide_labels
        self.tissue_ID_proportion = df_count
        ax = df_count.T.plot.bar(stacked=True, cmap=cmap, figsize=figsize)
        ax.legend(loc="best", bbox_to_anchor=(1, 1))
        ax.set_xlabel("images")
        ax.set_ylabel("tissue domain proportion")
        ax.set_ylim((0, 1))
        plt.tight_layout()
        if save_to is not None:
            ax.figure.savefig(save_to)
        else:
            return ax

    def make_umap(self, frac=None, cmap="tab20", save_to=None, alpha=0.8, dot_size_batch = 0.1):
        """
        plot umap for the cluster data

        Parameters
        ----------
        frac : None or float
            if None entire cluster data is used for the computation of umap
            else that percentage of cluster data is used.
        cmap : str
            str for cmap used for plotting. Default `"tab20"`.
        save_to : str or None
            Path to image file to save results. if `None`, show figure.
        alpha : float
            opaqueness of umap scatter plot (default=`0.8`)
        dot_size_batch = float
            scatter plot dot size (default=`0.1`)

        Returns
        -------
        Matplotlib object
        """
        cluster_data = self.cluster_data
        centroids = self.kmeans.cluster_centers_
        batch_labels = self.merged_batch_labels
        kmeans_labels = self.kmeans.labels_
        k = self.k
        # perform umap on the cluster data
        umap_centroid_data, standard_embedding_1 = perform_umap(
            cluster_data=cluster_data,
            centroids=centroids,
            batch_labels=batch_labels,
            kmeans_labels=kmeans_labels,
            frac=frac,
        )
        # defining a size of datapoints for scatter plot and tick labels
        size = [0.01] * len(umap_centroid_data.index)
        size[-k:] = [10] * k
        ticks = np.unique(np.array(umap_centroid_data["Kmeans_labels"]))
        tick_label = list(np.unique(np.array(umap_centroid_data["Kmeans_labels"])))
        tick_label[-1] = "centroids"
        # plotting a fig with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
        # defining color_map
        disc_cmap_1 = plt.cm.get_cmap(
            cmap, len(np.unique(np.array(umap_centroid_data.index)))
        )
        disc_cmap_2 = plt.cm.get_cmap(
            cmap, len(np.unique(np.array(umap_centroid_data["Kmeans_labels"])))
        )
        plot_1 = ax1.scatter(
            standard_embedding_1[:, 0],
            standard_embedding_1[:, 1],
            s=dot_size_batch,
            c=umap_centroid_data.index,
            cmap=disc_cmap_1,
            alpha=alpha,
        )
        ax1.set_title("UMAP with batch labels", fontsize = 24)
        ax1.set_xlabel("UMAP 2")
        ax1.set_ylabel("UMAP 1")
        ax1.set_xticks([])
        ax1.set_yticks([])
        cbar_1 = plt.colorbar(plot_1, ax=ax1)

        plot_2 = ax2.scatter(
            standard_embedding_1[:, 0],
            standard_embedding_1[:, 1],
            s=size,
            c=umap_centroid_data["Kmeans_labels"],
            cmap=disc_cmap_2,
            alpha=alpha,
        )
        ax2.set_title("UMAP with tissue domains", fontsize = 24)
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax2.set_xlabel("UMAP 2")
        ax2.set_ylabel("UMAP 1")
        cbar_2 = plt.colorbar(plot_2, ax=ax2, ticks=ticks)
        cbar_2.ax.set_yticklabels(tick_label)
        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

    def show_marker_overlay(
        self,
        image_index,
        channels=None,
        cmap="Set1",
        mask_out=True,
        ncols=4,
        save_to=None,
        **kwargs,
    ):
        """
        Plot tissue_ID with individual markers as alpha values to distinguish
        expression in identified tissue domains

        Parameters
        ----------
        image_index : int
            Index of image from `self.images` to plot overlays for (e.g. 0 for first
            image)
        channels : tuple of int or None, optional (default=`None`)
            List of channels by index or name to show
        cmap : str, optional (default="plasma")
            Matplotlib colormap to use for plotting tissue domains
        mask_out : bool, optional (default=`True`)
            Mask out non-tissue pixels prior to showing
        ncols : int
            Number of columns for gridspec if plotting individual channels.
        save_to : str or None
            Path to image file to save results. If `None`, show figure.
        **kwargs
            Arguments to pass to `plt.imshow()` function.

        Returns
        -------
        Matplotlib object (if plotting one feature or RGB) or gridspec object (for
        multiple features). Saves plot to file if `save_to` is not `None`.
        """
        # if image has multiple channels, plot them in gridspec
        if isinstance(channels, int):  # force channels into list if single integer
            channels = [channels]
        if isinstance(channels, str):  # force channels into int if single string
            channels = [self[image_index].ch.index(channels)]
        if checktype(channels):  # force channels into list of int if list of strings
            channels = [self[image_index].ch.index(x) for x in channels]
        if channels is None:  # if no channels are given, use all of them
            channels = [x for x in range(self[image_index].n_ch)]
        assert (
            len(channels) <= self[image_index].n_ch
        ), "Too many channels given: image has {}, expected {}".format(
            self[image_index].n_ch, len(channels)
        )
        # creating a copy of the image
        image_cp = self[image_index].copy()
        # re-scaling to set pixel value range between 0 to 1
        image_cp.scale()
        # defining cmap for discrete color bar
        cmap = plt.cm.get_cmap(cmap, self.k)
        # calculate gridspec dimensions
        if len(channels) + 1 <= ncols:
            n_rows, n_cols = 1, len(channels) + 1
        else:
            n_rows, n_cols = ceil(len(channels) + 1 / ncols), ncols
        fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
        # arrange axes as subplots
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        # plot tissue_ID first with colorbar
        ax = plt.subplot(gs[0])
        im = ax.imshow(self.tissue_IDs[image_index], cmap=cmap, **kwargs)
        ax.set_title(
            label="tissue_ID",
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        # colorbar scale for tissue_IDs
        _ = plt.colorbar(im, ticks=range(self.k), shrink=0.7)
        # add plots to axes
        i = 1
        for channel in channels:
            ax = plt.subplot(gs[i])
            # make copy for alpha
            im_tmp = image_cp.img[:, :, channel].copy()
            if self[image_index].mask is not None and mask_out:
                # area outside mask NaN
                self.tissue_IDs[image_index][self[image_index].mask == 0] = np.nan
                im = ax.imshow(
                    self.tissue_IDs[image_index], cmap=cmap, alpha=im_tmp, **kwargs
                )
            else:
                ax.imshow(self.tissue_IDs[image_index], alpha=im_tmp, **kwargs)
            ax.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            ax.set_title(
                label=self[image_index].ch[channel],
                loc="left",
                fontweight="bold",
                fontsize=16,
            )
            i = i + 1
        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

Ancestors

Methods

def confidence_score_images(self)

estimate confidence score for each image

Parameters

Returns

self.confidence_IDs and self.confidence_score_df is added containing
 
confidence score for each tissue domain assignment and mean confidence score for
 
each tissue domain within each image
 
Expand source code
def confidence_score_images(self):
    """
    estimate confidence score for each image

    Parameters
    ----------

    Returns
    -------
    self.confidence_IDs and self.confidence_score_df is added containing
    confidence score for each tissue domain assignment and mean confidence score for
    each tissue domain within each image
    """
    scaler = self.scaler
    centroids = self.kmeans.cluster_centers_
    features = self.model_features
    tissue_IDs = self.tissue_IDs
    use_path = self.use_paths
    # confidence score estimation for each image
    confidence_IDs = []
    confidence_score_df = pd.DataFrame()
    for i, image in enumerate(self.image_df["Img"]):
        cID, scores_dict = estimate_confidence_score_mxif(
            image, use_path, scaler, centroids, features, tissue_IDs[i]
        )
        confidence_IDs.append(cID)
        df = pd.DataFrame(scores_dict.values(), columns=[i])
        confidence_score_df = pd.concat(
            [confidence_score_df, df.T], ignore_index=True
        )
    # adding confidence_IDs and confidence_score_df to tissue labeller object
    self.confidence_IDs = confidence_IDs
    self.confidence_score_df = confidence_score_df
def label_tissue_regions(self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1)

Perform tissue-level clustering and label pixels in the corresponding images.

Parameters

k : int, optional (default=None)
Number of tissue regions to define
alpha : float
Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
plot_out : boolean, optional (default=True)
Determines if scaled inertia plot should be output
random_state : int, optional (default=18)
Seed for k-means clustering model
n_jobs : int
Number of cores to parallelize k-choosing and tissue domain assignment across. Default all available cores.

Returns

Does not return anything. self.tissue_ID is added, containing image with final tissue region IDs. self.kmeans contains trained sklearn clustering model. Parameters are also captured as attributes for posterity.

Expand source code
def label_tissue_regions(
    self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1
):
    """
    Perform tissue-level clustering and label pixels in the corresponding
    images.

    Parameters
    ----------
    k : int, optional (default=None)
        Number of tissue regions to define
    alpha: float
        Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
    plot_out : boolean, optional (default=True)
        Determines if scaled inertia plot should be output
    random_state : int, optional (default=18)
        Seed for k-means clustering model
    n_jobs : int
        Number of cores to parallelize k-choosing and tissue domain assignment across.
        Default all available cores.

    Returns
    -------
    Does not return anything. `self.tissue_ID` is added, containing image with
    final tissue region IDs. `self.kmeans` contains trained `sklearn` clustering
    model. Parameters are also captured as attributes for posterity.
    """
    # save the hyperparams as object attributes
    use_path = self.use_paths
    # find optimal k with parent class
    if k is None:
        print("Determining optimal cluster number k via scaled inertia")
        self.find_optimal_k(
            alpha=alpha,
            plot_out=plot_out,
            random_state=random_state,
            n_jobs=n_jobs,
        )
    # call k-means model from parent class
    self.find_tissue_regions(k=k, random_state=random_state)
    # loop through image objects and create tissue label images
    print("Creating tissue_ID images for image objects...")
    self.tissue_IDs = Parallel(n_jobs=n_jobs, verbose=10)(
        delayed(add_tissue_ID_single_sample_mxif)(
            image, use_path, self.model_features, self.kmeans, self.scaler
        )
        for image in self.image_df["Img"]
    )
def make_umap(self, frac=None, cmap='tab20', save_to=None, alpha=0.8, dot_size_batch=0.1)

plot umap for the cluster data

Parameters

frac : None or float
if None entire cluster data is used for the computation of umap else that percentage of cluster data is used.
cmap : str
str for cmap used for plotting. Default "tab20".
save_to : str or None
Path to image file to save results. if None, show figure.
alpha : float
opaqueness of umap scatter plot (default=0.8)

dot_size_batch = float scatter plot dot size (default=0.1)

Returns

Matplotlib object
 
Expand source code
def make_umap(self, frac=None, cmap="tab20", save_to=None, alpha=0.8, dot_size_batch = 0.1):
    """
    plot umap for the cluster data

    Parameters
    ----------
    frac : None or float
        if None entire cluster data is used for the computation of umap
        else that percentage of cluster data is used.
    cmap : str
        str for cmap used for plotting. Default `"tab20"`.
    save_to : str or None
        Path to image file to save results. if `None`, show figure.
    alpha : float
        opaqueness of umap scatter plot (default=`0.8`)
    dot_size_batch = float
        scatter plot dot size (default=`0.1`)

    Returns
    -------
    Matplotlib object
    """
    cluster_data = self.cluster_data
    centroids = self.kmeans.cluster_centers_
    batch_labels = self.merged_batch_labels
    kmeans_labels = self.kmeans.labels_
    k = self.k
    # perform umap on the cluster data
    umap_centroid_data, standard_embedding_1 = perform_umap(
        cluster_data=cluster_data,
        centroids=centroids,
        batch_labels=batch_labels,
        kmeans_labels=kmeans_labels,
        frac=frac,
    )
    # defining a size of datapoints for scatter plot and tick labels
    size = [0.01] * len(umap_centroid_data.index)
    size[-k:] = [10] * k
    ticks = np.unique(np.array(umap_centroid_data["Kmeans_labels"]))
    tick_label = list(np.unique(np.array(umap_centroid_data["Kmeans_labels"])))
    tick_label[-1] = "centroids"
    # plotting a fig with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
    # defining color_map
    disc_cmap_1 = plt.cm.get_cmap(
        cmap, len(np.unique(np.array(umap_centroid_data.index)))
    )
    disc_cmap_2 = plt.cm.get_cmap(
        cmap, len(np.unique(np.array(umap_centroid_data["Kmeans_labels"])))
    )
    plot_1 = ax1.scatter(
        standard_embedding_1[:, 0],
        standard_embedding_1[:, 1],
        s=dot_size_batch,
        c=umap_centroid_data.index,
        cmap=disc_cmap_1,
        alpha=alpha,
    )
    ax1.set_title("UMAP with batch labels", fontsize = 24)
    ax1.set_xlabel("UMAP 2")
    ax1.set_ylabel("UMAP 1")
    ax1.set_xticks([])
    ax1.set_yticks([])
    cbar_1 = plt.colorbar(plot_1, ax=ax1)

    plot_2 = ax2.scatter(
        standard_embedding_1[:, 0],
        standard_embedding_1[:, 1],
        s=size,
        c=umap_centroid_data["Kmeans_labels"],
        cmap=disc_cmap_2,
        alpha=alpha,
    )
    ax2.set_title("UMAP with tissue domains", fontsize = 24)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xlabel("UMAP 2")
    ax2.set_ylabel("UMAP 1")
    cbar_2 = plt.colorbar(plot_2, ax=ax2, ticks=ticks)
    cbar_2.ax.set_yticklabels(tick_label)
    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig
def plot_mse_mxif(self, figsize=(5, 5), ncols=None, labels=None, legend_cols=2, titles=None, loc='lower right', bbox_coordinates=(0, 0, 1.5, 1.5), save_to=None)

estimate mean square error within each tissue domain

Parameters

fig_size : Tuple
size for the bar plot
ncols : int, optional (default=None)
Number of columns for gridspec. If None, uses number of tissue domains k.
labels : list of str, optional (default=None)
Labels corresponding to each image in legend. If None, numeric index is used for each imaage
legend_cols : int, optional (default =2)
n_cols for legend
titles : list of str, optional (default=None)
Titles of plots corresponding to each MILWRM domain. If None, titles will be numbers 0 through k.
loc : str, optional (default = 'lower right')
str for legend position
bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
coordinates for the legend box
save_to : str, optional (default=None)
Path to image file to save plot

Returns

Matplotlib object
 
Expand source code
def plot_mse_mxif(
    self,
    figsize=(5, 5),
    ncols=None,
    labels=None,
    legend_cols=2,
    titles=None,
    loc="lower right",
    bbox_coordinates=(0, 0, 1.5, 1.5),
    save_to=None,
):
    """
    estimate mean square error within each tissue domain

    Parameters
    ----------
    fig_size : Tuple
        size for the bar plot
    ncols : int, optional (default=`None`)
        Number of columns for gridspec. If `None`, uses number of tissue domains k.
    labels : list of str, optional (default=`None`)
        Labels corresponding to each image in legend. If `None`, numeric index is
        used for each imaage
    legend_cols : int, optional (default = `2`)
        n_cols for legend
    titles : list of str, optional (default=`None`)
        Titles of plots corresponding to each MILWRM domain. If `None`, titles
        will be numbers 0 through k.
    loc : str, optional (default = 'lower right')
        str for legend position
    bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
        coordinates for the legend box
    save_to : str, optional (default=`None`)
        Path to image file to save plot

    Returns
    -------
    Matplotlib object
    """
    assert (
        self.kmeans is not None
    ), "No cluster results found. Run \
    label_tissue_regions() first."
    images = self.image_df["Img"]
    use_path = self.use_paths
    scaler = self.scaler
    centroids = self.kmeans.cluster_centers_
    features = self.model_features
    k = self.k
    features = self.model_features
    tissue_IDs = self.tissue_IDs
    mse_id = estimate_mse_mxif(
        images, use_path, tissue_IDs, scaler, centroids, features, k
    )
    if labels is None:
        labels = range(len(images))
    if titles is None:
        titles = ["tissue_ID " + str(x) for x in range(self.k)]
    n_panels = len(mse_id.keys())
    if ncols is None:
        ncols = len(titles)
    if n_panels <= ncols:
        n_rows, n_cols = 1, n_panels
    else:
        n_rows, n_cols = ceil(n_panels / ncols), ncols
    colors = plt.cm.tab20(np.linspace(0, 1, len(images)))
    fig = plt.figure(figsize=(n_cols * figsize[0], n_rows * figsize[1]))
    left, bottom = 0.1 / n_cols, 0.1 / n_rows
    gs = gridspec.GridSpec(
        nrows=n_rows,
        ncols=n_cols,
        left=left,
        bottom=bottom,
        right=1 - (n_cols - 1) * left - 0.01 / n_cols,
        top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
    )
    for i in mse_id.keys():
        plt.subplot(gs[i])
        df = pd.DataFrame.from_dict(mse_id[i])
        plt.boxplot(df, positions=range(len(features)), showfliers=False)
        plt.xticks(
            ticks=range(len(features)),
            labels=self.model_features,
            rotation=60,
            fontsize=8,
        )
        for col in df:
            for k in range(len(images)):
                dots = plt.scatter(
                    col,
                    df[col][k],
                    s=k + 1,
                    color=colors[k],
                    label=labels[k] if col == 0 else "",
                )
                offsets = dots.get_offsets()
                jittered_offsets = offsets
                # only jitter in the x-direction
                jittered_offsets[:, 0] += np.random.uniform(
                    -0.3, 0.3, offsets.shape[0]
                )
                dots.set_offsets(jittered_offsets)
        plt.xlabel("marker")
        plt.ylabel("mean square error")
        plt.title(titles[i])
    plt.legend(loc=loc, bbox_to_anchor=bbox_coordinates, ncol=legend_cols)
    gs.tight_layout(fig)
    if save_to:
        plt.savefig(fname=save_to, transparent=True, dpi=300)
    return fig
def plot_percentage_variance_explained(self, fig_size=(5, 5), R_square=False, save_to=None)

plot percentage variance_explained or not explained by clustering

Parameters

fig_size : Tuple
size for the bar plot
R_square : Boolean
Decides if R_square is plotted or S_square
save_to : str or None
Path to image file to save results. If None, show figure.

Returns

Matplotlib object
 
Expand source code
def plot_percentage_variance_explained(
    self, fig_size=(5, 5), R_square=False, save_to=None
):
    """
    plot percentage variance_explained or not explained by clustering

    Parameters
    ----------
    fig_size : Tuple
        size for the bar plot
    R_square : Boolean
        Decides if R_square is plotted or S_square
    save_to : str or None
        Path to image file to save results. If `None`, show figure.

    Returns
    -------
    Matplotlib object
    """
    scaler = self.scaler
    centroids = self.kmeans.cluster_centers_
    features = self.model_features
    use_path = self.use_paths
    S_squre_for_each_image = []
    R_squre_for_each_image = []
    for image, tissue_ID in zip(self.image_df["Img"], self.tissue_IDs):
        S_square = estimate_percentage_variance_mxif(
            image, use_path, scaler, centroids, features, tissue_ID
        )
        S_squre_for_each_image.append(S_square)
        R_squre_for_each_image.append(100 - S_square)

    if R_square == True:
        fig = plt.figure(figsize=fig_size)
        fig = plt.figure(figsize=(5, 5))
        plt.scatter(
            range(len(R_squre_for_each_image)),
            R_squre_for_each_image,
            color="black",
        )
        plt.xlabel("images")
        plt.ylabel("percentage variance explained by Kmeans")
        plt.ylim((0, 100))
        plt.axhline(
            y=np.mean(R_squre_for_each_image),
            linestyle="dashed",
            linewidth=1,
            color="black",
        )

    else:
        fig = plt.figure(figsize=fig_size)
        plt.scatter(
            range(len(S_squre_for_each_image)),
            S_squre_for_each_image,
            color="black",
        )
        plt.xlabel("images")
        plt.ylabel("percentage variance explained by Kmeans")
        plt.ylim((0, 100))
        plt.axhline(
            y=np.mean(S_squre_for_each_image),
            linestyle="dashed",
            linewidth=1,
            color="black",
        )

    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig
def plot_tissue_ID_proportions_mxif(self, tID_labels=None, slide_labels=None, figsize=(5, 5), cmap='tab20', save_to=None)

Plot proportion of each tissue domain within each slide

Parameters

tID_labels : list of str, optional (default=None)
List of labels corresponding to MILWRM tissue domains for plotting legend
slide_labels : list of str, optional (default=None)
List of labels for each slide batch for labeling x-axis
figsize : tuple of float, optional (default=(5,5))
Size of matplotlib figure
cmap : str, optional (default ="tab20")
 
save_to : str, optional (default=None)
Path to image file to save plot

Returns

gridspec.GridSpec if save_to is None, else saves plot to file

Expand source code
def plot_tissue_ID_proportions_mxif(
    self,
    tID_labels=None,
    slide_labels=None,
    figsize=(5, 5),
    cmap="tab20",
    save_to=None,
):
    """
    Plot proportion of each tissue domain within each slide

    Parameters
    ----------
    tID_labels : list of str, optional (default=`None`)
        List of labels corresponding to MILWRM tissue domains for plotting legend
    slide_labels : list of str, optional (default=`None`)
        List of labels for each slide batch for labeling x-axis
    figsize : tuple of float, optional (default=(5,5))
        Size of matplotlib figure
    cmap : str, optional (default = `"tab20"`)
    save_to : str, optional (default=`None`)
        Path to image file to save plot

    Returns
    -------
    `gridspec.GridSpec` if `save_to` is `None`, else saves plot to file
    """
    df_count = pd.DataFrame()
    for i in range(len(self.tissue_IDs)):
        unique, counts = np.unique(self.tissue_IDs[i], return_counts=True)
        dict_ = dict(zip(unique, counts))
        n_counts = []
        for k in range(self.k):
            if k not in dict_.keys():
                n_counts.append(0)
            else:
                n_counts.append(dict_[k])
        df = pd.DataFrame(n_counts, columns=[i])
        df_count = pd.concat([df_count, df], axis=1)
    df_count = df_count / df_count.sum()
    if tID_labels:
        assert (
            len(tID_labels) == df_count.shape[1]
        ), "Length of given tissue domain labels does not match number of tissue domains!"
        df_count.columns = tID_labels
    if slide_labels:
        assert (
            len(slide_labels) == df_count.shape[0]
        ), "Length of given slide labels does not match number of slides!"
        df_count.index = slide_labels
    self.tissue_ID_proportion = df_count
    ax = df_count.T.plot.bar(stacked=True, cmap=cmap, figsize=figsize)
    ax.legend(loc="best", bbox_to_anchor=(1, 1))
    ax.set_xlabel("images")
    ax.set_ylabel("tissue domain proportion")
    ax.set_ylim((0, 1))
    plt.tight_layout()
    if save_to is not None:
        ax.figure.savefig(save_to)
    else:
        return ax
def prep_cluster_data(self, features, filter_name='gaussian', sigma=2, fract=0.2, path_save=None)

Prepare master array for tissue level clustering

Parameters

features : list of int or str
Indices or names of MxIF channels to use for tissue labeling
filter_name : str
Name of the filter to use - gaussian, median or bilateral
sigma : float, optional (default=2)
Standard deviation of Gaussian kernel for blurring
fract : float, optional (default=0.2)
Fraction of cluster data from each image to randomly select for model building
path_save : str (default = None)
Path to save final preprocessed files, if self.use_path is True default path_save will raise Exception

Returns

Does not return anything. self.images are normalized, blurred and scaled according to user parameters. self.cluster_data becomes master np.array for cluster training. Parameters are also captured as attributes for posterity.

Expand source code
def prep_cluster_data(
    self, features, filter_name="gaussian", sigma=2, fract=0.2, path_save=None
):
    """
    Prepare master array for tissue level clustering

    Parameters
    ----------
    features : list of int or str
        Indices or names of MxIF channels to use for tissue labeling
    filter_name : str
        Name of the filter to use - gaussian, median or bilateral
    sigma : float, optional (default=2)
        Standard deviation of Gaussian kernel for blurring
    fract : float, optional (default=0.2)
        Fraction of cluster data from each image to randomly select for model
        building
    path_save : str (default = None)
        Path to save final preprocessed files, if self.use_path is True
        default path_save will raise Exception

    Returns
    -------
    Does not return anything. `self.images` are normalized, blurred and scaled
    according to user parameters. `self.cluster_data` becomes master `np.array`
    for cluster training. Parameters are also captured as attributes for posterity.

    """
    if self.cluster_data is not None:
        print("WARNING: overwriting existing cluster data")
        self.cluster_data = None
    # save the hyperparams as object attributes
    self.model_features = features
    use_path = self.use_paths
    # calculate the batch wise means
    mean_for_each_batch = {}
    for batch in self.image_df["batch_names"].unique():
        list_mean_estimators = list(
            self.image_df[self.image_df["batch_names"] == batch]["mean estimators"]
        )
        mean_estimator_batch = sum(map(np.array, list_mean_estimators))
        pixels = sum(self.image_df[self.image_df["batch_names"] == batch]["pixels"])
        mean_for_each_batch[batch] = mean_estimator_batch / pixels
    # log_normalize, apply blurring filter, minmax scale each channel and subsample
    subsampled_data = []
    path_to_blurred_npz = []
    for image, batch in zip(self.image_df["Img"], self.image_df["batch_names"]):
        tmp = prep_data_single_sample_mxif(
            image,
            use_path=use_path,
            mean=mean_for_each_batch[batch],
            filter_name=filter_name,
            sigma=sigma,
            features=self.model_features,
            fract=fract,
            path_save=path_save,
        )
        if self.use_paths == True:
            subsampled_data.append(tmp[0])
            path_to_blurred_npz.append(tmp[1])
        else:
            subsampled_data.append(tmp)
    batch_labels = [
        [x] * len(subsampled_data[x]) for x in range(len(subsampled_data))
    ]  # batch labels for umap
    self.merged_batch_labels = list(itertools.chain(*batch_labels))
    if self.use_paths == True:
        self.image_df["Img"] = path_to_blurred_npz
    cluster_data = np.row_stack(subsampled_data)
    # perform z-score normalization on cluster_Data
    scaler = StandardScaler()
    self.scaler = scaler.fit(cluster_data)
    scaled_data = scaler.transform(cluster_data)
    self.cluster_data = scaled_data
def show_marker_overlay(self, image_index, channels=None, cmap='Set1', mask_out=True, ncols=4, save_to=None, **kwargs)

Plot tissue_ID with individual markers as alpha values to distinguish expression in identified tissue domains

Parameters

image_index : int
Index of image from self.images to plot overlays for (e.g. 0 for first image)
channels : tuple of int or None, optional (default=None)
List of channels by index or name to show
cmap : str, optional (default="plasma")
Matplotlib colormap to use for plotting tissue domains
mask_out : bool, optional (default=True)
Mask out non-tissue pixels prior to showing
ncols : int
Number of columns for gridspec if plotting individual channels.
save_to : str or None
Path to image file to save results. If None, show figure.
**kwargs
Arguments to pass to plt.imshow() function.

Returns

Matplotlib object (if plotting one feature or RGB) or gridspec object (for
 

multiple features). Saves plot to file if save_to is not None.

Expand source code
def show_marker_overlay(
    self,
    image_index,
    channels=None,
    cmap="Set1",
    mask_out=True,
    ncols=4,
    save_to=None,
    **kwargs,
):
    """
    Plot tissue_ID with individual markers as alpha values to distinguish
    expression in identified tissue domains

    Parameters
    ----------
    image_index : int
        Index of image from `self.images` to plot overlays for (e.g. 0 for first
        image)
    channels : tuple of int or None, optional (default=`None`)
        List of channels by index or name to show
    cmap : str, optional (default="plasma")
        Matplotlib colormap to use for plotting tissue domains
    mask_out : bool, optional (default=`True`)
        Mask out non-tissue pixels prior to showing
    ncols : int
        Number of columns for gridspec if plotting individual channels.
    save_to : str or None
        Path to image file to save results. If `None`, show figure.
    **kwargs
        Arguments to pass to `plt.imshow()` function.

    Returns
    -------
    Matplotlib object (if plotting one feature or RGB) or gridspec object (for
    multiple features). Saves plot to file if `save_to` is not `None`.
    """
    # if image has multiple channels, plot them in gridspec
    if isinstance(channels, int):  # force channels into list if single integer
        channels = [channels]
    if isinstance(channels, str):  # force channels into int if single string
        channels = [self[image_index].ch.index(channels)]
    if checktype(channels):  # force channels into list of int if list of strings
        channels = [self[image_index].ch.index(x) for x in channels]
    if channels is None:  # if no channels are given, use all of them
        channels = [x for x in range(self[image_index].n_ch)]
    assert (
        len(channels) <= self[image_index].n_ch
    ), "Too many channels given: image has {}, expected {}".format(
        self[image_index].n_ch, len(channels)
    )
    # creating a copy of the image
    image_cp = self[image_index].copy()
    # re-scaling to set pixel value range between 0 to 1
    image_cp.scale()
    # defining cmap for discrete color bar
    cmap = plt.cm.get_cmap(cmap, self.k)
    # calculate gridspec dimensions
    if len(channels) + 1 <= ncols:
        n_rows, n_cols = 1, len(channels) + 1
    else:
        n_rows, n_cols = ceil(len(channels) + 1 / ncols), ncols
    fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
    # arrange axes as subplots
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
    # plot tissue_ID first with colorbar
    ax = plt.subplot(gs[0])
    im = ax.imshow(self.tissue_IDs[image_index], cmap=cmap, **kwargs)
    ax.set_title(
        label="tissue_ID",
        loc="left",
        fontweight="bold",
        fontsize=16,
    )
    ax.tick_params(labelbottom=False, labelleft=False)
    sns.despine(bottom=True, left=True)
    # colorbar scale for tissue_IDs
    _ = plt.colorbar(im, ticks=range(self.k), shrink=0.7)
    # add plots to axes
    i = 1
    for channel in channels:
        ax = plt.subplot(gs[i])
        # make copy for alpha
        im_tmp = image_cp.img[:, :, channel].copy()
        if self[image_index].mask is not None and mask_out:
            # area outside mask NaN
            self.tissue_IDs[image_index][self[image_index].mask == 0] = np.nan
            im = ax.imshow(
                self.tissue_IDs[image_index], cmap=cmap, alpha=im_tmp, **kwargs
            )
        else:
            ax.imshow(self.tissue_IDs[image_index], alpha=im_tmp, **kwargs)
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        ax.set_title(
            label=self[image_index].ch[channel],
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        i = i + 1
    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig

Inherited members

class st_labeler (adatas)

Tissue domain labeling class for spatial transcriptomics (ST) data

Initialize ST tissue labeler class

Parameters

adatas : list of anndata.AnnData
Single anndata object or list of objects to label consensus tissue domains

Returns

Does not return anything. self.adatas attribute is updated, self.cluster_data attribute is initiated as None.

Expand source code
class st_labeler(tissue_labeler):
    """
    Tissue domain labeling class for spatial transcriptomics (ST) data
    """

    def __init__(self, adatas):
        """
        Initialize ST tissue labeler class

        Parameters
        ----------
        adatas : list of anndata.AnnData
            Single anndata object or list of objects to label consensus tissue domains

        Returns
        -------
        Does not return anything. `self.adatas` attribute is updated,
        `self.cluster_data` attribute is initiated as `None`.
        """
        tissue_labeler.__init__(self)  # initialize parent class
        if not isinstance(adatas, list):  # force single anndata object to list
            adatas = [adatas]
        print("Initiating ST labeler with {} anndata objects".format(len(adatas)))
        self.adatas = adatas
        self.raw = adatas.copy()

    def prep_cluster_data(
        self,
        use_rep,
        features=None,
        n_rings=1,
        histo=False,
        fluor_channels=None,
        spatial_graph_key=None,
        n_jobs=-1,
    ):
        """
        Prepare master dataframe for tissue-level clustering

        Parameters
        ----------
        use_rep : str
            Representation from `adata.obsm` to use as clustering data (e.g. "X_pca")
        features : list of int or None, optional (default=`None`)
            List of features to use from `adata.obsm[use_rep]` (e.g. [0,1,2,3,4] to
            use first 5 principal components when `use_rep`="X_pca"). If `None`, use
            all features from `adata.obsm[use_rep]`
        n_rings : int, optional (default=1)
            Number of hexagonal rings around each spatial transcriptomics spot to blur
            features by for capturing regional information. Assumes 10X Genomics Visium
            platform.
        histo : bool, optional (default `False`)
            Use histology data from Visium anndata object (R,G,B brightfield features)
            in addition to `adata.obsm[use_rep]`? If fluorescent imaging data rather
            than brightfield, use `fluor_channels` argument instead.
        fluor_channels : list of int or None, optional (default `None`)
            Channels from fluorescent image to use for model training (e.g. [1,3] for
            channels 1 and 3 of Visium fluorescent imaging data). If `None`, do not
            use imaging data for training.
        spatial_graph_key : str, optional (default=`None`)
            Key in `adata.obsp` containing spatial graph connectivities (i.e.
            `"spatial_connectivities"`). If `None`, compute new spatial graph using
            `n_rings` in `squidpy`.
        n_jobs : int, optional (default=-1)
            Number of cores to parallelize over. Default all available cores.

        Returns
        -------
        Does not return anything. `self.adatas` are updated, adding "blur_*" features
        to `.obs` if `n_rings > 0`.
        `self.cluster_data` becomes master `np.array` for cluster training.
        Parameters are also captured as attributes for posterity.
        """
        if self.cluster_data is not None:
            print("WARNING: overwriting existing cluster data")
            self.cluster_data = None
        if features is None:
            self.features = [x for x in range(self.adatas[0].obsm[use_rep].shape[1])]
        else:
            self.features = features
        # save the hyperparams as object attributes
        self.rep = use_rep
        self.histo = histo
        self.fluor_channels = fluor_channels
        self.n_rings = n_rings
        # collect clustering data from self.adatas in parallel
        print(
            "Collecting and blurring {} features from .obsm[{}]...".format(
                len(self.features),
                use_rep,
            )
        )
        cluster_data = Parallel(n_jobs=n_jobs, verbose=10)(
            delayed(prep_data_single_sample_st)(
                adata,
                adata_i,
                use_rep,
                self.features,
                histo,
                fluor_channels,
                spatial_graph_key,
                n_rings,
            )
            for adata_i, adata in enumerate(self.adatas)
        )
        batch_labels = [
            [x] * len(cluster_data[x]) for x in range(len(cluster_data))
        ]  # batch labels for umap
        self.merged_batch_labels = list(itertools.chain(*batch_labels))
        # concatenate blurred features into cluster_data df for cluster training
        subsampled_data = pd.concat(cluster_data)
        # perform z-scaling on final cluster data
        scaler = StandardScaler()
        self.scaler = scaler.fit(subsampled_data)
        scaled_data = scaler.transform(subsampled_data)
        self.cluster_data = scaled_data
        print("Collected clustering data of shape: {}".format(self.cluster_data.shape))

    def label_tissue_regions(
        self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1
    ):
        """
        Perform tissue-level clustering and label pixels in the corresponding
        `anndata` objects.

        Parameters
        ----------
        k : int, optional (default=None)
            Number of tissue regions to define
        alpha: float
            Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
        plot_out : boolean, optional (default=True)
            Determines if scaled inertia plot should be output
        random_state : int, optional (default=18)
            Seed for k-means clustering model.
        n_jobs : int
            Number of cores to parallelize k-choosing across

        Returns
        -------
        Does not return anything. `self.adatas` are updated, adding "tissue_ID" field
        to `.obs`. `self.kmeans` contains trained `sklearn` clustering model.
        Parameters are also captured as attributes for posterity.
        """
        # find optimal k with parent class
        if k is None:
            print("Determining optimal cluster number k via scaled inertia")
            self.find_optimal_k(
                plot_out=plot_out, alpha=alpha, random_state=random_state, n_jobs=n_jobs
            )
        # call k-means model from parent class
        self.find_tissue_regions(k=k, random_state=random_state)
        # loop through anndata object and add tissue labels to adata.obs dataframe
        start = 0
        print("Adding tissue_ID label to anndata objects")
        for i in range(len(self.adatas)):
            IDs = self.kmeans.labels_
            self.adatas[i].obs["tissue_ID"] = IDs[start : start + self.adatas[i].n_obs]
            self.adatas[i].obs["tissue_ID"] = (
                self.adatas[i].obs["tissue_ID"].astype("category")
            )
            self.adatas[i].obs["tissue_ID"] = (
                self.adatas[i].obs["tissue_ID"].cat.set_categories(np.unique(IDs))
            )
            start += self.adatas[i].n_obs

    def confidence_score(self):
        """
        estimate confidence score for each visium slide

        Parameters
        ----------

        Returns
        -------
        self.adatas[i].obs.confidence_IDs and self.confidence_score_df are added
        containing confidence score for each tissue domain assignment and mean confidence
        score for each tissue domain within each visium slide
        """
        assert (
            self.kmeans is not None
        ), "No cluster results found. Run \
        label_tissue_regions() first."
        i_slice = 0
        j_slice = 0
        confidence_score_df = pd.DataFrame()
        adatas = self.adatas
        cluster_data = self.cluster_data
        centroids = self.kmeans.cluster_centers_
        for i, adata in enumerate(adatas):
            j_slice = j_slice + adata.n_obs
            data = cluster_data[i_slice:j_slice]
            scores_dict = estimate_confidence_score_st(data, adata, centroids)
            df = pd.DataFrame(scores_dict.values(), columns=[i])
            confidence_score_df = pd.concat([confidence_score_df, df], axis=1)
            i_slice = i_slice + adata.n_obs
        self.confidence_score_df = confidence_score_df

    def plot_gene_loadings(
        self,
        PC_loadings,
        n_genes=10,
        ncols=None,
        titles=None,
        save_to=None,
    ):
        """
        Plot MILWRM loadings in gene space specifically for MILWRM done with PCs

        Parameters
        ----------
        PC_loadings : numpy.ndarray
            numpy.ndarray containing PC loadings shape format (genes, components)
        n_genes : int, optional (default=10)
            number of genes to plot
        ncols : int, optional (default=`None`)
            Number of columns for gridspec. If `None`, uses number of tissue domains k.
        titles : list of str, optional (default=`None`)
            Titles of plots corresponding to each MILWRM domain. If `None`, titles
            will be numbers 0 through k.
        save_to : str, optional (default=`None`)
            Path to image file to save plot

        Returns
        -------
        Matplotlib object and PC loadings in gene space set as self.gene_loadings_df
        """
        assert (
            self.kmeans is not None
        ), "No cluster results found. Run \
        label_tissue_regions() first."
        assert (
            PC_loadings.shape[0] == self.adatas[0].n_vars
        ), f"loadings matrix does not, \
        contain enough genes, there should be {self.adatas[0].n_vars} genes"
        assert (
            PC_loadings.shape[1] >= self.kmeans.cluster_centers_.shape[1]
        ), f"loadings matrix \
        does not contain enough components, there should be atleast {self.adatas[0].n_vars} components"
        if titles is None:
            titles = ["tissue_ID " + str(x) for x in range(self.k)]
        centroids = self.kmeans.cluster_centers_
        temp = PC_loadings.T
        loadings = temp[range(self.kmeans.cluster_centers_.shape[1])]
        gene_loadings = np.matmul(centroids, loadings)
        gene_loadings_df = pd.DataFrame(gene_loadings)
        gene_loadings_df = gene_loadings_df.T
        gene_loadings_df["genes"] = self.adatas[0].var_names
        self.gene_loadings_df = gene_loadings_df
        n_panels = self.k
        if ncols is None:
            ncols = self.k
        if n_panels <= ncols:
            n_rows, n_cols = 1, n_panels
        else:
            n_rows, n_cols = ceil(n_panels / ncols), ncols
        fig = plt.figure(figsize=((ncols * n_cols, ncols * n_rows)))
        left, bottom = 0.1 / n_cols, 0.1 / n_rows
        gs = gridspec.GridSpec(
            nrows=n_rows,
            ncols=n_cols,
            left=left,
            bottom=bottom,
            right=1 - (n_cols - 1) * left - 0.01 / n_cols,
            top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
        )
        for i in range(self.k):
            df = (
                gene_loadings_df[[i, "genes"]]
                .sort_values(i, axis=0, ascending=False)[:n_genes]
                .reset_index(drop=True)
            )
            plt.subplot(gs[i])
            df_rev = df.sort_values(i).reset_index(drop=True)
            for j, score in enumerate((df_rev[i])):
                plt.text(
                    x=score,
                    y=j + 0.1,
                    s=df_rev.loc[j, "genes"],
                    color="black",
                    verticalalignment="center",
                    horizontalalignment="right",
                    fontsize="medium",
                    fontstyle="italic",
                )
                plt.ylim([0, j + 1])
                plt.xlim([0, df.max().values[0] + 0.1])
                plt.tick_params(
                    axis="y",  # changes apply to the y-axis
                    which="both",  # both major and minor ticks are affected
                    left=False,
                    right=False,
                    labelleft=False,
                )
                plt.title(titles[i])
        gs.tight_layout(fig)
        if save_to is not None:
            print("Saving feature loadings to {}".format(save_to))
            plt.savefig(save_to)
        else:
            return gs

    def plot_percentage_variance_explained(
        self, fig_size=(5, 5), R_square=False, save_to=None
    ):
        """
        plot percentage variance_explained or not explained by clustering

        Parameters
        ----------
        figsize : tuple of float, optional (default=(5,5))
            Size of matplotlib figure
        R_square : Boolean
            Decides if R_square is plotted or S_square
        save_to : str or None
            Path to image file to save results. If `None`, show figure.

        Returns
        -------
        Matplotlib object
        """
        assert (
            self.kmeans is not None
        ), "No cluster results found. Run \
        label_tissue_regions() first."
        centroids = self.kmeans.cluster_centers_
        adatas = self.adatas
        cluster_data = self.cluster_data
        S_squre_for_each_st = []
        R_squre_for_each_st = []
        i_slice = 0
        j_slice = 0
        for adata in adatas:
            j_slice = j_slice + adata.n_obs
            sub_cluster_data = cluster_data[i_slice:j_slice]
            S_square = estimate_percentage_variance_st(
                sub_cluster_data, adata, centroids
            )
            S_squre_for_each_st.append(S_square)
            R_squre_for_each_st.append(100 - S_square)
            i_slice = i_slice + adata.n_obs

        if R_square:
            fig = plt.figure(figsize=fig_size)
            plt.scatter(
                range(len(R_squre_for_each_st)), R_squre_for_each_st, color="black"
            )
            plt.xlabel("images")
            plt.ylabel("percentage variance explained by Kmeans")
            plt.ylim((0, 100))
            plt.axhline(
                y=np.mean(R_squre_for_each_st),
                linestyle="dashed",
                linewidth=1,
                color="black",
            )

        else:
            fig = plt.figure(figsize=fig_size)
            fig = plt.figure(figsize=(5, 5))
            plt.scatter(
                range(len(S_squre_for_each_st)), S_squre_for_each_st, color="black"
            )
            plt.xlabel("images")
            plt.ylabel("percentage variance explained by Kmeans")
            plt.ylim((0, 100))
            plt.axhline(
                y=np.mean(S_squre_for_each_st),
                linestyle="dashed",
                linewidth=1,
                color="black",
            )

        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

    def plot_mse_st(
        self,
        figsize=(5, 5),
        ncols=None,
        labels=None,
        titles=None,
        loc="lower right",
        bbox_coordinates=(0, 0, 1.5, 1.5),
        save_to=None,
    ):
        """
        estimate mean square error within each tissue domain

        Parameters
        ----------
        fig_size : Tuple
            size for the bar plot
        ncols : int, optional (default=`None`)
            Number of columns for gridspec. If `None`, uses number of tissue domains k.
        labels : list of str, optional (default=`None`)
            Labels corresponding to each image in legend. If `None`, numeric index is
            used for each imaage
        titles : list of str, optional (default=`None`)
            Titles of plots corresponding to each MILWRM domain. If `None`, titles
            will be numbers 0 through k.
        loc : str, optional (default = 'lower right')
            str for legend position
        bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
            coordinates for the legend box
        save_to : str, optional (default=`None`)
            Path to image file to save plot

        Returns
        -------
        Matplotlib object
        """
        assert (
            self.kmeans is not None
        ), "No cluster results found. Run \
        label_tissue_regions() first."
        cluster_data = self.cluster_data
        adatas = self.adatas
        k = self.k
        features = self.features
        centroids = self.kmeans.cluster_centers_
        mse_id = estimate_mse_st(cluster_data, adatas, centroids, k)
        colors = plt.cm.tab20(np.linspace(0, 1, len(adatas)))
        if titles is None:
            titles = ["tissue_domain " + str(x) for x in range(self.k)]
        if labels is None:
            labels = range(len(adatas))
        n_panels = len(mse_id.keys())
        if ncols is None:
            ncols = len(titles)
        if n_panels <= ncols:
            n_rows, n_cols = 1, n_panels
        else:
            n_rows, n_cols = ceil(n_panels / ncols), ncols
        fig = plt.figure(figsize=(n_cols * figsize[0], n_rows * figsize[1]))
        left, bottom = 0.1 / n_cols, 0.1 / n_rows
        gs = gridspec.GridSpec(
            nrows=n_rows,
            ncols=n_cols,
            left=left,
            bottom=bottom,
            right=1 - (n_cols - 1) * left - 0.01 / n_cols,
            top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
        )
        for i in mse_id.keys():
            plt.subplot(gs[i])
            df = pd.DataFrame.from_dict(mse_id[i])
            plt.boxplot(df, positions=features, showfliers=False)
            for col in df:
                for k in range(len(df[col])):
                    dots = plt.scatter(
                        col,
                        df[col][k],
                        s=k + 1,
                        color=colors[k],
                        label=labels[k] if col == 0 else "",
                    )
                    offsets = dots.get_offsets()
                    jittered_offsets = offsets
                    # only jitter in the x-direction
                    jittered_offsets[:, 0] += np.random.uniform(
                        -0.3, 0.3, offsets.shape[0]
                    )
                    dots.set_offsets(jittered_offsets)
            plt.xlabel("PCs")
            plt.ylabel("mean square error")
            plt.title(titles[i])
        plt.legend(loc=loc, bbox_to_anchor=bbox_coordinates)
        gs.tight_layout(fig)
        if save_to:
            plt.savefig(fname=save_to, transparent=True, dpi=300)
        return fig

    def plot_tissue_ID_proportions_st(
        self,
        tID_labels=None,
        slide_labels=None,
        figsize=(5, 5),
        cmap="tab20",
        save_to=None,
    ):
        """
        Plot proportion of each tissue domain within each slide

        Parameters
        ----------
        tID_labels : list of str, optional (default=`None`)
            List of labels corresponding to MILWRM tissue domains for plotting legend
        slide_labels : list of str, optional (default=`None`)
            List of labels for each slide batch for labeling x-axis
        figsize : tuple of float, optional (default=(5,5))
            Size of matplotlib figure
        cmap : str, optional (default = `"tab20"`)
            Colormap from matplotlib
        save_to : str, optional (default=`None`)
            Path to image file to save plot

        Returns
        -------
        `gridspec.GridSpec` if `save_to` is `None`, else saves plot to file
        """
        df_count = pd.DataFrame()
        for adata in self.adatas:
            df = adata.obs["tissue_ID"].value_counts(normalize=True, sort=False)
            df_count = pd.concat([df_count, df], axis=1)
        df_count = df_count.T.reset_index(drop=True)
        if tID_labels:
            assert (
                len(tID_labels) == df_count.shape[1]
            ), "Length of given tissue domain labels does not match number of tissue domains!"
            df_count.columns = tID_labels
        if slide_labels:
            assert (
                len(slide_labels) == df_count.shape[0]
            ), "Length of given slide labels does not match number of slides!"
            df_count.index = slide_labels
        ax = df_count.plot.bar(stacked=True, cmap=cmap, figsize=figsize)
        ax.legend(loc="best", bbox_to_anchor=(1, 1))
        ax.set_xlabel("slides")
        ax.set_ylabel("tissue domain proportion")
        ax.set_ylim((0, 1))
        plt.tight_layout()
        if save_to is not None:
            ax.figure.savefig(save_to)
        else:
            return ax

    def show_feature_overlay(
        self,
        adata_index,
        pita,
        features=None,
        histo=None,
        cmap="tab20",
        label="feature",
        ncols=4,
        save_to=None,
        **kwargs,
    ):
        """
        Plot tissue_ID with individual pita features as alpha values to distinguish
        expression in identified tissue domains

        Parameters
        ----------
        adata_index : int
            Index of adata from `self.adatas` to plot overlays for (e.g. 0 for first
            adata object)
        pita : np.array
            Image of desired expression in pixel space from `.assemble_pita()`
        features : list of int, optional (default=`None`)
            List of features by index to show in plot. If `None`, use all features.
        histo : np.array or `None`, optional (default=`None`)
            Histology image to show along with pita in gridspec. If `None`, ignore.
        cmap : str, optional (default="tab20")
            Matplotlib colormap to use for plotting tissue domains
        label : str
            What to title each panel of the gridspec (i.e. "PC" or "usage") or each
            channel in RGB image. Can also pass list of names e.g. ["NeuN","GFAP",
            "DAPI"] corresponding to channels.
        ncols : int
            Number of columns for gridspec
        save_to : str or None
            Path to image file to save results. if `None`, show figure.
        **kwargs
            Arguments to pass to `plt.imshow()` function

        Returns
        -------
        Matplotlib object (if plotting one feature or RGB) or gridspec object (for
        multiple features). Saves plot to file if `save_to` is not `None`.
        """
        assert pita.ndim > 1, "Pita does not have enough dimensions: {} given".format(
            pita.ndim
        )
        assert pita.ndim < 4, "Pita has too many dimensions: {} given".format(pita.ndim)
        # create tissue_ID pita for plotting
        tIDs = assemble_pita(
            self.adatas[adata_index],
            features="tissue_ID",
            use_rep="obs",
            plot_out=False,
            verbose=False,
        )
        # if pita has multiple features, plot them in gridspec
        if isinstance(features, int):  # force features into list if single integer
            features = [features]
        # if no features are given, use all of them
        elif features is None:
            features = [x + 1 for x in range(pita.shape[2])]
        else:
            assert (
                pita.ndim > 2
            ), "Not enough features in pita: shape {}, expecting 3rd dim with length {}".format(
                pita.shape, len(features)
            )
            assert (
                len(features) <= pita.shape[2]
            ), "Too many features given: pita has {}, expected {}".format(
                pita.shape[2], len(features)
            )
        # min-max scale each feature in pita to convert to interpretable alpha values
        mms = MinMaxScaler()
        if pita.ndim == 3:
            pita_tmp = mms.fit_transform(
                pita.reshape((pita.shape[0] * pita.shape[1], pita.shape[2]))
            )
        elif pita.ndim == 2:
            pita_tmp = mms.fit_transform(
                pita.reshape((pita.shape[0] * pita.shape[1], 1))
            )
        # reshape back to original
        pita = pita_tmp.reshape(pita.shape)
        # figure out labels for gridspec plots
        if isinstance(label, str):
            # if label is single string, name channels numerically
            labels = ["{}_{}".format(label, x) for x in features]
        else:
            assert len(label) == len(
                features
            ), "Please provide the same number of labels as features; {} labels given, {} features given.".format(
                len(label), len(features)
            )
            labels = label
        # calculate gridspec dimensions
        if histo is not None:
            # determine where the histo image is in anndata
            assert (
                histo
                in self.adatas[adata_index]
                .uns["spatial"][
                    list(self.adatas[adata_index].uns["spatial"].keys())[0]
                ]["images"]
                .keys()
            ), "Must provide one of {} for histo".format(
                self.adatas[adata_index]
                .uns["spatial"][
                    list(self.adatas[adata_index].uns["spatial"].keys())[0]
                ]["images"]
                .keys()
            )
            histo = self.adatas[adata_index].uns["spatial"][
                list(self.adatas[adata_index].uns["spatial"].keys())[0]
            ]["images"][histo]
            if len(features) + 2 <= ncols:
                n_rows, n_cols = 1, len(features) + 2
            else:
                n_rows, n_cols = ceil((len(features) + 2) / ncols), ncols
            labels = ["Histology", "tissue_ID"] + labels  # append to front of labels
        else:
            if len(features) + 1 <= ncols:
                n_rows, n_cols = 1, len(features) + 1
            else:
                n_rows, n_cols = ceil(len(features) + 1 / ncols), ncols
            labels = ["tissue_ID"] + labels  # append to front of labels
        fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
        # arrange axes as subplots
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        # add plots to axes
        i = 0
        if histo is not None:
            # add histology plot to first axes
            ax = plt.subplot(gs[i])
            im = ax.imshow(histo, **kwargs)
            ax.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            ax.set_title(
                label=labels[i],
                loc="left",
                fontweight="bold",
                fontsize=16,
            )
            i = i + 1
        # plot tissue_ID first with colorbar
        ax = plt.subplot(gs[i])
        im = ax.imshow(tIDs, cmap=cmap, **kwargs)
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        ax.set_title(
            label=labels[i],
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        # colorbar scale for tissue_IDs
        _ = plt.colorbar(im, shrink=0.7)
        i = i + 1
        for feature in features:
            ax = plt.subplot(gs[i])
            im = ax.imshow(tIDs, alpha=pita[:, :, feature - 1], cmap=cmap, **kwargs)
            ax.tick_params(labelbottom=False, labelleft=False)
            sns.despine(bottom=True, left=True)
            ax.set_title(
                label=labels[i],
                loc="left",
                fontweight="bold",
                fontsize=16,
            )
            i = i + 1
        fig.tight_layout()
        if save_to:
            plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
        return fig

Ancestors

Methods

def confidence_score(self)

estimate confidence score for each visium slide

Parameters

Returns

self.adatas[i].obs.confidence_IDs and self.confidence_score_df are added
 
containing confidence score for each tissue domain assignment and mean confidence
 
score for each tissue domain within each visium slide
 
Expand source code
def confidence_score(self):
    """
    estimate confidence score for each visium slide

    Parameters
    ----------

    Returns
    -------
    self.adatas[i].obs.confidence_IDs and self.confidence_score_df are added
    containing confidence score for each tissue domain assignment and mean confidence
    score for each tissue domain within each visium slide
    """
    assert (
        self.kmeans is not None
    ), "No cluster results found. Run \
    label_tissue_regions() first."
    i_slice = 0
    j_slice = 0
    confidence_score_df = pd.DataFrame()
    adatas = self.adatas
    cluster_data = self.cluster_data
    centroids = self.kmeans.cluster_centers_
    for i, adata in enumerate(adatas):
        j_slice = j_slice + adata.n_obs
        data = cluster_data[i_slice:j_slice]
        scores_dict = estimate_confidence_score_st(data, adata, centroids)
        df = pd.DataFrame(scores_dict.values(), columns=[i])
        confidence_score_df = pd.concat([confidence_score_df, df], axis=1)
        i_slice = i_slice + adata.n_obs
    self.confidence_score_df = confidence_score_df
def label_tissue_regions(self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1)

Perform tissue-level clustering and label pixels in the corresponding anndata objects.

Parameters

k : int, optional (default=None)
Number of tissue regions to define
alpha : float
Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
plot_out : boolean, optional (default=True)
Determines if scaled inertia plot should be output
random_state : int, optional (default=18)
Seed for k-means clustering model.
n_jobs : int
Number of cores to parallelize k-choosing across

Returns

Does not return anything. self.adatas are updated, adding "tissue_ID" field to .obs. self.kmeans contains trained sklearn clustering model. Parameters are also captured as attributes for posterity.

Expand source code
def label_tissue_regions(
    self, k=None, alpha=0.05, plot_out=True, random_state=18, n_jobs=-1
):
    """
    Perform tissue-level clustering and label pixels in the corresponding
    `anndata` objects.

    Parameters
    ----------
    k : int, optional (default=None)
        Number of tissue regions to define
    alpha: float
        Manually tuned factor on [0.0, 1.0] that penalizes the number of clusters
    plot_out : boolean, optional (default=True)
        Determines if scaled inertia plot should be output
    random_state : int, optional (default=18)
        Seed for k-means clustering model.
    n_jobs : int
        Number of cores to parallelize k-choosing across

    Returns
    -------
    Does not return anything. `self.adatas` are updated, adding "tissue_ID" field
    to `.obs`. `self.kmeans` contains trained `sklearn` clustering model.
    Parameters are also captured as attributes for posterity.
    """
    # find optimal k with parent class
    if k is None:
        print("Determining optimal cluster number k via scaled inertia")
        self.find_optimal_k(
            plot_out=plot_out, alpha=alpha, random_state=random_state, n_jobs=n_jobs
        )
    # call k-means model from parent class
    self.find_tissue_regions(k=k, random_state=random_state)
    # loop through anndata object and add tissue labels to adata.obs dataframe
    start = 0
    print("Adding tissue_ID label to anndata objects")
    for i in range(len(self.adatas)):
        IDs = self.kmeans.labels_
        self.adatas[i].obs["tissue_ID"] = IDs[start : start + self.adatas[i].n_obs]
        self.adatas[i].obs["tissue_ID"] = (
            self.adatas[i].obs["tissue_ID"].astype("category")
        )
        self.adatas[i].obs["tissue_ID"] = (
            self.adatas[i].obs["tissue_ID"].cat.set_categories(np.unique(IDs))
        )
        start += self.adatas[i].n_obs
def plot_gene_loadings(self, PC_loadings, n_genes=10, ncols=None, titles=None, save_to=None)

Plot MILWRM loadings in gene space specifically for MILWRM done with PCs

Parameters

PC_loadings : numpy.ndarray
numpy.ndarray containing PC loadings shape format (genes, components)
n_genes : int, optional (default=10)
number of genes to plot
ncols : int, optional (default=None)
Number of columns for gridspec. If None, uses number of tissue domains k.
titles : list of str, optional (default=None)
Titles of plots corresponding to each MILWRM domain. If None, titles will be numbers 0 through k.
save_to : str, optional (default=None)
Path to image file to save plot

Returns

Matplotlib object and PC loadings in gene space set as self.gene_loadings_df
 
Expand source code
def plot_gene_loadings(
    self,
    PC_loadings,
    n_genes=10,
    ncols=None,
    titles=None,
    save_to=None,
):
    """
    Plot MILWRM loadings in gene space specifically for MILWRM done with PCs

    Parameters
    ----------
    PC_loadings : numpy.ndarray
        numpy.ndarray containing PC loadings shape format (genes, components)
    n_genes : int, optional (default=10)
        number of genes to plot
    ncols : int, optional (default=`None`)
        Number of columns for gridspec. If `None`, uses number of tissue domains k.
    titles : list of str, optional (default=`None`)
        Titles of plots corresponding to each MILWRM domain. If `None`, titles
        will be numbers 0 through k.
    save_to : str, optional (default=`None`)
        Path to image file to save plot

    Returns
    -------
    Matplotlib object and PC loadings in gene space set as self.gene_loadings_df
    """
    assert (
        self.kmeans is not None
    ), "No cluster results found. Run \
    label_tissue_regions() first."
    assert (
        PC_loadings.shape[0] == self.adatas[0].n_vars
    ), f"loadings matrix does not, \
    contain enough genes, there should be {self.adatas[0].n_vars} genes"
    assert (
        PC_loadings.shape[1] >= self.kmeans.cluster_centers_.shape[1]
    ), f"loadings matrix \
    does not contain enough components, there should be atleast {self.adatas[0].n_vars} components"
    if titles is None:
        titles = ["tissue_ID " + str(x) for x in range(self.k)]
    centroids = self.kmeans.cluster_centers_
    temp = PC_loadings.T
    loadings = temp[range(self.kmeans.cluster_centers_.shape[1])]
    gene_loadings = np.matmul(centroids, loadings)
    gene_loadings_df = pd.DataFrame(gene_loadings)
    gene_loadings_df = gene_loadings_df.T
    gene_loadings_df["genes"] = self.adatas[0].var_names
    self.gene_loadings_df = gene_loadings_df
    n_panels = self.k
    if ncols is None:
        ncols = self.k
    if n_panels <= ncols:
        n_rows, n_cols = 1, n_panels
    else:
        n_rows, n_cols = ceil(n_panels / ncols), ncols
    fig = plt.figure(figsize=((ncols * n_cols, ncols * n_rows)))
    left, bottom = 0.1 / n_cols, 0.1 / n_rows
    gs = gridspec.GridSpec(
        nrows=n_rows,
        ncols=n_cols,
        left=left,
        bottom=bottom,
        right=1 - (n_cols - 1) * left - 0.01 / n_cols,
        top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
    )
    for i in range(self.k):
        df = (
            gene_loadings_df[[i, "genes"]]
            .sort_values(i, axis=0, ascending=False)[:n_genes]
            .reset_index(drop=True)
        )
        plt.subplot(gs[i])
        df_rev = df.sort_values(i).reset_index(drop=True)
        for j, score in enumerate((df_rev[i])):
            plt.text(
                x=score,
                y=j + 0.1,
                s=df_rev.loc[j, "genes"],
                color="black",
                verticalalignment="center",
                horizontalalignment="right",
                fontsize="medium",
                fontstyle="italic",
            )
            plt.ylim([0, j + 1])
            plt.xlim([0, df.max().values[0] + 0.1])
            plt.tick_params(
                axis="y",  # changes apply to the y-axis
                which="both",  # both major and minor ticks are affected
                left=False,
                right=False,
                labelleft=False,
            )
            plt.title(titles[i])
    gs.tight_layout(fig)
    if save_to is not None:
        print("Saving feature loadings to {}".format(save_to))
        plt.savefig(save_to)
    else:
        return gs
def plot_mse_st(self, figsize=(5, 5), ncols=None, labels=None, titles=None, loc='lower right', bbox_coordinates=(0, 0, 1.5, 1.5), save_to=None)

estimate mean square error within each tissue domain

Parameters

fig_size : Tuple
size for the bar plot
ncols : int, optional (default=None)
Number of columns for gridspec. If None, uses number of tissue domains k.
labels : list of str, optional (default=None)
Labels corresponding to each image in legend. If None, numeric index is used for each imaage
titles : list of str, optional (default=None)
Titles of plots corresponding to each MILWRM domain. If None, titles will be numbers 0 through k.
loc : str, optional (default = 'lower right')
str for legend position
bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
coordinates for the legend box
save_to : str, optional (default=None)
Path to image file to save plot

Returns

Matplotlib object
 
Expand source code
def plot_mse_st(
    self,
    figsize=(5, 5),
    ncols=None,
    labels=None,
    titles=None,
    loc="lower right",
    bbox_coordinates=(0, 0, 1.5, 1.5),
    save_to=None,
):
    """
    estimate mean square error within each tissue domain

    Parameters
    ----------
    fig_size : Tuple
        size for the bar plot
    ncols : int, optional (default=`None`)
        Number of columns for gridspec. If `None`, uses number of tissue domains k.
    labels : list of str, optional (default=`None`)
        Labels corresponding to each image in legend. If `None`, numeric index is
        used for each imaage
    titles : list of str, optional (default=`None`)
        Titles of plots corresponding to each MILWRM domain. If `None`, titles
        will be numbers 0 through k.
    loc : str, optional (default = 'lower right')
        str for legend position
    bbox_coordinates : Tuple, optional (default = (0,0,1.5,1.5))
        coordinates for the legend box
    save_to : str, optional (default=`None`)
        Path to image file to save plot

    Returns
    -------
    Matplotlib object
    """
    assert (
        self.kmeans is not None
    ), "No cluster results found. Run \
    label_tissue_regions() first."
    cluster_data = self.cluster_data
    adatas = self.adatas
    k = self.k
    features = self.features
    centroids = self.kmeans.cluster_centers_
    mse_id = estimate_mse_st(cluster_data, adatas, centroids, k)
    colors = plt.cm.tab20(np.linspace(0, 1, len(adatas)))
    if titles is None:
        titles = ["tissue_domain " + str(x) for x in range(self.k)]
    if labels is None:
        labels = range(len(adatas))
    n_panels = len(mse_id.keys())
    if ncols is None:
        ncols = len(titles)
    if n_panels <= ncols:
        n_rows, n_cols = 1, n_panels
    else:
        n_rows, n_cols = ceil(n_panels / ncols), ncols
    fig = plt.figure(figsize=(n_cols * figsize[0], n_rows * figsize[1]))
    left, bottom = 0.1 / n_cols, 0.1 / n_rows
    gs = gridspec.GridSpec(
        nrows=n_rows,
        ncols=n_cols,
        left=left,
        bottom=bottom,
        right=1 - (n_cols - 1) * left - 0.01 / n_cols,
        top=1 - (n_rows - 1) * bottom - 0.1 / n_rows,
    )
    for i in mse_id.keys():
        plt.subplot(gs[i])
        df = pd.DataFrame.from_dict(mse_id[i])
        plt.boxplot(df, positions=features, showfliers=False)
        for col in df:
            for k in range(len(df[col])):
                dots = plt.scatter(
                    col,
                    df[col][k],
                    s=k + 1,
                    color=colors[k],
                    label=labels[k] if col == 0 else "",
                )
                offsets = dots.get_offsets()
                jittered_offsets = offsets
                # only jitter in the x-direction
                jittered_offsets[:, 0] += np.random.uniform(
                    -0.3, 0.3, offsets.shape[0]
                )
                dots.set_offsets(jittered_offsets)
        plt.xlabel("PCs")
        plt.ylabel("mean square error")
        plt.title(titles[i])
    plt.legend(loc=loc, bbox_to_anchor=bbox_coordinates)
    gs.tight_layout(fig)
    if save_to:
        plt.savefig(fname=save_to, transparent=True, dpi=300)
    return fig
def plot_percentage_variance_explained(self, fig_size=(5, 5), R_square=False, save_to=None)

plot percentage variance_explained or not explained by clustering

Parameters

figsize : tuple of float, optional (default=(5,5))
Size of matplotlib figure
R_square : Boolean
Decides if R_square is plotted or S_square
save_to : str or None
Path to image file to save results. If None, show figure.

Returns

Matplotlib object
 
Expand source code
def plot_percentage_variance_explained(
    self, fig_size=(5, 5), R_square=False, save_to=None
):
    """
    plot percentage variance_explained or not explained by clustering

    Parameters
    ----------
    figsize : tuple of float, optional (default=(5,5))
        Size of matplotlib figure
    R_square : Boolean
        Decides if R_square is plotted or S_square
    save_to : str or None
        Path to image file to save results. If `None`, show figure.

    Returns
    -------
    Matplotlib object
    """
    assert (
        self.kmeans is not None
    ), "No cluster results found. Run \
    label_tissue_regions() first."
    centroids = self.kmeans.cluster_centers_
    adatas = self.adatas
    cluster_data = self.cluster_data
    S_squre_for_each_st = []
    R_squre_for_each_st = []
    i_slice = 0
    j_slice = 0
    for adata in adatas:
        j_slice = j_slice + adata.n_obs
        sub_cluster_data = cluster_data[i_slice:j_slice]
        S_square = estimate_percentage_variance_st(
            sub_cluster_data, adata, centroids
        )
        S_squre_for_each_st.append(S_square)
        R_squre_for_each_st.append(100 - S_square)
        i_slice = i_slice + adata.n_obs

    if R_square:
        fig = plt.figure(figsize=fig_size)
        plt.scatter(
            range(len(R_squre_for_each_st)), R_squre_for_each_st, color="black"
        )
        plt.xlabel("images")
        plt.ylabel("percentage variance explained by Kmeans")
        plt.ylim((0, 100))
        plt.axhline(
            y=np.mean(R_squre_for_each_st),
            linestyle="dashed",
            linewidth=1,
            color="black",
        )

    else:
        fig = plt.figure(figsize=fig_size)
        fig = plt.figure(figsize=(5, 5))
        plt.scatter(
            range(len(S_squre_for_each_st)), S_squre_for_each_st, color="black"
        )
        plt.xlabel("images")
        plt.ylabel("percentage variance explained by Kmeans")
        plt.ylim((0, 100))
        plt.axhline(
            y=np.mean(S_squre_for_each_st),
            linestyle="dashed",
            linewidth=1,
            color="black",
        )

    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig
def plot_tissue_ID_proportions_st(self, tID_labels=None, slide_labels=None, figsize=(5, 5), cmap='tab20', save_to=None)

Plot proportion of each tissue domain within each slide

Parameters

tID_labels : list of str, optional (default=None)
List of labels corresponding to MILWRM tissue domains for plotting legend
slide_labels : list of str, optional (default=None)
List of labels for each slide batch for labeling x-axis
figsize : tuple of float, optional (default=(5,5))
Size of matplotlib figure
cmap : str, optional (default ="tab20")
Colormap from matplotlib
save_to : str, optional (default=None)
Path to image file to save plot

Returns

gridspec.GridSpec if save_to is None, else saves plot to file

Expand source code
def plot_tissue_ID_proportions_st(
    self,
    tID_labels=None,
    slide_labels=None,
    figsize=(5, 5),
    cmap="tab20",
    save_to=None,
):
    """
    Plot proportion of each tissue domain within each slide

    Parameters
    ----------
    tID_labels : list of str, optional (default=`None`)
        List of labels corresponding to MILWRM tissue domains for plotting legend
    slide_labels : list of str, optional (default=`None`)
        List of labels for each slide batch for labeling x-axis
    figsize : tuple of float, optional (default=(5,5))
        Size of matplotlib figure
    cmap : str, optional (default = `"tab20"`)
        Colormap from matplotlib
    save_to : str, optional (default=`None`)
        Path to image file to save plot

    Returns
    -------
    `gridspec.GridSpec` if `save_to` is `None`, else saves plot to file
    """
    df_count = pd.DataFrame()
    for adata in self.adatas:
        df = adata.obs["tissue_ID"].value_counts(normalize=True, sort=False)
        df_count = pd.concat([df_count, df], axis=1)
    df_count = df_count.T.reset_index(drop=True)
    if tID_labels:
        assert (
            len(tID_labels) == df_count.shape[1]
        ), "Length of given tissue domain labels does not match number of tissue domains!"
        df_count.columns = tID_labels
    if slide_labels:
        assert (
            len(slide_labels) == df_count.shape[0]
        ), "Length of given slide labels does not match number of slides!"
        df_count.index = slide_labels
    ax = df_count.plot.bar(stacked=True, cmap=cmap, figsize=figsize)
    ax.legend(loc="best", bbox_to_anchor=(1, 1))
    ax.set_xlabel("slides")
    ax.set_ylabel("tissue domain proportion")
    ax.set_ylim((0, 1))
    plt.tight_layout()
    if save_to is not None:
        ax.figure.savefig(save_to)
    else:
        return ax
def prep_cluster_data(self, use_rep, features=None, n_rings=1, histo=False, fluor_channels=None, spatial_graph_key=None, n_jobs=-1)

Prepare master dataframe for tissue-level clustering

Parameters

use_rep : str
Representation from adata.obsm to use as clustering data (e.g. "X_pca")
features : list of int or None, optional (default=None)
List of features to use from adata.obsm[use_rep] (e.g. [0,1,2,3,4] to use first 5 principal components when use_rep="X_pca"). If None, use all features from adata.obsm[use_rep]
n_rings : int, optional (default=1)
Number of hexagonal rings around each spatial transcriptomics spot to blur features by for capturing regional information. Assumes 10X Genomics Visium platform.
histo : bool, optional (default False)
Use histology data from Visium anndata object (R,G,B brightfield features) in addition to adata.obsm[use_rep]? If fluorescent imaging data rather than brightfield, use fluor_channels argument instead.
fluor_channels : list of int or None, optional (default None)
Channels from fluorescent image to use for model training (e.g. [1,3] for channels 1 and 3 of Visium fluorescent imaging data). If None, do not use imaging data for training.
spatial_graph_key : str, optional (default=None)
Key in adata.obsp containing spatial graph connectivities (i.e. "spatial_connectivities"). If None, compute new spatial graph using n_rings in squidpy.
n_jobs : int, optional (default=-1)
Number of cores to parallelize over. Default all available cores.

Returns

Does not return anything. self.adatas are updated, adding "blur_*" features to .obs if n_rings > 0. self.cluster_data becomes master np.array for cluster training. Parameters are also captured as attributes for posterity.

Expand source code
def prep_cluster_data(
    self,
    use_rep,
    features=None,
    n_rings=1,
    histo=False,
    fluor_channels=None,
    spatial_graph_key=None,
    n_jobs=-1,
):
    """
    Prepare master dataframe for tissue-level clustering

    Parameters
    ----------
    use_rep : str
        Representation from `adata.obsm` to use as clustering data (e.g. "X_pca")
    features : list of int or None, optional (default=`None`)
        List of features to use from `adata.obsm[use_rep]` (e.g. [0,1,2,3,4] to
        use first 5 principal components when `use_rep`="X_pca"). If `None`, use
        all features from `adata.obsm[use_rep]`
    n_rings : int, optional (default=1)
        Number of hexagonal rings around each spatial transcriptomics spot to blur
        features by for capturing regional information. Assumes 10X Genomics Visium
        platform.
    histo : bool, optional (default `False`)
        Use histology data from Visium anndata object (R,G,B brightfield features)
        in addition to `adata.obsm[use_rep]`? If fluorescent imaging data rather
        than brightfield, use `fluor_channels` argument instead.
    fluor_channels : list of int or None, optional (default `None`)
        Channels from fluorescent image to use for model training (e.g. [1,3] for
        channels 1 and 3 of Visium fluorescent imaging data). If `None`, do not
        use imaging data for training.
    spatial_graph_key : str, optional (default=`None`)
        Key in `adata.obsp` containing spatial graph connectivities (i.e.
        `"spatial_connectivities"`). If `None`, compute new spatial graph using
        `n_rings` in `squidpy`.
    n_jobs : int, optional (default=-1)
        Number of cores to parallelize over. Default all available cores.

    Returns
    -------
    Does not return anything. `self.adatas` are updated, adding "blur_*" features
    to `.obs` if `n_rings > 0`.
    `self.cluster_data` becomes master `np.array` for cluster training.
    Parameters are also captured as attributes for posterity.
    """
    if self.cluster_data is not None:
        print("WARNING: overwriting existing cluster data")
        self.cluster_data = None
    if features is None:
        self.features = [x for x in range(self.adatas[0].obsm[use_rep].shape[1])]
    else:
        self.features = features
    # save the hyperparams as object attributes
    self.rep = use_rep
    self.histo = histo
    self.fluor_channels = fluor_channels
    self.n_rings = n_rings
    # collect clustering data from self.adatas in parallel
    print(
        "Collecting and blurring {} features from .obsm[{}]...".format(
            len(self.features),
            use_rep,
        )
    )
    cluster_data = Parallel(n_jobs=n_jobs, verbose=10)(
        delayed(prep_data_single_sample_st)(
            adata,
            adata_i,
            use_rep,
            self.features,
            histo,
            fluor_channels,
            spatial_graph_key,
            n_rings,
        )
        for adata_i, adata in enumerate(self.adatas)
    )
    batch_labels = [
        [x] * len(cluster_data[x]) for x in range(len(cluster_data))
    ]  # batch labels for umap
    self.merged_batch_labels = list(itertools.chain(*batch_labels))
    # concatenate blurred features into cluster_data df for cluster training
    subsampled_data = pd.concat(cluster_data)
    # perform z-scaling on final cluster data
    scaler = StandardScaler()
    self.scaler = scaler.fit(subsampled_data)
    scaled_data = scaler.transform(subsampled_data)
    self.cluster_data = scaled_data
    print("Collected clustering data of shape: {}".format(self.cluster_data.shape))
def show_feature_overlay(self, adata_index, pita, features=None, histo=None, cmap='tab20', label='feature', ncols=4, save_to=None, **kwargs)

Plot tissue_ID with individual pita features as alpha values to distinguish expression in identified tissue domains

Parameters

adata_index : int
Index of adata from self.adatas to plot overlays for (e.g. 0 for first adata object)
pita : np.array
Image of desired expression in pixel space from .assemble_pita()
features : list of int, optional (default=None)
List of features by index to show in plot. If None, use all features.
histo : np.array or None, optional (default=None)
Histology image to show along with pita in gridspec. If None, ignore.
cmap : str, optional (default="tab20")
Matplotlib colormap to use for plotting tissue domains
label : str
What to title each panel of the gridspec (i.e. "PC" or "usage") or each channel in RGB image. Can also pass list of names e.g. ["NeuN","GFAP", "DAPI"] corresponding to channels.
ncols : int
Number of columns for gridspec
save_to : str or None
Path to image file to save results. if None, show figure.
**kwargs
Arguments to pass to plt.imshow() function

Returns

Matplotlib object (if plotting one feature or RGB) or gridspec object (for
 

multiple features). Saves plot to file if save_to is not None.

Expand source code
def show_feature_overlay(
    self,
    adata_index,
    pita,
    features=None,
    histo=None,
    cmap="tab20",
    label="feature",
    ncols=4,
    save_to=None,
    **kwargs,
):
    """
    Plot tissue_ID with individual pita features as alpha values to distinguish
    expression in identified tissue domains

    Parameters
    ----------
    adata_index : int
        Index of adata from `self.adatas` to plot overlays for (e.g. 0 for first
        adata object)
    pita : np.array
        Image of desired expression in pixel space from `.assemble_pita()`
    features : list of int, optional (default=`None`)
        List of features by index to show in plot. If `None`, use all features.
    histo : np.array or `None`, optional (default=`None`)
        Histology image to show along with pita in gridspec. If `None`, ignore.
    cmap : str, optional (default="tab20")
        Matplotlib colormap to use for plotting tissue domains
    label : str
        What to title each panel of the gridspec (i.e. "PC" or "usage") or each
        channel in RGB image. Can also pass list of names e.g. ["NeuN","GFAP",
        "DAPI"] corresponding to channels.
    ncols : int
        Number of columns for gridspec
    save_to : str or None
        Path to image file to save results. if `None`, show figure.
    **kwargs
        Arguments to pass to `plt.imshow()` function

    Returns
    -------
    Matplotlib object (if plotting one feature or RGB) or gridspec object (for
    multiple features). Saves plot to file if `save_to` is not `None`.
    """
    assert pita.ndim > 1, "Pita does not have enough dimensions: {} given".format(
        pita.ndim
    )
    assert pita.ndim < 4, "Pita has too many dimensions: {} given".format(pita.ndim)
    # create tissue_ID pita for plotting
    tIDs = assemble_pita(
        self.adatas[adata_index],
        features="tissue_ID",
        use_rep="obs",
        plot_out=False,
        verbose=False,
    )
    # if pita has multiple features, plot them in gridspec
    if isinstance(features, int):  # force features into list if single integer
        features = [features]
    # if no features are given, use all of them
    elif features is None:
        features = [x + 1 for x in range(pita.shape[2])]
    else:
        assert (
            pita.ndim > 2
        ), "Not enough features in pita: shape {}, expecting 3rd dim with length {}".format(
            pita.shape, len(features)
        )
        assert (
            len(features) <= pita.shape[2]
        ), "Too many features given: pita has {}, expected {}".format(
            pita.shape[2], len(features)
        )
    # min-max scale each feature in pita to convert to interpretable alpha values
    mms = MinMaxScaler()
    if pita.ndim == 3:
        pita_tmp = mms.fit_transform(
            pita.reshape((pita.shape[0] * pita.shape[1], pita.shape[2]))
        )
    elif pita.ndim == 2:
        pita_tmp = mms.fit_transform(
            pita.reshape((pita.shape[0] * pita.shape[1], 1))
        )
    # reshape back to original
    pita = pita_tmp.reshape(pita.shape)
    # figure out labels for gridspec plots
    if isinstance(label, str):
        # if label is single string, name channels numerically
        labels = ["{}_{}".format(label, x) for x in features]
    else:
        assert len(label) == len(
            features
        ), "Please provide the same number of labels as features; {} labels given, {} features given.".format(
            len(label), len(features)
        )
        labels = label
    # calculate gridspec dimensions
    if histo is not None:
        # determine where the histo image is in anndata
        assert (
            histo
            in self.adatas[adata_index]
            .uns["spatial"][
                list(self.adatas[adata_index].uns["spatial"].keys())[0]
            ]["images"]
            .keys()
        ), "Must provide one of {} for histo".format(
            self.adatas[adata_index]
            .uns["spatial"][
                list(self.adatas[adata_index].uns["spatial"].keys())[0]
            ]["images"]
            .keys()
        )
        histo = self.adatas[adata_index].uns["spatial"][
            list(self.adatas[adata_index].uns["spatial"].keys())[0]
        ]["images"][histo]
        if len(features) + 2 <= ncols:
            n_rows, n_cols = 1, len(features) + 2
        else:
            n_rows, n_cols = ceil((len(features) + 2) / ncols), ncols
        labels = ["Histology", "tissue_ID"] + labels  # append to front of labels
    else:
        if len(features) + 1 <= ncols:
            n_rows, n_cols = 1, len(features) + 1
        else:
            n_rows, n_cols = ceil(len(features) + 1 / ncols), ncols
        labels = ["tissue_ID"] + labels  # append to front of labels
    fig = plt.figure(figsize=(ncols * n_cols, ncols * n_rows))
    # arrange axes as subplots
    gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
    # add plots to axes
    i = 0
    if histo is not None:
        # add histology plot to first axes
        ax = plt.subplot(gs[i])
        im = ax.imshow(histo, **kwargs)
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        ax.set_title(
            label=labels[i],
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        i = i + 1
    # plot tissue_ID first with colorbar
    ax = plt.subplot(gs[i])
    im = ax.imshow(tIDs, cmap=cmap, **kwargs)
    ax.tick_params(labelbottom=False, labelleft=False)
    sns.despine(bottom=True, left=True)
    ax.set_title(
        label=labels[i],
        loc="left",
        fontweight="bold",
        fontsize=16,
    )
    # colorbar scale for tissue_IDs
    _ = plt.colorbar(im, shrink=0.7)
    i = i + 1
    for feature in features:
        ax = plt.subplot(gs[i])
        im = ax.imshow(tIDs, alpha=pita[:, :, feature - 1], cmap=cmap, **kwargs)
        ax.tick_params(labelbottom=False, labelleft=False)
        sns.despine(bottom=True, left=True)
        ax.set_title(
            label=labels[i],
            loc="left",
            fontweight="bold",
            fontsize=16,
        )
        i = i + 1
    fig.tight_layout()
    if save_to:
        plt.savefig(fname=save_to, transparent=True, bbox_inches="tight", dpi=300)
    return fig

Inherited members