Maison  >  Article  >  développement back-end  >  Comment extraire les règles de décision des arbres de décision scikit-learn ?

Comment extraire les règles de décision des arbres de décision scikit-learn ?

Mary-Kate Olsen
Mary-Kate Olsenoriginal
2024-10-27 09:14:03979parcourir

How to Extract Decision Rules from scikit-learn Decision Trees?

Extraction de règles de décision à partir des arbres de décision scikit-learn

Énoncé du problème :

Le les règles de décision sous-jacentes à un modèle d'arbre de décision entraîné doivent-elles être extraites sous forme de liste textuelle ?

Solution :

À l'aide de la fonction tree_to_code, il est possible de générer une fonction Python valide qui représente les règles de décision d'un arbre de décision scikit-learn :

<code class="python">from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)</code>

Exemple :

Pour un arbre de décision qui tente de renvoyer son entrée (un nombre compris entre 0 et 10), la fonction tree_to_code imprimerait la fonction Python suivante :

<code class="python">def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]</code>

Avertissements :

Évitez les problèmes courants suivants :

  • Ne comptez pas sur tree_.threshold == -2 pour identifier les nœuds feuilles ; vérifiez plutôt tree.feature ou tree.children_*.
  • Spécifiez correctement les noms de fonctionnalités, en évitant les entiers qui peuvent correspondre à des fonctionnalités non définies.
  • Des instructions simples si suffisent dans la fonction récursive.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn