Docjar: A Java Source and Docuemnt Enginecom.*    java.*    javax.*    org.*    all    new    plug-in

Quick Search    Search Deep

Source code: AI/NeuralNetworks/DistributedFeedForwardWeightedTrainingLayer.java


1   /* DistributedFeedForwardWeightedTrainingLayerjava */
2   
3   package AI.NeuralNetworks;
4   
5   import java.util.*;
6   import java.io.*;
7   //import MPI.*;
8   
9   /**
10    This class is used to represent the inner and output layers of a multilayer feed forward neural network
11  */
12  public class DistributedFeedForwardWeightedTrainingLayer extends FeedForwardWeightedTrainingLayer {
13  //  protected int rootProcess;
14  //  protected Comm communicator;
15    
16    /**
17      creates a new insatnce of FeedForwardWeightedLayer
18      @param size is the size of the input layer without counting the bias
19      @param previousSize is the size of the previous layer without counting the bias
20    */
21  /*  public DistributedFeedForwardWeightedTrainingLayer(int size, int previousSize, Comm communicator, int root){
22      super(size, previousSize);
23      this.communicator = communicator;
24      this.rootProcess = root;
25    }
26  */  
27    /**
28      creates a new insatnce of FeedForwardWeightedLayer
29      @param layer is the FeedForwardWeightedLayer from which the new layer is created
30    */
31    public DistributedFeedForwardWeightedTrainingLayer(FeedForwardWeightedLayer layer/*, Comm communicator, int root*/  ){
32      super(layer);
33  //    this.communicator = communicator;
34  //    this.rootProcess = root;
35    }
36    
37    /**
38      Adds the deltaWeights to the weights
39      @param learninRate is the constan by wich is multiplied the deltaWeight before it is added to the weights
40    */
41  /*  public void addDeltaWeights(float learningRate){
42      float[] tempWeight;
43      
44      for(int row=0; row<deltaWeights.length; row++){
45        deltaWeights[row] = communicator.AllReduceSumFloat(deltaWeights[row], rootProcess);
46        for(int col=0; col<deltaWeights[row].length; col++){
47          weights[row][col] += deltaWeights[row][col] * learningRate;
48        }
49      }
50    }  
51  */  
52    
53    /**
54      Sendsthe deltaWeights to the weights in the BufferedWriter
55      @param bw is the BufferedWriter where to wright the weights
56    */
57    public void sendDeltaWeights(BufferedWriter bw){
58  //    float[] tempWeight;
59      int weightBits;
60      
61      try{
62        
63        for(int row=0; row<deltaWeights.length; row++){
64    //      deltaWeights[row] = communicator.AllReduceSumFloat(deltaWeights[row], rootProcess);
65          for(int col=0; col<deltaWeights[row].length; col++){
66            //bw.write("\t"+deltaWeights[row][col]);
67  //System.out.println("SdeltaWeight["+row+"]["+col+"]="+deltaWeights[row][col]);//debug
68            weightBits = Float.floatToIntBits(deltaWeights[row][col]);
69  //System.out.println("SdeltaWeight["+row+"]["+col+"]="+weightBits + "-->" + deltaWeights[row][col]);//debug          
70            bw.write( weightBits & 0x000000ff);
71            bw.write((weightBits & 0x0000ff00) >>>8);
72            bw.write((weightBits & 0x00ff0000) >>>16);
73            bw.write((weightBits & 0xff000000) >>>24);
74          }
75    //      bw.newLine();
76        }
77        bw.flush();
78      }catch (IOException ioe){
79        ioe.printStackTrace();
80      }
81    }  
82  
83  
84    /**
85      adds the delta weights form br to the local delta weights
86      @param br is the BufferedReader from where to parse the weights
87    */
88    public void parseDeltaWeights(BufferedReader br){
89  //    StringTokenizer st;
90      int lolo, lohi, hilo, hihi, receivedDWBits;
91      
92      try{
93        
94        for(int row=0; row<deltaWeights.length; row++){
95  //        st = new StringTokenizer( br.readLine() );
96          for(int col=0; col<deltaWeights[row].length; col++){
97            //deltaWeights[row][col] += Float.parseFloat( st.nextToken() );
98            lolo = br.read();
99            lohi = br.read();
100           hilo = br.read();
101           hihi = br.read();
102           receivedDWBits = lolo | (lohi<<8) | (hilo<<16) | (hihi<<24);
103           deltaWeights[row][col] = Float.intBitsToFloat(receivedDWBits);
104 //System.out.println("RdeltaWeight["+row+"]["+col+"]="+ receivedDWBits + "-->" +deltaWeights[row][col]);//debug          
105         }
106       }
107     }catch (IOException ioe){
108       ioe.printStackTrace();
109     }
110     
111   }
112 
113 }