• Home
  • About
    • C-H Kim's blog photo

      C-H Kim's blog

      Anyone is welcome who loving C / C++, Python, java... ect.

    • Learn More
    • Email
    • LinkedIn
    • Github
    • StackOverflow
    • Steam
  • Posts
    • All Posts
    • All Tags
  • Projects

ONNX Runtime C++ Ort::Session 사용하기

29 Jan 2022

Reading time ~1 minute

Ort::Session은 inference하기 위한 거의 마지막 작업이다. Session은 모델을 로드하고 tensor를 입력 받아 inference한다.

Ort::Session 생성 및 사용

이전 포스트에서 다뤘던 Ort::Env, Ort::SessionOptions와 model path를 인자로 받는다. 특히 Ort::Env의 경우 Session과 생명주기가 같아야 한다. 즉 Env가 소멸되지 않아야 Session을 사용할 수 있다.

Session 생성

이전 포스트에서 다룬 객체를 초기화했다면 다음의 생성자를 이용해 Session을 생성할 수 있다.

Ort::Session(Ort::Env&, const char *model_path, const Ort::SessionOptions&);

이 생성자는 모델 파일을 불러오지만 메모리 상에 존재하는 모델을 이용하는 생성자도 있다. 필요하면 차후에 작성하도록 하겠다. (2022-01-29)

Model 정보 얻기

필요한 Model의 정보는 다음과 같다.

  • input/output의 개수
  • input/output의 모양
  • input/output의 이름
  • input/output의 type
  1. input/output의 개수
    size_t input_number = session.GetInputCount();
    size_t output_number = session.GetOutputCount();
    

    매우 간단하게 사용가능하다.

  2. input/output의 이름
    Ort::AllocatorWithDefaultOptions allocator;
    
    // 0번째 input의 이름
    const char *input_name = session.GetInputName(0, allocator);
    // 0번째 output의 이름
    const char *output_name = session.GetOutputName(0, allocator);
    
    // C가 생각나는 끔찍한 방식의 메모리 해제
    allocator.Free((void*) input_name);
    allocator.Free((void*) output_name);
    

    allocator로 할당한 이름을 Free로 해제하는 이 끔찍한 상황은 ONNX Runtime C++ API가 C wrapper이기 때문에 발생한 문제이다.

  3. input/output의 type, shape
    Ort::TypeInfo input_type_info = session.GetInputTypeInfo(0); // 0번째 입력에 대한 type 정보
    Ort::TensorTypeAndShapeInfo input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
    

    이렇게 모델이 요구하는 type에 대한 정보를 알아낼 수 있다. Ort::TypeInfo는 이것 이외에도 다른 정보도 제공하지만 생략하겠다.

    ONNXTensorElementDataType elem_type = input_tensor_info.GetElementType();
    
    if (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT == elem_type)
      {
        std::cout << "tensor which consists of float!\n";
      }
    

    위의 방법으로 요구하는 입력의 자료형을 알아낼 수 있다.

    ONNXTensorElementDataType의 종류는 다음과 같다.(enum ONNXTensorElementDataType)

    • ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
    • ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16

    다만 지원하지 않는 type이 끼여 있는데 complex64, complex128은 지원하지 않는다고 한다.

    std::vector<int64_t> input_shape = input_tensor_info.GetShape();
    

    이렇게 tensor의 모양까지 알아낼 수 있었다. output의 경우 처음 과정에 GetOutputInfo를 사용하면 된다.

이제 남은 부분은 inference다!
inference는 다음과 같이 진행된다.

  1. 전처리
  2. tensor 생성
  3. inference
  4. 결과 사용
Index: ONNX Runtime C++ API 사용법 Next: ONNX Runtime C++ Inference 하기


C++ONNX Runtime Share Tweet