Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Model ? #4

Open
Alissonerdx opened this issue May 22, 2018 · 2 comments
Open

Test Model ? #4

Alissonerdx opened this issue May 22, 2018 · 2 comments

Comments

@Alissonerdx
Copy link

I'm sorry for the question but how do I test the generated model?

@Lebhoryi
Copy link

Lebhoryi commented Nov 5, 2018

me too, have u solved it, bro?

@Asuna88
Copy link

Asuna88 commented May 14, 2021

I'm sorry for the question but how do I test the generated model?

在train.py里面添加如下代码实现预测:
`

def generate_test(batch, size):

ptest = 'data/test_'
datagen2 = ImageDataGenerator(rescale=1. / 255)
test_generator = datagen2.flow_from_directory(
    ptest,
    target_size=(size, size),
    batch_size=batch,
    class_mode='categorical')

count = 0
for root, dirs, files in os.walk(ptest):
    for each in files:
        count += 1
return test_generator,count

def test(weights,batch=1, size=224)->'result':

size = 224
test_gen,count = generate_test(batch, size)
with CustomObjectScope({'relu6': keras.layers.ReLU(max_value=6, name="relu6"),'DepthwiseConv2D': keras.layers.DepthwiseConv2D}):
    model = load_model(weights)
print("test")

predictions = model.predict(test_gen,steps=count//batch,verbose=1)
evaluate_result = model.evaluate(test_gen,steps=count//batch,verbose=1) 
print("预测类别结果:",np.argmax(predictions,axis=-1),"预测结果的shape:",predictions.shape)
print(dict(zip(model.metrics_names, evaluate_result)))
print("Done!")`

最后在主函数里面添加test()函数即可

if name == 'main':

  main(sys.argv)
  weights = 'trained_model/all_weights.h5'
  test(weights)

简单说就是:
model.fit()用于训练
hist = model.fit( train_generator, validation_data=validation_generator, steps_per_epoch=count1 // batch, validation_steps=count2 // batch, epochs=epochs, callbacks=[earlystop])

model.predict用于预测结果值
predictions = model.predict(test_gen,steps=count//batch,verbose=1)

model.evaluate用于得到指标值
evaluate_result = model.evaluate(test_gen,steps=count//batch,verbose=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants