大家好,欢迎来到IT知识分享网。
学习流程:Estimator 封装了对机器学习不同阶段的控制,用户无需不断的为新机器学习任务重复编写训练、评估、预测的代码。可以专注于对网络结构的控制。
数据导入:Estimator 的数据导入也是由 input_fn 独立定义的。例如,用户可以非常方便的只通过改变 input_fn 的定义,来使用相同的网络结构学习不同的数据。
网络结构:Estimator 的网络结构是在 model_fn 中独立定义的,用户创建的任何网络结构都可以在 Estimator 的控制下进行机器学习。这可以允许用户很方便的使用别人定义好的 model_fn。model_fn模型函数必须要有features, mode两个参数,可自己选择加入labels(可以把label也放进features中)。最后要返回特定的tf.estimator.EstimatorSpec()。模型有三个阶段都共用的正向传播部分,和由mode值来控制返回不同tf.estimator.EstimatorSpec的三个分支。
训练
输出信息解析
[Tensorflow:模型训练tensorflow.train]
在训练或评估中利用Hook打印中间信息
hooks:如果不送值,则训练过程中不会显示字典中的数值。
steps:指定了训练多少次,如果不送值,则训练到dataset API遍历完数据集为止。
max_steps:指定了最大训练次数。
# 在训练或评估的循环中,每50次print出一次字典中的数值
tensors_to_log = {“probabilities”: “softmax_tensor”}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
early stopping
函数原型
tf.contrib.estimator.stop_if_no_increase_hook(
estimator,
metric_name,
max_steps_without_increase,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None
)
‘stop_if_no_decrease_hook’这个模块在tf 1.10才加入。hook可以看作一个管理训练过程的工具,比如说这里就是设置提前终止的条件,变量loss在100000步以内没有下降即终止,实际上更广泛的用法是用在对测试集的f1值上。
参数
metric_name: str类型,比如loss或者accuracy. hook中的参数metric_name=’acc’就是tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)中的eval_metric_ops,即tf模块代码中通过的for step, metrics in read_eval_metrics(eval_dir).items()得到的。但是训练好checkpoint后,就不能改,需要删除之前训练好的模型,重新训练。
max_steps_without_increase: int,如果没有增加的最大长是多少,如果超过了这个最大步长metric还是没有增加那么就会停止。
eval_dir:默认是使用estimator.eval_dir目录,用于存放评估的summary file。
min_steps:训练的最小步长,如果训练小于这个步长那么永远都不会停止。
run_every_secs和run_every_steps:表示多长时间获得步长调用一次should_stop_fn。
示例
metrics = {
‘acc’: tf.metrics.accuracy(tf.argmax(labels), tf.argmax(pred_ids)),
‘precision’: tf.metrics.precision(tf.argmax(labels), tf.argmax(pred_ids)),
‘precision_’: tf_metrics.precision(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
‘recall’: tf.metrics.recall(tf.argmax(labels), tf.argmax(pred_ids)),
‘recall_’: tf_metrics.recall(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
‘f1_’: tf_metrics.f1(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
‘auc’: tf.metrics.auc(labels, pred_ids),
}
for metric_name, op in metrics.items():
tf.summary.scalar(metric_name, op[1])
”’ train and evaluate ”’
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
elif mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer().minimize(loss=loss,
global_step=tf.train.get_or_create_global_step())
…
hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, ‘f1’, max_steps_without_increase=1000,
min_steps=8000, run_every_secs=120)
train_spec = tf.estimator.TrainSpec(input_fn=train_inpf, hooks=[hook])
[简书tf.estimate]
from: -柚子皮-
ref:
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/23463.html