Press n or j to go to the next uncovered block, b, p or k for the previous block.
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | 1x 1x 1x 1x 1x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 344x 344x 344x 344x 344x 344x 344x 253x 253x 253x 344x 344x 344x 344x 344x 2508x 2508x 2508x 2508x 344x 344x 344x 253x 253x 253x 253x 344x 344x 344x 344x 2508x 2508x 2508x 2508x 2508x 2508x 2508x 2508x 344x 344x 253x 253x 253x 253x 253x 253x 253x 253x 253x 344x 344x 80256x 34194x 34194x 80256x 344x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x | /**
* Complete data preprocessing pipeline for TimesFM.
*
* Takes raw user-provided time series and produces the patched, padded,
* normalized tensors required by the model's forward pass.
*
* Pipeline:
* 1. Clean each series (trailing NaN → leading NaN → interpolate internal)
* 2. Pad/truncate each series to `maxContext` length
* 3. Split into patches of `inputPatchLen`
* 4. Compute per-patch running statistics (RevIN μ, σ)
* 5. Apply RevIN normalization
*
* Mirrors the logic in `TimesFM_2p5.forecast()` and `decode()`.
*/
import type { ForecastConfig, ModelConfig } from './types';
import { cleanSeries } from './utils/nan-handler';
import { createRunningStats, updateRunningStats, type RunningStats } from './utils/stats';
import { revinBatch } from './utils/revin';
import { leftPad, concat, concatUint8 } from './utils/tensor-utils';
// ---------------------------------------------------------------------------
// Result types
// ---------------------------------------------------------------------------
/**
* Output of the full preprocessing pipeline — ready for model inference.
*/
export interface PreprocessedData {
/** Patched input [batchSize][numPatches * inputPatchLen] — flat per batch entry. */
patchedInputs: Float32Array[];
/** Patch-level mask [batchSize][numPatches * inputPatchLen]. */
patchedMasks: Uint8Array[];
/** Per-patch means [batchSize * numPatches] — for RevIN reversal. */
contextMu: Float32Array[];
/** Per-patch std deviations [batchSize * numPatches] — for RevIN reversal. */
contextSigma: Float32Array[];
/** Per-batch-element running stats after the last patch. */
lastStats: RunningStats[];
/** Number of patches per series (= maxContext / inputPatchLen). */
numPatches: number;
/** Number of input series. */
batchSize: number;
/** The cleaned raw inputs (for post-processing reference). */
cleanedInputs: Float32Array[];
/** The truncated inputs (after truncation to maxContext, before padding).
* @internal Reserved for future iterative refinement APIs. */
truncatedInputs: Float32Array[];
}
// ---------------------------------------------------------------------------
// Main pipeline
// ---------------------------------------------------------------------------
/**
* Run the full preprocessing pipeline on a batch of raw time series.
*
* @param inputs Raw 1-D time series (any length, may contain NaN).
* @param fc Forecast configuration (controls maxContext).
* @param mc Model architecture config (controls patch sizes).
*/
export function preprocess(
inputs: Float32Array[],
fc: ForecastConfig,
mc: ModelConfig,
): PreprocessedData {
const batchSize = inputs.length;
const { inputPatchLen } = mc;
const numPatches = Math.floor(fc.maxContext / inputPatchLen);
// ---- Step 1: Clean each series ----
const cleanedInputs = inputs.map((s) => cleanSeries(s));
// ---- Step 2: Pad/truncate to maxContext ----
const padded: Float32Array[] = [];
const fullMasks: Uint8Array[] = [];
const truncatedInputs: Float32Array[] = [];
for (const series of cleanedInputs) {
const { padded: p, mask: m } = leftPad(series, fc.maxContext);
padded.push(p);
fullMasks.push(m);
// Record the truncated version (last maxContext points)
truncatedInputs.push(
series.length > fc.maxContext ? series.slice(series.length - fc.maxContext) : series,
);
}
// ---- Step 3: Split into patches ----
const patchedInputs: Float32Array[] = [];
const patchedMasks: Uint8Array[] = [];
for (let b = 0; b < batchSize; b++) {
const flatInput = padded[b];
const flatMask = fullMasks[b];
// Concatenate all patches into one flat array per batch element
const patchValues: Float32Array[] = [];
const patchMasks: Uint8Array[] = [];
for (let p = 0; p < numPatches; p++) {
const offset = p * inputPatchLen;
patchValues.push(flatInput.slice(offset, offset + inputPatchLen));
patchMasks.push(flatMask.slice(offset, offset + inputPatchLen));
}
patchedInputs.push(concat(patchValues));
patchedMasks.push(concatUint8(patchMasks));
}
// ---- Step 4: Compute per-patch running statistics ----
const contextMu: Float32Array[] = [];
const contextSigma: Float32Array[] = [];
const lastStats: RunningStats[] = [];
for (let b = 0; b < batchSize; b++) {
let stats = createRunningStats();
const flatInput = padded[b];
const flatMask = fullMasks[b];
for (let p = 0; p < numPatches; p++) {
const offset = p * inputPatchLen;
const patchValues = flatInput.slice(offset, offset + inputPatchLen);
const patchMask = flatMask.slice(offset, offset + inputPatchLen);
const [updated] = updateRunningStats(stats, patchValues, patchMask);
stats = updated;
contextMu.push(new Float32Array([stats.mu]));
contextSigma.push(new Float32Array([stats.sigma]));
}
lastStats.push({ ...stats });
}
// ---- Step 5: Apply RevIN normalization ----
const normed = revinBatch(
patchedInputs,
contextMu,
contextSigma,
false, // forward normalization
numPatches,
inputPatchLen,
);
// Apply mask → zero out padded positions
for (let b = 0; b < batchSize; b++) {
const mask = patchedMasks[b];
for (let i = 0; i < normed[b].length; i++) {
if (mask[i] === 1) {
normed[b][i] = 0;
}
}
}
return {
patchedInputs: normed,
patchedMasks,
contextMu,
contextSigma,
lastStats,
numPatches,
batchSize,
cleanedInputs,
truncatedInputs,
};
}
|