Numerically Stable Softmax and Cross Entropy

December 15, 2022


In this post, we'll take a look at softmax and cross entropy loss. We'll see that naive implementations result in numerically instability, and then we'll derive implementations that are numerically stable.

Symbols


Softmax


The softmax function is defined as:
\[ \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} \]

In code:

def softmax(x):
    # assumes x is a vector
    return np.exp(x) / np.sum(np.exp(x))

x = np.array([1.2, 2, -4, 0.0])
softmax(x)
# outputs: [0.28310553, 0.63006295, 0.00156177, 0.08526975]

For large inputs we start seeing some numerical instability:

x = np.array([1.2, 2000, -4000, 0.0])
softmax(x)
# outputs: [0., nan, 0.,  0.]

Floating point numbers aren't magic, they have limits:

np.finfo(np.float64).max
# 1.7976931348623157e+308, largest positive number

np.finfo(np.float64).tiny
# 2.2250738585072014e-308, smallest positive number at full precision

np.finfo(np.float64).smallest_subnormal
# 5e-324, smallest positive number

When we go beyond these limits, we start seeing funky behaviour:

np.finfo(np.float64).max * 2
# inf, overflow error

np.inf - np.inf
# nan, not a number error

np.finfo(np.float64).smallest_subnormal / 2
# 0.0, underflow error

Looking back at our softmax example that resulted in [0., nan, 0., 0.], we can see that the overflow of np.exp(2000) = np.inf is causing the nan, since we end up with np.inf / np.inf = nan. If we want to avoid nans, we need to avoid infs. To avoid infs, we need to avoid overflows. To avoid overflows, we need to prevent our numbers from growing too large.

Underflows on the other hand don't seem quite as detrimental. Worst case scenario, we get a result 0 and lose all precision (i.e. np.exp(-4000) = 0). While this is not ideal, this is a lot better than running into inf and nan.

Given the relative stability of floating point underflows vs overflows, how can we fix softmax?

Let's revisit our softmax equation and apply some tricks:
\[ \begin{align} \text{softmax}(x)_i &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\ &= 1\cdot \frac{e^{x_i}}{\sum_j e^{x_j}} \\ &= \frac{C}{C}\frac{e^{x_i}}{\sum_j e^{x_j}} \\ &= \frac{Ce^{x_i}}{\sum_j Ce^{x_j}} \\ &= \frac{e^{x_i + \log C}}{\sum_j e^{x_j + \log C}} \\ \end{align} \]
Here, we're taking advantage of the rule \(a\cdot b^x = b^{x + \log_b a}\). As a result, we are given the ability to offset our inputs by any constant we choose. So, if we want to prevent large numbers, we can simply set \(\log C = -\max(x)\):
\[ \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \]

This version is numerically stable:

In code:

def softmax(x):
    # assumes x is a vector
    x = x - np.max(x)
    return np.exp(x) / np.sum(np.exp(x))

x = np.array([1.2, 2, -4, 0])
softmax(x)
# outputs: [0.28310553, 0.63006295, 0.00156177, 0.08526975]

# works for large numbers!!!
x = np.array([1.2, 2, -4, 0]) * 1000
softmax(x)
# outputs: [0., 1., 0., 0.]

Cross Entropy and Log Softmax


The cross entropy between two probability distributions is defined as.
\[ H(p, q) = -\sum_i p_i\log(q_i) \]
where \(p\) and \(q\) are our probability distributions represented as probability vectors (that is \(p_i\) and \(q_i\)are the probabilities of event \(i\) occuring for \(p\) and \(q\) respectively).

Roughly speaking, cross entropy measures the similarity between two probability distributions, so in the context of machine learning, we use cross entropy as a loss function where:

In this setup, our cross entropy simplifies to:
\[ \begin{align} H(p, q) &= -\sum_i p_i\log(q_i) \\ &= -p_y\cdot\log(q_y) -\sum_{i \neq y} p_i\log(q_i) \\ &= -1\cdot\log(q_y) -\sum_{i \neq y} 0\cdot\log(q_i) \\ &= -\log(q_y) - 0 \sum_{i \neq y} \log(q_i) \\ &= -\log(q_y) \\ &= -\log(\text{softmax}(\hat{y})_y) \end{align} \]

In code:

def cross_entropy(y_hat, y_true):
    # assume y_hat is a vector and y_true is an integer
    return -np.log(softmax(y_hat)[y_true])

cross_entropy(
    y_hat=np.random.normal(size=(10)),
    y_true=3,
)
# 2.580982279204241

For large numbers in y_hat, we start seeing inf:

cross_entropy(
    y_hat = np.array([-1000, 1000]),
    y_true = 0,
)
# inf

In this case softmax(y_hat) = [0, 1], with y_true = 0, we get -log(0) = inf. So to make log(softmax(x)) stable, we need to prevent log(0). Let's see if we can employ some math tricks to make this happen:
\[ \begin{align} \log(\text{softmax}(x)_i) & = \log(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}) \\ &= \log(e^{x_i - \max(x)}) - \log(\sum_j e^{x_j - \max(x)}) \\ &= (x_i - \max(x))\log(e) - \log(\sum_j e^{x_j - \max(x)}) \\ &= (x_i - \max(x))\cdot 1 - \log(\sum_j e^{x_j - \max(x)}) \\ &= x_i - \max(x) - \log(\sum_j e^{x_j - \max(x)}) \\ \end{align} \]
This new equation guarentees that the sum inside the log will always be \(\geq 1\), so we no longer need to worry about log(0) errors.

In code:

def log_softmax(x):
    # assumes x is a vector
    x_max = np.max(x)
    return x - x_max - np.log(np.sum(np.exp(x - x_max)))

def cross_entropy(y_hat, y_true):
    return -log_softmax(y_hat)[y_true]

cross_entropy(
    y_hat=np.random.normal(size=(10)),
    y_true=3,
)
# 2.580982279204241

# works for large inputs!!!!
cross_entropy(
    y_hat = np.array([-1000, 1000]),
    y_true = 3,
)
# 705.3963550098291