Machine-Learning Examples: scikit-learn versus TMVA (CERN ROOT)

Basic Overview

When you think machine-learing+python, you usually think of scikit-learn. But there are alternatives! One of these is TMVA, a machine-learning package which is part of CERN’s ROOT analysis software. Conveniently, ROOT classes can be accessed in python with pyROOT.

For scikit-learn users, TMVA might have advantages – such as easy event weighting and interfaces to examine data correlations, and overtraining checks. This, along with the histogram-oriented tools of ROOT make it a good choice for physicists doing high-statistics signal-background separation.

For TMVA users, scikit-learn can provide a straightforward interface to machine learning, with with very easy applications of the trained classifier to new datasets. And of course, you get all the benefits of working with numpy arrays.

The exercises below are intended to show the same project being done in both TMVA and scikit-learn, with the goal of helping people be better able to switch between the two according to their needs. As you can see from the plots, very similar MVA classifiers for the SVM can be produced with either toolkit, but you’ll notice that the code for each is quite different.

Equivalent operations can be done in TMVA and scikit-learn. The goal of this post is to provide a guided example to working with both tools.

Equivalent operations can be done in TMVA and scikit-learn. The goal of this post is to provide a guided example to working with both tools.

Before we get to the examples, here is a little cheat-sheet to convert between TMVA code and scikit-learn code, assuming a data file that contains both the X (Data) and Y(truth). The example here is an SVM with an rbf kernel.

Action scikit-learn TMVA
Load tools import numpy import ROOT
from sklearn.svm import SVC ROOT.TMVA.Tools.Instance()
Import data data=numpy.load(‘file.npy’) data=TFile.Open(‘file.root’).Get(‘ntuple_name’)
Create Factory (n/a, user-guided) factory=ROOT.TMVA.Factory(“TMVAClassification”, … )
Initialize Data (n/a, will access numpy array) factory.AddSignalTree(ntuple); factory.AddBackgroundTree(ntuple)
Prep Truth Y=data[:,truth_column]; X=data[:,[other_columns]] factory.PrepareTrainingAndTestTree(‘truth==1′,’truth==0’,OtherOptions)
Split Train/Test r =numpy.random.rand(X.shape[0]) (n/a, done in factory OtherOptions with “SplitMode=Random”)
Xtrain=X[r>0.5]; Xtest=X[r<=0.5]
Ytrain=Y[r>0.5]; Ytest=Y[r<=0.5]
Initialize SVM method = SVC(C = 10.0, kernel = ‘rbf’, method = factory.BookMethod( ROOT.TMVA.Types.kSVM, “SVM”,
  ,tol=0.001,gamma=0.25) “C=10.0:Gamma=0.25:Tol=0.001:VarTransform=Norm” )
 Do Training,Y) factory.TrainAllMethods()
 Do Testing method.predict(Xtest) factory.TestAllMethods()

Making some pseudo-data to play with

Now we will examine an instance of machine-learning in scikit-learn and in TMVA. To do so, we will need some (pseudo) data. In the following code, I make 20K data-points with truth values of 0 or 1, which have some inherent separability because they are distributed according to different gaussian distributions. The data is saved both as a numpy array and as a ROOT ntuple.

import random, numpy, ROOT

# Initialie root data file
f_root = ROOT.TFile.Open('data.root','RECREATE')
# Declare root ntuple, with 6 x-variables and one y (truth) variable
ntuple = ROOT.TNtuple("data","data","x1:x2:x3:x4:x5:x6:truth")

# Initialize list for numpy array
numpy_data = []

# Geneate and fill signal data (truth = 1.0), xvalues ceneterd at 100
for a in range(20000):
	x =  [random.gauss(100.,5.) for i in range(6)]

# Geneate and fill signal data (truth = 0.0), xvalues ceneterd at 95
for a in range(20000):
	x =  [random.gauss(95.0,5.) for i in range(6)]

# Save ROOT file

# Get and Save numpy array
numpy_data_as_array = numpy.array(numpy_data)'data.npy',numpy_data_as_array)

Now we have stored locally two identical datasets – ‘data.npy’ and ‘data.root’. It is worth noting that is not the most efficient means of storage, and that the .npy data is 11MB whereas the .root file is 4MB.

A machine-learning example in scikit-learn

Scikit-learn has the benefit of straightforward syntax and vectorized manipulations in numpy, which is useful for complicated splitting of the training and testing sample. Results are available on-call with the predict() and fit() functions. The plot below and its code show how easily one can separate datasets with thousands of points. In this example signal and background can be determined with 90% accuracy.

The output SVM classifier calculated with scikit-learn SVC. Compatibility between dots (testing data) and histograms (training data) indicates that overtraining is not a problem." width="800" height="600" class="size-full wp-image-203" /> The output SVM classifier calculated with scikit-learn SVC. Compatibility between dots (testing data) and histograms (training data) indicates that overtraining is not a problem.

The output SVM classifier calculated with scikit-learn SVC. Compatibility between dots (testing data) and histograms (training data) indicates that overtraining is not a problem. The output SVM classifier calculated with scikit-learn SVC. Compatibility between dots (testing data) and histograms (training data) indicates that overtraining is not a problem.

import numpy
from sklearn.svm import SVC
from matplotlib import pyplot as plt

# Get the pseudo-data from the npy file
npy_data = numpy.load('data.npy')

# Split numpy array into X (data) and Y (truth)
# Y is just last column
# X is all but last column

# Random number list for splitting training and testing
r =numpy.random.rand(X.shape[0])

# Get training and testing samples, splitting in half (using 0.5)

# Create and run the SVC with an rbf kernel 
# Some tuning has been done to get the gamma and C values. 
method_sklearn = SVC(C = 1.0, kernel = 'rbf',tol=0.001,gamma=0.005),Ytrain) 

# The percent of correctly classified training and testing data 
# should be roughly equivalent (i.e. not overtrained) 
# and ideally, near 100%.  We will see about 90% success. 
print '---------- Training/Testing info ----------' 
print 'Trained SVM correctly classifies', 
print 100*(sum(method_sklearn.predict(Xtest)==Ytest))/Xtest.shape[0], 
print '%  of the testing data and ', 
print 100*(sum(method_sklearn.predict(Xtrain)==Ytrain))/Xtrain.shape[0], 
print '%  of the training data. ' print '-'*50 

# Now lets view the trained classifier. To start, we will get the values of 
# the classifier in training and testing.  
# It come out like [[a] [b] [c]], so we ravel it into [a b c] 
# Total of four things: Training and Testing For signal and background (S and B) 

Classifier_training_S = method_sklearn.decision_function(Xtrain[Ytrain>0.5]).ravel()
Classifier_training_B = method_sklearn.decision_function(Xtrain[Ytrain<0.5]).ravel() 
Classifier_testing_S = method_sklearn.decision_function(Xtest[Ytest>0.5]).ravel()
Classifier_testing_B = method_sklearn.decision_function(Xtest[Ytest<0.5]).ravel()

# This will be the min/max of our plots
c_max = 2.0
c_min = -2.0

# Get histograms of the classifiers
Histo_training_S = numpy.histogram(Classifier_training_S,bins=40,range=(c_min,c_max))
Histo_training_B = numpy.histogram(Classifier_training_B,bins=40,range=(c_min,c_max))
Histo_testing_S = numpy.histogram(Classifier_testing_S,bins=40,range=(c_min,c_max))
Histo_testing_B = numpy.histogram(Classifier_testing_B,bins=40,range=(c_min,c_max))

# Lets get the min/max of the Histograms
AllHistos= [Histo_training_S,Histo_training_B,Histo_testing_S,Histo_testing_B]
h_max = max([histo[0].max() for histo in AllHistos])*1.2
h_min = max([histo[0].min() for histo in AllHistos])

# Get the histogram properties (binning, widths, centers)
bin_edges = Histo_training_S[1]
bin_centers = ( bin_edges[:-1] + bin_edges[1:]  ) /2.
bin_widths = (bin_edges[1:] - bin_edges[:-1])

# To make error bar plots for the data, take the Poisson uncertainty sqrt(N)
ErrorBar_testing_S = numpy.sqrt(Histo_testing_S[0])
ErrorBar_testing_B = numpy.sqrt(Histo_testing_B[0])

# Draw objects
ax1 = plt.subplot(111)

# Draw solid histograms for the training data,Histo_training_S[0],facecolor='blue',linewidth=0,width=bin_widths,label='S (Train)',alpha=0.5),Histo_training_B[0],facecolor='red',linewidth=0,width=bin_widths,label='B (Train)',alpha=0.5)

# # Draw error-bar histograms for the testing data
ax1.errorbar(bin_centers, Histo_testing_S[0], yerr=ErrorBar_testing_S, xerr=None, ecolor='blue',c='blue',fmt='o',label='S (Test)')
ax1.errorbar(bin_centers, Histo_testing_B[0], yerr=ErrorBar_testing_B, xerr=None, ecolor='red',c='red',fmt='o',label='B (Test)')

# Make a colorful backdrop to show the clasification regions in red and blue
ax1.axvspan(0.0, c_max, color='blue',alpha=0.08)
ax1.axvspan(c_min,0.0, color='red',alpha=0.08)

# Adjust the axis boundaries (just cosmetic)
ax1.axis([c_min, c_max, h_min, h_max])

# Make labels and title
plt.title("Classification with scikit-learn")
plt.xlabel("Classifier, SVM [rbf kernel, C=1, gamma=0.005]")

# Make legend with smalll font
legend = ax1.legend(loc='upper center', shadow=True,ncol=2)
for alabel in legend.get_texts():

# Save the result to png

The equivalent procedure in TMVA

As you will see from the code, everything relies heavily on the TMVA factory which process the training and testing and outputs all relevant information into a single file, from which the plots can be easily produced, and a cut on the discriminant can be optimized for desired signal purity.

Again we see no evidence of overtraining. TMVA does not classify as signal or background inherently - it is up to the user to choose a discriminant cutoff.

Again we see no evidence of overtraining. TMVA does not classify as signal or background inherently – it is up to the user to choose a discriminant cutoff.

import ROOT

# Get the data from the ROOT file
root_data = ROOT.TFile.Open('data.root').Get('data')

# Useful output information will be stored in a new root file:
f_out = ROOT.TFile("LearningOutput.root","RECREATE")

# Create the TMVA factory
factory = ROOT.TMVA.Factory("TMVAClassification", f_out,"AnalysisType=Classification")

# Add the six variables to the TMVA factory as floats
for x in ['x1','x2','x3','x4','x5','x6']:

# Link the signal and background to the root_data ntuple

# cuts defining the signal and background sample
sigCut = ROOT.TCut("truth > 0.5")
bgCut = ROOT.TCut("truth <= 0.5") 
# Prepare the training/testing signal/background  

# Book the SVM method and train/test 
method = factory.BookMethod( ROOT.TMVA.Types.kSVM, "SVM", "C=1.0:Gamma=0.005:Tol=0.001:VarTransform=None" ) 

# Histogrammed results are already stored in a file for us! 
# We will open this file (LearningOutput.root) shortly.
# These are histogram (TH) one-dimensional double (1D) objects 
Histo_training_S = ROOT.TH1D('Histo_training_S','S (Train)',40,0.0,1.0) 
Histo_training_B = ROOT.TH1D('Histo_training_B','B (Train)',40,0.0,1.0) 
Histo_testing_S = ROOT.TH1D('Histo_testing_S','S (Test)',40,0.0,1.0) 
Histo_testing_B = ROOT.TH1D('Histo_testing_B','B (Test)',40,0.0,1.0) 

# Fetch the trees of events from the root file 
TrainTree = f_out.Get("TrainTree") 
TestTree = f_out.Get("TestTree") 

# Cutting on these objects in the trees will allow to separate true S/B SCut_Tree = 'classID>0.5'
BCut_Tree = 'classID<0.5'

# Now lets project the tree information into those histograms

# Create the color styles


# Histogram fill styles

# Histogram marker styles

# Set titles
Histo_training_S.GetXaxis().SetTitle("Classifier, SVM [rbf kernel, C=1, gamma=0.005]")

# Draw the objects
c1 = ROOT.TCanvas("c1","",800,600)

# Reset the y-max of the plot
ymax = max([h.GetMaximum() for h in [Histo_training_S,Histo_training_B,Histo_testing_S,Histo_testing_B] ])
ymax *=1.2

# Create Legend 0.42,  0.72,  0.57,  0.88).SetFillColor(0)

# Add custom title
l1.DrawLatex(0.26,0.93,"Classification with TMVA (ROOT)")

# Finally, draw the figure

Exploring the TMVA Gui

TMVA has the benefit of outputting results in a neatly-stored ROOT file which can be examined in a GUI. There are pre-existing tools like TMVAGui to look at correlation studies, overtraining checks, and much more.

It is very easy to load the TMVA Gui and make plots from the menu. Simply run:

import ROOT
raw_input('Press Enter to exit')

The first thing you’ll notice is a window of point-and-click options, to create a suite of plots. These include information about the signal-efficiency and background rejection, histograms of the classifier (like we made earlier), correlation matrices, and parallel-coordinate visualizations of the SVM. Basically, it has everything you could want to see.


Either scikit-learn or TMVA will get the job done in most instances, but each has it’s own useful features.

Some of the benefits of TMVA are automated sample splitting for training and testing, many built-in learning methods, and a GUI with many useful visualizations. If that interests you, TMVA is worth looking into.

If you need to quickly run analysis on common data formats, including text files, and want all the benefits manipulating arrays in numpy, and want the flexibility to do your training and testing individually, scikit-learn is a good choice.