Abstract

Deep Neural Networks (DNNs) usually work in an end-to-end manner. This makes the trained DNNs easy to use, but they remain an ambiguous decision process for every test case. Unfortunately, the interpretability of decisions is crucial in some scenarios, such as medical or financial data mining and decision-making. In this paper, we propose a Tree-Network-Tree (TNT) learning framework for explainable decision-making, where the knowledge is alternately transferred between the tree model and DNNs. Specifically, the proposed TNT learning framework exerts the advantages of different models at different stages: (1) a novel James–Stein Decision Tree (JSDT) is proposed to generate better knowledge representations for DNNs, especially when the input data are in low-frequency or low-quality; (2) the DNNs output high-performing prediction result from the knowledge embedding inputs and behave as a teacher model for the following tree model; and (3) a novel distillable Gradient Boosted Decision Tree (dGBDT) is proposed to learn interpretable trees from the soft labels and make a comparable prediction as DNNs do. Extensive experiments on various machine learning tasks demonstrated the effectiveness of the proposed method.

Highlights

  • Deep Neural Networks (DNNs) have achieved great success in many multimodal prediction tasks such as cross-modal embedding [1], image caption [2], and visual question answering [3]

  • As typical end-to-end models, DNNs usually work in a black-box paradigm [4,5] and the decision process is unknown for the test case, which limits the application of DNNs for some scenarios requiring explanation, such as medical or financial data mining and decision-making [6,7]

  • For the applications on tabular or structured data, we can adopt an ensemble of decision trees, such as the random forest [26], Gradient Boosted Decision Trees (GBDT), or Gradient Boosting Machine (GBM) [27], to learn the knowledge patterns

Read more

Summary

Introduction

Deep Neural Networks (DNNs) have achieved great success in many multimodal prediction tasks such as cross-modal embedding [1], image caption [2], and visual question answering [3]. Given the DNNs (TNT) learning Distillable yE y, TNT first trains a tree-based model (e.g., random forest, GBDT, or input data X and prediction target y y our proposed JSDT) on the training dataset { X, y} and extracts the decision path of all trees to form an embedding representation XE. The advantages of TNT come from three parts: the first tree model is robust for representing the dark knowledge in input data; the DNN model ensures good prediction performance; and the decision paths can be explicitly extracted from a distillable tree, it is interpretable for decision-making. Based on the proposed TNT framework, we further explored the different ways of implementation, including the choices of data flow, and the potential end-to-end differentiable structures We evaluated all these possible models on various machine learning tasks and conducted extensive experiments to show the interpretability of TNT for the medical diagnosis scenarios. We propose the TNT framework and verify it with extensive experiments

Deep Models in Black Box
Tree Models
Knowledge Distillation
Proposed Tree Models
James–Stein Decision Trees
Distillable Gradient Boosted Decision Trees
Proposed TNT Framework
Tree-Network-Tree Learning Framework
Further Exploration y yE y
Experiments
Datasets and Setup
Robustness and Performance
Methods
Interpretability
Partial Dependence Plots
Classification Activation Mapping
Findings
Conclusions
Full Text
Published version (Free)

Talk to us

Join us for a 30 min session where you can share your feedback and ask us any queries you have

Schedule a call