将C代码更改为JAVA代码[神经网络]

时间:2014-02-06 23:19:35

标签: java c

我有问题。我在C中找到了Artifical Neural Network代码,我想把它改成JAVA。所以我在功能上做了这些修改:

在C:

void ComputeFeedForwardSignals(double* MAT_INOUT, double* V_IN,double* V_OUT, double* V_BIAS,int size1,int size2,int layer)
{
  int row,col;
  for(row=0;row < size2; row++) 
    {
      V_OUT[row]=0.0;
      for(col=0;col<size1;col++)V_OUT[row]+=(*(MAT_INOUT+(row*size1)+col)*V_IN[col]);
      V_OUT[row]+=V_BIAS[row];
      if(layer==0) V_OUT[row] = tanh(V_OUT[row]);
   }
}

在JAVA中使用相同的代码:

private void  ComputeFeedForwardSignals (double[][] MAT_INOUT, double[] V_IN, double[] V_OUT, double[] V_BIAS, int size1, int size2, int layer) {
      int row,col;
      for(row=0;row < size2; row++) 
        {
          V_OUT[row]=0.0;
            for(col=0;col<size1;col++)V_OUT[row]+=(*(MAT_INOUT+(row*size1)+col)*V_IN[col]);
          V_OUT[row]+=V_BIAS[row];
          if(layer==0) V_OUT[row] = Math.tanh(V_OUT[row]);
       }
  }

...但我不知道,我怎么能取代这一行:

for(col=0;col<size1;col++)V_OUT[row]+=(*(MAT_INOUT+(row*size1)+col)*V_IN[col]);

请帮帮我。

...编辑

所以,这是原始的,完整的和有效的C代码(我只剪切大数组):

//Analysis Type - Classification 
#include <stdio.h>
#include <conio.h>
#include <math.h>
#include <stdlib.h>


