Press n or j to go to the next uncovered block, b, p or k for the previous block.
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | 1x 377x 377x 1x 2738x 2738x 2738x 2738x 2738x 2738x 2738x 2738x 87336x 53138x 53138x 53134x 53134x 53134x 87336x 2738x 923x 923x 1815x 1815x 2738x 57858x 53134x 53134x 53134x 57858x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1815x 1x 161x 161x 161x 161x 161x 161x 21067x 21066x 21066x 21059x 21059x 21059x 161x 159x 159x 161x 21064x 21063x 21063x 21059x 21059x 21059x 159x 159x 159x | /**
* Welford-style online (streaming) statistics for RevIN normalization.
*
* Mirrors the Python `update_running_stats()` in torch/util.py and
* flax/util.py.
*
* These utilities are called at every patch boundary during autoregressive
* decoding. Numerical errors here compound over long horizons, so we
* implement the numerically-stable two-pass algorithm and skip NaN/Inf
* values to prevent data corruption.
*/
// ---------------------------------------------------------------------------
// Running stats interface
// ---------------------------------------------------------------------------
export interface RunningStats {
n: number;
mu: number;
sigma: number;
}
/** Return zero initialised running stats. */
export function createRunningStats(): RunningStats {
return { n: 0, mu: 0, sigma: 0 };
}
// ---------------------------------------------------------------------------
// Single-patch update (one batch element)
// ---------------------------------------------------------------------------
/**
* Update running statistics with a single patch of values.
*
* This is the core Welford step. `mask` values of 1 indicate padding
* positions that must be ignored entirely.
*
* @returns A tuple `[new_stats, new_stats]` matching the Python convention
* where the second element is a convenience copy.
*/
export function updateRunningStats(
stats: RunningStats,
values: Float32Array,
mask: Uint8Array,
): [RunningStats, RunningStats] {
let incN = 0;
let incSum = 0;
const len = values.length;
for (let i = 0; i < len; i++) {
if (mask[i] === 0) {
// non-masked = valid
const v = values[i];
// Skip NaN and Infinity to prevent poisoning the running statistics.
// A single NaN would make sum=NaN, mean=NaN, sigma=NaN, destroying
// all downstream RevIN normalization.
if (!Number.isFinite(v)) continue;
incN++;
incSum += v;
}
}
// Shortcut: no valid values
if (incN === 0) {
return [stats, stats];
}
// Numerically-stable two-pass variance:
// σ² = (Σ(v - μ)²) / N rather than Σv²/N - μ²
// The one-pass E[X²] - E[X]² formula suffers from catastrophic cancellation
// when values are large relative to their variance.
const incMu = incSum / incN;
// Two-pass: accumulate squared deviations from the computed mean
let incVar = 0;
for (let i = 0; i < len; i++) {
if (mask[i] === 0 && Number.isFinite(values[i])) {
const diff = values[i] - incMu;
incVar += diff * diff;
}
}
incVar /= incN;
const incSigma = Math.sqrt(Math.max(0, incVar));
// Pooled update (Welford's parallel algorithm)
const newN = stats.n + incN;
const newMu = (stats.n * stats.mu + incN * incMu) / newN;
// Parallel variance merge:
// σ²_new = (n1·σ1² + n2·σ2² + n1·(μ1 - μ_new)² + n2·(μ2 - μ_new)²) / N
const term1 = stats.n * stats.sigma * stats.sigma;
const term2 = incN * incSigma * incSigma;
const term3 = stats.n * (stats.mu - newMu) * (stats.mu - newMu);
const term4 = incN * (incMu - newMu) * (incMu - newMu);
const newVar = (term1 + term2 + term3 + term4) / newN;
const result: RunningStats = {
n: newN,
mu: newMu,
sigma: Math.sqrt(Math.max(0, newVar)),
};
return [result, result];
}
// ---------------------------------------------------------------------------
// Convenience: compute stats for a whole array at once
// ---------------------------------------------------------------------------
/**
* Compute mean and population standard deviation for an array.
*
* Uses the numerically-stable two-pass algorithm and skips NaN/Inf values
* to prevent data corruption during z-score normalization.
*
* @param mask Optional mask; masked positions are ignored.
*/
export function computeStats(
values: Float32Array,
mask?: Uint8Array,
): { mean: number; std: number } {
// First pass: count valid (finite, unmasked) values and compute mean
let n = 0;
let sum = 0;
for (let i = 0; i < values.length; i++) {
if (mask && mask[i] !== 0) continue;
const v = values[i];
if (!Number.isFinite(v)) continue;
n++;
sum += v;
}
if (n === 0) return { mean: 0, std: 0 };
const mean = sum / n;
// Second pass: compute variance from the mean (numerically stable)
let varSum = 0;
for (let i = 0; i < values.length; i++) {
if (mask && mask[i] !== 0) continue;
const v = values[i];
if (!Number.isFinite(v)) continue;
const diff = v - mean;
varSum += diff * diff;
}
const variance = Math.max(0, varSum / n);
return { mean, std: Math.sqrt(variance) };
}
|