Tensorflow:estimator训练

Tensorflow:estimator训练学习流程:Estimator封装了对机器学习不同阶段的控制,用户无需不断的为新机器学习任务重复编写训练、评估、预测的代码。可以专注于对网络结构的控制。数据导入:Estimator的数据导入也是由input_fn独立定义的。例如,用户可以非常方便的只通过改变input_fn的定义,来使用相同的网络结构学习不同的数据。网络结构:Estimator的网络结构是在model_fn中独…

大家好,欢迎来到IT知识分享网。

学习流程:Estimator 封装了对机器学习不同阶段的控制,用户无需不断的为新机器学习任务重复编写训练、评估、预测的代码。可以专注于对网络结构的控制。
数据导入:Estimator 的数据导入也是由 input_fn 独立定义的。例如,用户可以非常方便的只通过改变 input_fn 的定义,来使用相同的网络结构学习不同的数据。
网络结构:Estimator 的网络结构是在 model_fn 中独立定义的,用户创建的任何网络结构都可以在 Estimator 的控制下进行机器学习。这可以允许用户很方便的使用别人定义好的 model_fn。model_fn模型函数必须要有features, mode两个参数,可自己选择加入labels(可以把label也放进features中)。最后要返回特定的tf.estimator.EstimatorSpec()。模型有三个阶段都共用的正向传播部分,和由mode值来控制返回不同tf.estimator.EstimatorSpec的三个分支。

 

 

训练

输出信息解析

[Tensorflow:模型训练tensorflow.train]

在训练或评估中利用Hook打印中间信息

hooks:如果不送值,则训练过程中不会显示字典中的数值。

steps:指定了训练多少次,如果不送值,则训练到dataset API遍历完数据集为止。

max_steps:指定了最大训练次数。

# 在训练或评估的循环中,每50次print出一次字典中的数值
tensors_to_log = {“probabilities”: “softmax_tensor”}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])

 

early stopping

函数原型

tf.contrib.estimator.stop_if_no_increase_hook(
    estimator,
    metric_name,
    max_steps_without_increase,
    eval_dir=None,
    min_steps=0,
    run_every_secs=60,
    run_every_steps=None
)

‘stop_if_no_decrease_hook’这个模块在tf 1.10才加入。hook可以看作一个管理训练过程的工具,比如说这里就是设置提前终止的条件,变量loss在100000步以内没有下降即终止,实际上更广泛的用法是用在对测试集的f1值上。

参数

metric_name: str类型,比如loss或者accuracy. hook中的参数metric_name=’acc’就是tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)中的eval_metric_ops,即tf模块代码中通过的for step, metrics in read_eval_metrics(eval_dir).items()得到的。但是训练好checkpoint后,就不能改,需要删除之前训练好的模型,重新训练。

max_steps_without_increase: int,如果没有增加的最大长是多少,如果超过了这个最大步长metric还是没有增加那么就会停止。

eval_dir:默认是使用estimator.eval_dir目录,用于存放评估的summary file。

min_steps:训练的最小步长,如果训练小于这个步长那么永远都不会停止。

run_every_secs和run_every_steps:表示多长时间获得步长调用一次should_stop_fn。

示例

        metrics = {

            ‘acc’: tf.metrics.accuracy(tf.argmax(labels), tf.argmax(pred_ids)),
            ‘precision’: tf.metrics.precision(tf.argmax(labels), tf.argmax(pred_ids)),
            ‘precision_’: tf_metrics.precision(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            ‘recall’: tf.metrics.recall(tf.argmax(labels), tf.argmax(pred_ids)),
            ‘recall_’: tf_metrics.recall(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            ‘f1_’: tf_metrics.f1(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            ‘auc’: tf.metrics.auc(labels, pred_ids),
        }

        for metric_name, op in metrics.items():
            tf.summary.scalar(metric_name, op[1])

        ”’ train and evaluate ”’
        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
        elif mode == tf.estimator.ModeKeys.TRAIN:
            train_op = tf.train.AdamOptimizer().minimize(loss=loss,
                                                         global_step=tf.train.get_or_create_global_step())

        hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, ‘f1’, max_steps_without_increase=1000,
                                                             min_steps=8000, run_every_secs=120) 
        train_spec = tf.estimator.TrainSpec(input_fn=train_inpf, hooks=[hook])

[简书tf.estimate]

from: -柚子皮-

ref:

 

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/23463.html

(0)

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信