Skip to content

Alternative solution (Chapter 1, Exercise 3) #62

@labdmitriy

Description

@labdmitriy

Hi @rougier,

Because Discussions are not enabled for this repository, I decided to share my alternative solution for ridgeline plot exercise as separate issue, because it is a slightly different approach, for clarity I added the comments for each step:

from matplotlib.patches import ConnectionPatch
import matplotlib.pyplot as plt
import numpy as np

# Initial function
def curve():
    n = np.random.randint(1,5)
    centers = np.random.normal(0.0,1.0,n)
    widths = np.random.uniform(5.0,50.0,n)
    widths = 10*widths/widths.sum()
    scales = np.random.uniform(0.1,1.0,n)
    scales /= scales.sum()
    X = np.zeros(500)
    x = np.linspace(-3,3,len(X))
    
    for center, width, scale in zip(centers, widths, scales):
        X = X + scale*np.exp(- (x-center)*(x-center)*width)
        
    return X

# Set random seed and subplots number
RANDOM_STATE = 25
np.random.seed(RANDOM_STATE)
rows = 50
cols = 3

# Configure figure and overlapping axes
fig, axs = plt.subplots(rows, cols, figsize=(10, rows*0.2), subplot_kw={'yticks': []})
plt.subplots_adjust(top=0.95, bottom=0.05, hspace=-0.5, wspace=0.1)

# Label columns with titles and rows with y-axis labels
# https://stackoverflow.com/a/25814386
row_names = [f'Serie {serie}' for serie in range(rows, 0, -1)]
col_names = [f'Value {value}' for value in range(1, cols + 1)]

for ax, name in zip(axs[:, 0], row_names):
    ax.set_ylabel(name, rotation=0, fontsize='small', loc='bottom', labelpad=40)
    
for ax, name in zip(axs[0], col_names):
    ax.set_title(name, fontsize='large', fontweight='bold', loc='left', pad=20)

# Control the degree of "softness/pastelness" of the colors (if required)
# https://stackoverflow.com/a/72289062
c = 0
colors = (1. - c) * plt.get_cmap('Spectral')(np.linspace(0, 1, rows)) + c * np.ones((rows, 4))
colors = colors[::-1]

# Plot graphs from left to right, from top to bottom
for idx in range(rows * cols):
    i = idx // cols
    j = idx % cols
    
    ax = axs[i][j]
    ax.set_facecolor('none')
    ax.spines[['left', 'top', 'right', 'bottom']].set_visible(False)
    
    if i != rows - 1:
        ax.get_xaxis().set_visible(False)
    
    y = curve()
    x = np.linspace(-3, 3, y.size)
    ax.plot(x, y, color='k', linewidth=1)
    ax.fill_between(x, y, color=colors[i])
    ax.set_zorder(i)

# Plot vertical lines for each column of subplots
for col_idx in range(cols):
    coords = axs[0][col_idx].get_position()
    zero_x_coords = (coords.x1 + coords.x0) / 2
    conn = ConnectionPatch(xyA=(zero_x_coords, 0.99), xyB=(0, 0), 
                           coordsA='figure fraction', coordsB='data', 
                           axesA=axs[0, col_idx], axesB=axs[rows-1, col_idx],
                           zorder=rows, linewidth=0.5, linestyle=(0, (8, 2)), color='k')
    fig.add_artist(conn)

plt.show()

Output image you can find in Jupyter notebook with exercises for Chapter 1.

Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions