Machine Learning 1.02: Decision Trees PDF
Document Details
Uploaded by ProperScholarship1342
University of Bath
Rohit Babbar
Tags
Summary
This document is a set of lecture notes covering decision trees in machine learning. It introduces supervised learning, classification and regression. Decision stumps and continuous decision stumps are also covered. The document also provides examples of overfitting in machine learning, and provides brief explanations of further categories.
Full Transcript
Machine Learning 1.02: Decision Trees Rohit Babbar [email protected] 1 / 34 This lecture Decision tree – supervised classification A good starter algorithm – easy to understand Next...
Machine Learning 1.02: Decision Trees Rohit Babbar [email protected] 1 / 34 This lecture Decision tree – supervised classification A good starter algorithm – easy to understand Next lecture is random forests, which build on decision trees (random forests were the best approach before deep learning) Glossary 2 / 34 What is the goal? I Learn y = f (x) from data – x ∈ Rn is a n dimensional feature vector – y ∈ {0,... , K − 1} is the output class when there are K possible output classes. 3 / 34 What is the goal? I Learn y = f (x) from data – x ∈ Rn is a n dimensional feature vector – y ∈ {0,... , K − 1} is the output class when there are K possible output classes. We can’t consider every possible f (x) – there are infinitely many that fit any data set! Instead fθ (x) will belong to a space of functions parametrised by θ 3 / 34 What is the goal? II Learn y = fθ (x) from data, {(x1 , y1 ), (x2 , y2 ),... , (xN , yN )} But how do we know which θ is best? 4 / 34 What is the goal? II Learn y = fθ (x) from data, {(x1 , y1 ), (x2 , y2 ),... , (xN , yN )} But how do we know which θ is best? A loss function: L(yi , fθ (xi )) Total loss: N P i=1 L(yi , fθ (xi )) 4 / 34 What is the goal? II Learn y = fθ (x) from data, {(x1 , y1 ), (x2 , y2 ),... , (xN , yN )} But how do we know which θ is best? A loss function: L(yi , fθ (xi )) Total loss: N P i=1 L(yi , fθ (xi )) PN Goal: Find θ that minimises i=1 L(yi , fθ (xi )) 4 / 34 Decision stump I Algorithm from last lecture! # Parameters... feature = ’ teeth ’ match = F a l s e # Function... def e v a l u a t e ( f v ) : r e t u r n f v [ f e a t u r e ] == match # F i t to data... best = 0.0 for f in features : f o r m i n [ F a l s e , True ] : a c c u r a c y = p e r f o r m a n c e ( f , m, t r a i n ) i f accuracy > best : feature = f match = m 5 / 34 Decision stump II Converted to current notation: Parameters: θ = {feature, match} Function: (δ = Kronecker delta function) fθ(x) = δ(xfeature , match) Loss function: 0 if yi = fθ (xi ) L(yi , fθ (xi )) = 1 otherwise (this is called 0–1 loss) 6 / 34 Decision stump III What is the space of functions it learns? 7 / 34 Decision stump III What is the space of functions it learns? Identify function or it’s inverse (not) – selects input feature most similar to output! This is not very useful... 7 / 34 Continuous decision stump I 3 2 What if the input was continuous? (colours denote classes) 1 0 1 2 3 3 2 1 0 1 2 3 8 / 34 Continuous decision stump I 3 2 What if the input was continuous? (colours denote classes) 1 We could instead split (partition) the 0 space: 0 if xfeature < split 1 fθ (x) = 1 xfeature ≥ split 2 3 3 2 1 0 1 2 3 8 / 34 Continuous decision stump II 3 feature = y , split = 0.5 2 accuracy = 50.8% (accuracy = one minus average 0–1 loss) 1 0 1 2 3 3 2 1 0 1 2 3 9 / 34 Continuous decision stump II 3 feature = y , split = 0.5 2 accuracy = 50.8% (accuracy = one minus average 0–1 loss) 1 feature = y , split = 0.3 accuracy = 52.1% 0 1 2 3 3 2 1 0 1 2 3 9 / 34 Continuous decision stump II 3 feature = y , split = 0.5 2 accuracy = 50.8% (accuracy = one minus average 0–1 loss) 1 feature = y , split = 0.3 accuracy = 52.1% 0 feature = x, split = −0.5 accuracy = 83.2% 1 2 3 3 2 1 0 1 2 3 9 / 34 Continuous decision stump II 3 feature = y , split = 0.5 2 accuracy = 50.8% (accuracy = one minus average 0–1 loss) 1 feature = y , split = 0.3 accuracy = 52.1% 0 feature = x, split = −0.5 accuracy = 83.2% feature = x, split = 0.2 1 accuracy = 92.4% 2 3 3 2 1 0 1 2 3 9 / 34 Continuous decision stump II 3 feature = y , split = 0.5 2 accuracy = 50.8% (accuracy = one minus average 0–1 loss) 1 feature = y , split = 0.3 accuracy = 52.1% 0 feature = x, split = −0.5 accuracy = 83.2% feature = x, split = 0.2 1 accuracy = 92.4% feature = x, split = 0.0 2 accuracy = 100% 3 3 2 1 0 1 2 3 9 / 34 Continuous decision stump III How to find best parameters? Brute force: 1. Sweep each dimension and consider every split (half way between each pair of exemplars when sorted along that axis) 2. Evaluate loss function for every split 3. Choose best x: y: 100 100 80 80 60 60 40 40 20 20 0 0 3 2 1 0 1 2 3 3 2 1 0 1 2 3 10 / 34 Continuous decision stump IV 8 6 4 2 What about this? 0 Axis-aligned separation is rare in real data! 2 4 6 8 8 6 4 2 0 2 4 6 8 11 / 34 Decision trees I What if we did this recursively? 8 6 4 2 0 2 4 6 8 8 6 4 2 0 2 4 6 8 12 / 34 Decision trees I What if we did this recursively? 8 Background colour – shaded with ratio of red/blue points 6 First split – as before 4 Can be represented as a tree 2 0 2 4 6 8 8 6 4 2 0 2 4 6 8 12 / 34 Decision trees I What if we did this recursively? 8 Background colour – shaded with ratio of red/blue points 6 First split – as before 4 Can be represented as a tree 2 Have split left half of first split again 0 2 4 6 8 8 6 4 2 0 2 4 6 8 12 / 34 Decision trees I What if we did this recursively? 8 Background colour – shaded with ratio of red/blue points 6 First split – as before 4 Can be represented as a tree 2 Have split left half of first split again Jumping ahead... 0 2 4 6 8 8 6 4 2 0 2 4 6 8 12 / 34 Decision trees I What if we did this recursively? 8 Background colour – shaded with ratio of red/blue points 6 First split – as before 4 Can be represented as a tree 2 Have split left half of first split again Jumping ahead... 0 2 Decision tree = recursive splitting 4 6 8 8 6 4 2 0 2 4 6 8 12 / 34 Decision trees II The function parameter, θ, is a binary tree 13 / 34 Decision trees II The function parameter, θ, is a binary tree You have two kinds of node: internal, which contain a split (rounded rectangles) leaf, which contain an answer (big numbers) 13 / 34 Decision trees III To evaluate the function, fθ (x): 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split Go left or right based on test 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split Go left or right based on test Move to next split, and so on... 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split Go left or right based on test Move to next split, and so on... 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split Go left or right based on test Move to next split, and so on... 14 / 34 Decision trees III To evaluate the function, fθ (x): Start at top Move to first split Go left or right based on test Move to next split, and so on... Stop at leaf, return its answer 14 / 34 Decision trees IV Greedy optimisation of parameters (θ) (brute force at each node) Tree construction: 1. Training data contains only one class → generate leaf node 2. Otherwise: Try all features / splits Select best (lowest total loss) Recurse (build another tree) for left and right children (train each tree with data that reaches it) 15 / 34 Decision trees V 8 6 4 Need L(yi , fθ (xi )), the loss function 2 0 2 4 6 8 8 6 4 2 0 2 4 6 8 16 / 34 Decision trees V 8 6 4 Need L(yi , fθ (xi )), the loss function 2 0–1 loss works for decision stumps...... but fails for decision trees! 0 (wasted splits, can’t refine initial rectangle) 2 4 6 8 8 6 4 2 0 2 4 6 8 16 / 34 Decision trees V 8 6 4 Need L(yi , fθ (xi )), the loss function 2 0–1 loss works for decision stumps...... but fails for decision trees! 0 (wasted splits, can’t refine initial rectangle) 2 Two that work: Gini impurity (first image on this slide) 4 Information gain (current image) 6 8 8 6 4 2 0 2 4 6 8 16 / 34 Gini impurity “Probability that if you select two items from a data set at random (with replacement) they will have a different class” Lowest (0): When there is only one class Highest (< 1): When every data point has a different class 17 / 34 Gini impurity “Probability that if you select two items from a data set at random (with replacement) they will have a different class” Lowest (0): When there is only one class Highest (< 1): When every data point has a different class pi = P(selecting class i from data set) X X G (p) = pi (1 − pi ) = 1 − pi2 i i 17 / 34 Gini impurity “Probability that if you select two items from a data set at random (with replacement) they will have a different class” Lowest (0): When there is only one class Highest (< 1): When every data point has a different class pi = P(selecting class i from data set) X X G (p) = pi (1 − pi ) = 1 − pi2 i i 17 / 34 Weighting the split Can calculate G (p (left) ) and G (p (right) ) for halves of a split...... but need a single number Weight by exemplar count: nl nr L(split) = G (p (left) ) + G (p (right) ) n n n = total exemplar count nl = exemplars traveling down left branch nr = exemplars traveling down right branch 18 / 34 Weighting the split Can calculate G (p (left) ) and G (p (right) ) for halves of a split...... but need a single number Weight by exemplar count: nl nr L(split) = G (p (left) ) + G (p (right) ) n n n = total exemplar count nl = exemplars traveling down left branch nr = exemplars traveling down right branch Intuitively: more data in a branch =⇒ more important Information gain is weighted the same... 18 / 34 Information gain Conceptually: How much you learn from traversing a split Entropy: X H(p) = − pi log(pi ) i Important – internet wouldn’t work without it! = bits required on average to encode data with given PDF (bits assumes log2 ; if loge the unit is nats) 19 / 34 Information gain Conceptually: How much you learn from traversing a split Entropy: X H(p) = − pi log(pi ) i Important – internet wouldn’t work without it! = bits required on average to encode data with given PDF (bits assumes log2 ; if loge the unit is nats) Information gain = number of bits/nats obtained from traversing the split: nl nr I (split) = H(p (parent) ) − H(p (left) ) − H(p (right) ) n n (you maximise this one!) 19 / 34 Which loss function? 1.0 Typical conditions =⇒ almost identical! 0.8 Unsurprising as graphs very similar Red = Gini impurity Green = information gain 0.6 0.4 0.2 0.0 0.0 0.2 0.4 0.6 0.8 1.0 20 / 34 Which loss function? 1.0 Typical conditions =⇒ almost identical! 0.8 Unsurprising as graphs very similar Red = Gini impurity Green = information gain 0.6 Gini: Faster to compute (no log) ∴ default. 0.4 0.2 0.0 0.0 0.2 0.4 0.6 0.8 1.0 20 / 34 Which loss function? 1.0 Typical conditions =⇒ almost identical! 0.8 Unsurprising as graphs very similar Red = Gini impurity Green = information gain 0.6 Gini: Faster to compute (no log) ∴ default. 0.4 Information gain: better for some problems 0.2 works when Gini cannot, e.g. regression has theory! 0.0 0.0 0.2 0.4 0.6 0.8 1.0 20 / 34 Overfitting 8 6 4 Boundary is a circle... 2 0 2 4 6 8 8 6 4 2 0 2 4 6 8 21 / 34 Overfitting 8 6 4 Boundary is a circle... but it does this... 2 0 2 4 6 8 8 6 4 2 0 2 4 6 8 21 / 34 Overfitting 8 6 4 Boundary is a circle... but it does this... 2 This is called overfitting 0 Modelling noise, not signal 2 4 6 8 8 6 4 2 0 2 4 6 8 21 / 34 Overfitting 8 6 4 Boundary is a circle... but it does this... 2 This is called overfitting 0 Modelling noise, not signal 2 Can detect overfitting using a test set 4 (algorithm can’t fit noise it hasn’t seen) 6 8 8 6 4 2 0 2 4 6 8 21 / 34 Early stopping Can avoid overfitting by stopping early Limit on tree depth Minimum leaf node size Prevents function getting too complicated! 22 / 34 Early stopping Can avoid overfitting by stopping early Limit on tree depth Minimum leaf node size Prevents function getting too complicated! These extra parameters are called hyperparameters (Gini or information gain is also one) Also optimised! (disturbingly often by hand) 22 / 34 A gloassary of ML problem settings 23 / 34 Supervised learning Learn a function: y = f (⃗x ) From (many) examples of input (⃗x ) and output (y ) Majority of ML: Classification or regression... 24 / 34 Supervised learning: Classification Learn a function: y = f (⃗x ) Classification: y is discrete 25 / 34 Supervised learning: Classification Learn a function: y = f (⃗x ) Classification: y is discrete Identifying camera trap animals Input: Image Output: Which animal (peccary) 25 / 34 Supervised learning: Classification Learn a function: y = f (⃗x ) Classification: y is discrete Identifying camera trap animals Input: Image Output: Which animal Predicting voting intention Input: Demographics Output: Preferred candidate (probabilistic) Run on entire country → Predict election winner (peccary) (YouGov, 2017-06-07) 25 / 34 Supervised learning: Regression Learn a function: y = f (⃗x ) Regression: y is continuous 26 / 34 Supervised learning: Regression Learn a function: y = f (⃗x ) Regression: y is continuous Predicting critical temperature of a superconductor Input: Material properties Output: Temperature 26 / 34 Supervised learning: Regression Learn a function: y = f (⃗x ) Regression: y is continuous Predicting critical temperature of a superconductor Input: Material properties Output: Temperature Inferring particle paths (LHC) Input: Detector energy spikes Output: Particle paths Trained with simulation 26 / 34 Supervised learning: Further kinds Multi-label classification: y is a set e.g. identifying objects in an image e.g. text summarisation (reusing source sentences) 27 / 34 Supervised learning: Further kinds Multi-label classification: y is a set e.g. identifying objects in an image e.g. text summarisation (reusing source sentences) Structured prediction: y is anything else! e.g. Sentence tagging: y is a sequence (such as part-of-speech tagging) e.g. Automated design: y is a CAD model 27 / 34 Unsupervised learning No y ! Finds patterns in data Examples: Clustering Density estimation Dimensionality reduction 28 / 34 Unsupervised learning: Clustering Clustering: Groups “similar” data points Arbitrary similarity definition 29 / 34 Unsupervised learning: Clustering Clustering: Groups “similar” data points (Gene expression of SLC2A4) Arbitrary similarity definition Identifying co-regulated genes: Input: Many expression level measurements Output: Groups of genes that tend to express at same time 29 / 34 Unsupervised learning: Clustering Clustering: Groups “similar” data points (Gene expression of SLC2A4) Arbitrary similarity definition Identifying co-regulated genes: Input: Many expression level measurements Output: Groups of genes that tend to express at same time Discovering social groups Input: Friend graph Output: Social groups (individuals may belong to several) (a Facebook friend graph) 29 / 34 Unsupervised learning: Dimensionality reduction Dimensionality reduction / manifold learning: Reduce dimensions while preserving information Also used for visualisation (important for verification) 30 / 34 Unsupervised learning: Dimensionality reduction Dimensionality reduction / manifold learning: Reduce dimensions while preserving information Also used for visualisation (important for verification) Organising news Input: Word vectors Output: Position in layout 30 / 34 *-supervised Collecting data cheap Labelling data expensive Semi-supervised: Some labelled data Lots of unlabelled data 31 / 34 *-supervised Collecting data cheap Labelling data expensive Semi-supervised: Some labelled data Lots of unlabelled data Precise labels expensive, inaccurate labels cheap Weakly-supervised: Learns from “weak” labels Outputs “strong” labels 31 / 34 *-supervised Collecting data cheap e.g. finding cats Labelling data expensive Image contains cat – fast Box around cat – slow Semi-supervised: Some labelled data Lots of unlabelled data Precise labels expensive, inaccurate labels cheap Weakly-supervised: Learns from “weak” labels Outputs “strong” labels 31 / 34 Glossary Supervised Unsupervised Classification Clustering Regularisation Density estimation Multi-label classification / abnormality detection Structured prediction Dimensionality reduction / manifold learning Semi-supervised Weakly-supervised Graphical models (incomplete! e.g. causality) 32 / 34 Further categories Can also classify ML algorithms by... 33 / 34 Further categories Can also classify ML algorithms by... Answer quality: Point estimate e.g. “The patient has cancer ” Probabilistic e.g. “60% chance that the patient has cancer ” 33 / 34 Further categories Can also classify ML algorithms by... Answer quality: Workflow: Point estimate Batch learning e.g. “The patient has cancer ” i.e. Collect data then learn Probabilistic Incremental learning e.g. “60% chance that the patient has i.e. Learn as data arrives cancer ” Active learning i.e. Algorithm selects data to learn from! (leads to automating science... ) 33 / 34 Further categories Can also classify ML algorithms by... Answer quality: Workflow: Point estimate Batch learning e.g. “The patient has cancer ” i.e. Collect data then learn Probabilistic Incremental learning e.g. “60% chance that the patient has i.e. Learn as data arrives cancer ” Active learning i.e. Algorithm selects data to learn from! Area: (leads to automating science... ) Traditional Computer vision Natural language processing (NLP) Interactive 33 / 34 Summary Introduced decision trees Good at ignoring useless data Can be interpretable Overfits most of the time Overfitting, testing, hyperparameters → future lectures Next lecture: Extending decision trees to random forests (which are much, much better) 34 / 34