Skip to main content

Support Vector Machine in R

 

Support Vector Machine in R

In this lab, we will learn how to use Support Vector Machine in R to build a classifier.

Kindly go through the video lectures to get a better understanding of this lab.

We will first create our own simulated data. Then we will use this simulated data to build an SVM classifier.

Let’s create a simulated data

#Two independent variables
set.seed(1); x1 = rnorm(500, mean=5, sd=3); mx1 = mean(x1) 
set.seed(10); x2 = rnorm(500, mean=10, sd=8); mx2 = mean(x2)

# Response variable with two classes -1 and +1
y = ifelse((x2>= mx2 & x1>=mx1), 1, -1) 
y [x1<2.5] = 1 #To add more non-linearity in the data

plot(x2, x1, pch=3+y, col=ifelse(1+y,"red","blue")) #Plot the data
lines(c(20,0),c(0,5), lwd=5, col="orange")
lines(c(30,0),c(0,16), lwd=5, col="green")

From the plot, we can clearly see that the classes cannot be divided by a linear boundary. We tried two different linear boundaries to separate the data points in their two classes, but it seems it is not possible.

As we can also see that the data points of two classes are mixed up in such a way that even a simple non-linear boundary will not be able to separate the two classes properly. For this reason, we need an effective machine learning classifier such as SVM to do the job most effectively.

Train and Test Data set

You can also embed plots, for example:

X = as.matrix(cbind(x1, x2)); data = data.frame(x=X, y=as.factor(y))
m = nrow(data)

set.seed(1);train.idx = sample(1:m, 0.8*m, replace=F)

train.n = length(train.idx)

test.n = m-train.n

train = data[train.idx,]
test = data[-train.idx,]

dim(train)
## [1] 400   3
dim(test)
## [1] 100   3

Training the SVM model on train data set

To train the SVM model in R, we will use a function, svm() from a package, e1071. You need to install this package in R.

library(e1071)

svm.model = svm(y~., data = train, kernel= "radial", scale=FALSE) # We don't want to scale data as their mean and sd                                                                    # are pretty close, 
                                                                  # Using radial kernel for non-linear data

summary(svm.model)
## 
## Call:
## svm(formula = y ~ ., data = train, kernel = "radial", scale = FALSE)
## 
## 
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  radial 
##        cost:  1 
## 
## Number of Support Vectors:  218
## 
##  ( 105 113 )
## 
## 
## Number of Classes:  2 
## 
## Levels: 
##  -1 1

The summary tells us that the model has used 218 data points as support vectors (105 from one class and 113 from another)

You can see other components of the model:

names(svm.model)
##  [1] "call"            "type"            "kernel"          "cost"           
##  [5] "degree"          "gamma"           "coef0"           "nu"             
##  [9] "epsilon"         "sparse"          "scaled"          "x.scale"        
## [13] "y.scale"         "nclasses"        "levels"          "tot.nSV"        
## [17] "nSV"             "labels"          "SV"              "index"          
## [21] "rho"             "compprob"        "probA"           "probB"          
## [25] "sigma"           "coefs"           "na.action"       "fitted"         
## [29] "decision.values" "terms"

If you wish, you can learn about these different components by typing ?svm in R.

Plotting the model

Let’s plot the model

plot(svm.model, train)

Here, do not get confused with the shape of the data points. Shape of the data points have nothing to do with their class in this plot. All the crosses represent support vectors and not misclassified data points.

In this plot, colors represent classes. Maroon color is 1 and light pink is -1.

From the plot, we can see that the model has done a great job in creating complex non-linear boundaries that have been able to separate almost all the data points according to their respective classes.

Visualising the data in a higher dimension

What SVM does is, it projects the non-linearly separable data set into a higher dimension in which the data points can be separated linearly.

To understand the above idea, let’s plot the data points in a 3d space using a function scatterplot3d() from a package scatterplot3d. You need to install this package in R.

