import java.beans.*
import java.io.Serializable
import java.util.Vector
import java.util.Enumeration
import org.pentaho.dm.kf.KFGroovyScript
import org.pentaho.dm.kf.GroovyHelper
import weka.core.*
import weka.gui.Logger
import weka.gui.beans.*
import weka.classifiers.bayes.NaiveBayes
import weka.classifiers.functions.Logistic
import weka.classifiers.Evaluation
import weka.classifiers.Classifier
import groovy.swing.SwingBuilder
import javax.swing.*
import java.awt.*
// add further imports here if necessary
/**
* Example Groovy script that generates a learning curve for a classifier.
* Allows the classifier to be connected via a "configuration" event or
* specified via an environment variable (CLASSIFIER_NAME). Classifier options
* and the parameters of the learning curve could be specified via environment
* variables as well through just minor changes to the script.
*
* Generates both a "TextEvent" containing the curve information and a
* "DataSetEvent". The latter can be visualized in a DataVisualizer component.
*
* Also demonstrates how to allow the user to set options for the script
* via a graphical pop-up window.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}org)
*/
class LearningCurve
implements
KFGroovyScript,
EnvironmentHandler,
BeanCommon,
EventConstraints,
UserRequestAcceptor,
TrainingSetListener,
TestSetListener,
DataSourceListener,
InstanceListener,
TextListener,
BatchClassifierListener,
IncrementalClassifierListener,
BatchClustererListener,
GraphListener,
ChartListener,
ThresholdDataListener,
VisualizableErrorListener,
ConfigurationListener,
Serializable {
/** Don't delete!!
* GroovyHelper has the following useful methods:
*
* notifyListenerType(Object event) - GroovyHelper will pass on event
* appropriate listener type for you
* ArrayList<TrainingSetListener> getTrainingSetListeners() - get
* a list of any directly connected components that are listening
* for TrainingSetEvents from us
* ArrayList<TestSetListener> getTestSetListeners()
* ArrayList<InstanceListener> getInstanceListeners()
* ArrayList<TextListener> getTextListeners()
* ArrayList<DataSourceListener> getDataSourceListeners()
* ArrayList<BatchClassifierListener> getBatchClassifierListeners()
* ArrayList<IncrementalClassifierListener> getIncrementalClassifierListeners()
* ArrayList<BatchClustererListener> getBatchClustererListeners()
* ArrayList<GraphListenerListener> getGraphListeners()
* ArrayList<ChartListener> getChartListeners()
* ArrayList<ThresholdDataListener> getThresholdDataListeners()
* ArrayList<VisualizableErrorListener> getVisualizableErrorListeners()
*/
GroovyHelper m_helper
Logger m_log = null
Environment m_env = Environment.getSystemWide()
String m_holdoutSize = "33.0"
String m_stepSize = "100"
String m_numSteps = "10"
String m_classifierName = "\${CLASSIFIER_NAME}"
String m_classifierOptions = null
Object m_incomingConnection = null
weka.gui.beans.Classifier m_connectedConfigurable = null
/** Don't delete!! */
void setManager(GroovyHelper manager) { m_helper = manager }
/** Alter or add to in order to tell the KnowlegeFlow
* environment whether a certain incoming connection type is allowed
*/
boolean connectionAllowed(String eventName) {
if (eventName.equals("trainingSet") &&
m_incomingConnection == null) { return true }
if (eventName.equals("configuration") &&
m_connectedConfigurable == null) { return true}
return false
}
/** Alter or add to in order to tell the KnowlegeFlow
* environment whether a certain incoming connection type is allowed
*/
boolean connectionAllowed(EventSetDescriptor esd) {
return connectionAllowed(esd.getName())
}
/** Add (optional) code to do something when you have been
* registered as a listener with a source for the named event
*/
void connectionNotification(String eventName, Object source) {
if (eventName.equals("trainingSet")) {
m_incomingConnection = source
}
if (eventName.equals("configuration")) {
// check the type of the configurable
if (source instanceof weka.gui.beans.Classifier) {
m_connectedConfigurable = (weka.gui.beans.Classifier)source
} else {
if (m_log != null) {
m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)")
m_log.logMessage("[LearningCurve] Connected configurable is not a classifier!!")
}
}
}
}
/** Add (optional) code to do something when you have been
* deregistered as a listener with a source for the named event
*/
void disconnectionNotification(String eventName, Object source) {
if (eventName.equals("trainingSet")) {
m_incomingConnection = null
}
if (eventName.equals("configuration")) {
m_connectedConfigurable = null
}
}
/** Custom name of this component. Do something with it if you
* like. GroovyHelper already stores it and alters the icon text
* for you */
void setCustomName(String name) { }
/** Custom name of this component. No need to return anything
* GroovyHelper already stores it and alters the icon text
* for you */
String getCustomName() { return null }
/** Add code to return true when you are busy doing something
*/
boolean isBusy() { return false }
/** Store and use this logging object in order to post messages
* to the log
*/
void setLog(Logger logger) {
m_log = logger
}
/** Store and use this Environment object in order to lookup and
* use the values of environment variables
*/
void setEnvironment(Environment env) {
m_env = env
}
/** Stop any processing (if possible)
*/
void stop() { }
/** Alter or add to in order to tell the KnowlegeFlow
* whether, at the current time, the named event could
* be generated.
*/
boolean eventGeneratable(String eventName) {
if (eventName.equals("text")) { return true }
if (eventName.equals("dataSet")) { return true }
return false
}
/** Implement this to tell KnowledgeFlow about any methods
* that the user could invoke (i.e. to show a popup visualization
* or something).
*/
Enumeration enumerateRequests() {
Vector items = new Vector(0)
items.add("Set options...")
return items.elements()
}
/** Make the user-requested action happen here.
*/
void performRequest(String requestName) {
if (requestName.equals("Set options...")) {
def swing = new SwingBuilder()
def holderP1 = {
swing.panel() {
borderLayout()
label (text:'Holdout set size: ', constraints:BorderLayout.WEST)
hSize = textField(text:m_holdoutSize, columns:6,
actionPerformed: {
m_holdoutSize = hSize.text
}, constraints:BorderLayout.CENTER)
}
}
def holderP2 = {
swing.panel() {
borderLayout()
label (text:'Number of steps: ', constraints:BorderLayout.WEST)
nSteps = textField(text:m_numSteps, columns:6,
actionPerformed: {
m_numSteps = nSteps.text
}, constraints:BorderLayout.CENTER)
}
}
def holderP3 = {
swing.panel() {
borderLayout()
label (text:'Step size: ', constraints:BorderLayout.WEST)
sSize = textField(text:m_stepSize, columns:6,
actionPerformed: {
m_stepSize = sSize.text
}, constraints:BorderLayout.CENTER)
}
}
def holderP4 = {
swing.panel() {
boxLayout(axis:BoxLayout.Y_AXIS)
widget(holderP1())
widget(holderP2())
widget(holderP3())
}
}
def holderP5 = {
swing.panel() {
boxLayout(axis:BoxLayout.X_AXIS)
button(text:'OK',
actionPerformed: {
m_holdoutSize = hSize.text
m_numSteps = nSteps.text
m_stepSize = sSize.text
dispose()
})
button(text:"CANCEL",
actionPerformed: {
dispose()
})
}
}
def frame = swing.frame(title:'Learning Curve Options', size:[300,600]) {
borderLayout()
widget(holderP4(), constraints:BorderLayout.NORTH)
widget(holderP5(), constraints:BorderLayout.SOUTH)
}
frame.pack()
frame.show()
}
}
//--------------- Incoming events ------------------
//--------------- Implement as necessary -----------
void acceptTrainingSet(TrainingSetEvent e) {
if (e.isStructureOnly()) {
return
}
StringBuffer buff = new StringBuffer()
Instances insts = new Instances(e.getTrainingSet())
insts.randomize(new Random(1))
String hSize = m_holdoutSize
String sSize = m_stepSize
String nSteps = m_numSteps
String classifierName = m_classifierName
String classifierOptions = m_classifierOptions
String[] splitOptions = null
if (m_env != null) {
try {
hSize = m_env.substitute(hSize)
sSize = m_env.substitute(sSize)
nSteps = m_env.substitute(nSteps)
if (classifierName != null && classifierName.length() > 0) {
classifierName = m_env.substitute(classifierName)
}
if (classifierOptions != null && classifierOptions.length() > 0) {
classifierOptions = m_env.substitute(classifierOptions)
}
} catch (Exception ex) {
}
}
weka.classifiers.Classifier classifierToUse = null
if (m_connectedConfigurable == null) {
// try and instantiate from the supplied classifier name
if (classifierName == null || classifierName.length() == 0) {
if (m_log != null) {
m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)")
m_log.logMessage("[LearningCurve] No classifier supplied!")
}
return
}
if (classifierOptions != null && classifierOptions.length() > 0) {
try {
splitOptions = Utils.splitOptions(classifierOptions)
} catch (Exception ex) {
if (m_log != null) {
m_log.statusMessage("LearningCurve\$"+hashCode()+"ERROR (see log for details)")
m_log.logMessage("[LearningCurve] Problem parsing classifier options")
}
return
}
}
classifierToUse = Classifier.forName(classifierName, splitOptions)
} else {
classifierToUse = m_connectedConfigurable.getClassifierTemplate()
classifierToUse = weka.classifiers.Classifier.makeCopy(classifierToUse)
}
double hS = Double.parseDouble(hSize)
hS /= 100
int sS = Integer.parseInt(sSize)
int nS = Integer.parseInt(nSteps)
int numInHoldout = hS * insts.numInstances()
Instances holdoutI = new Instances(insts, numInHoldout)
for (int i = insts.numInstances() - numInHoldout; i < insts.numInstances(); i++) {
holdoutI.add(insts.instance(i))
}
String classifierSetUpString = classifierToUse.class.toString() + " "
if (classifierToUse instanceof OptionHandler) {
classifierSetUpString += Utils.joinOptions(((OptionHandler)classifierToUse).getOptions())
}
if (m_log != null) {
m_log.logMessage("[LearningCurve] Using classifier " + classifierSetUpString)
}
// create the instances structure to hold the learning curve results
Attribute setSize = new Attribute("NumInstances")
Attribute aucA = new Attribute("PercentCorrect")
FastVector atts = new FastVector()
atts.addElement(setSize)
atts.addElement(aucA)
// The preceeding "__" tells the DataVisualizer to connect the points with lines
Instances learnCInstances = new Instances("__Learning curve: " + classifierSetUpString, atts, 0)
boolean done = false
Instances training = new Instances(insts, 0)
for (int i = 0; i < nS; i++) {
if (m_log != null) {
m_log.statusMessage("LearningCurve\$"+hashCode()+"|Processing set "+(i+1))
}
int numInThisStep = ((i + 1) * sS)
if (numInThisStep >= (insts.numInstances() - numInHoldout)) {
numInThisStep = (insts.numInstances() - numInHoldout)
done = true
}
for (int k = (i * sS); k < numInThisStep; k++) {
training.add(insts.instance(k))
}
// train on this set
Classifier newModel = Classifier.makeCopies(classifierToUse, 1)[0]
newModel.buildClassifier(training)
Evaluation eval = new Evaluation(holdoutI)
eval.evaluateModel(newModel, holdoutI)
double pc = (1.0 - eval.errorRate()) * 100.0
//double auc = 1.0 - eval.errorRate();
buff.append(""+numInThisStep+","+pc+"\n")
//System.err.println(""+numInThisStep+","+auc+"\n")
Instance newInst = new Instance(2)
newInst.setValue(0, (double)numInThisStep)
newInst.setValue(1, pc)
learnCInstances.add(newInst)
if (done) {
break
}
}
if (m_log != null) {
m_log.statusMessage("LearningCurve\$"+hashCode()+"|Finished.")
}
//System.err.println(buff.toString())
m_helper.notifyTextListeners(new TextEvent(this, buff.toString(), "learning curve"))
m_helper.notifyDataSourceListeners(new DataSetEvent(this, learnCInstances))
}
void acceptTestSet(TestSetEvent e) { }
void acceptDataSet(DataSetEvent e) { }
void acceptInstance(InstanceEvent e) { }
void acceptText(TextEvent e) { }
void acceptClassifier(BatchClassifierEvent e) { }
void acceptClassifier(IncrementalClassifierEvent e) { }
void acceptClusterer(BatchClustererEvent e) { }
void acceptGraph(GraphEvent e) { }
void acceptDataPoint(ChartEvent e) { }
void acceptDataSet(ThresholdDataEvent e) { }
void acceptDataSet(VisualizableErrorEvent e) { }
void acceptConfiguration(ConfigurationEvent e) { }
}