00001
00002 #ifndef MVAUtils_BDT_H
00003 #define MVAUtils_BDT_H
00004
00005 #include<vector>
00006 #include "TString.h"
00007
00008 #include "MVAUtils/Node.h"
00009
00010 namespace TMVA {
00011 class DecisionTreeNode;
00012 class MethodBDT;
00013 }
00014 class TTree;
00015
00016 namespace MVAUtils
00017 {
00034 class BDT
00035 {
00036 public:
00037 BDT(TTree *tree);
00038 BDT(TMVA::MethodBDT *bdt);
00039
00040
00041 void newTree(const std::vector<int>& vars, const std::vector<float>& values);
00042
00043
00044 void newTree(const TMVA::DecisionTreeNode *node);
00045
00047 unsigned int GetNTrees() const { return m_forest.size(); }
00048
00050 float GetOffset() const { return m_offset; }
00051
00052
00053 float GetResponse(const std::vector<float>& values) const;
00054 float GetResponse(const std::vector<float*>& pointers) const;
00055
00057 float GetResponse() const {
00058 return (m_pointers.size() ? GetResponse(m_pointers) : -9999.);
00059 }
00060
00061
00062 float GetGradBoostMVA(const std::vector<float>& values) const;
00063 float GetGradBoostMVA(const std::vector<float*>& pointers) const;
00064
00065
00066 std::vector<float> GetMultiResponse(const std::vector<float>& values, unsigned int numClasses) const;
00067 std::vector<float> GetMultiResponse(const std::vector<float*>& pointers, unsigned int numClasses) const;
00068
00070 std::vector<float> GetValues() const;
00071
00073 std::vector<float*> GetPointers() const { return m_pointers; }
00074
00076 void SetPointers(std::vector<float*>& pointers) { m_pointers = pointers; }
00077
00078
00079 TTree* WriteTree(TString name = "BDT");
00080
00081
00082 void PrintForest() const;
00083 void PrintTree(Node::index_t index) const;
00084
00085 private:
00086
00087 float GetTreeResponse(const std::vector<float>& values, Node::index_t index) const;
00088 float GetTreeResponse(const std::vector<float*>& pointers, Node::index_t index) const;
00089
00090 float m_offset;
00091 std::vector<Node::index_t> m_forest;
00092 std::vector<float*> m_pointers;
00093 std::vector<Node> m_nodes;
00094
00095 };
00096 }
00097
00098 #endif