#定义一个类来处理图像提取标签
class ImageTagExtractor(object):
def __init__(self,model_file,codebook_file):
with open(model_file,'rb') as f:
self.erf = pickle.load(f)
with open(codebook_file,'rb') as f:
self.kmeans,self.centroids = pickle.load(f)
# 用训练好的模型来预测输出
def predict(self,img,scaling_size):
img = bf.resize_image(img,scaling_size)
feature_vector = bf.BagOfWords().construct_feature(img,self.kmeans,self.centroids)
print(self.erf.classify(feature_vector))
image_tag = self.erf.classify(feature_vector)[0]
return image_tag
主程序
import argparse
import pickle as pickle
import cv2
import numpy as np
from python_machine_learn.c9 import build_feature as bf
from python_machine_learn.c9.train import ERFTrainer
from sklearn.ensemble import ExtraTreesClassifier
from sklearn import preprocessing
from sklearn.cluster import KMeans
if __name__=='__main__':
input_image = 'image/car_2.jpg'
model_file = 'model/train_1.pkl'
codebook_file = 'codebook/9_8.pkl'
input_image = cv2.imread(input_image)
scaling_size = 200
print("\nOutput:", ImageTagExtractor(model_file,codebook_file).predict(input_image, scaling_size))
#定义一个参数解析器
def build_arg_parser():
parser = argparse.ArgumentParser(description='Extracts features from each line and classifies the data')
parser.add_argument('--input-image',dest='input_image',required = True,help = 'Input image to be classifed')
parser.add_argument('--model-file',dest='model_file',required =True,help= 'Input file containing the trained model')
parser.add_argument('--codebook-file',dest='codebook_file',required = True,help = 'Input the containing the codebook')
return parser
if __name__=='__main__':
args = build_arg_parser().parse_args()
model_file = args.model_file
codebook_file = args.codebook_file
input_image = cv2.imread(args.input_image)
scaling_size = 200
print("\nOutput:", ImageTagExtractor(model_file,codebook_file).predict(input_image, scaling_size))
结果