Skip to content

Commit

Permalink
[ATLAS] Update ATLAS, support new APIs like get input shape map
Browse files Browse the repository at this point in the history
  • Loading branch information
doxutx committed Feb 29, 2024
1 parent 70836d0 commit 97e7c1b
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 66 deletions.
8 changes: 7 additions & 1 deletion include/tnn/core/instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class PUBLIC Instance {
std::shared_ptr<AbstractModelInterpreter> GetInterpreter();
#endif // end of GET_INTERP_ENABLE

#ifdef GET_NETWORK_ENABLE
AbstractNetwork *GetNetwork();
#endif

// tnn instance network infer async.
// device gpu, all layer infer complete will call Callback.
Status ForwardAsync(Callback call_back);
Expand Down Expand Up @@ -122,8 +126,10 @@ class PUBLIC Instance {
NetworkConfig net_config_;
ModelConfig model_config_;

#ifndef GET_NETWORK_ENABLE
AbstractNetwork *GetNetwork();

#endif

//Mat interface for simple use
public:
// set input Mat, if input_name is not set, take the first input as default
Expand Down
82 changes: 28 additions & 54 deletions scripts/build_atlas.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,82 +1,56 @@
#!/bin/bash

#export DDK_PATH=/data1/Ascend/ascend-toolkit/latest
#export NPU_HOST_LIB=/data1/Ascend/ascend-toolkit/latest/acllib/lib64

ARM="ON"
OPENMP="ON"
DEBUG=0
SHARED_LIB="ON"
BENCHMARK="OFF"
TNN_TEST="ON"
TARGET_ARCH=aarch64

TNN_BUILD_PATH=$PWD
if [ -z $TNN_ROOT_PATH ]
then
TNN_ROOT_PATH=$(cd `dirname $0`; pwd)/..
fi

TNN_BUILD_DIR=${TNN_ROOT_PATH}/scripts/build_atlas
TNN_INSTALL_DIR=${TNN_ROOT_PATH}/scripts/release_atlas
if [ $DEBUG == "ON" ]; then
TNN_BUILD_DIR=${TNN_ROOT_PATH}/scripts/build_atlas_debug
TNN_INSTALL_DIR=${TNN_ROOT_PATH}/scripts/release_atlas_debug
fi

TNN_VERSION_PATH=$TNN_ROOT_PATH/scripts/version
echo $TNN_ROOT_PATH
echo $TNN_VERSION_PATH
echo ' '
echo '******************** step 1: update version.h ********************'
cd $TNN_VERSION_PATH
source $TNN_VERSION_PATH/version.sh
source $TNN_VERSION_PATH/add_version_attr.sh

echo ' '
echo '******************** step 2: start build atlas ********************'
#删除旧SDK文件
cd $TNN_BUILD_PATH
if [ -x "build_atlas" ];then
rm -r build_atlas
fi

#新建build目录
mkdir build_atlas
cd build_atlas
mkdir -p ${TNN_BUILD_DIR}
cd ${TNN_BUILD_DIR}


cmake ${TNN_ROOT_PATH} \
-DCMAKE_BUILD_TYPE=Release \
-DDEBUG=$DEBUG \
-DTNN_TEST_ENABLE:BOOL=$TNN_TEST \
-DTNN_BENCHMARK_MODE:BOOL=$BENCHMARK \
-DTNN_TEST_ENABLE:BOOL="ON" \
-DTNN_BENCHMARK_MODE:BOOL="OFF" \
-DTNN_CPU_ENABLE:BOOL="ON" \
-DTNN_ARM_ENABLE:BOOL=$ARM \
-DTNN_OPENMP_ENABLE:BOOL=$OPENMP \
-DTNN_ARM_ENABLE:BOOL="ON" \
-DTNN_OPENMP_ENABLE:BOOL="ON" \
-DTNN_X86_ENABLE:BOOL="OFF" \
-DTNN_BUILD_SHARED:BOOL=$SHARED_LIB \
-DTNN_BUILD_SHARED:BOOL="ON" \
-DCMAKE_SYSTEM_PROCESSOR=$TARGET_ARCH \
-DTNN_ATLAS_ENABLE:BOOL="ON"
make -j8

echo ' '
echo '******************** step 3: add version attr ********************'
#添加版本信息到库文件
cd $TNN_BUILD_PATH
if [ "$SHARED_LIB" = "ON" ];then
AddAllVersionAttr "$TNN_BUILD_PATH/build_atlas/libTNN.so"
AddAllVersionAttr "$TNN_BUILD_PATH/build64/libTNN.so"
else
AddAllVersionAttr "$TNN_BUILD_PATH/build_atlas/libTNN.a"
AddAllVersionAttr "$TNN_BUILD_PATH/build64/libTNN.a"
fi
echo "Building TNN on ATLAS ..."
make -j $(nproc)


echo '******************** step 4: copy to release ********************'
cd $TNN_BUILD_PATH
mkdir -p release_atlas
cd release_atlas
rm -rf *
mkdir lib
cd ..
if [ "$SHARED_LIB" = "ON" ];then
cp -d build_atlas/libTNN.so* release_atlas/lib
else
cp build_atlas/libTNN.a release_atlas/lib
if [ -d ${TNN_INSTALL_DIR} ]
then
rm -rf ${TNN_INSTALL_DIR}
fi
cp -r ${TNN_ROOT_PATH}/include release_atlas
mkdir ${TNN_INSTALL_DIR}
mkdir ${TNN_INSTALL_DIR}/lib
mkdir ${TNN_INSTALL_DIR}/bin

cp -r ${TNN_ROOT_PATH}/include ${TNN_INSTALL_DIR}/
cp ${TNN_BUILD_DIR}/libTNN.so* ${TNN_INSTALL_DIR}/lib
cp ${TNN_BUILD_DIR}/test/TNNTest ${TNN_INSTALL_DIR}/bin

echo "build done!"
echo "Building TNN on ATLAS ... done!"
18 changes: 17 additions & 1 deletion source/tnn/device/atlas/atlas_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ Status AtlasNetwork::GetAllOutputBlobs(BlobMap &blobs) {
blobs = output_blob_map_;
return TNN_OK;
}

// @brief get atlas model id of current network
uint32_t AtlasNetwork::GetModelId() const {
return this->model_id_;
}

// @brief get atlas model desc of current network
aclmdlDesc* AtlasNetwork::GetModelDesc() const {
return this->model_desc_;
}

Status AtlasNetwork::Reshape(const InputShapesMap &inputs) {
aclError ret = aclrtSetCurrentContext(context_);
Expand Down Expand Up @@ -277,6 +287,12 @@ Status AtlasNetwork::Forward() {
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context & synchronized failed");
}

ret = aclrtSynchronizeStream(stream_);
if (ret != ACL_ERROR_NONE) {
LOGE("before forward synchronize stream failed\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed");
}

ret = aclmdlExecute(model_id_, input_, output_);
if (ret != ACL_ERROR_NONE) {
LOGE("execute model failed, modelId is %u\n", model_id_);
Expand Down Expand Up @@ -619,7 +635,7 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si
LOGE("get batch size failed\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get batch size failed");
}
output_dim0_map_[blob_name] = (int)acl_dims.dims[0] / max_batch;
output_dim0_map_[blob_name] = std::max(1, (int)acl_dims.dims[0] / max_batch);
}
// get data type
data_type = aclmdlGetOutputDataType(model_desc_, index);
Expand Down
6 changes: 6 additions & 0 deletions source/tnn/device/atlas/atlas_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class AtlasNetwork : public AbstractNetwork {
// @param blobs output blobs name map
virtual Status GetAllOutputBlobs(BlobMap &blobs);

// @brief get atlas model id of current network
uint32_t GetModelId() const;

// @brief get atlas model desc of current network
aclmdlDesc* GetModelDesc() const;

private:
// @brief load model from om file
Status LoadModelFromFile(const std::string &om_file);
Expand Down
107 changes: 97 additions & 10 deletions source/tnn/device/atlas/tnn_impl_atlas.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright 2019 Tencent. All Rights Reserved

#include "tnn_impl_atlas.h"
#include "atlas_network.h"
#include "atlas_utils.h"
#include <fstream>
#include "tnn/core/instance.h"
#include "tnn/interpreter/abstract_model_interpreter.h"
Expand Down Expand Up @@ -28,28 +30,101 @@ Status TNNImplAtlas::DeInit() {
}

Status TNNImplAtlas::AddOutput(const std::string& layer_name, int output_index) {
LOGE("Atlas not support this api (AddOutput)\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "Atlas not support this api (AddOutput)");
LOGE("AddOutput() API not supported on TNN ATLAS.\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "AddOutput() API not supported on TNN ATLAS.\n");
}

Status TNNImplAtlas::GetModelInputNames(std::vector<std::string>& input_names) {
LOGE("Atlas not support this api (GetModelInputNames)\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "Atlas not support this api (GetModelInputNames)");
if (this->model_desc_of_the_first_instance_ == nullptr) {
LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing.");
}

size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_);
std::vector<std::string> in_names_vec;
for (size_t i=0; i<num_inputs; i++) {
std::string input_name;
input_name.assign(aclmdlGetInputNameByIndex(this->model_desc_of_the_first_instance_, i));
in_names_vec.emplace_back(input_name);
}
input_names = in_names_vec;

return TNN_OK;
}

Status TNNImplAtlas::GetModelOutputNames(std::vector<std::string>& output_names) {
LOGE("Atlas not support this api (GetModelOutputNames)\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "Atlas not support this api (GetModelOutputNames)");
if (this->model_desc_of_the_first_instance_ == nullptr) {
LOGE("Fail to Get TNN Atlas ModelOutputNames, model desc missing.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelOutputNames, model desc missing.");
}

size_t num_outputs = aclmdlGetNumOutputs(this->model_desc_of_the_first_instance_);
std::vector<std::string> out_names_vec;
for (size_t i=0; i<num_outputs; i++) {
std::string output_name;
output_name.assign(aclmdlGetOutputNameByIndex(this->model_desc_of_the_first_instance_, i));
out_names_vec.emplace_back(output_name);
}
output_names = out_names_vec;

return TNN_OK;
}

Status TNNImplAtlas::GetModelInputShapesMap(InputShapesMap& shapes_map) {
LOGE("Atlas not support this api (GetModelInputShapesMap)\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "Atlas not support this api (GetModelInputShapesMap)");
if (this->model_desc_of_the_first_instance_ == nullptr) {
LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing.");
}

size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_);
InputShapesMap in_shapes_map;
for (size_t i=0; i<num_inputs; i++) {
aclmdlIODims acl_dims;
aclError acl_ret = aclmdlGetInputDims(this->model_desc_of_the_first_instance_, i, &acl_dims);
if (acl_ret != ACL_ERROR_NONE) {
LOGE("acl get input dim failed (acl error code: %d)\n", acl_ret);
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input dim falied");
}
std::string input_name;
input_name.assign(aclmdlGetInputNameByIndex(this->model_desc_of_the_first_instance_, i));
std::vector<int> in_dims;
for (int d=0; d<std::min(int(acl_dims.dimCount),7); d++) { // Max Dim Allowed is 6.
if (acl_dims.dims[d]!=0) {
in_dims.push_back(acl_dims.dims[d]);
} else {
break;
}
}
in_shapes_map[input_name] = in_dims;
}
shapes_map = in_shapes_map;

return TNN_OK;
}

Status TNNImplAtlas::GetModelInputDataTypeMap(InputDataTypeMap& data_type_map) {
LOGE("Atlas not support this api (GetModelInputDataTypeMap)\n");
return Status(TNNERR_DEVICE_NOT_SUPPORT, "Atlas not support this api (GetModelInputDataTypeMap)");
if (this->model_desc_of_the_first_instance_ == nullptr) {
LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n");
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing.");
}

size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_);
InputDataTypeMap in_dtype_map;
for (size_t i=0; i<num_inputs; i++) {
std::string input_name;
input_name.assign(aclmdlGetInputNameByIndex(this->model_desc_of_the_first_instance_, i));
aclDataType acl_dtype = aclmdlGetInputDataType(this->model_desc_of_the_first_instance_, i);
DataType tnn_dtype;
aclError acl_ret = ConvertFromAclDataTypeToTnnDataType(acl_dtype, tnn_dtype);
if (acl_ret != ACL_ERROR_NONE) {
LOGE("acl get input data type failed, maybe unsupported data type (acl error code: %d)\n", acl_ret);
return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input data type failed");
}
in_dtype_map[input_name] = tnn_dtype;
}
data_type_map = in_dtype_map;

return TNN_OK;
}

std::shared_ptr<Instance> TNNImplAtlas::CreateInst(NetworkConfig& net_config, Status& status,
Expand All @@ -63,6 +138,18 @@ std::shared_ptr<Instance> TNNImplAtlas::CreateInst(NetworkConfig& net_config, St
InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type) {
auto instance = std::make_shared<Instance>(net_config, model_config_);
status = instance->Init(interpreter_, min_inputs_shape, max_inputs_shape, inputs_data_type);

AtlasNetwork* atlas_net = reinterpret_cast<AtlasNetwork*>(instance->GetNetwork());
if (this->model_id_of_the_first_instance_ == 0) {
this->model_id_of_the_first_instance_ = atlas_net->GetModelId();
LOGD("TNNImplAtlas init the first Instance, get model id.\n");
}

if (this->model_desc_of_the_first_instance_ == nullptr) {
this->model_desc_of_the_first_instance_ = atlas_net->GetModelDesc();
LOGD("TNNImplAtlas init the first Instance, get model desc.\n");
}

return instance;
}

Expand Down
7 changes: 7 additions & 0 deletions source/tnn/device/atlas/tnn_impl_atlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifndef TNN_SOURCE_DEVICE_ATLAS_TNN_IMPL_ATLAS_H_
#define TNN_SOURCE_DEVICE_ATLAS_TNN_IMPL_ATLAS_H_

#include "acl/acl.h"
#include "tnn/core/macro.h"
#include "tnn/core/tnn_impl.h"

Expand Down Expand Up @@ -68,6 +69,12 @@ class TNNImplAtlas : public TNNImpl {

private:
std::shared_ptr<AbstractModelInterpreter> interpreter_;

// Model Desc and Model id for the first instance.
// Set when the first Effective CreateInst is called.
// Usage: Get input/output names, shapes, datatypes ... etc.
uint32_t model_id_of_the_first_instance_ = 0;
aclmdlDesc* model_desc_of_the_first_instance_ = nullptr;
};

} // namespace TNN_NS
Expand Down

0 comments on commit 97e7c1b

Please sign in to comment.