export_onnx

Utility to export a quantized torch model to quantized ONNX.

Functions

configure_linear_module_onnx_quantizers

Sets the onnx export attributes for the given model.

export_fp4

Export quantized model to FP4 ONNX.

export_fp8

Export quantized model to FP8 ONNX.

export_fp8_mha

Export quantized fMHA to FP8 ONNX.

export_int8

Export quantized model to INT8 ONNX.

export_mxfp8

Export quantized model to MXFP8 ONNX.

scaled_dot_product_attention

Perform scaled dot product attention.

configure_linear_module_onnx_quantizers(model)

Sets the onnx export attributes for the given model.

export_fp4(g, inputs, block_size, amax, num_bits, trt_high_precision_dtype, onnx_quantizer_type)

Export quantized model to FP4 ONNX.

Parameters:
  • g (GraphContext)

  • inputs (Value)

  • block_size (int)

  • amax (Value)

  • num_bits (tuple[int, int])

  • trt_high_precision_dtype (str)

  • onnx_quantizer_type (str)

export_fp8(g, inputs, amax, trt_high_precision_dtype)

Export quantized model to FP8 ONNX.

Parameters:
  • g (GraphContext)

  • inputs (Value)

  • amax (float)

  • trt_high_precision_dtype (str)

export_fp8_mha(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, q_quantized_scale=1.0, k_quantized_scale=1.0, v_quantized_scale=1.0, high_precision_flag='Half', disable_fp8_mha=True)

Export quantized fMHA to FP8 ONNX.

FP8 ONNX graph:

Q           K          V
|           |          |
\          /           |
QDQ      QDQ           |
  \      /             |
 Cast   Cast           |
   \    /              |
    BMM1               |
     \                 |
    Cast              QDQ
       \               |
      SoftMax          |
         |             |
        QDQ            |
          \            |
           Cast      Cast
               \     /
                BMM2
                 |
                Cast
Parameters:
  • g (GraphContext)

  • query (Value)

  • key (Value)

  • value (Value)

  • attn_mask (Value | None)

  • dropout_p (float)

  • is_causal (bool)

  • scale (Value | None)

  • q_quantized_scale (float)

  • k_quantized_scale (float)

  • v_quantized_scale (float)

  • high_precision_flag (str)

  • disable_fp8_mha (bool)

export_int8(g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype)

Export quantized model to INT8 ONNX.

Parameters:
  • g (GraphContext)

  • inputs (Value)

  • amax (Tensor)

  • num_bits (int)

  • unsigned (bool)

  • narrow_range (bool)

  • trt_high_precision_dtype (str)

export_mxfp8(g, inputs, onnx_quantizer_type, block_size, axis=-1)

Export quantized model to MXFP8 ONNX.

Parameters:
  • g (GraphContext)

  • inputs (Tensor)

  • onnx_quantizer_type (str)

  • block_size (int)

  • axis (int)

scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)

Perform scaled dot product attention.

Parameters:
  • g (GraphContext)

  • query (Value)

  • key (Value)

  • value (Value)

  • attn_mask (Value | None)

  • dropout_p (float)

  • is_causal (bool)

  • scale (Value | None)

  • enable_gqa (bool)