Heim  >  Artikel  >  Backend-Entwicklung  >  Beherrschen Sie schnell die Hook-Funktion in Python

Beherrschen Sie schnell die Hook-Funktion in Python

coldplay.xixi
coldplay.xixinach vorne
2020-12-11 17:11:218295Durchsuche

Die Spalte „Python-Tutorial“ stellt die Hook-Hook-Funktion in Python vor. Viele kostenlose Lernempfehlungen finden Sie im Python-Tutorial 1. Ja Haken

Ich höre oft das Konzept der Hook-Funktion. Kürzlich habe ich mir das Open-Source-Framework mmdetection zur Zielerkennung angesehen, und es enthält auch viele Hook-Programmiermethoden. Welche Funktion hat der Haken? Beherrschen Sie schnell die Hook-Funktion in Python

Was ist Hook? Unter Haken versteht man, wie der Name schon sagt, einen Haken, der dazu dient, bei Bedarf etwas aufzuhängen. Die spezifische Erklärung lautet: Die Hook-Funktion besteht darin, unsere eigene implementierte Hook-Funktion zu einem bestimmten Zeitpunkt an den Ziel-Mount-Punkt anzuhängen. Die Rolle der Hook-Funktion Das Konzept des Hooks ist beispielsweise in der Windows-Desktop-Softwareentwicklung sehr verbreitet, insbesondere der Mechanismus verschiedener Ereignisauslöser. Im C++-MFC-Programm ist es beispielsweise erforderlich, die Zeit zu überwachen Wenn die linke Maustaste gedrückt wird, stellt MFC eine onLeftKeyDown-Hook-Funktion bereit. Offensichtlich implementiert das MFC-Framework die spezifische Operation von onLeftKeyDown nicht für uns, sondern stellt uns nur einen Hook zur Verfügung. Wenn wir es verarbeiten müssen, müssen wir diese Funktion nur neu schreiben und die Operation, die wir benötigen, in diesem Hook bereitstellen Nicht mounten, der MFC-Ereignisauslösemechanismus führt leere Vorgänge aus. Aus dem Obigen ist ersichtlich, dass

die Hook-Funktion eine vordefinierte Funktion im Programm ist. Diese Funktion befindet sich im ursprünglichen Programmprozess (Freilegen eines Hooks).
Wir müssen den Hook definieren Der vorhandene Prozess Um ein bestimmtes Detail im Funktionsblock zu implementieren, müssen wir unsere Implementierung in den Hook einbinden oder registrieren, um die Hook-Funktion für das Ziel verfügbar zu machen. Der Hook ist ein Programmiermechanismus und steht nicht in direktem Zusammenhang mit der spezifischen Sprache Beziehung

    Wenn Sie sich den Designmodus ansehen, ist der Hook-Modus eine Erweiterung der Vorlagenmethode
  • Der Hook wird nur verwendet, wenn er registriert ist, also im ursprünglichen Programmprozess, wenn keine Registrierung oder Montage erfolgt , die Ausführung ist leer (dh es wird keine Operation ausgeführt)
  • Dieser Artikel verwendet Python, um die Implementierung von Hooks zu erläutern, und zeigt die Anwendungsfälle von Hooks in Open-Source-Projekten. Die Funktion der Hook-Funktion ähnelt einem anderen Namen, den wir oft hören: Rückruffunktion (Callback-Funktion) und kann nach demselben Modell verstanden werden.

  • 2. Hook-Implementierungsbeispiel

    Soweit ich weiß, wird die Hook-Funktion am häufigsten in irgendeiner Art von Prozessverarbeitung verwendet. Dieser Prozess besteht oft aus vielen Schritten. In diesen Schritten werden häufig Hook-Funktionen eingebunden, um Flexibilität für das Hinzufügen zusätzlicher Vorgänge zu bieten.
  • Das Folgende ist ein einfaches Beispiel. Der Zweck dieses Beispiels besteht darin, eine universelle Funktion zum Einfügen von Inhalten in die Warteschlange zu implementieren. Es gibt 2 Prozessschritte

  • Sie müssen die Daten filtern, bevor Sie sie in die Warteschlange einfügen input_filter_fn

  • In die Warteschlange einfügen insert_queue

  • class ContentStash(object):
        """
        content stash for online operation
        pipeline is
        1. input_filter: filter some contents, no use to user
        2. insert_queue(redis or other broker): insert useful content to queue
        """
    
        def __init__(self):
            self.input_filter_fn = None
            self.broker = []
    
        def register_input_filter_hook(self, input_filter_fn):
            """
            register input filter function, parameter is content dict
            Args:
                input_filter_fn: input filter function
    
            Returns:
    
            """
            self.input_filter_fn = input_filter_fn
    
        def insert_queue(self, content):
            """
            insert content to queue
            Args:
                content: dict
    
            Returns:
    
            """
            self.broker.append(content)
    
        def input_pipeline(self, content, use=False):
            """
            pipeline of input for content stash
            Args:
                use: is use, defaul False
                content: dict
    
            Returns:
    
            """
            if not use:
                return
    
            # input filter
            if self.input_filter_fn:
                _filter = self.input_filter_fn(content)
                
            # insert to queue
            if not _filter:
                self.insert_queue(content)
    
    
    
    # test
    ## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列
    def input_filter_hook(content):
        """
        test input filter hook
        Args:
            content: dict
    
        Returns: None or content
    
        """
        if content.get('time') is None:
            return
        else:
            return content
    
    
    # 原有程序
    content = {'filename': 'test.jpg', 'b64_file': "#test", 'data': {"result": "cat", "probility": 0.9}}
    content_stash = ContentStash('audit', work_dir='')
    
    # 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content
    content_stash.register_input_filter_hook(input_filter_hook)
    
    # 执行流程
    content_stash.input_pipeline(content)
  • 3 . Der Hook ist Anwendungen in Open-Source-Frameworks

    3.1 Keras

    Im Deep-Learning-Trainingsprozess wird die Hook-Funktion vollständig reflektiert.
