Home  >  Article  >  Backend Development  >  How to Extract Decision Rules from scikit-learn Decision Trees?

How to Extract Decision Rules from scikit-learn Decision Trees?

Mary-Kate Olsen
Mary-Kate OlsenOriginal
2024-10-27 09:14:03870browse

How to Extract Decision Rules from scikit-learn Decision Trees?

Decision Rule Extraction from scikit-learn Decision Trees

Problem Statement:

Can the decision rules underlying a trained decision tree model be extracted as a textual list?

Solution:

Using the tree_to_code function, it is possible to generate a valid Python function that represents the decision rules of a scikit-learn decision tree:

<code class="python">from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)</code>

Example:

For a decision tree that attempts to return its input (a number between 0 and 10), the tree_to_code function would print the following Python function:

<code class="python">def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]</code>

Caveats:

Avoid the following common issues:

  • Do not rely on tree_.threshold == -2 to identify leaf nodes; check tree.feature or tree.children_* instead.
  • Specify feature names correctly, avoiding integers that may correspond to undefined features.
  • Single if statements suffice in the recursive function.

The above is the detailed content of How to Extract Decision Rules from scikit-learn Decision Trees?. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn