X Tutup
Skip to content

PERF: _api check function speedups#31071

Closed
scottshambaugh wants to merge 9 commits intomatplotlib:mainfrom
scottshambaugh:check_size_speedup
Closed

PERF: _api check function speedups#31071
scottshambaugh wants to merge 9 commits intomatplotlib:mainfrom
scottshambaugh:check_size_speedup

Conversation

@scottshambaugh
Copy link
Contributor

PR summary

This speeds up the check_* functions in _api, largely by caching the function setup and deferring error handling.

This necessarily comes with a big change in how these are called, as they now take in tuples and return functions. Before: _api.check_shape((None, 2), arg=arg), after: _api.check_shape((None, 2))("arg", arg). This also means that we cannot check multiple args in one call, although this is not common in the codebase and the caching eliminates the overhead.

  • check_shape - 3 us to 1 us (3x) - Cache and defer constructing the error message
  • check_isinstance - 2 us to 0.5 us (4x) - Cache and defer constructing the error message
  • check_in_list - 0.8 us to 0.5 us (1.5x) - Cast to tuple and cache
  • check_getitem - 1 us to 1 us (1x) - No speedup, updated to match new API

If there's a lot of pushback on the new API, only a few of these are on hot paths so I don't expect the overall benefit to be enormous. Each function has its own commit for review.

Inspired by @anntzer's comment here: #31001 (comment)

Profiling script I had Claude throw together:

Details
"""Profile the optimized _api check functions.

Compares OLD style (kwargs, inline check) vs NEW style (curried, cached).
"""

import difflib
import itertools
import timeit
import numpy as np

N = 500_000


# ============================================================================
# OLD IMPLEMENTATIONS (exact code from before the changes)
# ============================================================================

def old_check_isinstance(types, /, **kwargs):
    """Exact old implementation."""
    none_type = type(None)
    types = ((types,) if isinstance(types, type) else
             (none_type,) if types is None else
             tuple(none_type if tp is None else tp for tp in types))

    def type_name(tp):
        return ("None" if tp is none_type
                else tp.__qualname__ if tp.__module__ == "builtins"
                else f"{tp.__module__}.{tp.__qualname__}")

    for k, v in kwargs.items():
        if not isinstance(v, types):
            names = [*map(type_name, types)]
            if "None" in names:
                names.remove("None")
                names.append("None")
            raise TypeError(
                "{!r} must be an instance of {}, not a {}".format(
                    k,
                    ", ".join(names[:-1]) + " or " + names[-1]
                    if len(names) > 1 else names[0],
                    type_name(type(v))))


def old_check_in_list(values, /, *, _print_supported_values=True, **kwargs):
    """Exact old implementation."""
    if not kwargs:
        raise TypeError("No argument to check!")
    for key, val in kwargs.items():
        try:
            exists = val in values
        except ValueError:
            exists = False
        if not exists:
            msg = f"{val!r} is not a valid value for {key}"
            if _print_supported_values:
                msg += f"; supported values are {', '.join(map(repr, values))}"
            raise ValueError(msg)


def old_check_shape(shape, /, **kwargs):
    """Exact old implementation."""
    for k, v in kwargs.items():
        data_shape = v.shape

        if (len(data_shape) != len(shape)
                or any(s != t and t is not None for s, t in zip(data_shape, shape))):
            dim_labels = iter(itertools.chain(
                'NMLKJIH',
                (f"D{i}" for i in itertools.count())))
            text_shape = ", ".join([str(n) if n is not None else next(dim_labels)
                                    for n in shape[::-1]][::-1])
            if len(shape) == 1:
                text_shape += ","

            raise ValueError(
                f"{k!r} must be {len(shape)}D with shape ({text_shape}), "
                f"but your input has shape {v.shape}"
            )


def old_check_getitem(mapping, /, _error_cls=ValueError, **kwargs):
    """Exact old implementation."""
    if len(kwargs) != 1:
        raise ValueError("check_getitem takes a single keyword argument")
    (k, v), = kwargs.items()
    try:
        return mapping[v]
    except KeyError:
        if len(mapping) > 5:
            if len(best := difflib.get_close_matches(v, mapping.keys(), cutoff=0.5)):
                suggestion = f"Did you mean one of {best}?"
            else:
                suggestion = ""
        else:
            suggestion = f"Supported values are {', '.join(map(repr, mapping))}"
        raise _error_cls(f"{v!r} is not a valid value for {k}. {suggestion}") from None


# ============================================================================
# PROFILING
# ============================================================================

def profile_all():
    from matplotlib import _api

    print(f"Profiling _api check functions ({N:,} iterations)")
    print("=" * 75)
    print()
    print(f"{'Function':<20} {'Old':>10} {'New+Cache':>12} {'New-Cache':>12} {'Speedup':>10}")
    print("-" * 75)

    # --- check_isinstance ---
    types = (int, float)
    value = 1.0

    # Warm up
    old_check_isinstance(types, v=value)
    t_old = timeit.timeit(lambda: old_check_isinstance(types, v=value), number=N)

    t_new_hit = timeit.timeit(lambda: _api.check_isinstance(types)("v", value), number=N)

    def new_miss():
        _api.check_isinstance.cache_clear()
        return _api.check_isinstance(types)("v", value)
    t_new_miss = timeit.timeit(new_miss, number=N)

    old_us = t_old * 1e6 / N
    hit_us = t_new_hit * 1e6 / N
    miss_us = t_new_miss * 1e6 / N
    print(f"{'check_isinstance':<20} {old_us:>9.2f}µs {hit_us:>11.2f}µs {miss_us:>11.2f}µs {t_old/t_new_hit:>9.1f}x")

    # --- check_in_list ---
    values = ("left", "center", "right")
    value = "center"

    # Warm up
    old_check_in_list(values, v=value)
    t_old = timeit.timeit(lambda: old_check_in_list(values, v=value), number=N)

    t_new_hit = timeit.timeit(lambda: _api.check_in_list(values)("v", value), number=N)

    def new_miss():
        _api._check_in_tuple.cache_clear()
        return _api.check_in_list(values)("v", value)
    t_new_miss = timeit.timeit(new_miss, number=N)

    old_us = t_old * 1e6 / N
    hit_us = t_new_hit * 1e6 / N
    miss_us = t_new_miss * 1e6 / N
    print(f"{'check_in_list':<20} {old_us:>9.2f}µs {hit_us:>11.2f}µs {miss_us:>11.2f}µs {t_old/t_new_hit:>9.1f}x")

    # --- check_shape ---
    shape = (None, 2)
    arr = np.array([[1, 2], [3, 4]])

    # Warm up
    old_check_shape(shape, arr=arr)
    t_old = timeit.timeit(lambda: old_check_shape(shape, arr=arr), number=N)

    t_new_hit = timeit.timeit(lambda: _api.check_shape(shape)("arr", arr), number=N)

    def new_miss():
        _api.check_shape.cache_clear()
        return _api.check_shape(shape)("arr", arr)
    t_new_miss = timeit.timeit(new_miss, number=N)

    old_us = t_old * 1e6 / N
    hit_us = t_new_hit * 1e6 / N
    miss_us = t_new_miss * 1e6 / N
    print(f"{'check_shape':<20} {old_us:>9.2f}µs {hit_us:>11.2f}µs {miss_us:>11.2f}µs {t_old/t_new_hit:>9.1f}x")

    # --- check_getitem (no caching in new impl) ---
    mapping = {"left": 0, "center": 0.5, "right": 1}
    value = "center"

    # Warm up
    old_check_getitem(mapping, v=value)
    t_old = timeit.timeit(lambda: old_check_getitem(mapping, v=value), number=N)
    t_new = timeit.timeit(lambda: _api.check_getitem(mapping)("v", value), number=N)

    old_us = t_old * 1e6 / N
    new_us = t_new * 1e6 / N
    print(f"{'check_getitem':<20} {old_us:>9.2f}µs {new_us:>11.2f}µs {'N/A':>12} {t_old/t_new:>9.1f}x")

    print("-" * 75)
    print()
    print("Notes:")
    print("  - 'Old' = kwargs-based API (before changes)")
    print("  - 'New+Cache' = curried API with cache hit (typical usage)")
    print("  - 'New-Cache' = curried API with cache miss every call")
    print("  - 'Speedup' = Old / New+Cache")
    print()


if __name__ == "__main__":
    profile_all()

PR checklist

@timhoffm
Copy link
Member

timhoffm commented Feb 4, 2026

check_getitem() is somewhat different from the other check_* methods. Both do some form of input validation and error out with a nice error message. But while the check_* only do an assertion, check_getitem() retrieves a value. It's also the only function that has a return value.

