#!/usr/bin/env python import os import numpy as np from matplotlib import pyplot def plot_weight_boundary(w, color, label, linetype='-'): pyplot.quiver(0,0, w[0],w[1], angles='xy', scale_units='xy', scale=1, color=color) # draw arrow (0,0)->w x = np.arange(-10,10,0.01) # like range() pyplot.plot(x, -w[0]/w[1] * x, linetype, color=color, label=label) # perc separation boundary (line) pyplot.legend(loc=2) # top left location def gen(): # generate random separable data set w = np.random.rand(2) - 0.5 # from [0,1) to [-0.5, +0.5) w /= np.linalg.norm(w) # normalize to unit vector print "unit oracle w=", w plot_weight_boundary(w, 'blue', 'oracle', '--') # oracle weight vector X = np.random.randn(100, 2) * 10 # [-10,10]x[-10,10] range Y = X.dot(w) # dot product Y_filter = np.abs(Y) > 1 # # filter: functional margin > 2 X = X[Y_filter] # the subset that meets margin requirement Y = Y[Y_filter] # the corresponding Y subset # draw data points xx, yy = X[Y>1].T # positive points pyplot.plot(xx, yy, 'ob') # blue circle xx, yy = X[Y<-1].T # negative points pyplot.plot(xx, yy, 'sr') # red square pyplot.plot([-10,10],[0,0], lw=0.3, color='black') # from (-10,0) to (10,0), i.e., draw x-axis pyplot.plot([0,0],[-10,10], lw=0.3, color='black') # from (0,-10) to (0,10), i.e., draw y-axis pyplot.xlim(-10, 10) pyplot.ylim(-10, 10) pyplot.axes().set_aspect('equal') # equal xy scale! return X, Y #return np.concatenate((X[Y>1], X[Y<-1])), np.concatenate((Y[Y>1], Y[Y<-1])) # clustered pos/neg def perc(data): w = np.zeros(2) # (0,0) epoch, total, updates = 0, 0, 1 while updates > 0: # keep working until convergence epoch += 1 updates = 0 for (x, y) in data: if x.dot(w) * y <= 0: w += y * x updates += 1 print "epoch %d, updates %d" % (epoch, updates) total += updates print "converged after %d updates, w: %s" % (total, w) w /= np.linalg.norm(w) # normalize to unit vector plot_weight_boundary(w, 'black', 'perc') print "unit perceptron w=", w print return w pyplot.ion() # turn interactive mode on while True: X, Y = gen() # generate random data set w = perc(zip(X, Y)) # train perceptron pyplot.show() try: a = raw_input() # keyboard prompt except: break pyplot.clf() # clear figure