Image Classification
BiWAKO.ResNet
Basic ResNet V2 trained on ImageNet
Attributes:
Name | Type | Description |
---|---|---|
model_path |
str |
Path to the model file. If automatic download is not enabled, this path is used to save the file. |
model |
onnxruntime.InferenceSession |
Inference session for the model. |
input_name |
str |
Name of the input node. |
output_name |
str |
Name of the output node. |
input_shape |
tuple |
Shape of the input node. |
label |
dict |
Dictionary of the label. The key is the class index and the value is the class name. |
mean |
np.ndarray |
Mean of the normalization. |
var |
np.ndarray |
Variance of the normalization. |
__init__(self, model='resnet18v2')
special
Initialize ResNet
Available models: "resnet152v2" "resnet101v2" "resnet50v2" or "resnet18v2"
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
str |
Choice of the model from the table above or path to the downloaded onnx file. If the file has not been downloaded, the automatic download is triggered. Defaults to "resnet18v2". |
'resnet18v2' |
predict(self, image)
Return the prediction of the model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
Image |
Image to be predicted. Accept path or cv2 image. |
required |
Returns:
Type | Description |
---|---|
np.ndarray |
1 by 1000 array of the prediction. Softmax is not applied. |
render(self, prediction, image, topk=5, **kwargs)
Return the original image with the predicted class names
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prediction |
np.ndarray |
Prediction returned by predict(). |
required |
image |
Image |
Image to be predicted. Accept path or cv2 image. |
required |
topk |
int |
Number of classes to display with higher probability. Defaults to 5. |
5 |
Returns:
Type | Description |
---|---|
np.ndarray |
Image with the predicted class names. |