Skip to content

Commit d130006

Browse files
committed
simplify to_rgba() by extracting the part relating to RGBA data
1 parent e7d53e0 commit d130006

File tree

3 files changed

+52
-49
lines changed

3 files changed

+52
-49
lines changed

lib/matplotlib/artist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ def set_array(self, A):
14461446
A : array-like or None
14471447
The values that are mapped to colors.
14481448
1449-
The base class `.VectorMappable` does not make any assumptions on
1449+
The base class `.ColorizingArtist` does not make any assumptions on
14501450
the dimensionality and shape of the value array *A*.
14511451
"""
14521452
if A is None:
@@ -1466,7 +1466,7 @@ def get_array(self):
14661466
"""
14671467
Return the array of values, that are mapped to colors.
14681468
1469-
The base class `.VectorMappable` does not make any assumptions on
1469+
The base class `.ColorizingArtist` does not make any assumptions on
14701470
the dimensionality and shape of the array.
14711471
"""
14721472
return self._A

lib/matplotlib/cm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class ScalarMappable(colorizer.ColorizerShim):
279279
"""
280280
A mixin class to map one or multiple sets of scalar data to RGBA.
281281
282-
The VectorMappable applies data normalization before returning RGBA colors
282+
The ScalarMappable applies data normalization before returning RGBA colors
283283
from the given `~matplotlib.colors.Colormap`, `~matplotlib.colors.BivarColormap`,
284284
or `~matplotlib.colors.MultivarColormap`.
285285
"""

lib/matplotlib/colorizer.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
class Colorizer():
2525
"""
2626
Class that holds the data to color pipeline
27-
accessible via `.to_rgba(A)` and executed via
28-
the `.norm` and `.cmap` attributes.
27+
accessible via `Colorizer.to_rgba(A)` and executed via
28+
the `Colorizer.norm` and `Colorizer.cmap` attributes.
2929
"""
3030
def __init__(self, cmap=None, norm=None):
3131

@@ -125,56 +125,59 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
125125
126126
"""
127127
# First check for special case, image input:
128-
# First check for special case, image input:
129-
try:
130-
if x.ndim == 3:
131-
if x.shape[2] == 3:
132-
if alpha is None:
133-
alpha = 1
134-
if x.dtype == np.uint8:
135-
alpha = np.uint8(alpha * 255)
136-
m, n = x.shape[:2]
137-
xx = np.empty(shape=(m, n, 4), dtype=x.dtype)
138-
xx[:, :, :3] = x
139-
xx[:, :, 3] = alpha
140-
elif x.shape[2] == 4:
141-
xx = x
142-
else:
143-
raise ValueError("Third dimension must be 3 or 4")
144-
if xx.dtype.kind == 'f':
145-
# If any of R, G, B, or A is nan, set to 0
146-
if np.any(nans := np.isnan(x)):
147-
if x.shape[2] == 4:
148-
xx = xx.copy()
149-
xx[np.any(nans, axis=2), :] = 0
150-
151-
if norm and (xx.max() > 1 or xx.min() < 0):
152-
raise ValueError("Floating point image RGB values "
153-
"must be in the 0..1 range.")
154-
if bytes:
155-
xx = (xx * 255).astype(np.uint8)
156-
elif xx.dtype == np.uint8:
157-
if not bytes:
158-
xx = xx.astype(np.float32) / 255
159-
else:
160-
raise ValueError("Image RGB array must be uint8 or "
161-
"floating point; found %s" % xx.dtype)
162-
# Account for any masked entries in the original array
163-
# If any of R, G, B, or A are masked for an entry, we set alpha to 0
164-
if np.ma.is_masked(x):
165-
xx[np.any(np.ma.getmaskarray(x), axis=2), 3] = 0
166-
return xx
167-
except AttributeError:
168-
# e.g., x is not an ndarray; so try mapping it
169-
pass
170-
171-
# This is the normal case, mapping a scalar array:
128+
if isinstance(x, np.ndarray) and x.ndim == 3:
129+
return self._pass_image_data(x, alpha, bytes, norm)
130+
131+
# Otherwise run norm -> colormap pipeline
172132
x = ma.asarray(x)
173133
if norm:
174134
x = self.norm(x)
175135
rgba = self.cmap(x, alpha=alpha, bytes=bytes)
176136
return rgba
177137

138+
@staticmethod
139+
def _pass_image_data(x, alpha=None, bytes=False, norm=True):
140+
"""
141+
Helper function to pass ndarray of shape (...,3) or (..., 4)
142+
through `to_rgba()`, see `to_rgba()` for docstring.
143+
"""
144+
if x.shape[2] == 3:
145+
if alpha is None:
146+
alpha = 1
147+
if x.dtype == np.uint8:
148+
alpha = np.uint8(alpha * 255)
149+
m, n = x.shape[:2]
150+
xx = np.empty(shape=(m, n, 4), dtype=x.dtype)
151+
xx[:, :, :3] = x
152+
xx[:, :, 3] = alpha
153+
elif x.shape[2] == 4:
154+
xx = x
155+
else:
156+
raise ValueError("Third dimension must be 3 or 4")
157+
if xx.dtype.kind == 'f':
158+
# If any of R, G, B, or A is nan, set to 0
159+
if np.any(nans := np.isnan(x)):
160+
if x.shape[2] == 4:
161+
xx = xx.copy()
162+
xx[np.any(nans, axis=2), :] = 0
163+
164+
if norm and (xx.max() > 1 or xx.min() < 0):
165+
raise ValueError("Floating point image RGB values "
166+
"must be in the 0..1 range.")
167+
if bytes:
168+
xx = (xx * 255).astype(np.uint8)
169+
elif xx.dtype == np.uint8:
170+
if not bytes:
171+
xx = xx.astype(np.float32) / 255
172+
else:
173+
raise ValueError("Image RGB array must be uint8 or "
174+
"floating point; found %s" % xx.dtype)
175+
# Account for any masked entries in the original array
176+
# If any of R, G, B, or A are masked for an entry, we set alpha to 0
177+
if np.ma.is_masked(x):
178+
xx[np.any(np.ma.getmaskarray(x), axis=2), 3] = 0
179+
return xx
180+
178181
def normalize(self, x):
179182
"""
180183
Normalize the data in x.

0 commit comments

Comments
 (0)