All files / timesfm-core/src preprocessor.ts

100% Statements 85/85
100% Branches 12/12
100% Functions 1/1
100% Lines 85/85

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169                                  1x 1x 1x 1x                                                                                   1x 253x 253x 253x 253x 253x 253x 253x     253x     253x 253x 253x   253x 344x 344x 344x   344x 344x 344x 344x     253x 253x   253x 344x 344x     344x 344x   344x 2508x 2508x 2508x 2508x   344x 344x 344x     253x 253x 253x   253x 344x 344x 344x   344x 2508x 2508x 2508x   2508x 2508x   2508x 2508x 2508x   344x 344x     253x 253x 253x 253x 253x 253x 253x 253x     253x 344x 344x 80256x 34194x 34194x 80256x 344x   253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x 253x  
/**
 * Complete data preprocessing pipeline for TimesFM.
 *
 * Takes raw user-provided time series and produces the patched, padded,
 * normalized tensors required by the model's forward pass.
 *
 * Pipeline:
 *   1. Clean each series (trailing NaN → leading NaN → interpolate internal)
 *   2. Pad/truncate each series to `maxContext` length
 *   3. Split into patches of `inputPatchLen`
 *   4. Compute per-patch running statistics (RevIN μ, σ)
 *   5. Apply RevIN normalization
 *
 * Mirrors the logic in `TimesFM_2p5.forecast()` and `decode()`.
 */
 
import type { ForecastConfig, ModelConfig } from './types';
import { cleanSeries } from './utils/nan-handler';
import { createRunningStats, updateRunningStats, type RunningStats } from './utils/stats';
import { revinBatch } from './utils/revin';
import { leftPad, concat, concatUint8 } from './utils/tensor-utils';
 
// ---------------------------------------------------------------------------
// Result types
// ---------------------------------------------------------------------------
 
/**
 * Output of the full preprocessing pipeline — ready for model inference.
 */
export interface PreprocessedData {
  /** Patched input [batchSize][numPatches * inputPatchLen] — flat per batch entry. */
  patchedInputs: Float32Array[];
  /** Patch-level mask [batchSize][numPatches * inputPatchLen]. */
  patchedMasks: Uint8Array[];
  /** Per-patch means [batchSize * numPatches] — for RevIN reversal. */
  contextMu: Float32Array[];
  /** Per-patch std deviations [batchSize * numPatches] — for RevIN reversal. */
  contextSigma: Float32Array[];
  /** Per-batch-element running stats after the last patch. */
  lastStats: RunningStats[];
  /** Number of patches per series (= maxContext / inputPatchLen). */
  numPatches: number;
  /** Number of input series. */
  batchSize: number;
  /** The cleaned raw inputs (for post-processing reference). */
  cleanedInputs: Float32Array[];
  /** The truncated inputs (after truncation to maxContext, before padding).
   *  @internal Reserved for future iterative refinement APIs. */
  truncatedInputs: Float32Array[];
}
 
// ---------------------------------------------------------------------------
// Main pipeline
// ---------------------------------------------------------------------------
 
/**
 * Run the full preprocessing pipeline on a batch of raw time series.
 *
 * @param inputs  Raw 1-D time series (any length, may contain NaN).
 * @param fc      Forecast configuration (controls maxContext).
 * @param mc      Model architecture config (controls patch sizes).
 */
export function preprocess(
  inputs: Float32Array[],
  fc: ForecastConfig,
  mc: ModelConfig,
): PreprocessedData {
  const batchSize = inputs.length;
  const { inputPatchLen } = mc;
  const numPatches = Math.floor(fc.maxContext / inputPatchLen);
 
  // ---- Step 1: Clean each series ----
  const cleanedInputs = inputs.map((s) => cleanSeries(s));
 
  // ---- Step 2: Pad/truncate to maxContext ----
  const padded: Float32Array[] = [];
  const fullMasks: Uint8Array[] = [];
  const truncatedInputs: Float32Array[] = [];
 
  for (const series of cleanedInputs) {
    const { padded: p, mask: m } = leftPad(series, fc.maxContext);
    padded.push(p);
    fullMasks.push(m);
    // Record the truncated version (last maxContext points)
    truncatedInputs.push(
      series.length > fc.maxContext ? series.slice(series.length - fc.maxContext) : series,
    );
  }
 
  // ---- Step 3: Split into patches ----
  const patchedInputs: Float32Array[] = [];
  const patchedMasks: Uint8Array[] = [];
 
  for (let b = 0; b < batchSize; b++) {
    const flatInput = padded[b];
    const flatMask = fullMasks[b];
 
    // Concatenate all patches into one flat array per batch element
    const patchValues: Float32Array[] = [];
    const patchMasks: Uint8Array[] = [];
 
    for (let p = 0; p < numPatches; p++) {
      const offset = p * inputPatchLen;
      patchValues.push(flatInput.slice(offset, offset + inputPatchLen));
      patchMasks.push(flatMask.slice(offset, offset + inputPatchLen));
    }
 
    patchedInputs.push(concat(patchValues));
    patchedMasks.push(concatUint8(patchMasks));
  }
 
  // ---- Step 4: Compute per-patch running statistics ----
  const contextMu: Float32Array[] = [];
  const contextSigma: Float32Array[] = [];
  const lastStats: RunningStats[] = [];
 
  for (let b = 0; b < batchSize; b++) {
    let stats = createRunningStats();
    const flatInput = padded[b];
    const flatMask = fullMasks[b];
 
    for (let p = 0; p < numPatches; p++) {
      const offset = p * inputPatchLen;
      const patchValues = flatInput.slice(offset, offset + inputPatchLen);
      const patchMask = flatMask.slice(offset, offset + inputPatchLen);
 
      const [updated] = updateRunningStats(stats, patchValues, patchMask);
      stats = updated;
 
      contextMu.push(new Float32Array([stats.mu]));
      contextSigma.push(new Float32Array([stats.sigma]));
    }
 
    lastStats.push({ ...stats });
  }
 
  // ---- Step 5: Apply RevIN normalization ----
  const normed = revinBatch(
    patchedInputs,
    contextMu,
    contextSigma,
    false, // forward normalization
    numPatches,
    inputPatchLen,
  );
 
  // Apply mask → zero out padded positions
  for (let b = 0; b < batchSize; b++) {
    const mask = patchedMasks[b];
    for (let i = 0; i < normed[b].length; i++) {
      if (mask[i] === 1) {
        normed[b][i] = 0;
      }
    }
  }
 
  return {
    patchedInputs: normed,
    patchedMasks,
    contextMu,
    contextSigma,
    lastStats,
    numPatches,
    batchSize,
    cleanedInputs,
    truncatedInputs,
  };
}