A neural network compiler and runtime that runs entirely in the browser. ReScript generates optimized WGSL compute shaders. WebGPU executes them on your GPU. Autograd handles backpropagation. No server required.
160+ ops defined as algebraic types with shape, dtype, graph
Shape inference with broadcasting, reduction, convolution rules
3500+ lines of WGSL shader generation for every op
Graph compilation, buffer allocation, dispatch scheduling
WebGPU execution, pipeline caching, buffer management
Unary, binary, trigonometric, and activation functions — all compiled to parallel WGSL shaders with broadcasting.
Tiled matrix multiplication, batched matmul, INT4 quantized matmul with configurable group size, Gemm, Einsum.
Sum, Mean, Max, Min, Prod, L1/L2 norm, LogSumExp — with arbitrary axis selection and keepDims.
BatchNorm, LayerNorm, RMSNorm, InstanceNorm, GroupNorm, LRN — critical for stable training and inference.
1D/2D/3D convolutions, transposed convolutions, depthwise separable, max/avg/global/adaptive pooling.
Scaled dot-product attention with causal masking, multi-head attention, RNN, LSTM, GRU cells.
Full automatic differentiation with gradient tape, backward kernel generation for every differentiable op. Adam, AdamW, SGD optimizers.
PyTorch-style module system — Sequential, Linear, Conv2D, LSTM, Embedding, BatchNorm, Dropout, with .forward() and .backward().
import { Tensor, nn, optim, init } from './nn.js'; await init(); // Initialize WebGPU // Define model const model = new nn.Sequential( new nn.Linear(2, 16), new nn.Tanh(), new nn.Linear(16, 1), new nn.Sigmoid() ); const optimizer = new optim.Adam(model.parameters(), 0.01); // Training loop for (let i = 0; i < 100; i++) { const y = await model.forward(x); const loss = await nn.mseLoss(y, target); await loss.backward(); // Autograd computes gradients on GPU await optimizer.step(); // Adam update via WGSL shader }
// 1. Define operations as typed variants in ReScript type op = | MatMul | ReLU | Softmax({ axis: int }) | LayerNorm({ axes: array<int>, epsilon: float }) | ... // 160+ operations // 2. Shape inference validates dimensions at compile time Shape.infer(MatMul, [[32, 128], [128, 64]]) // → Some([32, 64]) // 3. Codegen emits optimized WGSL compute shaders Codegen.generate(MatMul, [[32, 128], [128, 64]]) // → { wgsl: "@compute @workgroup_size(16,16)...", dispatch: [2, 4, 1] } // 4. Runtime executes on GPU via WebGPU API
Type-safe compiler core
GPU compute shaders
Browser GPU runtime
JS bundling
GPTQ model support
LLM inference target