-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Open
Description
Hi @rougier,
Because Discussions are not enabled for this repository, I decided to share my alternative solution for ridgeline plot exercise as separate issue, because it is a slightly different approach, for clarity I added the comments for each step:
from matplotlib.patches import ConnectionPatch
import matplotlib.pyplot as plt
import numpy as np
# Initial function
def curve():
n = np.random.randint(1,5)
centers = np.random.normal(0.0,1.0,n)
widths = np.random.uniform(5.0,50.0,n)
widths = 10*widths/widths.sum()
scales = np.random.uniform(0.1,1.0,n)
scales /= scales.sum()
X = np.zeros(500)
x = np.linspace(-3,3,len(X))
for center, width, scale in zip(centers, widths, scales):
X = X + scale*np.exp(- (x-center)*(x-center)*width)
return X
# Set random seed and subplots number
RANDOM_STATE = 25
np.random.seed(RANDOM_STATE)
rows = 50
cols = 3
# Configure figure and overlapping axes
fig, axs = plt.subplots(rows, cols, figsize=(10, rows*0.2), subplot_kw={'yticks': []})
plt.subplots_adjust(top=0.95, bottom=0.05, hspace=-0.5, wspace=0.1)
# Label columns with titles and rows with y-axis labels
# https://stackoverflow.com/a/25814386
row_names = [f'Serie {serie}' for serie in range(rows, 0, -1)]
col_names = [f'Value {value}' for value in range(1, cols + 1)]
for ax, name in zip(axs[:, 0], row_names):
ax.set_ylabel(name, rotation=0, fontsize='small', loc='bottom', labelpad=40)
for ax, name in zip(axs[0], col_names):
ax.set_title(name, fontsize='large', fontweight='bold', loc='left', pad=20)
# Control the degree of "softness/pastelness" of the colors (if required)
# https://stackoverflow.com/a/72289062
c = 0
colors = (1. - c) * plt.get_cmap('Spectral')(np.linspace(0, 1, rows)) + c * np.ones((rows, 4))
colors = colors[::-1]
# Plot graphs from left to right, from top to bottom
for idx in range(rows * cols):
i = idx // cols
j = idx % cols
ax = axs[i][j]
ax.set_facecolor('none')
ax.spines[['left', 'top', 'right', 'bottom']].set_visible(False)
if i != rows - 1:
ax.get_xaxis().set_visible(False)
y = curve()
x = np.linspace(-3, 3, y.size)
ax.plot(x, y, color='k', linewidth=1)
ax.fill_between(x, y, color=colors[i])
ax.set_zorder(i)
# Plot vertical lines for each column of subplots
for col_idx in range(cols):
coords = axs[0][col_idx].get_position()
zero_x_coords = (coords.x1 + coords.x0) / 2
conn = ConnectionPatch(xyA=(zero_x_coords, 0.99), xyB=(0, 0),
coordsA='figure fraction', coordsB='data',
axesA=axs[0, col_idx], axesB=axs[rows-1, col_idx],
zorder=rows, linewidth=0.5, linestyle=(0, (8, 2)), color='k')
fig.add_artist(conn)
plt.show()Output image you can find in Jupyter notebook with exercises for Chapter 1.
Thank you.
Metadata
Metadata
Assignees
Labels
No labels