Table of Contents
- Introduction
- Prerequisites
- Project Setup
- Architecture Overview
- Exporting Models to ONNX
- Loading an ONNX Model
- Text Generation Pipeline
- Building the CLI Chat Interface
- Going Further
- Testing
- Deployment Considerations
- Conclusion
TLDR
- The full source code and a quick README is on Github: ndemir/rust-onnx-chat
Introduction
ONNX (Open Neural Network Exchange) provides a framework-agnostic way to serialize machine-learning models. This means we can train a model in PyTorch or TensorFlow and run it anywhere ONNX Runtime is supported, including Rust.
What You’ll Build
In this tutorial we will build a terminal-based chat bot that:
- Loads a Transformer model in ONNX format
- Tokenizes user input
- Generates replies token-by-token
- Streams the response back to the console
Key Learning Outcomes
During this journey, you’ll master:
- Model Export: Converting HuggingFace models to ONNX format
- Rust Integration: Loading and running ONNX models in Rust
- Text Generation: Implementing token-by-token text generation
- Advanced Topics: Conversation history, temperature sampling, and GPU acceleration
Prerequisites
Since we will use rust language, of course we need rust. But, we also need some kind of basic ML knowledge because we will also implement tokenization & sampling.
What | Why |
---|---|
Rust 1.70+ | std::io::IsTerminal is stabilized |
ONNX Runtime library | Actual model inference |
Basic ML knowledge | Understand tokenization & sampling |
Install Rust with rustup and make sure cargo
is on your $PATH
.
Project Setup
Setting up the project should be very straightforward if you have some rust experinece. If you don’t have, no worries, let’s just continue with the steps.
# Create the project
cargo new rust-onnx-chat
cd rust-onnx-chat
Cargo.toml
[package]
name = "onnx-chat"
version = "0.1.0"
edition = "2021"
[dependencies]
ort = "1.16" # ONNX Runtime
ndarray = "0.15" # Tensors
rand = "0.8" # Sampling
anyhow = "1.0" # Error handling
tokio = { version = "1", features = ["full"] }
minijinja = "2.0" # Simple prompt templating
serde = { version = "1", features = ["derive"] }
serde_json = "1.0"
tokenizers = "0.20" # HuggingFace tokenizers
Architecture Overview
A typical chat pipeline looks like this:
We’ll encapsulate the heavy-lifting inside two modules:
chat.rs
– model loading, inference, and conversation handlingmain.rs
– ergonomic CLI wrapper
Exporting Models to ONNX
Before we can use a model with ONNX Runtime, we need to export it from its original framework (PyTorch, TensorFlow, etc.) to ONNX format. This section covers how to convert a HuggingFace model to ONNX.
Step-by-Step: Export HuggingFace Models to ONNX
Step 1: Install Required Dependencies
pip install optimum[onnxruntime]
Step 2: Export the Model
optimum-cli export onnx --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --task text-generation-with-past tinyllama_onnx/
Step 3: Verify Export Success After export, you’ll have these files:
model.onnx
- The actual neural network (typically 1-5GB)tokenizer.json
- HuggingFace tokenizer configurationconfig.json
- Model metadata and parameters
Supported Models and Tasks
Model Type | Export Command | Use Case |
---|---|---|
GPT-2/GPT-J | --task text-generation-with-past | Text generation, chat bots |
BERT | --task feature-extraction | Text classification, embeddings |
T5 | --task text2text-generation-with-past | Translation, summarization |
LLaMA | --task text-generation-with-past | Large language model tasks |
What happens during export?
Optimum uses the Torch ONNX exporter internally. The specific implementation can be found in the codebase here: Optimum ONNX Convert Implementation
The export process (PyTorch ONNX Documentation):
- Traces the PyTorch model: “executes the model once with the given args and records all operations that happen during that execution”
- Converts the computational graph to ONNX format: “exports the traced model to the specified file”
- Applies optimizations like constant folding when enabled (optional)
Directory Structure
After export, you’ll have:
tinyllama_onnx/
├── model.onnx # The neural network
├── tokenizer.json # Tokenizer configuration
└── config.json # Model metadata
Loading an ONNX Model
There are two main functions to load model & tokenizer.
SessionBuilder::new(&env).with_model_from_file(model_path)
is used to load the modeltokenizers::Tokenizer::from_file(tokenizer_path)
is used to load the tokenizer.
We will implement a ChatModel
struct and add methods to it and wrap the two functions to manage them.
use ort::{Environment, SessionBuilder};
use std::path::Path;
pub struct ChatModel {
session: ort::Session,
tokenizer: tokenizers::Tokenizer,
}
impl ChatModel {
pub fn new(model_path: &Path, tokenizer_path: &Path) -> anyhow::Result<Self> {
// 1️⃣ Create an ONNX Runtime environment
let env = Environment::builder()
.with_name("chat_model")
.build()?;
// 2️⃣ Load the model into a session
let session = SessionBuilder::new(&env)?
.with_model_from_file(model_path)?;
// 3️⃣ Load the HuggingFace tokenizer
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
Ok(Self { session, tokenizer })
}
}
In summary, ChatModel
:
- Encapsulates model, tokenizer, and helper methods
- Keeps the public API clean (
generate
,generate_streaming
, etc.)
Text Generation Pipeline
Once we have loaded the model and the tokenizer, we need a way to generate text based on the given inputs. The high-level steps are:
- Encode the prompt into input IDs
- For each step
- Create input tensors (
input_ids
,attention_mask
,position_ids
) - Run the model
- Sample a token (argmax, temperature, or nucleus sampling)
- Append token to sequence
- Create input tensors (
- Decode newly-generated tokens back to UTF-8 text
impl ChatModel {
pub fn generate(&self, prompt: &str, max_len: usize) -> anyhow::Result<String> {
let encoding = self.tokenizer.encode(prompt, false)?;
let mut ids = encoding.get_ids().to_vec();
for _ in 0..max_len {
let input_ids: Vec<i64> = ids.iter().map(|&id| id as i64).collect();
let next_token = self.argmax(&logits); // or sampling
if self.is_eos(next_token) { break; }
ids.push(next_token);
}
let text = self.tokenizer.decode(&ids[encoding.len()..], true)?;
Ok(text)
}
}
Building the CLI Chat Interface
The CLI lives in src/main.rs
and acts as proxy to relay all heavy work to ChatBot
:
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let model_name = std::env::args().nth(1).unwrap_or_else(|| "tinyllama".into());
let mut bot = chat::ChatBot::new(&model_name).await?;
if std::io::stdin().is_terminal() {
// Interactive REPL
println!("🤖 ONNX Chat Bot — type 'quit' to exit");
loop {
print!("\n> ");
std::io::stdout().flush()?;
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
if input.trim().eq_ignore_ascii_case("quit") { break; }
let answer = bot.generate_response(input.trim()).await?;
println!("🤖 {answer}");
}
} else {
// Piped input
let mut buf = String::new();
std::io::stdin().read_to_string(&mut buf)?;
if !buf.trim().is_empty() {
println!("{}", bot.generate_response(buf.trim()).await?);
}
}
Ok(())
}
Going Further
What we have done so far is laying out the foundation. Even though what we have so far can be used in a simple production system, of course there is always room for improvements. Let’s look at those.
1. Conversation Memory
Store the last n (user, assistant) pairs so replies remain contextual.
2. Temperature & Top-p Sampling
Argmax is deterministic but boring. For creative replies, divide logits by temperature
and sample only from the top-p cumulative probability mass.
3. Streaming Tokens
Render tokens as soon as they are produced for a snappier UX:
pub fn generate_streaming<F>(&self, prompt: &str, cb: F) -> anyhow::Result<()> where F: FnMut(&str) { /* ... */ }
4. Performance Optimizations
- Batching multiple sequences
- KV-Cache for Transformers
- Quantization (int8) with ONNX Runtime tooling
- GPU providers (
CUDA
,DirectML
,ROCm
)
Deployment Considerations
Depends on what you need and what you will do and where your model will run, you may neeed to consider a couple of topics. The following table lists those topcics with possible recommended solutions.
Concern | Recommendation |
---|---|
Model size | Quantize or distill large models |
Memory | Monitor RSS; use streaming outputs |
Concurrency | Multi-threaded Tokio or Actix-web API |
Device | CPU by default, enable CUDA for speed-ups |
Conclusion
Pairing ONNX Runtime with Rust lets you deploy high-performance and memory-safe ML applications. Once the core pipeline is in place you can iterate rapidly and experiment with larger models, richer prompts, or even expose the bot as a REST or gRPC service.
Frequently Asked Questions (FAQ)
What is ONNX and why should I use it for chat bots?
ONNX (Open Neural Network Exchange) provides a framework-agnostic way to serialize machine-learning models. This means you can train a model in PyTorch or TensorFlow and run it anywhere ONNX Runtime is supported, including Rust. This gives you the flexibility to use the best training frameworks while deploying with high-performance runtime environments.
What are the prerequisites for building an ONNX chat bot with Rust?
You need Rust 1.70+ (for std::io::IsTerminal), ONNX Runtime library for model inference, and basic ML knowledge to understand tokenization and sampling. Install Rust with rustup and ensure cargo is on your PATH.
How do I export a HuggingFace model to ONNX format?
First install optimum: pip install optimum[onnxruntime]. Then run the export command: optimum-cli export onnx –model MODEL_NAME –task text-generation-with-past output_dir/. This works for most HuggingFace models including GPT-2, GPT-J, LLaMA, and T5.
What files are created when exporting a model to ONNX?
The export process creates three essential files: model.onnx (the actual neural network, typically 1-5GB), tokenizer.json (HuggingFace tokenizer configuration), and config.json (model metadata and parameters).
How does the text generation pipeline work?
The pipeline follows these steps: 1) Encode the prompt into input IDs using the tokenizer, 2) For each generation step, create input tensors and run the model, 3) Sample a token using argmax, temperature, or nucleus sampling, 4) Append the token to the sequence, 5) Decode newly-generated tokens back to UTF-8 text.
What performance optimizations can I implement?
Key optimizations include: batching multiple sequences together, implementing KV-Cache for Transformers, using quantization (int8) with ONNX Runtime tooling, and enabling GPU providers like CUDA, DirectML, or ROCm for hardware acceleration.
How do I add conversation memory to my chat bot?
Store the last n (user, assistant) pairs so replies remain contextual. This allows the bot to reference previous parts of the conversation and provide more coherent, context-aware responses.
What are the main deployment considerations?
Consider model size (quantize or distill large models), memory usage (monitor RSS and use streaming outputs), concurrency (use multi-threaded Tokio or Actix-web API), and target device (CPU by default, enable CUDA for speed-ups).
How do I implement streaming token generation?
Use a callback function approach with generate_streaming method that renders tokens as soon as they are produced. This provides a snappier user experience by showing partial responses in real-time rather than waiting for the complete response.
What sampling techniques can I use instead of argmax?
For more creative replies, you can implement temperature sampling (divide logits by temperature) and top-p sampling (sample only from the top-p cumulative probability mass). Argmax is deterministic but boring, while these techniques add controlled randomness to responses.