Matek Matek - 1 year ago 157
R Question

R mlr package - is it possible to save all models from Parameter tuning?

I wanted to ask whether it is possible to save all models which were created during parameter tuning, e.g. with

function. I'd like to save models from every fold of cross validation for every hyperparameter set.

I can see that there is
parameter for both
function but I can't find one in
or similar function and I cant really figure out a way to imitate this behaviour using other functions (I'm new to mlr).

Is there a way to do this?

PS I know it might sound crazy, nevertheless I need it for some internal validation.

PS2 Unfortunately, it seems there is no "mlr" tag yet and I don't have enough rep to create one.

Answer Source

I guess there are shorter solutions but the following is not so hacky. We use a Wrapper to get hold of the model so we can save it in a list in the global environment. Alternatively you can change that line to something more sophisticated and save it on the hard disk. This might be worthwhile because models can get quite big.


# Define the tuning problem
ps = makeParamSet(
  makeDiscreteParam("C", values = 2^(-2:2)),
  makeDiscreteParam("sigma", values = 2^(-2:2))
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("Holdout")
lrn = makeLearner("classif.ksvm")

# Define a wrapper to save all models that were trained with it
makeSaveWrapper = function(learner) {
    id = paste0(learner$id, "save", sep = "."),
    type = learner$type,
    next.learner = learner,
    par.set = makeParamSet(),
    par.vals = list(),
    learner.subclass = "SaveWrapper",
    model.subclass = "SaveModel")

trainLearner.SaveWrapper = function(.learner, .task, .subset, ...) {
  m = train(.learner$next.learner, task = .task, subset = .subset)
  stored.models <<- c(stored.models, list(m)) # not very efficient, maybe you want to save on hard disk here?
  mlr:::makeChainModel(next.model = m, cl = "SaveModel")

predictLearner.SaveWrapper = function(.learner, .model, .newdata, ...) {
  NextMethod(.newdata = .newdata)

stored.models = list() # initialize empty list to store results
lrn.saver = makeSaveWrapper(lrn)

res = tuneParams(lrn.saver, task = iris.task, resampling = rdesc, par.set = ps, control = ctrl)

stored.models[[1]] # the normal mlr trained model
stored.models[[1]]$learner.model # the underlying model
getLearnerParVals(stored.models[[1]]$learner) # the hyper parameter settings
stored.models[[1]]$subset # the indices used to train the model
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download