Source code: AI/NeuralNetworks/DistributedFeedForwardWeightedTrainingLayer.java
1 /* DistributedFeedForwardWeightedTrainingLayerjava */
2
3 package AI.NeuralNetworks;
4
5 import java.util.*;
6 import java.io.*;
7 //import MPI.*;
8
9 /**
10 This class is used to represent the inner and output layers of a multilayer feed forward neural network
11 */
12 public class DistributedFeedForwardWeightedTrainingLayer extends FeedForwardWeightedTrainingLayer {
13 // protected int rootProcess;
14 // protected Comm communicator;
15
16 /**
17 creates a new insatnce of FeedForwardWeightedLayer
18 @param size is the size of the input layer without counting the bias
19 @param previousSize is the size of the previous layer without counting the bias
20 */
21 /* public DistributedFeedForwardWeightedTrainingLayer(int size, int previousSize, Comm communicator, int root){
22 super(size, previousSize);
23 this.communicator = communicator;
24 this.rootProcess = root;
25 }
26 */
27 /**
28 creates a new insatnce of FeedForwardWeightedLayer
29 @param layer is the FeedForwardWeightedLayer from which the new layer is created
30 */
31 public DistributedFeedForwardWeightedTrainingLayer(FeedForwardWeightedLayer layer/*, Comm communicator, int root*/ ){
32 super(layer);
33 // this.communicator = communicator;
34 // this.rootProcess = root;
35 }
36
37 /**
38 Adds the deltaWeights to the weights
39 @param learninRate is the constan by wich is multiplied the deltaWeight before it is added to the weights
40 */
41 /* public void addDeltaWeights(float learningRate){
42 float[] tempWeight;
43
44 for(int row=0; row<deltaWeights.length; row++){
45 deltaWeights[row] = communicator.AllReduceSumFloat(deltaWeights[row], rootProcess);
46 for(int col=0; col<deltaWeights[row].length; col++){
47 weights[row][col] += deltaWeights[row][col] * learningRate;
48 }
49 }
50 }
51 */
52
53 /**
54 Sendsthe deltaWeights to the weights in the BufferedWriter
55 @param bw is the BufferedWriter where to wright the weights
56 */
57 public void sendDeltaWeights(BufferedWriter bw){
58 // float[] tempWeight;
59 int weightBits;
60
61 try{
62
63 for(int row=0; row<deltaWeights.length; row++){
64 // deltaWeights[row] = communicator.AllReduceSumFloat(deltaWeights[row], rootProcess);
65 for(int col=0; col<deltaWeights[row].length; col++){
66 //bw.write("\t"+deltaWeights[row][col]);
67 //System.out.println("SdeltaWeight["+row+"]["+col+"]="+deltaWeights[row][col]);//debug
68 weightBits = Float.floatToIntBits(deltaWeights[row][col]);
69 //System.out.println("SdeltaWeight["+row+"]["+col+"]="+weightBits + "-->" + deltaWeights[row][col]);//debug
70 bw.write( weightBits & 0x000000ff);
71 bw.write((weightBits & 0x0000ff00) >>>8);
72 bw.write((weightBits & 0x00ff0000) >>>16);
73 bw.write((weightBits & 0xff000000) >>>24);
74 }
75 // bw.newLine();
76 }
77 bw.flush();
78 }catch (IOException ioe){
79 ioe.printStackTrace();
80 }
81 }
82
83
84 /**
85 adds the delta weights form br to the local delta weights
86 @param br is the BufferedReader from where to parse the weights
87 */
88 public void parseDeltaWeights(BufferedReader br){
89 // StringTokenizer st;
90 int lolo, lohi, hilo, hihi, receivedDWBits;
91
92 try{
93
94 for(int row=0; row<deltaWeights.length; row++){
95 // st = new StringTokenizer( br.readLine() );
96 for(int col=0; col<deltaWeights[row].length; col++){
97 //deltaWeights[row][col] += Float.parseFloat( st.nextToken() );
98 lolo = br.read();
99 lohi = br.read();
100 hilo = br.read();
101 hihi = br.read();
102 receivedDWBits = lolo | (lohi<<8) | (hilo<<16) | (hihi<<24);
103 deltaWeights[row][col] = Float.intBitsToFloat(receivedDWBits);
104 //System.out.println("RdeltaWeight["+row+"]["+col+"]="+ receivedDWBits + "-->" +deltaWeights[row][col]);//debug
105 }
106 }
107 }catch (IOException ioe){
108 ioe.printStackTrace();
109 }
110
111 }
112
113 }