Ein Trainingsprozess (ohne Datenvorbereitung) fragt den Trainingssatz mehrmals ab, jedes Mal wird er als Epoche bezeichnet und jede Epoche wird für das Training in mehrere Stapel unterteilt. Der Prozess ist unterteilt in:

Beherrschen Sie schnell die Hook-Funktion in PythonTraining starten

Vor dem Training einer Epoche

    Vor dem Training einer Charge
  • input_filter_fn

  • 插入队列 insert_queue

@keras_export('keras.callbacks.Callback')
class Callback(object):
  """Abstract base class used to build new callbacks.

  Attributes:
      params: Dict. Training parameters
          (eg. verbosity, batch size, number of epochs...).
      model: Instance of `keras.models.Model`.
          Reference of the model being trained.

  The `logs` dictionary that callback methods
  take as argument will contain keys for quantities relevant to
  the current batch or epoch (see method-specific docstrings).
  """

  def __init__(self):
    self.validation_data = None  # pylint: disable=g-missing-from-attributes
    self.model = None
    # Whether this Callback should only run on the chief worker in a
    # Multi-Worker setting.
    # TODO(omalleyt): Make this attr public once solution is stable.
    self._chief_worker_only = None
    self._supports_tf_logs = False

  def set_params(self, params):
    self.params = params

  def set_model(self, model):
    self.model = model

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_batch_begin(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_begin`."""

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_batch_end(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_end`."""

  @doc_controls.for_subclass_implementers
  def on_epoch_begin(self, epoch, logs=None):
    """Called at the start of an epoch.

    Subclasses should override for any actions to run. This function should only
    be called during TRAIN mode.

    Arguments:
        epoch: Integer, index of epoch.
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_epoch_end(self, epoch, logs=None):
    """Called at the end of an epoch.

    Subclasses should override for any actions to run. This function should only
    be called during TRAIN mode.

    Arguments:
        epoch: Integer, index of epoch.
        logs: Dict, metric results for this training epoch, and for the
          validation epoch if validation is performed. Validation result keys
          are prefixed with `val_`.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_train_batch_begin(self, batch, logs=None):
    """Called at the beginning of a training batch in `fit` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.train_step`. Typically,
          the values of the `Model`'s metrics are returned.  Example:
          `{'loss': 0.2, 'accuracy': 0.7}`.
    """
    # For backwards compatibility.
    self.on_batch_begin(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_train_batch_end(self, batch, logs=None):
    """Called at the end of a training batch in `fit` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """
    # For backwards compatibility.
    self.on_batch_end(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_test_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `evaluate` methods.

    Also called at the beginning of a validation batch in the `fit`
    methods, if validation data is provided.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.test_step`. Typically,
          the values of the `Model`'s metrics are returned.  Example:
          `{'loss': 0.2, 'accuracy': 0.7}`.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_test_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `evaluate` methods.

    Also called at the end of a validation batch in the `fit`
    methods, if validation data is provided.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_predict_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `predict` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.predict_step`,
          it typically returns a dict with a key 'outputs' containing
          the model's outputs.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_predict_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `predict` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """

  @doc_controls.for_subclass_implementers
  def on_train_begin(self, logs=None):
    """Called at the beginning of training.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_train_end(self, logs=None):
    """Called at the end of training.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently the output of the last call to `on_epoch_end()`
          is passed to this argument for this method but that may change in
          the future.
    """

  @doc_controls.for_subclass_implementers
  def on_test_begin(self, logs=None):
    """Called at the beginning of evaluation or validation.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_test_end(self, logs=None):
    """Called at the end of evaluation or validation.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently the output of the last call to
          `on_test_batch_end()` is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_begin(self, logs=None):
    """Called at the beginning of prediction.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_end(self, logs=None):
    """Called at the end of prediction.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  def _implements_train_batch_hooks(self):
    """Determines if this Callback should be called for each train batch."""
    return (not generic_utils.is_default(self.on_batch_begin) or
            not generic_utils.is_default(self.on_batch_end) or
            not generic_utils.is_default(self.on_train_batch_begin) or
            not generic_utils.is_default(self.on_train_batch_end))

3. hook在开源框架中的应用

3.1 keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练

  • 训练一个epoch前

  • 训练一个batch前

  • 训练一个batch后

  • 训练一个epoch后

  • 评估验证集

  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

# Container that configures and calls `tf.keras.Callback`s.
      if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=True,
            add_progbar=verbose != 0,
            model=self,
            verbose=verbose,
            epochs=epochs,
            steps=data_handler.inferred_steps)

      ## I am hook
      callbacks.on_train_begin()
      training_logs = None
      # Handle fault-tolerance for multi-worker.
      # TODO(omalleyt): Fix the ordering issues that mean this has to
      # happen after `callbacks.on_train_begin`.
      data_handler._initial_epoch = (  # pylint: disable=protected-access
          self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
      for epoch, iterator in data_handler.enumerate_epochs():
        self.reset_metrics()
        callbacks.on_epoch_begin(epoch)
        with data_handler.catch_stop_iteration():
          for step in data_handler.steps():
            with trace.Trace(
                'TraceContext',
                graph_type='train',
                epoch_num=epoch,
                step_num=step,
                batch_size=batch_size):
              ## I am hook
              callbacks.on_train_batch_begin(step)
              tmp_logs = train_function(iterator)
              if data_handler.should_sync:
                context.async_wait()
              logs = tmp_logs  # No error, now safe to assign to logs.
              end_step = step + data_handler.step_increment
              callbacks.on_train_batch_end(end_step, logs)
        epoch_logs = copy.copy(logs)

        # Run validation.

        ## I am hook
        callbacks.on_epoch_end(epoch, epoch_logs)

这些钩子的原始程序是在模型训练流程中的

keras源码位置: tensorflowpythonkerasenginetraining.py

部分摘录如下(## I am hook):

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders

    # put model on gpus

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        # Support batch_size > 1 in validation
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

3.2 mmdetection

mmdetection是一个目标检测的开源框架,集成了许多不同的目标检测深度学习算法(pytorch版),如faster-rcnn, fpn, retianet等。里面也大量使用了hook,暴露给应用实现流程中具体部分。

详见https://github.com/open-mmlab/mmdetection

Nach dem Training einer Charge🎜🎜🎜🎜Nach dem Training einer Epoche🎜🎜 🎜🎜Bewertung Überprüfungssatz🎜🎜🎜🎜Training beenden🎜🎜🎜🎜Diese Schritte sind in den Prozess des Trainings eines Datenstapels eingestreut. Diese können als Hook-Funktionen verstanden werden, z. B. in Nach dem Training für eine Epoche müssen wir das trainierte Modell speichern und nach Training beenden das beste Modell verwenden, um den Testsatzeffekt usw. auszuführen. 🎜🎜Die Hook-Funktion wird in Keras durch verschiedene Rückruffunktionen implementiert. Fügen Sie hier eine übergeordnete Rückrufklasse ein. Beim Anpassen müssen Sie nur diese übergeordnete Klasse erben und die Hooks implementieren, die Sie interessieren. 🎜rrreee🎜Die ursprünglichen Programme dieser Hooks befinden sich im Modelltrainingsprozess mmdetection 🎜mmdetection ist ein Open-Source-Framework zur Zielerkennung, das viele verschiedene Deep-Learning-Algorithmen zur Zielerkennung (Pytorch-Version) integriert, wie z. B. Faster-RCNN, FPN, Retianet usw. Hooks werden auch häufig verwendet, um bestimmte Teile des Anwendungsimplementierungsprozesses offenzulegen. 🎜🎜Weitere Informationen finden Sie unter https://github.com/open-mmlab/mmdetection🎜

这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders

    # put model on gpus

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        # Support batch_size > 1 in validation
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

4. 总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现

  • 挂载或者注册时, 流程执行就会执行这个钩子函数

  • 回调函数和hook函数功能上是一致的

  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

相关免费学习推荐:php编程(视频)

Das obige ist der detaillierte Inhalt vonBeherrschen Sie schnell die Hook-Funktion in Python. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:csdn.net. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen