前言 提到森林,就不得不聯想到樹,因為正是一棵棵的樹構成了龐大的森林,而在本篇文章中的樹,指的就是Decision Tree-----決策樹。隨機森林就是一棵棵決策樹的組合,也就是說隨機森林=boosting+決策樹,這樣就好理解多了吧,再來說說GBDT,GBDT全稱是Gradie
提到森林,就不得不聯想到樹,因為正是一棵棵的樹構成了龐大的森林,而在本篇文章中的”樹“,指的就是Decision Tree-----決策樹。隨機森林就是一棵棵決策樹的組合,也就是說隨機森林=boosting+決策樹,這樣就好理解多了吧,再來說說GBDT,GBDT全稱是Gradient Boosting Decision Tree,就是梯度提升決策樹,與隨機森林的思想很像,但是比隨機森林稍稍的難一點,當然效果相對于前者而言,也會好許多。由于本人才疏學淺,本文只會詳細講述Random Forest算法的部分,至于GBDT我會給出一小段篇幅做介紹引導,讀者能夠如果有興趣的話,可以自行學習。
要想理解隨機森林算法,就不得不提決策樹,什么是決策樹,如何構造決策樹,簡單的回答就是數據的分類以樹形結構的方式所展現,每個子分支都代表著不同的分類情況,比如下面的這個圖所示:
當然決策樹的每個節點分支不一定是三元的,可以有2個或者更多。分類的終止條件為,沒有可以再拿來分類的屬性條件或者說分到的數據的分類已經完全一致的情況。決策樹分類的標準和依據是什么呢,下面介紹主要的2種劃分標準。
1、信息增益。這是ID3算法系列所用的方法,C4.5算法在這上面做了少許的改進,用信息增益率來作為劃分的標準,可以稍稍減小數據過于擬合的缺點。
2、基尼指數。這是CART分類回歸樹所用的方法。也是類似于信息增益的一個定義,最終都是根據數據劃分后的純度來做比較,這個純度,你也可以理解為熵的變化,當然我們所希望的情況就是分類后數據的純度更純,也就是說,前后劃分分類之后的熵的差越大越好。不過CART算法比較好的一點是樹構造好后,還有剪枝的操作,剪枝操作的種類就比較多了,我之前在實現CART算法時用的是代價復雜度的剪枝方法。
這2種決策算法在我之前的博文中已經有所提及,不理解的可以點擊我的ID3系列算法介紹和我的CART分類回歸樹算法。
原本不打算將Boosting單獨拉出來講的,后來想想還是有很多內容可談的。Boosting本身不是一種算法,他更應該說是一種思想,首先對數據構造n個弱分類器,最后通過組合n個弱分類器對于某個數據的判斷結果作為最終的分類結果,就變成了一個強分類器,效果自然要好過單一分類器的分類效果。他可以理解為是一種提升算法,舉一個比較常見的Boosting思想的算法AdaBoost,他在訓練每個弱分類器的時候,提高了對于之前分錯數據的權重值,最終能夠組成一批相互互補的分類器集合。詳細可以查看我的AdaBoost算法學習。
OK,2個重要的概念都已經介紹完畢,終于可以介紹主角Random Forest的出現了,正如前言中所說Random Forest=Decision Trees + Boosting,這里的每個弱分類器就是一個決策樹了,不過這里的決策樹都是二叉樹,就是只有2個孩子分支,自然我立刻想到的做法就是用CART算法來構建,因為人家算法就是二元分支的。隨機算法,隨機算法,當然重在隨機2個字上面,下面是2個方面體現了隨機性。對于數據樣本的采集量,比如我數據由100條,我可以每次隨機取出其中的20條,作為我構造決策樹的源數據,采取又放回的方式,并不是第一次抽到的數據,第二次不能重復,第二隨機性體現在對于數據屬性的隨機采集,比如一行數據總共有10個特征屬性,我每次隨機采用其中的4個。正是由于對于數據的行壓縮和列壓縮,使得數據的隨機性得以保證,就很難出現之前的數據過擬合的問題了,也就不需要在決策樹最后進行剪枝操作了,這個是與一般的CART算法所不同的,尤其需要注意。
下面是隨機森林算法的構造過程:
1、通過給定的原始數據,選出其中部分數據進行決策樹的構造,數據選取是”有放回“的過程,我在這里用的是CART分類回歸樹。
2、隨機森林構造完成之后,給定一組測試數據,使得每個分類器對其結果分類進行評估,最后取評估結果的眾數最為最終結果。
算法非常的好理解,在Boosting算法和決策樹之上做了一個集成,下面給出算法的實現,很多資料上只有大篇幅的理論,我還是希望能帶給大家一點實在的東西。
輸入數據(之前決策樹算法時用過的)input.txt:
Rid Age Income Student CreditRating BuysComputer 1 Youth High No Fair No 2 Youth High No Excellent No 3 MiddleAged High No Fair Yes 4 Senior Medium No Fair Yes 5 Senior Low Yes Fair Yes 6 Senior Low Yes Excellent No 7 MiddleAged Low Yes Excellent Yes 8 Youth Medium No Fair No 9 Youth Low Yes Fair Yes 10 Senior Medium Yes Fair Yes 11 Youth Medium Yes Excellent Yes 12 MiddleAged Medium No Excellent Yes 13 MiddleAged High Yes Fair Yes 14 Senior Medium No Excellent No
樹節點類TreeNode.java:
package DataMining_RandomForest; import java.util.ArrayList; /** * 回歸分類樹節點 * * @author lyq * */ public class TreeNode { // 節點屬性名字 private String attrName; // 節點索引標號 private int nodeIndex; //包含的葉子節點數 private int leafNum; // 節點誤差率 private double alpha; // 父親分類屬性值 private String parentAttrValue; // 孩子節點 private TreeNode[] childAttrNode; // 數據記錄索引 private ArrayListdataIndex; public String getAttrName() { return attrName; } public void setAttrName(String attrName) { this.attrName = attrName; } public int getNodeIndex() { return nodeIndex; } public void setNodeIndex(int nodeIndex) { this.nodeIndex = nodeIndex; } public double getAlpha() { return alpha; } public void setAlpha(double alpha) { this.alpha = alpha; } public String getParentAttrValue() { return parentAttrValue; } public void setParentAttrValue(String parentAttrValue) { this.parentAttrValue = parentAttrValue; } public TreeNode[] getChildAttrNode() { return childAttrNode; } public void setChildAttrNode(TreeNode[] childAttrNode) { this.childAttrNode = childAttrNode; } public ArrayList getDataIndex() { return dataIndex; } public void setDataIndex(ArrayList dataIndex) { this.dataIndex = dataIndex; } public int getLeafNum() { return leafNum; } public void setLeafNum(int leafNum) { this.leafNum = leafNum; } }
package DataMining_RandomForest; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; /** * 決策樹 * * @author lyq * */ public class DecisionTree { // 樹的根節點 TreeNode rootNode; // 數據的屬性列名稱 String[] featureNames; // 這棵樹所包含的數據 ArrayListdatas; // 決策樹構造的的工具類 CARTTool tool; public DecisionTree(ArrayList datas) { this.datas = datas; this.featureNames = datas.get(0); tool = new CARTTool(datas); // 通過CART工具類進行決策樹的構建,并返回樹的根節點 rootNode = tool.startBuildingTree(); } /** * 根據給定的數據特征描述進行類別的判斷 * * @param features * @return */ public String decideClassType(String features) { String classType = ""; // 查詢屬性組 String[] queryFeatures; // 在本決策樹中對應的查詢的屬性值描述 ArrayList featureStrs; featureStrs = new ArrayList<>(); queryFeatures = features.split(","); String[] array; for (String name : featureNames) { for (String featureValue : queryFeatures) { array = featureValue.split("="); // 將對應的屬性值加入到列表中 if (array[0].equals(name)) { featureStrs.add(array); } } } // 開始從根據節點往下遞歸搜索 classType = recusiveSearchClassType(rootNode, featureStrs); return classType; } /** * 遞歸搜索樹,查詢屬性的分類類別 * * @param node * 當前搜索到的節點 * @param remainFeatures * 剩余未判斷的屬性 * @return */ private String recusiveSearchClassType(TreeNode node, ArrayList remainFeatures) { String classType = null; // 如果節點包含了數據的id索引,說明已經分類到底了 if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { classType = judgeClassType(node.getDataIndex()); return classType; } // 取出剩余屬性中的一個匹配屬性作為當前的判斷屬性名稱 String[] currentFeature = null; for (String[] featureValue : remainFeatures) { if (node.getAttrName().equals(featureValue[0])) { currentFeature = featureValue; break; } } for (TreeNode childNode : node.getChildAttrNode()) { // 尋找子節點中屬于此屬性值的分支 if (childNode.getParentAttrValue().equals(currentFeature[1])) { remainFeatures.remove(currentFeature); classType = recusiveSearchClassType(childNode, remainFeatures); // 如果找到了分類結果,則直接挑出循環 break; }else{ //進行第二種情況的判斷加上!符號的情況 String value = childNode.getParentAttrValue(); if(value.charAt(0) == '!'){ //去掉第一個!字符 value = value.substring(1, value.length()); if(!value.equals(currentFeature[1])){ remainFeatures.remove(currentFeature); classType = recusiveSearchClassType(childNode, remainFeatures); break; } } } } return classType; } /** * 根據得到的數據行分類進行類別的決策 * * @param dataIndex * 根據分類的數據索引號 * @return */ public String judgeClassType(ArrayList dataIndex) { // 結果類型值 String resultClassType = ""; String classType = ""; int count = 0; int temp = 0; Map type2Num = new HashMap (); for (String index : dataIndex) { temp = Integer.parseInt(index); // 取最后一列的決策類別數據 classType = datas.get(temp)[featureNames.length - 1]; if (type2Num.containsKey(classType)) { // 如果類別已經存在,則使其計數加1 count = type2Num.get(classType); count++; } else { count = 1; } type2Num.put(classType, count); } // 選出其中類別支持計數最多的一個類別值 count = -1; for (Map.Entry entry : type2Num.entrySet()) { if ((int) entry.getValue() > count) { count = (int) entry.getValue(); resultClassType = (String) entry.getKey(); } } return resultClassType; } }
package DataMining_RandomForest; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.Random; /** * 隨機森林算法工具類 * * @author lyq * */ public class RandomForestTool { // 測試數據文件地址 private String filePath; // 決策樹的樣本占總數的占比率 private double sampleNumRatio; // 樣本數據的采集特征數量占總特征的比例 private double featureNumRatio; // 決策樹的采樣樣本數 private int sampleNum; // 樣本數據的采集采樣特征數 private int featureNum; // 隨機森林中的決策樹的數目,等于總的數據數/用于構造每棵樹的數據的數量 private int treeNum; // 隨機數產生器 private Random random; // 樣本數據列屬性名稱行 private String[] featureNames; // 原始的總的數據 private ArrayListtotalDatas; // 決策樹森林 private ArrayList decisionForest; public RandomForestTool(String filePath, double sampleNumRatio, double featureNumRatio) { this.filePath = filePath; this.sampleNumRatio = sampleNumRatio; this.featureNumRatio = featureNumRatio; readDataFile(); } /** * 從文件中讀取數據 */ private void readDataFile() { File file = new File(filePath); ArrayList dataArray = new ArrayList (); try { BufferedReader in = new BufferedReader(new FileReader(file)); String str; String[] tempArray; while ((str = in.readLine()) != null) { tempArray = str.split(" "); dataArray.add(tempArray); } in.close(); } catch (IOException e) { e.getStackTrace(); } totalDatas = dataArray; featureNames = totalDatas.get(0); sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio); //算屬性數量的時候需要去掉id屬性和決策屬性,用條件屬性計算 featureNum = (int) ((featureNames.length -2) * featureNumRatio); // 算數量的時候需要去掉首行屬性名稱行 treeNum = (totalDatas.size() - 1) / sampleNum; } /** * 產生決策樹 */ private DecisionTree produceDecisionTree() { int temp = 0; DecisionTree tree; String[] tempData; //采樣數據的隨機行號組 ArrayList sampleRandomNum; //采樣屬性特征的隨機列號組 ArrayList featureRandomNum; ArrayList datas; sampleRandomNum = new ArrayList<>(); featureRandomNum = new ArrayList<>(); datas = new ArrayList<>(); for(int i=0; i 0){ array[0] = temp + ""; } temp++; } tree = new DecisionTree(datas); return tree; } /** * 構造隨機森林 */ public void constructRandomTree() { DecisionTree tree; random = new Random(); decisionForest = new ArrayList<>(); System.out.println("下面是隨機森林中的決策樹:"); // 構造決策樹加入森林中 for (int i = 0; i < treeNum; i++) { System.out.println("\n決策樹" + (i+1)); tree = produceDecisionTree(); decisionForest.add(tree); } } /** * 根據給定的屬性條件進行類別的決策 * * @param features * 給定的已知的屬性描述 * @return */ public String judgeClassType(String features) { // 結果類型值 String resultClassType = ""; String classType = ""; int count = 0; Map type2Num = new HashMap (); for (DecisionTree tree : decisionForest) { classType = tree.decideClassType(features); if (type2Num.containsKey(classType)) { // 如果類別已經存在,則使其計數加1 count = type2Num.get(classType); count++; } else { count = 1; } type2Num.put(classType, count); } // 選出其中類別支持計數最多的一個類別值 count = -1; for (Map.Entry entry : type2Num.entrySet()) { if ((int) entry.getValue() > count) { count = (int) entry.getValue(); resultClassType = (String) entry.getKey(); } } return resultClassType; } }
package DataMining_RandomForest; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.Queue; /** * CART分類回歸樹算法工具類 * * @author lyq * */ public class CARTTool { // 類標號的值類型 private final String YES = "Yes"; private final String NO = "No"; // 所有屬性的類型總數,在這里就是data源數據的列數 private int attrNum; private String filePath; // 初始源數據,用一個二維字符數組存放模仿表格數據 private String[][] data; // 數據的屬性行的名字 private String[] attrNames; // 每個屬性的值所有類型 private HashMap> attrValue; public CARTTool(ArrayList dataArray) { attrValue = new HashMap<>(); readData(dataArray); } /** * 根據隨機選取的樣本數據進行初始化 * @param dataArray * 已經讀入的樣本數據 */ public void readData(ArrayList dataArray) { data = new String[dataArray.size()][]; dataArray.toArray(data); attrNum = data[0].length; attrNames = data[0]; } /** * 首先初始化每種屬性的值的所有類型,用于后面的子類熵的計算時用 */ public void initAttrValue() { ArrayList tempValues; // 按照列的方式,從左往右找 for (int j = 1; j < attrNum; j++) { // 從一列中的上往下開始尋找值 tempValues = new ArrayList<>(); for (int i = 1; i < data.length; i++) { if (!tempValues.contains(data[i][j])) { // 如果這個屬性的值沒有添加過,則添加 tempValues.add(data[i][j]); } } // 一列屬性的值已經遍歷完畢,復制到map屬性表中 attrValue.put(data[0][j], tempValues); } } /** * 計算機基尼指數 * * @param remainData * 剩余數據 * @param attrName * 屬性名稱 * @param value * 屬性值 * @param beLongValue * 分類是否屬于此屬性值 * @return */ public double computeGini(String[][] remainData, String attrName, String value, boolean beLongValue) { // 實例總數 int total = 0; 【本文來自鴻網互聯 (http://www.68idc.cn)】 // 正實例數 int posNum = 0; // 負實例數 int negNum = 0; // 基尼指數 double gini = 0; // 還是按列從左往右遍歷屬性 for (int j = 1; j < attrNames.length; j++) { // 找到了指定的屬性 if (attrName.equals(attrNames[j])) { for (int i = 1; i < remainData.length; i++) { // 統計正負實例按照屬于和不屬于值類型進行劃分 if ((beLongValue && remainData[i][j].equals(value)) || (!beLongValue && !remainData[i][j].equals(value))) { if (remainData[i][attrNames.length - 1].equals(YES)) { // 判斷此行數據是否為正實例 posNum++; } else { negNum++; } } } } } total = posNum + negNum; double posProbobly = (double) posNum / total; double negProbobly = (double) negNum / total; gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly; // 返回計算基尼指數 return gini; } /** * 計算屬性劃分的最小基尼指數,返回最小的屬性值劃分和最小的基尼指數,保存在一個數組中 * * @param remainData * 剩余誰 * @param attrName * 屬性名稱 * @return */ public String[] computeAttrGini(String[][] remainData, String attrName) { String[] str = new String[2]; // 最終該屬性的劃分類型值 String spiltValue = ""; // 臨時變量 int tempNum = 0; // 保存屬性的值劃分時的最小的基尼指數 double minGini = Integer.MAX_VALUE; ArrayList valueTypes = attrValue.get(attrName); // 屬于此屬性值的實例數 HashMap belongNum = new HashMap<>(); for (String string : valueTypes) { // 重新計數的時候,數字歸0 tempNum = 0; // 按列從左往右遍歷屬性 for (int j = 1; j < attrNames.length; j++) { // 找到了指定的屬性 if (attrName.equals(attrNames[j])) { for (int i = 1; i < remainData.length; i++) { // 統計正負實例按照屬于和不屬于值類型進行劃分 if (remainData[i][j].equals(string)) { tempNum++; } } } } belongNum.put(string, tempNum); } double tempGini = 0; double posProbably = 1.0; double negProbably = 1.0; for (String string : valueTypes) { tempGini = 0; posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); negProbably = 1 - posProbably; tempGini += posProbably * computeGini(remainData, attrName, string, true); tempGini += negProbably * computeGini(remainData, attrName, string, false); if (tempGini < minGini) { minGini = tempGini; spiltValue = string; } } str[0] = spiltValue; str[1] = minGini + ""; return str; } public void buildDecisionTree(TreeNode node, String parentAttrValue, String[][] remainData, ArrayList remainAttr, boolean beLongParentValue) { // 屬性劃分值 String valueType = ""; // 劃分屬性名稱 String spiltAttrName = ""; double minGini = Integer.MAX_VALUE; double tempGini = 0; // 基尼指數數組,保存了基尼指數和此基尼指數的劃分屬性值 String[] giniArray; if (beLongParentValue) { node.setParentAttrValue(parentAttrValue); } else { node.setParentAttrValue("!" + parentAttrValue); } if (remainAttr.size() == 0) { if (remainData.length > 1) { ArrayList indexArray = new ArrayList<>(); for (int i = 1; i < remainData.length; i++) { indexArray.add(remainData[i][0]); } node.setDataIndex(indexArray); } // System.out.println("attr remain null"); return; } for (String str : remainAttr) { giniArray = computeAttrGini(remainData, str); tempGini = Double.parseDouble(giniArray[1]); if (tempGini < minGini) { spiltAttrName = str; minGini = tempGini; valueType = giniArray[0]; } } // 移除劃分屬性 remainAttr.remove(spiltAttrName); node.setAttrName(spiltAttrName); // 孩子節點,分類回歸樹中,每次二元劃分,分出2個孩子節點 TreeNode[] childNode = new TreeNode[2]; String[][] rData; boolean[] bArray = new boolean[] { true, false }; for (int i = 0; i < bArray.length; i++) { // 二元劃分屬于屬性值的劃分 rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); boolean sameClass = true; ArrayList indexArray = new ArrayList<>(); for (int k = 1; k < rData.length; k++) { indexArray.add(rData[k][0]); // 判斷是否為同一類的 if (!rData[k][attrNames.length - 1] .equals(rData[1][attrNames.length - 1])) { // 只要有1個不相等,就不是同類型的 sameClass = false; break; } } childNode[i] = new TreeNode(); if (!sameClass) { // 創建新的對象屬性,對象的同個引用會出錯 ArrayList rAttr = new ArrayList<>(); for (String str : remainAttr) { rAttr.add(str); } buildDecisionTree(childNode[i], valueType, rData, rAttr, bArray[i]); } else { String pAtr = (bArray[i] ? valueType : "!" + valueType); childNode[i].setParentAttrValue(pAtr); childNode[i].setDataIndex(indexArray); } } node.setChildAttrNode(childNode); } /** * 屬性劃分完畢,進行數據的移除 * * @param srcData * 源數據 * @param attrName * 劃分的屬性名稱 * @param valueType * 屬性的值類型 * @parame beLongValue 分類是否屬于此值類型 */ private String[][] removeData(String[][] srcData, String attrName, String valueType, boolean beLongValue) { String[][] desDataArray; ArrayList desData = new ArrayList<>(); // 待刪除數據 ArrayList selectData = new ArrayList<>(); selectData.add(attrNames); // 數組數據轉化到列表中,方便移除 for (int i = 0; i < srcData.length; i++) { desData.add(srcData[i]); } // 還是從左往右一列列的查找 for (int j = 1; j < attrNames.length; j++) { if (attrNames[j].equals(attrName)) { for (int i = 1; i < desData.size(); i++) { if (desData.get(i)[j].equals(valueType)) { // 如果匹配這個數據,則移除其他的數據 selectData.add(desData.get(i)); } } } } if (beLongValue) { desDataArray = new String[selectData.size()][]; selectData.toArray(desDataArray); } else { // 屬性名稱行不移除 selectData.remove(attrNames); // 如果是劃分不屬于此類型的數據時,進行移除 desData.removeAll(selectData); desDataArray = new String[desData.size()][]; desData.toArray(desDataArray); } return desDataArray; } /** * 構造分類回歸樹,并返回根節點 * @return */ public TreeNode startBuildingTree() { initAttrValue(); ArrayList remainAttr = new ArrayList<>(); // 添加屬性,除了最后一個類標號屬性 for (int i = 1; i < attrNames.length - 1; i++) { remainAttr.add(attrNames[i]); } TreeNode rootNode = new TreeNode(); buildDecisionTree(rootNode, "", data, remainAttr, false); setIndexAndAlpah(rootNode, 0, false); showDecisionTree(rootNode, 1); return rootNode; } /** * 顯示決策樹 * * @param node * 待顯示的節點 * @param blankNum * 行空格符,用于顯示樹型結構 */ private void showDecisionTree(TreeNode node, int blankNum) { System.out.println(); for (int i = 0; i < blankNum; i++) { System.out.print(" "); } System.out.print("--"); // 顯示分類的屬性值 if (node.getParentAttrValue() != null && node.getParentAttrValue().length() > 0) { System.out.print(node.getParentAttrValue()); } else { System.out.print("--"); } System.out.print("--"); if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { String i = node.getDataIndex().get(0); System.out.print("【" + node.getNodeIndex() + "】類別:" + data[Integer.parseInt(i)][attrNames.length - 1]); System.out.print("["); for (String index : node.getDataIndex()) { System.out.print(index + ", "); } System.out.print("]"); } else { // 遞歸顯示子節點 System.out.print("【" + node.getNodeIndex() + ":" + node.getAttrName() + "】"); if (node.getChildAttrNode() != null) { for (TreeNode childNode : node.getChildAttrNode()) { showDecisionTree(childNode, 2 * blankNum); } } else { System.out.print("【 Child Null】"); } } } /** * 為節點設置序列號,并計算每個節點的誤差率,用于后面剪枝 * * @param node * 開始的時候傳入的是根節點 * @param index * 開始的索引號,從1開始 * @param ifCutNode * 是否需要剪枝 */ private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) { TreeNode tempNode; // 最小誤差代價節點,即將被剪枝的節點 TreeNode minAlphaNode = null; double minAlpah = Integer.MAX_VALUE; Queue nodeQueue = new LinkedList (); nodeQueue.add(node); while (nodeQueue.size() > 0) { index++; // 從隊列頭部獲取首個節點 tempNode = nodeQueue.poll(); tempNode.setNodeIndex(index); if (tempNode.getChildAttrNode() != null) { for (TreeNode childNode : tempNode.getChildAttrNode()) { nodeQueue.add(childNode); } computeAlpha(tempNode); if (tempNode.getAlpha() < minAlpah) { minAlphaNode = tempNode; minAlpah = tempNode.getAlpha(); } else if (tempNode.getAlpha() == minAlpah) { // 如果誤差代價值一樣,比較包含的葉子節點個數,剪枝有多葉子節點數的節點 if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { minAlphaNode = tempNode; } } } } if (ifCutNode) { // 進行樹的剪枝,讓其左右孩子節點為null minAlphaNode.setChildAttrNode(null); } } /** * 為非葉子節點計算誤差代價,這里的后剪枝法用的是CCP代價復雜度剪枝 * * @param node * 待計算的非葉子節點 */ private void computeAlpha(TreeNode node) { double rt = 0; double Rt = 0; double alpha = 0; // 當前節點的數據總數 int sumNum = 0; // 最少的偏差數 int minNum = 0; ArrayList dataIndex; ArrayList leafNodes = new ArrayList<>(); addLeafNode(node, leafNodes); node.setLeafNum(leafNodes.size()); for (TreeNode attrNode : leafNodes) { dataIndex = attrNode.getDataIndex(); int num = 0; sumNum += dataIndex.size(); for (String s : dataIndex) { // 統計分類數據中的正負實例數 if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { num++; } } minNum += num; // 取小數量的值部分 if (1.0 * num / dataIndex.size() > 0.5) { num = dataIndex.size() - num; } rt += (1.0 * num / (data.length - 1)); } //同樣取出少偏差的那部分 if (1.0 * minNum / sumNum > 0.5) { minNum = sumNum - minNum; } Rt = 1.0 * minNum / (data.length - 1); alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1); node.setAlpha(alpha); } /** * 篩選出節點所包含的葉子節點數 * * @param node * 待篩選節點 * @param leafNode * 葉子節點列表容器 */ private void addLeafNode(TreeNode node, ArrayList leafNode) { ArrayList dataIndex; if (node.getChildAttrNode() != null) { for (TreeNode childNode : node.getChildAttrNode()) { dataIndex = childNode.getDataIndex(); if (dataIndex != null && dataIndex.size() > 0) { // 說明此節點為葉子節點 leafNode.add(childNode); } else { // 如果還是非葉子節點則繼續遞歸調用 addLeafNode(childNode, leafNode); } } } } }
package DataMining_RandomForest; import java.text.MessageFormat; /** * 隨機森林算法測試場景 * * @author lyq * */ public class Client { public static void main(String[] args) { String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair"; String resultClassType = ""; // 決策樹的樣本占總數的占比率 double sampleNumRatio = 0.4; // 樣本數據的采集特征數量占總特征的比例 double featureNumRatio = 0.5; RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio, featureNumRatio); tool.constructRandomTree(); resultClassType = tool.judgeClassType(queryStr); System.out.println(); System.out .println(MessageFormat.format( "查詢屬性描述{0},預測的分類
下面是隨機森林中的決策樹: 決策樹1 --!--【1:Income】 --Medium--【2】類別:Yes[1, 2, ] --!Medium--【3:Student】 --No--【4】類別:No[3, 5, ] --!No--【5】類別:Yes[4, ] 決策樹2 --!--【1:Student】 --No--【2】類別:No[1, 3, ] --!No--【3】類別:Yes[2, 4, 5, ] 查詢屬性描述Age=Youth,Income=Low,Student=No,CreditRating=Fair,預測的分類
輸出的結果決策樹建議從左往右看,從上往下,【】符號表示一個節點,---XX---表示屬性值的劃分,你就應該能看懂這棵樹了,在console上想展示漂亮的樹形效果的確很難。。這里說一個算法的重大不足,數據太少,導致選擇的樣本數據不足,所選屬性太少,,構造的決策樹數量過少,自然分類的準確率不見得會有多準,博友只要能領會代碼中所表達的算法的思想即可。
下面來說說隨機森林的兄弟算法GBDT,梯度提升決策樹,他有很多的決策樹,他也有組合的思想,但是他不是隨機森林算法2,GBDT的關鍵在于Gradient Boosting,梯度提升。這個詞語理解起來就不容易了。學術的描述,每一次建立模型是在之前建立模型的損失函數的梯度下降方向。GBDT的核心在于,每一棵樹學的是之前所有樹結論和的殘差,這個殘差你可以理解為與預測值的差值。舉個例子:比如預測張三的年齡,張三的真實年齡18歲,第一棵樹預測張的年齡12歲,此時殘差為18-12=6歲,因此在第二棵樹中,我們把張的年齡作為6歲去學習,如果預測成功了,則張的真實年齡就是A樹和B樹的結果預測值的和,但是如果B預測成了5歲,那么殘差就變成了6-5=1歲,那么此時需要構建第三樹對1歲做預測,后面一樣的道理。每棵樹都是對之前失敗預測的一個補充,用公式的表達就是如下的這個樣子:
F0在這里是初始值,Ti是一棵棵的決策樹,不同的問題選擇不同的損失函數和初始值。在阿里內部對于此算法的叫法為TreeLink。所以下次聽到什么Treelink算法了指的就是梯度提升樹算法,其實我在這里省略了很大篇幅的數學推導過程,再加上自己還不是專家,無法徹底解釋清數學的部分,所以就沒有提及,希望以后有時間可以深入學習此方面的知識。
聲明:本網頁內容旨在傳播知識,若有侵權等問題請及時與本網聯系,我們將在第一時間刪除處理。TEL:177 7030 7066 E-MAIL:11247931@qq.com