refer to example code
.
from functools import partial
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification
from optimum.onnxruntime.configuration import AutoQuantizationConfig, AutoCalibrationConfig
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantizer = ORTQuantizer.from_pretrained(onnx_model)
qconfig = AutoQuantizationConfig.arm64(is_static=True, per_channel=False)
def preprocess_fn(ex, tokenizer):
return tokenizer(ex["sentence"])
calibration_dataset = quantizer.get_calibration_dataset(
"glue",
dataset_config_name="sst2",
preprocess_function=partial(preprocess_fn, tokenizer=tokenizer),
num_samples=50,
dataset_split="train",
)
calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)
ranges = quantizer.fit(
dataset=calibration_dataset,
calibration_config=calibration_config,
operators_to_quantize=qconfig.operators_to_quantize,
)
model_quantized_path = quantizer.quantize(
save_dir="path/to/output/model",
calibration_tensors_range=ranges,
quantization_config=qconfig,
)
..
options for several instructions
.
optimum-cli onnxruntime quantize --help
usage: optimum-cli <command> [<args>] onnxruntime quantize [-h] --onnx_model ONNX_MODEL -o OUTPUT [--per_channel] (--arm64 | --avx2 | --avx512 | --avx512_vnni | --tensorrt | -c CONFIG)
options:
-h, --help show this help message and exit
--arm64 Quantization for the ARM64 architecture.
--avx2 Quantization with AVX-2 instructions.
--avx512 Quantization with AVX-512 instructions.
--avx512_vnni Quantization with AVX-512 and VNNI instructions.
--tensorrt Quantization for NVIDIA TensorRT optimizer.
-c CONFIG, --config CONFIG
`ORTConfig` file to use to optimize the model.
Required arguments:
--onnx_model ONNX_MODEL
Path to the repository where the ONNX models to quantize are located.
-o OUTPUT, --output OUTPUT
Path to the directory where to store generated ONNX model.
Optional arguments:
--per_channel Compute the quantization parameters on a per-channel basis.
..
refer to this page for details:
https://huggingface.co/docs/optimum/onnxruntime/usage_guides/quantization#quantize-seq2seq-models
refer to this code as well
.
you may be able to get idea.
# Export to ONNX
model = ORTModelForSeq2SeqLM.from_pretrained(model_path, from_transformers=True, export=True, provider='CUDAExecutionProvider').to(device)
model.save_pretrained(onnx_path)
# quantization code
encoder_quantizer = ORTQuantizer.from_pretrained(onnx_path, file_name='encoder_model.onnx')
decoder_quantizer = ORTQuantizer.from_pretrained(onnx_path, file_name='decoder_model.onnx')
decoder_wp_quantizer = ORTQuantizer.from_pretrained(onnx_path, file_name='decoder_with_past_model.onnx')
quantizer = [encoder_quantizer, decoder_quantizer, decoder_wp_quantizer]
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
for q in quantizer:
q.quantize(save_dir=output_path, quantization_config=dqconfig)
#inference code
model = ORTModelForSeq2SeqLM.from_pretrained(
model_id=model_path,
encoder_file_name='encoder_model_quantized.onnx',
decoder_file_name='decoder_model_quantized.onnx',
decoder_with_past_file_name='decoder_with_past_model_quantized.onnx',
provider='CUDAExecutionProvider',
use_io_binding=True,
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large')
...
dataset = self.dataset(input_dict)
dataset.set_format(type='torch', device=self.device, columns=['input_ids', 'attention_mask'])
data_loader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=self.data_collator)
generated_outputs: List[OUTPUT_TYPE] = []
for i, batch in enumerate(data_loader):
_batch = {key: val.to(self.device) for key, val in batch.items()}
outputs = self.model.generate(**_batch, generation_config=self.generation_config)
decoded_outputs = self.tokenizer.batch_decode(outputs.cpu().tolist(), skip_special_tokens=True)
.
Thank you.
note! quantisation and optimise is different.
No comments:
Post a Comment