Skip to content

Commit 5826b02

Browse files
authored
[R] Drop plot tree style support. (#10989)
1 parent 7d8da41 commit 5826b02

File tree

13 files changed

+214
-275
lines changed

13 files changed

+214
-275
lines changed

R-package/R/xgb.plot.multi.trees.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#' @inheritParams xgb.plot.tree
2424
#' @param features_keep Number of features to keep in each position of the multi trees,
2525
#' by default 5.
26+
#' @param render Should the graph be rendered or not? The default is `TRUE`.
2627
#' @inherit xgb.plot.tree return
2728
#'
2829
#' @examples

R-package/R/xgb.plot.tree.R

Lines changed: 35 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -3,63 +3,38 @@
33
#' Read a tree model text dump and plot the model.
44
#'
55
#' @details
6-
#' When using `style="xgboost"`, the content of each node is visualized as follows:
7-
#' - For non-terminal nodes, it will display the split condition (number or name if
8-
#' available, and the condition that would decide to which node to go next).
9-
#' - Those nodes will be connected to their children by arrows that indicate whether the
10-
#' branch corresponds to the condition being met or not being met.
6+
#' The content of each node is visualized as follows:
7+
#' - For non-terminal nodes, it will display the split condition (number or name
8+
#' if available, and the condition that would decide to which node to go
9+
#' next).
10+
#' - Those nodes will be connected to their children by arrows that indicate
11+
#' whether the branch corresponds to the condition being met or not being met.
1112
#' - Terminal (leaf) nodes contain the margin to add when ending there.
1213
#'
13-
#' When using `style="R"`, the content of each node is visualized like this:
14-
#' - *Feature name*.
15-
#' - *Cover:* The sum of second order gradients of training data.
16-
#' For the squared loss, this simply corresponds to the number of instances in the node.
17-
#' The deeper in the tree, the lower the value.
18-
#' - *Gain* (for split nodes): Information gain metric of a split
19-
#' (corresponds to the importance of the node in the model).
20-
#' - *Value* (for leaves): Margin value that the leaf may contribute to the prediction.
21-
#'
22-
#' The tree root nodes also indicate the tree index (0-based).
23-
#'
2414
#' The "Yes" branches are marked by the "< split_value" label.
2515
#' The branches also used for missing values are marked as bold
2616
#' (as in "carrying extra capacity").
2717
#'
28-
#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR backend.
18+
#' This function uses [GraphViz](https://www.graphviz.org/) as DiagrammeR
19+
#' backend.
2920
#'
30-
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
31-
#' [setinfo()], they will be used in the output from this function.
32-
#' @param trees An integer vector of tree indices that should be used.
33-
#' The default (`NULL`) uses all trees.
34-
#' Useful, e.g., in multiclass classification to get only
35-
#' the trees of one class. *Important*: the tree index in XGBoost models
36-
#' is zero-based (e.g., use `trees = 0:2` for the first three trees).
21+
#' @param model Object of class `xgb.Booster`. If it contains feature names
22+
#' (they can be set through [setinfo()], they will be used in the
23+
#' output from this function.
24+
#' @param tree_idx An integer of the tree index that should be used. This
25+
#' is an 1-based index.
3726
#' @param plot_width,plot_height Width and height of the graph in pixels.
3827
#' The values are passed to `DiagrammeR::render_graph()`.
39-
#' @param render Should the graph be rendered or not? The default is `TRUE`.
40-
#' @param show_node_id a logical flag for whether to show node id's in the graph.
41-
#' @param style Style to use for the plot:
42-
#' - `"xgboost"`: will use the plot style defined in the core XGBoost library,
43-
#' which is shared between different interfaces through the 'dot' format. This
44-
#' style was not available before version 2.1.0 in R. It always plots the trees
45-
#' vertically (from top to bottom).
46-
#' - `"R"`: will use the style defined from XGBoost's R interface, which predates
47-
#' the introducition of the standardized style from the core library. It might plot
48-
#' the trees horizontally (from left to right).
49-
#'
50-
#' Note that `style="xgboost"` is only supported when all of the following conditions are met:
51-
#' - Only a single tree is being plotted.
52-
#' - Node IDs are not added to the graph.
53-
#' - The graph is being returned as `htmlwidget` (`render=TRUE`).
28+
#' @param with_stats Whether to dump some additional statistics about the
29+
#' splits. When this option is on, the model dump contains two additional
30+
#' values: gain is the approximate loss function gain we get in each split;
31+
#' cover is the sum of second order gradient in each node.
5432
#' @param ... Currently not used.
5533
#' @return
56-
#' The value depends on the `render` parameter:
57-
#' - If `render = TRUE` (default): Rendered graph object which is an htmlwidget of
58-
#' class `grViz`. Similar to "ggplot" objects, it needs to be printed when not
59-
#' running from the command line.
60-
#' - If `render = FALSE`: Graph object which is of DiagrammeR's class `dgr_graph`.
61-
#' This could be useful if one wants to modify some of the graph attributes
62-
#' before rendering the graph with `DiagrammeR::render_graph()`.
34+
#'
35+
#' Rendered graph object which is an htmlwidget of ' class `grViz`. Similar to
36+
#' "ggplot" objects, it needs to be printed when not running from the command
37+
#' line.
6338
#'
6439
#' @examples
6540
#' data(agaricus.train, package = "xgboost")
@@ -73,119 +48,35 @@
7348
#' objective = "binary:logistic"
7449
#' )
7550
#'
76-
#' # plot the first tree, using the style from xgboost's core library
77-
#' # (this plot should look identical to the ones generated from other
78-
#' # interfaces like the python package for xgboost)
79-
#' xgb.plot.tree(model = bst, trees = 1, style = "xgboost")
80-
#'
81-
#' # plot all the trees
82-
#' xgb.plot.tree(model = bst, trees = NULL)
51+
#' # plot the first tree
52+
#' xgb.plot.tree(model = bst, tree_idx = 1)
8353
#'
84-
#' # plot only the first tree and display the node ID:
85-
#' xgb.plot.tree(model = bst, trees = 0, show_node_id = TRUE)
8654
#'
8755
#' \dontrun{
8856
#' # Below is an example of how to save this plot to a file.
89-
#' # Note that for export_graph() to work, the {DiagrammeRsvg}
90-
#' # and {rsvg} packages must also be installed.
9157
#'
9258
#' library(DiagrammeR)
9359
#'
94-
#' gr <- xgb.plot.tree(model = bst, trees = 0:1, render = FALSE)
95-
#' export_graph(gr, "tree.pdf", width = 1500, height = 1900)
96-
#' export_graph(gr, "tree.png", width = 1500, height = 1900)
60+
#' gr <- xgb.plot.tree(model = bst, tree_idx = 1)
61+
#' htmlwidgets::saveWidget(gr, 'plot.html')
9762
#' }
9863
#'
9964
#' @export
100-
xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
101-
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
65+
xgb.plot.tree <- function(model,
66+
tree_idx = 1,
67+
plot_width = NULL,
68+
plot_height = NULL,
69+
with_stats = FALSE, ...) {
10270
check.deprecation(...)
10371
if (!inherits(model, "xgb.Booster")) {
104-
stop("model: Has to be an object of class xgb.Booster")
72+
stop("model has to be an object of the class xgb.Booster")
10573
}
106-
10774
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
108-
stop("DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
109-
}
110-
111-
style <- as.character(head(style, 1L))
112-
stopifnot(style %in% c("R", "xgboost"))
113-
if (style == "xgboost") {
114-
if (NROW(trees) != 1L || !render || show_node_id) {
115-
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
116-
}
117-
118-
txt <- xgb.dump(model, dump_format = "dot")
119-
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
120-
}
121-
122-
dt <- xgb.model.dt.tree(model = model, trees = trees)
123-
124-
dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)]
125-
if (show_node_id)
126-
dt[, label := paste0(ID, ": ", label)]
127-
dt[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
128-
dt[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
129-
dt[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
130-
# in order to draw the first tree on top:
131-
dt <- dt[order(-Tree)]
132-
133-
nodes <- DiagrammeR::create_node_df(
134-
n = nrow(dt),
135-
ID = dt$ID,
136-
label = dt$label,
137-
fillcolor = dt$filledcolor,
138-
shape = dt$shape,
139-
data = dt$Feature,
140-
fontcolor = "black")
141-
142-
if (nrow(dt[Feature != "Leaf"]) != 0) {
143-
edges <- DiagrammeR::create_edge_df(
144-
from = match(rep(dt[Feature != "Leaf", c(ID)], 2), dt$ID),
145-
to = match(dt[Feature != "Leaf", c(Yes, No)], dt$ID),
146-
label = c(
147-
dt[Feature != "Leaf", paste("<", Split)],
148-
rep("", nrow(dt[Feature != "Leaf"]))
149-
),
150-
style = c(
151-
dt[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")],
152-
dt[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]
153-
),
154-
rel = "leading_to")
155-
} else {
156-
edges <- NULL
75+
stop("The DiagrammeR package is required for xgb.plot.tree", call. = FALSE)
15776
}
15877

159-
graph <- DiagrammeR::create_graph(
160-
nodes_df = nodes,
161-
edges_df = edges,
162-
attr_theme = NULL
163-
)
164-
graph <- DiagrammeR::add_global_graph_attrs(
165-
graph = graph,
166-
attr_type = "graph",
167-
attr = c("layout", "rankdir"),
168-
value = c("dot", "LR")
78+
txt <- xgb.dump(model, dump_format = "dot", with_stats = with_stats)
79+
DiagrammeR::grViz(
80+
txt[[tree_idx]], width = plot_width, height = plot_height
16981
)
170-
graph <- DiagrammeR::add_global_graph_attrs(
171-
graph = graph,
172-
attr_type = "node",
173-
attr = c("color", "style", "fontname"),
174-
value = c("DimGray", "filled", "Helvetica")
175-
)
176-
graph <- DiagrammeR::add_global_graph_attrs(
177-
graph = graph,
178-
attr_type = "edge",
179-
attr = c("color", "arrowsize", "arrowhead", "fontname"),
180-
value = c("DimGray", "1.5", "vee", "Helvetica")
181-
)
182-
183-
if (!render) return(invisible(graph))
184-
185-
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
18682
}
187-
188-
# Avoid error messages during CRAN check.
189-
# The reason is that these variables are never declared
190-
# They are mainly column names inferred by Data.table...
191-
globalVariables(c("Feature", "ID", "Cover", "Gain", "Split", "Yes", "No", "Missing", ".", "shape", "filledcolor", "label"))

R-package/R/xgb.train.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
#' objective is non-convex.
104104
#'
105105
#' See the tutorials [Custom Objective and Evaluation Metric](https://xgboost.readthedocs.io/en/stable/tutorials/custom_metric_obj.html)
106-
#' and [Advanced Usage of Custom Objectives](https://xgboost.readthedocs.io/en/stable/tutorials/advanced_custom_obj)
106+
#' and [Advanced Usage of Custom Objectives](https://xgboost.readthedocs.io/en/latest/tutorials/advanced_custom_obj.html)
107107
#' for more information about custom objectives.
108108
#'
109109
#' - `base_score`: The initial prediction score of all instances, global bias. Default: 0.5.

R-package/man/xgb.plot.multi.trees.Rd

Lines changed: 6 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)