PyBrain - out = fnn.activateOnDataset(griddata)

时间:2015-10-25 07:54:09

标签: python image-processing neural-network pybrain

我一直在调整神经网络来对来自PyBrain教程的图像进行分类:

http://pybrain.org/docs/tutorial/fnn.html

以png格式输入图像数据,为每个图像分配一个特定的类。

它运作良好,直到:

out = fnn.activateOnDataset(griddata)

它返回的消息是:AssertionError:(3,2)

我很确定我是如何声明griddata数据集的,但我不确切知道是什么?

在教程版本上运行正常。

我的代码:

from pybrain.datasets            import ClassificationDataSet
from pybrain.utilities           import percentError
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure.modules   import SoftmaxLayer

from pylab import ion, ioff, figure, draw, contourf, clf, show, hold, plot
from scipy import diag, arange, meshgrid, where
from numpy.random import multivariate_normal
import cv2

from pyroc import *

#Creates cover type array based on color of pixels in roadmap

coverType = [(255,225,104,3),   #Road
(254,254,253,0), #Other
(254,254,254,3),    #Road
(253,254,253,0),#Other
(253,225,158,0),#Other
] # have other cover type but sample amount included

coverTypes = len(coverType)

print coverTypes #to count

#Creates dataset

alldata = ClassificationDataSet(3,1,nb_classes=10)

"""Classifies Roadmap Sub-Images by type and loads matching Satellite Sub-Image 
with classification into dataset."""

for eachFile in glob.glob('Roadmap Sub-Images/*'):
    image = Image.open(eachFile)
    fileName = eachFile
    newFileName = fileName.replace("Roadmap Sub-Images", "Satellite  Sub-Images")      

    colors = image.convert('RGB').getcolors() #Finds all colors in image and their frequency
    colors.sort() #Sorts colors in image by their frequency
    colorMostFrequent = colors[-1][1] #Finds last element in array, the most frequent color

    for eachColor in range(1,151): #151 number of element in CoverType array
        if colorMostFrequent[0] == coverType[eachColor][0] and colorMostFrequent[1] == coverType[eachColor][1] and colorMostFrequent[2] == coverType[eachColor][2]:

        print newFileName #Check new route
        image = cv2.imread(newFileName)    
        meanImage = cv2.mean(image) #Take average color        
        meanImageRGB = meanImage[:3] #Converts to RGB scale, excluding "alpha"        
        print meanImageRGB #Check RGB average colors        
        alldata.addSample(meanImageRGB,coverType[eachColor][3])



tstdata, trndata = alldata.splitWithProportion( 0.25 )

trndata._convertToOneOfMany( )
tstdata._convertToOneOfMany( )

fnn = buildNetwork( trndata.indim, 5, trndata.outdim, outclass=SoftmaxLayer )

trainer = BackpropTrainer( fnn, dataset=trndata, momentum=0.1, verbose=True, weightdecay=0.01)

ticks = arange(-3.,6.,0.2)

X, Y = meshgrid(ticks, ticks)

#I think every thing is good to here problem with the griddata dataset I think?

# need column vectors in dataset, not arrays

griddata = ClassificationDataSet(2,1, nb_classes=4)

for i in xrange(X.size):
    griddata.addSample([X.ravel()[i],Y.ravel()[i]], [0])

griddata._convertToOneOfMany()  # this is still needed to make the fnn feel comfy

for i in range(20):
    trainer.trainEpochs( 1 )

    trnresult = percentError( trainer.testOnClassData(),
                          trndata['class'] )
    tstresult = percentError( trainer.testOnClassData(
       dataset=tstdata ), tstdata['class'] )

    print "epoch: %4d" % trainer.totalepochs, \
          "  train error: %5.2f%%" % trnresult, \
          "  test error: %5.2f%%" % tstresult


    out = fnn.activateOnDataset(alldata)

    out = out.argmax(axis=1)  # the highest output activation gives the class
    out = out.reshape(X.shape)

    figure(1)
    ioff()  # interactive graphics off
    clf()   # clear the plot
    hold(True) # overplot on
    for c in [0,1,2]:
        here, _ = where(tstdata['class']==c)
        plot(tstdata['input'][here,0],tstdata['input'][here,1],'o')
    if out.max()!=out.min():  # safety check against flat field
        contourf(X, Y, out)   # plot the contour
    ion()   # interactive graphics on
    draw()  # update the plot

ioff()
show()

1 个答案:

答案 0 :(得分:0)

我认为这与初始数据集的尺寸不符合griddata的尺寸有关。

alldata = ClassificationDataSet(3,1,nb_classes=10) griddata = ClassificationDataSet(2,1, nb_classes=4)

他们应该都是3,1。但是,当我调整这个时,我的代码会在稍后阶段失败,所以我也很好奇。