#include #include #include #include "TH1F.h" #include "TFile.h" #include "TTree.h" #include "TMVA/Factory.h" #include "TMVA/DataLoader.h" #include "TMVA/Reader.h" #include "TMVA/Types.h" TFile *results; int bestmatch; int correct = 0; int wrong = 0; //Define a structure for all the parameters of the wines, //this can be inserted into a ROOT TTree struct winedata{ Float_t alcohol; Float_t malic; Float_t ash; Float_t alcalinity; Float_t magnesium; Float_t phenols; Float_t flavanoids; Float_t nonflavanoids; Float_t proanthocyanins; Float_t colorint; Float_t hue; Float_t OD; Float_t proline; }; int Train(){ //Function for training the classifiers //Prepare a ROOT-file for output histograms results = new TFile("results.root","RECREATE"); //Create the objet myWineData, which is of the type winedata (the struct that was defined earlier) winedata myWineData; //Create pointers for TTrees for each of the grape cultivars TTree *class1; TTree *class2; TTree *class3; //Reserve memory class1 = new TTree("C1","Class 1 wines"); class2 = new TTree("C2","Class 2 wines"); class3 = new TTree("C3","Class 3 wines"); //Open data file ifstream wines("wine.data"); std::string tmp; int theclass; //Create branches "wine" into all of the TTrees class1->Branch("wine",&myWineData.alcohol,"alcohol/F:malic:ash:alcalinity:magnesium:phenols:flavanoids:nonflavanoids:proanthocyanins:colorint:hue:OD:proline"); class2->Branch("wine",&myWineData.alcohol,"alcohol/F:malic:ash:alcalinity:magnesium:phenols:flavanoids:nonflavanoids:proanthocyanins:colorint:hue:OD:proline"); class3->Branch("wine",&myWineData.alcohol,"alcohol/F:malic:ash:alcalinity:magnesium:phenols:flavanoids:nonflavanoids:proanthocyanins:colorint:hue:OD:proline"); //Read the input data while(wines.good()){ getline(wines,tmp,','); theclass = atoi(tmp.c_str()); getline(wines,tmp,','); myWineData.alcohol = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.malic = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.ash = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.alcalinity = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.magnesium = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.phenols = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.flavanoids = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.nonflavanoids = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.proanthocyanins = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.colorint = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.hue = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.OD = atof(tmp.c_str()); getline(wines,tmp,'\n'); myWineData.proline = atof(tmp.c_str()); switch (theclass) { case 1: class1->Fill(); break; case 2: class2->Fill(); break; case 3: class3->Fill(); break; default: break; } } //Prepare a TMVA factory TMVA::Factory *factory = new TMVA::Factory("factory",results,"!V:!Silent:!Color:DrawProgressBar:AnalysisType=multiclass"); TMVA::DataLoader *loader = new TMVA::DataLoader("loader"); //Add variables loader->AddVariable("alcohol",'F'); loader->AddVariable("malic",'F'); loader->AddVariable("ash",'F'); loader->AddVariable("alcalinity",'F'); loader->AddVariable("magnesium",'F'); loader->AddVariable("phenols",'F'); loader->AddVariable("flavanoids",'F'); loader->AddVariable("nonflavanoids",'F'); loader->AddVariable("proanthocyanins",'F'); loader->AddVariable("colorint",'F'); loader->AddVariable("hue",'F'); loader->AddVariable("OD",'F'); loader->AddVariable("proline",'F'); //Add the three different classes of data loader->AddTree(class1,"Class 1"); loader->AddTree(class2,"Class 2"); loader->AddTree(class3,"Class 3"); //Prepare the training and test trees loader->PrepareTrainingAndTestTree("","SplitMode=Random:NormMode=NumEvents:!V"); //Book MVA methods factory->BookMethod(loader,TMVA::Types::kMLP,"MLP","HiddenLayers=10:NeuronType=linear:VarTransform=Norm"); factory->BookMethod(loader,TMVA::Types::kBDT,"BDT","NTrees=5:BoostType=Grad"); //Train them factory->TrainAllMethods(); //Test them factory->TestAllMethods(); //Evaluate how well they perform factory->EvaluateAllMethods(); //Save histograms, TTrees etc. on disk results->Write(); delete factory; wines.close(); return 0; } int Classify(){ //Create a reader TMVA::Reader *reader = new TMVA::Reader(); //Open output file results = TFile::Open("results.root","UPDATE"); //Create myWineData winedata myWineData; //Open input files ifstream wines("wine.data"); std::string tmp; int theclass; //Create response histograms TH1F *response1; response1 = new TH1F("Response1","Classifier response",10,0,1); TH1F *response2; response2 = new TH1F("Response2","Classifier response",10,0,1); TH1F *response3; response3 = new TH1F("Response3","Classifier response",10,0,1); //Add variables to the reader, with same names and in same order as for the factory reader->AddVariable("alcohol",&myWineData.alcohol); reader->AddVariable("malic",&myWineData.malic); reader->AddVariable("ash",&myWineData.ash); reader->AddVariable("alcalinity",&myWineData.alcalinity); reader->AddVariable("magnesium",&myWineData.magnesium); reader->AddVariable("phenols",&myWineData.phenols); reader->AddVariable("flavanoids",&myWineData.flavanoids); reader->AddVariable("nonflavanoids",&myWineData.nonflavanoids); reader->AddVariable("proanthocyanins",&myWineData.proanthocyanins); reader->AddVariable("colorint",&myWineData.colorint); reader->AddVariable("hue",&myWineData.hue); reader->AddVariable("OD",&myWineData.OD); reader->AddVariable("proline",&myWineData.proline); //Book the MVAs and tell where the weight files from training are reader->BookMVA("MLP","weights/factory_MLP.weights.xml"); reader->BookMVA("BDT","weights/factory_BDT.weights.xml"); //Read & classify events while(wines){ getline(wines,tmp,','); theclass = atoi(tmp.c_str()); if(theclass==0) break; getline(wines,tmp,','); myWineData.alcohol = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.malic = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.ash = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.alcalinity = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.magnesium = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.phenols = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.flavanoids = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.nonflavanoids = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.proanthocyanins = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.colorint = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.hue = atof(tmp.c_str()); getline(wines,tmp,','); myWineData.OD = atof(tmp.c_str()); getline(wines,tmp,'\n'); myWineData.proline = atof(tmp.c_str()); response1->Fill(reader->EvaluateMulticlass(0,"MLP")); response2->Fill(reader->EvaluateMulticlass(1,"MLP")); response3->Fill(reader->EvaluateMulticlass(2,"MLP")); printf("New wine: %f %f %f\n",reader->EvaluateMulticlass(0,"MLP"),reader->EvaluateMulticlass(1,"MLP"),reader->EvaluateMulticlass(2,"MLP")); if(reader->EvaluateMulticlass(0,"MLP")>reader->EvaluateMulticlass(1,"MLP")&&reader->EvaluateMulticlass(0,"MLP")>reader->EvaluateMulticlass(2,"MLP")){ bestmatch = 1; } else if(reader->EvaluateMulticlass(1,"MLP")>reader->EvaluateMulticlass(0,"MLP")&&reader->EvaluateMulticlass(1,"MLP")>reader->EvaluateMulticlass(2,"MLP")){ bestmatch = 2; } else if(reader->EvaluateMulticlass(2,"MLP")>reader->EvaluateMulticlass(1,"MLP")&&reader->EvaluateMulticlass(2,"MLP")>reader->EvaluateMulticlass(0,"MLP")){ bestmatch = 3; } if(bestmatch==theclass) correct++; else wrong++; printf("Wine classified as: %d, true class: %d\n",bestmatch,theclass); } printf("Correct:%d/%d\n",correct,correct+wrong); //Save histograms etc. results->Write(); return 0; }