export_onnx
Utility to export a quantized torch model to quantized ONNX.
Functions
Sets the onnx export attributes for the given model. |
|
Export quantized model to FP4 ONNX. |
|
Export quantized model to FP8 ONNX. |
|
Export quantized fMHA to FP8 ONNX. |
|
Export quantized model to INT8 ONNX. |
|
Export quantized model to MXFP8 ONNX. |
|
Perform scaled dot product attention. |
- configure_linear_module_onnx_quantizers(model)
Sets the onnx export attributes for the given model.
- export_fp4(g, inputs, block_size, amax, num_bits, trt_high_precision_dtype, onnx_quantizer_type)
Export quantized model to FP4 ONNX.
- Parameters:
g (GraphContext)
inputs (Value)
block_size (int)
amax (Value)
num_bits (tuple[int, int])
trt_high_precision_dtype (str)
onnx_quantizer_type (str)
- export_fp8(g, inputs, amax, trt_high_precision_dtype)
Export quantized model to FP8 ONNX.
- Parameters:
g (GraphContext)
inputs (Value)
amax (float)
trt_high_precision_dtype (str)
- export_fp8_mha(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, q_quantized_scale=1.0, k_quantized_scale=1.0, v_quantized_scale=1.0, high_precision_flag='Half', disable_fp8_mha=True)
Export quantized fMHA to FP8 ONNX.
FP8 ONNX graph:
Q K V | | | \ / | QDQ QDQ | \ / | Cast Cast | \ / | BMM1 | \ | Cast QDQ \ | SoftMax | | | QDQ | \ | Cast Cast \ / BMM2 | Cast
- Parameters:
g (GraphContext)
query (Value)
key (Value)
value (Value)
attn_mask (Value | None)
dropout_p (float)
is_causal (bool)
scale (Value | None)
q_quantized_scale (float)
k_quantized_scale (float)
v_quantized_scale (float)
high_precision_flag (str)
disable_fp8_mha (bool)
- export_int8(g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype)
Export quantized model to INT8 ONNX.
- Parameters:
g (GraphContext)
inputs (Value)
amax (Tensor)
num_bits (int)
unsigned (bool)
narrow_range (bool)
trt_high_precision_dtype (str)
- export_mxfp8(g, inputs, onnx_quantizer_type, block_size, axis=-1)
Export quantized model to MXFP8 ONNX.
- Parameters:
g (GraphContext)
inputs (Tensor)
onnx_quantizer_type (str)
block_size (int)
axis (int)
- scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)
Perform scaled dot product attention.
- Parameters:
g (GraphContext)
query (Value)
key (Value)
value (Value)
attn_mask (Value | None)
dropout_p (float)
is_causal (bool)
scale (Value | None)
enable_gqa (bool)