Create model with R that have a validation
voxsim
New Altair Community Member
Answers
-
I merged decisiontree and execute scripts
This run well .. i put the result in a variable called result. The script in R is an implementation of C4.5 and the result is an associative array.
package com.rapidminer.operator.r;
import java.net.URL;
import java.util.List;
import org.rosuda.REngine.REXP;
import org.rosuda.REngine.REXPMismatchException;
import org.rosuda.REngine.RList;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.parameter.ParameterTypeText;
import com.rapidminer.parameter.TextType;
import com.rapidminer.tools.r.RPlotPainter;
import com.rapidminer.tools.r.RSession;
import com.rapidminer.tools.r.RSessionListener;
import com.rapidminer.tools.r.RSessionManager;
import com.rapidminer.tools.r.translation.RTranslations;
import com.rapidminer.tools.r.translation.RTranslator;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.tree.GreaterSplitCondition;
import com.rapidminer.operator.learner.tree.LessEqualsSplitCondition;
import com.rapidminer.operator.learner.tree.NominalSplitCondition;
import com.rapidminer.operator.learner.tree.SplitCondition;
import com.rapidminer.operator.learner.tree.Tree;
import com.rapidminer.operator.learner.tree.TreeModel;
public class RDecisionTreeLearner extends AbstractLearner implements RSessionListener {
public static final String PARAMETER_R_SCRIPT = "script";
public static final String PARAMETER_INPUT = "input";
public static final String PARAMETER_VARIABLE_NAME = "name_of_variable";
private transient boolean isErrorOccurred = false;
private transient String errorOccured;
public RDecisionTreeLearner(OperatorDescription description) {
super(description);
}
@Override
public Class<? extends PredictionModel> getModelClass() {
return TreeModel.class;
}
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeText(PARAMETER_R_SCRIPT, "This script will be executed on one of the available R servers.", TextType.PLAIN, true);
type.setExpert(false);
types.add(type);
type = new ParameterTypeString(PARAMETER_INPUT, "This assigns each input port a variable name. If the type of input object is supported by the R translation, it will be accessible under this variable name.");
type.setExpert(false);
types.add(type);
return types;
}
public void changeTreeToLeaf(Tree node, ExampleSet exampleSet) {
Attribute label = exampleSet.getAttributes().getLabel();
exampleSet.recalculateAttributeStatistics(label);
int labelValue = (int)exampleSet.getStatistics(label, Statistics.MODE);
if(labelValue != -1) {
String labelName = label.getMapping().mapIndex(labelValue);
node.setLeaf(labelName);
for (String value : label.getMapping().getValues()) {
int count = (int)exampleSet.getStatistics(label, Statistics.COUNT, value);
node.addCount(value, count);
}
}
}
public Tree getChilds(Tree current, Attribute bestAttribute, ExampleSet exampleSet, REXP payload, int depth) throws REXPMismatchException {
int i = 0;
int j = -1;
SplittedExampleSet splitted = null;
if(payload.isString()) {
String[] names = payload._attr().asList().at(0).asStrings();
String[] values = payload.asStrings();
while (i < names.length) {
String name = names;
String value = values;
if (name!=null && value!=null) {
SplitCondition condition = null;
if (bestAttribute.isNominal()) {
condition = new NominalSplitCondition(bestAttribute, value);
splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute);
j += 1;
} else {
double bestSplitValue = Double.valueOf(name.substring(2)).doubleValue();
if (i == 0) {
condition = new LessEqualsSplitCondition(bestAttribute, bestSplitValue);
splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute, bestSplitValue);
j = 0;
} else {
condition = new GreaterSplitCondition(bestAttribute, bestSplitValue);
j = 1;
}
}
splitted.selectSingleSubset(j);
ExampleSet eSet = (ExampleSet)splitted.clone();
Tree child = new Tree(eSet);
child.setLeaf(value);
changeTreeToLeaf(child, eSet);
current.addChild(child, condition);
}
i++;
}
return current;
}
if(payload.isList()) {
RList rList = payload.asList();
while (i < payload.length()) {
String name = rList.keyAt(i);
if (name!=null) {
SplitCondition condition = null;
if (bestAttribute.isNominal()) {
condition = new NominalSplitCondition(bestAttribute, name);
splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute);
j += 1;
} else {
double bestSplitValue = Double.valueOf(name.substring(2)).doubleValue();
if (i == 0) {
condition = new LessEqualsSplitCondition(bestAttribute, bestSplitValue);
splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute, bestSplitValue);
j = 0;
} else {
condition = new GreaterSplitCondition(bestAttribute, bestSplitValue);
j = 1;
}
}
splitted.selectSingleSubset(j);
Tree child = buildTreeFromR((ExampleSet)splitted.clone(), rList.at(i), depth);
current.addChild(child, condition);
}
i++;
}
return current;
}
return current;
}
// Pari un nodo, dispari un arco
public Tree buildTreeFromR(ExampleSet exampleSet, REXP payload, int depth) throws REXPMismatchException {
Tree current = null;
if(payload.isString() && payload.length() == 1) {
String name = payload.asString();
current = new Tree((ExampleSet) exampleSet.clone());
current.setLeaf(name);
changeTreeToLeaf(current, exampleSet);
return current;
}
int i = 0;
if(payload.isList()) {
RList rList = payload.asList();
while (i < payload.length()) {
String name = rList.keyAt(i);
if (name!=null) {
current = new Tree((ExampleSet) exampleSet.clone());
current.setLeaf(name);
Attribute bestAttribute = exampleSet.getAttributes().get(name);
current = getChilds(current, bestAttribute, exampleSet, rList.at(i), depth+1);
}
i++;
}
return current;
}
return current;
}
public Tree translateRTreeToTree(ExampleSet exampleSet) throws OperatorException {
isErrorOccurred = false;
errorOccured = null;
Tree root = new Tree((ExampleSet)exampleSet.clone());
root.setLeaf("Errore");
/*Attribute label = exampleSet.getAttributes().get("TargetRelativo");
exampleSet.getAttributes().setLabel(label);*/
// try retrieving the connection
RSession rSession = null;
rSession = RSessionManager.acquireSession();
rSession.registerSessionListener(this);
try {
// making RapidMiner Input available in R
String inputVariableName = getParameterAsString(PARAMETER_INPUT);
RTranslator<? extends IOObject> translator = RTranslations.getTranslators(ExampleSet.class);
if (translator != null) {
translator.exportObject(rSession, inputVariableName, (IOObject) exampleSet);
} else {
throw new UserError(this, "r.no_translator_available", ExampleSet.class.getSimpleName());
}
if (isParameterSet(PARAMETER_R_SCRIPT))
rSession.execute(getParameterAsString(PARAMETER_R_SCRIPT));
// checking for errors during execution occurred on log
if (isErrorOccurred)
throw new UserError(this, new Throwable(errorOccured), "r.r_error");
REXP payload = rSession.eval("result");
root = buildTreeFromR(exampleSet, payload, 0);
} catch (REXPMismatchException e) {
e.printStackTrace();
} finally {
RSessionManager.releaseSession(rSession);
}
return root;
}
@Override
public Model learn(ExampleSet exampleSet) throws OperatorException {
// learn tree
Tree root = translateRTreeToTree(exampleSet);
// create and return model
return new TreeModel(exampleSet, root);
}
@Override
public boolean supportsCapability(OperatorCapability capability) {
switch (capability) {
case BINOMINAL_ATTRIBUTES:
case POLYNOMINAL_ATTRIBUTES:
case NUMERICAL_ATTRIBUTES:
case POLYNOMINAL_LABEL:
case BINOMINAL_LABEL:
case WEIGHTED_EXAMPLES:
case MISSING_VALUES:
return true;
default:
return false;
}
}
@Override
public void informErrors(String[] errors) {
StringBuilder builder = new StringBuilder();
for (String error : errors) {
isErrorOccurred = true;
builder.append(error);
logError(error);
}
errorOccured = builder.toString();
}
@Override
public void informWarnings(String[] warnings) {
for (String warning : warnings)
logWarning(warning);
}
@Override
public void informOutput(String[] text) {
for (String warning : text)
logNote(warning);
}
/* Ignore following events */
@Override
public void informAssignment(RSession session) {
}
@Override
public void informEvaluation(RSession session) {
}
@Override
public void informExecution(RSession session) {
}
@Override
public void informInterpretation(RSession session) {
}
@Override
public void notifyPlotListener(RPlotPainter plotPainter) {
}
@Override
public void informHelpChange(URL helpPage) {
}
}
Simon0