double input_hidden_weights[100][320]=
{
 {5.80887333084651e-001, 1.30476168251902e+000, 5.80288623607794e-001, 8.14671389077252e-001, 5.43029117736068e-001, 5.15946547751079e-001, 4.86324144066176e-001, 3.26116870507742e-001, 4.02847954975450e-001, 1.66380273429940e-001, 3.39504086983093e-001, 5.26249449907226e-002, 2.56145034309448e-001, -9.70569690724137e-002, 1.59242465161706e-001, -2.33214192739307e-001, 1.87648652577582e-002, -3.34041203799010e-001, -1.11393479757180e-001, -3.73367046579334e-001, -2.17964256784897e-001, -4.00192132238685e-001, -2.81209945053773e-001, -3.76012204270845e-001, -3.10614624230684e-001, -3.15921671544855e-001, -2.81529041789346e-001, -2.93125194852085e-001, -2.34501635198919e-001, -2.53483381689582e-001, -1.48248548624413e-001, -2.39004623741234e-001, -3.76904804526471e-002, -2.41802215751481e-001, 9.19371048359130e-002, -2.45416703676273e-001, 2.10015748048206e-001, -2.69629916475520e-001, 3.06584661619282e-001, -2.55206735127512e-001, 3.40664514242493e-001, -2.31156686442331e-001, 3.48619703956421e-001, -1.70205811278937e-001, 2.58540184919054e-001, -9.02211837839767e-002, 1.58219383097399e-001, 2.27214780937834e-002, 4.83874056064942e-002, 1.12670862875882e-001, -7.03828317119573e-002, 2.29852689788459e-001, -1.77576747217408e-001, 3.39353599529732e-001, -2.85410462161884e-001, 4.32495256767594e-001, -3.92336294795819e-001, //...//, 2.05201376126660e+000, 3.66183884164477e-001, -3.13236249649404e-001, -6.90288932727299e-001, -4.03350305378540e-001, -5.18217246345780e-001, -2.11389978349476e-002 }
};

double output_bias[23]={ 9.26583664006470e-001, -5.12649304998720e-001, 8.50167833795463e-001, -3.12847753694448e+000, 3.03755707027426e+000, -2.37819656326885e+000, 2.64402241833182e+000, 2.54741061276513e+000, -1.52779355163716e+000, -1.24383441583121e+000, 3.72867458316797e+000, 4.15165486353238e+000, -4.88008803506142e+000, -3.35612469382435e+000, -2.74153025313899e+000, 3.11815950006976e+000, 1.49738419343937e+000, -3.43954471203446e+000, 3.03236240807163e-001, -4.77180501599233e+000, 3.08664451646140e+000, -1.66680993545569e-002, 2.12529133729690e+000 };

double max_input[320]={ 1.54200000000000e+003, //...//,  1.86000000000000e+002, 2.24000000000000e+002 };

double min_input[320]={ -1.72800000000000e+003, -1.92000000000000e+003, -2.50800000000000e+003, -2.95900000000000e+003, //...//,  -1.13000000000000e+002, -1.01000000000000e+002, -1.19000000000000e+002, -9.30000000000000e+001, -1.18000000000000e+002, -1.01000000000000e+002 };

double input[320] = {618, 1067, 499, 1179, 358, 1187, //...// , 0, 0, 0, 0};
double hidden[100];
double output[23];

double MeanInputs[320]={ 1.24955952380952e+002, 5.69701190476191e+002, 1.37892261904762e+002, //...//,  3.82738095238095e-001, 6.29166666666667e-001 };

void FindMax(double* vec, double* max, long* maxIndex,int len)
{
  long i;
  *max = vec[0];
  *maxIndex = 0;
  for(i=1; i<len; i++)
  {
    if(vec[i]>*max)
    {
      *max = vec[i];
      *maxIndex = i;
    }
  }

}

void ScaleInputs(double* input, double minimum, double maximum, int size)
{
 double delta;
 long i;
 for(i=0; i<size; i++)
 {
    delta = (maximum-minimum)/(max_input[i]-min_input[i]);
    input[i] = minimum - delta*min_input[i]+ delta*input[i];
 }
}

void softmax(double* vec,int len)
{

  long i, j;
  double sum=0.0;
  for(i=0; i<len; i++)
  {
    if(vec[i]>200)
    {
      double max;
      long maxIndex;
      FindMax(vec, &max, &maxIndex,len);
      for(j=0; j<len; j++)
      {        if(j==maxIndex) vec[j] = 1.0;
        else vec[j] = 0.0;
      }
      return;
    }
    else
    {
      vec[i] = exp(vec[i]);
    }
    sum += vec[i];
  }
  if(sum==0)
  {
   long a = 1;
  }
  if(sum!=0.0)
  {
    for(i=0; i<len; i++) vec[i] = vec[i]/sum;
  }
  else for(i=0; i<len; i++) vec[i] = 1.0/(double)len;
}

void ComputeFeedForwardSignals(double* MAT_INOUT,double* V_IN,double* V_OUT, double* V_BIAS,int size1,int size2,int layer)
{
  int row,col;     
  for(row=0;row < size2; row++) 
    {
      V_OUT[row]=0.0;
      for(col=0;col<size1;col++) V_OUT[row]+=(*(MAT_INOUT+(row*size1)+col)*V_IN[col]);

      V_OUT[row]+=V_BIAS[row];
      if(layer==0) V_OUT[row] = tanh(V_OUT[row]);
   }
}

void RunNeuralNet_Classification () 
{ 

  ComputeFeedForwardSignals((double*)input_hidden_weights,input,hidden,hidden_bias,320, 100,0);
  ComputeFeedForwardSignals((double*)hidden_output_wts,hidden,output,output_bias,100, 23,1);
}

int main()
{
  int cont_inps;
  int index;
  int i=0;
  int keyin=1;
  double max;
  while(1)
  {
    max=3.e-300;

    for(cont_inps=0;cont_inps<320;cont_inps++)
    {
     //Substitution of missing continuous variables
     if(input[cont_inps] == -9999)
      input[cont_inps]=MeanInputs[cont_inps];
    }
    ScaleInputs(input,0,1,320);
    RunNeuralNet_Classification();
   //Output Activation is Softmax;
    softmax(output, 23);
    for(i=0;i<23;i++)
    {
      if(max<output[i])
      {
        max=output[i];
        index=i+1;
      }
    }
    printf("\n%s","Predicted category = ");

    switch(index)
    {
        case 1: printf("%s\n","A"); break;
        case 2: printf("%s\n","B"); break;
        case 3: printf("%s\n","C"); break;
        case 4: printf("%s\n","D"); break;
        case 5: printf("%s\n","E"); break;
        case 6: printf("%s\n","F"); break;
        case 7: printf("%s\n","G"); break;
        case 8: printf("%s\n","H"); break;
        case 9: printf("%s\n","I"); break;
        case 10: printf("%s\n","J"); break;
        case 11: printf("%s\n","K"); break;
        case 12: printf("%s\n","L"); break;
        case 13: printf("%s\n","M"); break;
        case 14: printf("%s\n","N"); break;
        case 15: printf("%s\n","O"); break;
        case 16: printf("%s\n","P"); break;
        case 17: printf("%s\n","R"); break;
        case 18: printf("%s\n","S"); break;
        case 19: printf("%s\n","T"); break;
        case 20: printf("%s\n","U"); break;
        case 21: printf("%s\n","W"); break;
        case 22: printf("%s\n","Y"); break;
        case 23: printf("%s\n","Z"); break;
        default: break;
    }
    printf("\n%s%.14f","Confidence level = ",max);
    printf("\n\n%s\n","Press any key to make another prediction or enter 0 to quit the program.");
    keyin=getch();
    if(keyin==48)break;
  }
    return 0;
}

JAVA中的代码效果不佳:

public class MAIN {



    public static void main(String[] args) {
    network Neurons = new network();
    Neurons.Read();
    Neurons.SSN();
    }

      }






import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;


public class network {


    private double[][] Input_hidden_weights = new double[100][320]; 
    private double[] Hidden_bias = new double[100]; 
    private double[][] Hidden_output_wts = new double[23][100];
    private double[] Output_bias = new double[23];
    private double[] Max_input = new double[320];   
    private double[] Min_input = new double[320];   
    private double[] Input = new double[320];
    private double[] Hidden = new double[100];
    private double[] Output = new double[23];
    private double[] MeanInputs = new double[320];

    public void Read() {

        this.Read_Input_hidden_weights();
        this.Hidden_bias();
        this.Hidden_output_wts();
        this.Output_bias();
        this.Max_input();
        this.Min_input();
        this.Input();
        this.MeanInputs();
    }




    private void Read_Input_hidden_weights() {

        FileReader fr = null;
        String linia = "";
        int a=0;
        int b=0;

           try {
             fr = new FileReader("Ihw.txt");
           } catch (FileNotFoundException e) {
               System.out.println("ERR");
               System.exit(1);
           }

           BufferedReader bfr = new BufferedReader(fr);


           try {
             while((linia = bfr.readLine()) != null){
                 Input_hidden_weights[a][b] = Double.parseDouble(linia); 
             b++;

             if (b==320) {
                 b=0;
                 a++;
             }

             }
            } catch (IOException e) {
                System.out.println("ERR");
                System.exit(2);
           }

           try {
             fr.close();
            } catch (IOException e) {
                 System.out.println("ERR");
                 System.exit(3);
                }
            }






    private void Hidden_bias() {

        FileReader fr = null;
        String linia = "";
        int a=0;

           try {
             fr = new FileReader("Hb.txt");
           } catch (FileNotFoundException e) {
               System.out.println("ERR");
               System.exit(1);
           }

           BufferedReader bfr = new BufferedReader(fr);


           try {
             while((linia = bfr.readLine()) != null){
                 Hidden_bias[a] = Double.parseDouble(linia); 
                a++;


             }
            } catch (IOException e) {
                System.out.println("ERR");
                System.exit(2);
           }

           try {
             fr.close();
            } catch (IOException e) {
                 System.out.println("ERR");
                 System.exit(3);
                }
            }






  private void Hidden_output_wts() {

    FileReader fr = null;
    String linia = "";
    int a=0;
    int b=0;

       try {
         fr = new FileReader("How.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
            Hidden_output_wts[a][b] = Double.parseDouble(linia); 
            b++;

         if (b==100) {
             b=0;
             a++;
         }

         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }





  private void Output_bias() {

    FileReader fr = null;
    String linia = "";
    int a=0;

       try {
         fr = new FileReader("Ob.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
             Output_bias[a] = Double.parseDouble(linia); 
            a++;


         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }





  private void Max_input() {

    FileReader fr = null;
    String linia = "";
    int a=0;

       try {
         fr = new FileReader("Mi.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
             Max_input[a] = Double.parseDouble(linia); 
            a++;

         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }





  private void Min_input() {

    FileReader fr = null;
    String linia = "";
    int a=0;

       try {
         fr = new FileReader("Mini.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
             Min_input[a] = Double.parseDouble(linia); 
            a++;

         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }






  private void Input() {

    FileReader fr = null;
    String linia = "";
    int a=0;

       try {
         fr = new FileReader("I.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
             Input[a] = Double.parseDouble(linia); 
            a++;


         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }





  private void MeanInputs() {

    FileReader fr = null;
    String linia = "";
    int a=0;


       try {
         fr = new FileReader("Mei.txt");
       } catch (FileNotFoundException e) {
           System.out.println("ERR");
           System.exit(1);
       }

       BufferedReader bfr = new BufferedReader(fr);


       try {
         while((linia = bfr.readLine()) != null){
             MeanInputs[a] = Double.parseDouble(linia); 
            a++;


         }
        } catch (IOException e) {
            System.out.println("ERR");
            System.exit(2);
       }

       try {
         fr.close();
        } catch (IOException e) {
             System.out.println("ERR");
             System.exit(3);
            }
        }




public void SSN() {

    int cont_inps;
    int index = 0;
    int i = 0;
    double max;


max=3.e-300;


    for (cont_inps=0; cont_inps<320; cont_inps++) {
        if (Input[cont_inps] == -9999) Input[cont_inps] = MeanInputs[cont_inps];    
    }

    ScaleInputs(Input, 0, 1, 320);

    RunNeuralNet_Classification();

    //Output Activation is Softmax;
    Softmax(Output, 23);    

        for(i=0; i<23; i++) {

            if(max < Output[i]) {
                max = Output[i];
                index = i+1;
            }
         }  

        System.out.printf("Predicted category = ");

    switch(index)
        {
            case 1: System.out.printf("A"); break;
            case 2: System.out.printf("B"); break;
            case 3: System.out.printf("C"); break;
            case 4: System.out.printf("D"); break;
            case 5: System.out.printf("E"); break;
            case 6: System.out.printf("F"); break;
            case 7: System.out.printf("G"); break;
            case 8: System.out.printf("H"); break;
            case 9: System.out.printf("I"); break;
            case 10: System.out.printf("J"); break;
            case 11: System.out.printf("K"); break;
            case 12: System.out.printf("L"); break;
            case 13: System.out.printf("M"); break;
            case 14: System.out.printf("N"); break;
            case 15: System.out.printf("O"); break;
            case 16: System.out.printf("P"); break;
            case 17: System.out.printf("R"); break;
            case 18: System.out.printf("S"); break;
            case 19: System.out.printf("T"); break;
            case 20: System.out.printf("U"); break;
            case 21: System.out.printf("W"); break;
            case 22: System.out.printf("Y"); break;
            case 23: System.out.printf("Z"); break;
            default: break;
       }    

    System.out.printf(", Confidence level = " + max);

  }






void FindMax(double vec[], double max[], long maxIndex[], int len)
{
  int i;

  max[0] = vec[0];
  maxIndex[0] = 0;

  for(i=1; i<len; i++) {

    if(vec[i] > max[0]) {
      max[0] = vec[i];
      maxIndex[0] = i;

    }
  }
}




  private void ScaleInputs(double Input[], double minimum, double maximum, int size) {

      double delta;
      int i=0;

      for (i=0; i<size; i++) {
          delta = (maximum-minimum)/(this.Max_input[i] - this.Min_input[i]); 
          Input[i] = minimum - delta * this.Min_input[i] + delta * Input[i];
      }
  }




  private void Softmax (double vec[], int len) {


        int i, j;
        double sum=0.0;

        for(i=0; i<len; i++)
        {
          if(vec[i]>200)
          {
            double[] max = new double[1]; 
            long[] maxIndex = new long[1];

            FindMax(vec, max, maxIndex, len);

            for(j=0; j<len; j++)
            {        if(j==maxIndex[0]) vec[j] = 1.0;
              else vec[j] = 0.0;
            }
            return;
          }
          else
          {
            vec[i] = Math.exp(vec[i]);
          }
          sum += vec[i];
        }
        if(sum==0)
        {
         //long a = 1;
        }
        if(sum!=0.0)
        {
          for(i=0; i<len; i++) vec[i] = vec[i]/sum;
        }
        else for(i=0; i<len; i++) vec[i] = 1.0/(double)len;

}



  private void  ComputeFeedForwardSignals (double[][] MAT_INOUT, double[] V_IN, double[] V_OUT, double[] V_BIAS, int size1, int size2, int layer) {
      int row, col;
      for(row=0;row < size2; row++) 
        {
          V_OUT[row]=0.0;
          for(col=0; col<size1; col++) V_OUT[row] =+ MAT_INOUT[row][col] * V_IN[col];

          V_OUT[row]+=V_BIAS[row];

          if(layer==0) V_OUT[row] = Math.tanh(V_OUT[row]);
       }
  }


  private void RunNeuralNet_Classification() 
  { 
     ComputeFeedForwardSignals(Input_hidden_weights, Input, Hidden, Hidden_bias, 320, 100, 0);

     ComputeFeedForwardSignals(Hidden_output_wts, Hidden, Output, Output_bias, 100, 23, 1);    
  }



}

在我正在打开的文件中,我放了所有数据,具有以下结构:

    5.80887333084651e-001
    1.30476168251902e+000
    5.80288623607794e-001
    8.14671389077252e-001
    5.43029117736068e-001
    5.15946547751079e-001
    4.86324144066176e-001
    3.26116870507742e-001
    4.02847954975450e-001
    1.66380273429940e-001
    3.39504086983093e-001
    5.26249449907226e-002
    2.56145034309448e-001
..... etc.

这些是您可以在C代码中的数组中看到的值。什么是坏事?

1 个答案:

答案 0 :(得分:3)

计算

V_OUT[row]+=(*(MAT_INOUT+(row*size1)+col)*V_IN[col]);

包含[row] [column] 2d数组索引对的手动转换为单维数组索引。这行可以写成

V_OUT[row] += MAT_INOUT[row * size1 + col] * V_IN[col];

我想这会更容易理解!

在(旧方言)C中需要这种代码来实现可变大小的2d数组。在Java中你需要

V_OUT[row] += MAT_INOUT[row][col] * V_IN[col];