感知机原理参考博客:
【机器学习】感知机原理详解
感知机模型:f(x)=sign(w*x+b)
sign是符号函数
感知机模型的其中一个超平面是:w*x+b=0
w是超平面的法向量,b是超平面的截距
这个超平面把样本分为正负两类(结合sign()函数)
编号 | 宽度 | 长度 | 检验类别 |
---|---|---|---|
1 | 3 | 3 | 正品 |
2 | 4 | 3 | 正品 |
3 | 1 | 1 | 次品 |
将正负样本区分开
x = np.array([[3,3],[4,3],[1,1]],dtype=np.float64)
y = np.array([1,1,-1],dtype=np.float64)
l=len(x)
x_positive = []
x_negetive = []
for i in range(l):
if y[i]==1:
x_positive.append(x[i])
else:
x_negetive.append(x[i])
x_positive = np.array(x_positive)
x_negetive = np.array(x_negetive)
2.赋值w_0,b_0
w=np.array([0.0,0.0])
b=np.array([0.0])
lr = 0.5
square = lambda x:x*x
def sign(v):
if v>=0:
return 1
else:
return -1
flag=0
3.选取数据点(x_i,y_i)
4.判断该数据点是否为当前模型的误分类点,并进行更新
y_i(w*x_i+b)<=0时更新
while flag==0:
for i in range(3):
y_pred = sign(np.matmul(w, x[i])+b)
if y[i] * y_pred < 0: #是误分类点。
partial_w = (-y[i]*x[i])
partial_b = (-y[i])
#更新:
w = w-lr*partial_w
b = b-lr*partial_b
5.若为误分类点则转到第三步,直到没有误分类点为止
完整代码如下:
import time
import numpy as np
import matplotlib.pyplot as plt
#训练数据
x = np.array([[3,3],[4,3],[1,1]],dtype=np.float64)
y = np.array([1,1,-1],dtype=np.float64)
l=len(x)
x_positive = []
x_negetive = []
for i in range(l):
if y[i]==1:
x_positive.append(x[i])
else:
x_negetive.append(x[i])
x_positive = np.array(x_positive)
x_negetive = np.array(x_negetive)
#参数定义
w=np.array([0.0,0.0])
b=np.array([0.0])
lr = 0.5
square = lambda x:x*x
def sign(v):
if v>=0:
return 1
else:
return -1
flag=0
#感知机学习算法是有误分类驱动的
while flag==0:
for i in range(3):
y_pred = sign(np.matmul(w, x[i])+b)
if y[i] * y_pred < 0: #是误分类点。
partial_w = (-y[i]*x[i])
partial_b = (-y[i])
#更新:
w = w-lr*partial_w
b = b-lr*partial_b
#画图
if w[1]==0:
continue
plt.xlim(0,10)
plt.ylim(-5,5)
plt.scatter(x_positive[:,0],x_positive[:,1],color='blue')
plt.scatter(x_negetive[:,0],x_negetive[:,1],color='red')
px=np.linspace(0,10,10)
py=(-w[0]/w[1])*px-b/w[1]
plt.plot(px,py)
plt.pause(0.1)
plt.clf()
#检验
rst=[]
for i in range(3):
if y[i]*sign(np.matmul(w, x[i])+b)>0:
#分类正确,则记为1
rst.append(1)
else:
rst.append(0)
#如果全都分类正确
if min(rst)==1:
flag=1
time.sleep(3)
break
print('w=',w)
print('b=',b)
文章评论