summaryrefslogtreecommitdiff
path: root/ext/onnx/gstonnxobjectdetector.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ext/onnx/gstonnxobjectdetector.cpp')
-rw-r--r--ext/onnx/gstonnxobjectdetector.cpp670
1 files changed, 670 insertions, 0 deletions
diff --git a/ext/onnx/gstonnxobjectdetector.cpp b/ext/onnx/gstonnxobjectdetector.cpp
new file mode 100644
index 000000000..28f4cf2fa
--- /dev/null
+++ b/ext/onnx/gstonnxobjectdetector.cpp
@@ -0,0 +1,670 @@
+/*
+ * GStreamer gstreamer-onnxobjectdetector
+ * Copyright (C) 2021 Collabora Ltd.
+ *
+ * gstonnxobjectdetector.c
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Library General Public
+ * License as published by the Free Software Foundation; either
+ * version 2 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Library General Public License for more details.
+ *
+ * You should have received a copy of the GNU Library General Public
+ * License along with this library; if not, write to the
+ * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
+ * Boston, MA 02110-1301, USA.
+ */
+
+/**
+ * SECTION:element-onnxobjectdetector
+ * @short_description: Detect objects in video frame
+ *
+ * This element can apply a generic ONNX object detection model such as YOLO or SSD
+ * to each video frame.
+ *
+ * To install ONNX on your system, recursively clone this repository
+ * https://github.com/microsoft/onnxruntime.git
+ *
+ * and build and install with cmake:
+ *
+ * CPU:
+ *
+ * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
+ * $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
+ *
+ *
+ * GPU :
+ *
+ * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
+ * -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
+ * $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
+ *
+ *
+ * where :
+ *
+ * 1. $SRC_DIR and $BUILD_DIR are local source and build directories
+ * 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
+ * $CUDA_PATH is an environment variable set to the CUDA root path.
+ * On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA.
+ *
+ *
+ * ## Example launch command:
+ *
+ * (note: an object detection model has 3 or 4 output nodes, but there is no naming convention
+ * to indicate which node outputs the bounding box, which node outputs the label, etc.
+ * So, the `onnxobjectdetector` element has properties to map each node's functionality to its
+ * respective node index in the specified model )
+ *
+ * ```
+ * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
+ * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
+ * videoconvert ! \
+ * onnxobjectdetector \
+ * box-node-index=0 \
+ * class-node-index=1 \
+ * score-node-index=2 \
+ * detection-node-index=3 \
+ * execution-provider=cpu \
+ * model-file=model.onnx \
+ * label-file=COCO_classes.txt ! \
+ * videoconvert ! \
+ * autovideosink
+ * ```
+ */
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "gstonnxobjectdetector.h"
+#include "gstonnxclient.h"
+
+#include <gst/gst.h>
+#include <gst/video/video.h>
+#include <gst/video/gstvideometa.h>
+#include <stdlib.h>
+#include <string.h>
+#include <glib.h>
+
+GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug);
+#define GST_CAT_DEFAULT onnx_object_detector_debug
+#define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr))
+GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector",
+ GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR);
+
+/* GstOnnxObjectDetector properties */
+enum
+{
+ PROP_0,
+ PROP_MODEL_FILE,
+ PROP_LABEL_FILE,
+ PROP_SCORE_THRESHOLD,
+ PROP_DETECTION_NODE_INDEX,
+ PROP_BOUNDING_BOX_NODE_INDEX,
+ PROP_SCORE_NODE_INDEX,
+ PROP_CLASS_NODE_INDEX,
+ PROP_INPUT_IMAGE_FORMAT,
+ PROP_OPTIMIZATION_LEVEL,
+ PROP_EXECUTION_PROVIDER
+};
+
+
+#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
+#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
+#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
+
+static GstStaticPadTemplate gst_onnx_object_detector_src_template =
+GST_STATIC_PAD_TEMPLATE ("src",
+ GST_PAD_SRC,
+ GST_PAD_ALWAYS,
+ GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
+ );
+
+static GstStaticPadTemplate gst_onnx_object_detector_sink_template =
+GST_STATIC_PAD_TEMPLATE ("sink",
+ GST_PAD_SINK,
+ GST_PAD_ALWAYS,
+ GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
+ );
+
+static void gst_onnx_object_detector_set_property (GObject * object,
+ guint prop_id, const GValue * value, GParamSpec * pspec);
+static void gst_onnx_object_detector_get_property (GObject * object,
+ guint prop_id, GValue * value, GParamSpec * pspec);
+static void gst_onnx_object_detector_finalize (GObject * object);
+static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform *
+ trans, GstBuffer * buf);
+static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans,
+ GstBuffer * buf);
+static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans);
+static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform *
+ trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
+
+G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector,
+ GST_TYPE_BASE_TRANSFORM);
+
+static void
+gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass)
+{
+ GObjectClass *gobject_class = (GObjectClass *) klass;
+ GstElementClass *element_class = (GstElementClass *) klass;
+ GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
+
+ GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector",
+ 0, "onnx_objectdetector");
+ gobject_class->set_property = gst_onnx_object_detector_set_property;
+ gobject_class->get_property = gst_onnx_object_detector_get_property;
+ gobject_class->finalize = gst_onnx_object_detector_finalize;
+
+ /**
+ * GstOnnxObjectDetector:model-file
+ *
+ * ONNX model file
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
+ g_param_spec_string ("model-file",
+ "ONNX model file", "ONNX model file", NULL, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:label-file
+ *
+ * Label file for ONNX model
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
+ g_param_spec_string ("label-file",
+ "Label file", "Label file associated with model", NULL, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+
+ /**
+ * GstOnnxObjectDetector:detection-node-index
+ *
+ * Index of model detection node
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_DETECTION_NODE_INDEX,
+ g_param_spec_int ("detection-node-index",
+ "Detection node index",
+ "Index of neural network output node corresponding to number of detected objects",
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
+ GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+
+ /**
+ * GstOnnxObjectDetector:bounding-box-node-index
+ *
+ * Index of model bounding box node
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_BOUNDING_BOX_NODE_INDEX,
+ g_param_spec_int ("box-node-index",
+ "Bounding box node index",
+ "Index of neural network output node corresponding to bounding box",
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
+ GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:score-node-index
+ *
+ * Index of model score node
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_SCORE_NODE_INDEX,
+ g_param_spec_int ("score-node-index",
+ "Score node index",
+ "Index of neural network output node corresponding to score",
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
+ GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:class-node-index
+ *
+ * Index of model class (label) node
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_CLASS_NODE_INDEX,
+ g_param_spec_int ("class-node-index",
+ "Class node index",
+ "Index of neural network output node corresponding to class (label)",
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
+ GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+
+ /**
+ * GstOnnxObjectDetector:score-threshold
+ *
+ * Threshold for deciding when to remove boxes based on score
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
+ g_param_spec_float ("score-threshold",
+ "Score threshold",
+ "Threshold for deciding when to remove boxes based on score",
+ 0.0, 1.0,
+ GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:input-image-format
+ *
+ * Model input image format
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_INPUT_IMAGE_FORMAT,
+ g_param_spec_enum ("input-image-format",
+ "Input image format",
+ "Input image format",
+ GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
+ GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:optimization-level
+ *
+ * ONNX optimization level
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_OPTIMIZATION_LEVEL,
+ g_param_spec_enum ("optimization-level",
+ "Optimization level",
+ "ONNX optimization level",
+ GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
+ GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ /**
+ * GstOnnxObjectDetector:execution-provider
+ *
+ * ONNX execution provider
+ *
+ * Since: 1.20
+ */
+ g_object_class_install_property (G_OBJECT_CLASS (klass),
+ PROP_EXECUTION_PROVIDER,
+ g_param_spec_enum ("execution-provider",
+ "Execution provider",
+ "ONNX execution provider",
+ GST_TYPE_ONNX_EXECUTION_PROVIDER,
+ GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
+ (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
+
+ gst_element_class_set_static_metadata (element_class, "onnxobjectdetector",
+ "Filter/Effect/Video",
+ "Apply neural network to detect objects in video frames",
+ "Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
+ gst_element_class_add_pad_template (element_class,
+ gst_static_pad_template_get (&gst_onnx_object_detector_sink_template));
+ gst_element_class_add_pad_template (element_class,
+ gst_static_pad_template_get (&gst_onnx_object_detector_src_template));
+ basetransform_class->transform_ip =
+ GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip);
+ basetransform_class->transform_caps =
+ GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps);
+}
+
+static void
+gst_onnx_object_detector_init (GstOnnxObjectDetector * self)
+{
+ self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient ();
+ self->onnx_disabled = false;
+}
+
+static void
+gst_onnx_object_detector_finalize (GObject * object)
+{
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
+
+ g_free (self->model_file);
+ delete GST_ONNX_MEMBER (self);
+ G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object);
+}
+
+static void
+gst_onnx_object_detector_set_property (GObject * object, guint prop_id,
+ const GValue * value, GParamSpec * pspec)
+{
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
+ const gchar *filename;
+ auto onnxClient = GST_ONNX_MEMBER (self);
+
+ switch (prop_id) {
+ case PROP_MODEL_FILE:
+ filename = g_value_get_string (value);
+ if (filename
+ && g_file_test (filename,
+ (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
+ if (self->model_file)
+ g_free (self->model_file);
+ self->model_file = g_strdup (filename);
+ } else {
+ GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
+ gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
+ }
+ break;
+ case PROP_LABEL_FILE:
+ filename = g_value_get_string (value);
+ if (filename
+ && g_file_test (filename,
+ (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
+ if (self->label_file)
+ g_free (self->label_file);
+ self->label_file = g_strdup (filename);
+ } else {
+ GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
+ }
+ break;
+ case PROP_SCORE_THRESHOLD:
+ GST_OBJECT_LOCK (self);
+ self->score_threshold = g_value_get_float (value);
+ GST_OBJECT_UNLOCK (self);
+ break;
+ case PROP_OPTIMIZATION_LEVEL:
+ self->optimization_level =
+ (GstOnnxOptimizationLevel) g_value_get_enum (value);
+ break;
+ case PROP_EXECUTION_PROVIDER:
+ self->execution_provider =
+ (GstOnnxExecutionProvider) g_value_get_enum (value);
+ break;
+ case PROP_DETECTION_NODE_INDEX:
+ onnxClient->setOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
+ g_value_get_int (value));
+ break;
+ case PROP_BOUNDING_BOX_NODE_INDEX:
+ onnxClient->setOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
+ g_value_get_int (value));
+ break;
+ break;
+ case PROP_SCORE_NODE_INDEX:
+ onnxClient->setOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
+ g_value_get_int (value));
+ break;
+ break;
+ case PROP_CLASS_NODE_INDEX:
+ onnxClient->setOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
+ g_value_get_int (value));
+ break;
+ case PROP_INPUT_IMAGE_FORMAT:
+ onnxClient->setInputImageFormat ((GstMlModelInputImageFormat)
+ g_value_get_enum (value));
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
+ break;
+ }
+}
+
+static void
+gst_onnx_object_detector_get_property (GObject * object, guint prop_id,
+ GValue * value, GParamSpec * pspec)
+{
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
+ auto onnxClient = GST_ONNX_MEMBER (self);
+
+ switch (prop_id) {
+ case PROP_MODEL_FILE:
+ g_value_set_string (value, self->model_file);
+ break;
+ case PROP_LABEL_FILE:
+ g_value_set_string (value, self->label_file);
+ break;
+ case PROP_SCORE_THRESHOLD:
+ GST_OBJECT_LOCK (self);
+ g_value_set_float (value, self->score_threshold);
+ GST_OBJECT_UNLOCK (self);
+ break;
+ case PROP_OPTIMIZATION_LEVEL:
+ g_value_set_enum (value, self->optimization_level);
+ break;
+ case PROP_EXECUTION_PROVIDER:
+ g_value_set_enum (value, self->execution_provider);
+ break;
+ case PROP_DETECTION_NODE_INDEX:
+ g_value_set_int (value,
+ onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION));
+ break;
+ case PROP_BOUNDING_BOX_NODE_INDEX:
+ g_value_set_int (value,
+ onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX));
+ break;
+ break;
+ case PROP_SCORE_NODE_INDEX:
+ g_value_set_int (value,
+ onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE));
+ break;
+ break;
+ case PROP_CLASS_NODE_INDEX:
+ g_value_set_int (value,
+ onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS));
+ break;
+ case PROP_INPUT_IMAGE_FORMAT:
+ g_value_set_enum (value, onnxClient->getInputImageFormat ());
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
+ break;
+ }
+}
+
+static gboolean
+gst_onnx_object_detector_create_session (GstBaseTransform * trans)
+{
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
+ auto onnxClient = GST_ONNX_MEMBER (self);
+
+ GST_OBJECT_LOCK (self);
+ if (self->onnx_disabled || onnxClient->hasSession ()) {
+ GST_OBJECT_UNLOCK (self);
+
+ return TRUE;
+ }
+ if (self->model_file) {
+ gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file,
+ self->optimization_level,
+ self->execution_provider);
+ if (!ret) {
+ GST_ERROR_OBJECT (self,
+ "Unable to create ONNX session. Detection disabled.");
+ } else {
+ auto outputNames = onnxClient->getOutputNodeNames ();
+
+ for (size_t i = 0; i < outputNames.size (); ++i)
+ GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i,
+ outputNames[i]);
+ if (outputNames.size () < 3) {
+ GST_ERROR_OBJECT (self,
+ "Number of output tensor nodes %d does not match the 3 or 4 nodes "
+ "required for an object detection model. Detection is disabled.",
+ (gint) outputNames.size ());
+ self->onnx_disabled = TRUE;
+ }
+ // sanity check on output node indices
+ if (onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) ==
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
+ GST_ERROR_OBJECT (self,
+ "Output detection node index not set. Detection disabled.");
+ self->onnx_disabled = TRUE;
+ }
+ if (onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) ==
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
+ GST_ERROR_OBJECT (self,
+ "Output bounding box node index not set. Detection disabled.");
+ self->onnx_disabled = TRUE;
+ }
+ if (onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) ==
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
+ GST_ERROR_OBJECT (self,
+ "Output score node index not set. Detection disabled.");
+ self->onnx_disabled = TRUE;
+ }
+ if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex
+ (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) ==
+ GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
+ GST_ERROR_OBJECT (self,
+ "Output class node index not set. Detection disabled.");
+ self->onnx_disabled = TRUE;
+ }
+ // model is not usable, so fail
+ if (self->onnx_disabled) {
+ GST_ELEMENT_WARNING (self, RESOURCE, FAILED,
+ ("ONNX model cannot be used for object detection"), (NULL));
+
+ return FALSE;
+ }
+ }
+ } else {
+ self->onnx_disabled = TRUE;
+ }
+ GST_OBJECT_UNLOCK (self);
+ if (self->onnx_disabled){
+ gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
+ }
+
+ return TRUE;
+}
+
+
+static GstCaps *
+gst_onnx_object_detector_transform_caps (GstBaseTransform *
+ trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
+{
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
+ auto onnxClient = GST_ONNX_MEMBER (self);
+ GstCaps *other_caps;
+ guint i;
+
+ if ( !gst_onnx_object_detector_create_session (trans) )
+ return NULL;
+ GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
+
+ if (gst_base_transform_is_passthrough (trans)
+ || (!onnxClient->isFixedInputImageSize ()))
+ return gst_caps_ref (caps);
+
+ other_caps = gst_caps_new_empty ();
+ for (i = 0; i < gst_caps_get_size (caps); ++i) {
+ GstStructure *structure, *new_structure;
+
+ structure = gst_caps_get_structure (caps, i);
+ new_structure = gst_structure_copy (structure);
+ gst_structure_set (new_structure, "width", G_TYPE_INT,
+ onnxClient->getWidth (), "height", G_TYPE_INT,
+ onnxClient->getHeight (), NULL);
+ GST_LOG_OBJECT (self,
+ "transformed structure %2d: %" GST_PTR_FORMAT " => %"
+ GST_PTR_FORMAT, i, structure, new_structure);
+ gst_caps_append_structure (other_caps, new_structure);
+ }
+
+ if (!gst_caps_is_empty (other_caps) && filter_caps) {
+ GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps,
+ GST_CAPS_INTERSECT_FIRST);
+ gst_caps_replace (&other_caps, tmp);
+ gst_caps_unref (tmp);
+ }
+
+ return other_caps;
+}
+
+
+static GstFlowReturn
+gst_onnx_object_detector_transform_ip (GstBaseTransform * trans,
+ GstBuffer * buf)
+{
+ if (!gst_base_transform_is_passthrough (trans)
+ && !gst_onnx_object_detector_process (trans, buf)){
+ GST_ELEMENT_WARNING (trans, STREAM, FAILED,
+ ("ONNX object detection failed"), (NULL));
+ return GST_FLOW_ERROR;
+ }
+
+ return GST_FLOW_OK;
+}
+
+static gboolean
+gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
+{
+ GstMapInfo info;
+ GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
+
+ if (!vmeta) {
+ GST_WARNING_OBJECT (trans, "missing video meta");
+ return FALSE;
+ }
+ if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
+ GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
+ auto boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
+ self->label_file ? self->label_file : "",
+ self->score_threshold);
+ for (auto & b:boxes) {
+ auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
+ GST_ONNX_OBJECT_DETECTOR_META_NAME,
+ b.x0, b.y0,
+ b.width,
+ b.height);
+ if (!vroi_meta) {
+ GST_WARNING_OBJECT (trans,
+ "Unable to attach GstVideoRegionOfInterestMeta to buffer");
+ return FALSE;
+ }
+ auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME,
+ GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL,
+ G_TYPE_STRING,
+ b.label.c_str (),
+ GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE,
+ G_TYPE_DOUBLE,
+ b.score,
+ NULL);
+ gst_video_region_of_interest_meta_add_param (vroi_meta, s);
+ GST_DEBUG_OBJECT (self,
+ "Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
+ b.label.c_str (), b.score, b.x0, b.y0,
+ b.x0 + b.width, b.y0 + b.height);
+ }
+ gst_buffer_unmap (buf, &info);
+ }
+
+ return TRUE;
+}