Home  >  Article  >  Backend Development  >  How to Extract Decision Rules from a Scikit-Learn Decision Tree?

How to Extract Decision Rules from a Scikit-Learn Decision Tree?

DDD
DDDOriginal
2024-10-28 02:26:02887browse

How to Extract Decision Rules from a Scikit-Learn Decision Tree?

Extracting Decision Rules from Scikit-Learn Decision Tree

Decision trees, a widely used machine learning algorithm, provide insights by modeling decision-making processes as a hierarchical structure of rules. However, extracting these decision rules explicitly can be challenging. This article outlines a comprehensive approach to extracting textual decision rules from a trained Scikit-Learn decision tree.

Python Code for Decision Rule Extraction

The following Python code snippet utilizes the underlying data structures of Scikit-Learn decision trees to traverse and generate human-readable decision paths:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)

Creating a Valid Python Function

This code traverses the tree recursively, printing out each conditional split and threshold. The result is a valid Python function that effectively emulates the decision-making process of the trained decision tree.

Example Output

For instance, consider a tree that attempts to return its input, a number between 0 and 10. The generated Python function would look like this:

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Benefits and Cautions

This method provides a clear and testable representation of the tree's decision rules. However, note that the code assumes that each node in the tree is a binary decision node. If your decision tree contains non-binary decision nodes, you will need to adapt the code accordingly.

The above is the detailed content of How to Extract Decision Rules from a Scikit-Learn Decision Tree?. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn