Skip to content

unaiverse.modules.cnu.cnus

What this module does 🔴

Implements the core Conditional Neural Units (CNUs) module: a key-addressable memory bank with top-k attention, key normalization, and scrambling logic for dynamically generating layer weights.

cnus

█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░ A Collectionless AI Project (https://collectionless.ai) Registration/Login: https://unaiverse.io Code Repositories: https://github.com/collectionlessai/ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi

CNUs

CNUs(q=1, d=2, m=3, u=4, delta=3, gamma_alpha=0.1, tau_alpha=0.5, tau_mu=100, tau_eta=100, upd_m='WTA', upd_k='ad_hoc_WTA', beta_k=0.001, psi_fn='identity', scramble=False)

Bases: Module

Contextual Neural Units: an attention-based key-value memory module for PyTorch.

Each CNUs layer holds q neurons. Every neuron maintains a bank of m learnable keys (each of size d) and m corresponding memory units (each of size u). At inference time the module maps an input vector to the key space via a configurable projection function psi, computes dot-product attention scores against all keys, selects the top-delta winning keys, and blends the associated memory units by their softmax attention weights to produce a u-dimensional output per neuron.

Two complementary learning regimes are supported:

  • upd_k="ad_hoc_WTA" -- keys are updated online without backpropagation through the key-matching step: the winning key is nudged toward the current input by beta_k. An optional scrambling mechanism (scramble=True) replaces under-used, stale keys with fresh inputs sampled from the current mini-batch.
  • upd_k="grad_WTA" / upd_k=None -- keys are standard nn.Parameter tensors trained by gradient descent through the full module.

Memory units (self.M) are always nn.Parameter tensors optimised by the surrounding optimizer, regardless of the key-update strategy.

Attributes:

Name Type Description
q

Number of neurons.

d

Dimensionality of each key vector.

m

Number of key-memory pairs per neuron.

u

Dimensionality of each memory unit (and of the per-neuron output).

delta

Number of top attention responses to select (top-delta).

gamma_alpha

Softmax temperature applied to the top-delta dot products.

tau_alpha

Attention-score threshold below which scrambling may be triggered.

tau_mu

Usage count threshold; keys used fewer than tau_mu times are considered under-used.

tau_eta

Age threshold; keys whose age exceeds tau_eta steps are considered stale.

scramble

Whether the key/memory scrambling routine is active.

upd_m

Memory update strategy (None or "WTA").

upd_k

Key update strategy (None, "ad_hoc_WTA", or "grad_WTA").

beta_k

Learning rate used by the "ad_hoc_WTA" key-update rule.

psi_fn

Name of the projection function used to map inputs to the key space.

M

Memory tensor of shape (q, m, u); always an nn.Parameter.

K

Key tensor of shape (q, m, d); an nn.Parameter unless upd_k="ad_hoc_WTA", in which case it is a registered buffer.

mu

Per-neuron usage counters of shape (q, m); None unless upd_k="ad_hoc_WTA".

eta

Per-neuron age counters of shape (q, m); None unless upd_k="ad_hoc_WTA".

scrambling_count

Number of scrambling operations performed per neuron, shape (q,).

reset_memories

If True, reset_parameters also re-initializes M. Defaults to True.

Examples:

Create a CNUs layer with 2 neurons, 8 keys, and 16-dimensional memories, using the default ad-hoc WTA key-update strategy:

>>> import torch
>>> from unaiverse.modules.cnu.cnus import CNUs
>>> layer = CNUs(q=2, d=4, m=8, u=16, delta=3,
...              gamma_alpha=0.1, tau_alpha=0.5,
...              tau_mu=100, tau_eta=100,
...              upd_m="WTA", upd_k="ad_hoc_WTA",
...              beta_k=0.001, psi_fn="identity", scramble=True)
>>> layer
CNUs()

Initialize the CNUs layer and allocate all keys, memories, and counters.

Parameters are validated with assertions before any tensor allocation takes place. Keys are initialised uniformly in [-1/sqrt(d), 1/sqrt(d)] and then L2-normalized along the key dimension. Memories are initialised uniformly in [-1/sqrt(u), 1/sqrt(u)]. Usage counters mu are zeroed and age counters eta are set to tau_eta so that all keys are immediately considered available for scrambling.

Parameters:

Name Type Description Default
q

Number of neurons. Defaults to 1.

1
d

Dimensionality of each key vector. Defaults to 2.

2
m

Number of key-memory pairs per neuron. Defaults to 3.

3
u

Dimensionality of each memory unit (and of the per-neuron output). Defaults to 4.

4
delta

Number of top-scoring keys to select at each forward step. Clamped to min(delta, m). Defaults to 3.

3
gamma_alpha

Softmax temperature applied to the top-delta dot-product scores before blending memories. Defaults to 0.1.

0.1
tau_alpha

Threshold on the top-1 attention response below which scrambling is considered (requires scramble=True). Defaults to 0.5.

0.5
tau_mu

Usage count threshold. Keys whose cumulative usage count is below this value are deemed under-used. Defaults to 100.

100
tau_eta

Age threshold in steps. Keys older than this value are deemed stale. Defaults to 100.

100
upd_m

Memory update strategy. Must be None (gradient only) or "WTA" (winner-takes-all gradient). Defaults to "WTA".

'WTA'
upd_k

Key update strategy. Must be None (gradient only), "ad_hoc_WTA" (online WTA without backprop), or "grad_WTA" (gradient-based WTA). If upd_m="WTA", upd_k must not be None. Defaults to "ad_hoc_WTA".

'ad_hoc_WTA'
beta_k

Learning rate for the "ad_hoc_WTA" key update. Ignored when upd_k is not "ad_hoc_WTA". Defaults to 0.001.

0.001
psi_fn

Name of the function (from psi) used to project the input tensor to the key space before dot-product attention. Defaults to "identity".

'identity'
scramble

Whether to enable the key/memory scrambling routine. When True and upd_k="ad_hoc_WTA", under-used stale keys are periodically replaced with current inputs. Defaults to False.

False

Raises:

Type Description
AssertionError

If upd_m is not None or "WTA".

AssertionError

If upd_k is not None, "ad_hoc_WTA", or "grad_WTA".

AssertionError

If upd_m="WTA" and upd_k=None.

Examples:

>>> import torch
>>> from unaiverse.modules.cnu.cnus import CNUs
>>> layer = CNUs(q=1, d=4, m=8, u=16)
>>> x = torch.randn(32, 4)  # batch of 32, input dim 4
Source code in unaiverse/modules/cnu/cnus.py
def __init__(self, q=1, d=2, m=3, u=4, delta=3,
             gamma_alpha=0.1, tau_alpha=0.5, tau_mu=100, tau_eta=100,
             upd_m="WTA", upd_k="ad_hoc_WTA",
             beta_k=0.001,
             psi_fn="identity",
             scramble=False):
    """Initialize the CNUs layer and allocate all keys, memories, and counters.

    Parameters are validated with assertions before any tensor allocation takes
    place. Keys are initialised uniformly in ``[-1/sqrt(d), 1/sqrt(d)]`` and then
    L2-normalized along the key dimension. Memories are initialised uniformly in
    ``[-1/sqrt(u), 1/sqrt(u)]``. Usage counters ``mu`` are zeroed and age counters
    ``eta`` are set to ``tau_eta`` so that all keys are immediately considered
    available for scrambling.

    Args:
        q: Number of neurons. Defaults to 1.
        d: Dimensionality of each key vector. Defaults to 2.
        m: Number of key-memory pairs per neuron. Defaults to 3.
        u: Dimensionality of each memory unit (and of the per-neuron output).
            Defaults to 4.
        delta: Number of top-scoring keys to select at each forward step.
            Clamped to ``min(delta, m)``. Defaults to 3.
        gamma_alpha: Softmax temperature applied to the top-``delta`` dot-product
            scores before blending memories. Defaults to 0.1.
        tau_alpha: Threshold on the top-1 attention response below which
            scrambling is considered (requires ``scramble=True``). Defaults to 0.5.
        tau_mu: Usage count threshold. Keys whose cumulative usage count is below
            this value are deemed under-used. Defaults to 100.
        tau_eta: Age threshold in steps. Keys older than this value are deemed
            stale. Defaults to 100.
        upd_m: Memory update strategy. Must be ``None`` (gradient only) or
            ``"WTA"`` (winner-takes-all gradient). Defaults to ``"WTA"``.
        upd_k: Key update strategy. Must be ``None`` (gradient only),
            ``"ad_hoc_WTA"`` (online WTA without backprop), or ``"grad_WTA"``
            (gradient-based WTA). If ``upd_m="WTA"``, ``upd_k`` must not be
            ``None``. Defaults to ``"ad_hoc_WTA"``.
        beta_k: Learning rate for the ``"ad_hoc_WTA"`` key update. Ignored when
            ``upd_k`` is not ``"ad_hoc_WTA"``. Defaults to 0.001.
        psi_fn: Name of the function (from ``psi``) used to project the input
            tensor to the key space before dot-product attention. Defaults to
            ``"identity"``.
        scramble: Whether to enable the key/memory scrambling routine. When
            ``True`` and ``upd_k="ad_hoc_WTA"``, under-used stale keys are
            periodically replaced with current inputs. Defaults to ``False``.

    Raises:
        AssertionError: If ``upd_m`` is not ``None`` or ``"WTA"``.
        AssertionError: If ``upd_k`` is not ``None``, ``"ad_hoc_WTA"``, or
            ``"grad_WTA"``.
        AssertionError: If ``upd_m="WTA"`` and ``upd_k=None``.

    Examples:
        >>> import torch
        >>> from unaiverse.modules.cnu.cnus import CNUs
        >>> layer = CNUs(q=1, d=4, m=8, u=16)
        >>> x = torch.randn(32, 4)  # batch of 32, input dim 4
    """

    super(CNUs, self).__init__()
    assert upd_m in (None, 'WTA'), "Unknown value for upd_m, it must be None or 'WTA'"
    assert upd_k in (None, 'ad_hoc_WTA', 'grad_WTA'), "Unknown value for upd_k, it must be " \
                                                      "None, 'ad_hoc_WTA', or 'grad_WTA'"
    assert upd_m is None or (upd_m == 'WTA' and upd_k is not None), \
        "If upd_m is 'WTA', then upd_k must be ad_hoc_WTA or grad_WTA (it cannot be None)"
    self.q = q
    self.d = d
    self.m = m
    self.u = u
    self.gamma_alpha = gamma_alpha
    self.tau_alpha = tau_alpha
    self.tau_mu = tau_mu
    self.tau_eta = tau_eta
    self.scramble = scramble
    self.delta = min(delta, self.m)
    self.upd_m = upd_m
    self.upd_k = upd_k
    self.beta_k = beta_k
    self.psi_fn = psi_fn
    self.debug = False  # Temporarily used
    self.reset_memories = True

    # Creating keys (self.K) and memories (self.M)
    self.M = torch.nn.Parameter(torch.empty((self.q, self.m, self.u), dtype=torch.float32))
    if self.upd_k == "ad_hoc_WTA":
        self.register_buffer('K', torch.zeros((self.q, self.m, self.d)))
    else:
        self.K = torch.nn.Parameter(torch.empty((self.q, self.m, self.d), dtype=torch.float32))

    # Buffers for ad_hoc_WTA key updates (average usefulness register buffer "mu" and age "eta")
    if self.upd_k == "ad_hoc_WTA":
        self.register_buffer('mu', torch.zeros(self.q, m, dtype=torch.float))
        self.register_buffer('eta', torch.ones((self.q, m), dtype=torch.float) * self.tau_eta)
        if self.debug:
            self.register_buffer('key_counter', torch.zeros(self.q, m, dtype=torch.float))
    else:
        self.mu = None
        self.eta = None

    # Scrambling stats
    self.register_buffer('scrambling_count', torch.zeros(self.q, dtype=torch.long))

    # Initializing memories and keys
    self.reset_parameters()

q instance-attribute

q = q

d instance-attribute

d = d

m instance-attribute

m = m

u instance-attribute

u = u

gamma_alpha instance-attribute

gamma_alpha = gamma_alpha

tau_alpha instance-attribute

tau_alpha = tau_alpha

tau_mu instance-attribute

tau_mu = tau_mu

tau_eta instance-attribute

tau_eta = tau_eta

scramble instance-attribute

scramble = scramble

delta instance-attribute

delta = min(delta, m)

upd_m instance-attribute

upd_m = upd_m

upd_k instance-attribute

upd_k = upd_k

beta_k instance-attribute

beta_k = beta_k

psi_fn instance-attribute

psi_fn = psi_fn

debug instance-attribute

debug = False

reset_memories instance-attribute

reset_memories = True

M instance-attribute

M = Parameter(empty((q, m, u), dtype=float32))

K instance-attribute

K = Parameter(empty((q, m, d), dtype=float32))

mu instance-attribute

mu = None

eta instance-attribute

eta = None

reset_parameters

reset_parameters()

Reset all learnable parameters and internal counters to their initial state.

Keys are re-initialized uniformly in [-1/sqrt(d), 1/sqrt(d)] and L2-normalized. Memories are re-initialized uniformly in [-1/sqrt(u), 1/sqrt(u)] only when reset_memories is True. Usage counters mu and age counters eta are reset to their construction-time defaults (zero and tau_eta respectively).

Note

Memory re-initialization is conditional on self.reset_memories so that callers can preserve learned memories while refreshing only keys and counters.

Source code in unaiverse/modules/cnu/cnus.py
def reset_parameters(self):
    """Reset all learnable parameters and internal counters to their initial state.

    Keys are re-initialized uniformly in ``[-1/sqrt(d), 1/sqrt(d)]`` and
    L2-normalized. Memories are re-initialized uniformly in
    ``[-1/sqrt(u), 1/sqrt(u)]`` only when ``reset_memories`` is ``True``. Usage
    counters ``mu`` and age counters ``eta`` are reset to their construction-time
    defaults (zero and ``tau_eta`` respectively).

    Note:
        Memory re-initialization is conditional on ``self.reset_memories`` so that
        callers can preserve learned memories while refreshing only keys and
        counters.
    """
    self.__reset_keys()
    if self.reset_memories:
        self.__reset_memories()
    self.__reset_counters()

compute_weights

compute_weights(x)

Compute the attention-weighted memory blend for each neuron given an input batch.

The input x is first projected to the key space by psi, then dot-product attention is computed against all m keys of each neuron. The top-delta responses are selected and passed through a scaled softmax to obtain attention weights alpha. The corresponding memory units are blended by these weights to produce the per-neuron output.

When upd_k="ad_hoc_WTA" the input is detached from the computational graph before key-matching, so no gradient flows to layers below through this step. When training with upd_k="ad_hoc_WTA", the keys and their counters are updated in-place before the final memory blend is computed.

When upd_m="WTA", the top-1 memory is scaled by its own attention weight (with a live gradient) and the remaining top-(delta-1) memories are blended from a detached copy of M. When upd_m=None all memories are blended with live gradients.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch_size, d). Each row is one input sample to be matched against the key bank.

required

Returns:

Type Description

A tensor of shape (batch_size, q, u) where entry [b, n, :]

is the attention-weighted blend of the u-dimensional memory units of

neuron n for batch element b.

Raises:

Type Description
NotImplementedError

If upd_m holds an unrecognised value (should not occur under normal construction constraints).

Examples:

>>> import torch
>>> from unaiverse.modules.cnu.cnus import CNUs
>>> layer = CNUs(q=2, d=4, m=8, u=16, upd_m=None, upd_k=None)
>>> x = torch.randn(32, 4)
>>> W = layer.compute_weights(x)  # shape: (32, 2, 16)
Source code in unaiverse/modules/cnu/cnus.py
def compute_weights(self, x):
    """Compute the attention-weighted memory blend for each neuron given an input batch.

    The input ``x`` is first projected to the key space by ``psi``, then
    dot-product attention is computed against all ``m`` keys of each neuron. The
    top-``delta`` responses are selected and passed through a scaled softmax to
    obtain attention weights ``alpha``. The corresponding memory units are blended
    by these weights to produce the per-neuron output.

    When ``upd_k="ad_hoc_WTA"`` the input is detached from the computational graph
    before key-matching, so no gradient flows to layers below through this step.
    When training with ``upd_k="ad_hoc_WTA"``, the keys and their counters are
    updated in-place before the final memory blend is computed.

    When ``upd_m="WTA"``, the top-1 memory is scaled by its own attention weight
    (with a live gradient) and the remaining top-``(delta-1)`` memories are blended
    from a detached copy of ``M``. When ``upd_m=None`` all memories are blended
    with live gradients.

    Args:
        x: Input tensor of shape ``(batch_size, d)``. Each row is one input
            sample to be matched against the key bank.

    Returns:
        A tensor of shape ``(batch_size, q, u)`` where entry ``[b, n, :]``
        is the attention-weighted blend of the ``u``-dimensional memory units of
        neuron ``n`` for batch element ``b``.

    Raises:
        NotImplementedError: If ``upd_m`` holds an unrecognised value (should not
            occur under normal construction constraints).

    Examples:
        >>> import torch
        >>> from unaiverse.modules.cnu.cnus import CNUs
        >>> layer = CNUs(q=2, d=4, m=8, u=16, upd_m=None, upd_k=None)
        >>> x = torch.randn(32, 4)
        >>> W = layer.compute_weights(x)  # shape: (32, 2, 16)
    """

    # Shortcuts (notice that "self.delta" is called "k" in shortcuts, while "self.delta-1" is called "z")
    q, m, u, d, k = self.q, self.m, self.u, self.d, self.delta
    b = x.shape[0]
    M_qmu = self.M

    # Ensuring keys are normalized (not needed with ad_hoc_WTA updates)
    if self.upd_k != 'ad_hoc_WTA':
        self.__normalize_keys()
    else:
        x = x.detach()  # In ad-hoc WTA, no gradient is propagated to the layers below through key-matching

    # Mapping the input to the key space using the psi function
    x_bd = psi(x, self.psi_fn, key_size=d, normalize=True)

    # Finding the top responses and indices for the attention procedure
    top_responses_bqk, top_indices_bqk = self.__top_k_attention(x_bd)

    # Probabilities
    top_alpha_bqk = torch.softmax((self.gamma_alpha / math.sqrt(d)) * top_responses_bqk, dim=2)

    if self.debug:

        # Getting the top-1 indices for the current mini-batch
        top1_indices_qb = top_indices_bqk[..., 0].t()
        self.key_counter.data.scatter_add_(dim=1,
                                           index=top1_indices_qb,
                                           src=torch.ones_like(top1_indices_qb, dtype=self.key_counter.dtype))

    # Updating keys with the ad-hoc scheme (also refreshing top-stuff: responses, indices, alpha)
    if self.training and self.upd_k == 'ad_hoc_WTA':
        top_responses_bqk, top_indices_bqk, top_alpha_bqk = \
            self.__update_keys_and_counters(x_bd, top_responses_bqk, top_indices_bqk, top_alpha_bqk)

    # Reading memories and blending them
    if self.upd_m is None:

        # Preparing to read memory units and to blend them
        M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)

        # Getting top memory units
        top_M_bqku = torch.gather(M_exp_bqmu, dim=2,
                                  index=top_indices_bqk.view(b, q, k, 1).expand(b, q, k, u))

        # Mixing memory units by attention scores
        # -> top_alpha_bqk: [b,q,k], that we un-squeeze to [b,q,1,k]
        # -> top_M_bqku: [b,q,k,u]
        # -> W_bqu: matmul([(b,q),1,k], [(b,q),k,u]) = [b,q,1,u] that we squeeze to [b,q,u]
        W_bqu = torch.matmul(top_alpha_bqk.view(b, q, 1, k), top_M_bqku).squeeze(2)

    elif self.upd_m == 'WTA':

        # Preparing to read memory units and to blend them
        M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)

        # Dealing with top-1 stuff
        top1_M_exp_bq1u = torch.gather(M_exp_bqmu, dim=2,
                                       index=top_indices_bqk[..., 0:1].view(b, q, 1, 1).expand(b, q, 1, u))

        # Mixing memory units by attention scores
        # -> top1_alpha_bqk: [b,q,k], that we select to [b,k,1] un-squeeze to [b,q,1,1]
        # -> top1_M_exp_bq1u: [b,q,1,u]
        # -> W_bqu: [b,q,1,1] * [b,q,1,u] = [b,q,1,u] that we squeeze to [b,q,u]
        top1_W_bqu = (top_alpha_bqk[..., 0:1].view(b, q, 1, 1) * top1_M_exp_bq1u).squeeze(2)

        # Dealing with top-2-and-following stuff
        top2on_M_exp_bqzu = torch.gather(M_exp_bqmu.detach(), dim=2,
                                         index=top_indices_bqk[..., 1:].view(b, q, k-1, 1).expand(b, q, k-1, u))
        top2on_alpha_bqz = top_alpha_bqk[:, :, 1:]
        if self.upd_k == 'grad_WTA':
            top2on_alpha_bqz = top2on_alpha_bqz.detach()

        # Mixing memory units by attention scores
        # -> top2on_alpha_bqz: [b,q,k-1], that we un-squeeze to [b,q,1,k-1]
        # -> top2on_M_exp_bqzu: [b,q,k-1,u]
        # -> W_bqu: matmul([(b,q),1,k-1], [(b,q),k-1,u]) = [b,q,1,u] that we squeeze to [b,q,u]
        top2on_W_bqu = torch.matmul(top2on_alpha_bqz.view(b, q, 1, k-1), top2on_M_exp_bqzu).squeeze(2)

        # Merging top1 and top-2-and-following stuff
        W_bqu = top1_W_bqu + top2on_W_bqu

    else:

        # What is going on?
        raise NotImplementedError

    return W_bqu

forward

forward(x)

Perform the forward pass.

CNUs is an abstract base class. Concrete subclasses must override this method to define how the attention-weighted memory blends produced by compute_weights are combined into the final module output.

Parameters:

Name Type Description Default
x

Input tensor passed through the module.

required

Raises:

Type Description
NotImplementedError

Always, because CNUs does not implement forward directly.

Source code in unaiverse/modules/cnu/cnus.py
def forward(self, x):
    """Perform the forward pass.

    ``CNUs`` is an abstract base class. Concrete subclasses must override this
    method to define how the attention-weighted memory blends produced by
    ``compute_weights`` are combined into the final module output.

    Args:
        x: Input tensor passed through the module.

    Raises:
        NotImplementedError: Always, because ``CNUs`` does not implement
            ``forward`` directly.
    """
    raise NotImplementedError

reset_counter

reset_counter()

Reset the debug key-usage counter to zero.

When self.debug is True, key_counter accumulates how many times each key has been selected as the top-1 winner across forward passes. This method zeros that buffer so counts can be measured over a fresh window.

Note

This method has no effect when self.debug is False.

Source code in unaiverse/modules/cnu/cnus.py
def reset_counter(self):
    """Reset the debug key-usage counter to zero.

    When ``self.debug`` is ``True``, ``key_counter`` accumulates how many times
    each key has been selected as the top-1 winner across forward passes. This
    method zeros that buffer so counts can be measured over a fresh window.

    Note:
        This method has no effect when ``self.debug`` is ``False``.
    """
    if self.debug:
        self.key_counter.data = torch.zeros_like(self.key_counter)