diff --git a/imgcat/imgcat.py b/imgcat/imgcat.py index e192667..9fb63fb 100644 --- a/imgcat/imgcat.py +++ b/imgcat/imgcat.py @@ -98,14 +98,14 @@ def to_content_buf(data: Any) -> bytes: # numpy ndarray: convert to png import numpy im: 'numpy.ndarray' = data + if im.dtype.kind == 'f': + # https://stackoverflow.com/a/66862750 + im = (im * 256).clip(0, 255).astype('uint8') if len(im.shape) == 2: mode = 'L' # 8-bit pixels, grayscale - im = im.astype(sys.modules['numpy'].uint8) elif len(im.shape) == 3 and im.shape[2] in (1, 3, 4): # (H, W, C) format mode = None # RGB/RGBA - if im.dtype.kind == 'f': - im = (im * 255).astype('uint8') if im.shape[2] == 1: mode = 'L' # 8-bit grayscale im = numpy.squeeze(im, axis=2) @@ -113,8 +113,6 @@ def to_content_buf(data: Any) -> bytes: # (C, H, W) format mode = None # RGB/RGBA im = numpy.rollaxis(im, 0, 3) # CHW -> HWC - if im.dtype.kind == 'f': - im = (im * 255).astype('uint8') if im.shape[2] == 1: mode = 'L' # 8-bit grayscale im = numpy.squeeze(im, axis=2)