在java中读取MNIST数据集时出现高/低端整数的问题

时间:2018-02-16 21:24:00

标签: java mnist

我一直在尝试读取MNIST数据集,以便能够将其格式化以用于神经网络。但是我一直在努力试图从高端转换到低端到工作。

当我读取数据时,输出的第一个Integer是529205256,当转换为低端格式为134777631时,仍然高于2051的预期“幻数” {1}}。

无论我尝试哪种解决方案,我都会得到相同的错误号码,所以如果有人能够向我解释我的错误,我真的很感激。

部分代码是从github借来的。

这是我的代码中发生错误的部分:

public static List<int[][]> getImages(String infile) {
    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(IMAGE_FILE_MAGIC_NUMBER, bb.getInt());
    int numImages = bb.getInt();
    int numRows = bb.getInt();
    int numColumns = bb.getInt();

    List<int[][]> images = new ArrayList<>();

    for (int i = 0; i < numImages; i++)
        images.add(readImage(numRows, numColumns, bb));

    return images;
}

当调用bb.getInt()时,它返回整数529205256,即使在使用此位代码进行转换后

public static int swap(int value)
  {
    int b1 = (value >>  0) & 0xff;
    int b2 = (value >>  8) & 0xff;
    int b3 = (value >> 16) & 0xff;
    int b4 = (value >> 24) & 0xff;

    return b1 << 24 | b2 << 16 | b3 << 8 | b4 << 0;
  }

仍然没有产生正确的数字,因此 assertMagicNumber 会抛出异常,因为值不相等。

如果有必要,这是课程的其余部分:

package core;

import static java.lang.String.format;


import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;

public class MnistReader {
public static final int LABEL_FILE_MAGIC_NUMBER = 2049;
public static final int IMAGE_FILE_MAGIC_NUMBER = 2051;

public static int[] getLabels(String infile) {

    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(LABEL_FILE_MAGIC_NUMBER, bb.getInt());

    int numLabels = bb.getInt();
    int[] labels = new int[numLabels];

    for (int i = 0; i < numLabels; ++i)
        labels[i] = bb.get() & 0xFF; // To unsigned

    return labels;
}

public static List<int[][]> getImages(String infile) {
    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(IMAGE_FILE_MAGIC_NUMBER, bb.getInt());
    int numImages = bb.getInt();
    int numRows = bb.getInt();
    int numColumns = bb.getInt();

    List<int[][]> images = new ArrayList<>();

    for (int i = 0; i < numImages; i++)
        images.add(readImage(numRows, numColumns, bb));

    return images;
}

private static int[][] readImage(int numRows, int numCols, ByteBuffer bb) {
    int[][] image = new int[numRows][];
    for (int row = 0; row < numRows; row++)
        image[row] = readRow(numCols, bb);
    return image;
}

private static int[] readRow(int numCols, ByteBuffer bb) {
    int[] row = new int[numCols];
    for (int col = 0; col < numCols; ++col)
        row[col] = bb.get() & 0xFF; // To unsigned
    return row;
}

public static void assertMagicNumber(int expectedMagicNumber, int magicNumber) {

    System.out.println(expectedMagicNumber);
    System.out.println(magicNumber);

    if (expectedMagicNumber != magicNumber) {
        switch (expectedMagicNumber) {
        case LABEL_FILE_MAGIC_NUMBER:
            throw new RuntimeException("This is not a label file.");
        case IMAGE_FILE_MAGIC_NUMBER:
            throw new RuntimeException("This is not an image file.");
        default:
            throw new RuntimeException(
                    format("Expected magic number %d, found %d", expectedMagicNumber, magicNumber));
        }
    }
}
//
//
//
//

public static ByteBuffer loadFileToByteBuffer(String infile) {
    return ByteBuffer.wrap(loadFile(infile));
}

public static byte[] loadFile(String infile) {
    try {
        RandomAccessFile f = new RandomAccessFile(infile, "r");
        FileChannel chan = f.getChannel();
        long fileSize = chan.size();
        ByteBuffer bb = ByteBuffer.allocate((int) fileSize);
        chan.read(bb);
        bb.flip();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        for (int i = 0; i < fileSize; i++)
            baos.write(bb.get());
        chan.close();
        f.close();
        return baos.toByteArray();
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}



public static String renderImage(int[][] image) {
    StringBuffer sb = new StringBuffer();

    for (int row = 0; row < image.length; row++) {
        sb.append("|");
        for (int col = 0; col < image[row].length; col++) {
            int pixelVal = image[row][col];
            if (pixelVal == 0)
                sb.append(" ");
            else if (pixelVal < 256 / 3)
                sb.append(".");
            else if (pixelVal < 2 * (256 / 3))
                sb.append("x");
            else
                sb.append("X");
        }
        sb.append("|\n");
    }

    return sb.toString();
}

public static String repeat(String s, int n) {
    StringBuilder sb = new StringBuilder();
    for (int i = 0; i < n; i++)
        sb.append(s);
    return sb.toString();
}

/* (Added method)
 * converts the image data from a 2-dimensional to a 1-dimensional array
 * and compresses the pixel values between 0 and 1
 */

public static double[] convertImage(int[][] source) {

    double[] convertedImage = new double[784];
    int currentPos = 0;
    for(int i = 0; i < source.length; i++) {
        for(int j = 0; j < source[i].length; j++) {
            convertedImage[currentPos] = source[i][j] / 255;
        }
    }
    return convertedImage;
}

/* (Added method)
 * converts the label data from an Integer to a vector that can be used
 * as output for the neural network
 */

public static double[] convertLabel(int label) {

    double[] convertedLabel = new double[10];
    convertedLabel[label] = 1;
    return convertedLabel;
}

public static int swap(int value)
  {
    int b1 = (value >>  0) & 0xff;
    int b2 = (value >>  8) & 0xff;
    int b3 = (value >> 16) & 0xff;
    int b4 = (value >> 24) & 0xff;

    return b1 << 24 | b2 << 16 | b3 << 8 | b4 << 0;
  }

}

我真的不知道我的错误在哪里,所以我们将不胜感激。

链接到MNIST:http://yann.lecun.com/exdb/mnist/

编辑:事实证明解压缩文件本身存在问题。修好之后,一切都按预期开始了

1 个答案:

答案 0 :(得分:0)

我可以建议一个解决方案。在这个class中,您可以找到以下这些行:

IsActive

还有3个额外的课程,可以轻松重复使用。