如何存储神经网络权重

时间:2018-12-10 23:38:25

标签: javascript neural-network synaptic.js

我知道这是该线程的副本:How to store neural network knowledge data?

但是,由于这对我没有帮助,我再次询问,希望得到答案。

我正在使用与另一个线程(synaptics.js)中的那个家伙相同的库。

将神经网络权重训练为JSON文件后如何保存神经网络权重?按照synaptics文档中的说明使用以下方法后,它没有显示在我的文件目录中:

var exported = myNetwork.toJSON();
var imported = Network.fromJSON(exported);

我不知道该怎么办。

我希望能够保存和加载到网络中,而不必每次使用它时都进行培训。

在代码方面,我还是一个业余爱好者,但这是我到目前为止所做的一切:

//SAVESTATE switch
var saveState = true;

// DEBUGGING I/O SWITCH
var debug = false;

//Defining synaptic variables
var Layer = synaptic.Layer;
var Network = synaptic.Network;
var Trainer = synaptic.Trainer;

//Defining network layers
var inputLayer = new Layer(784);
var hiddenLayer = new Layer(100);
var outputLayer = new Layer(10);

//Projecting layers
inputLayer.project(hiddenLayer);
hiddenLayer.project(outputLayer);

//Structuring the network
var myNetwork = new Network({
  input: inputLayer,
  hidden: [hiddenLayer],
  output: outputLayer
});

//DEBUG
if (debug == true) {
  console.log("Network created.")
}

//Defining trainingset
var set = mnist.set(700, 200);
var trainingSet = set.training;
var testSet = set.test;

var trainer = new Trainer(myNetwork);

//rmnh
var drawPos;

if (saveState = true) {
  Network.fromJSON(myNetwork.toJSON());
}

//DEBUG
if (debug == true) {
  console.log("Trainer ready.");
  console.log("RANDOM NUMBER: " + generateNum(0, testSet.length));
}

/*
 * MAIN NEURAL NETWORK FUNCTION
 *
 * Used for
 */
function neuralNetwork(iterations) {

  //DEBUG
  if (debug == true) {
    console.log("Training network.")
  }

  //Start training
  trainer.train(trainingSet, {
    rate: .2,
    iterations: iterations,
    error: .1,
    shuffle: true,
    log: 1,
    cost: Trainer.cost.CROSS_ENTROPY
  });

  //DEBUG
  if (debug == true) {
    console.log("Network done training.")
    console.log(myNetwork.activate(testSet[0].input));
    console.log(testSet[0].output);
  }
};

function setup() {
  var canvas = createCanvas(56, 56);
  canvas.parent('canvas');
  background(50);
}

function testNN() {
  drawPos = 0;
  var screenX = 0;
  var screenY = 0;

  var rnum = generateNum(0, testSet.length);
  var randomNum = rnum

  if (debug == true) {
    console.log(randomNum);
  }

  for (screenY = 0; screenY < 56; screenY += 2) {
    for (screenX = 0; screenX < 56; screenX += 2) {
      generatePixel(screenX, screenY, drawPos, randomNum);
      drawPos++
    }
  }

  var array = testSet[randomNum].input;

  var arrayResult = myNetwork.activate(array);

  var largest = Math.max.apply(Math, arrayResult);

  var result = arrayResult.indexOf(largest);

  var fakeResult = testSet[randomNum].output;

  //DEBUG
  if (debug == true) {
    console.log(myNetwork.activate(testSet[randomNum].input));
  }

  for (var i = 0; i < fakeResult.length; i++) {
    if (fakeResult[i] != 0) {
      document.getElementById('fakeResult').innerHTML = "<span class='h2'>" + i + "</span>";
    }
  }
  document.getElementById('result').innerHTML = "<span class='h2'>" + result + "</span>"
  myNetwork.toJSON();

}

function generatePixel(x, y, pos, val) {
  strokeWeight(0);
  var c = testSet[val].input[drawPos];
  c = c * 100
  var greyScale = color(c);

  if (debug == true) {
    console.log(c);
    console.log(drawPos);
  }

  fill(greyScale);
  var x2 = x + 1
  var y2 = x + 1

  rect(x, y, x2, y2);
}

//Function for generating a random number
function generateNum(min, max) {
  return Math.floor(Math.random() * (max - min + 1) + min);
}

function formSubmit() {
  if (debug == true) {
    console.log("Form Submitted. Awaiting network.");
  }
  var a = document.getElementById('iterations').value;
  neuralNetwork(a);
}
body {
  margin: 0;
}

#header {
  background-color: #FFFFFF;
  min-height: 90px;
  border-bottom: 5px;
  border-bottom-color: #000000;
  border-bottom-style: solid;
}

#container {

}

#interface {
  display: inline-block;
  float: left;
}

#footer {

}

#interfaceText {
  padding-top: 19px;
  margin-right: 10px;
  float: left;
}

#canvas {
float: left;
}

.h1 {
  color: #000000;
  font-size: 50px;
  font-family: Verdana;
  margin: 16px 0px 0px 16px;
  position: absolute;
  cursor:context-menu;
}

.h2 {
  color: #000000;
  font-size: 15px;
  font-family: Verdana;
  cursor:context-menu;
}
<!DOCTYPE HTML>
<html>

<head>
  <script src="scripts/node_modules/synaptic/dist/synaptic.js">
  </script>
  <script src="scripts/p5.js"></script>
  <script src="scripts/node_modules/mnist/dist/mnist.js"></script>
  <link rel="stylesheet" href="stylesheets/main.css">
</head>

<body>
  <section id="header">
    <span class="h1">Neural Network</span>
  </section>
  <section id="container">
    <form id="form">
      <span class="h2">Training:</span><br>
      <span>Iterations:</span>
      <input id="iterations" placeholder="20">
    </form>
    <div id="button"><button onclick="formSubmit()">Submit</button></div>
    <script src="scripts/network.js"></script>
    <br>
  </section>
  <section id="interface">
    <div id="interfaceText">
      <span class="h2">Current picture: </span>
    </div>
    <div id="canvas">
    </div><br><br><br><br>
    <div id="interfaceText">
      <span class="h2">Network result:</span>
    </div>
    <div id="result">
      <span class="errorText">This feature will be buggy if the network is not trained</span>
    </div><br><br>
    <div id="interfaceText">
      <span class="h2">Real value:</span>
    </div>
    <div id="fakeResult">
    </div><br><br><br>
    <div id="button">
      <button onclick="testNN()">Generate</button>
    </div>
  </section>
  <section id="footer">
  </section>
</body>

</html>

请帮助:)

0 个答案:

没有答案