Skip to content
2 changes: 1 addition & 1 deletion fastplotlib/graphics/line_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def cmap(self, args):
if isinstance(args, str):
name = args
transform, alpha = None, 1.0
if len(args) == 1:
elif len(args) == 1:
name = args[0]
transform, alpha = None, None

Expand Down
1 change: 1 addition & 0 deletions fastplotlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .functions import *
from .gpu import enumerate_adapters, select_adapter, print_wgpu_report
from ._plot_helpers import *


@dataclass
Expand Down
53 changes: 53 additions & 0 deletions fastplotlib/utils/_plot_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Sequence

import numpy as np

from ..graphics._base import Graphic
from ..graphics._collection_base import GraphicCollection


def get_nearest_graphics(
pos: tuple[float, float] | tuple[float, float, float],
graphics: Sequence[Graphic] | GraphicCollection,
) -> np.ndarray[Graphic]:
"""
Returns the nearest ``graphics`` to the passed position ``pos`` in world space.
Uses the distance between ``pos`` and the center of the bounding sphere for each graphic.

Parameters
----------
pos: (x, y) | (x, y, z)
position in world space, z-axis is ignored when calculating L2 norms if ``pos`` is 2D

graphics: Sequence, i.e. array, list, tuple, etc. of Graphic | GraphicCollection
the graphics from which to return a sorted array of graphics in order of closest
to furthest graphic

Returns
-------
tuple[Graphic]
nearest graphics to ``pos`` in order

"""

if isinstance(graphics, GraphicCollection):
graphics = graphics.graphics

if not all(isinstance(g, Graphic) for g in graphics):
raise TypeError("all elements of `graphics` must be Graphic objects")
Comment on lines +36 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there also be a check for if all the graphics are in the same subplot? Or would that not make a difference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We leave that to the user, so it's just a very simple function unaware of plot areas


pos = np.asarray(pos)

if pos.shape != (2,) or not pos.shape != (3,):
raise TypeError

# get centers
centers = np.empty(shape=(len(graphics), len(pos)))
for i in range(centers.shape[0]):
centers[i] = graphics[i].world_object.get_world_bounding_sphere()[: len(pos)]

# l2
distances = np.linalg.norm(centers[:, : len(pos)] - pos, ord=2, axis=1)

sort_indices = np.argsort(distances)
return np.asarray(graphics)[sort_indices]
33 changes: 33 additions & 0 deletions tests/test_plot_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
import fastplotlib as fpl


def make_circle(center, radius: float, n_points: int = 75) -> np.ndarray:
theta = np.linspace(0, 2 * np.pi, n_points)
xs = radius * np.sin(theta)
ys = radius * np.cos(theta)

return np.column_stack([xs, ys]) + center


def test_get_nearest_graphics():
circles = list()

centers = [[0, 0], [0, 20], [20, 0], [20, 20]]

for center in centers:
circles.append(make_circle(center, 5, n_points=75))

fig = fpl.Figure()

lines = fig[0, 0].add_line_collection(circles, cmap="jet", thickness=5)

fig[0, 0].add_scatter(np.array([[0, 12, 0]]))

# check distances
nearest = fpl.utils.get_nearest_graphics((0, 12), lines)
assert nearest[0] is lines[1] # closest
assert nearest[1] is lines[0]
assert nearest[2] is lines[3]
assert nearest[3] is lines[2] # furthest
assert nearest[-1] is lines[2]