Skip to content

unaiverse.modules.cnu.psi

What this module does 🟢

Provides the psi feature-mapping function and 1D/2D resize helpers used by the CNU module to project input features into the fixed-size key space via interpolation and normalization.

psi

█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░ A Collectionless AI Project (https://collectionless.ai) Registration/Login: https://unaiverse.io Code Repositories: https://github.com/collectionlessai/ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi

psi

psi(x, mode, key_size, normalize=True)

Apply a projection function (psi) that maps a tensor to a fixed-size key vector.

The psi function is a configurable projection layer used in Contextual Neural Units (CNU) to encode an input tensor into a fixed-length representation of size key_size. Several projection strategies are supported, ranging from flat identity mappings to spatially-aware resizing and sign-binarisation. After projection an optional L2 normalisation is applied along the feature dimension.

The mode determines which internal helper is called:

  • "identity": flatten the input and use it as-is.
  • "sign": flatten and binarise each element via torch.sign.
  • "resize1d": 1-D linear interpolation (see resize1d).
  • "resize2d": 2-D bilinear interpolation then flatten (see resize2d).
  • "resize2d_sign": same as "resize2d" followed by torch.sign.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, ...). For "resize2d" and "resize2d_sign" the shape must be (batch, channels, height, width).

required
mode

Projection strategy. Must be one of "identity", "sign", "resize1d", "resize2d", or "resize2d_sign".

required
key_size

Target dimensionality of the output key vector. The projected tensor must have exactly this many features after flattening.

required
normalize

If True, the output is L2-normalised along the feature dimension (dim=1) before returning. Defaults to True.

True

Returns:

Type Description

A tensor of shape (batch, key_size), optionally L2-normalised.

Raises:

Type Description
NotImplementedError

If mode is not one of the recognised strings.

AssertionError

If the projection result does not match key_size (e.g., the chosen mode cannot map the input to the requested size).

Examples:

>>> import torch
>>> x = torch.randn(4, 3, 8, 8)  # batch of 4 RGB 8x8 images
>>> key = psi(x, mode="resize2d", key_size=64)
>>> key.shape
torch.Size([4, 64])
>>> key_bin = psi(x, mode="resize2d_sign", key_size=64, normalize=False)
>>> key_bin.shape
torch.Size([4, 64])
Source code in unaiverse/modules/cnu/psi.py
def psi(x, mode, key_size, normalize=True):
    """Apply a projection function (psi) that maps a tensor to a fixed-size key vector.

    The psi function is a configurable projection layer used in Contextual Neural Units
    (CNU) to encode an input tensor into a fixed-length representation of size
    ``key_size``. Several projection strategies are supported, ranging from flat
    identity mappings to spatially-aware resizing and sign-binarisation. After
    projection an optional L2 normalisation is applied along the feature dimension.

    The mode determines which internal helper is called:

    - ``"identity"``: flatten the input and use it as-is.
    - ``"sign"``: flatten and binarise each element via ``torch.sign``.
    - ``"resize1d"``: 1-D linear interpolation (see ``resize1d``).
    - ``"resize2d"``: 2-D bilinear interpolation then flatten (see ``resize2d``).
    - ``"resize2d_sign"``: same as ``"resize2d"`` followed by ``torch.sign``.

    Args:
        x: Input tensor of shape ``(batch, ...)``. For ``"resize2d"`` and
            ``"resize2d_sign"`` the shape must be ``(batch, channels, height, width)``.
        mode: Projection strategy. Must be one of ``"identity"``, ``"sign"``,
            ``"resize1d"``, ``"resize2d"``, or ``"resize2d_sign"``.
        key_size: Target dimensionality of the output key vector. The projected
            tensor must have exactly this many features after flattening.
        normalize: If ``True``, the output is L2-normalised along the feature
            dimension (``dim=1``) before returning. Defaults to ``True``.

    Returns:
        A tensor of shape ``(batch, key_size)``, optionally L2-normalised.

    Raises:
        NotImplementedError: If ``mode`` is not one of the recognised strings.
        AssertionError: If the projection result does not match ``key_size``
            (e.g., the chosen mode cannot map the input to the requested size).

    Examples:
        >>> import torch
        >>> x = torch.randn(4, 3, 8, 8)  # batch of 4 RGB 8x8 images
        >>> key = psi(x, mode="resize2d", key_size=64)
        >>> key.shape
        torch.Size([4, 64])
        >>> key_bin = psi(x, mode="resize2d_sign", key_size=64, normalize=False)
        >>> key_bin.shape
        torch.Size([4, 64])
    """
    if mode == "identity":
        o = x.flatten(start_dim=1)
    elif mode == "sign":
        o = torch.sign(x.flatten(start_dim=1))
    elif mode == "resize1d":
        o = resize1d(x, key_size)
    elif mode == "resize2d":
        o = resize2d(x, key_size)
    elif mode == "resize2d_sign":
        o = torch.sign(resize2d(x, key_size))
    else:
        raise NotImplementedError
    assert o.shape[1] == key_size, \
        "The selected psi function (" \
        + str(mode) + ") cannot map data to the target " \
                      "key_size (data_size: " + str(o.shape[1]) + ", key_size: " + str(key_size) + ")"
    if normalize:
        o = F.normalize(o, p=2.0, dim=1, eps=1e-12, out=None)
    return o

