+ add tflite export

+ add Quantisierung
This commit is contained in:
Vincent Hanewinkel 2025-07-15 17:13:11 +02:00
parent a696025186
commit 78ad1dd29e
6 changed files with 114 additions and 4 deletions

View File

@ -7,4 +7,10 @@ dir = os.path.dirname(os.path.realpath(__file__))
model = YOLO(dir + "/best.pt") model = YOLO(dir + "/best.pt")
# Exportiere es ins TFLite-Format # Exportiere es ins TFLite-Format
model.export(format="tflite") model.export(
format="tflite",
dynamic=False,
inference_type="uint8",
quantize=True,
calibration_data="path/to/rep_dataset/"
)

Binary file not shown.

89
requirements.txt Normal file
View File

@ -0,0 +1,89 @@
absl-py==2.3.0
ai-edge-litert==1.3.0
astunparse==1.6.3
backports.strenum==1.2.8
certifi==2025.4.26
charset-normalizer==3.4.2
coloredlogs==15.0.1
contourpy==1.3.2
cycler==0.12.1
filelock==3.18.0
flatbuffers==25.2.10
fonttools==4.58.2
fsspec==2025.5.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.73.0
h5py==3.14.0
humanfriendly==10.0
idna==3.10
Jinja2==3.1.6
keras==3.10.0
kiwisolver==1.4.8
libclang==18.1.1
Markdown==3.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.3
mdurl==0.1.2
ml_dtypes==0.5.1
mpmath==1.3.0
namex==0.1.0
networkx==3.5
numpy==2.1.3
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
onnx==1.17.0
onnx2tf==1.27.10
onnx_graphsurgeon==0.5.8
onnxruntime==1.22.0
onnxslim==0.1.57
opencv-python==4.11.0.86
opt_einsum==3.4.0
optree==0.16.0
packaging==25.0
pandas==2.3.0
pillow==11.2.1
protobuf==5.29.5
psutil==7.0.0
py-cpuinfo==9.0.0
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2
requests==2.32.4
rich==14.0.0
scipy==1.15.3
six==1.17.0
sng4onnx==1.0.4
sympy==1.14.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tensorflow==2.19.0
tensorflow-io-gcs-filesystem==0.37.1
termcolor==3.1.0
tf_keras==2.19.0
torch==2.7.1
torchvision==0.22.1
tqdm==4.67.1
triton==3.3.1
typing_extensions==4.14.0
tzdata==2025.2
ultralytics==8.3.154
ultralytics-thop==2.0.14
urllib3==2.4.0
Werkzeug==3.1.3
wrapt==1.17.2

4
test_tflite.py Normal file
View File

@ -0,0 +1,4 @@
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="yolo_training/NAO_detector/weights/best_saved_model/best_float16.tflite")
interpreter.allocate_tensors()

View File

@ -9,6 +9,8 @@ import numpy as np
import torch import torch
import torchvision import torchvision
dir = os.path.dirname(os.path.realpath(__file__))
def verify_images(dataset_path): def verify_images(dataset_path):
"""Überprüfe alle Bilder auf Lesbarkeit und korrekte Dimensionen.""" """Überprüfe alle Bilder auf Lesbarkeit und korrekte Dimensionen."""
print("\nÜberprüfe Bilder...") print("\nÜberprüfe Bilder...")
@ -79,7 +81,7 @@ def train_yolo():
results = model.train( results = model.train(
data=dataset_yaml, data=dataset_yaml,
epochs=50, epochs=50,
imgsz=640, # Bildgröße imgsz=320, # Bildgröße
batch=16, # Batch-Größe batch=16, # Batch-Größe
device=0, # Verwende CPU (oder 'cuda' für GPU) device=0, # Verwende CPU (oder 'cuda' für GPU)
patience=5, # Early Stopping patience=5, # Early Stopping
@ -92,10 +94,19 @@ def train_yolo():
print("\nExportiere Modell in verschiedene Formate...") print("\nExportiere Modell in verschiedene Formate...")
# ONNX Format (gut für eingebettete Systeme) # ONNX Format (gut für eingebettete Systeme)
model.export(format='onnx', imgsz=640) model.export(format='onnx', imgsz=320)
# OpenVINO Format (optimiert für Intel Hardware) # OpenVINO Format (optimiert für Intel Hardware)
model.export(format='openvino', imgsz=640) model.export(format='openvino', imgsz=320)
model.export(
format="tflite",
dynamic=False,
inference_type="uint8",
quantize=True,
imgsz=320,
calibration_data=os.path.join(dir, "calib_images")
)
print("\nTraining abgeschlossen. Die Modelle wurden im 'yolo_training/NAO_detector' Ordner gespeichert.") print("\nTraining abgeschlossen. Die Modelle wurden im 'yolo_training/NAO_detector' Ordner gespeichert.")