Ort::Session
은 inference하기 위한 거의 마지막 작업이다. Session은 모델을 로드하고 tensor를 입력 받아 inference한다.
Ort::Session
생성 및 사용
이전 포스트에서 다뤘던 Ort::Env
, Ort::SessionOptions
와 model path를 인자로 받는다. 특히 Ort::Env
의 경우 Session과 생명주기가 같아야 한다. 즉 Env가 소멸되지 않아야 Session을 사용할 수 있다.
Session 생성
이전 포스트에서 다룬 객체를 초기화했다면 다음의 생성자를 이용해 Session을 생성할 수 있다.
Ort::Session(Ort::Env&, const char *model_path, const Ort::SessionOptions&);
이 생성자는 모델 파일을 불러오지만 메모리 상에 존재하는 모델을 이용하는 생성자도 있다. 필요하면 차후에 작성하도록 하겠다. (2022-01-29)
Model 정보 얻기
필요한 Model의 정보는 다음과 같다.
- input/output의 개수
- input/output의 모양
- input/output의 이름
- input/output의 type
- input/output의 개수
size_t input_number = session.GetInputCount(); size_t output_number = session.GetOutputCount();
매우 간단하게 사용가능하다.
- input/output의 이름
Ort::AllocatorWithDefaultOptions allocator; // 0번째 input의 이름 const char *input_name = session.GetInputName(0, allocator); // 0번째 output의 이름 const char *output_name = session.GetOutputName(0, allocator); // C가 생각나는 끔찍한 방식의 메모리 해제 allocator.Free((void*) input_name); allocator.Free((void*) output_name);
allocator로 할당한 이름을
Free
로 해제하는 이 끔찍한 상황은 ONNX Runtime C++ API가 C wrapper이기 때문에 발생한 문제이다. - input/output의 type, shape
Ort::TypeInfo input_type_info = session.GetInputTypeInfo(0); // 0번째 입력에 대한 type 정보 Ort::TensorTypeAndShapeInfo input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
이렇게 모델이 요구하는 type에 대한 정보를 알아낼 수 있다.
Ort::TypeInfo
는 이것 이외에도 다른 정보도 제공하지만 생략하겠다.ONNXTensorElementDataType elem_type = input_tensor_info.GetElementType(); if (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT == elem_type) { std::cout << "tensor which consists of float!\n"; }
위의 방법으로 요구하는 입력의 자료형을 알아낼 수 있다.
ONNXTensorElementDataType의 종류는 다음과 같다.(enum ONNXTensorElementDataType)
- ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
- ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
- ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8
- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8
- ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16
- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16
- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
- ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
- ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
- ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
- ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE
- ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32
- ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64
- ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
- ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
- ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
다만 지원하지 않는 type이 끼여 있는데 complex64, complex128은 지원하지 않는다고 한다.
std::vector<int64_t> input_shape = input_tensor_info.GetShape();
이렇게 tensor의 모양까지 알아낼 수 있었다. output의 경우 처음 과정에
GetOutputInfo
를 사용하면 된다.
이제 남은 부분은 inference다!
inference는 다음과 같이 진행된다.
- 전처리
- tensor 생성
- inference
- 결과 사용