library(scatterplot3d)
 s3d = scatterplot3d(train, pch="", col.grid = "lightblue", col.axis = "blue", highlight.3d = TRUE)
 s3d$points3d(train, pch = (as.integer(train$y)+3))
 s3d$points3d(train, col = ifelse(as.integer(train$y)==1,"red","blue"))

As you can see how these non-linearly separable data set in 2D has been clearly separated in 3D. SVM places a hyperplane in the gap and classifies them.

Predictions

Performance on Train Data set

pred.train = fitted(svm.model)

table(train$y, pred.train)
##     pred.train
##       -1   1
##   -1 237   0
##   1    4 159
cat('\n Accuracy on train data set:\n',mean((pred.train == train$y))*100, '\n')
## 
##  Accuracy on train data set:
##  99

So, only 4 data points in the training set were misclassified. 99% accuracy shows that model has done extremely well.

Performance on Test Data Set

pred.test = predict(svm.model, newdata=test, type="class")

table(test$y, pred.test)
##     pred.test
##      -1  1
##   -1 48  2
##   1   2 48
cat('\n Accuracy on test data set:\n',mean((pred.test == test$y))*100, '\n')
## 
##  Accuracy on test data set:
##  96

96% accuracy on test data is great. Again, only 4 data points were misclassified.

Exercise: Try some other data sets with two-class/multi-class response variable from R datasets and MASS packages, etc.


Click the links below for more


Comments

Popular posts from this blog

Metaverse needs better technology, scalable infra, strong governance

Many minds have been intrigued by the idea of metaverse, and its effect is such that the social media giant like Facebook has been rebranded as Meta. Yet, there is a big question mark on the future of this technology. The enablers of metaverse such as augmented reality, mixed reality and virtual reality operating on computers, smartphones and other devices have failed to give the complete real-world like immersive experience to end users. There is a clear lack of standard virtual environment and technical specifications for implementing metaverse  –  a bottleneck in using technologies from different proprietors. Due to the business privacy and transparency concerns, interoperability of services from various providers has become a big challenge. Although, the efforts to standardize virtual reality, such as Universal Scene Description, glTF and OpenXR may help in a long run, but a lot more needs to be put in.  The technologies and devices, such as wireless he...

What is ChatGPT?

Introduction ChatGPT is a language model developed by OpenAI based on the GPT-3.5 architecture. It is designed to perform various natural language processing tasks such as language translation, text summarization, question-answering, and chatbot interactions. In this blog, we will discuss ChatGPT, its architecture, applications, and benefits. Architecture ChatGPT is based on the GPT-3.5 architecture, which is an extension of the GPT-3 architecture. The model has 175 billion parameters, making it one of the largest language models available. The architecture consists of 96 transformer blocks with a hidden size of 12,288 and 10 attention heads. The model is trained using a combination of unsupervised and supervised learning techniques. Applications ChatGPT has a wide range of applications in various fields such as healthcare, finance, customer service, and education. Some of the applications of ChatGPT are as follows: Language translation: ChatGPT can translate text from one language to ...

Exploratory Data Analysis

  Lab_D_2_RM Asmi Ariv 2022-10-14 Exploratory Data Analysis In this lab, we will go through various steps to explore a dataset using descriptive statistics, summary of data, different graphs, etc. Factor Variables (try the following in R): data = read.csv( "patient.csv" );data #Reading patient data ## Patient Gender Age Group ## 1 Dick M 20 2 ## 2 Anna F 25 1 ## 3 Sam M 30 3 ## 4 Jennie F 28 2 ## 5 Joss M 29 3 ## 6 Don M 21 2 ## 7 Annie F 26 1 ## 8 John M 32 3 ## 9 Rose F 27 2 ## 10 Jack M 31 3 data$Gender #It is a string/character variable ## [1] "M" "F" "M" "F" "M" "M" "F" "M" "F" "M" data$Gender = factor(data$Gender,levels=c( "M" , "F" ), ordered= TRUE ) #...