Using wot from Python


This notebook provides a running example with simulated data to walk you through the most important tools of wot.

You can read the code and the detailed comments below, and simply copy-paste it into a python interpreter, or download the whole python script for a specific section by clicking the download button next to the section title.

The examples below will show you how to generate your own simulated data using wot’s simulate module, and compute transport maps for it to generate your own plots.

Alternatively, we also provide simulated data with all transport maps precomputed, and an archive with all of the scripts presented in this notebook. Just click the button below to download the whole archive.


Download python code

Simulating data


import numpy as np
import wot.simulate
from numpy.random import randint
from numpy.random import random

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
covariate_file = 'covariate.txt'
gene_sets_file = 'gene_sets.gmt'

number_of_timepoints = 51
covariates_count = 5
average_cell_count_per_timepoint = 3000

gene_sets = {
    'Stem cells': ['Stem_gene'],
    'Myeloid stem cells': ['Myeloid_gene'],
    'Red blood cells': ['RBC_gene'],
    'Granulocytes': ['Granulo_gene'],
    'Lymphocytes': ['Lympho_gene'],
}
# --------------------------------------

gene_names = ['X_gene', 'Y_gene',
              'RBC_gene', 'Granulo_gene', 'Lympho_gene',
              'Myeloid_gene', 'Stem_gene']
tips = [
    [[-.3, 0, 0, 0, 0, 0, 10], [0, 2.5, 0, 0, 0, 2, 3], [-1, 5, 0, 0, 0, 4, 1], [-2, 7.5, 5, 0, 0, 0, 0],
     [-2.5, 10, 10, 0, 0, 0, 0]],
    [[-.3, 0, 0, 0, 0, 0, 10], [0, 2.5, 0, 0, 0, 2, 3], [-1, 5, 0, 0, 0, 4, 1], [-.5, 7.5, 0, 5, 0, 0, 0],
     [-.5, 10, 0, 10, 0, 0, 0]],
    [[-.3, 0, 0, 0, 0, 0, 10], [0, 2.5, 0, 0, 0, 2, 3], [.3, 5, 0, 0, 5, 1, 1], [1, 7.5, 0, 0, 9, 0, 0],
     [2, 10, 0, 0, 10, 0, 0]]
]
times = [np.linspace(0, 1, num=len(k)) for k in tips]

N = number_of_timepoints
timepoints = np.linspace(0, 1, num=N)

