# JAX — High-Performance Numerical Computing by Google > A Python library for high-performance machine learning research combining NumPy-like syntax with automatic differentiation, XLA compilation, and hardware acceleration on GPU and TPU. ## Install Save as a script file and run: # JAX — High-Performance Numerical Computing by Google ## Quick Use ```bash pip install jax[cuda12] python -c "import jax.numpy as jnp; x = jnp.ones((3,3)); print(jnp.dot(x, x))" ``` ## Introduction JAX is a numerical computing library from Google that combines a NumPy-compatible API with composable function transformations for automatic differentiation, vectorization, JIT compilation via XLA, and parallelization. It is the foundation for many Google AI research projects and large-scale model training runs. ## What JAX Does - Provides a NumPy-compatible API that runs on CPU, GPU, and TPU - Computes gradients automatically with `jax.grad` for any Python function - Compiles functions to optimized machine code via XLA with `jax.jit` - Vectorizes operations across batch dimensions with `jax.vmap` - Parallelizes computation across multiple devices with `jax.pmap` and sharding APIs ## Architecture Overview JAX traces Python functions into an intermediate representation called jaxpr, then lowers it to XLA HLO for compilation and execution. The tracing mechanism uses abstract values to capture computation graphs without executing them. Transformations like grad, jit, vmap, and pmap compose functionally, allowing nested combinations. Device memory management is handled by the XLA runtime. ## Self-Hosting & Configuration - Install via pip: `pip install jax[cuda12]` for GPU or `pip install jax` for CPU - TPU support available on Google Cloud with `pip install jax[tpu]` - Set `JAX_PLATFORM_NAME=cpu` to force CPU execution for debugging - Configure memory allocation with `XLA_PYTHON_CLIENT_PREALLOCATE=false` for shared GPU - Use `jax.devices()` to inspect available accelerators at runtime ## Key Features - Composable transformations: jit, grad, vmap, and pmap can be freely nested - XLA compilation delivers performance comparable to hand-written CUDA code - Functional programming model with immutable arrays ensures reproducibility - Native multi-device and multi-host parallelism for distributed training - Foundation for Flax, Optax, Orbax, and the broader JAX ecosystem ## Comparison with Similar Tools - **PyTorch** — Imperative and easier to debug; JAX is functional with stronger compilation - **TensorFlow** — Larger ecosystem but heavier; JAX is leaner with composable transforms - **NumPy** — Same API but no GPU support, autodiff, or JIT compilation - **CuPy** — GPU-accelerated NumPy without autodiff or compilation - **Triton (OpenAI)** — Lower-level GPU kernel programming; JAX operates at array level ## FAQ **Q: Should I use JAX or PyTorch?** A: JAX suits functional programming and research requiring composable transforms. PyTorch is more flexible for imperative-style development and has a larger third-party ecosystem. **Q: Does JAX work on Apple Silicon?** A: JAX supports CPU on Apple Silicon. GPU acceleration via Metal is experimental and not officially supported. **Q: What is the relationship between JAX and Flax?** A: Flax is a neural network library built on JAX. JAX provides the numerical primitives; Flax adds modules, optimizers, and training utilities. **Q: Can JAX run on multiple GPUs?** A: Yes. Use `jax.pmap` for data parallelism or the newer sharding API for flexible multi-device partitioning. ## Sources - https://github.com/jax-ml/jax - https://jax.readthedocs.io/ --- Source: https://tokrepo.com/en/workflows/7abd1277-3c92-11f1-9bc6-00163e2b0d79 Author: Script Depot