An exclusive raffle opportunity for active members like you! Complete your profile, answer questions and get your first accepted badge to enter the raffle.
package com.rapidminer.operator.r;import java.net.URL;import java.util.List;import org.rosuda.REngine.REXP;import org.rosuda.REngine.REXPMismatchException;import org.rosuda.REngine.RList;import com.rapidminer.example.Attribute;import com.rapidminer.example.ExampleSet;import com.rapidminer.example.Statistics;import com.rapidminer.example.set.SplittedExampleSet;import com.rapidminer.operator.IOObject;import com.rapidminer.operator.Model;import com.rapidminer.operator.OperatorCapability;import com.rapidminer.operator.OperatorDescription;import com.rapidminer.operator.OperatorException;import com.rapidminer.operator.UserError;import com.rapidminer.parameter.ParameterType;import com.rapidminer.parameter.ParameterTypeString;import com.rapidminer.parameter.ParameterTypeText;import com.rapidminer.parameter.TextType;import com.rapidminer.tools.r.RPlotPainter;import com.rapidminer.tools.r.RSession;import com.rapidminer.tools.r.RSessionListener;import com.rapidminer.tools.r.RSessionManager;import com.rapidminer.tools.r.translation.RTranslations;import com.rapidminer.tools.r.translation.RTranslator;import com.rapidminer.operator.learner.AbstractLearner;import com.rapidminer.operator.learner.PredictionModel;import com.rapidminer.operator.learner.tree.GreaterSplitCondition;import com.rapidminer.operator.learner.tree.LessEqualsSplitCondition;import com.rapidminer.operator.learner.tree.NominalSplitCondition;import com.rapidminer.operator.learner.tree.SplitCondition;import com.rapidminer.operator.learner.tree.Tree;import com.rapidminer.operator.learner.tree.TreeModel;public class RDecisionTreeLearner extends AbstractLearner implements RSessionListener { public static final String PARAMETER_R_SCRIPT = "script"; public static final String PARAMETER_INPUT = "input"; public static final String PARAMETER_VARIABLE_NAME = "name_of_variable"; private transient boolean isErrorOccurred = false; private transient String errorOccured; public RDecisionTreeLearner(OperatorDescription description) {Â Â Â Â super(description);Â Â }@Override public Class<? extends PredictionModel> getModelClass() { return TreeModel.class; } public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); ParameterType type = new ParameterTypeText(PARAMETER_R_SCRIPT, "This script will be executed on one of the available R servers.", TextType.PLAIN, true); type.setExpert(false); types.add(type); type = new ParameterTypeString(PARAMETER_INPUT, "This assigns each input port a variable name. If the type of input object is supported by the R translation, it will be accessible under this variable name."); type.setExpert(false); types.add(type); return types; } public void changeTreeToLeaf(Tree node, ExampleSet exampleSet) {Â Â Â Â Attribute label = exampleSet.getAttributes().getLabel();Â Â Â Â exampleSet.recalculateAttributeStatistics(label);Â Â Â Â int labelValue = (int)exampleSet.getStatistics(label, Statistics.MODE);Â Â Â Â Â Â Â Â if(labelValue != -1) {Â Â Â Â String labelName = label.getMapping().mapIndex(labelValue);Â Â Â Â node.setLeaf(labelName);Â Â Â Â for (String value : label.getMapping().getValues()) {Â Â Â Â int count = (int)exampleSet.getStatistics(label, Statistics.COUNT, value);Â Â Â Â node.addCount(value, count);Â Â Â Â }Â Â Â Â }Â Â } public Tree getChilds(Tree current, Attribute bestAttribute, ExampleSet exampleSet, REXP payload, int depth) throws REXPMismatchException { int i = 0; int j = -1; SplittedExampleSet splitted = null; if(payload.isString()) { String[] names = payload._attr().asList().at(0).asStrings(); String[] values = payload.asStrings(); while (i < names.length) { String name = names; String value = values; if (name!=null && value!=null) { SplitCondition condition = null; if (bestAttribute.isNominal()) { condition = new NominalSplitCondition(bestAttribute, value); splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute); j += 1; } else { double bestSplitValue = Double.valueOf(name.substring(2)).doubleValue(); if (i == 0) { condition = new LessEqualsSplitCondition(bestAttribute, bestSplitValue); splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute, bestSplitValue); j = 0; } else { condition = new GreaterSplitCondition(bestAttribute, bestSplitValue); j = 1; } } splitted.selectSingleSubset(j); ExampleSet eSet = (ExampleSet)splitted.clone(); Tree child = new Tree(eSet); child.setLeaf(value); changeTreeToLeaf(child, eSet); current.addChild(child, condition); } i++; } return current; } if(payload.isList()) { RList rList = payload.asList(); while (i < payload.length()) { String name = rList.keyAt(i); if (name!=null) { SplitCondition condition = null; if (bestAttribute.isNominal()) { condition = new NominalSplitCondition(bestAttribute, name); splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute); j += 1; } else { double bestSplitValue = Double.valueOf(name.substring(2)).doubleValue(); if (i == 0) { condition = new LessEqualsSplitCondition(bestAttribute, bestSplitValue); splitted = SplittedExampleSet.splitByAttribute(exampleSet, bestAttribute, bestSplitValue); j = 0; } else { condition = new GreaterSplitCondition(bestAttribute, bestSplitValue); j = 1; } } splitted.selectSingleSubset(j); Tree child = buildTreeFromR((ExampleSet)splitted.clone(), rList.at(i), depth); current.addChild(child, condition); } i++; } return current; } return current; } // Pari un nodo, dispari un arco public Tree buildTreeFromR(ExampleSet exampleSet, REXP payload, int depth) throws REXPMismatchException { Tree current = null; if(payload.isString() && payload.length() == 1) { String name = payload.asString(); current = new Tree((ExampleSet) exampleSet.clone()); current.setLeaf(name); changeTreeToLeaf(current, exampleSet); return current; } int i = 0; if(payload.isList()) { RList rList = payload.asList(); while (i < payload.length()) { String name = rList.keyAt(i); if (name!=null) { current = new Tree((ExampleSet) exampleSet.clone()); current.setLeaf(name); Attribute bestAttribute = exampleSet.getAttributes().get(name); current = getChilds(current, bestAttribute, exampleSet, rList.at(i), depth+1); } i++; } return current; } return current; } public Tree translateRTreeToTree(ExampleSet exampleSet) throws OperatorException { isErrorOccurred = false; errorOccured = null; Tree root = new Tree((ExampleSet)exampleSet.clone()); root.setLeaf("Errore"); /*Attribute label = exampleSet.getAttributes().get("TargetRelativo"); exampleSet.getAttributes().setLabel(label);*/ // try retrieving the connection RSession rSession = null; rSession = RSessionManager.acquireSession(); rSession.registerSessionListener(this); try { // making RapidMiner Input available in R String inputVariableName = getParameterAsString(PARAMETER_INPUT); RTranslator<? extends IOObject> translator = RTranslations.getTranslators(ExampleSet.class); if (translator != null) { translator.exportObject(rSession, inputVariableName, (IOObject) exampleSet); } else { throw new UserError(this, "r.no_translator_available", ExampleSet.class.getSimpleName()); } if (isParameterSet(PARAMETER_R_SCRIPT)) rSession.execute(getParameterAsString(PARAMETER_R_SCRIPT)); // checking for errors during execution occurred on log if (isErrorOccurred) throw new UserError(this, new Throwable(errorOccured), "r.r_error"); REXP payload = rSession.eval("result"); root = buildTreeFromR(exampleSet, payload, 0); } catch (REXPMismatchException e) { e.printStackTrace(); } finally { RSessionManager.releaseSession(rSession); } return root; }@Override public Model learn(ExampleSet exampleSet) throws OperatorException { // learn tree Tree root = translateRTreeToTree(exampleSet); // create and return model return new TreeModel(exampleSet, root); }@Override public boolean supportsCapability(OperatorCapability capability) { switch (capability) {Â Â Â Â case BINOMINAL_ATTRIBUTES:Â Â Â Â case POLYNOMINAL_ATTRIBUTES:Â Â Â Â case NUMERICAL_ATTRIBUTES:Â Â Â Â case POLYNOMINAL_LABEL:Â Â Â Â case BINOMINAL_LABEL:Â Â Â Â case WEIGHTED_EXAMPLES:Â Â Â Â case MISSING_VALUES:Â Â Â Â Â Â return true;Â Â Â Â default:Â Â Â Â Â Â return false;Â Â Â Â } }@Override public void informErrors(String[] errors) { StringBuilder builder = new StringBuilder(); for (String error : errors) { isErrorOccurred = true; builder.append(error); logError(error); } errorOccured = builder.toString(); }@Override public void informWarnings(String[] warnings) { for (String warning : warnings) logWarning(warning); }@Override public void informOutput(String[] text) { for (String warning : text) logNote(warning); } /* Ignore following events */@Override public void informAssignment(RSession session) { }@Override public void informEvaluation(RSession session) { }@Override public void informExecution(RSession session) { }@Override public void informInterpretation(RSession session) { }@Override public void notifyPlotListener(RPlotPainter plotPainter) { }@Override public void informHelpChange(URL helpPage) { }}