|
3 | 3 | #' Read a tree model text dump and plot the model.
|
4 | 4 | #'
|
5 | 5 | #' @details
|
6 |
| -#' When using `style="xgboost"`, the content of each node is visualized as follows: |
7 |
| -#' - For non-terminal nodes, it will display the split condition (number or name if |
8 |
| -#' available, and the condition that would decide to which node to go next). |
9 |
| -#' - Those nodes will be connected to their children by arrows that indicate whether the |
10 |
| -#' branch corresponds to the condition being met or not being met. |
| 6 | +#' The content of each node is visualized as follows: |
| 7 | +#' - For non-terminal nodes, it will display the split condition (number or name |
| 8 | +#' if available, and the condition that would decide to which node to go |
| 9 | +#' next). |
| 10 | +#' - Those nodes will be connected to their children by arrows that indicate |
| 11 | +#' whether the branch corresponds to the condition being met or not being met. |
11 | 12 | #' - Terminal (leaf) nodes contain the margin to add when ending there.
|
12 | 13 | #'
|
13 |
| -#' When using `style="R"`, the content of each node is visualized like this: |
14 |
| -#' - *Feature name*. |
15 |
| -#' - *Cover:* The sum of second order gradients of training data. |
16 |
| -#' For the squared loss, this simply corresponds to the number of instances in the node. |
17 |
| -#' The deeper in the tree, the lower the value. |
18 |
| -#' - *Gain* (for split nodes): Information gain metric of a split |
19 |
| -#' (corresponds to the importance of the node in the model). |
20 |
| -#' - *Value* (for leaves): Margin value that the leaf may contribute to the prediction. |
21 |
| -#' |
22 |
| -#' The tree root nodes also indicate the tree index (0-based). |
23 |
| -#' |
24 | 14 | #' The "Yes" branches are marked by the "< split_value" label.
|
25 | 15 | #' The branches also used for missing values are marked as bold
|
26 | 16 | #' (as in "carrying extra capacity").
|
27 | 17 | #'
|
28 |
| -#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR backend. |
| 18 | +#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR |
| 19 | +#' backend. |
29 | 20 | #'
|
30 |
| -#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through |
31 |
| -#' [setinfo()], they will be used in the output from this function. |
32 |
| -#' @param trees An integer vector of tree indices that should be used. |
33 |
| -#' The default (`NULL`) uses all trees. |
34 |
| -#' Useful, e.g., in multiclass classification to get only |
35 |
| -#' the trees of one class. *Important*: the tree index in XGBoost models |
36 |
| -#' is zero-based (e.g., use `trees = 0:2` for the first three trees). |
| 21 | +#' @param model Object of class `xgb.Booster`. If it contains feature names |
| 22 | +#' (they can be set through [setinfo()], they will be used in the |
| 23 | +#' output from this function. |
| 24 | +#' @param tree_idx An integer of the tree index that should be used. This |
| 25 | +#' is an 1-based index. |
37 | 26 | #' @param plot_width,plot_height Width and height of the graph in pixels.
|
38 | 27 | #' The values are passed to `DiagrammeR::render_graph()`.
|
39 |
| -#' @param render Should the graph be rendered or not? The default is `TRUE`. |
40 |
| -#' @param show_node_id a logical flag for whether to show node id's in the graph. |
41 |
| -#' @param style Style to use for the plot: |
42 |
| -#' - `"xgboost"`: will use the plot style defined in the core XGBoost library, |
43 |
| -#' which is shared between different interfaces through the 'dot' format. This |
44 |
| -#' style was not available before version 2.1.0 in R. It always plots the trees |
45 |
| -#' vertically (from top to bottom). |
46 |
| -#' - `"R"`: will use the style defined from XGBoost's R interface, which predates |
47 |
| -#' the introducition of the standardized style from the core library. It might plot |
48 |
| -#' the trees horizontally (from left to right). |
49 |
| -#' |
50 |
| -#' Note that `style="xgboost"` is only supported when all of the following conditions are met: |
51 |
| -#' - Only a single tree is being plotted. |
52 |
| -#' - Node IDs are not added to the graph. |
53 |
| -#' - The graph is being returned as `htmlwidget` (`render=TRUE`). |
| 28 | +#' @param with_stats Whether to dump some additional statistics about the |
| 29 | +#' splits. When this option is on, the model dump contains two additional |
| 30 | +#' values: gain is the approximate loss function gain we get in each split; |
| 31 | +#' cover is the sum of second order gradient in each node. |
54 | 32 | #' @param ... Currently not used.
|
55 | 33 | #' @return
|
56 |
| -#' The value depends on the `render` parameter: |
57 |
| -#' - If `render = TRUE` (default): Rendered graph object which is an htmlwidget of |
58 |
| -#' class `grViz`. Similar to "ggplot" objects, it needs to be printed when not |
59 |
| -#' running from the command line. |
60 |
| -#' - If `render = FALSE`: Graph object which is of DiagrammeR's class `dgr_graph`. |
61 |
| -#' This could be useful if one wants to modify some of the graph attributes |
62 |
| -#' before rendering the graph with `DiagrammeR::render_graph()`. |
| 34 | +#' |
| 35 | +#' Rendered graph object which is an htmlwidget of ' class `grViz`. Similar to |
| 36 | +#' "ggplot" objects, it needs to be printed when not running from the command |
| 37 | +#' line. |
63 | 38 | #'
|
64 | 39 | #' @examples
|
65 | 40 | #' data(agaricus.train, package = "xgboost")
|
|
73 | 48 | #' objective = "binary:logistic"
|
74 | 49 | #' )
|
75 | 50 | #'
|
76 |
| -#' # plot the first tree, using the style from xgboost's core library |
77 |
| -#' # (this plot should look identical to the ones generated from other |
78 |
| -#' # interfaces like the python package for xgboost) |
79 |
| -#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost") |
80 |
| -#' |
81 |
| -#' # plot all the trees |
82 |
| -#' xgb.plot.tree(model = bst, trees = NULL) |
| 51 | +#' # plot the first tree |
| 52 | +#' xgb.plot.tree(model = bst, tree_idx = 1) |
83 | 53 | #'
|
84 |
| -#' # plot only the first tree and display the node ID: |
85 |
| -#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE) |
86 | 54 | #'
|
87 | 55 | #' \dontrun{
|
88 | 56 | #' # Below is an example of how to save this plot to a file.
|
89 |
| -#' # Note that for export_graph() to work, the {DiagrammeRsvg} |
90 |
| -#' # and {rsvg} packages must also be installed. |
91 | 57 | #'
|
92 | 58 | #' library(DiagrammeR)
|
93 | 59 | #'
|
94 |
| -#' gr <- xgb.plot.tree(model = bst, trees = 0:1, render = FALSE) |
95 |
| -#' export_graph(gr, "tree.pdf", width = 1500, height = 1900) |
96 |
| -#' export_graph(gr, "tree.png", width = 1500, height = 1900) |
| 60 | +#' gr <- xgb.plot.tree(model = bst, tree_idx = 1) |
| 61 | +#' htmlwidgets::saveWidget(gr, 'plot.html') |
97 | 62 | #' }
|
98 | 63 | #'
|
99 | 64 | #' @export
|
100 |
| -xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, |
101 |
| - render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) { |
| 65 | +xgb.plot.tree <- function(model, |
| 66 | + tree_idx = 1, |
| 67 | + plot_width = NULL, |
| 68 | + plot_height = NULL, |
| 69 | + with_stats = FALSE, ...) { |
102 | 70 | check.deprecation(...)
|
103 | 71 | if (!inherits(model, "xgb.Booster")) {
|
104 |
| - stop("model: Has to be an object of class xgb.Booster") |
| 72 | + stop("model has to be an object of the class xgb.Booster") |
105 | 73 | }
|
106 |
| - |
107 | 74 | if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
|
108 |
| - stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE) |
109 |
| - } |
110 |
| - |
111 |
| - style <- as.character(head(style, 1L)) |
112 |
| - stopifnot(style %in% c("R", "xgboost")) |
113 |
| - if (style == "xgboost") { |
114 |
| - if (NROW(trees) != 1L || !render || show_node_id) { |
115 |
| - stop("style='xgboost' is only supported for single, rendered tree, without node IDs.") |
116 |
| - } |
117 |
| - |
118 |
| - txt <- xgb.dump(model, dump_format = "dot") |
119 |
| - return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height)) |
120 |
| - } |
121 |
| - |
122 |
| - dt <- xgb.model.dt.tree(model = model, trees = trees) |
123 |
| - |
124 |
| - dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)] |
125 |
| - if (show_node_id) |
126 |
| - dt[, label := paste0(ID, ": ", label)] |
127 |
| - dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)] |
128 |
| - dt[, shape := "rectangle"][Feature == "Leaf", shape := "oval"] |
129 |
| - dt[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"] |
130 |
| - # in order to draw the first tree on top: |
131 |
| - dt <- dt[order(-Tree)] |
132 |
| - |
133 |
| - nodes <- DiagrammeR::create_node_df( |
134 |
| - n = nrow(dt), |
135 |
| - ID = dt$ID, |
136 |
| - label = dt$label, |
137 |
| - fillcolor = dt$filledcolor, |
138 |
| - shape = dt$shape, |
139 |
| - data = dt$Feature, |
140 |
| - fontcolor = "black") |
141 |
| - |
142 |
| - if (nrow(dt[Feature != "Leaf"]) != 0) { |
143 |
| - edges <- DiagrammeR::create_edge_df( |
144 |
| - from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID), |
145 |
| - to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID), |
146 |
| - label = c( |
147 |
| - dt[Feature != "Leaf", paste("<", Split)], |
148 |
| - rep("", nrow(dt[Feature != "Leaf"])) |
149 |
| - ), |
150 |
| - style = c( |
151 |
| - dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")], |
152 |
| - dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")] |
153 |
| - ), |
154 |
| - rel = "leading_to") |
155 |
| - } else { |
156 |
| - edges <- NULL |
| 75 | + stop("The DiagrammeR package is required for xgb.plot.tree", call. = FALSE) |
157 | 76 | }
|
158 | 77 |
|
159 |
| - graph <- DiagrammeR::create_graph( |
160 |
| - nodes_df = nodes, |
161 |
| - edges_df = edges, |
162 |
| - attr_theme = NULL |
163 |
| - ) |
164 |
| - graph <- DiagrammeR::add_global_graph_attrs( |
165 |
| - graph = graph, |
166 |
| - attr_type = "graph", |
167 |
| - attr = c("layout", "rankdir"), |
168 |
| - value = c("dot", "LR") |
| 78 | + txt <- xgb.dump(model, dump_format = "dot", with_stats = with_stats) |
| 79 | + DiagrammeR::grViz( |
| 80 | + txt[[tree_idx]], width = plot_width, height = plot_height |
169 | 81 | )
|
170 |
| - graph <- DiagrammeR::add_global_graph_attrs( |
171 |
| - graph = graph, |
172 |
| - attr_type = "node", |
173 |
| - attr = c("color", "style", "fontname"), |
174 |
| - value = c("DimGray", "filled", "Helvetica") |
175 |
| - ) |
176 |
| - graph <- DiagrammeR::add_global_graph_attrs( |
177 |
| - graph = graph, |
178 |
| - attr_type = "edge", |
179 |
| - attr = c("color", "arrowsize", "arrowhead", "fontname"), |
180 |
| - value = c("DimGray", "1.5", "vee", "Helvetica") |
181 |
| - ) |
182 |
| - |
183 |
| - if (!render) return(invisible(graph)) |
184 |
| - |
185 |
| - DiagrammeR::render_graph(graph, width = plot_width, height = plot_height) |
186 | 82 | }
|
187 |
| - |
188 |
| -# Avoid error messages during CRAN check. |
189 |
| -# The reason is that these variables are never declared |
190 |
| -# They are mainly column names inferred by Data.table... |
191 |
| -globalVariables(c("Feature", "ID", "Cover", "Gain", "Split", "Yes", "No", "Missing", ".", "shape", "filledcolor", "label")) |
0 commit comments