means = np.array(
    [wot.simulate.interp(timepoints, times[k], tips[k],
                         method='linear', smooth=(N // 10))
     for k in range(len(tips))])
means = np.asarray([means[:, t] for t in range(N)])

covs = [.08, .1, .04, .04, .04, .03, .05]
covs = [[c * (random() + .5) for c in covs]] * len(tips)

sizes = [average_cell_count_per_timepoint + randint(-100, 100) for _ in range(N)]

data = [wot.simulate.multivariate_normal_mixture(means[i],
                                                 covs, size=sizes[i]) for i in range(N)]
data_to_dataset = lambda i: \
    wot.dataset_from_x(data[i],
                       row_prefix="cell_g{:02}_".format(i),
                       columns=gene_names)
dataset_list = [data_to_dataset(i) for i in range(N)]

for i in range(N):
    wot.set_cell_metadata(dataset_list[i], 'day', i)
    covariates = randint(0, covariates_count, size=sizes[i])
    wot.set_cell_metadata(dataset_list[i], 'covariate', covariates)
ds = wot.merge_datasets(*dataset_list)

wot.io.write_gene_sets(gene_sets, gene_sets_file, "gmt")
wot.io.write_dataset(ds, matrix_file)
wot.io.write_dataset_metadata(ds.obs, days_file, 'day')
wot.io.write_dataset_metadata(ds.obs, covariate_file, 'covariate')

Download python code

Plotting two features


import numpy as np
from matplotlib import pyplot

import wot.graphics

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
gene_x_plot = 0
gene_y_plot = 1
destination_file = "generated_data.png"
# --------------------------------------

color1 = [.08, .34, .59]  # first color
color2 = [.08, .59, .34]  # final color

ds = wot.io.read_dataset(matrix_file)
wot.io.add_row_metadata_to_dataset(ds, days_path=days_file)

# you can use any of the columns here, or metadata information :
cell_colors = np.asarray(ds.obs['day'])
cell_colors = cell_colors / max(cell_colors)
cell_colors = [wot.graphics.color_mix(color1, color2, d)
               for d in cell_colors]
wot.set_cell_metadata(ds, 'color', cell_colors)

pyplot.figure(figsize=(5, 5))
pyplot.axis('off')
wot.graphics.plot_2d_dataset(pyplot, ds,
                             x=gene_x_plot, y=gene_y_plot)
pyplot.autoscale(enable=True, tight=True)
pyplot.tight_layout(pad=0)
pyplot.savefig(destination_file)

Generated data

Download python code

Cell sets


from matplotlib import pyplot

import wot.graphics

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
gene_sets_file = 'gene_sets.gmt'
quantile_for_cell_sets = .88
cell_sets_file = 'cell_sets.gmt'
bg_color = "#80808020"
cell_sets_to_color = [
    ['red', 'Red blood cells'],
    ['blue', 'Granulocytes'],
    ['green', 'Lymphocytes'],
    ['purple', 'Myeloid stem cells'],
    ['black', 'Stem cells'],
]
gene_x_plot = 0
gene_y_plot = 1
destination_file = "cell_sets.png"
# --------------------------------------


ds = wot.io.read_dataset(matrix_file)
wot.io.add_row_metadata_to_dataset(ds, days_path=days_file)

# Compute the cell sets for the given quantile

gene_sets = wot.io.read_sets(gene_sets_file, ds.var.index.values)
cell_sets = wot.get_cells_in_gene_sets(gene_sets, ds,
                                       quantile=quantile_for_cell_sets)
wot.io.write_gene_sets(cell_sets, cell_sets_file, "gmt")

# Plot the cell sets

wot.set_cell_metadata(ds, 'color', bg_color)

for color, cset_name in cell_sets_to_color:
    wot.set_cell_metadata(ds, 'color', color,
                          indices=cell_sets[cset_name])

pyplot.figure(figsize=(5, 5))
pyplot.axis('off')
wot.graphics.plot_2d_dataset(pyplot, ds,
                             x=gene_x_plot, y=gene_y_plot)
wot.graphics.legend_figure(pyplot, cell_sets_to_color)
pyplot.autoscale(enable=True, tight=True)
pyplot.tight_layout(pad=0)
pyplot.savefig(destination_file)

Cell sets plots

Download python code

Create Transport Maps


If you have downloaded our precomputed transport maps, you can skip this section.

To compute ancestors and descendants, you will need to compute transport maps using an OT Model, storing all parameters for the transport maps. You can initialize it in python, for future use :

import wot.ot

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
# --------------------------------------

ot_model = wot.ot.initialize_ot_model(matrix_file, days_file, growth_iters=1,
                                      epsilon=.02, lambda1=10, lambda2=80, local_pca=0)
ot_model.compute_all_transport_maps()

You can then use wot.load_ot_model to load this model, with the same parameters.

Download python code

Ancestors of a cell set


import numpy as np
from matplotlib import pyplot

import wot.graphics

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
bg_color = "#80808080"
gene_x_plot = 0
gene_y_plot = 1
cell_sets_file = 'cell_sets.gmt'
target_cell_set = "Red blood cells"
target_timepoint = 50
destination_file = "ancestors.png"
# --------------------------------------

ds = wot.io.read_dataset(matrix_file)

tmap_model = wot.tmap.TransportMapModel.from_directory('tmaps')

transparent = lambda x: wot.graphics.hexstring_of_rgba((.08, .34, .59, x))


def color_cells(population):
    p = population.p
    if not np.isclose(max(p), 0):
        p = p / max(p)
    color = [transparent(x) for x in p]
    wot.set_cell_metadata(ds, 'color', color,
                          indices=tmap_model.cell_ids(population))


pyplot.figure(figsize=(5, 5))
pyplot.axis('off')
wot.set_cell_metadata(ds, 'color', bg_color)
wot.graphics.plot_2d_dataset(pyplot, ds, x=gene_x_plot, y=gene_y_plot)

cell_sets = wot.io.read_sets(cell_sets_file, as_dict=True)
population = tmap_model.population_from_ids(
    cell_sets[target_cell_set],
    at_time=target_timepoint)[0]
color_cells(population)

while tmap_model.can_pull_back(population):
    population = tmap_model.pull_back(population)
    color_cells(population)

wot.graphics.plot_2d_dataset(pyplot, ds,
                             x=gene_x_plot, y=gene_y_plot)
wot.graphics.legend_figure(pyplot,
                           [["#316DA2", "Ancestors of {}".format(target_cell_set)]])
pyplot.autoscale(enable=True, tight=True)
pyplot.tight_layout(pad=0)
pyplot.savefig(destination_file)

Ancestors plot

Download python code

Shared ancestry


import numpy as np
from matplotlib import pyplot

import wot.graphics

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
bg_color = "#80808050"
gene_x_plot = 0
gene_y_plot = 1
cell_set_1 = "Red blood cells"
cell_set_2 = "Granulocytes"
cell_sets_file = 'cell_sets.gmt'
target_timepoint = 50
destination_file = "shared_ancestry.png"
# --------------------------------------


tmap_model = wot.tmap.TransportMapModel.from_directory('tmaps')
cell_sets = wot.io.read_sets(cell_sets_file, as_dict=True)
populations = tmap_model.population_from_cell_sets(cell_sets,
                                                   at_time=target_timepoint)

trajectories = tmap_model.compute_trajectories(populations)
ds = wot.io.read_dataset(matrix_file)
wot.set_cell_metadata(ds, 'color', bg_color)

pyplot.figure(figsize=(5, 5))
pyplot.axis('off')
wot.graphics.plot_2d_dataset(pyplot, ds)

probabilities = trajectories.X[:, trajectories.var.index.get_indexer_for([cell_set_1, cell_set_2])]
alphas = np.amax(probabilities, axis=1)
t = np.log((probabilities.T[0] + 1e-9) / (probabilities.T[1] + 1e-9))
t = np.clip(t / 8 + .5, 0, 1)
alphas = alphas / max(alphas)
colors = [wot.graphics.hexstring_of_rgba([t[i], 0, 1 - t[i], alphas[i]])
          for i in range(len(t))]
ds.obs.loc[:, 'color'] = colors

wot.graphics.plot_2d_dataset(pyplot, ds)
wot.graphics.legend_figure(pyplot,
                           [["#A00000", "Ancestors of " + cell_set_1],
                            ["#0000A0", "Ancestors of " + cell_set_2]],
                           loc=3)
pyplot.autoscale(enable=True, tight=True)
pyplot.tight_layout(pad=0)
pyplot.savefig(destination_file)

Shared ancestry plot

Download python code


import numpy
from matplotlib import pyplot

import wot

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
cell_sets_file = 'cell_sets.gmt'
target_cell_set = 'Red blood cells'
target_timepoint = 50
skip_first_n_genes = 2
destination_file = "trajectory_trends.png"
# --------------------------------------

ds = wot.io.read_dataset(matrix_file)
tmap_model = wot.tmap.TransportMapModel.from_directory('tmaps')
cell_sets = wot.io.read_sets(cell_sets_file, as_dict=True)
all_populations = tmap_model.population_from_cell_sets(cell_sets,
                                                       at_time=target_timepoint)
population = all_populations[target_cell_set]

# timepoints, means, variances = tmap_model.compute_trajectory_trends(ds, population)
trajectory_ds = tmap_model.compute_trajectories({target_cell_set: all_populations[target_cell_set]})
results = wot.ot.compute_trajectory_trends_from_trajectory(trajectory_ds, ds)
means, variances = results[0]
timepoints = means.obs.index.values
pyplot.figure(figsize=(5, 5))
means = means.X
stds = numpy.sqrt(variances.X)
genes = ds.var.index
for i in range(skip_first_n_genes, means.shape[1]):
    pyplot.plot(timepoints, means[:, i], label=genes[i])
    pyplot.fill_between(timepoints, means[:, i] - stds[:, i],
                        means[:, i] + stds[:, i], alpha=.5)

pyplot.xlabel("Time")
pyplot.ylabel("Gene expression")
pyplot.title("Trajectory trend of {} from time {}" \
             .format(target_cell_set, target_timepoint))
pyplot.legend()
pyplot.savefig(destination_file)

Trajectory trends plot

Download python code

Ancestor census


from matplotlib import pyplot

import wot

# ------ Configuration variables -------
matrix_file = 'matrix.txt'
cell_sets_file = 'cell_sets.gmt'
target_cell_set = 'Red blood cells'
target_timepoint = 50
destination_file = "ancestor_census.png"
# --------------------------------------


tmap_model = wot.tmap.TransportMapModel.from_directory('tmaps')

cs_matrix = wot.io.read_sets(cell_sets_file)
cell_sets = wot.io.convert_binary_dataset_to_dict(cs_matrix)
all_populations = tmap_model.population_from_cell_sets(cell_sets,
                                                       at_time=target_timepoint)
population = all_populations[target_cell_set]

timepoints, census = tmap_model.compute_ancestor_census(cs_matrix, population)

pyplot.figure(figsize=(5, 5))

cset_names = list(cell_sets.keys())
for i in range(census.shape[2]):
    pyplot.plot(timepoints, census[:, :, i].T, label=cset_names[i])

pyplot.xlabel("Time")
pyplot.ylabel("Proportion of ancestors")
pyplot.title("Ancestor census of {} from time {}" \
             .format(target_cell_set, target_timepoint))
pyplot.legend()
pyplot.savefig(destination_file)

Ancestor census plot

Download python code

Validation summary


import numpy as np
from matplotlib import pyplot

import wot.commands
import wot.graphics
# ------ Configuration variables -------
matrix_file = 'matrix.txt'
days_file = 'days.txt'
covariate_file = 'covariate.txt'
destination_file = 'validation_summary.png'
# --------------------------------------

ot_model = wot.ot.initialize_ot_model(matrix_file, days_file,
                                      covariate=covariate_file, growth_iters=1, tmap_prefix='val')
vs = wot.commands.compute_validation_summary(ot_model)
vs['time'] = (vs['interval_start'] + vs['interval_end']) / 2
vs['type'] = vs['pair0'].astype(str).str[0]
res = vs.groupby(['time', 'type'])['distance'] \
    .agg([np.mean, np.std])

legend = {
    'P': ["#f000f0", "between real batches"],
    'R': ["#00f000", "between random and real"],
    'I': ["#f00000", "between interpolated and real"],
    'F': ["#00f0f0", "between first and real"],
    'L': ["#f0f000", "between last and real"],
}

pyplot.figure(figsize=(10, 10))
pyplot.title("Validation of the OT model")
pyplot.xlabel("time")
pyplot.ylabel("distance")
wot.graphics.legend_figure(pyplot, legend.values())
for p, d in res.groupby('type'):
    if p not in legend.keys():
        continue
    t = np.asarray(d.index.get_level_values('time'))
    m = np.asarray(d['mean'])
    s = np.asarray(d['std'])
    pyplot.plot(t, m, '-o', color=legend[p][0])
    pyplot.fill_between(t, m - s, m + s, color=legend[p][0] + "50")
pyplot.savefig(destination_file)

Validation summary plot