1 TensorRT C++ API支持的模型输入维度

在TensorRT 7.0及以上版本,我们通常使用以下语句指定输入维度:

    const std::string input_name = "input";
    const std::string output_name = "output";
    const int inputIndex = m_TensorRT_Engine->getBindingIndex(input_name.c_str());
    const int outputIndex = m_TensorRT_Engine->getBindingIndex(output_name.c_str());
    m_TensorRT_Context->setBindingDimensions(inputIndex, Dims3(3, 100, 20));

其中Dims3代表该深度学习模型的输入Tensor的维度为三维tensor,shape为(3,100,20)

一般的深度学习模型,一般的输入维度为(C,H,W),这种输入的维度数据为三维tensor。

另外TensorRT C++ API最高支持Dims4,用于支持4维tensor数据的模型输入。但是随着深度学习框架目前发展的越来越复杂,更多的深度的学习模型需要5维,6维甚至更高维度的tensor作为网络输入,那么如何在现有的TensorRT API去扩展更高维度的输入tensor以满足我们自己的需要呢?

2 扩展TensorRT C++ API 模型输入维度

在TensorRT C++ API的include目录下的NvInferRuntimeCommon.h文件定义了类Class Dims32,

//!
//! \class Dims
//! \brief Structure to define the dimensions of a tensor.
//!
//! TensorRT can also return an invalid dims structure. This structure is represented by nbDims == -1
//! and d[i] == 0 for all d.
//!
//! TensorRT can also return an "unknown rank" dims structure. This structure is represented by nbDims == -1
//! and d[i] == -1 for all d.
//!
class Dims32
{
public:
    //! The maximum rank (number of dimensions) supported for a tensor.
    static constexpr int32_t MAX_DIMS{8};
    //! The rank (number of dimensions).
    int32_t nbDims;
    //! The extent of each dimension.
    int32_t d[MAX_DIMS];
};

该类用于定义tensor的输入维度,从类定义上看,该类支持的最大维度为8。

在TensorRT C++ API的include目录下的NvInferLegacyDims.h定义了目前TensorRT所指的输入维度:

/*
 * Copyright 1993-2021 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H

#include "NvInferRuntimeCommon.h"

//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!

//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims2 object.
    //!
    Dims2()
        : Dims{2, {}}
    {
    }

    //!
    //! \brief Construct a Dims2 from 2 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //!
    Dims2(int32_t d0, int32_t d1)
        : Dims{2, {d0, d1}}
    {
    }
};

//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
    //!
    //! \brief Construct an empty DimsHW object.
    //!
    DimsHW()
        : Dims2()
    {
    }

    //!
    //! \brief Construct a DimsHW given height and width.
    //!
    //! \param height the height of the data
    //! \param width the width of the data
    //!
    DimsHW(int32_t height, int32_t width)
        : Dims2(height, width)
    {
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t& h()
    {
        return d[0];
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t h() const
    {
        return d[0];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t& w()
    {
        return d[1];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t w() const
    {
        return d[1];
    }
};

//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims3 object.
    //!
    Dims3()
        : Dims{3, {}}
    {
    }

    //!
    //! \brief Construct a Dims3 from 3 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //!
    Dims3(int32_t d0, int32_t d1, int32_t d2)
        : Dims{3, {d0, d1, d2}}
    {
    }
};

//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims4 object.
    //!
    Dims4()
        : Dims{4, {}}
    {
    }

    //!
    //! \brief Construct a Dims4 from 4 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //!
    Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
        : Dims{4, {d0, d1, d2, d3}}
    {
    }
};

} // namespace nvinfer1

#endif // NV_INFER_LEGCY_DIMS_H

从上述文件的代码看,构建输入维度只需要继承类Dims,然后按定义进行初始化即可。所以为了TensortRT可以支持Dims5,Dims6,Dims7,Dims8等高输入维度,那么需要自定义扩展以上维度,扩展后的NvInferLegacyDims.h文件内容如下所示:

/*
 * Copyright 1993-2021 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H

#include "NvInferRuntimeCommon.h"

//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!

//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims2 object.
    //!
    Dims2()
        : Dims{2, {}}
    {
    }

    //!
    //! \brief Construct a Dims2 from 2 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //!
    Dims2(int32_t d0, int32_t d1)
        : Dims{2, {d0, d1}}
    {
    }
};

//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
    //!
    //! \brief Construct an empty DimsHW object.
    //!
    DimsHW()
        : Dims2()
    {
    }

    //!
    //! \brief Construct a DimsHW given height and width.
    //!
    //! \param height the height of the data
    //! \param width the width of the data
    //!
    DimsHW(int32_t height, int32_t width)
        : Dims2(height, width)
    {
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t& h()
    {
        return d[0];
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t h() const
    {
        return d[0];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t& w()
    {
        return d[1];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t w() const
    {
        return d[1];
    }
};

//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims3 object.
    //!
    Dims3()
        : Dims{3, {}}
    {
    }

    //!
    //! \brief Construct a Dims3 from 3 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //!
    Dims3(int32_t d0, int32_t d1, int32_t d2)
        : Dims{3, {d0, d1, d2}}
    {
    }
};

//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims4 object.
    //!
    Dims4()
        : Dims{4, {}}
    {
    }

    //!
    //! \brief Construct a Dims4 from 4 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //!
    Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
        : Dims{4, {d0, d1, d2, d3}}
    {
    }
};

//!
//! \class Dims5
//! \brief Descriptor for four-dimensional data.
//!
class Dims5 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims5()
    {
        nbDims = 5;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //!
    Dims5(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4)
    {
        nbDims = 5;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};

//!
//! \class Dims6
//! \brief Descriptor for four-dimensional data.
//!
class Dims6 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims6()
    {
        nbDims = 6;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //!
    Dims6(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5)
    {
        nbDims = 6;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};

//!
//! \class Dims7
//! \brief Descriptor for four-dimensional data.
//!
class Dims7 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims7()
    {
        nbDims = 7;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //! \param d6 The seventh element.
    //!
    Dims7(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6)
    {
        nbDims = 7;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        d[6] = d6;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};


//!
//! \class Dims8
//! \brief Descriptor for four-dimensional data.
//!
class Dims8 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims8()
    {
        nbDims = 8;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //! \param d6 The seventh element.
    //! \param d7 The eighth element.
    //!
    Dims8(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6, int32_t d7)
    {
        nbDims = 8;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        d[6] = d6;
        d[7] = d7;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};




} // namespace nvinfer1

#endif // NV_INFER_LEGCY_DIMS_H

将NvInferLegacyDims.h修改之后,重新编译即可使用所扩展的Dims5、Dims6、Dims7、Dims8的5维,6维,7维,8维网络输入维度。