resize1d

resize1d(I, key_size)

Resize a 1-D feature tensor to a target length via linear interpolation.

If the input already has the required number of features, it is returned unchanged. Otherwise torch.nn.functional.interpolate is used with mode="linear" to up- or down-sample along the feature dimension. The channel dimension expected by interpolate is handled transparently with unsqueeze/squeeze so the caller does not need to reshape the tensor.

Parameters:

Name Type Description Default
I

Input tensor of shape (batch, features).

required
key_size

Target number of features in the output.

required

Returns:

Type Description

A tensor of shape (batch, key_size) with values resampled by linear

interpolation.

Source code in unaiverse/modules/cnu/psi.py
def resize1d(I, key_size):
    """Resize a 1-D feature tensor to a target length via linear interpolation.

    If the input already has the required number of features, it is returned
    unchanged. Otherwise ``torch.nn.functional.interpolate`` is used with
    ``mode="linear"`` to up- or down-sample along the feature dimension. The
    channel dimension expected by ``interpolate`` is handled transparently with
    ``unsqueeze``/``squeeze`` so the caller does not need to reshape the tensor.

    Args:
        I: Input tensor of shape ``(batch, features)``.
        key_size: Target number of features in the output.

    Returns:
        A tensor of shape ``(batch, key_size)`` with values resampled by linear
        interpolation.
    """
    if I.shape[1] == key_size:
        pass
    else:
        I = F.interpolate(I.unsqueeze(1), size=key_size, mode="linear").squeeze(1)
    return I

resize2d

resize2d(I, key_size)

Resize a 4-D image tensor to a target flat key size via bilinear interpolation.

The target key_size is divided evenly across the channel dimension to yield a per-channel spatial budget (spatial_key_size = key_size // c). A new (height, width) is computed that keeps the original aspect ratio as closely as possible while matching that budget. The image is then resized with bilinear interpolation and flattened. If the aspect-ratio rounding causes the flattened size to fall short of key_size, the remainder is zero-padded so the output always has exactly key_size features.

Parameters:

Name Type Description Default
I

Input tensor of shape (batch, channels, height, width).

required
key_size

Target number of features after flattening. Must be divisible by the number of channels for an exact spatial split.

required

Returns:

Type Description

A tensor of shape (batch, key_size). Values beyond the bilinearly

resampled region (if any) are zero-padded.

Source code in unaiverse/modules/cnu/psi.py
def resize2d(I, key_size):
    """Resize a 4-D image tensor to a target flat key size via bilinear interpolation.

    The target ``key_size`` is divided evenly across the channel dimension to yield
    a per-channel spatial budget (``spatial_key_size = key_size // c``). A new
    ``(height, width)`` is computed that keeps the original aspect ratio as closely
    as possible while matching that budget. The image is then resized with bilinear
    interpolation and flattened. If the aspect-ratio rounding causes the flattened
    size to fall short of ``key_size``, the remainder is zero-padded so the output
    always has exactly ``key_size`` features.

    Args:
        I: Input tensor of shape ``(batch, channels, height, width)``.
        key_size: Target number of features after flattening. Must be divisible by
            the number of channels for an exact spatial split.

    Returns:
        A tensor of shape ``(batch, key_size)``. Values beyond the bilinearly
        resampled region (if any) are zero-padded.
    """
    b, c, h, w = I.shape
    spatial_key_size = key_size // c
    ratio = float(spatial_key_size) / float(w * h)
    w = int(round(math.sqrt(ratio) * w))
    h = spatial_key_size // w
    remainder = key_size - (c * h * w)
    o = F.interpolate(I, size=(h, w), mode="bilinear").flatten(start_dim=1)
    if h * w < spatial_key_size:
        o = torch.cat([o, torch.zeros((b, remainder), device=o.device, dtype=o.dtype)], dim=1)
    return o