Package EIX is the set of tools to explore the structure
of XGBoost and lightGBM models. It includes functions finding strong
interactions and also checking importance of single variables and
interactions by usage different measures. EIX consists
several functions to visualize results.
Almost all EIX functions require only two parameters: a
XGBoost or LightGBM model and data table used as training dataset. The
exceptions are the waterfall function and its plot. The
first one requires parameters: a XGBoost model and observation, which
prediction has to be explained). These two functions support only
XGBoost models. All plots are created with package ggplot2.
Most of them use plot theme theme_mi2 from
DALEX.
This vignette shows usage of EIX package. It lets to
explain XGBoost prediction model concerning departures of employees from
company using HR_data. Dataset was taken from kaggle and consists 14999
observations and 10 variables. The dataset is also available in package
EIX and there it is described more precisely.
#devtools :: install_github("ModelOriented/EIX")
library("EIX")
set.seed(4)
knitr::kable(head(HR_data))| satisfaction_level | last_evaluation | number_project | average_montly_hours | time_spend_company | Work_accident | left | promotion_last_5years | sales | salary |
|---|---|---|---|---|---|---|---|---|---|
| 0.38 | 0.53 | 2 | 157 | 3 | 0 | 1 | 0 | sales | low |
| 0.80 | 0.86 | 5 | 262 | 6 | 0 | 1 | 0 | sales | medium |
| 0.11 | 0.88 | 7 | 272 | 4 | 0 | 1 | 0 | sales | medium |
| 0.72 | 0.87 | 5 | 223 | 5 | 0 | 1 | 0 | sales | low |
| 0.37 | 0.52 | 2 | 159 | 3 | 0 | 1 | 0 | sales | low |
| 0.41 | 0.50 | 2 | 153 | 3 | 0 | 1 | 0 | sales | low |
To create correct XGBoost model, remember to change categorical features to factors and next change the data frame to sparse matrix. The categorical features are one-hot encoded.
library("Matrix")
sparse_matrix <- sparse.model.matrix(left ~ . - 1, data = HR_data)
head(sparse_matrix)## 6 x 19 sparse Matrix of class "dgCMatrix"
##
## 1 0.38 0.53 2 157 3 . . . . . . . . . 1 . . 1 .
## 2 0.80 0.86 5 262 6 . . . . . . . . . 1 . . . 1
## 3 0.11 0.88 7 272 4 . . . . . . . . . 1 . . . 1
## 4 0.72 0.87 5 223 5 . . . . . . . . . 1 . . 1 .
## 5 0.37 0.52 2 159 3 . . . . . . . . . 1 . . 1 .
## 6 0.41 0.50 2 153 3 . . . . . . . . . 1 . . 1 .
Package EIX uses table, which was generated by
xgboost::xgb.model.dt.tree with information about trees,
their nodes and leaves.
library("xgboost")
param <- list(objective = "binary:logistic", max_depth = 2)
xgb_model <- xgboost(sparse_matrix, label = HR_data[, left] == 1, objective = "binary:logistic", max_depth = 2, nrounds = 50, verbosity = 0)## Warning in throw_err_or_depr_msg("Parameter '", match_old, "' has been renamed
## to '", : Parameter 'label' has been renamed to 'y'. This warning will become an
## error in a future version.
| Tree | Node | ID | Feature | Split | Yes | No | Missing | Gain | Cover |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0-0 | satisfaction_level | 0.47 | 0-1 | 0-2 | 0-2 | 4302.0146500 | 2720.8074 |
| 0 | 1 | 0-1 | number_project | 3.00 | 0-3 | 0-4 | 0-4 | 1224.4882800 | 758.7931 |
| 0 | 2 | 0-2 | time_spend_company | 5.00 | 0-5 | 0-6 | 0-6 | 1769.9522700 | 1962.0144 |
| 0 | 3 | 0-3 | Leaf | NA | NA | NA | NA | 1.0564061 | 315.9975 |
| 0 | 4 | 0-4 | Leaf | NA | NA | NA | NA | 0.2831307 | 442.7956 |
| 0 | 5 | 0-5 | Leaf | NA | NA | NA | NA | -0.3695500 | 1602.4811 |
Function xgboost::xgb.importance shows importance of
single variables. EIX adds new measures of variables’
importance and shows also importance of interactions.
| Feature | Gain | Cover | Frequency |
|---|---|---|---|
| satisfaction_level | 0.4454403 | 0.3139911 | 0.3000000 |
| time_spend_company | 0.2334310 | 0.1867941 | 0.1692308 |
| number_project | 0.1626944 | 0.1286851 | 0.1461538 |
| last_evaluation | 0.0739606 | 0.1451796 | 0.1461538 |
| average_montly_hours | 0.0682220 | 0.1635922 | 0.1769231 |
| Work_accident | 0.0085868 | 0.0285113 | 0.0230769 |
The lollipop plot is used to visualize the model in such
way that the most important variables and interactions are visible.
On the x-axis, there are tree numbers and on the y-axis there is Gain measure for each node. One segment is one tree in the model and each point is one node. On the plot there are all nodes, which are not leaves. Shape of points signifies depth of node. All roots on the plot are connected by a red line. If in the same segment there is a variable with a higher depth above the variable with a lower depth, it means that interaction occurs.
There is opportunity to set a different way of labeling. On the plot
we can see the most important variables in roots (horizontal labels),
and interactions (vertical labels), this is option
labels = "topAll" which is default. Moreover, there are two
additional options: labels = "roots" - for variables in
roots only, labels = "interactions" for interactions only.
The numbers of labels visible on the plot you can change by parametr
threshold (range from 0 to 1, default 0.1). The plot is on
a logarithmic scale because the initial trees usually are the most
important. You can change the scale of the plot by setting the parameter
log_scale = FALSE.
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## ℹ The deprecated feature was likely used in the EIX package.
## Please report the issue at <https://github.com/ModelOriented/EIX/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
We can consider interactions in two ways. In first approach we can
explore all pairs of variable, which occur in the model one above the
other. This approach is not the best one, because we cannot distinguish
if pair of variables are real interaction or not. In this approach high
gain of pair can be a result of high gain of down variable (child). To
explore pairs of variables you can generate table with them using
function interactions with parametr
option = "pairs". This table includes Gain
measure and number of occurrences of pairs. You can also use the
function plot to visualize Gain
measure.
## Parent Child sumGain frequency
## <char> <char> <num> <int>
## 1: satisfaction_level time_spend_company 4025.8846 4
## 2: satisfaction_level number_project 3140.5500 3
## 3: satisfaction_level satisfaction_level 910.0398 7
## 4: last_evaluation satisfaction_level 871.0141 6
## 5: last_evaluation average_montly_hours 694.8370 1
## 6: last_evaluation number_project 610.4091 2
The interactions plot is a matrix plot with a child from
the pair on the x-axis and the parent on the y-axis. The color of the
square at the intersection of two variables means value of
sumGain measure. The darker square, the higher
sumGain of variable pairs. The range of
sumGain measure is divided into four equal parts:
very low, low, medium, high.
In second approach, to find strong interactions, we can consider only
these pairs of variables, where variable on the bottom (child) has
higher gain than variable on the top (parent). We can also create
ranking of interactions using function importance with
parameter option = "interactions". More details in the next
section.
## Parent Child sumGain frequency
## <char> <char> <num> <int>
## 1: satisfaction_level time_spend_company 1903.1588 2
## 2: last_evaluation satisfaction_level 706.4140 4
## 3: last_evaluation average_montly_hours 694.8370 1
## 4: last_evaluation number_project 603.4144 1
## 5: satisfaction_level number_project 475.4262 1
## 6: number_project last_evaluation 387.5845 3
For exploring variables’ and interactions’ importance there are three
functions in EIX package: importance, its
plot with parameter radar = TRUE or
radar = FALSE. With EIX package we can compare
importance of single variables and interactions. The functions
importance can return three kinds of outputs, depending on
the opt parameter:
option = "variables" - it consists only single
variables
option = "interactions"- only interactions
option = "both"- output shows importance both single
variables and interactions.
NOTE: option = "both" is not direct connection
option = "variables" and
option = "interactions", because values of variable
importance measure, which were in the interactions, are not included in
importance of single variable.
In EIX the following measures are available:
EIX package gives additionally measures of variables
importance for single variable:
The function importance returns a table with all
available importance measures for given option. The table is sorted by
descending value of sumGain.
The function plot with parameter
radar = FALSE and a result from the importance
function as an argument shows two measures of importance, which can be
chosen by xmeasure and ymeasure parameters. By
parameter top we can decide how many positions will be
included in the plot.
## Feature sumGain sumCover meanGain meanCover
## <char> <num> <num> <num> <num>
## 1: satisfaction_level 10460.0 32980 326.90 1031.0
## 2: time_spend_company 3504.0 18910 194.70 1050.0
## 3: number_project 3006.0 11290 214.70 806.2
## 4: satisfaction_level:time_spend_company 1903.0 2878 951.60 1439.0
## 5: last_evaluation 1146.0 15330 81.85 1095.0
## 6: average_montly_hours 893.5 18570 42.55 884.2
## frequency mean5Gain
## <num> <num>
## 1: 32 1729.0
## 2: 18 574.4
## 3: 14 566.6
## 4: 2 951.6
## 5: 14 165.7
## 6: 21 105.0
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## ℹ The deprecated feature was likely used in the EIX package.
## Please report the issue at <https://github.com/ModelOriented/EIX/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
The function plot with parameter
radar = TRUE enables to compare different measures of
variables and interactions importance on the radar plot from
ggiraphExtra package. Bellow I attach the example of radar
plot. On the outside of the circle there are names of variables or
interactions. Colored lines represent various measures of importance.
The positions on the plot are sorted decreasing. The variable with the
highest sumGain value is on the right of 12 o’clock.
Next the sumGain value decreases in a clockwise
direction. On the plot it is possible to change place, where the
features names start by parameter text_start_point (range
from 0 to 1, default 0.5), and size of this text by parametrer
text_size.
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the EIX package.
## Please report the issue at <https://github.com/ModelOriented/EIX/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Ignoring unknown labels:
## • fill : "Measures"
For single prediction explaining package EIX uses two
packages: xgboostExplainer i breakDown. The
package xgboostExplainer is a tool to interpreting
prediction of xgboost model. The package EIX uses its code
and modifies it to include interactions. The methodology of plot
creation comes from package breakDown.
The function waterfall returns table with variables’
impact on the prediction of the model. Depending on the parameter
option, the table includes interactions
(option = "interactions"- default) or does not
(option = "variables"). The function plot with
waterfall object as an argument visualizes this table. On
the y-axis there are: intercept (it is the probability that random
variable from training dataset will be 1), variables (which have an
impact on prediction) and final prognosis of the model. On the x-axis
there is log-odds of impact each variables.
data <- HR_data[9,]
new_observation <- sparse_matrix[9,]
wf<-waterfall(xgb_model, new_observation, data, option = "interactions")
wf## contribution
## xgboost: intercept -0.240
## xgboost: last_evaluation = 1 1.818
## xgboost: time_spend_company = 5 1.270
## xgboost: satisfaction_level:time_spend_company = 0.89:5 0.764
## xgboost: Work_accident = 0 -0.571
## xgboost: satisfaction_level = 0.89 -0.359
## xgboost: average_montly_hours = 224 -0.271
## xgboost: last_evaluation:average_montly_hours = 1:224 0.215
## xgboost: last_evaluation:time_spend_company = 1:5 0.210
## xgboost: salary = 2 0.167
## xgboost: average_montly_hours:last_evaluation = 224:1 -0.164
## xgboost: time_spend_company:satisfaction_level = 5:0.89 0.161
## xgboost: number_project:average_montly_hours = 5:224 0.115
## xgboost: number_project = 5 -0.100
## xgboost: number_project:last_evaluation = 5:1 0.092
## xgboost: satisfaction_level:number_project = 0.89:5 -0.090
## xgboost: last_evaluation:satisfaction_level = 1:0.89 0.060
## xgboost: average_montly_hours:number_project = 224:5 0.050
## xgboost: average_montly_hours:satisfaction_level = 224:0.89 0.036
## xgboost: prediction 3.164
## `height` was translated to `width`.