Docjar: A Java Source and Docuemnt Enginecom.*    java.*    javax.*    org.*    all    new    plug-in

Quick Search    Search Deep

Source code: AI/NeuralNetworks/DistributedBackErrorPropagation.java


1   /* DistributedBackErrorPropagation.java */
2   
3   package AI.NeuralNetworks;
4   
5   import java.util.*;
6   import java.net.*;
7   import java.io.*;
8   //import MPI.*;
9   
10  /** 
11    This Class is a nerual network with a multiple layer feed forward architecture where
12    a distributed version of the back error propagation algorithm could be run
13  */
14  public class DistributedBackErrorPropagation extends BackErrorPropagation implements NeuralNetworkTeacher{
15  //  protected boolean master;
16    protected Set hosts;
17    protected Socket sock;
18    
19    protected BufferedWriter bw;
20    
21  //  protected int rootProcess = 0;
22  //  protected Comm communicator;
23    
24  //  public static int ROOT = 0;
25  //  public static String LOAD_STRUCTURE = "";
26    
27  
28    /**
29     Creates a new DistributedBackErrorPropagation instance
30     */
31    public DistributedBackErrorPropagation(FeedForwardNetwork network, BufferedWriter bw){
32      ListIterator originalNet;
33      
34      layers = new LinkedList();
35      
36      originalNet = network.layers.listIterator();
37      layers.add(new FeedForwardInputLayer((FeedForwardInputLayer) originalNet.next()));
38      while(originalNet.hasNext()){
39        layers.add(new DistributedFeedForwardWeightedTrainingLayer((FeedForwardWeightedLayer) originalNet.next() /*, communicator, rootProcess*/  ));
40      }
41      this.bw = bw;
42    }
43  
44    /**
45     Creates a new DistributedBackErrorPropagation instance
46     */
47    public DistributedBackErrorPropagation(FeedForwardNetwork network){
48      this(network, null);
49    }
50    /**
51     Trains the network with the  given patterns
52     @param patterns are the pairs of input-expected output that the network must learn
53     */
54    public float trainNetwork(LinkedList trainingPatterns) throws UnexpectedInputArraySizeException{
55      ListIterator patternList;
56      PatternPair pattern;
57      float[] outputError;
58  //    int epoch=0;
59      float error = 0;
60  //    float[] tempError = new float[1];
61          
62      patternList = trainingPatterns.listIterator(0);
63      
64      resetDeltaWeights();
65      error = 0;
66      
67  //System.out.println(this.toString());    
68      
69      while(patternList.hasNext()){
70        pattern = (PatternPair) patternList.next();
71      
72  //        System.out.println("\t"+r+"\tepoch: "+epoch+"\tpattern: "+niceArray(pattern.input())); //debug 
73        adjustDeltaWeights(pattern.input(), pattern.output());
74  
75        outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError();
76        for (int e=0; e<outputError.length; e++){
77          error += outputError[e] * outputError[e];
78  //System.out.println("errorAcum " + error );
79        }
80      }
81  //    outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getError();
82  //    tempError[0] = error;
83  //    tempError = communicator.AllReduceSumFloat(tempError, rootProcess);
84  //    error = (float) Math.sqrt(Math.abs(tempError[0] / (trainingPatterns.size() * outputError.length)));
85  
86  //    addDeltaWeights();      
87  
88      sendError(error);
89  //System.out.println("errorSend " + error );
90      sendDeltaWeights();      
91          
92      return error;
93    }
94    
95    public void sendError(float error){
96      try{
97        /**/
98        int errorBits = Float.floatToIntBits(error);
99        bw.write( errorBits & 0x000000ff);
100       bw.write((errorBits & 0x0000ff00) >>>8);
101       bw.write((errorBits & 0x00ff0000) >>>16);
102       bw.write((errorBits & 0xff000000) >>>24);
103 //System.out.println("sendError = "+ error);//debug      
104       /**/
105       //bw.write(""+error); // cambiar aca que manda el error como String
106       //bw.newLine();
107       bw.flush();
108     } catch (IOException ioe) {
109       ioe.printStackTrace();
110     }
111   }
112   
113   public void sendDeltaWeights(){// cambiar aca que manda los pesos como String
114     ListIterator layer;
115     
116     layer = layers.listIterator(1); //the first layer does not have weights
117     
118     while(layer.hasNext()){
119       ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).sendDeltaWeights(bw);
120     }    
121   }
122   
123   /**
124     Adds the all deltaWeights to the weights
125   */
126   public void addDeltaWeights(){
127     ListIterator layer;
128     
129     layer = layers.listIterator(1); //the first layer does not have weights
130     
131     while(layer.hasNext()){
132       ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).addDeltaWeights(learningRate);
133     }
134   }  
135   
136   /**
137     adds the delta weights form br to the local delta weights
138   */
139   public void parseDeltaWeights(BufferedReader br){
140     ListIterator layer;
141     
142     layer = layers.listIterator(1); //the first layer does not have weights
143     
144     while(layer.hasNext()){
145       ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).parseDeltaWeights(br);
146     }
147   }
148   
149   
150   
151   static private String niceArray(float array[]){
152     String s = "[ ";
153     for(int i=0; i<array.length; i++){
154       s += array[i] + " ";
155     }
156     s+= "]";
157     return s;
158   }
159   static private String niceArray(int array[]){
160     String s = "[ ";
161     for(int i=0; i<array.length; i++){
162       s += array[i] + " ";
163     }
164     s+= "]";
165     return s;
166   }
167   
168   static private void showList(LinkedList pat){ //debug
169     ListIterator li = pat.listIterator(0);
170     
171     while(li.hasNext()){
172       System.out.println(((PatternPair) li.next()).toString());
173     }
174   
175   }
176 
177 }