Skip to content

Mondrian

The lines of code where ninejs is used are highlighted (only four, including the import); everything else is part of the original code.

import polars as pl
import numpy as np
from plotnine import (
    ggplot,
    aes,
    geom_rect,
    theme_minimal,
    scale_fill_manual,
    theme,
    element_blank,
)
from typing import List, Tuple
from enum import Enum
from ninejs import interactive, save


class MondrianColour(Enum):
    BLACK = "#000000"
    YELLOW = "#FDDE06"
    BLUE = "#0300AD"
    RED = "#E70503"
    WHITE = "#ffffff"


class Node:
    def __init__(
        self, depth: int, x_range: Tuple[float, float], y_range: Tuple[float, float]
    ):
        self.depth = depth
        self.x_range = x_range
        self.y_range = y_range
        self.left = None
        self.right = None
        self.split_value = None
        self.is_vertical = np.random.choice([True, False])


def generate_tree(
    node: Node, max_depth: int, min_size: float, force_split: bool = False
) -> None:
    width = node.x_range[1] - node.x_range[0]
    height = node.y_range[1] - node.y_range[0]

    if not force_split:
        if node.depth >= max_depth or (np.random.random() < 0.1 and node.depth > 1):
            return

        if width < min_size and height < min_size:
            return

    if node.is_vertical and width >= min_size:
        node.split_value = np.random.uniform(
            node.x_range[0] + min_size, node.x_range[1] - min_size
        )
        node.left = Node(
            node.depth + 1, (node.x_range[0], node.split_value), node.y_range
        )
        node.right = Node(
            node.depth + 1, (node.split_value, node.x_range[1]), node.y_range
        )
    elif not node.is_vertical and height >= min_size:
        node.split_value = np.random.uniform(
            node.y_range[0] + min_size, node.y_range[1] - min_size
        )
        node.left = Node(
            node.depth + 1, node.x_range, (node.y_range[0], node.split_value)
        )
        node.right = Node(
            node.depth + 1, node.x_range, (node.split_value, node.y_range[1])
        )
    else:
        return

    generate_tree(node.left, max_depth, min_size)
    generate_tree(node.right, max_depth, min_size)


def initial_splits(root: Node, min_size: float) -> None:
    # Vertical split
    root.is_vertical = True
    root.split_value = np.random.uniform(
        root.x_range[0] + min_size, root.x_range[1] - min_size
    )
    root.left = Node(1, (root.x_range[0], root.split_value), root.y_range)
    root.right = Node(1, (root.split_value, root.x_range[1]), root.y_range)

    # Horizontal splits
    root.left.is_vertical = False
    root.left.split_value = np.random.uniform(
        root.left.y_range[0] + min_size, root.left.y_range[1] - min_size
    )
    root.left.left = Node(
        2, root.left.x_range, (root.left.y_range[0], root.left.split_value)
    )
    root.left.right = Node(
        2, root.left.x_range, (root.left.split_value, root.left.y_range[1])
    )

    root.right.is_vertical = False
    root.right.split_value = np.random.uniform(
        root.right.y_range[0] + min_size, root.right.y_range[1] - min_size
    )
    root.right.left = Node(
        2, root.right.x_range, (root.right.y_range[0], root.right.split_value)
    )
    root.right.right = Node(
        2, root.right.x_range, (root.right.split_value, root.right.y_range[1])
    )


def tree_to_rectangles(node: Node, rectangles: List[dict]) -> None:
    if node.left is None and node.right is None:
        rectangles.append(
            {
                "xmin": node.x_range[0],
                "xmax": node.x_range[1],
                "ymin": node.y_range[0],
                "ymax": node.y_range[1],
                "depth": node.depth,
            }
        )
    else:
        tree_to_rectangles(node.left, rectangles)
        tree_to_rectangles(node.right, rectangles)


np.random.seed(42)

root = Node(0, (0, 1), (0, 1))
min_size = 0.05
max_depth = 12

# Perform initial splits
initial_splits(root, min_size)

# Continue generating the tree
generate_tree(root.left.left, max_depth, min_size)  # pyrefly: ignore
generate_tree(root.left.right, max_depth, min_size)  # pyrefly: ignore
generate_tree(root.right.left, max_depth, min_size)  # pyrefly: ignore
generate_tree(root.right.right, max_depth, min_size)  # pyrefly: ignore

rectangles = []
tree_to_rectangles(root, rectangles)

colours = pl.Series(
    name="colour",
    values=np.random.choice(
        [colour.value for colour in MondrianColour], size=len(rectangles)
    ),
)

df = pl.DataFrame(rectangles).with_columns(colours)

plot = (
    ggplot(
        df,
        aes(
            xmin="xmin",
            xmax="xmax",
            ymin="ymin",
            ymax="ymax",
            fill="colour",
            tooltip="colour",
            data_id="colour",
        ),
    )
    + geom_rect(color="black", size=2)
    + scale_fill_manual(values=[colour.value for colour in MondrianColour])
    + theme_minimal()
    + theme(
        legend_position="none",
        aspect_ratio=1,
        axis_text=element_blank(),
        axis_ticks=element_blank(),
        panel_grid=element_blank(),
        figure_size=(10, 10),
    )
)

interactive(plot) + save("docs/iframes/mondrian.html")