A function named check_... should not return something (or if it returns something, it'd intuitively return the result of the check). I therefore propose that re rename check_getitem() to getitem_checked() or similar. This also means we do not necessarily have to move it to the new API since it's somewhat different - but we still could if similarity is desired.

@timhoffm
Copy link
Member

timhoffm commented Feb 4, 2026

Random thought: The lists in check_in_list are ad-hoc generated, but they are actually program-level constants. We could formalize this into a concept like:

class Axes:

    _ORIENTATIONS = ValueGroup("orientation", ('horizontal', 'vertical'))

    def bar(..., orentation=...):
        self._ORENTATIONS.assert_contains(orientation)

instead of

    def bar(..., orentation=...):
        _api.check_in_list(("horizontal", "vertical"))("orientation", orientation)

Whether _ORIENTATIONS lives in the class or in a separate module can be decided. It could be a common ground truth e.g. also for rcParams.

Advantages:

  • This should be even faster than the proposal here. No objects are created, it's only a method call and a __contains__ check.
  • More formalized concept for permissible parameter values.

Disadvantage:

  • While still being redable, it's less expressive because one cannot see the actual valid values at the usage site.

@scottshambaugh
Copy link
Contributor Author

scottshambaugh commented Feb 4, 2026

That's a good point, renamed to getitem_checked. I kept the change in the API since we now no longer have the call pattern (args, /, **kwargs) anywhere else as far as I can see.

For the lists of valid arguments, I'm so-so on a new class, but we should definitely consolidate the valid values anywhere that we reuse the same arguments in different places. Here's what I find:

  • ('horizontal', 'vertical') - Axes methods (4) + widget classes (3)
  • ('tick1', 'tick2', 'grid') - GeoAxes (2) and PolarAxes (2) axis transforms
  • ('major', 'minor', 'both') - Axis methods (3) + FigureBase.autofmt_xdate
  • (None, "equal", "transform") - ParasiteAxesBase methods (2)
  • ('z', 'dzdx', 'dzdy') - TriInterpolator subclasses (2)
  • ('x', 'y', 'both') - _AxesBase methods (2)
  • ('data', 'pixels') - RectangleSelector (2)
  • ('center', 'right', 'left') - Text alignment (2)
  • ('bottom', 'center', 'top') - Axes.bar_label, _AxesBase.set_ylabel (2)
  • ("center", "left", "right") - Legend methods (2)

@timhoffm
Copy link
Member

timhoffm commented Feb 4, 2026

Let's put the check_in_list alternative to the side for now. It was more a spark of an idea. But giben it seems not to be called too often, let's not completely reinvent checking.

On the original proposal, I find
_api.check_shape((None, 2))("arg", arg) is less readable than _api.check_shape((None, 2), arg=arg).
The separate call and stating the parameter once as string, once as variable look quite busy. Though admittedly, the current form is also not very readable, but I can live with the arg=arg ideom.

Overall, I'm a bit unclear where the performance advantage comes from. The happy path is:

for k, v in kwargs.items():
data_shape = v.shape
if (len(data_shape) != len(shape)
or any(s != t and t is not None for s, t in zip(data_shape, shape))):

It seems there is not much to be gained from caching.

How is the performance for

def check_shape_oldapi(shape, **kwargs):
    for k, v in kwargs.items():
        _api.check_shape(shape)(k, v)

is it almost as fast as _api.check_shape(shape)("arg", arg) - if so, congrats you've found a more efficient implementation. But that does not require a new API; we could do the caching internally.
Or is it as slow as the old implementation - in that case the kwargs handling is the culprit and we could keep a single function call and just remove the multi arg option to speed up, e.g. _api.check_shape((None, 2)), "arg", arg). Or alternatively chane the old way to this and benchmark it directly.

Note: The multi-arg open came basically for free through the kwarg approach. But it's not something we really need or make substantial use of.

Edit: AI tells me that **kwargs are 2-3 times slower than positional arguments. Since the happy path doesn't do a lot, it might indeed be that **kwargs is a substantial factor. But t.b.d.

@scottshambaugh
Copy link
Contributor Author

I'm less and less convinced this change is worth it. Profiling at the whole draw cycle level with the script in #31001, I don't see an appreciable change in top-line percentage of call time spent in these functions (~0.8% of total time). The caching is added 20% overhead for each cache miss (ie the first time we call each function), which is partially offsetting the gains here. I think our time is better spent on other PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
X Tutup